import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
batch_size = 64
learning_rate = 1e-2
num_epoches = 20
data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])
#transform.Compose() 将各种预处理操作组合在一起
#transform.ToTensor() 将数据转化为Tensor类型,并自动标准化,Tensor的取值是(0,1)
#transform.Normalize()是标准化操作,类似正太分布的标准化,第一个值是均值,第二个值是方差
#如果图像是三个通道,则transform.Normalize([a,b,c],[d,e,f])
train_dataset = datasets.MNIST(root = './mnist_data', train = True, transform = data_tf, download = True) #用datasets加载数据集,传入预处理
test_dataset = datasets.MNIST(root = './mnist_data', train = False,transform = data_tf)
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True) #利用DataLoader建立一个数据迭代器
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)
class Batch_Net(nn.Module):
def __init__(self, inputdim, hidden1, hidden2, outputdim):
super(Batch_Net, self).__init__()
self.layer1 = nn.Sequential(nn.Linear(inputdim, hidden1), nn.BatchNorm1d(hidden1), nn.ReLU(True))
self.layer2 = nn.Sequential(nn.Linear(hidden1, hidden2), nn.BatchNorm1d(hidden2), nn.ReLU(True))
self.layer3 = nn.Sequential(nn.Linear(hidden2, outputdim)) def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
model = Batch_Net(28*28, 300, 100, 10)
model

定义损失函数和优化器

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr = learning_rate)

训练模型

for epoch in range(num_epoches):
train_loss = 0
train_acc = 0
model.train() #这句话会自动调整batch_normalize和dropout值,很关键!
for img, label in train_loader:
img = img.view(img.size(0), -1) #将数据扁平化为一维
img = Variable(img)
label = Variable(label)
# 前向传播
out = model(img)
loss = criterion(out, label)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 记录误差
train_loss += loss.item()
# 计算分类的准确率
_, pred = out.max(1)
num_correct = (pred == label).sum().item()
acc = num_correct / img.shape[0]
train_acc += acc print('epoch:{},train_loss:{:.6f},acc:{:.6f}'.format(epoch+1, train_loss/len(train_loader), train_acc/len(train_loader)))
epoch:1,train_loss:0.002079,acc:0.999767
......  
epoch:19,train_loss:0.001532,acc:0.999917
epoch:20,train_loss:0.001670,acc:0.999850

测试集

model.eval()  #在评估模型时使用,固定BN 和 Dropout
eval_loss = 0
val_acc = 0
for img , label in test_loader:
img = img.view(img.size(0), -1)
img = Variable(img, volatile = True) #volatile=TRUE表示前向传播是不会保留缓存,因为测试集不需要反向传播
label = Variable(label, volatile = True)
out = model(img)
loss = criterion(out, label)
eval_loss += loss.item()
_,pred = torch.max(out, 1)
num_correct = (pred == label).sum().item()
print(num_correct)
eval_acc = num_correct / label.shape[0]
val_acc += eval_acc print('Test Loss:{:.6f}, Acc:{:.6f}'.format(eval_loss/len(test_loader), val_acc/len(test_loader)))
Test Loss:0.062413, Acc:0.981091

多层全连接神经网络实现minist手写数字分类的更多相关文章

  1. keras与卷积神经网络(CNN)实现识别minist手写数字

    在本篇博文当中,笔者采用了卷积神经网络来对手写数字进行识别,采用的神经网络的结构是:输入图片——卷积层——池化层——卷积层——池化层——卷积层——池化层——Flatten层——全连接层(64个神经元) ...

  2. Tensorflow 多层全连接神经网络

    本节涉及: 身份证问题 单层网络的模型 多层全连接神经网络 激活函数 tanh 身份证问题新模型的代码实现 模型的优化 一.身份证问题 身份证号码是18位的数字[此处暂不考虑字母的情况],身份证倒数第 ...

  3. python手写神经网络实现识别手写数字

    写在开头:这个实验和matlab手写神经网络实现识别手写数字一样. 实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手 ...

  4. matlab手写神经网络实现识别手写数字

    实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手写数字图片,于是我就尝试用matlab写一个网络. 实验数据:500 ...

  5. MNIST手写数字分类simple版(03-2)

    simple版本nn模型 训练手写数字处理 MNIST_data数据   百度网盘链接:https://pan.baidu.com/s/19lhmrts-vz0-w5wv2A97gg 提取码:cgnx ...

  6. Tensorflow-线性回归与手写数字分类

    线性回归 步骤 构造线性回归数据 定义输入层 设计神经网络中间层 定义神经网络输出层 计算二次代价函数,构建梯度下降 进行训练,获取预测值 画图展示 代码 import tensorflow as t ...

  7. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  8. Pytorch1.0入门实战一:LeNet神经网络实现 MNIST手写数字识别

    记得第一次接触手写数字识别数据集还在学习TensorFlow,各种sess.run(),头都绕晕了.自从接触pytorch以来,一直想写点什么.曾经在2017年5月,Andrej Karpathy发表 ...

  9. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

随机推荐

  1. E - 秋实大哥与战争

    秋实大哥与战争 Time Limit: 3000/1000MS (Java/Others)     Memory Limit: 65535/65535KB (Java/Others) Submit S ...

  2. P4868 Preprefix sum

    传送门 挺显然的一题?单点修改,前缀和数组前缀查询 树状数组就可以维护了 考虑每个位置对应询问的贡献,设询问的位置为 $x$,对于原数组 $a[i]$ 的某个位置 $i$,它会贡献 $(x-i+1)* ...

  3. [.net core]2.hello word(.net core web app模版简介)

    创建一个.net core web app project 弹出这个窗口 empty代表 最低依赖,  意味着往往需要手动按需添加依赖. web应用程序(模型视力控制器) 则会帮你创建好control ...

  4. Vue2 & ElementUI实现管理后台之input获得焦点

    Vue.directive('focus', function (el, option) { var defClass = 'el-input', defTag = 'input'; var valu ...

  5. vue单页应用首次加载太慢之性能优化

    问题描述: 最近开发了一个单页应用,上线后发现页面初始加载要20s才能完成,这就很影响用户体验了,于是分析原因,发现页面加载时有个 vendor.js达到了3000多kb,于是在网上查找了一下原因,是 ...

  6. Hive的架构(二)

    02 Hive的架构 1.Hive的架构图 2.Hive的服务(角色) 1.用户访问接口 ​ CLI(Command Line Interface):用户可以使用Hive自带的命令行接口执行Hive ...

  7. Atcoder Regular 098 区间Pre=Xor Q询问区间连续K去最小值最小极差

    C 用scanf("%s")就会WA..不知道为什么 /*Huyyt*/ #include<bits/stdc++.h> #define mem(a,b) memset ...

  8. Python核心技术与实战——十三|Python中参数传递机制

    我们在前面的章节里学习了Python的函数基础以及应用,那么现在想一想:传参,也就是把一些参数从一个函数传递到另一个函数,从而使其执行相应的任务,这个过程的底层是如何工作的,原理又是怎样的呢? 在实际 ...

  9. .NET CORE学习笔记系列 开篇介绍

    ASP.NET Core学习和使用了一段时间了,好记性不如烂笔头,通过查阅官网学习文档和一些大神们的博客总结一下.主要路线先总结一下ASP.NET Core的基础知识,然后是ASP.NET Core  ...

  10. CF 1272F Two Bracket Sequences (括号dp)

    题目地址 洛谷CF1272F Solution 首先题目中有两个括号串 \(s\) 和 \(t\) ,考虑先设计两维表示 \(s\) 匹配到的位置和 \(t\) 匹配到的位置. 接着根据 括号dp的一 ...