1. 1 import tensorflow as tf
  2. import tensorflow.examples.tutorials.mnist.input_data as input_data
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. mnist=input_data.read_data_sets("MNIST_data/",one_hot=True) #下载据数
  6. print('train images:',mnist.train.images.shape, #查看数据
  7. 'labels:',mnist.train.labels.shape)
  8. print('validation images:',mnist.validation.images.shape,
  9. 'labels:',mnist.validation.labels.shape)
  10. print('test images:',mnist.test.images.shape,
  11. 'labels:',mnist.test.labels.shape
  12. #定义显示多项图像的函数
  13. def plot_images_labels_prediction_3(images,labels,prediction,idx,num=):
  14. fig=plt.gcf()
  15. fig.set_size_inches(,)
  16. if num>:num=
  17. for i in range(,num):
  18. ax=plt.subplot(,,i+)
  19. ax.imshow(np.reshape(images[idx],(,)),cmap='binary')
  20. title='lable='+str(np.argmax(labels[idx]))
  21. if len(prediction)>:
  22. title+=",prediction="+str(prediction[idx])
  23. ax.set_title(title,fontsize=)
  24. ax.set_xticks([]);ax.set_yticks([])
  25. idx+=
  26. plt.show()
  27.  
  28. plot_images_labels_prediction_3(mnist.train.images,mnist.train.labels,[],)
  29. #定义layer函数,构建多层感知器模型
  30. def layer(output_dim,input_dim,inputs,activation=None):
  31. W=tf.Variable(tf.random_normal([input_dim,output_dim]))
  32. b=tf.Variable(tf.random_normal([,output_dim]))
  33. XWb=tf.matmul(inputs,W)+b
  34. if activation is None:
  35. outputs=XWb
  36. else:
  37. outputs=activation(XWb)
  38. return outputs
  39. #建立输入层
  40. x=tf.placeholder("float",[None,])
  41. #建立隐藏层
  42. h1=layer(output_dim=,input_dim=,inputs=x,
  43. activation=tf.nn.relu)
  44. #建立输出层
  45. y_predict=layer(output_dim=,input_dim=,inputs=h1,
  46. activation=None)
  47. y_label=tf.placeholder("float",[None,])
  48. #定义损失函数
  49. loss_function=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits
  50. (logits=y_predict,
  51. labels=y_label))
  52. #定义优化器
  53. optimizer=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_function)
  54. #计算每一项数据是否预测正确
  55. correct_prediction=tf.equal(tf.argmax(y_label,),
  56. tf.argmax(y_predict,))
  57. #计算预测正确结果的平均值
  58. accuracy=tf.reduce_mean(tf.cast(correct_prediction,"float"))
  59. #、定义训练参数
  60. trainEpochs= #设置执行15个训练周期
  61. batchSize= #每一批次项数为100
  62. totalBatchs=int(mnist.train.num_examples/batchSize) #计算每个训练周期
  63. loss_list=[];epoch_list=[];accuracy_list=[] #初始化训练周期、误差、准确率
  64. from time import time #导入时间模块
  65. startTime=time() #开始计算时间
  66. sess=tf.Session() #建立Session
  67. sess.run(tf.global_variables_initializer()) #初始化TensorFlow global 变量
  68. #、进行训练
  69. for epoch in range(trainEpochs):
  70. for i in range(totalBatchs):
  71. batch_x,batch_y=mnist.train.next_batch(batchSize) #使用mnist.train.next_batch方法读取批次数据,传入参数batchSize是100
  72. sess.run(optimizer,feed_dict={x:batch_x,
  73. y_label:batch_y}) #执行批次训练
  74. loss,acc=sess.run([loss_function,accuracy], #使用验证数据计算准确率
  75. feed_dict={x:mnist.validation.images,
  76. y_label:mnist.validation.labels})
  77. epoch_list.append(epoch); #加入训练周期列表
  78. loss_list.append(loss) #加入误差列表
  79. accuracy_list.append(acc) #加入准确率列表
  80. print("Train Epoch:",'%02d' % (epoch+),"Loss=",\
  81. "{:.9f}".format(loss),"Accuracy=",acc)
  82. duration=time()-startTime
  83. print("Train Finished takes:",duration) #计算并显示全部训练所需时间
  84. #画出误差执行结果
  85.  
  86. fig=plt.gcf()
  87. fig.set_size_inches(,)
  88. plt.plot(epoch_list,loss_list,label='loss')
  89. plt.ylabel('loss')
  90. plt.xlabel('epoch')
  91. plt.legend(['loss'],loc='upper left')
  92. #画出准确率执行结果
  93. plt.plot(epoch_list,accuracy_list,label="accuracy")
  94. fig=plt.gcf()
  95. fig.set_size_inches(,)
  96. plt.ylim(0.8,)
  97. plt.ylabel('accuracy')
  98. plt.xlabel('epoch')
  99. plt.legend()
  100. plt.show()
  101. #评估模型准确率
  102. print("accuracy:",sess.run(accuracy,
  103. feed_dict={x:mnist.test.images,
  104. y_label:mnist.test.labels}))
  105. #进行预测
  106. #.执行预测
  107. prediction_result=sess.run(tf.argmax(y_predict,),
  108. feed_dict={x:mnist.test.images})
  109. #.预测结果
  110. print(prediction_result[:])
  111. #.显示前10项预测结果
  112. plot_images_labels_prediction_3(mnist.test.images,
  113. mnist.test.labels,
  114. prediction_result,)

运行结果:

TensorFlow—多层感知器—MNIST手写数字识别的更多相关文章

  1. 【TensorFlow-windows】(三) 多层感知器进行手写数字识别(mnist)

    主要内容: 1.基于多层感知器的mnist手写数字识别(代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  2. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  3. Tensorflow实现MNIST手写数字识别

    之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...

  4. TensorFlow——MNIST手写数字识别

    MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/   一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...

  5. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  6. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  7. Tensorflow之MNIST手写数字识别:分类问题(1)

    一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点:   1.将离散特征的取值扩展 ...

  8. 基于TensorFlow的MNIST手写数字识别-初级

    一:MNIST数据集    下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...

  9. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

随机推荐

  1. RouterOS 设定NAT loopback (Hairpin NAT)回流

    In the below network topology a web server behind a router is on private IP address space, and the r ...

  2. USB接口程序编写

    copy from http://blog.csdn.net/luckywang1103/article/details/12393889# HID是Human Interface Devices的缩 ...

  3. php如何高效的读取大文件

    通常来说在php读取大文件的时候,我们采用的方法一般是一行行来讲取,而不是一次性把文件全部写入内存中,这样会导致php程序卡死,下面就给大家介绍这样一个例子. 需求:有一个800M的日志文件,大约有5 ...

  4. sencha touch 小米3无法点击问题 修复

    修改源码文件夹下event/publisher/Dom.js中的attachListener方法,代码如下 attachListener: function(eventName, doc) { if ...

  5. LinkedHashMap唯一,存储取出有序

    package cn.itcast_03; import java.util.LinkedHashMap; import java.util.Set; /* * LinkedHashMap:是Map接 ...

  6. django中使用Form组件

    内容: 1.Form组件介绍 2.Form组件常用字段 3.Form组件校验功能 4.Form组件内置正则校验 参考:https://www.cnblogs.com/liwenzhou/p/87478 ...

  7. php中点击链接直接下载图片

    最近需要一个功能,是点击链接,直接把图片下载下来,一般情况下,图片是在新页直接打开的,不会自动提示下载,在网上找来找,用这个挺好使,代码如下: $filename = basename($downfi ...

  8. hdu 4370 0 or 1,最短路

    题目描述 给定n * n矩阵C ij(1 <= i,j <= n),我们要找到0或1的n * n矩阵X ij(1 <= i,j <= n). 此外,X ij满足以下条件: 1. ...

  9. mysql数据库的维护,备份和复制

    在数据库运行时维护数据库 执行mysql数据库维护的方法之一就是连接mysql服务器,并告诉它做什么事, 如对myisam数据表进行检查或者修复, 可以使用check table tbname或rep ...

  10. JAVA 读取配置文件 xxx.properties

    package config_demo; import java.io.InputStream; import java.util.Properties; public class UrlDemo { ...