1. '''
  2. Created on 2017年5月23日
  3.  
  4. @author: weizhen
  5. '''
  6. import os
  7. import tensorflow as tf
  8. from tensorflow.examples.tutorials.mnist import input_data
  9. # minist_inference中定义的常量和前向传播的函数不需要改变,
  10. # 因为前向传播已经通过tf.variable_scope实现了计算节点按照网络结构的划分
  11. import mnist_inference
  12. from mnist_train import MOVING_AVERAGE_DECAY, REGULARAZTION_RATE, \
  13. LEARNING_RATE_BASE, BATCH_SIZE, LEARNING_RATE_DECAY, TRAINING_STEPS, MODEL_SAVE_PATH, MODEL_NAME
  14. INPUT_NODE = 784
  15. OUTPUT_NODE = 10
  16. LAYER1_NODE = 500
  17. def train(mnist):
  18. # 将处理输入数据集的计算都放在名子为"input"的命名空间下
  19. with tf.name_scope("input"):
  20. x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
  21. y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-cinput')
  22. regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
  23. y = mnist_inference.inference(x, regularizer)
  24. global_step = tf.Variable(0, trainable=False)
  25.  
  26. # 将滑动平均相关的计算都放在名为moving_average的命名空间下
  27. with tf.name_scope("moving_average"):
  28. variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
  29. variable_averages_op = variable_averages.apply(tf.trainable_variables())
  30.  
  31. # 将计算损失函数相关的计算都放在名为loss_function的命名空间下
  32. with tf.name_scope("loss_function"):
  33. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
  34. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  35. loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
  36.  
  37. # 将定义学习率、优化方法以及每一轮训练需要执行的操作都放在名子为"train_step"的命名空间下
  38. with tf.name_scope("train_step"):
  39. learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,
  40. global_step,
  41. mnist.train._num_examples / BATCH_SIZE,
  42. LEARNING_RATE_DECAY,
  43. staircase=True)
  44. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  45.  
  46. with tf.control_dependencies([train_step, variable_averages_op]):
  47. train_op = tf.no_op(name='train')
  48.  
  49. # 训练模型。
  50. with tf.Session() as sess:
  51. tf.global_variables_initializer().run()
  52. for i in range(TRAINING_STEPS):
  53. xs, ys = mnist.train.next_batch(BATCH_SIZE)
  54.  
  55. if i % 1000 == 0:
  56. # 配置运行时需要记录的信息。
  57. run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
  58. # 运行时记录运行信息的proto。
  59. run_metadata = tf.RunMetadata()
  60. _, loss_value, step = sess.run(
  61. [train_op, loss, global_step], feed_dict={x: xs, y_: ys},
  62. options=run_options, run_metadata=run_metadata)
  63. print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
  64. writer = tf.summary.FileWriter("/log/modified_mnist_train.log", tf.get_default_graph())
  65. writer.add_run_metadata(run_metadata, "stop%03d" % i)
  66. writer.close()
  67. print("After %d training steps(s),loss on training batch is %g."%(step,loss_value))
  68. else:
  69. _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
  70. # 初始化Tensorflow持久化类
  71. # saver = tf.train.Saver()
  72. # with tf.Session() as sess:
  73. # tf.global_variables_initializer().run()
  74. #
  75. # 在训练过程中不再测试模型在验证数据上的表现,验证和测试的过程将会有一个独立的程序来完成
  76. # for i in range(TRAINING_STEPS):
  77. # xs, ys = mnist.train.next_batch(BATCH_SIZE)
  78. # _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x:xs, y_:ys})
  79.  
  80. # 每1000轮保存一次模型
  81. # if i % 1000 == 0:
  82. # 输出当前训练情况。这里只输出了模型在当前训练batch上的损失函数大小
  83. # 通过损失函数的大小可以大概了解训练的情况。在验证数据集上的正确率信息
  84. # 会有一个单独的程序来生成
  85. # print("After %d training step(s),loss on training batch is %g" % (step, loss_value))
  86.  
  87. # 保存当前的模型。注意这里给出了global_step参数,这样可以让每个被保存模型的文件末尾加上训练的轮数
  88. # 比如"model.ckpt-1000"表示训练1000轮之后得到的模型
  89. # saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
  90.  
  91. # 将当前的计算图输出到TensorBoard日志文件
  92. # writer=tf.summary.FileWriter("/path/to/log",tf.get_default_graph())
  93. # writer.close()
  94.  
  95. def main(argv=None):
  96. mnist = input_data.read_data_sets("/tmp/data", one_hot=True)
  97. train(mnist)
  98.  
  99. if __name__ == '__main__':
  100. tf.app.run()

88、使用tensorboard进行可视化学习,查看具体使用时间,训练轮数,使用内存大小的更多相关文章

  1. 87、使用TensorBoard进行可视化学习

    1.还是以手写识别为类,至于为什么一直用手写识别这个例子,原因很简单,因为书上只给出了这个类子呀,哈哈哈,好神奇 下面是可视化学习的标准函数 ''' Created on 2017年5月23日 @au ...

  2. linux学习--查看操作系统版本及cpu及内存信息

    查看版本当前操作系统内核信息 uname -a 查看当前操作系统版本信息 cat  /proc/version 查看物理cpu个数: cat /proc/cpuinfo| grep "phy ...

  3. 可视化学习Tensorboard

    可视化学习Tensorboard TensorBoard 涉及到的运算,通常是在训练庞大的深度神经网络中出现的复杂而又难以理解的运算.为了更方便 TensorFlow 程序的理解.调试与优化,发布了一 ...

  4. TensorBoard:可视化学习

    数据序列化 TensorBoard 通过读取 TensorFlow 的事件文件来运行.TensorFlow 的事件文件包括了你会在 TensorFlow 运行中涉及到的主要数据.下面是 TensorB ...

  5. Tensorflow学习笔记3:TensorBoard可视化学习

    TensorBoard简介 Tensorflow发布包中提供了TensorBoard,用于展示Tensorflow任务在计算过程中的Graph.定量指标图以及附加数据.大致的效果如下所示, Tenso ...

  6. Pytorch在colab和kaggle中使用TensorBoard/TensorboardX可视化

    在colab和kaggle内核的Jupyter notebook中如何可视化深度学习模型的参数对于我们分析模型具有很大的意义,相比tensorflow, pytorch缺乏一些的可视化生态包,但是幸好 ...

  7. 使用 TensorBoard 可视化模型、数据和训练

    使用 TensorBoard 可视化模型.数据和训练 在 60 Minutes Blitz 中,我们展示了如何加载数据,并把数据送到我们继承 nn.Module 类的模型,在训练数据上训练模型,并在测 ...

  8. R语言可视化学习笔记之添加p-value和显著性标记

    R语言可视化学习笔记之添加p-value和显著性标记 http://www.jianshu.com/p/b7274afff14f?from=timeline   上篇文章中提了一下如何通过ggpubr ...

  9. Tensorflow搭建神经网络及使用Tensorboard进行可视化

    创建神经网络模型 1.构建神经网络结构,并进行模型训练 import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt ...

随机推荐

  1. 服务器上的 IPProxy代理设置

    1.window 平台 CCProxy 安装包 传送门: http://www.xue51.com/soft/2794.html 该页面详细的说明了ccproxy怎么安装.怎么破jie.... 下面老 ...

  2. 剑指offer---4、序列化二叉树

    剑指offer---4.序列化二叉树 一.总结 一句话总结: 1. 对于序列化:使用前序遍历,递归的将二叉树的值转化为字符,并且在每次二叉树的结点不为空时,在转化val所得的字符之后添加一个' , ' ...

  3. 用 Flask 来写个轻博客 (32) — 使用 Flask-RESTful 来构建 RESTful API 之一

    目录 目录 前文列表 扩展阅读 RESTful API REST 原则 无状态原则 面向资源 RESTful API 的优势 REST 约束 前文列表 用 Flask 来写个轻博客 (1) - 创建项 ...

  4. linux执行时间段内日志关键字搜索

    sed -n '/起始时间/,/结束时间/p' 日志文件 | grep '关键字' 查询文件debug.log在2019-11-18 08:00:00~2019-11-18 08:21:00时间段内e ...

  5. Workbox使用策略

    1.什么是Workbox Strategies? 当service workers 首次被引入时,可以设定一组常见的缓存策略. 缓存策略是一种模式,用于确定service workers 在收到fet ...

  6. UVA - 11624 J - Fire! (BFS)

    题目传送门 J - Fire! Joe works in a maze. Unfortunately, portions of the maze have caught on fire, and the ...

  7. Opencv3.3(Linux)编译安装至python的坑

    编译安装OpenCV绝对是一件让人发狂的事情,CMake繁多的选项,国内蛋疼的网速,实在让人无力吐槽,然而为了使用contrib包,我不得不重新编译他. OpenCV的编译 其实OpenCV编译并不是 ...

  8. 绘图matplotlib

    前言 matplotlib是python的一个绘图库,如果你没有绘制过图,可以先试试js的绘图库http://www.runoob.com/highcharts/highcharts-line-lab ...

  9. 24.循环栅栏 CyclicBarrier

    import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CyclicBarrier; /** * ...

  10. python3.x 匿名函数lambda_扩展sort

    #匿名函数lambda 参数: 表达式关键字 lambda 说明它是一个匿名函数,冒号 : 前面的变量是该匿名函数的参数,冒号后面是函数的返回值,注意这里不需使用 return 关键字. ambda只 ...