Pytorch实现卷积神经网络CNN
Pytorch是torch的Python版本,对TensorFlow造成很大的冲击,TensorFlow无疑是最流行的,但是Pytorch号称在诸多性能上要优于TensorFlow,比如在RNN的训练上,所以Pytorch也吸引了很多人的关注。之前有一篇关于TensorFlow实现的CNN可以用来做对比。
下面我们就开始用Pytorch实现CNN。
step 0 导入需要的包
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as data
import matplotlib.pyplot as plt
step 1 数据预处理
这里需要将training data转化成torch能够使用的DataLoader,这样可以方便使用batch进行训练。
import torchvision #数据库模块 torch.manual_seed(1) #reproducible #Hyper Parameters
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001 train_data = torchvision.datasets.MNIST(
root='/mnist/', #保存位置
train=True, #training set
transform=torchvision.transforms.ToTensor(), #converts a PIL.Image or numpy.ndarray
#to torch.FloatTensor(C*H*W) in range(0.0,1.0)
download=True
) test_data = torchvision.datasets.MNIST(root='/MNIST/')
#如果是普通的Tensor数据,想使用torch_dataset = data.TensorDataset(data_tensor=x, target_tensor=y)
#将Tensor转换成torch能识别的dataset
#批训练, 50 samples, 1 channel, 28*28, (50, 1, 28 ,28)
train_loader = data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_lables[:2000]
step 2 定义网络结构
需要指出的几个地方:1)class CNN需要继承Module ; 2)需要调用父类的构造方法:super(CNN, self).__init__() ;3)在Pytorch中激活函数Relu也算是一层layer; 4)需要实现forward()方法,用于网络的前向传播,而反向传播只需要调用Variable.backward()即可。
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( #input shape (1,28,28)
nn.Conv2d(in_channels=1, #input height
out_channels=16, #n_filter
kernel_size=5, #filter size
stride=1, #filter step
padding=2 #con2d出来的图片大小不变
), #output shape (16,28,28)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2) #2x2采样,output shape (16,14,14) )
self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), #output shape (32,7,7)
nn.ReLU(),
nn.MaxPool2d(2))
self.out = nn.Linear(32*7*7,10) def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) #flat (batch_size, 32*7*7)
output = self.out(x)
return output
step 3 查看网络结构
使用print(cnn)可以看到网络的结构详细信息,ReLU()真的是一层layer。
cnn = CNN()
print(cnn)

step 4 训练
指定optimizer,loss function,需要特别指出的是记得每次反向传播前都要清空上一次的梯度,optimizer.zero_grad()。
#optimizer
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) #loss_fun
loss_func = nn.CrossEntropyLoss() #training loop
for epoch in range(EPOCH):
for i, (x, y) in enumerate(train_loader):
batch_x = Variable(x)
batch_y = Variable(y)
#输入训练数据
output = cnn(batch_x)
#计算误差
loss = loss_func(output, batch_y)
#清空上一次梯度
optimizer.zero_grad()
#误差反向传递
loss.backward()
#优化器参数更新
optimizer.step()
step 5 预测结果
test_output =cnn(test_x[:10])
pred_y = torch.max(test_output,1)[1].data.numpy().squeeze()
print(pred_y, 'prediction number')
print(test_y[:10])
reference:
Pytorch实现卷积神经网络CNN的更多相关文章
- 写给程序员的机器学习入门 (八) - 卷积神经网络 (CNN) - 图片分类和验证码识别
这一篇将会介绍卷积神经网络 (CNN),CNN 模型非常适合用来进行图片相关的学习,例如图片分类和验证码识别,也可以配合其他模型实现 OCR. 使用 Python 处理图片 在具体介绍 CNN 之前, ...
- 卷积神经网络(CNN)前向传播算法
在卷积神经网络(CNN)模型结构中,我们对CNN的模型结构做了总结,这里我们就在CNN的模型基础上,看看CNN的前向传播算法是什么样子的.重点会和传统的DNN比较讨论. 1. 回顾CNN的结构 在上一 ...
- 卷积神经网络(CNN)反向传播算法
在卷积神经网络(CNN)前向传播算法中,我们对CNN的前向传播算法做了总结,基于CNN前向传播算法的基础,我们下面就对CNN的反向传播算法做一个总结.在阅读本文前,建议先研究DNN的反向传播算法:深度 ...
- 卷积神经网络CNN总结
从神经网络到卷积神经网络(CNN)我们知道神经网络的结构是这样的: 那卷积神经网络跟它是什么关系呢?其实卷积神经网络依旧是层级网络,只是层的功能和形式做了变化,可以说是传统神经网络的一个改进.比如下图 ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 深度学习之卷积神经网络(CNN)详解与代码实现(二)
用Tensorflow实现卷积神经网络(CNN) 本文系作者原创,转载请注明出处:https://www.cnblogs.com/further-further-further/p/10737065. ...
- 深度学习之卷积神经网络(CNN)详解与代码实现(一)
卷积神经网络(CNN)详解与代码实现 本文系作者原创,转载请注明出处:https://www.cnblogs.com/further-further-further/p/10430073.html 目 ...
- 【深度学习系列】卷积神经网络CNN原理详解(一)——基本原理
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 卷积神经网络(CNN)学习笔记1:基础入门
卷积神经网络(CNN)学习笔记1:基础入门 Posted on 2016-03-01 | In Machine Learning | 9 Comments | 14935 Vie ...
随机推荐
- VS2010程序打包操作(结合图片详细讲解)
附视频教程:http://www.cnblogs.com/mengdesen/archive/2011/06/14/2080312.html 1. 在vs2010 选择“新建项目”----“其他项 ...
- mac上的webStorm上配置gitHub
一,webStorm下,首先打开Preferences; 二,在Version Control目录下,选择GitHub,填写有边的内容; 注意:填写完Login和Password的以后,点击Test一 ...
- 【BZOJ2212】[Poi2011]Tree Rotations 线段树合并
[BZOJ2212][Poi2011]Tree Rotations Description Byteasar the gardener is growing a rare tree called Ro ...
- [Algorithms] Longest Common Substring
The Longest Common Substring (LCS) problem is as follows: Given two strings s and t, find the length ...
- 《挑战程序设计竞赛》2.6 数学问题-辗转相除法 AOJ0005 POJ2429 1930(1)
AOJ0005 http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=0005 题意 给定两个数,求其最大公约数GCD以及最小公倍数LCM. ...
- Taylor series
w用有限来表达无限,由已知到未知,化未知为已知. https://en.wikipedia.org/wiki/Taylor_series The Greek philosopher Zeno cons ...
- sqlalchemy(二)高级用法 2
转自:https://www.cnblogs.com/coder2012/p/4746941.html 外键以及relationship 首先创建数据库,在这里一个user对应多个address,因此 ...
- 转!!SpringMVC与Struts2区别与比较总结
1.Struts2是类级别的拦截, 一个类对应一个request上下文,SpringMVC是方法级别的拦截,一个方法对应一个request上下文,而方法同时又跟一个url对应,所以说从架构本身上Spr ...
- A4纸网页打印中对应像素的设定和换算
最近开发项目时遇到了网页打印的问题,这是问题之二,打印宽度设置 在公制长度单位与屏幕分辨率进行换算时,必须用到一个DPI(Dot PerInch)指标. 经过我仔细的测试,发现了网页打印中,默认采用 ...
- 我的Android进阶之旅------>Android颜色值(#AARRGGBB)透明度百分比和十六进制对应关系以及计算方法
我的Android进阶之旅-->Android颜色值(RGB)所支持的四种常见形式 透明度百分比和十六进制对应关系表格 透明度 十六进制 100% FF 99% FC 98% FA 97% F7 ...