主要是四个文件

mnist_train.py

  1. #coding: utf-8
  2. import os
  3.  
  4. import tensorflow as tf
  5. from tensorflow.examples.tutorials.mnist import input_data
  6.  
  7. import mnist_inference
  8.  
  9. BATCH_SIZE = 100
  10. LEARNING_RATE_BASE = 0.8
  11. LEARNING_RATE_DECAY = 0.99
  12. REGULARAZTION_RATE = 0.0001
  13. TRAINING_STEPS =10000
  14. MOVING_AVERAGE_DECAY = 0.99
  15. MODEL_SAVE_PATH = "./mobilenet_v1_model/"
  16. MODEL_NAME = "model.ckpt"
  17. channels = 1
  18.  
  19. def train_MLP(mnist):
  20. x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
  21. y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
  22. regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
  23.  
  24. y = mnist_inference.inference_MLP(x, regularizer)
  25.  
  26. global_step = tf.Variable(0, trainable=False)
  27.  
  28. variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
  29. variable_averages_op = variable_averages.apply(tf.trainable_variables())
  30. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
  31. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  32. loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
  33. learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY)
  34. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  35.  
  36. with tf.control_dependencies([train_step, variable_averages_op]):
  37. train_op = tf.no_op(name='train')
  38.  
  39. saver = tf.train.Saver()
  40. with tf.Session() as sess:
  41. tf.initialize_all_variables().run()
  42.  
  43. for i in range(TRAINING_STEPS):
  44. xs, ys = mnist.train.next_batch(BATCH_SIZE)
  45. _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
  46.  
  47. if i % 1000 == 0:
  48. print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
  49. # print os.path.join(MODEL_SAVE_PATH, MODEL_NAME)
  50. saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
  51.  
  52. def train_mobilenet(mnist):
  53. x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
  54. y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
  55. regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
  56.  
  57. #mobilenet 把输入数据变成与w矩阵同纬度的
  58. x_image = tf.reshape(x, [-1,28,28,1])
  59. x_image = tf.image.resize_image_with_crop_or_pad(x_image, 28*4,28*4)
  60. y = mnist_inference.inference_mobilenet(x_image, regularizer)
  61.  
  62. global_step = tf.Variable(0, trainable=False)
  63.  
  64. variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
  65. variable_averages_op = variable_averages.apply(tf.trainable_variables())
  66. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
  67. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  68. loss = cross_entropy_mean #+ tf.add_n(tf.get_collection('losses'))
  69. learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY)
  70. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  71.  
  72. with tf.control_dependencies([train_step, variable_averages_op]):
  73. train_op = tf.no_op(name='train')
  74.  
  75. saver = tf.train.Saver()
  76. with tf.Session() as sess:
  77. tf.initialize_all_variables().run()
  78.  
  79. for i in range(TRAINING_STEPS):
  80. xs, ys = mnist.train.next_batch(BATCH_SIZE)
  81. _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
  82.  
  83. if i % 1000 == 0:
  84. print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
  85. # print os.path.join(MODEL_SAVE_PATH, MODEL_NAME)
  86. saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
  87. else:
  88. print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
  89.  
  90. def main(argv=None):
  91. mnist = input_data.read_data_sets("../MNIST_data", one_hot=True)
  92. train_mobilenet(mnist)
  93.  
  94. if __name__ == '__main__':
  95. tf.app.run()

mnist_eval.py

  1. #coding: utf-8
  2. import time
  3. import tensorflow as tf
  4. from tensorflow.examples.tutorials.mnist import input_data
  5.  
  6. import mnist_inference
  7. import mnist_train
  8.  
  9. #every 10 sec load the newest model
  10. EVAL_INTERVAL_SECS = 10
  11.  
  12. def evaluate_MLP(mnist):
  13. with tf.Graph().as_default() as g:
  14. x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
  15. y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
  16. validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
  17.  
  18. y = mnist_inference.inference(x, None)
  19.  
  20. correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  21. accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32))
  22.  
  23. variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
  24. variable_to_restore = variable_averages.variables_to_restore()
  25. saver = tf.train.Saver(variable_to_restore)
  26.  
  27. #while True:
  28. if 1:
  29. with tf.Session() as sess:
  30. ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)
  31. if ckpt and ckpt.model_checkpoint_path:
  32. #load the model
  33. saver.restore(sess, ckpt.model_checkpoint_path)
  34. global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
  35. accuracy_score = sess.run(accuracy, feed_dict=validate_feed)
  36. print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score))
  37.  
  38. else:
  39. print('No checkpoint file found')
  40. return
  41. #time.sleep(EVAL_INTERVAL_SECS)
  42.  
  43. def evaluate_mobilenet(mnist):
  44. with tf.Graph().as_default() as g:
  45. x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
  46. y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
  47.  
  48. #mobilenet 把输入数据变成与w矩阵同纬度的
  49. x_image = tf.reshape(x, [-1,28,28,1])
  50. x_image = tf.image.resize_image_with_crop_or_pad(x_image, 28*4,28*4)
  51. y = mnist_inference.inference_mobilenet(x_image, None)
  52.  
  53. correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  54. accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32))
  55.  
  56. variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
  57. variable_to_restore = variable_averages.variables_to_restore()
  58. saver = tf.train.Saver(variable_to_restore)
  59.  
  60. input = mnist.validation.images
  61. label = mnist.validation.labels
  62. batch_size = 100
  63. TEST_STEPS = input.shape[0] / batch_size
  64. sum_accury = 0.0
  65. #while True:
  66. if 1:
  67. with tf.Session() as sess:
  68. ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)
  69. if ckpt and ckpt.model_checkpoint_path:
  70. #load the model
  71. saver.restore(sess, ckpt.model_checkpoint_path)
  72. global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
  73. for i in range(int(TEST_STEPS)):
  74. input_batch = input[i*batch_size : (i + 1)*batch_size, :]
  75. label_batch = label[i*batch_size : (i + 1)*batch_size, :]
  76. validate_feed = {x: input_batch, y_: label_batch}
  77. # 取出部分数据测试
  78. accuracy_score = sess.run(accuracy, feed_dict=validate_feed)
  79. sum_accury += accuracy_score
  80. print("test %s batch steps, validation accuracy = %g" % (i, accuracy_score))
  81.  
  82. else:
  83. print('No checkpoint file found')
  84. return
  85. #time.sleep(EVAL_INTERVAL_SECS)
  86. print("After %s training steps, all validation accuracy = %g" % (global_step, sum_accury / TEST_STEPS))
  87.  
  88. def main(argv=None):
  89. mnist = input_data.read_data_sets("../MNIST_data", one_hot=True)
  90. evaluate_mobilenet(mnist)
  91.  
  92. if __name__ == '__main__':
  93. tf.app.run()

mnist_inference.py

  1. #coding: utf-8
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5.  
  6. import numpy as np
  7. import tensorflow as tf
  8.  
  9. import mobilenet_v1
  10.  
  11. slim = tf.contrib.slim
  12.  
  13. #define the variables of nerual network
  14. INPUT_NODE = 784
  15. OUTPUT_NODE = 10
  16. LAYER1_NODE = 500
  17.  
  18. def get_weight_variable(shape, regularizer):
  19. weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))
  20.  
  21. if regularizer != None:
  22. tf.add_to_collection('losses', regularizer(weights))
  23.  
  24. return weights
  25.  
  26. #define the forward network with MLPnet
  27. def inference_MLP(input_tensor, regularizer):
  28. with tf.variable_scope('layer1'):
  29. weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
  30. biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
  31. layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
  32.  
  33. with tf.variable_scope('layer2'):
  34. weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
  35. biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
  36. layer2 = tf.matmul(layer1, weights) + biases
  37.  
  38. return layer2
  39.  
  40. #define the forward network with mobilenet_v1
  41. def inference_mobilenet(input_tensor, regularizer):
  42. #inputs = tf.random_uniform((batch_size, height, width, 3))
  43. with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
  44. normalizer_fn=slim.batch_norm):
  45. logits, end_points = mobilenet_v1.mobilenet_v1(
  46. input_tensor,
  47. num_classes=OUTPUT_NODE,
  48. dropout_keep_prob=0.8,
  49. is_training=True,
  50. min_depth=8,
  51. depth_multiplier=1.0,
  52. conv_defs=None,
  53. prediction_fn=tf.contrib.layers.softmax,
  54. spatial_squeeze=True,
  55. reuse=None,
  56. scope='MobilenetV1',
  57. global_pool=False
  58. )
  59.  
  60. return logits

mobilenet_v1.py

从此处下载

https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py

TensorFlow基础笔记(13) Mobilenet训练测试mnist数据的更多相关文章

  1. TensorFlow基础笔记(13) tf.name_scope tf.variable_scope学习

    转载http://blog.csdn.net/jerr__y/article/details/60877873 1. 首先看看比较简单的 tf.name_scope(‘scope_name’). tf ...

  2. TensorFlow基础笔记(0) 参考资源学习文档

    1 官方文档 https://www.tensorflow.org/api_docs/ 2 极客学院中文文档 http://www.tensorfly.cn/tfdoc/api_docs/python ...

  3. TensorFlow基础笔记(3) cifar10 分类学习

    TensorFlow基础笔记(3) cifar10 分类学习 CIFAR-10 is a common benchmark in machine learning for image recognit ...

  4. 机器学习实战 - 读书笔记(13) - 利用PCA来简化数据

    前言 最近在看Peter Harrington写的"机器学习实战",这是我的学习心得,这次是第13章 - 利用PCA来简化数据. 这里介绍,机器学习中的降维技术,可简化样品数据. ...

  5. TensorFlow学习笔记——LeNet-5(训练自己的数据集)

    在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...

  6. Tensorflow学习笔记(一):MNIST机器学习入门

    学习深度学习,首先从深度学习的入门MNIST入手.通过这个例子,了解Tensorflow的工作流程和机器学习的基本概念. 一  MNIST数据集 MNIST是入门级的计算机视觉数据集,包含了各种手写数 ...

  7. TensorFlow基础笔记(2) minist分类学习

    (1) 最简单的神经网络分类器 # encoding: UTF-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist i ...

  8. tensorflow学习笔记3:写一个mnist rpc服务

    本篇做一个没有实用价值的mnist rpc服务,重点记录我在调试整合tensorflow和opencv时遇到的问题: 准备模型 mnist的基础模型结构就使用tensorflow tutorial给的 ...

  9. TensorFlow基础笔记(14) 网络模型的保存与恢复_mnist数据实例

    http://blog.csdn.net/huachao1001/article/details/78502910 http://blog.csdn.net/u014432647/article/de ...

随机推荐

  1. LeetCode总结 -- 一维数据合并篇

    合并是一维数据结构中非经常见的操作, 一般是排序, 分布式算法中的子操作. 这篇总结主要介绍LeetCode中关于合并的几个题目: Merge Two Sorted ListsMerge Sorted ...

  2. 【JUnit4.10源码分析】3.4 Description与測试树

    Description使用组合模式描写叙述一个測试树.组合模式中全部元素都是Composite对象. Description有成员变量private final ArrayList<Descri ...

  3. Flume入门

    1.Flume是什么? ○ Flume是由cloudera开发的实时日志收集系统    ○ 核心概念是由一个叫做Agent(代理节点)的java进程运行在日志收集节点    ○ Flume在0.94. ...

  4. "DISTINCT" make huge difference

    继上一篇提到的UNION/UNION ALL会影响执行计划,再次碰到一个类似的问题.一个SQL加了DISTINCT跟不加DISTINCT的执行计划完全不同,导致执行时间差了好多倍. 原始的SQL如下所 ...

  5. CCTableView(一)

    #ifndef __TABLEVIEWTESTSCENE_H__ #define __TABLEVIEWTESTSCENE_H__ #include "cocos2d.h" #in ...

  6. C/C++预定义宏

    编译器识别预定义的 ANSI/ISO C99 C 预处理宏,Microsoft C++ 实现将提供更多宏.这些预处理器宏不带参数,并且不能重新定义. ANSI 兼容的预定义宏 __FILE__,__L ...

  7. rzsz安装【转】

    环境:CentOS 发生情况:需要安装工具:szrz 工具进行 windows 和linux传文件 安装方式:从网上其他教程找的所以就按照如下方式操作 1. 下载软件 rzsz-3.34.tar.gz ...

  8. 过滤4字节及以上的字符c++实现

    这个是根据php的一个版本改的,用来处理utf-8编码的多字节字符,比如中文,俄文等等. #include <iostream> #include <string> int s ...

  9. 进制转化之递归 && 栈

    将10进制转换成2进制,是除以2得到的余数,再倒序排列,这可以用递归实现,也可以用数据结构——栈实现. 先看递归实现: #include<stdio.h> void to_two(int ...

  10. c++之——抽象基类

    在一个虚函数的声明语句的分号前加上 =0:就可以将一个虚函数变成纯虚函数,其中,=0只能出现在类内部的虚函数声明语句处.纯虚函数只用声明,而不用定义,其存在就是为了提供接口,含有纯虚函数的类是抽象基类 ...