手写数字识别,神经网络领域的“hello world”例子,通过pytorch一步步构建,通过训练与调整,达到“100%”准确率

1、快速开始

1.1 定义神经网络类,继承torch.nn.Module,文件名为digit_recog.py

  1. import torch.nn as nn
  2.  
  3. class Net(nn.Module):
  4. def __init__(self):
  5. super(Net, self).__init__()
  6. self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 5, 1, 2)
  7. , nn.ReLU()
  8. , nn.MaxPool2d(2, 2))
  9. self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5)
  10. , nn.ReLU()
  11. , nn.MaxPool2d(2, 2))
  12. self.fc1 = nn.Sequential(
  13. nn.Linear(16 * 5 * 5, 120),
  14. # nn.Dropout2d(),
  15. nn.ReLU()
  16. )
  17. self.fc2 = nn.Sequential(
  18. nn.Linear(120, 84),
  19. nn.Dropout2d(),
  20. nn.ReLU()
  21. )
  22. self.fc3 = nn.Linear(84, 10)
  23.  
  24. # 前向传播
  25. def forward(self, x):
  26. x = self.conv1(x)
  27. x = self.conv2(x)
  28. # 线性层的输入输出都是一维数据,所以要把多维度的tensor展平成一维
  29. x = x.view(x.size()[0], -1)
  30. x = self.fc1(x)
  31. x = self.fc2(x)
  32. x = self.fc3(x)
  33. return x
  1.  

上面的类定义了一个3层的网络结构,根据问题类型,最后一层是确定的

1.2 开始训练:

  1. import torch
  2. import torchvision as tv
  3. import torchvision.transforms as transforms
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. import os
  7. import copy
  8. import time
  9. from digit_recog import Net
  10. from digit_recog_mydataset import MyDataset
  11.  
  12. # 读取已保存的模型
  13. def getmodel(pth, net):
  14. state_filepath = pth
  15. if os.path.exists(state_filepath):
  16. # 加载参数
  17. nn_state = torch.load(state_filepath)
  18. # 加载模型
  19. net.load_state_dict(nn_state)
  20. # 拷贝一份
  21. return copy.deepcopy(nn_state)
  22. else:
  23. return net.state_dict()
  24.  
  25. # 构建数据集
  26. def getdataset(batch_size):
  27. # 定义数据预处理方式
  28. transform = transforms.ToTensor()
  29.  
  30. # 定义训练数据集
  31. trainset = tv.datasets.MNIST(
  32. root='./data/',
  33. train=True,
  34. download=True,
  35. transform=transform)
  36.  
  37. # 去掉注释,加入自己的数据集
  38. # trainset += MyDataset(os.path.abspath("./data/myimages/"), 'train.txt', transform=transform)
  39.  
  40. # 定义训练批处理数据
  41. trainloader = torch.utils.data.DataLoader(
  42. trainset,
  43. batch_size=batch_size,
  44. shuffle=True,
  45. )
  46.  
  47. # 定义测试数据集
  48. testset = tv.datasets.MNIST(
  49. root='./data/',
  50. train=False,
  51. download=True,
  52. transform=transform)
  53.  
  54. # 去掉注释,加入自己的数据集
  55. # testset += MyDataset(os.path.abspath("./data/myimages/"), 'test.txt', transform=transform)
  56.  
  57. # 定义测试批处理数据
  58. testloader = torch.utils.data.DataLoader(
  59. testset,
  60. batch_size=batch_size,
  61. shuffle=False,
  62. )
  63.  
  64. return trainloader, testloader
  65.  
  66. # 训练
  67. def training(device, net, model, dataset_loader, epochs, criterion, optimizer, save_model_path):
  68. trainloader, testloader = dataset_loader
  69. # 最佳模型
  70. best_model_wts = model
  71. # 最好分数
  72. best_acc = 0.0
  73. # 计时
  74. since = time.time()
  75. for epoch in range(epochs):
  76. sum_loss = 0.0
  77. # 训练数据集
  78. for i, data in enumerate(trainloader):
  79. inputs, labels = data
  80. inputs, labels = inputs.to(device), labels.to(device)
  81. # 梯度清零,避免带入下一轮累加
  82. optimizer.zero_grad()
  83. # 神经网络运算
  84. outputs = net(inputs)
  85. # 损失值
  86. loss = criterion(outputs, labels)
  87. # 损失值反向传播
  88. loss.backward()
  89. # 执行优化
  90. optimizer.step()
  91. # 损失值汇总
  92. sum_loss += loss.item()
  93. # 每训练完100条数据就显示一下损失值
  94. if i % 100 == 99:
  95. print('[%d, %d] loss: %.03f'
  96. % (epoch + 1, i + 1, sum_loss / 100))
  97. sum_loss = 0.0
  98. # 每训练完一轮测试一下准确率
  99. with torch.no_grad():
  100. correct = 0
  101. total = 0
  102. for data in testloader:
  103. images, labels = data
  104. images, labels = images.to(device), labels.to(device)
  105. outputs = net(images)
  106. # 取得分最高的
  107. _, predicted = torch.max(outputs.data, 1)
  108. # print(labels)
  109. # print(torch.nn.Softmax(dim=1)(outputs.data).detach().numpy()[0])
  110. # print(torch.nn.functional.normalize(outputs.data).detach().numpy()[0])
  111. total += labels.size(0)
  112. correct += (predicted == labels).sum()
  113.  
  114. print('测试结果:{}/{}'.format(correct, total))
  115. epoch_acc = correct.double() / total
  116. print('当前分数:{} 最高分数:{}'.format(epoch_acc, best_acc))
  117. if epoch_acc > best_acc:
  118. best_acc = epoch_acc
  119. best_model_wts = copy.deepcopy(net.state_dict())
  120. print('第%d轮的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  121.  
  122. time_elapsed = time.time() - since
  123. print('训练完成于 {:.0f}m {:.0f}s'.format(
  124. time_elapsed // 60, time_elapsed % 60))
  125. print('最高分数: {:4f}'.format(best_acc))
  126. # 保存训练模型
  127. if save_model_path is not None:
  128. save_state_path = os.path.join('model/', 'net.pth')
  129. torch.save(best_model_wts, save_state_path)
  130.  
  131. # 基于cpu还是gpu
  132. DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  133. NET = Net().to(DEVICE)
  134. # 超参数设置
  135. EPOCHS = 8# 训练多少轮
  136. BATCH_SIZE = 64 # 数据集批处理数量 64
  137. LR = 0.001 # 学习率
  138.  
  139. # 交叉熵损失函数,通常用于多分类问题上
  140. CRITERION = nn.CrossEntropyLoss()
  141. # 优化器
  142. # OPTIMIZER = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
  143. OPTIMIZER = optim.Adam(NET.parameters(), lr=LR)
  144. MODEL = getmodel(os.path.join('model/', 'net.pth'), NET)
  145. training(DEVICE, NET, MODEL, getdataset(BATCH_SIZE), 1, CRITERION, OPTIMIZER, os.path.join('model/', 'net.pth'))

利用标准的mnist数据集跑出来的识别率能达到99%

2、参与进来

目的是为了识别自己的图片,增加参与感

2.1 打开windows附件中的画图工具,用鼠标画几个数字,然后用截图工具保存下来

2.2 实现自己的数据集:

digit_recog_mydataset.py

  1. from PIL import Image
  2. import torch
  3. import os
  4.  
  5. # 实现自己的数据集
  6. class MyDataset(torch.utils.data.Dataset):
  7. def __init__(self, root, datafile, transform=None, target_transform=None):
  8. super(MyDataset, self).__init__()
  9. fh = open(os.path.join(root, datafile), 'r')
  10. datas = []
  11. for line in fh:
  12. # 删除本行末尾的字符
  13. line = line.rstrip()
  14. # 通过指定分隔符对字符串进行拆分,默认为所有的空字符,包括空格、换行、制表符等
  15. words = line.split()
  16. # words[0]是图片信息,words[1]是标签
  17. datas.append((words[0], int(words[1])))
  18.  
  19. self.datas = datas
  20. self.transform = transform
  21. self.target_transform = target_transform
  22. self.root = root
  23.  
  24. # 必须实现的方法,用于按照索引读取每个元素的具体内容
  25. def __getitem__(self, index):
  26. # 获取图片及标签,即上面每行中word[0]和word[1]的信息
  27. img, label = self.datas[index]
  28. # 打开图片,重设尺寸,转换为灰度图
  29. img = Image.open(os.path.join(self.root, img)).resize((28, 28)).convert('L')
  30.  
  31. # 数据预处理
  32. if self.transform is not None:
  33. img = self.transform(img)
  34. return img, label
  35.  
  36. # 必须实现的方法,返回数据集的长度
  37. def __len__(self):
  38. return len(self.datas)

2.3 在图片文件夹中新建两个文件,train.txt和test.txt,分别写上训练与测试集的数据,格式如下

训练与测试的数据要严格区分开,否则训练出来的模型会有问题

2.4 加入训练、测试数据集

反注释训练方法中的这两行

  1. # trainset += MyDataset(os.path.abspath("./data/myimages/"), 'train.txt', transform=transform)
  2.  
  3. # testset += MyDataset(os.path.abspath("./data/myimages/"), 'test.txt', transform=transform)

继续执行训练,这里我训练出来的最高识别率是98%

2.5 测试模型

  1. # -*- coding: utf-8 -*-
  2. # encoding:utf-8
  3.  
  4. import torch
  5. import numpy as np
  6. from PIL import Image
  7. import os
  8. import matplotlib
  9. import matplotlib.pyplot as plt
  10. import glob
  11. from digit_recog import Net
  12.  
  13. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  14. net = Net().to(device)
  15. # 加载参数
  16. nn_state = torch.load(os.path.join('model/', 'net.pth'))
  17. # 参数加载到指定模型
  18. net.load_state_dict(nn_state)
  19.  
  20. # 指定默认字体
  21. matplotlib.rcParams['font.sans-serif'] = ['SimHei']
  22. matplotlib.rcParams['font.family'] = 'sans-serif'
  23. # 解决负号'-'显示为方块的问题
  24. matplotlib.rcParams['axes.unicode_minus'] = False
  25.  
  26. # 要识别的图片
  27. file_list = glob.glob(os.path.join('data/test_image/', '*'))
  28. grid_rows = len(file_list) / 5 + 1
  29.  
  30. for i, file in enumerate(file_list):
  31. # 读取图片并重设尺寸
  32. image = Image.open(file).resize((28, 28))
  33. # 灰度图
  34. gray_image = image.convert('L')
  35. # 图片数据处理
  36. im_data = np.array(gray_image)
  37. im_data = torch.from_numpy(im_data).float()
  38. im_data = im_data.view(1, 1, 28, 28)
  39. # 神经网络运算
  40. outputs = net(im_data)
  41. # 取最大预测值
  42. _, pred = torch.max(outputs, 1)
  43. # print(torch.nn.Softmax(dim=1)(outputs).detach().numpy()[0])
  44. # print(torch.nn.functional.normalize(outputs).detach().numpy()[0])
  45. # 显示图片
  46. plt.subplot(grid_rows, 5, i + 1)
  47. plt.imshow(gray_image)
  48. plt.title(u"你是{}?".format(pred.item()), fontsize=8)
  49. plt.axis('off')
  50.  
  51. print('[{}]预测数字为: [{}]'.format(file, pred.item()))
  52.  
  53. plt.show()

可视化结果

这批图片是经过图片增强后识别的结果,准确率有待提高

3、优化

3.1 更多样本:

收集难度大

3.2 数据增强:

简单地处理一下自己手写的数字图片

  1. # -*- coding: utf-8 -*-
  2. # encoding:utf-8
  3.  
  4. import torch
  5. import numpy as np
  6. from PIL import Image
  7. import os
  8. import matplotlib
  9. import matplotlib.pyplot as plt
  10. import glob
  11. from scipy.ndimage import filters
  12.  
  13. class ImageProcceed:
  14. def __init__(self, image_folder):
  15. self.image_folder = image_folder
  16.  
  17. def save(self, rotate, filter=None, to_gray=True):
  18. file_list = glob.glob(os.path.join(self.image_folder, '*.png'))
  19. print(len(file_list))
  20. for i, file in enumerate(file_list):
  21. # 读取图片数据
  22. image = Image.open(file) # .resize((28, 28))
  23. # 灰度图
  24. if to_gray == True:
  25. image = image.convert('L')
  26. # 旋转
  27. image = image.rotate(rotate)
  28. if filter is not None:
  29. image = filters.gaussian_filter(image, 0.5)
  30. image = Image.fromarray(image)
  31. filename = os.path.basename(file)
  32. fileext = os.path.splitext(filename)[1]
  33. savefile = filename.replace(fileext, '-rt{}{}'.format(rotate, fileext))
  34. print(savefile)
  35. image.save(os.path.join(self.image_folder, savefile))
  36.  
  37. ip = ImageProcceed('data/myimages/')
  38. ip.save(20, filter=0.5)

3.3 改变网络大小:

比如把上面的Net类中的3层改为2层

3.4 调参:

改变学习率,训练更多次数等

后面我调整了Net类中的两个地方,准确率终于达到100%,这只是在我小批量测试集上的表现而已,而现实中预测是不可能达到100%的,每台机器可能有差异,每次运行的结果会有不同,再次帖出代码

  1. import torch.nn as nn
  2.  
  3. class Net(nn.Module):
  4. def __init__(self):
  5. super(Net, self).__init__()
  6. # 卷积: 1通道输入,6通道输出,卷积核5*5,步长1,前后补2个0
  7. # 激活函数一般用ReLU,后面改良的有LeakyReLU/PReLU
  8. # MaxPool2d池化,一般是2
  9. self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 5, 1, 2)
  10. , nn.PReLU()
  11. , nn.MaxPool2d(2, 2))
  12. self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5)
  13. , nn.PReLU()
  14. , nn.MaxPool2d(2, 2))
  15. self.fc1 = nn.Sequential(
  16. nn.Linear(16 * 5 * 5, 120), # 卷积输出16,乘以卷积核5*5
  17. # nn.Dropout2d(), # Dropout接收来自linear的数据,Dropout2d接收来自conv2d的数据
  18. nn.PReLU()
  19. )
  20. self.fc2 = nn.Sequential(
  21. nn.Linear(120, 84),
  22. nn.Dropout(p=0.2),
  23. nn.PReLU()
  24. )
  25. self.fc3 = nn.Linear(84, 10) # 输出层节点为10,代表数字0-9
  26.  
  27. # 前向传播
  28. def forward(self, x):
  29. x = self.conv1(x)
  30. x = self.conv2(x)
  31. # 线性层的输入输出都是一维数据,所以要把多维度的tensor展平成一维
  32. x = x.view(x.size()[0], -1)
  33. x = self.fc1(x)
  34. x = self.fc2(x)
  35. x = self.fc3(x)
  36. return x

上面改了两个地方,一个是激活函数ReLU改成了PReLU,正则化Dropout用0.2作为参数,下面是再次运行测试后的结果

识别手写数字增强版100% - pytorch从入门到入道(一)的更多相关文章

  1. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  2. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  3. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  4. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  5. 3 TensorFlow入门之识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  6. python手写神经网络实现识别手写数字

    写在开头:这个实验和matlab手写神经网络实现识别手写数字一样. 实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手 ...

  7. 用BP人工神经网络识别手写数字

    http://wenku.baidu.com/link?url=HQ-5tZCXBQ3uwPZQECHkMCtursKIpglboBHq416N-q2WZupkNNH3Gv4vtEHyPULezDb5 ...

  8. python机器学习使用PCA降维识别手写数字

    PCA降维识别手写数字 关注公众号"轻松学编程"了解更多. PCA 用于数据降维,减少运算时间,避免过拟合. PCA(n_components=150,whiten=True) n ...

  9. KNN 算法-实战篇-如何识别手写数字

    公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...

随机推荐

  1. idea tomcat提示Unable to ping server at localhost:1099

    idea启动tomcat报错Unable to ping server at localhost:1099 是 IDEA配置的jdk版本 与 tomcat的jdk版本不配导致的

  2. 一次Commons-HttpClient的BindException排查

    线上有个老应用,在流量增长的时候,HttpClient抛出了BindException.部分的StackTrace信息如下: java.net.BindException: Address alrea ...

  3. 数据存储检索之B+树和LSM-Tree

    作为一名应用系统开发人员,为什么要关注数据内部的存储和检索呢?首先,你不太可能从头开始实现一套自己的存储引擎,往往需要从众多现有的存储引擎中选择一个适合自己应用的存储引擎.因此,为了针对你特定的工作负 ...

  4. CS184.1X 计算机图形学导论 第3讲L3V1

    二维空间的变换 L3V1这一课主要讲了二维空间的变换,包括平移.错切和旋转. 缩放 缩放矩阵 使用矩阵的乘法来完成缩放 缩放矩阵是一个对角矩阵,对角线上的值对应缩放倍数 错切(shear) 错切可以将 ...

  5. 《java编程思想》P22-P37(第二章一切都是对象)

    1.JAVA操纵的标识符实际上是对象的一个"引用";如String s;里的s是String类的引用并非对象. 2.程序运行时,有五个不同的地区可以存储数据. (1)寄存器:最快的 ...

  6. Blazor(一):运行初体验,全新的.net web的开发

    官网:https://dotnet.microsoft.com/apps/aspnet/web-apps/client 作者BBS:http://bbs.hslcommunication.cn/ 我们 ...

  7. Web安全之变量覆盖漏洞

    通常将可以用自定义的参数值替换原有变量值的情况称为变量覆盖漏洞.经常导致变量覆盖漏洞场景有:$$使用不当,extract()函数使用不当,parse_str()函数使用不当,import_reques ...

  8. C++中哪些函数不能声明为virtual?

    首先要明确,virtual是用于支持类多态的关键字,所以出现在类声明之外的地方都是错误的.由此可以断定下文的1. 普通函数(即非类成员函数)不能是virtual的,否则不能通过编译,virtual只能 ...

  9. 剑指Offer(十九)——顺时针打印矩阵

    题目描述 输入一个矩阵,按照从外向里以顺时针的顺序依次打印出每一个数字. 例如,如果输入如下4 X 4矩阵: 1   2    3     4 5   6    7     8 9   10  11  ...

  10. selenium驱动chrome浏览器问题

    selenium是一个浏览器自动化测试框架,以下介绍其如何驱动chrome浏览器? 1.下载与本地chrome版本对应的chromedriver.exe ,下载地址为http://npm.taobao ...