卷积神经网络目前被广泛地用在图片识别上, 已经有层出不穷的应用, 如果你对卷积神经网络充满好奇心,这里为你带来pytorch实现cnn一些入门的教程代码

#首先导入包

import torch
from torch.autograd import Variable
import torch.nn as nn
import torchvision
import torch.utils.data as Data

#一、数据准备

#训练数据:用了torchvision.datasets.MNIST,root是文件路径,train为True(这是训练数据),transform是把图像数据转换为张量,download(如果本地已有该文件选择false,没有就选择true)

train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=False)

#训练数据:同上,train为False(这是测试数据)

test_data = torchvision.datasets.MNIST(root='./mnist/',train=False)

# "训练数据加载器":dataset为训练数据,shuflle为打乱数据的顺序,batch_size是让数据50个为一组

train_loader = Data.DataLoader(dataset=train_data,shuffle=True,batch_size=50)

test_data.test_data.size()

torch.Size([10000, 28, 28])

#测试数据 test_data下的test_data为测试数据,因为下面conv2d输入的为4维数据,所以此处用torch.unsqueeze升维

test_x = Variable(torch.unsqueeze(test_data.test_data,dim=1),volatile=True).type(torch.FloatTensor)

#测试数据目标值

test_y = test_data.test_labels

#二、实现模型

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()

    #conv2d参数:输入1维,输出16维,5个卷积核(kernel),步长(stride)为1,padding是2(如果想要 con2d 出来的图片长宽没有变化, padding=(kernel_size-1)/2 当 stride=1)
    self.conv1 = nn.Sequential(nn.Conv2d(1,16,5,1,2),nn.ReLU(),nn.MaxPool2d(2))
    self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.MaxPool2d(2))

    #Linear参数:输入维数,输出分的种类数
    self.out = nn.Linear(32*7*7,10)
  def forward(self,x):
    x1 = self.conv1(x)
    x2 = self.conv2(x1)

    #这里给x3降为2维可以让linear函数使用
    x3 = x2.view(x2.size(0),-1)
    out = self.out(x3)
    return out

#自动调整参数,最优化模型

cnn = CNN()

optimizer = torch.optim.Adam(cnn.parameters(),lr = 0.02)
loss_func = nn.CrossEntropyLoss()

#三、训练模型

for step,(x,y) in enumerate(train_loader):
  x = Variable(x)
  y = Variable(y)
  out = cnn(x)
  loss = loss_func(out,y)

  #以下为固定操作,为了训练每一条数据,不断调整参数
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

#四、测试

predict = cnn(test_x[:10])
res = torch.max(predict,1)[1]

res #测试数据

tensor([7, 2, 1, 0, 4, 1, 4, 9, 9, 9])

test_y[:10] #真实数据

tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])

#在这里我们发现前十个数据分类准确率达到90

Pytorch卷积神经网络识别手写数字集的更多相关文章

  1. 如何用卷积神经网络CNN识别手写数字集?

    前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...

  2. Tensorflow搭建卷积神经网络识别手写英语字母

    更新记录: 2018年2月5日 初始文章版本 近几天需要进行英语手写体识别,查阅了很多资料,但是大多数资料都是针对MNIST数据集的,并且主要识别手写数字.为了满足实际的英文手写识别需求,需要从训练集 ...

  3. PyTorch基础——使用卷积神经网络识别手写数字

    一.介绍 实验内容 内容包括用 PyTorch 来实现一个卷积神经网络,从而实现手写数字识别任务. 除此之外,还对卷积神经网络的卷积核.特征图等进行了分析,引出了过滤器的概念,并简单示了卷积神经网络的 ...

  4. Python实现神经网络算法识别手写数字集

    最近忙里偷闲学习了一点机器学习的知识,看到神经网络算法时我和阿Kun便想到要将它用Python代码实现.我们用了两种不同的方法来编写它.这里只放出我的代码. MNIST数据集基于美国国家标准与技术研究 ...

  5. 使用TensorFlow的卷积神经网络识别手写数字(3)-识别篇

    from PIL import Image import numpy as np import tensorflow as tf import time bShowAccuracy = True # ...

  6. 使用TensorFlow的卷积神经网络识别手写数字(2)-训练篇

    import numpy as np import tensorflow as tf import matplotlib import matplotlib.pyplot as plt import ...

  7. 使用TensorFlow的卷积神经网络识别手写数字(1)-预处理篇

    功能: 将文件夹下的20*20像素黑白图片,根据重心位置绘制到28*28图片上,然后保存.经过预处理的图片有利于数字的准确识别.参见MNIST对图片的要求. 此处可下载已处理好的图片: https:/ ...

  8. 李宏毅 Keras手写数字集识别(优化篇)

    在之前的一章中我们讲到的keras手写数字集的识别中,所使用的loss function为‘mse’,即均方差.那我们如何才能知道所得出的结果是不是overfitting?我们通过运行结果中的trai ...

  9. TensorFlow卷积神经网络实现手写数字识别以及可视化

    边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...

随机推荐

  1. httpclient cer

    X509Certificate2 cer = new X509Certificate2(@"path", "********", X509KeyStorageF ...

  2. Vue3.0报错error: Unexpected console statement (no-console) 解决办法

    写项目过程中用ESLint遵守代码规范很有必要,但是对于一些规范也很是无语,比如:‘Unexpected console statement (no-console)’,连console都不能用,这就 ...

  3. 解决Ubuntu18.10 网络图标经常消失连不上网问题

    我不知道是什么原因,Ubuntu虚拟机经常会出现无法上网的问题? 此时右上角没有网络标志,Settings->NetWork也只有VPN一项,不知道咋用. 在网上终于找到了方法,亲测有效:htt ...

  4. CSRF漏洞的挖掘与利用

    0x01 CSRF的攻击原理 CSRF 百度上的意思是跨站请求伪造,其实最简单的理解我们可以这么讲,假如一个微博关注用户的一个功能,存在CSRF漏洞,那么此时黑客只需要伪造一个页面让受害者间接或者直接 ...

  5. 内网漫游之SOCKS代理大结局

    0×01 引言 在实际渗透过程中,我们成功入侵了目标服务器.接着我们想在本机上通过浏览器或者其他客户端软件访问目标机器内部网络中所开放的端口,比如内网的3389端口.内网网站8080端口等等.传统的方 ...

  6. Buffer、核心API、npm

      Buffer基本操作 Buffer对象是Node处理二进制数据的一个接口.它是Node原生提供的全局对象,可以直接使用,不需要require(‘buffer’). 实例化 Buffer.from( ...

  7. iOS 关于NavigationController返回的一些笔记

    1.理解NavigationController返回机制 一般NavigationController下的子view只有一层或者有很多层,子view返回最顶层则可以直接用 [self.navigati ...

  8. 【恢复】Redo日志文件丢失的恢复

    第一章 Redo文件丢失的恢复 1.1  online redolog file 丢失 联机Redo日志是Oracle数据库中比较核心的文件,当Redo日志文件异常之后,数据库就无法正常启动,而且有丢 ...

  9. python调用C语言接口

    python调用C语言接口 注:本文所有示例介绍基于linux平台 在底层开发中,一般是使用C或者C++,但是有时候为了开发效率或者在写测试脚本的时候,会经常使用到python,所以这就涉及到一个问题 ...

  10. (五)Kubernetes Pod状态和生命周期管理

    什么是Pod Pod是kubernetes中你可以创建和部署的最小也是最简的单位.Pod代表着集群中运行的进程. Pod中封装着应用的容器(有的情况下是好几个容器),存储.独立的网络IP,管理容器如何 ...