1、Torch构建简单的模型

  1. # coding:utf-8
  2. import torch
  3.  
  4. class Net(torch.nn.Module):
  5. def __init__(self,img_rgb=3,img_size=32,img_class=13):
  6. super(Net, self).__init__()
  7. self.conv1 = torch.nn.Sequential(
  8. torch.nn.Conv2d(in_channels=img_rgb, out_channels=img_size, kernel_size=3, stride=1,padding= 1), #
  9. torch.nn.ReLU(),
  10. torch.nn.MaxPool2d(2),
  11. # torch.nn.Dropout(0.5)
  12. )
  13. self.conv2 = torch.nn.Sequential(
  14. torch.nn.Conv2d(28, 64, 3, 1, 1),
  15. torch.nn.ReLU(),
  16. torch.nn.MaxPool2d(2)
  17. )
  18. self.conv3 = torch.nn.Sequential(
  19. torch.nn.Conv2d(64, 64, 3, 1, 1),
  20. torch.nn.ReLU(),
  21. torch.nn.MaxPool2d(2)
  22. )
  23. self.dense = torch.nn.Sequential(
  24. torch.nn.Linear(64 * 3 * 3, 128),
  25. torch.nn.ReLU(),
  26. torch.nn.Linear(128, img_class)
  27. )
  28.  
  29. def forward(self, x):
  30. conv1_out = self.conv1(x)
  31. conv2_out = self.conv2(conv1_out)
  32. conv3_out = self.conv3(conv2_out)
  33. res = conv3_out.view(conv3_out.size(0), -1)
  34. out = self.dense(res)
  35. return out
  36.  
  37. CUDA = torch.cuda.is_available()
  38.  
  39. model = Net(1,28,13)
  40. print(model)
  41.  
  42. optimizer = torch.optim.Adam(model.parameters())
  43. loss_func = torch.nn.MultiLabelSoftMarginLoss()#nn.CrossEntropyLoss()
  44.  
  45. if CUDA:
  46. model.cuda()
  47.  
  48. def batch_training_data(x_train,y_train,batch_size,i):
  49. n = len(x_train)
  50. left_limit = batch_size*i
  51. right_limit = left_limit+batch_size
  52. if n>=right_limit:
  53. return x_train[left_limit:right_limit,:,:,:],y_train[left_limit:right_limit,:]
  54. else:
  55. return x_train[left_limit:, :, :, :], y_train[left_limit:, :]

  

2、奉献训练过程的代码

  1. # coding:utf-8
  2. import time
  3. import os
  4. import torch
  5. import numpy as np
  6. from data_processing import get_DS
  7. from CNN_nework_model import cnn_face_discern_model
  8. from torch.autograd import Variable
  9. from use_torch_creation_model import optimizer, model, loss_func, batch_training_data,CUDA
  10. from sklearn.metrics import accuracy_score
  11.  
  12. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  13.  
  14. st = time.time()
  15. # 获取训练集与测试集以 8:2 分割
  16. x_,y_,y_true,label = get_DS()
  17.  
  18. label_number = len(label)
  19.  
  20. x_train,y_train = x_[:960,:,:,:].reshape((960,1,28,28)),y_[:960,:]
  21.  
  22. x_test,y_test = x_[960:,:,:,:].reshape((340,1,28,28)),y_[960:,:]
  23.  
  24. y_test_label = y_true[960:]
  25.  
  26. print(time.time() - st)
  27. print(x_train.shape,x_test.shape)
  28.  
  29. batch_size = 100
  30. n = int(len(x_train)/batch_size)+1
  31.  
  32. for epoch in range(100):
  33. global loss
  34. for batch in range(n):
  35. x_training,y_training = batch_training_data(x_train,y_train,batch_size,batch)
  36. batch_x,batch_y = Variable(torch.from_numpy(x_training)).float(),Variable(torch.from_numpy(y_training)).float()
  37. if CUDA:
  38. batch_x=batch_x.cuda()
  39. batch_y=batch_y.cuda()
  40.  
  41. out = model(batch_x)
  42. loss = loss_func(out, batch_y)
  43.  
  44. optimizer.zero_grad()
  45. loss.backward()
  46. optimizer.step()
  47. # 测试精确度
  48. if epoch%9==0:
  49. global x_test_tst
  50. if CUDA:
  51. x_test_tst = Variable(torch.from_numpy(x_test)).float().cuda()
  52. y_pred = model(x_test_tst)
  53.  
  54. y_predict = np.argmax(y_pred.cpu().data.numpy(),axis=1)
  55.  
  56. acc = accuracy_score(y_test_label,y_predict)
  57.  
  58. print("loss={} aucc={}".format(loss.cpu().data.numpy(),acc))

  

3、总结

通过博主通过TensorFlow、keras、pytorch进行训练同样的模型同样的图像数据,结果发现,pyTorch快了很多倍,特别是在导入模型的时候比TensorFlow快了很多。合适部署接口和集成在项目中。

奉献pytorch 搭建 CNN 卷积神经网络训练图像识别的模型,配合numpy 和matplotlib 一起使用调用 cuda GPU进行加速训练的更多相关文章

  1. pytorch 8 CNN 卷积神经网络

    # library # standard library import os # third-party library import torch import torch.nn as nn impo ...

  2. 用Keras搭建神经网络 简单模版(三)—— CNN 卷积神经网络(手写数字图片识别)

    # -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...

  3. Keras(四)CNN 卷积神经网络 RNN 循环神经网络 原理及实例

    CNN 卷积神经网络 卷积 池化 https://www.cnblogs.com/peng8098/p/nlp_16.html 中有介绍 以数据集MNIST构建一个卷积神经网路 from keras. ...

  4. Deep Learning模型之:CNN卷积神经网络(一)深度解析CNN

    http://m.blog.csdn.net/blog/wu010555688/24487301 本文整理了网上几位大牛的博客,详细地讲解了CNN的基础结构与核心思想,欢迎交流. [1]Deep le ...

  5. [转]Theano下用CNN(卷积神经网络)做车牌中文字符OCR

    Theano下用CNN(卷积神经网络)做车牌中文字符OCR 原文地址:http://m.blog.csdn.net/article/details?id=50989742 之前时间一直在看 Micha ...

  6. Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现(转)

    Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文, ...

  7. CNN(卷积神经网络)、RNN(循环神经网络)、DNN(深度神经网络)的内部网络结构有什么区别?

    https://www.zhihu.com/question/34681168 CNN(卷积神经网络).RNN(循环神经网络).DNN(深度神经网络)的内部网络结构有什么区别?修改 CNN(卷积神经网 ...

  8. day-16 CNN卷积神经网络算法之Max pooling池化操作学习

    利用CNN卷积神经网络进行训练时,进行完卷积运算,还需要接着进行Max pooling池化操作,目的是在尽量不丢失图像特征前期下,对图像进行downsampling. 首先看下max pooling的 ...

  9. cnn(卷积神经网络)比较系统的讲解

    本文整理了网上几位大牛的博客,详细地讲解了CNN的基础结构与核心思想,欢迎交流. [1]Deep learning简介 [2]Deep Learning训练过程 [3]Deep Learning模型之 ...

随机推荐

  1. beautifulsoup4 用法一二

    声明一个beautifulsoup4对象 bs = ( url,//路由 html_parser,//解析html代码 encoding//编码)//另一种请求解析方法 import requests ...

  2. Python 面向对象Ⅴ

    基础重载方法 下表列出了一些通用的功能,你可以在自己的类重写: 运算符重载 Python同样支持运算符重载,实例如下: 以上代码执行结果如下所示: 类属性与方法 类的私有属性 __private_at ...

  3. 关于C#的学习

    长期以来对C#的认识一直停留在微软件开发的完全面向对象的语言的模糊印象上,对其工程也缺乏多文件以上级别的修改能力,而当前流行度的驱使下,想深入了解它并运用. 于是从git上下载了一个C#开源项目,打开 ...

  4. jquery which事件 语法

    jquery which事件 语法 作用:which 属性指示按了哪个键或按钮.大理石平台精度等级 语法:event.whic 参数: 参数 描述 event     必需.规定要检查的事件.这个 e ...

  5. bytes和bytearray总结

    The core built-in types for manipulating binary data are bytes and bytearray. They are supported by ...

  6. selenium实现chrome分屏截图的合并

    selenium的截图功能在chrome下无法实现,但是可以操作滚动条来一屏一屏的截图,然后再合并成一张图,合并图片的代码在网上找的,十分感谢那位朋友,具体解决方案如下:直接上代码: def capt ...

  7. AtCoder AGC017C Snuke and Spells

    题目链接 https://atcoder.jp/contests/agc017/tasks/agc017_c 题解 很久前不会做看了题解,现在又看了一下,只想说,这种智商题真的杀我... 转化成如果现 ...

  8. Zookeeper入门(六)之zkCli.sh对节点的增删改查

    参考地址为:https://www.cnblogs.com/sherrykid/p/5813148.html 1.连接 在 bin 目录下的  zkCli.sh  就是ZooKeeper客户端 ./z ...

  9. 安装Dubbo 并且安装注册中心(Zookeeper-3.3.6)

    安装zookeeper 安装Tomcat 载dubbo-admin-2.5.4.war 进入Apache ZooKeeper官方网站进行下载,https://zookeeper.apache.org/ ...

  10. 第六周总结&实验报告四

    这周是放国庆节的假,所有没有进行深入的学习,只是写了个实验的题目,也发现了自己在基础上还是要加强学习. 实验四 类的继承 一. 实验目的 (1) 掌握类的继承方法: (2) 变量的继承和覆盖,方法的继 ...