目录

Batch Normalization笔记

我们将会用MNIST数据集来演示这个batch normalization的使用, 以及他所带来的效果:

引包

  1. import tensorflow as tf
  2. import os
  3. from tensorflow.examples.tutorials.mnist import input_data
  4. from tensorflow.contrib.layers import flatten
  5. import numpy as np
  6. import tensorflow.contrib.slim as slim

构建模型:

  1. def model1(input, is_training, keep_prob):
  2. input = tf.reshape(input, shape=[-1, 28, 28, 1])
  3. batch_norm_params = {
  4. 'decay': 0.95,
  5. 'updates_collections': None
  6. }
  7. with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training):
  8. with slim.arg_scope([slim.conv2d, slim.fully_connected],
  9. weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
  10. normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params,
  11. activation_fn=tf.nn.crelu):
  12. conv1 = slim.conv2d(input, 16, 5, scope='conv1')
  13. pool1 = slim.max_pool2d(conv1, 2, scope='pool1')
  14. conv2 = slim.conv2d(pool1, 32, 5, scope='conv2')
  15. pool2 = slim.max_pool2d(conv2, 2, scope='pool2')
  16. flatten = slim.flatten(pool2)
  17. fc = slim.fully_connected(flatten, 1024, scope='fc1')
  18. print(fc.get_shape())
  19. drop = slim.dropout(fc, keep_prob=keep_prob)
  20. logits = slim.fully_connected(drop, 10, activation_fn=None, scope='logits')
  21. return logits
  1. def model2(input, is_training, keep_prob):
  2. input = tf.reshape(input, shape=[-1, 28, 28, 1])
  3. with slim.arg_scope([slim.conv2d, slim.fully_connected],
  4. weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
  5. normalizer_fn=None, activation_fn=tf.nn.crelu):
  6. with slim.arg_scope([slim.dropout], is_training=is_training):
  7. conv1 = slim.conv2d(input, 16, 5, scope='conv1')
  8. pool1 = slim.max_pool2d(conv1, 2, scope='pool1')
  9. conv2 = slim.conv2d(pool1, 32, 5, scope='conv2')
  10. pool2 = slim.max_pool2d(conv2, 2, scope='pool2')
  11. flatten = slim.flatten(pool2)
  12. fc = slim.fully_connected(flatten, 1024, scope='fc1')
  13. print(fc.get_shape())
  14. drop = slim.dropout(fc, keep_prob=keep_prob)
  15. logits = slim.fully_connected(drop, 10, activation_fn=None, scope='logits')
  16. return logits

构建训练函数

  1. def train(model, model_path, train_log_path, test_log_path):
  2. # 计算图
  3. graph = tf.Graph()
  4. with graph.as_default():
  5. X = tf.placeholder(dtype=tf.float32, shape=[None, 28 * 28])
  6. Y = tf.placeholder(dtype=tf.float32, shape=[None, 10])
  7. is_training = tf.placeholder(dtype=tf.bool)
  8. logit = model(X, is_training, 0.7)
  9. loss =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=Y))
  10. accuray = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logit, 1), tf.argmax(Y, 1)), tf.float32))
  11. global_step = tf.Variable(0, trainable=False)
  12. learning_rate = tf.train.exponential_decay(0.1, global_step, 1000, 0.95, staircase=True)
  13. optimizer = tf.train.AdagradOptimizer(learning_rate=learning_rate)
  14. update = slim.learning.create_train_op(loss, optimizer, global_step)
  15. mnist = input_data.read_data_sets("tmp", one_hot=True)
  16. saver = tf.train.Saver()
  17. tf.summary.scalar("loss", loss)
  18. tf.summary.scalar("accuracy", accuray)
  19. merged_summary_op = tf.summary.merge_all()
  20. train_summary_writter = tf.summary.FileWriter(train_log_path, graph=tf.get_default_graph())
  21. test_summary_writter = tf.summary.FileWriter(test_log_path, graph=tf.get_default_graph())
  22. init = tf.global_variables_initializer()
  23. iter_num = 10000
  24. batch_size = 1024
  25. os.environ["CUDA_VISIBLE_DEVICES"] = '2' # 选择cuda的设备
  26. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2) # gpu显存使用
  27. with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
  28. sess.run(init)
  29. if not os.path.exists(os.path.dirname(model_path)):
  30. os.makedirs(os.path.dirname(model_path))
  31. else:
  32. try:
  33. saver.restore(sess, model_path)
  34. except:
  35. pass
  36. for i in range(iter_num):
  37. x, y = mnist.train.next_batch(batch_size)
  38. sess.run(update, feed_dict={X:x, Y:y, is_training:True})
  39. if i % 100 == 0:
  40. x_test, y_test = mnist.test.next_batch(batch_size)
  41. print("train:", sess.run(accuray, feed_dict={X: x, Y: y, is_training:False}))
  42. print("test:", sess.run(accuray, feed_dict={X: x_test, Y: y_test, is_training:False}))
  43. saver.save(sess, model_path)
  44. g, summary = sess.run([global_step, merged_summary_op], feed_dict={X: x, Y: y, is_training:False})
  45. train_summary_writter.add_summary(summary, g)
  46. train_summary_writter.flush()
  47. g, summary = sess.run([global_step, merged_summary_op], feed_dict={X: x_test, Y: y_test, is_training:False})
  48. test_summary_writter.add_summary(summary, g)
  49. test_summary_writter.flush()
  50. train_summary_writter.close()
  51. test_summary_writter.close()

下面我们来进行计算:

  1. train(model1, "model1/model", "model1_train_log", "model1_test_log")
  1. train(model2, "model2/model", "model2_train_log", "model2_test_log")

结论

我们发现, 加了batch norm的似乎收敛的更快一些, 这个我们可以从对比上可以很清楚的看到, 所以这个bn是我们一个很好的技术, 前提是你选的参数比较适合.

以下是两个注意点:

The keys to use batch normalization in slim are:

Set proper decay rate for BN layer. Because a BN layer uses EMA (exponential moving average) to approximate the population mean/variance, it takes sometime to warm up, i.e. to get the EMA close to real population mean/variance. The default decay rate is 0.999, which is kind of high for our little cute MNIST dataset and needs ~1000 steps to get a good estimation. In my code, decay is set to 0.95, then it learns the population statistics very quickly. However, a large value of decay does have it own advantage: it gathers information from more mini-batches thus is more stable.

Use slim.learning.create_train_op to create train op instead of tf.train.GradientDescentOptimizer(0.1).minimize(loss) or something else!.

深度学习中batch normalization的更多相关文章

  1. 深度学习中 Batch Normalization

    深度学习中 Batch Normalization为什么效果好?(知乎) https://www.zhihu.com/question/38102762

  2. 深度学习中 Batch Normalization为什么效果好

    看mnist数据集上其他人的CNN模型时了解到了Batch Normalization 这种操作.效果还不错,至少对于训练速度提升了很多. batch normalization的做法是把数据转换为0 ...

  3. zz详解深度学习中的Normalization,BN/LN/WN

    详解深度学习中的Normalization,BN/LN/WN 讲得是相当之透彻清晰了 深度神经网络模型训练之难众所周知,其中一个重要的现象就是 Internal Covariate Shift. Ba ...

  4. 深度学习中的Normalization模型

    Batch Normalization(简称 BN)自从提出之后,因为效果特别好,很快被作为深度学习的标准工具应用在了各种场合.BN 大法虽然好,但是也存在一些局限和问题,诸如当 BatchSize ...

  5. [优化]深度学习中的 Normalization 模型

    来源:https://www.chainnews.com/articles/504060702149.htm 机器之心专栏 作者:张俊林 Batch Normalization (简称 BN)自从提出 ...

  6. 深度学习之Batch Normalization

    在机器学习领域中,有一个重要的假设:独立同分布假设,也就是假设训练数据和测试数据是满足相同分布的,否则在训练集上学习到的模型在测试集上的表现会比较差.而在深层神经网络的训练中,当中间神经层的前一层参数 ...

  7. 深度学习中优化【Normalization】

    深度学习中优化操作: dropout l1, l2正则化 momentum normalization 1.为什么Normalization?     深度神经网络模型的训练为什么会很困难?其中一个重 ...

  8. 深度学习中的batch、epoch、iteration的含义

    深度学习的优化算法,说白了就是梯度下降.每次的参数更新有两种方式. 第一种,遍历全部数据集算一次损失函数,然后算函数对各个参数的梯度,更新梯度.这种方法每更新一次参数都要把数据集里的所有样本都看一遍, ...

  9. 深度学习中 --- 解决过拟合问题(dropout, batchnormalization)

    过拟合,在Tom M.Mitchell的<Machine Learning>中是如何定义的:给定一个假设空间H,一个假设h属于H,如果存在其他的假设h’属于H,使得在训练样例上h的错误率比 ...

随机推荐

  1. 新建maven项目,JRE System Library[J2SE-1.5]

    上篇博文中搭建了maven多模块项目,发现全是JRE System Library[J2SE-1.5],如图. 怎么避免这种情况呢? windows-preferences-maven-user se ...

  2. 转-Windows路由表配置:双网卡路由分流

    原文链接:http://www.cnblogs.com/lightnear/archive/2013/02/03/2890835.html 一.windows 路由表解释 route print -4 ...

  3. Servlet--j2e中文乱码解决

    我们在写项目的时候经常会传递一些中文参数,但是j2e默认使用ISO-8859-1来编码和解码,所以很容易出现中文乱码问题.这里我做一个统一的整理,其实这里的中文乱码问题和上一篇的路径问题都是j2e经常 ...

  4. 07_jquery入门第一天

    视频来源:麦子学院 讲师:魏畅然 补充:JSON.stringify()函数 [https://www.cnblogs.com/damonlan/archive/2012/03/13/2394787. ...

  5. redis数据类型-有序集合

    有序集合类型 在集合类型的基础上有序集合类型为集合中的每个元素都关联了一个分数,这使得我们不仅可以完成插入.删除和判断元素是否存在等集合类型支持的操作,还能够获得分数最高(或最低)的前N个元素.获得指 ...

  6. 小谈ConcurrentHashMap

    面试的时候被面试官问了点相关知识,再次记录一些自己的总结 一. 1.HashTable也可实现线程安全,但是它是用synchronized实现的,所以其他线程访问HashTable的同步方法时,可能会 ...

  7. 一步一步从原理跟我学邮件收取及发送 9.多行结果与socket的阻塞

    前几篇的文章发表后,有网友留言说没有涉及到阻塞的问题吗?在 socket 的编程当中,这确实是个很重要的问题.结合目前我们文章的内容进度,我们来看看为什么说阻塞概念很重要. 接着上篇的内容,当我们发送 ...

  8. VNC配置

    简介 VNC (Virtual Network Console)是虚拟网络控制台的缩写.它 是一款优秀的远程控制工具软件,由著名的 AT&T 的欧洲研究实验室开发的.VNC 是在基于 UNIX ...

  9. 2. getline()和get()

    1.面向行输入:getline() ---其实还可以接受第三个参数. getline()函数读取整行,调用该方法 使用cin.getline().该函数有两个参数, 第一个参数是是用来存储输入行的数组 ...

  10. 细数Python Flask微信公众号开发中遇到的那些坑

    最近两三个月的时间,断断续续边学边做完成了一个微信公众号页面的开发工作.这是一个快递系统,主要功能有用户管理.寄收件地址管理.用户下单,订单管理,订单查询及一些宣传页面等.本文主要细数下开发过程中遇到 ...