分为三个文件:mnist_inference.py:定义前向传播的过程以及神经网络中的参数,抽象成为一个独立的库函数;mnist_train.py:定义神经网络的训练过程,在此过程中,每个一段时间保存一次模型训练的中间结果;mnist_eval.py:定义测试过程。

  1. mnist_inference.py
  1. #coding=utf8
  2. import tensorflow as tf
  3.  
  4. #1. 定义神经网络结构相关的参数。
  5.  
  6. INPUT_NODE = 784
  7. OUTPUT_NODE = 10
  8. LAYER1_NODE = 500
  9.  
  10. #2. 通过tf.get_variable函数来获取变量。
  11. def get_weight_variable(shape, regularizer):
  12. weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))
  13. if regularizer != None: tf.add_to_collection('losses', regularizer(weights))
  14. return weights
  15.  
  16. #3. 定义神经网络的前向传播过程。使用命名空间方式,不需要把所有的变量都作为变量传递到不同的函数中提高程序的可读性
  17.  
  18. def inference(input_tensor, regularizer):
  19. with tf.variable_scope('layer1'):
  20.  
  21. weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
  22. biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
  23. layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
  24.  
  25. with tf.variable_scope('layer2'):
  26. weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
  27. biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
  28. layer2 = tf.matmul(layer1, weights) + biases
  29.  
  30. return layer2
  31.  
  1. mnist_train.py

#coding=utf8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference

import os

  1. #1. 定义神经网络结构相关的参数。
  2. BATCH_SIZE = 100
  3. LEARNING_RATE_BASE = 0.8
  4. LEARNING_RATE_DECAY = 0.99
  5. REGULARIZATION_RATE = 0.0001
  6. TRAINING_STEPS = 30000
  7. MOVING_AVERAGE_DECAY = 0.99
  8. MODEL_SAVE_PATH="MNIST_model/"
  9. MODEL_NAME="mnist_model"
  10. #2. 定义训练过程。
  11. def train(mnist):
  12. # 定义输入输出placeholder。
  13. x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
  14. y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
  15. regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
  16. y = mnist_inference.inference(x, regularizer)
  17. global_step = tf.Variable(0, trainable=False)
  18. # 定义损失函数、学习率、滑动平均操作以及训练过程。
  19. variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
  20. variables_averages_op = variable_averages.apply(tf.trainable_variables())
  21. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
  22. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  23. loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
  24. learning_rate = tf.train.exponential_decay(
  25. LEARNING_RATE_BASE,
  26. global_step,
  27. mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,
  28. staircase=True)
  29. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  30. with tf.control_dependencies([train_step, variables_averages_op]):
  31. train_op = tf.no_op(name='train')
  32. # 初始化TensorFlow持久化类。
  33. saver = tf.train.Saver()
  34. with tf.Session() as sess:
  35. tf.global_variables_initializer().run()
  36. for i in range(TRAINING_STEPS):
  37. xs, ys = mnist.train.next_batch(BATCH_SIZE)
  38. _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
  39. if i % 1000 == 0:
  40. print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
  41. saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
  42. def main(argv=None):
  43. mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
  44. train(mnist)
  45. if __name__ == '__main__':
  46. main()
  47. 结果如下:

mnist_eval.py:

  1.  
  2. import time
  3. import tensorflow as tf
  4. from tensorflow.examples.tutorials.mnist import input_data
  5. import mnist_inference
  6. #coding=utf8
  7. import mnist_train
  8.  
  9. #1. 每10秒加载一次最新的模型
  10.  
  11. # 加载的时间间隔。
  12. EVAL_INTERVAL_SECS = 10
  13.  
  14. def evaluate(mnist):
  15. with tf.Graph().as_default() as g:
  16. x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
  17. y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
  18. validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
  19.  
  20. y = mnist_inference.inference(x, None)
  21. correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  22. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  23.  
  24. variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
  25. variables_to_restore = variable_averages.variables_to_restore()
  26. saver = tf.train.Saver(variables_to_restore)
  27.  
  28. while True:
  29. with tf.Session() as sess:
  30. ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)
  31. if ckpt and ckpt.model_checkpoint_path:
  32. saver.restore(sess, ckpt.model_checkpoint_path)
  33. global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
  34. accuracy_score = sess.run(accuracy, feed_dict=validate_feed)
  35. print("After %s training step(s), validation accuracy = %g" % (global_step, accuracy_score))
  36. else:
  37. print('No checkpoint file found')
  38. return
  39. time.sleep(EVAL_INTERVAL_SECS)
  40.  
  41. def main(argv=None):
  42. mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
  43. evaluate(mnist)
  44.  
  45. if __name__ == '__main__':
  46. main()

结果如下:

Tensorflow 解决MNIST问题的重构程序的更多相关文章

  1. 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门

    2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...

  2. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  3. 一个简单的TensorFlow可视化MNIST数据集识别程序

    下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...

  4. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  5. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  6. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  7. win10下通过Anaconda安装TensorFlow-GPU1.3版本,并配置pycharm运行Mnist手写识别程序

    折腾了一天半终于装好了win10下的TensorFlow-GPU版,在这里做个记录. 准备安装包: visual studio 2015: Anaconda3-4.2.0-Windows-x86_64 ...

  8. Tensorflow之MNIST的最佳实践思路总结

    Tensorflow之MNIST的最佳实践思路总结   在上两篇文章中已经总结出了深层神经网络常用方法和Tensorflow的最佳实践所需要的知识点,如果对这些基础不熟悉,可以返回去看一下.在< ...

  9. tensorflow处理mnist(二)

    用卷积神经网络解决mnist的分类问题. 简单的例子 一行一行解释这个代码. 这个不是google官方的例子,但是很简洁,便于入门.tensorflow是先定义模型,最后赋值,计算.为了讨论问题方便, ...

随机推荐

  1. 【Golang 接口自动化05】使用yml管理自动化用例

    我们在前面几篇文章中学习怎么发送数据请求,怎么处理解析接口返回的结果,接下来我们一起来学习怎么进行测试用例管理,今天我们介绍的是使用yml文件进行用例管理,所以首先我们一起来了解一下YAML和它的简单 ...

  2. 云服务器ECS挖矿木马病毒处理和解决方案

    云服务器ECS挖矿木马病毒处理和解决方案 最近由于网络环境安全意识低的原因,导致一些云服务器ECS中了挖矿病毒的坑. 总结了一些解决挖矿病毒的一些思路.由于病毒更新速度快仅供参考. 1.查看cpu爆满 ...

  3. Confluence 6 使用 LDAP 授权连接一个内部目录 - Schema 设置

    基本 DN(Base DN) 根专有名称(DN),这个名称在你对目录服务器上进行查询的时候使用.例如: o=example,c=com cn=users,dc=ad,dc=example,dc=com ...

  4. HDU-1232 畅通工程 (并查集、判断图中树的棵数)

    Description 某省调查城镇交通状况,得到现有城镇道路统计表,表中列出了每条道路直接连通的城镇.省政府“畅通工程”的目标是使全省任何两个城镇间都可以实现交通(但不一定有直接的道路相连,只要互相 ...

  5. Oracle11g温习-第一章 2、ORACLE 物理结构

    2013年4月27日 星期六 10:26 物理操作系统文件的集合.主要包括: 控制文件(参数文件init$ORACLE_SID.ora记录了控制文件的位置) 二进制文件,控制文件由参数control_ ...

  6. Pandas DataFrame 数据选取和过滤

    This would allow chaining operations like: pd.read_csv('imdb.txt') .sort(columns='year') .filter(lam ...

  7. 简话Angular 05 Angular表单验证

    一句话: 可以使用所有html5表单验证功能,同时Angular还增强了部分验证,支持动态验证 1. 上源码 <div ng-controller="ExampleController ...

  8. xmind visio mindmanager edraw比较

    xmind visio mindmanager edraw比较   xmind visio mindmanager Edraw 中心主题 有 无 有   泳道图 无 有 有   结构上讲 [思维导图] ...

  9. ECC算法整理纪要

    初始ECC算法 1.用户A 密钥生成 (1):用随机数发生器产生随机数k∈[1,n-1]: (2):计算椭圆曲线点PA=[k]G,为公钥,k为用户A私钥: 2. 用户B加密算法及流程 设需要发送的消息 ...

  10. jsp jsp常用指令

    jsp指令是为jsp引擎设计的,他们并不直接产生任何可见输出,而只是告诉引擎如何处理jsp页面中的其余部分. jsp中的指令 page指令 include指令 taglib指令 jsp指令的基本语法 ...