数据准备

课程中获取数据的方法是从库中直接load_data

  1. from keras.datasets import mnist
  2. (x_train, y_train), (x_test, y_test) = mnist.load_data()

我尝试了一下,报这样的错误:[WinError 10054] 远程主机强迫关闭了一个现有的连接。so,我就直接去官网下载了数据集:http://yann.lecun.com/exdb/mnist/。该数据下载后得到的是idx格式数据,具体处理方法参考了这篇博客https://www.jianshu.com/p/84f72791806f,测试可用的源码如下(规则在注释里写得很详细),在后文中会直接调用里边的函数。

  1. import numpy as np
  2. import struct
  3. import matplotlib.pyplot as plt
  4.  
  5. # 训练集文件
  6. train_images_idx3_ubyte_file = 'C:\\Users\\小辉\\Desktop\\MNIST\\train-images.idx3-ubyte'
  7. # 训练集标签文件
  8. train_labels_idx1_ubyte_file = 'C:\\Users\\小辉\\Desktop\\MNIST\\train-labels.idx1-ubyte'
  9.  
  10. # 测试集文件
  11. test_images_idx3_ubyte_file = 'C:\\Users\\小辉\\Desktop\\MNIST\\t10k-images.idx3-ubyte'
  12. # 测试集标签文件
  13. test_labels_idx1_ubyte_file = 'C:\\Users\\小辉\\Desktop\\MNIST\\t10k-labels.idx1-ubyte'
  14.  
  15. def decode_idx3_ubyte(idx3_ubyte_file):
  16. """
  17. 解析idx3文件的通用函数
  18. :param idx3_ubyte_file: idx3文件路径
  19. :return: 数据集
  20. """
  21. # 读取二进制数据
  22. bin_data = open( train_images_idx3_ubyte_file, 'rb').read()
  23.  
  24. # 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
  25. offset = 0
  26. fmt_header = '>iiii'
  27. magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
  28. #print('魔数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))
  29.  
  30. # 解析数据集
  31. image_size = num_rows * num_cols
  32. offset += struct.calcsize(fmt_header)
  33. fmt_image = '>' + str(image_size) + 'B'
  34. images = np.empty((num_images, num_rows, num_cols))
  35. for i in range(num_images):
  36. #if (i + 1) % 10000 == 0:
  37. #print('已解析 %d' % (i + 1) + '张')
  38. images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols))
  39. offset += struct.calcsize(fmt_image)
  40. return images
  41.  
  42. def decode_idx1_ubyte(idx1_ubyte_file):
  43. """
  44. 解析idx1文件的通用函数
  45. :param idx1_ubyte_file: idx1文件路径
  46. :return: 数据集
  47. """
  48. # 读取二进制数据
  49. bin_data = open(idx1_ubyte_file, 'rb').read()
  50.  
  51. # 解析文件头信息,依次为魔数和标签数
  52. offset = 0
  53. fmt_header = '>ii'
  54. magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)
  55. #print('魔数:%d, 图片数量: %d张' % (magic_number, num_images))
  56.  
  57. # 解析数据集
  58. offset += struct.calcsize(fmt_header)
  59. fmt_image = '>B'
  60. labels = np.empty(num_images)
  61. for i in range(num_images):
  62. #if (i + 1) % 10000 == 0:
  63. # print('已解析 %d' % (i + 1) + '张')
  64. labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
  65. offset += struct.calcsize(fmt_image)
  66. return labels
  67.  
  68. def load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):
  69. """
  70. TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
  71. [offset] [type] [value] [description]
  72. 0000 32 bit integer 0x00000803(2051) magic number
  73. 0004 32 bit integer 60000 number of images
  74. 0008 32 bit integer 28 number of rows
  75. 0012 32 bit integer 28 number of columns
  76. 0016 unsigned byte ?? pixel
  77. 0017 unsigned byte ?? pixel
  78. ........
  79. xxxx unsigned byte ?? pixel
  80. Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
  81.  
  82. :param idx_ubyte_file: idx文件路径
  83. :return: n*row*col维np.array对象,n为图片数量
  84. """
  85. return decode_idx3_ubyte(idx_ubyte_file)
  86.  
  87. def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):
  88. """
  89. TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
  90. [offset] [type] [value] [description]
  91. 0000 32 bit integer 0x00000801(2049) magic number (MSB first)
  92. 0004 32 bit integer 60000 number of items
  93. 0008 unsigned byte ?? label
  94. 0009 unsigned byte ?? label
  95. ........
  96. xxxx unsigned byte ?? label
  97. The labels values are 0 to 9.
  98.  
  99. :param idx_ubyte_file: idx文件路径
  100. :return: n*1维np.array对象,n为图片数量
  101. """
  102. return decode_idx1_ubyte(idx_ubyte_file)
  103.  
  104. def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):
  105. """
  106. TEST SET IMAGE FILE (t10k-images-idx3-ubyte):
  107. [offset] [type] [value] [description]
  108. 0000 32 bit integer 0x00000803(2051) magic number
  109. 0004 32 bit integer 10000 number of images
  110. 0008 32 bit integer 28 number of rows
  111. 0012 32 bit integer 28 number of columns
  112. 0016 unsigned byte ?? pixel
  113. 0017 unsigned byte ?? pixel
  114. ........
  115. xxxx unsigned byte ?? pixel
  116. Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
  117.  
  118. :param idx_ubyte_file: idx文件路径
  119. :return: n*row*col维np.array对象,n为图片数量
  120. """
  121. return decode_idx3_ubyte(idx_ubyte_file)
  122.  
  123. def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):
  124. """
  125. TEST SET LABEL FILE (t10k-labels-idx1-ubyte):
  126. [offset] [type] [value] [description]
  127. 0000 32 bit integer 0x00000801(2049) magic number (MSB first)
  128. 0004 32 bit integer 10000 number of items
  129. 0008 unsigned byte ?? label
  130. 0009 unsigned byte ?? label
  131. ........
  132. xxxx unsigned byte ?? label
  133. The labels values are 0 to 9.
  134.  
  135. :param idx_ubyte_file: idx文件路径
  136. :return: n*1维np.array对象,n为图片数量
  137. """
  138. return decode_idx1_ubyte(idx_ubyte_file)
  139.  
  140. def run():
  141. train_images = load_train_images()
  142. train_labels = load_train_labels()
  143. test_images = load_test_images()
  144. test_labels = load_test_labels()
  145.  
  146. # 查看前十个数据及其标签以读取是否正确
  147. for i in range(10):
  148. print(train_labels[i])
  149. plt.imshow(train_images[i], cmap='gray')
  150. plt.show()
  151. print('done')
  152.  
  153. if __name__ == '__main__':
  154. run()

测试用的源码

数据预处理

导入相关包依赖及预处理函数

  1. import numpy as np
  2. from keras.models import Sequential
  3. from keras.layers.core import Dense, Dropout, Activation
  4. from keras.layers import Conv2D, MaxPooling2D, Flatten
  5. from keras.optimizers import SGD, Adam
  6. from keras.utils import np_utils
  7.  
  8. #clean data
  9. def load_dataset():
  10. x_train, y_train = load_train_images(), load_train_labels()
  11. x_test, y_test = load_test_images(), load_test_labels()
  12. number = 60000
  13. x_train, y_train = x_train[0:number], y_train[0:number]
  14. x_train = x_train.reshape(number, 28*28)
  15. x_test = x_test.reshape(x_test.shape[0], 28*28)
  16. x_train, x_test = x_train.astype('float32'), x_test.astype('float32')
  17. y_train, y_test = np_utils.to_categorical(y_train, 10), np_utils.to_categorical(y_test, 10)
  18. x_train, x_test = x_train / 255, x_test / 255
  19. return (x_train, y_train), (x_test, y_test)

到此,我们得到了训练和测试网络所需要的数据。

网络的搭建及训练结果

  1. 搭建网络训练结果
  2. (x_train, y_train), (x_test, y_test) = load_dataset()
  3. model = Sequential()
  4. #搭建三层网络
  5. model.add(Dense(input_dim=28*28,units=633,activation='sigmoid'))
  6. model.add(Dense(units=633,activation='sigmoid'))
  7. model.add(Dense(units=10,activation='softmax'))
  8.  
  9. model.compile(loss='mse',optimizer=SGD(lr=0.1),metrics=['accuracy'])
  10. model.fit(x_train,y_train,batch_size=100,epochs=20)
  11. result = model.evaluate(x_test,y_test)
  12. print('Test loss:', result[0])
  13. print('Accuracy:', result[1])

效果如下图所示:

改动地方主要为:

  • 激励函数由sigmoid改为relu
  • loss function由mse改为categorical_crossentropy
  • 增加了Dropout,防止过拟合

改动后构建模型代码:

  1. #搭建网络训练结果
  2. (x_train, y_train), (x_test, y_test) = load_dataset()
  3. model = Sequential()
  4. #搭建三层网络
  5. model.add(Dense(input_dim=28*28,units=700,activation='relu'))
  6. model.add(Dropout(0.2))
  7. model.add(Dense(units=700,activation='relu'))
  8. model.add(Dropout(0.2))
  9. model.add(Dense(units=10,activation='softmax'))
  10.  
  11. model.compile(loss='categorical_crossentropy',optimizer=SGD(lr=0.1),metrics=['accuracy'])
  12. model.fit(x_train,y_train,batch_size=100,epochs=20,validation_split=0.05)
  13. result = model.evaluate(x_test,y_test)
  14.  
  15. print('Test loss:', result[0])
  16. print('Accuracy:', result[1])

效果如下所示:

得到了比较好的测试结果。其中,最主要的还是激励函数影响。
1. 采用sigmoid等函数,算激活函数时(指数运算),计算量大,反向传播求误差梯度时,求导涉及除法,计算量相对大,而采用Relu激活函数,整个过程的计算量节省很多。
2. 对于深层网络,sigmoid函数反向传播时,很容易就会出现梯度消失的情况(在sigmoid接近饱和区时,变换太缓慢,导数趋于0,这种情况会造成信息丢失,从而无法完成深层网络的训练。
3. Relu会使一部分神经元的输出为0,这样就造成了网络的稀疏性,并且减少了参数的相互依存关系,缓解了过拟合问题的发生。

参考:https://blog.csdn.net/waple_0820/article/details/79415397

手写数字识别---demo的更多相关文章

  1. 【问题解决方案】Keras手写数字识别-ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接

    参考:台大李宏毅老师视频课程-Keras-Demo 在载入数据阶段报错: ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接 Google之 ...

  2. 【机器学习】李宏毅机器学习-Keras-Demo-神经网络手写数字识别与调参

    参考: 原视频:李宏毅机器学习-Keras-Demo 调参博文1:深度学习入门实践_十行搭建手写数字识别神经网络 调参博文2:手写数字识别---demo(有小错误) 代码链接: 编程环境: 操作系统: ...

  3. 利用神经网络算法的C#手写数字识别

    欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...

  4. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  5. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  6. 利用神经网络算法的C#手写数字识别(一)

    利用神经网络算法的C#手写数字识别 转发来自云加社区,用于学习机器学习与神经网络 欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwri ...

  7. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  8. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  9. 【深度学习系列】PaddlePaddle之手写数字识别

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

随机推荐

  1. 使用jmeter工具测试上传接口

    1.方法选择post:上传都是post上传. 2.路径输入正确的上传接口路径,并勾选Use multipart/form-data for POST 3.添加文件,文件路径尽量不要有中文,防止编码问题 ...

  2. python使用ip代理抓取网页

    在抓取一个网站的信息时,如果我们进行频繁的访问,就很有可能被网站检测到而被屏蔽,解决这个问题的方法就是使用ip代理 .在我们接入因特网进行上网时,我们的电脑都会被分配一个全球唯一地ip地址供我们使用, ...

  3. NABCD模型--软件工程

    1.N (Need 需求) 我们通过网络调查问卷的方式,收集样本数据,并对其进行分析和总结. 1.你是否为在校学生? 7.如果用过,你觉得还应该需要添加什么功能 通过调查发现,大多数学生并不是特别了解 ...

  4. composer install 时,提示:Package yiisoft/yii2-codeception is abandoned, you should avoid using it. Use codeception/codeception instead.的解决

    由 SHUIJINGWAN · 2017/11/24 1.composer install 时,提示:Package yiisoft/yii2-codeception is abandoned, yo ...

  5. ubuntu16.04 LTS Server 安装mysql phpmyadmin apache2 php5.6环境

    1.安装apache sudo apt-get install apache2 为了测试apache2是否正常,访问http://localhost/或http://127.0.0.1/,出现It W ...

  6. 20155333 2016-2017-2 《Java程序设计》第五周学习总结

    20155333 2016-2017-2 <Java程序设计>第五周学习总结 教材学习内容总结 1.使用try.catch语法 与C语言中程序流程和错误处理混在一起不同,Java中把正常流 ...

  7. ORACLE 查看分区表分区大小

    SELECT *  FROM dba_segments t WHERE t.segment_name ='table_name'; pratition_name : 分区名 bytes : 分区大小( ...

  8. 2018.08.30 bzoj4318: OSU!(期望dp)

    传送门 简单期望dp. 感觉跟Easy差不多,就是把平方差量进阶成了立方差量,原本维护的是(x+1)2−x2" role="presentation" style=&qu ...

  9. hdu-1179(匈牙利算法)

    题目链接: 思路:找n个巫师和m个魔棒匹配的问题,匈牙利算法模板 匈牙利算法:https://blog.csdn.net/sunny_hun/article/details/80627351 #inc ...

  10. TableView编辑状态下跳转页面的崩溃处理

    29down votefavorite 12 I have a viewController with a UITableView, the rows of which I allow to edit ...