1. 1 #coding:utf-8
  2. # 日期 2017年9月4日 环境 Python 3.5  TensorFlow 1.3 win10开发环境。
  3. import tensorflow as tf
  4. from tensorflow.examples.tutorials.mnist import input_data
  5. import os
  6.  
  7. # 基础的学习率
  8. LEARNING_RATE_BASE = 0.8
  9.  
  10. # 学习率的衰减率
  11. LEARNING_RATE_DECAY = 0.99
  12.  
  13. # 描述模型复杂度的正则化项在损失函数中的系数
  14. REGULARIZATION_RATE = 0.0001
  15.  
  16. # 训练轮数
  17. TRAINING_STEPS = 30000
  18.  
  19. # 滑动平均衰减率
  20. MOVING_AVERAGE_DECAY = 0.99
  21.  
  22. # 模型持久化保存路径
  23. MODEL_SAVE_PATH = "MNIST_model/"
  24. # 模型持久化保存文件名称
  25. MODEL_NAME = "mnist_model"
  26.  
  27. # 输入层节点数(对于数据集,相当于整个图片的像素数目)
  28. INPUT_NODE = 784
  29.  
  30. # 输出层的节点数(根据10个数字决定的)
  31. OUTPUT_NODE = 10
  32.  
  33. # 隐藏层的节点数,此例程中,隐藏层为一层。
  34. LAYER1_NODE = 500
  35.  
  36. # 一个训练batch中的训练数据个数,数字越小的时候,训练过程越接近随机梯度下降。
  37. BATCH_SIZE = 100
  38.  
  39. def train(mnist):
  40. # 定义输入输出placeholder。
  41. x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
  42. y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
  43. # 正则化损失函数
  44. regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
  45. # 使用定义的向前传播过程
  46. y = inference(x, regularizer)
  47.  
  48. # 定义存储训练轮数的变量。这个变量不需要计算滑动的平均值,所以这里指定这个变量为不可训练的变量(trainable=False)。
  49. # 在tensorflow中训练神经网络的时候,一般会将代表训练轮数的变量指定为不可训练的参数。
  50. global_step = tf.Variable(0, trainable=False)
  51.  
  52. # 定义损失函数、学习率、滑动平均操作以及训练过程。
  53. variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
  54. # 在所有代表神经网络参数的变量上使用滑动平均。其它辅助变量(如global_step)就不需要了
  55. variables_averages_op = variable_averages.apply(tf.trainable_variables())
  56. # 计算交叉熵作为刻画预测值和真实值之间差距的损失函数。(第一个参数是神经网络不包含softmax层的前向传播结果,第二个是训练数据的正确答案)
  57. # 因为标准答案是一个长度为10的一维数组,二该函数需要提供的是一个正确答案的数字,所以需要使用tf.argmax函数来得到正确答案对应的类别编号。
  58. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
  59. # 计算当前batch中所有样例的交叉熵平均值
  60. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  61. # 总损失等于交叉熵和
  62. loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
  63.  
  64. # 设置指数衰减的学习率
  65. learning_rate = tf.train.exponential_decay(
  66. LEARNING_RATE_BASE, # 基础的学习率,随着迭代的进行,更新变量时使用的学习率在这个基础上递减
  67. global_step, # 当前迭代的轮数
  68. mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY, # 过完所有的训练数据需要的迭代次数
  69. staircase=True)
  70.  
  71. # 使用tf.train.GradientDescentOptimizer优化算法来优化损失函数。注意这里损失函数包含了交叉熵和正则损失
  72. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  73.  
  74. with tf.control_dependencies([train_step, variables_averages_op]):
  75. train_op = tf.no_op(name='train')
  76.  
  77. # 初始化TensorFlow持久化类。
  78. saver = tf.train.Saver()
  79. with tf.Session() as sess:
  80. tf.global_variables_initializer().run()
  81.  
  82. # 在训练过程中,不在测试模型在验证数据上的表现,验证和测试的过程将会有一个独立的程序来完成。
  83. for i in range(TRAINING_STEPS):
  84. xs, ys = mnist.train.next_batch(BATCH_SIZE)
  85. _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
  86. if i % 1000 == 0:
  87. # 输出当前的训练情况,这里只输出了模型在当前训练batch上的损失函数大小,通过损失函数的大小可以大概了解训练的情况。在验证数据集上的正确
  88. # 率信息会有一个单独的程序来生成。
  89. print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
  90. # 保存当前的模型。global_step参数,这样可以让每个被保存模型的文件名末尾加上训练的轮数,如model.ckpt-1000表示训练1000轮之后得到的模型
  91. saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
  92.  
  93. # 通过tf.get_variable函数来获取变量 在测试是会通过保存的模型加载这些变量的取值。而且更加方便的是,因为可以在变量加载时将滑动平均变量重命名
  94. # 所以可以直接通过同样的名字在训练时使用变量自身,而在测试时使用变量的滑动平均值。这个函数中会将变量的正则化损失加损失集合。
  95. def get_weight_variable(shape, regularizer):
  96. weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))
  97. # 当给出正则化生产函数时,将当前变量的正则化损失加入名字为Losses的集合。在这里使用了add_to_collection函数将一个张量加入一个集合,
  98. # 而这个集合的名称为losses.这是自定义集合,不在Tensorflow自动管理的集合列表中
  99. if regularizer != None: tf.add_to_collection('losses', regularizer(weights))
  100. return weights
  101.  
  102. # 定义神经网络的前向传播过程(初始化所有参数的辅助函数,给定神经网络中的参数)
  103. def inference(input_tensor, regularizer):
  104. # 声明第一层神经网络的变量并完成前向传播过程
  105. with tf.variable_scope('layer1'):
  106. # 通过tf.get_variable 和tf.Variable没有本质区别,因为在训练或是测试中没有在同一个程序中多次调用这个函数。如果在同一个过程多次调用,
  107. # 在第一调用的之后需要将resuse参数设置为True
  108. weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
  109. biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
  110. layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
  111.  
  112. # 声明第二层神经网络的变量并完成向前传播的过程
  113. with tf.variable_scope('layer2'):
  114. weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
  115. biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
  116. layer2 = tf.matmul(layer1, weights) + biases
  117.  
  118. return layer2
  119.  
  120. # 2.主程序部分
  121. def main(argv=None):
  122. # 获取数据集(根据谷歌的例程中相关的获取路径)
  123. mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)
  124. # 根据数据集训练模型
  125. train(mnist)
  126.  
  127. # 1 .程序入口
  128. if __name__ == '__main__':
  129. main()

对Tensorflow中经典的MNIST模型的学习,程序整个过程进行了注释,摘自《实战google深度学习框架》中代码,并进行修改后注释。

Tensorflow学习笔记(对MNIST经典例程的)的代码注释与理解的更多相关文章

  1. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  2. 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别

    深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...

  3. TensorFlow学习笔记(MNIST报错修正 适用Tensorflow1.3)

    在Tensorflow实战Google框架下的深度学习这本书的MNIST的图像识别例子中,每次都要报错   错误如下: Only call `sparse_softmax_cross_entropy_ ...

  4. tensorflow学习笔记————分类MNIST数据集

    在使用tensorflow分类MNIST数据集中,最容易遇到的问题是下载MNIST样本的问题. 一般是通过使用tensorflow内置的函数进行下载和加载, from tensorflow.examp ...

  5. tensorflow学习笔记(10) mnist格式数据转换为TFrecords

    本程序 (1)mnist的图片转换成TFrecords格式 (2) 读取TFrecords格式 # coding:utf-8 # 将MNIST输入数据转化为TFRecord的格式 # http://b ...

  6. Tensorflow学习笔记No.5

    tf.data卷积神经网络综合应用实例 使用tf.data建立自己的数据集,并使用CNN卷积神经网络实现对卫星图像的二分类问题. 数据下载链接:https://pan.baidu.com/s/141z ...

  7. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  8. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  9. TensorFlow学习笔记——LeNet-5(训练自己的数据集)

    在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...

  10. tensorflow学习笔记——自编码器及多层感知器

    1,自编码器简介 传统机器学习任务很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难了,工程师必须在这 ...

随机推荐

  1. [BZOJ 1297][SCOI2009]迷路

    1297: [SCOI2009]迷路 Time Limit: 10 Sec  Memory Limit: 162 MBSubmit: 1418  Solved: 1017[Submit][Status ...

  2. 201621123060《JAVA程序设计》第一周学习总结

    1.本周学习总结 1.讲述了JAVA的发展史,关于JDK.JRE.JVM的联系和区别 2.JDK是用JAVA开发工具.做项目的关键.JRE是JAVA的运行环境(JAVA也是JAVA语言开发的).JVM ...

  3. android数据库持久化框架, ormlite框架,

    前言 Android中内置了SQLite,但是对于数据库操作这块,非常的麻烦.其实可以试用第3方的数据库持久化框架对之进行结构上调整, 摆脱了访问数据库操作的细节,不用再去写复杂的SQL语句.虽然这样 ...

  4. a标签传递参数

    a标签传递参数 单个参数:参数名称前面跟   ? <a href="localhost:8080/arguments?id=1">单个参数</a> 多个参数 ...

  5. js中多维数组转一维

    法一:使用数组map()方法,对数组中的每一项运行给定函数,返回每次函数调用的结果组成的数组. var arr = [1,[2,[[3,4],5],6]]; function unid(arr){ v ...

  6. Java8-如何构建一个Stream

    Stream的创建方式有很多种,除了最常见的集合创建,还有其他几种方式. List转Stream List继承自Collection接口,而Collection提供了stream()方法. List& ...

  7. Linux入门(1)_VMware和系统分区和系统安装和远程登陆管理

    1 VMware的安装和使用 注意有 快照 和 克隆 的功能. 快照相当于建立一个 系统还原点, 可以随时恢复到原来状态. 克隆功能可以复制一个和当前一样的系统,并可以选择链接安装,只使用很少的空间就 ...

  8. Formdata 图片上传 Ajax

    /*图片上传*/ $("点击对象").bind("click", function(e){ $('#form-upload').remove(); $('bod ...

  9. Python内置函数(24)——set

    英文文档: class set([iterable]) Return a new set object, optionally with elements taken from iterable. s ...

  10. 新概念英语(1-35)Our village

    新概念英语(1-35)Our village Are the children coming out of the park or going into it ? This is a photogra ...