指数衰减学习率是先使用较大的学习率来快速得到一个较优的解,然后随着迭代的继续,逐步减小学习率,使得模型在训练后期更加稳定。在训练神经网络时,需要设置学习率(learning rate)控制参数的更新速度,学习速率设置过小,会极大降低收敛速度,增加训练时间;学习率太大,可能导致参数在最优解两侧来回振荡。

函数原型:

tf.train.exponential_decay(
    learning_rate,
    global_step,
    decay_steps,
    decay_rate,
    staircase=False,#默认为False
    name=None
)

staircase:布尔值。如果True以不连续的间隔衰减学习速率,最后曲线就是锯齿状

该函数返回衰退的学习速率。它被计算为:

decayed_learning_rate = learning_rate *                        decay_rate ^ (global_step / decay_steps)
指数衰减学习率的各种参数:

# 初始学习率
learning_rate = 0.1
# 衰减系数
decay_rate = 0.9
# decay_steps控制衰减速度
# 如果decay_steps大一些,(global_step / decay_steps)就会增长缓慢一些
# 从而指数衰减学习率decayed_learning_rate就会衰减得慢一些
# 否则学习率很快就会衰减为趋近于0
decay_steps = 50
# 迭代轮数
global_steps = 3000
此时的意思是学习率以基数0.9每50步进行衰减。例如当迭代次数从1到3000次时,迭代到最后一次时,3000/50=60. 则衰减到基数的60次方。
是初始化的学习率, 是随着 的递增而衰减。显然,当 为初值0时, 有下面等式:

用来控制衰减速度,如果 大一些, 就会增长缓慢一些。从而指数衰减学习率 就会衰减得慢一否则学习率很快就会衰减为趋近于0。

徒手实现指数衰减学习率:

import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题
X = []
Y = []
learning_rate=1
global_steps=3000
decay_steps=50
decay_rate=0.9
# 指数学习率衰减过程
for global_step in range(global_steps):
decayed_learning_rate = learning_rate * decay_rate**(global_step / decay_steps)
X.append(global_step / decay_steps)
Y.append(decayed_learning_rate)
#print("global step: %d, learning rate: %f" % (global_step,decayed_learning_rate))
plt.plot(X,Y,'b')
plt.ylabel(u"learning_rate学习率")
plt.xlabel('global_step / decay_steps')
plt.show()

---------------------
作者:亮亮兰
来源:CSDN
原文:https://blog.csdn.net/lyl771857509/article/details/79734107
版权声明:本文为博主原创文章,转载请附上博文链接!

【tensorflow】】模型优化(一)指数衰减学习率的更多相关文章

  1. TensorFlow 模型优化工具包  —  训练后整型量化

    模型优化工具包是一套先进的技术工具包,可协助新手和高级开发者优化待部署和执行的机器学习模型.自推出该工具包以来,  我们一直努力降低机器学习模型量化的复杂性 (https://www.tensorfl ...

  2. Tensorflow笔记——神经网络图像识别(四)搭建模块化的神经网络八股(正则化,指数衰减学习率,滑动平均等优化)

    实战案例: 数据X[x0,x1]为正太分布随机点, 标注Y_,当x0*x0+x1*x1<2时,y_=1(红),否则y_=0(蓝)  建立三个.py文件 1.  generateds.py生成数据 ...

  3. 使用TensorFlow Serving优化TensorFlow模型

    使用TensorFlow Serving优化TensorFlow模型 https://www.tensorflowers.cn/t/7464 https://mp.weixin.qq.com/s/qO ...

  4. 超详细的Tensorflow模型的保存和加载(理论与实战详解)

    1.Tensorflow的模型到底是什么样的? Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等.所以,Tensorflow模型有两个主要的文件: a) Meta graph: ...

  5. 移动端目标识别(3)——使用TensorFlow Lite将tensorflow模型部署到移动端(ssd)之Running on mobile with TensorFlow Lite (写的很乱,回头更新一个简洁的版本)

    承接移动端目标识别(2) 使用TensorFlow Lite在移动设备上运行         在本节中,我们将向您展示如何使用TensorFlow Lite获得更小的模型,并允许您利用针对移动设备优化 ...

  6. 移动端目标识别(1)——使用TensorFlow Lite将tensorflow模型部署到移动端(ssd)之TensorFlow Lite简介

    平时工作就是做深度学习,但是深度学习没有落地就是比较虚,目前在移动端或嵌入式端应用的比较实际,也了解到目前主要有 caffe2,腾讯ncnn,tensorflow,因为工作用tensorflow比较多 ...

  7. 一份快速完整的Tensorflow模型保存和恢复教程(译)(转载)

    该文章转自https://blog.csdn.net/sinat_34474705/article/details/78995196 我在进行图像识别使用ckpt文件预测的时候,这个文章给我提供了极大 ...

  8. TensorFlow 模型文件

    在这篇 TensorFlow 教程中,我们将学习如下内容: TensorFlow 模型文件是怎么样的? 如何保存一个 TensorFlow 模型? 如何恢复一个 TensorFlow 模型? 如何使用 ...

  9. TensorFlow+TVM优化NMT神经机器翻译

    TensorFlow+TVM优化NMT神经机器翻译 背景 神经机器翻译(NMT)是一种自动化的端到端方法,具有克服传统基于短语的翻译系统中的弱点的潜力.本文为全球电子商务部署NMT服务. 目前,将Tr ...

随机推荐

  1. tomcat9下载与安装

    tomcat9下载与安装 官网下载地址:https://tomcat.apache.org/ 百度云地址:链接:https://pan.baidu.com/s/109PYcSh-eqTctLAXIsb ...

  2. JavaScript-JQ初探实现自定义滚动条

    这是一个基本实现思路,如果有新手和我一样没什么事,喜欢瞎研究话,可以参考下. 一.Html <div class="scroll_con"> <div class ...

  3. 简单几招助您加速 ARM 容器应用开发和测试流程

    随着5G时代的临近,低延迟网络.AI硬件算力提升.和智能化应用快速发展,一个万物智联的时代必将到来.我们需要将智能决策.实时处理能力从云延展到边缘和IoT设备端.阿里云容器服务推出了边缘容器,支持云- ...

  4. Spring_事务

    事务管理: 用来确保数据的完整性和一致性 事务就是一系列的动作,它们被当做一个单独的工作单元.这些动作要么全部完成,要么全部不起作用 事务的四个关键属性 原子性 一致性 隔离性 持久性 Spring两 ...

  5. 《2019年上半年Web应用安全报告》发布:90%以上攻击流量来源于扫描器,IP身份不再可信

    Web应用安全依然是互联网安全的最大威胁来源之一,除了传统的网页和APP,API和各种小程序也作为新的流量入口快速崛起,更多的流量入口和更易用的调用方式在提高web应用开发效率的同时也带来了更多和更复 ...

  6. 【JZOJ4930】【NOIP2017提高组模拟12.18】C

    题目描述 给出一个H的行和W列的网格.第i行第j列的状态是由一个字母的A[i][j]表示,如下: "." 此格为空. "o" 此格包含一个机器人. " ...

  7. node项目搭建

    一:安装 1.简单安装法 下载.msi [编译好的nodejs]  ->  点击安装 [系统会自动配置系统变量]   2.复杂安装法(不推荐) 由于nodejs的源码由C++和js组成 同时需要 ...

  8. Directx11教程(32) 纹理映射(2)

    原文:Directx11教程(32) 纹理映射(2)     在写代码之前,我们先制作一个dds文件.从网上找到了一张照片,处理成为512*512,保存为jpg格式.     启动微软的directx ...

  9. linux 下 自己写的 html文件产生中文乱码问题 解决办法

    再文件顶部加上  <meta http-equiv="Content-Type" content="text/html; charset=utf-8" / ...

  10. Cmake在编译osgEarth时遇到的一个错误

    CMake Error at src/osgEarthDrivers/CMakeLists.txt:7 (PROJECT): The CMAKE_C_COMPILER: llvm-gcc-4.2 is ...