圣诞节玩的有点嗨,差点忘记更新。祝大家昨天圣诞节快乐,再过几天元旦节快乐。

来继续学习,在/home/your_name/TensorFlow/cifar10/ 下新建文件夹cifar10_train,用来保存训练时的日志logs,继续在/home/your_name/TensorFlow/cifar10/ cifar10.py中输入如下代码:

  1. def train():
  2. # global_step
  3. global_step = tf.Variable(0, name = 'global_step', trainable=False)
  4. # cifar10 数据文件夹
  5. data_dir = '/home/your_name/TensorFlow/cifar10/data/cifar-10-batches-bin/'
  6. # 训练时的日志logs文件,没有这个目录要先建一个
  7. train_dir = '/home/your_name/TensorFlow/cifar10/cifar10_train/'
  8. # 加载 images,labels
  9. images, labels = my_cifar10_input.inputs(data_dir, BATCH_SIZE)
  10.  
  11. # 求 loss
  12. loss = losses(inference(images), labels)
  13. # 设置优化算法,这里用 SGD 随机梯度下降法,恒定学习率
  14. optimizer = tf.train.GradientDescentOptimizer(LEARNING_RATE)
  15. # global_step 用来设置初始化
  16. train_op = optimizer.minimize(loss, global_step = global_step)
  17. # 保存操作
  18. saver = tf.train.Saver(tf.all_variables())
  19. # 汇总操作
  20. summary_op = tf.merge_all_summaries()
  21. # 初始化方式是初始化所有变量
  22. init = tf.initialize_all_variables()
  23.  
  24. os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
  25. config = tf.ConfigProto()
  26. # 占用 GPU 的 20% 资源
  27. config.gpu_options.per_process_gpu_memory_fraction = 0.2
  28. # 设置会话模式,用 InteractiveSession 可交互的会话,逼格高
  29. sess = tf.InteractiveSession(config=config)
  30. # 运行初始化
  31. sess.run(init)
  32.  
  33. # 设置多线程协调器
  34. coord = tf.train.Coordinator()
  35. # 开始 Queue Runners (队列运行器)
  36. threads = tf.train.start_queue_runners(sess = sess, coord = coord)
  37. # 把汇总写进 train_dir,注意此处还没有运行
  38. summary_writer = tf.train.SummaryWriter(train_dir, sess.graph)
  39.  
  40. # 开始训练过程
  41. try:
  42. for step in xrange(MAX_STEP):
  43. if coord.should_stop():
  44. break
  45. start_time = time.time()
  46. # 在会话中运行 loss
  47. _, loss_value = sess.run([train_op, loss])
  48. duration = time.time() - start_time
  49. # 确认收敛
  50. assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
  51. if step % 30 == 0:
  52. # 本小节代码设置一些花哨的打印格式,可以不用管
  53. num_examples_per_step = BATCH_SIZE
  54. examples_per_sec = num_examples_per_step / duration
  55. sec_per_batch = float(duration)
  56. format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
  57. 'sec/batch)')
  58. print (format_str % (datetime.now(), step, loss_value,
  59. examples_per_sec, sec_per_batch))
  60.  
  61. if step % 100 == 0:
  62. # 运行汇总操作, 写入汇总
  63. summary_str = sess.run(summary_op)
  64. summary_writer.add_summary(summary_str, step)
  65.  
  66. if step % 1000 == 0 or (step + 1) == MAX_STEP:
  67. # 保存当前的模型和权重到 train_dir,global_step 为当前的迭代次数
  68. checkpoint_path = os.path.join(train_dir, 'model.ckpt')
  69. saver.save(sess, checkpoint_path, global_step=step)
  70.  
  71. except Exception, e:
  72. coord.request_stop(e)
  73. finally:
  74. coord.request_stop()
  75. coord.join(threads)
  76.  
  77. sess.close()
  78.  
  79. def evaluate():
  80.  
  81. data_dir = '/home/your_name/TensorFlow/cifar10/data/cifar-10-batches-bin/'
  82. train_dir = '/home/your_name/TensorFlow/cifar10/cifar10_train/'
  83. images, labels = my_cifar10_input.inputs(data_dir, BATCH_SIZE, train = False)
  84.  
  85. logits = inference(images)
  86. saver = tf.train.Saver(tf.all_variables())
  87.  
  88. os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
  89. config = tf.ConfigProto()
  90. config.gpu_options.per_process_gpu_memory_fraction = 0.2
  91. sess = tf.InteractiveSession(config=config)
  92. coord = tf.train.Coordinator()
  93. threads = tf.train.start_queue_runners(sess = sess, coord = coord)
  94.  
  95. # 加载模型参数
  96. print("Reading checkpoints...")
  97. ckpt = tf.train.get_checkpoint_state(train_dir)
  98. if ckpt and ckpt.model_checkpoint_path:
  99. ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
  100. global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
  101. saver.restore(sess, os.path.join(train_dir, ckpt_name))
  102. print('Loading success, global_step is %s' % global_step)
  103.  
  104. try:
  105. # 对比分类结果,至于为什么用这个函数,后面详谈
  106. top_k_op = tf.nn.in_top_k(logits, labels, 1)
  107. true_count = 0
  108. step = 0
  109. while step < 157:
  110. if coord.should_stop():
  111. break
  112. predictions = sess.run(top_k_op)
  113. true_count += np.sum(predictions)
  114. step += 1
  115.  
  116. precision = true_count / 10000
  117. print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
  118. except tf.errors.OutOfRangeError:
  119. coord.request_stop()
  120. finally:
  121. coord.request_stop()
  122. coord.join(threads)
  123.  
  124. sess.close()
  125.  
  126. if __name__ == '__main__':
  127.  
  128. if TRAIN:
  129. train ()
  130. else:
  131. evaluate()

现在说明一下 in_top_k 这个函数的作用,官方文档介绍中: tf.nn.in_top_k(predictions, targets, k, name=None)这个函数返回一个 batch_size 大小的布尔矩阵 array,predictions 是一个 batch_size*classes 大小的矩阵,targets 是一个 batch_size 大小的类别 index 矩阵,这个函数的作用是,如果 targets[i] 是 predictions[i][:] 的前 k 个最大值,则返回的 array[i] = True, 否则,返回的 array[i] = False。可以看到,在上述评估程序 evaluate 中,这个函数没有用 softmax 的结果进行计算,而是用 inference 最后的输出结果(一个全连接层)进行计算。

写完之后,点击运行,可以看到,训练的 loss 值,从刚开始的 2.31 左右,下降到最终的 0.00 左右,在训练的过程中,/home/your_name/TensorFlow/cifar10/cifar10_train/ 文件夹下会出现12个文件,其中有 5 个 model.ckpt-0000 文件,这个是训练过程中保存的模型,后面的数字表示迭代次数,5 个 model.ckpt-0000.meta 文件,这个是训练过程中保存的元数据(暂时不清楚功能),TensorFlow 默认只保存近期的几个模型和几个元数据,删除前面没用的模型和元数据。还有个 checkpoint 的文本文档,和一个 out.tfevents 形式的文件,是summary 的日志文件。如果不想用 tensorboard 看网络结构和训练过程中的权重分布,损失情况等等,在程序中可以不写 summary 语句。

训练完成之后,我们用 tensorboard 进行可视化(事实上在训练的过程中,随时可以可视化)。在任意位置打开命令行终端,输入:

  1. tensorboard --logdir=/home/your_name/TensorFlow/cifar10/cifar10_train/

会出现如下指示:

根据指示,打开浏览器,输入 http://127.0.1.1:6006(有的浏览器可能不支持,建议多换几个浏览器试试)会看到可视化的界面,有六个选项卡:

EVENTS 对话框里面有两个图,一个是训练过程中的 loss 图,一个是队列 queue 的图;由于没有 image_summary() 和 audio_summary() 语句,所以,IMAGES 和 AUDIO 选项卡都没有内容;GRAPHS 选项卡包含了整个模型的流程图,如下图,可以展开和移动选定的 namespace;DISTRBUTIONS 和 HISTOGRAMS 包含了训练时的各种汇总的分布和柱状图。

训练完之后,设置 TRAIN = False,进行测试,得到如下结果:

可以看到,测试的精度只有 76%,测试结果不够高的原因可能是,测试的时候没有经过 softmax 层,直接用全连接层的权重(存疑?),另外官方的代码也给出了官方的运行结果,如下:

可以看到,经过 10 万次迭代,官方给出的正确率达到 83%,我们只进行了 5 万次,达到 76% 的正确率,相对来说,还算可以,效果没有官方好的原因可能是:

1. 官方使用了非固定的学习率;

2. 官方迭代比本代码迭代次数多一倍;

参考文献:

1. https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10

TF Boys (TensorFlow Boys ) 养成记(六)的更多相关文章

  1. TF Boys (TensorFlow Boys ) 养成记(六): CIFAR10 Train 和 TensorBoard 简介

    圣诞节玩的有点嗨,差点忘记更新.祝大家昨天圣诞节快乐,再过几天元旦节快乐. 来继续学习,在/home/your_name/TensorFlow/cifar10/ 下新建文件夹cifar10_train ...

  2. TF Boys (TensorFlow Boys ) 养成记(一)

    本资料是在Ubuntu14.0.4版本下进行,用来进行图像处理,所以只介绍关于图像处理部分的内容,并且默认TensorFlow已经配置好,如果没有配置好,请参考官方文档配置安装,推荐用pip安装.关于 ...

  3. TF Boys (TensorFlow Boys ) 养成记(一):TensorFlow 基本操作

    本资料是在Ubuntu14.0.4版本下进行,用来进行图像处理,所以只介绍关于图像处理部分的内容,并且默认TensorFlow已经配置好,如果没有配置好,请参考官方文档配置安装,推荐用pip安装.关于 ...

  4. TF Boys (TensorFlow Boys ) 养成记(五)

    有了数据,有了网络结构,下面我们就来写 cifar10 的代码. 首先处理输入,在 /home/your_name/TensorFlow/cifar10/ 下建立 cifar10_input.py,输 ...

  5. TF Boys (TensorFlow Boys ) 养成记(四)

    前面基本上把 TensorFlow 的在图像处理上的基础知识介绍完了,下面我们就用 TensorFlow 来搭建一个分类 cifar10 的神经网络. 首先准备数据: cifar10 的数据集共有 6 ...

  6. TF Boys (TensorFlow Boys ) 养成记(三)

    上次说到了 TensorFlow 从文件读取数据,这次我们来谈一谈变量共享的问题. 为什么要共享变量?我举个简单的例子:例如,当我们研究生成对抗网络GAN的时候,判别器的任务是,如果接收到的是生成器生 ...

  7. TF Boys (TensorFlow Boys ) 养成记(二)

    TensorFlow 的 How-Tos,讲解了这么几点: 1. 变量:创建,初始化,保存,加载,共享: 2. TensorFlow 的可视化学习,(r0.12版本后,加入了Embedding Vis ...

  8. TF Boys (TensorFlow Boys ) 养成记(二): TensorFlow 数据读取

    TensorFlow 的 How-Tos,讲解了这么几点: 1. 变量:创建,初始化,保存,加载,共享: 2. TensorFlow 的可视化学习,(r0.12版本后,加入了Embedding Vis ...

  9. TF Boys (TensorFlow Boys ) 养成记(三): TensorFlow 变量共享

    上次说到了 TensorFlow 从文件读取数据,这次我们来谈一谈变量共享的问题. 为什么要共享变量?我举个简单的例子:例如,当我们研究生成对抗网络GAN的时候,判别器的任务是,如果接收到的是生成器生 ...

随机推荐

  1. max10中对DDR数据的采样转换

    (1)发现IP是这样处理DDR的数据:上长沿采的数据放在低位,下降沿采的数据在高位 (2)对于视频的行场信号是在下降沿采集,再延时一拍才能与数据对齐.

  2. DrawerLayout学习,抽屉效果

    第一节: 注意事项 *主视图一定要是DrawerLayout的第一子视图 *主视图宽度和高度匹配父视图,因为当你显示主视图时,要铺满整个屏幕,用户体验度较高 *必须显示指定的抽屉视图的android: ...

  3. Android广播错误.MainActivity$MyReceiver; no empty constructor

    广播的定义,如果是内部类,必须为静态类. 下面总结一下作为内部类的广播接收者在注册的时候需要注意的地方:   1.清单文件注册广播接收者时,广播接收者的名字格式需要注意.因为是内部类,所以需要在内部类 ...

  4. Spring JDBC 访问MSSQL

    在Spring中对底层的JDBC做了浅层的封装即JdbcTemplate,在访问数据库的DAO层完全可以使用JdbcTemplate完成任何数据访问的操作,接下来我们重点说说Spring JDBC对S ...

  5. 如何进行正确的SQL性能优化

    在SQL查询中,为了提高查询的效率,我们常常采取一些措施对查询语句进行SQL性能优化.本文我们总结了一些优化措施,接下来我们就一一介绍. 1.查询的模糊匹配 尽量避免在一个复杂查询里面使用 LIKE ...

  6. Mac上Homebrew的使用 (Homebrew 使 OS X 更完整)

    0 Homebrew是啥? “Homebrew installs the stuff you need that Apple didn’t.——Homebrew 使 OS X 更完整”. Homebr ...

  7. java常用类

    String 字符串类 System 可得到系统信息 Runtime类 StringBuilder(StringBuffer)类 Thread 线程类 Math 与数学有关的工具类 Date 日期类( ...

  8. Mybatis中模糊查询的各种写法

    1. sql中字符串拼接 SELECT * FROM tableName WHERE name LIKE CONCAT(CONCAT('%', #{text}), '%'); 2. 使用 ${...} ...

  9. slick-pg v0.1.5 发布

    这个版本的更新主要是: 增加了对 json 的支持 (PostgreSQL 9.3 正式版已经发布了,所以我适时加入了对 pg json 的支持.功能其实前两个星期就已经开发测试好了,但公司跟联邦政府 ...

  10. 【原创】“借贷宝”砸钱,邀请码 GZZKZK2 (注册成功每人可得20现金,可直接提现)。。。而这只是开始

    作为IT/互联网资深码农的我,从专业技术角度剖析其流程,确认其各个环节控制严格,无欺诈嫌疑, 最佳运气邀请码 : GZZKZK2, 你在注册时值得拥有, 无邀请码无奖励, 亲一定要记住.对 APP操作 ...