Pytorch是torch的Python版本,对TensorFlow造成很大的冲击,TensorFlow无疑是最流行的,但是Pytorch号称在诸多性能上要优于TensorFlow,比如在RNN的训练上,所以Pytorch也吸引了很多人的关注。之前有一篇关于TensorFlow实现的CNN可以用来做对比。

下面我们就开始用Pytorch实现CNN。

step 0 导入需要的包

  1. import torch
  2. import torch.nn as nn
  3. from torch.autograd import Variable
  4. import torch.utils.data as data
  5. import matplotlib.pyplot as plt

step 1  数据预处理

这里需要将training data转化成torch能够使用的DataLoader,这样可以方便使用batch进行训练。

  1. import torchvision #数据库模块
  2.  
  3. torch.manual_seed(1) #reproducible
  4.  
  5. #Hyper Parameters
  6. EPOCH = 1
  7. BATCH_SIZE = 50
  8. LR = 0.001
  9.  
  10. train_data = torchvision.datasets.MNIST(
  11. root='/mnist/', #保存位置
  12. train=True, #training set
  13. transform=torchvision.transforms.ToTensor(), #converts a PIL.Image or numpy.ndarray
  14. #to torch.FloatTensor(C*H*W) in range(0.0,1.0)
  15. download=True
  16. )
  17.  
  18. test_data = torchvision.datasets.MNIST(root='/MNIST/')
  19. #如果是普通的Tensor数据,想使用torch_dataset = data.TensorDataset(data_tensor=x, target_tensor=y)
  20. #将Tensor转换成torch能识别的dataset
  21. #批训练, 50 samples, 1 channel, 28*28, (50, 1, 28 ,28)
  22. train_loader = data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
  23.  
  24. test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.
  25. test_y = test_data.test_lables[:2000]

step 2 定义网络结构

需要指出的几个地方:1)class CNN需要继承Module ; 2)需要调用父类的构造方法:super(CNN, self).__init__()  ;3)在Pytorch中激活函数Relu也算是一层layer; 4)需要实现forward()方法,用于网络的前向传播,而反向传播只需要调用Variable.backward()即可。

  1. class CNN(nn.Module):
  2. def __init__(self):
  3. super(CNN, self).__init__()
  4. self.conv1 = nn.Sequential( #input shape (1,28,28)
  5. nn.Conv2d(in_channels=1, #input height
  6. out_channels=16, #n_filter
  7. kernel_size=5, #filter size
  8. stride=1, #filter step
  9. padding=2 #con2d出来的图片大小不变
  10. ), #output shape (16,28,28)
  11. nn.ReLU(),
  12. nn.MaxPool2d(kernel_size=2) #2x2采样,output shape (16,14,14)
  13.  
  14. )
  15. self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), #output shape (32,7,7)
  16. nn.ReLU(),
  17. nn.MaxPool2d(2))
  18. self.out = nn.Linear(32*7*7,10)
  19.  
  20. def forward(self, x):
  21. x = self.conv1(x)
  22. x = self.conv2(x)
  23. x = x.view(x.size(0), -1) #flat (batch_size, 32*7*7)
  24. output = self.out(x)
  25. return output

step 3 查看网络结构

使用print(cnn)可以看到网络的结构详细信息,ReLU()真的是一层layer。

  1. cnn = CNN()
  2. print(cnn)

step 4 训练

指定optimizer,loss function,需要特别指出的是记得每次反向传播前都要清空上一次的梯度,optimizer.zero_grad()。

  1. #optimizer
  2. optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
  3.  
  4. #loss_fun
  5. loss_func = nn.CrossEntropyLoss()
  6.  
  7. #training loop
  8. for epoch in range(EPOCH):
  9. for i, (x, y) in enumerate(train_loader):
  10. batch_x = Variable(x)
  11. batch_y = Variable(y)
  12. #输入训练数据
  13. output = cnn(batch_x)
  14. #计算误差
  15. loss = loss_func(output, batch_y)
  16. #清空上一次梯度
  17. optimizer.zero_grad()
  18. #误差反向传递
  19. loss.backward()
  20. #优化器参数更新
  21. optimizer.step()

step 5 预测结果

  1. test_output =cnn(test_x[:10])
  2. pred_y = torch.max(test_output,1)[1].data.numpy().squeeze()
  3. print(pred_y, 'prediction number')
  4. print(test_y[:10])

reference:

莫凡python pytorch 教程

Pytorch实现卷积神经网络CNN的更多相关文章

  1. 写给程序员的机器学习入门 (八) - 卷积神经网络 (CNN) - 图片分类和验证码识别

    这一篇将会介绍卷积神经网络 (CNN),CNN 模型非常适合用来进行图片相关的学习,例如图片分类和验证码识别,也可以配合其他模型实现 OCR. 使用 Python 处理图片 在具体介绍 CNN 之前, ...

  2. 卷积神经网络(CNN)前向传播算法

    在卷积神经网络(CNN)模型结构中,我们对CNN的模型结构做了总结,这里我们就在CNN的模型基础上,看看CNN的前向传播算法是什么样子的.重点会和传统的DNN比较讨论. 1. 回顾CNN的结构 在上一 ...

  3. 卷积神经网络(CNN)反向传播算法

    在卷积神经网络(CNN)前向传播算法中,我们对CNN的前向传播算法做了总结,基于CNN前向传播算法的基础,我们下面就对CNN的反向传播算法做一个总结.在阅读本文前,建议先研究DNN的反向传播算法:深度 ...

  4. 卷积神经网络CNN总结

    从神经网络到卷积神经网络(CNN)我们知道神经网络的结构是这样的: 那卷积神经网络跟它是什么关系呢?其实卷积神经网络依旧是层级网络,只是层的功能和形式做了变化,可以说是传统神经网络的一个改进.比如下图 ...

  5. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  6. 深度学习之卷积神经网络(CNN)详解与代码实现(二)

    用Tensorflow实现卷积神经网络(CNN) 本文系作者原创,转载请注明出处:https://www.cnblogs.com/further-further-further/p/10737065. ...

  7. 深度学习之卷积神经网络(CNN)详解与代码实现(一)

    卷积神经网络(CNN)详解与代码实现 本文系作者原创,转载请注明出处:https://www.cnblogs.com/further-further-further/p/10430073.html 目 ...

  8. 【深度学习系列】卷积神经网络CNN原理详解(一)——基本原理

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  9. 卷积神经网络(CNN)学习笔记1:基础入门

    卷积神经网络(CNN)学习笔记1:基础入门 Posted on 2016-03-01   |   In Machine Learning  |   9 Comments  |   14935  Vie ...

随机推荐

  1. el表达式的首字母大小写问题

    EL表达式获取对象属性的原理是这样的: 以表达式${user.name}为例 EL表达式会根据name去User类里寻找这个name的get方法,此时会自动把name首字母大写并加上get前缀,一旦找 ...

  2. spotlight on windows 监控

    1. spotlight on windows 安装 下载 https://pan.baidu.com/s/1qYi3lec Spotlight大家可以从其官方网站(http://www.quest. ...

  3. ModelShowDialog缓存上次浏览的URL

    1. 一种解决方法设置每次清楚浏览的页面. In IE7, go to Tools  |  Internet Options.  Click the Browsing History "Se ...

  4. 一个最简单的JStorm例子

    最简单的JStorm例子分为以下几个步骤: 1.生成Topology Map conf = new HashMp(); //topology所有自定义的配置均放入这个Map TopologyBuild ...

  5. 【BZOJ4260】Codechef REBXOR Trie树+贪心

    [BZOJ4260]Codechef REBXOR Description Input 输入数据的第一行包含一个整数N,表示数组中的元素个数. 第二行包含N个整数A1,A2,…,AN. Output ...

  6. 170110、Spring 事物机制总结

    spring两种事物处理机制,一是声明式事物,二是编程式事物 声明式事物 1)Spring的声明式事务管理在底层是建立在AOP的基础之上的.其本质是对方法前后进行拦截,然后在目标方法开始之前创建或者加 ...

  7. Powershell About LocalGroupMembership

    一: 结合active directory获取本地群组成员信息(包含本地用户和域用户,及域用户的情况 $DBServer = "xxxx" $DBDatabase = " ...

  8. IO流入门-第一章-FileInputStream

    FileInputStreamj基本用法和方法示例 import java.io.*; public class FileInputStreamTest01 { public static void ...

  9. DOM 综合练习(二)

    // 需求一: 二级联动菜单 <html> <head> <style type="text/css"> select{ width:100px ...

  10. VS c++ opencv画图

    任务:用c++在图片上画线 之前用过python的opencv,所以直接想到了用c++的opencv来画线. 但关键就是VS中如何配置c++ opencv库的问题: vs中opencv库的配置:htt ...