Pytorch是torch的Python版本,对TensorFlow造成很大的冲击,TensorFlow无疑是最流行的,但是Pytorch号称在诸多性能上要优于TensorFlow,比如在RNN的训练上,所以Pytorch也吸引了很多人的关注。之前有一篇关于TensorFlow实现的CNN可以用来做对比。

下面我们就开始用Pytorch实现CNN。

step 0 导入需要的包

 import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as data
import matplotlib.pyplot as plt

step 1  数据预处理

这里需要将training data转化成torch能够使用的DataLoader,这样可以方便使用batch进行训练。

 import torchvision  #数据库模块

 torch.manual_seed(1) #reproducible

 #Hyper Parameters
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001 train_data = torchvision.datasets.MNIST(
root='/mnist/', #保存位置
train=True, #training set
transform=torchvision.transforms.ToTensor(), #converts a PIL.Image or numpy.ndarray
#to torch.FloatTensor(C*H*W) in range(0.0,1.0)
download=True
) test_data = torchvision.datasets.MNIST(root='/MNIST/')
#如果是普通的Tensor数据,想使用torch_dataset = data.TensorDataset(data_tensor=x, target_tensor=y)
#将Tensor转换成torch能识别的dataset
#批训练, 50 samples, 1 channel, 28*28, (50, 1, 28 ,28)
train_loader = data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_lables[:2000]

step 2 定义网络结构

需要指出的几个地方:1)class CNN需要继承Module ; 2)需要调用父类的构造方法:super(CNN, self).__init__()  ;3)在Pytorch中激活函数Relu也算是一层layer; 4)需要实现forward()方法,用于网络的前向传播,而反向传播只需要调用Variable.backward()即可。

 class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( #input shape (1,28,28)
nn.Conv2d(in_channels=1, #input height
out_channels=16, #n_filter
kernel_size=5, #filter size
stride=1, #filter step
padding=2 #con2d出来的图片大小不变
), #output shape (16,28,28)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2) #2x2采样,output shape (16,14,14) )
self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), #output shape (32,7,7)
nn.ReLU(),
nn.MaxPool2d(2))
self.out = nn.Linear(32*7*7,10) def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) #flat (batch_size, 32*7*7)
output = self.out(x)
return output

step 3 查看网络结构

使用print(cnn)可以看到网络的结构详细信息,ReLU()真的是一层layer。

 cnn = CNN()
print(cnn)

step 4 训练

指定optimizer,loss function,需要特别指出的是记得每次反向传播前都要清空上一次的梯度,optimizer.zero_grad()。

 #optimizer
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) #loss_fun
loss_func = nn.CrossEntropyLoss() #training loop
for epoch in range(EPOCH):
for i, (x, y) in enumerate(train_loader):
batch_x = Variable(x)
batch_y = Variable(y)
#输入训练数据
output = cnn(batch_x)
#计算误差
loss = loss_func(output, batch_y)
#清空上一次梯度
optimizer.zero_grad()
#误差反向传递
loss.backward()
#优化器参数更新
optimizer.step()

step 5 预测结果

 test_output =cnn(test_x[:10])
pred_y = torch.max(test_output,1)[1].data.numpy().squeeze()
print(pred_y, 'prediction number')
print(test_y[:10])

reference:

莫凡python pytorch 教程

Pytorch实现卷积神经网络CNN的更多相关文章

  1. 写给程序员的机器学习入门 (八) - 卷积神经网络 (CNN) - 图片分类和验证码识别

    这一篇将会介绍卷积神经网络 (CNN),CNN 模型非常适合用来进行图片相关的学习,例如图片分类和验证码识别,也可以配合其他模型实现 OCR. 使用 Python 处理图片 在具体介绍 CNN 之前, ...

  2. 卷积神经网络(CNN)前向传播算法

    在卷积神经网络(CNN)模型结构中,我们对CNN的模型结构做了总结,这里我们就在CNN的模型基础上,看看CNN的前向传播算法是什么样子的.重点会和传统的DNN比较讨论. 1. 回顾CNN的结构 在上一 ...

  3. 卷积神经网络(CNN)反向传播算法

    在卷积神经网络(CNN)前向传播算法中,我们对CNN的前向传播算法做了总结,基于CNN前向传播算法的基础,我们下面就对CNN的反向传播算法做一个总结.在阅读本文前,建议先研究DNN的反向传播算法:深度 ...

  4. 卷积神经网络CNN总结

    从神经网络到卷积神经网络(CNN)我们知道神经网络的结构是这样的: 那卷积神经网络跟它是什么关系呢?其实卷积神经网络依旧是层级网络,只是层的功能和形式做了变化,可以说是传统神经网络的一个改进.比如下图 ...

  5. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  6. 深度学习之卷积神经网络(CNN)详解与代码实现(二)

    用Tensorflow实现卷积神经网络(CNN) 本文系作者原创,转载请注明出处:https://www.cnblogs.com/further-further-further/p/10737065. ...

  7. 深度学习之卷积神经网络(CNN)详解与代码实现(一)

    卷积神经网络(CNN)详解与代码实现 本文系作者原创,转载请注明出处:https://www.cnblogs.com/further-further-further/p/10430073.html 目 ...

  8. 【深度学习系列】卷积神经网络CNN原理详解(一)——基本原理

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  9. 卷积神经网络(CNN)学习笔记1:基础入门

    卷积神经网络(CNN)学习笔记1:基础入门 Posted on 2016-03-01   |   In Machine Learning  |   9 Comments  |   14935  Vie ...

随机推荐

  1. Codeforces 448 C. Painting Fence

    递归.分治. . . C. Painting Fence time limit per test 1 second memory limit per test 512 megabytes input ...

  2. Eclipse UML插件

    Green UML http://green.sourceforge.net/ AmaterasUML http://amateras.sourceforge.jp/cgi-bin/fswiki_en ...

  3. 基于注解的Spring AOP拦截含有泛型的DAO

    出错场景 1.抽象类BaseDao public abstract class BaseDao<T> { public BaseDao() { entityClass = (Class&l ...

  4. RPC远程过程调用概念及实现

    RPC框架学习笔记 >>什么是RPC RPC 的全称是 Remote Procedure Call 是一种进程间通信方式. 它允许程序调用另一个地址空间(通常是共享网络的另一台机器上)的过 ...

  5. jquery刷新页面指定部位

    做好好几次了,经常忘记格式,这次记下来 $("#baseInfo").load("/KnowledgeLib/Personalization/QuestionUpdate ...

  6. sql---字段类型转换,sql获取当前时间

    一.字段类型转换 convert(要转换成的数据类型,字段名称)例如 convert(varchar(100),col_name)Convert(int,Order_no) 二.sql获取当前时间 s ...

  7. php 汉字验证码

    代码: captcha.php <?php //实现简单的验证码 //session_start session_start(); //画布 $image = imagecreatetrueco ...

  8. CSRF Cross-site request forgery

    w 跨站请求伪造目标站---无知用户---恶意站 http://fallensnow-jack.blogspot.com/2011/08/webgoat-csrf.html https://wiki. ...

  9. ArcGIS runtime sdk for wpf 授权

    这两天由于runtime sdk for wpf的授权和runtime sdk 其他产品的授权的不一样导致自己混乱不堪. 总结下吧. sdk 简介 当前ArcGIS runtime sdk 包括一系列 ...

  10. CMDB初步了解

    本节内容 浅谈ITIL CMDB介绍 Django自定义用户认证 Restful 规范 资产管理功能开发 浅谈ITIL TIL即IT基础架构库(Information Technology Infra ...