TensorFlow笔记-07-神经网络优化-学习率,滑动平均

学习率

  • 学习率 learning_rate: 表示了每次参数更新的幅度大小。学习率过大,会导致待优化的参数在最小值附近波动,不收敛;学习率过小,会导致待优化的参数收敛缓慢
  • 在训练过程中,参数的更新向着损失函数梯度下降的方向
  • 参数的更新公式为:

    wn+1 = wn - learning_rate▽
  • 假设损失函数 loss = (w + 1)2。梯度是损失函数 loss 的导数为 ▽ = 2w + 2 。如参数初值为5,学习率为 0.2,则参数和损失函数更新如下:

1次 ·······参数w: 5 ·················5 - 0.2 * (2 * 5 + 2) = 2.6

2次 ·······参数w: 2.6 ··············2.6 - 0.2 * (2 * 2.6 + 2) = 1.16

3次 ·······参数w: 1.16 ············1.16 - 0.2 * (2 * 1.16 +2) = 0.296

4次 ·······参数w: 0.296

损失函数loss = (w + 1) 2 的图像为:



由图可知,损失函数 loss 的最小值会在(-1,0)处得到,此时损失函数的导数为 0,得到最终参数 w = -1。

代码 tf08learn 文件:https://xpwi.github.io/py/TensorFlow/tf08learn.py

  1. # coding: utf-8
  2. # 设损失函数loss = (w + 1)^2 , 令 w 是常数 5。反向传播就是求最小
  3. # loss 对应的 w 值
  4. import tensorflow as tf
  5. # 定义待优化参数 w 初值赋予5
  6. w = tf.Variable(tf.constant(5, dtype=tf.float32))
  7. # 定义损失函数 loss
  8. loss = tf.square(w + 1)
  9. # 定义反向传播方法
  10. train_step = tf.train.GradientDescentOptimizer(0.20).minimize(loss)
  11. # 生成会话,训练40轮
  12. with tf.Session() as sess:
  13. init_op = tf.global_variables_initializer()
  14. sess.run(init_op)
  15. for i in range(40):
  16. sess.run(train_step)
  17. W_val = sess.run(w)
  18. loss_val = sess.run(loss)
  19. print("After %s steps: w: is %f, loss: is %f." %(i, W_val, loss_val))

运行结果



运行结果分析: 由结果可知,随着损失函数值得减小,w 无线趋近于 -1

学习率的设置

  • 学习率过大,会导致待优化的参数在最小值附近波动,不收敛;学习率过小,会导致待优化的参数收敛缓慢
  • 例如:
  • (1) 对于上例的损失函数loss = (w + 1) 2,则将上述代码中学习率改为1,其余内容不变

    实验结果如下:

  • (2) 对于上例的损失函数loss = (w + 1) 2,则将上述代码中学习率改为0.0001,其余内容不变

    实验结果如下:



    由运行结果可知,损失函数 loss 值缓慢下降,w 值也在小幅度变化,收敛缓慢

指数衰减学习率

  • 指数衰减学习率:学习率随着训练轮数变化而动态更新



    其中,LEARNING_RATE_BASE 为学习率初始值,LEARNING_RATE_DECAY 为学习率衰减率,global_step 记录了当前训练轮数,为了不可训练型参数。学习率 learning_rate 更新频率为输入数据集总样本数除以每次喂入样本数。若 staircase 设置为 True 时,表示 global_step / learning rate step 取整数,学习率阶梯型衰减;若 staircase 设置为 False 时,学习率会是一条平滑下降的曲线。
  • 例如:

    在本例中,模型训练过程不设定固定的学习率,使用指数衰减学习率进行训练。其中,学习率初值设置为0.1,学习率衰减值设置为0.99,BATCH_SIZE 设置为1。
  • 代码 tf08learn2 文件:https://xpwi.github.io/py/TensorFlow/tf08learn2.py
  1. # coding: utf-8
  2. '''
  3. 设损失函数loss = (w + 1)^2 , 令 w 初值是常数5,
  4. 反向传播就是求最优 w,即求最小 loss 对应的w值。
  5. 使用指数衰减的学习率,在迭代初期得到较高的下降速度,
  6. 可以在较小的训练轮数下取得更有收敛度
  7. '''
  8. import tensorflow as tf
  9. LEARNING_RATE_BASE = 0.1 # 最初学习率
  10. LEARNING_RATE_DECAY = 0.99 # 学习率衰减率
  11. # 喂入多少轮 BATCH_SIZE 后,更新一次学习率,一般设置为:样本数/BATCH_SIZE
  12. LEARNING_RATE_STEP = 1
  13. # 运行了几轮 BATCH_SIZE 的计算器,初始值是0,设为不被训练
  14. global_step = tf.Variable(0, trainable=False)
  15. # 定义指数下降学习率
  16. learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step,
  17. LEARNING_RATE_STEP, LEARNING_RATE_DECAY, staircase=True)
  18. # 定义待优化参数 w 初值赋予5
  19. w = tf.Variable(tf.constant(5, dtype=tf.float32))
  20. # 定义损失函数 loss
  21. loss = tf.square(w+1)
  22. # 定义反向传播方法
  23. # 学习率为:0.2
  24. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  25. # 生成会话,训练40轮
  26. with tf.Session() as sess:
  27. init_op = tf.global_variables_initializer()
  28. sess.run(init_op)
  29. for i in range(40):
  30. sess.run(train_step)
  31. learning_rate_val = sess.run(learning_rate)
  32. global_step_val = sess.run(global_step)
  33. w_val = sess.run(w)
  34. loss_val = sess.run(loss)
  35. print("After %s steps: global_step is %f; : w: is %f;learn rate is %f; loss: is %f."
  36. %(i,global_step_val, w_val, learning_rate_val, loss_val))

运行结果



由结果可以看出,随着训练轮数增加学习率在不断减小

滑动平均

  • **滑动平均:记录了一段时间内模型中所有参数 w 和 b 各自的平均值,利用滑动平均值可以增强模型的泛化能力
  • **滑动平均值(影子)计算公式:影子 = 衰减率 * 参数
  • 其中衰减率 = min{AVERAGEDECAY(1+轮数/10+轮数)},影子初值 = 参数初值
  • 用 Tensorflow 函数表示:

**ema = tf.train.ExpoentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)

  • 其中 MOVING_AVERAGE_DECAY 表示滑动平均衰减率,一般会赋予接近1的值,global_step 表示当前训练了多少轮

**ema_op = ema.apply(tf.trainable_varables())

  • 其中 ema.apply() 函数实现对括号内参数的求滑动平均,tf.trainable_variables() 函数实现把所有待训练参数汇总为列表

with tf.control_dependencies([train_step, ema_op]):

**train_op = tf.no_op(name='train') **

  • 其中,该函数实现滑动平均和训练步骤同步运行
  • 查看模型中参数的平均值,可以用 ema.average() 函数
  • 例如:

    在神经网络中将 MOVING_AVERAGE_DECAY 设置为 0.9,参数 w1 设置为 0,w1 滑动平均值设置为 0
  • (1)开始时,轮数 global_step 设置为 0,参数 w1 更新为 1,则滑动平均值为:

**w1 滑动平均值 = min(0.99, 1/10)0+(1-min(0.99,1/10))1 = 0.9 **

  • (2)当轮数 global_step 设置为 0,参数 w1 更新为 10,以下代码 global_step 保持 100,每次执行滑动平均操作影子更新,则滑动平均值变为:

**w1 滑动平均值 = min(0.99, 101/110)0.9+(1-min(0.99,101/110))10 = 0.826+0.818 = 1.644 **

  • (3)再次运行,参数 w1 更新为 1.644,则滑动平均值变为:

**w1

更多文章链接:Tensorflow 笔记


- 本笔记不允许任何个人和组织转载

TensorFlow笔记-07-神经网络优化-学习率,滑动平均的更多相关文章

  1. TensorFlow笔记-06-神经网络优化-损失函数,自定义损失函数,交叉熵

    TensorFlow笔记-06-神经网络优化-损失函数,自定义损失函数,交叉熵 神经元模型:用数学公式比表示为:f(Σi xi*wi + b), f为激活函数 神经网络 是以神经元为基本单位构成的 激 ...

  2. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用滑动平均

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  3. 20180929 北京大学 人工智能实践:Tensorflow笔记07

    (完)

  4. Tensorflow 笔记

    TensorFlow笔记-08-过拟合,正则化,matplotlib 区分红蓝点 TensorFlow笔记-07-神经网络优化-学习率,滑动平均 TensorFlow笔记-06-神经网络优化-损失函数 ...

  5. tensorflow笔记之滑动平均模型

    tensorflow使用tf.train.ExponentialMovingAverage实现滑动平均模型,在使用随机梯度下降方法训练神经网络时候,使用这个模型可以增强模型的鲁棒性(robust),可 ...

  6. tensorflow笔记2(北大网课实战)

    1.正则化缓解过拟合 正则化在损失函数中引入模型复杂度指标,利用给w加权值,弱化了训练数据的噪声 一般不会正则化b. 2.matplotlib.pyplot 3.搭建模块化的神经网络八股: 前向传播就 ...

  7. tensorflow:实战Google深度学习框架第四章02神经网络优化(学习率,避免过拟合,滑动平均模型)

    1.学习率的设置既不能太小,又不能太大,解决方法:使用指数衰减法 例如: 假设我们要最小化函数 y=x2y=x2, 选择初始点 x0=5x0=5  1. 学习率为1的时候,x在5和-5之间震荡. im ...

  8. tensorflow入门笔记(二) 滑动平均模型

    tensorflow提供的tf.train.ExponentialMovingAverage 类利用指数衰减维持变量的滑动平均. 当训练模型的时候,保持训练参数的滑动平均是非常有益的.评估时使用取平均 ...

  9. TensorFlow+实战Google深度学习框架学习笔记(11)-----Mnist识别【采用滑动平均,双层神经网络】

    模型:双层神经网络 [一层隐藏层.一层输出层]隐藏层输出用relu函数,输出层输出用softmax函数 过程: 设置参数 滑动平均的辅助函数 训练函数 x,y的占位,w1,b1,w2,b2的初始化 前 ...

随机推荐

  1. 49 BOM 和DOM

    一.BOM window 对象  1.确认,输入,    window.alert(123) // 弹框    let ret = window.confirm("是否删除") / ...

  2. python-day5笔记

    一.python基础--基本数据类型 (无论用户输入什么内容,input 都会存成字符串格式) 1.基本数据类型 1)数字类型 整型(整数)int:年级,年纪,等级,身份证号,QQ号,手机号,leve ...

  3. python-day8-元组的内置方法

    #为何要有元组,存放多个值,元组不可变,更多的是用来做查询# t=(1,[1,3],'sss',(1,2)) #t=tuple((1,[1,3],'sss',(1,2)))# print(type(t ...

  4. UVA-10655 Contemplation! Algebra (矩阵)

    题目大意:给出a+b的值和ab的值,求a^n+b^n的值. 题目分析:有种错误的方法是这样的:利用已知的两个方程联立,求解出a和b,进而求出答案.这种方法之所以错,是因为这种方法有局限性.联立之后会得 ...

  5. SpringMVC实现RESTful服务

    SpringMVC实现RESTful服务 这里只说service,controller层的代码.Mapper层则直接继承Mapper<T>则可以,记住mybatis-config.xml一 ...

  6. 多线程私有数据pthread_key_create

    参照:http://blog.csdn.net/xiaohuangcat/article/details/18267561 在多线程的环境下,进程内的所有线程共享进程的数据空间.因此全局变量为所有线程 ...

  7. UVALive 5844 dfs暴力搜索

    题目链接:UVAive 5844 Leet DES:大意是给出两个字符串.第一个字符串里的字符可以由1-k个字符代替.问这两个字符串是不是相等.因为1<=k<=3.而且第一个字符串长度小于 ...

  8. TClientDataSet 的Filename 和 open方法

    cdsAdd.Open 不能用open,因为这个是DataSet的通用方法,不会检测文件名的,文件名是CDS特有的

  9. python 安装 scapy windows 10 64bit

    简介: 前段时间装的pypcap做嗅探.打包受阻.因为我都是在windows做的.也要打包到exe给别人用. 所以尝试了一下scapy,也可以嗅探,貌似功能更强大.先用sniff吧. 这个也不是在ve ...

  10. spring事务管理及相关知识

    最近在项目中遇到了spring事务的注解及相关知识,突然间感觉自己对于这部分知识只停留在表面的理解层次上,于是乎花些时间上网搜索了一些文章,以及对于源码的解读,整理如下: 一.既然谈到事务,那就先搞清 ...