1. 利用平pytorch搭建简单的神经网络实现minist手写字体的识别,采用三层线性函数迭代运算,使得其具备一定的非线性转化与运算能力,其数学原理如下:

  1. 其具体实现代码如下所示:
    import torch
    import matplotlib.pyplot as plt
    def plot_curve(data): #曲线输出函数构建
    fig=plt.figure()
    plt.plot(range(len(data)),data,color="blue")
    plt.legend(["value"],loc="upper right")
    plt.xlabel("step")
    plt.ylabel("value")
    plt.show()
  2.  
  3. def plot_image(img,label,name): #输出二维图像灰度图
    fig=plt.figure()
    for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout()
    plt.imshow(img[i][0]*0.3081+0.1307,cmap="gray",interpolation="none")
    plt.title("{}:{}".format(name, label[i].item()))
    plt.xticks([])
    plt.yticks([])
    plt.show()
    def one_hot(label,depth=10): #根据分类结果的数目将结果转换为一定的矩阵形式[n,1],n为分类结果的数目
    out=torch.zeros(label.size(0),depth)
    idx=torch.LongTensor(label).view(-1,1)
    out.scatter_(dim=1,index=idx,value=1)
    return out
  4.  
  5. batch_size=512
    import torch
    from torch import nn #完成神经网络的构建包
    from torch.nn import functional as F #包含常用的函数包
    from torch import optim #优化工具包
    import torchvision #视觉工具包
    import matplotlib.pyplot as plt
    from utils import plot_curve,plot_image,one_hot
    #step1 load dataset 加载数据包
    train_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("minist_data",train=True,download=True,transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
    ])),
    batch_size=batch_size,shuffle=True)
    test_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("minist_data",train=True,download=False,transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
    ])),
    batch_size=batch_size,shuffle=False)
    x,y=next(iter(train_loader))
    print(x.shape,y.shape)
    plot_image(x,y,"image")
    print(x)
    print(y)
  6.  
  7. #构建神经网络结构
    class Net(nn.Module):
    def __init__(self):
    super(Net,self).__init__()
    #xw+b
    self.fc1=nn.Linear(28*28,256)
    self.fc2=nn.Linear(256,64)
    self.fc3=nn.Linear(64,10)
    def forward(self, x):
    #x:[b,1,28,28]
    #h1=relu(xw1+b1)
    x=F.relu(self.fc1(x))
    #h2=relu(h1w2+b2)
    x=F.relu(self.fc2(x))
    #h3=h2w3+b3
    x=(self.fc3(x))
    return x
  8.  
  9. net=Net()
    #[w1,b1,w2,b2,w3,b3]
    optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
    train_loss=[]
    for epoch in range(3):
    for batch_idx,(x,y) in enumerate(train_loader):
    #x:[b,1,28,28],y:[512]
    x=x.view(x.size(0),28*28)
    # => [b,10]
    out =net(x)
    # [b,10]
    y_onehot=one_hot(y)
    #loss=mse(out,y_onehot)
    loss= F.mse_loss(out,y_onehot)
  10.  
  11. optimizer.zero_grad()
    loss.backward()
    #w'=w-lr*grad
    optimizer.step()
    train_loss.append(loss.item())
  12.  
  13. if batch_idx %10==0:
    print(epoch,batch_idx,loss.item()) #输出其预测loss损失函数的变化曲线
    plot_curve(train_loss)
    #get optimal [w1,b1,w2,b2,w3,b3]
  14.  
  15. total_correct=0
    for x,y in test_loader:
    x=x.view(x.size(0),28*28)
    out=net(x)
    pred=out.argmax(dim=1)
    correct=pred.eq(y).sum().float().item()
    total_correct+=correct
    total_num=len(test_loader.dataset)
    acc=total_correct/total_num
    print("test.acc:",acc) #输出整体预测的准确度
  16.  
  17. x,y=next(iter(test_loader))
    out=net(x.view(x.size(0),28*28))
    pred=out.argmax(dim=1)
    plot_image(x,pred,"test")
    实现结果如下所示:

  1.  

pytorch深度学习神经网络实现手写字体识别的更多相关文章

  1. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  2. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  3. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  4. 深度学习---手写字体识别程序分析(python)

    我想大部分程序员的第一个程序应该都是“hello world”,在深度学习领域,这个“hello world”程序就是手写字体识别程序. 这次我们详细的分析下手写字体识别程序,从而可以对深度学习建立一 ...

  5. 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别

    深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...

  6. BP神经网络的手写数字识别

    BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...

  7. TensorFlow卷积神经网络实现手写数字识别以及可视化

    边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...

  8. 利用c++编写bp神经网络实现手写数字识别详解

    利用c++编写bp神经网络实现手写数字识别 写在前面 从大一入学开始,本菜菜就一直想学习一下神经网络算法,但由于时间和资源所限,一直未展开比较透彻的学习.大二下人工智能课的修习,给了我一个学习的契机. ...

  9. 第二节,TensorFlow 使用前馈神经网络实现手写数字识别

    一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...

随机推荐

  1. Jekyll+Github个人博客构建之路

    请参考: http://robotkang.cc/2017/03/HowToCreateBlog/

  2. mysql 存入数据库 中文乱码

    1.要保证数据库.表.字段都是utf-8的数据类型.排序一直即可. 数据库的在数据库属性里面改: 表的在设计表里面改: 字段的也是在设计表里面改: 常用命令: -- 检查字符集类型show varia ...

  3. Springboot中使用kafka

    注:kafka消息队列默认采用配置消息主题进行消费,一个topic中的消息只能被同一个组(groupId)的消费者中的一个消费者消费. 1.在pom.xml依赖下新添加一下kafka依赖ar包 < ...

  4. stl_string复习

    #include <iostream>#include <string>#include <algorithm>using namespace std; void ...

  5. Vue.js开发去哪儿网WebApp

    一.项目介绍 这个项目主要参考了去哪儿网的布局,完成了首页.城市选择页面.详情页面的开发. 首页:实现了多区域轮播的功能,以及多区域列表的展示: 城市选择页面:在这个页面实现了城市展示.城市搜索.城市 ...

  6. 对委托 以及 action func 匿名函数 以及 lambda表达式的简单记录

    class Program { public delegate void MyDelegate(string str); static void Main(string[] args) { // My ...

  7. 画风清奇!看看大佬怎么玩Python

    一提到Python,不少人脑海里都会浮现出几个关键词"数据分析""网络爬虫""人工智能"等,但Python的用法,远不止这些.让我们看看国内 ...

  8. Burpsuite 工具详解(常用模块之proxy、spider 、decoder)

    Burpsuite常用模块之proxy.spider .decoder                                                 是一款集成化渗透测试工具(jav ...

  9. SQL 查询每组的第一条记录

    CREATE TABLE [dbo].[test1]( [program_id] [int] NULL, [person_id] [int] NULL ) ON [PRIMARY] /*查询每组分组中 ...

  10. 怪异盒子模型和行内元素的float

    设置了float属性的行内元素的display值会变成inline-block 怪异盒子模型: box-sizing:border-box:元素content包含内间距和border