pytorch深度学习神经网络实现手写字体识别
利用平pytorch搭建简单的神经网络实现minist手写字体的识别,采用三层线性函数迭代运算,使得其具备一定的非线性转化与运算能力,其数学原理如下:
其具体实现代码如下所示:
import torch
import matplotlib.pyplot as plt
def plot_curve(data): #曲线输出函数构建
fig=plt.figure()
plt.plot(range(len(data)),data,color="blue")
plt.legend(["value"],loc="upper right")
plt.xlabel("step")
plt.ylabel("value")
plt.show() def plot_image(img,label,name): #输出二维图像灰度图
fig=plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307,cmap="gray",interpolation="none")
plt.title("{}:{}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
def one_hot(label,depth=10): #根据分类结果的数目将结果转换为一定的矩阵形式[n,1],n为分类结果的数目
out=torch.zeros(label.size(0),depth)
idx=torch.LongTensor(label).view(-1,1)
out.scatter_(dim=1,index=idx,value=1)
return out batch_size=512
import torch
from torch import nn #完成神经网络的构建包
from torch.nn import functional as F #包含常用的函数包
from torch import optim #优化工具包
import torchvision #视觉工具包
import matplotlib.pyplot as plt
from utils import plot_curve,plot_image,one_hot
#step1 load dataset 加载数据包
train_loader=torch.utils.data.DataLoader(
torchvision.datasets.MNIST("minist_data",train=True,download=True,transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=True)
test_loader=torch.utils.data.DataLoader(
torchvision.datasets.MNIST("minist_data",train=True,download=False,transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=False)
x,y=next(iter(train_loader))
print(x.shape,y.shape)
plot_image(x,y,"image")
print(x)
print(y) #构建神经网络结构
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
#xw+b
self.fc1=nn.Linear(28*28,256)
self.fc2=nn.Linear(256,64)
self.fc3=nn.Linear(64,10)
def forward(self, x):
#x:[b,1,28,28]
#h1=relu(xw1+b1)
x=F.relu(self.fc1(x))
#h2=relu(h1w2+b2)
x=F.relu(self.fc2(x))
#h3=h2w3+b3
x=(self.fc3(x))
return x net=Net()
#[w1,b1,w2,b2,w3,b3]
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
train_loss=[]
for epoch in range(3):
for batch_idx,(x,y) in enumerate(train_loader):
#x:[b,1,28,28],y:[512]
x=x.view(x.size(0),28*28)
# => [b,10]
out =net(x)
# [b,10]
y_onehot=one_hot(y)
#loss=mse(out,y_onehot)
loss= F.mse_loss(out,y_onehot) optimizer.zero_grad()
loss.backward()
#w'=w-lr*grad
optimizer.step()
train_loss.append(loss.item()) if batch_idx %10==0:
print(epoch,batch_idx,loss.item()) #输出其预测loss损失函数的变化曲线
plot_curve(train_loss)
#get optimal [w1,b1,w2,b2,w3,b3] total_correct=0
for x,y in test_loader:
x=x.view(x.size(0),28*28)
out=net(x)
pred=out.argmax(dim=1)
correct=pred.eq(y).sum().float().item()
total_correct+=correct
total_num=len(test_loader.dataset)
acc=total_correct/total_num
print("test.acc:",acc) #输出整体预测的准确度 x,y=next(iter(test_loader))
out=net(x.view(x.size(0),28*28))
pred=out.argmax(dim=1)
plot_image(x,pred,"test")
实现结果如下所示:
pytorch深度学习神经网络实现手写字体识别的更多相关文章
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识
深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...
- 深度学习---手写字体识别程序分析(python)
我想大部分程序员的第一个程序应该都是“hello world”,在深度学习领域,这个“hello world”程序就是手写字体识别程序. 这次我们详细的分析下手写字体识别程序,从而可以对深度学习建立一 ...
- 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别
深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...
- BP神经网络的手写数字识别
BP神经网络的手写数字识别 ANN 人工神经网络算法在实践中往往给人难以琢磨的印象,有句老话叫“出来混总是要还的”,大概是由于具有很强的非线性模拟和处理能力,因此作为代价上帝让它“黑盒”化了.作为一种 ...
- TensorFlow卷积神经网络实现手写数字识别以及可视化
边学习边笔记 https://www.cnblogs.com/felixwang2/p/9190602.html # https://www.cnblogs.com/felixwang2/p/9190 ...
- 利用c++编写bp神经网络实现手写数字识别详解
利用c++编写bp神经网络实现手写数字识别 写在前面 从大一入学开始,本菜菜就一直想学习一下神经网络算法,但由于时间和资源所限,一直未展开比较透彻的学习.大二下人工智能课的修习,给了我一个学习的契机. ...
- 第二节,TensorFlow 使用前馈神经网络实现手写数字识别
一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...
随机推荐
- python ui自动化之元素定位和常用操作
做ui自动化的最基础的就是页面元素定位了,如果连页面元素都定位不到,自动化从何谈起呢?接下来我们就看看页面元素定位的方法吧!(这里就用百度页面来进行演示) 一.最通用的几种定位方式: 1.通过id定位 ...
- mvn 搭建临时仓库批量下载依赖jar包
1.新建文件夹temp,在temp下新建setup.bat ,pom.xml 2.编辑setup.bat 和pom.xml bsetup.bat call mvn -f pom.xml depende ...
- Ubuntu python3 与 python2 的 pip调用
ubuntu 是默认装有pytthon2.x 与 python3.x 共存的 通常终端里 python 表示 python2 版本 python3 表示 python3 版本 (如果你没更改软链接设置 ...
- cmd添加管理员账号
net user 用户名 密码 /add net localgroup Administrators 用户名 /add
- laravel 创建授权策略
用户只能编辑自己的资料 在完成对未登录用户的限制之后,接下来我们要限制的是已登录用户的操作,当 id 为 1 的用户去尝试更新 id 为 2 的用户信息时,我们应该返回一个 403 禁止访问的异常.在 ...
- Sqlmap 工具用法详解
Sqlmap 工具用法详解 sqlmap是一款自动化的sql注入工具. 1.主要功能:扫描.发现.利用给定的url的sql注入漏 ...
- idea整合scala
scala依赖java环境,首先下载jdk1.8 64位 1.windows安装scala环境 下载scala环境,执行 进入doc窗口输入scala -version查看scala版本号,出现版本号 ...
- linux与python3安装redis
1.linux安装redis服务 apt-get install redis* 进入客户端管理 redis-cli 启动服务 service redis startservice redis rest ...
- web打开本地文件并读取内容
<!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...
- LeetCode简单题(三)
题目一: 给定一个数组,它的第 i 个元素是一支给定股票第 i 天的价格. 如果你最多只允许完成一笔交易(即买入和卖出一支股票),设计一个算法来计算你所能获取的最大利润. 注意你不能在买入股票前卖出股 ...