深度学习之 mnist 手写数字识别

开始学习深度学习,先来一个手写数字的程序

  1. import numpy as np
  2. import os
  3. import codecs
  4. import torch
  5. from PIL import Image
  6. lr = 0.01
  7. momentum = 0.5
  8. epochs = 10
  9. def get_int(b):
  10. return int(codecs.encode(b, 'hex'), 16)
  11. def read_label_file(path):
  12. with open(path, 'rb') as f:
  13. data = f.read()
  14. assert get_int(data[:4]) == 2049
  15. length = get_int(data[4:8])
  16. parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
  17. return torch.from_numpy(parsed).view(length).long()
  18. def read_image_file(path):
  19. with open(path, 'rb') as f:
  20. data = f.read()
  21. assert get_int(data[:4]) == 2051
  22. length = get_int(data[4:8])
  23. num_rows = get_int(data[8:12])
  24. num_cols = get_int(data[12:16])
  25. images = []
  26. parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
  27. return torch.from_numpy(parsed).view(length, num_rows, num_cols)
  28. def loadmnist(path, kind='train'):
  29. labels_path = os.path.join(path, 'mnist' ,'%s-labels.idx1-ubyte' % kind)
  30. images_path = os.path.join(path,'mnist' ,'%s-images.idx3-ubyte' % kind)
  31. labels = read_label_file(labels_path)
  32. images = read_image_file(images_path)
  33. return images, labels
  34. import torch.utils.data as data
  35. import torchvision.transforms as transforms
  36. class Loader(data.Dataset):
  37. def __init__(self, root, label, transforms):
  38. self.imgs = []
  39. imgs,labels = loadmnist(root, label)
  40. self.imgs = imgs
  41. self.labels = labels
  42. self.transforms = transforms
  43. def __getitem__(self, index):
  44. img, label = self.imgs[index],self.labels[index]
  45. img = Image.fromarray(img.numpy(), mode='L')
  46. if self.transforms:
  47. img = self.transforms(img)
  48. return img, label
  49. def __len__(self):
  50. return len(self.imgs)
  51. def getTrainDataset():
  52. return Loader('d:\\work\\yoho\\dl\\dl-study\\chapter0\\', 'train', transforms.Compose([
  53. transforms.ToTensor(),
  54. transforms.Normalize((0.1307,), (0.3081,)),
  55. ]))
  56. def getTestDataset():
  57. return Loader('d:\\work\\yoho\\dl\\dl-study\\chapter0\\', 't10k', transforms.Compose([
  58. transforms.ToTensor(),
  59. transforms.Normalize((0.1307,), (0.3081,)),
  60. ]))
  61. import torch as t
  62. import torch.nn as nn
  63. class Net(nn.Module):
  64. def __init__(self):
  65. super(Net, self).__init__()
  66. self.features = nn.Sequential(
  67. nn.Conv2d(1, 10, kernel_size=5),
  68. nn.MaxPool2d(2),
  69. nn.ReLU(inplace=True),
  70. nn.Conv2d(10, 20, kernel_size=5),
  71. nn.Dropout2d(),
  72. nn.MaxPool2d(2),
  73. nn.ReLU(inplace=True),
  74. )
  75. self.classifier = nn.Sequential(
  76. nn.Linear(320, 50),
  77. nn.ReLU(inplace=True),
  78. nn.Dropout(),
  79. nn.Linear(50, 10),
  80. nn.LogSoftmax(dim=1)
  81. )
  82. def forward(self, x):
  83. x = self.features(x)
  84. x = x.view(x.size(0), -1)
  85. x = self.classifier(x)
  86. return x
  87. net = Net()
  88. import torch.optim as optim
  89. from torch.nn.modules import loss
  90. optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
  91. criterion = loss.CrossEntropyLoss()
  92. train_dataset = getTrainDataset()
  93. test_dataset = getTestDataset()
  94. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
  95. test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False)
  96. from torch.autograd import Variable as V
  97. def train(epoch):
  98. for i, (inputs, labels) in enumerate(train_loader):
  99. inputs_var, labels_var = V(inputs), V(labels)
  100. outputs = net(inputs_var)
  101. losses = criterion(outputs, labels_var)
  102. optimizer.zero_grad()
  103. losses.backward()
  104. optimizer.step()
  105. def test(epoch):
  106. for i, (inputs, labels) in enumerate(test_loader):
  107. inputs_var = V(inputs)
  108. outputs = net(inputs_var)
  109. _, pred = outputs.data.topk(5, 1, True, True)
  110. batch_size = labels.size(0)
  111. pred = pred.t()
  112. corrent = pred.eq(labels.view(1, -1).expand_as(pred))
  113. res = []
  114. for k in (1,5):
  115. correct_k = corrent[:k].view(-1).float().sum(0, keepdim=True)
  116. res.append(correct_k.mul_(100.0 / batch_size))
  117. print('{} {} top1 {} top5 {}'.format(epoch, i ,res[0][0], res[1][0]))
  118. def main():
  119. for epoch in range(0, epochs):
  120. train(epoch)
  121. test(epoch)
  122. main()

学习之后的,正确率很高,这种问题对于深度学习已经解决了。

深度学习之 mnist 手写数字识别的更多相关文章

  1. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  2. 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别

    用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...

  3. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  4. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  5. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

  6. Tensorflow之MNIST手写数字识别:分类问题(1)

    一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点:   1.将离散特征的取值扩展 ...

  7. Tensorflow实现MNIST手写数字识别

    之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...

  8. Pytorch入门——手把手教你MNIST手写数字识别

    MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...

  9. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

随机推荐

  1. mysql初步学习

    1.insert_select 的使用:从一个表复制数据给另一个表 INSERT INTO students(name,sex,LikeBooksNUM,LikesportNUM,average) S ...

  2. C语言最后一次作业--总结报告

    1.当初你是如何做出选择计算机专业的决定的? 经过一个学期,你的看法改变了么,为什么? 你觉得计算机是你喜欢的领域吗,它是你擅长的领域吗? 为什么? 当时选择计算机专业,是基于自己的高考分数和想出省的 ...

  3. JSP中动态include和静态include区别

    静态include(<%@ include file=""%>): 静态include(静态导入)是指将一个其他文件(一个jsp/html)嵌入到本页面 jsp的inc ...

  4. 目标检测网络之 YOLOv2

    YOLOv1基本思想 YOLO将输入图像分成SxS个格子,若某个物体 Ground truth 的中心位置的坐标落入到某个格子,那么这个格子就负责检测出这个物体. 每个格子预测B个bounding b ...

  5. DOM生成XML文档

    import java.io.File; import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuil ...

  6. vuex的学习笔记

    什么是Vuex? vuex是一个专门为vue.js设计的集中式状态管理架构.状态?我把它理解为在data中的属性需要共享给其他vue组件使用的部分,就叫做状态.简单的说就是data中需要共用的属性. ...

  7. C++基于范围循环(range-based for loop)的陷阱

    C++的基于范围的循环是C++11出现的新特性,很方便,一定程度上替代了使用迭代器的for循环用法.不过基于范围的for循环有一个隐藏的陷阱,如果不注意可能会出现严重的内存错误. 举例说明 看下面这个 ...

  8. 【JS】 Javascript与HTML DOM的互动 寻路

    JS HTML DOM DOM的全程是Document Object Module,即文档对象模型.一般来说,当一个页面被加载时,浏览器会在内部创建一个当前文档的DOM.就像用python的Etree ...

  9. canvas星空和图形变换

    图形变换. 一.画一片星空 先画一片canvas.width宽canvas.height高的黑色星空,再画200个随机位置,随机大小,随机旋转角度的星星. window.onload=function ...

  10. 如何让shell脚本自杀

    有些时候我们写的shell脚本中有一些后台任务,当脚本的流程已经执行到结尾处并退出时,这些后台任务会直接挂靠在init/systemd进程下,而不会随着脚本退出而停止. 例如: [root@maria ...