跟我学算法-吴恩达老师(mini-batchsize,指数加权平均,Momentum 梯度下降法,RMS prop, Adam 优化算法, Learning rate decay)
1.mini-batch size
表示每次都只筛选一部分作为训练的样本,进行训练,遍历一次样本的次数为(样本数/单次样本数目)
当mini-batch size 的数量通常介于1,m 之间
当为1时,称为随机梯度下降
一般我们选择64,128, 256等样本数目
import numpy as np
import math
def random_mini_batch(X, Y, mini_batch = 64, seed=0): np.random.seed(seed)
m = X.shape[1] # 表示X样本的数量
mini_batches = [] # step 1 shuffle(X, Y) np.random.permutation(m) 随机排列一个数字
permutation = list(np.random.permutation(m)) X_shuffle = X[:, permutation]
Y_shuffle = Y[:, permutation] num_complete_minibatch = math.floor(m/mini_batch) for k in range(0, num_complete_minibatch): mini_batch_x = X_shuffle[:, k*mini_batch:(k+1)*mini_batch]
mini_batch_y = Y_shuffle[:, k * mini_batch:(k + 1) * mini_batch] mini_batches.append((mini_batch_x, mini_batch_y)) if m%mini_batch != 0:
mini_batch_x = X_shuffle[:, num_complete_minibatch*mini_batch:m]
mini_batch_y = Y_shuffle[:, num_complete_minibatch*mini_batch:m]
mini_batches.append((mini_batch_x, mini_batch_y)) return mini_batches
2. 指数加权平均
v0 = 0
v1 = 0.9 * v0 + 0.1 * θ1 v0表示前一次的数值,θ1表示当前的数值
v2 = 0.9 * v1 + 0.1 * θ2
v3 = 0.9 * v2 + 0.1 * θ3
v4 = 0.9 * v3 + 0.1 * θ4
vt = β * vt-1 + (1-β) * θt
举个例子:
v100 = 0.1*θ100 + 0.1*0.9*θ99 + 0.1*0.9^2*θ98 ...
指数加权的偏差修正
vt / (1-β^t) β 通常表示 0.9, t表示时间
(1-ξ)^(1/ξ) = 1/e
3. Momentum 梯度下降法, 加快梯度下降的速度,在横轴方向上进行了加权,因为方向相同,在纵轴上进行了削减,因为方向相反,因此梯度下降前进的方向更快
动量梯度下降法, 前一次的方向与当前次的方向进行指数加权,得到当前此的方向
vdw = β * vdw(forward) + (1-β) * dw
vdb = β * vdb(forward) + (1-β) * db
w: = w - α * vdw
b: = b - α * bdw
def update_parameters_with_Momentum(parameter, grade, v, beta, learning_rate):
L = len(parameter) // 2
for i in range(L):
v['dW' + str(i+1)] = beta*v['dW' + str(i+1)] + (1-beta) * grade['dW' + str(i+1)]
v['db' + str(i+1)] = beta*v['db' + str(i+1)] + (1-beta) * grade['db' + str(i+1)]
parameter['dW' + str(i+1)] = parameter['dW' + str(i+1)] - learning_rate * v['dW' + str(i+1)]
parameter['db' + str(i+1)] = parameter['db' + str(i+1)] - learning_rate * v['db' + str(i+1)]
return parameter, grade, v
4. RMS prop
Sdw = β * sdw(forward) + (1 - β) * dw^2
Sdb = β * sdb(forward) + (1 - β) * db^2
w: = w - α * dw/(sqrt(sdw)+ε)
b: = b - α * db/(sqrt(sdb)+ε)
5. Adam 优化算法,是将动量梯度下降法与RMS prop 结合
vdw = 0
sdw = 0
vab = 0
sab = 0
vdw = β1 * vdw(forward) + (1-β1) * dw
vdb = β1 * vdb(forward) + (1-β1) * db
Sdw = β2 * sdw(forward) + (1 - β2) * dw^2
Sdb = β2 * sdb(forward) + (1 - β2) * db^2
vdw(correct) = vdw / (1-β1^t)
vdb(correct) = vdb / (1-β1^t)
Sdw(correct) = Sdw / (1-β2^t)
Sdb(correct) = Sdb / (1-β2^t)
w: = w - α * vdw(correct)/(sqrt(Sdw(correct))+ε)
b: = b - α * vdb(correct)/(sqrt( Sdb(correct) )+ε)
β1 = 0.9
β2 = 0.999
ε = 10^-8
def update_parameters_with_Adam(parameter, grade, v, s, t, learning_rate, beta1=0.9, beta2=0.999, g=1e-8):
L = len(parameter) // 2
for i in range(L):
v['dW' + str(i + 1)] = (beta1 * v['dW' + str(i+1)] + (1 - beta1) * grade['dw' + str(i+1)]) / (1-beta1 ** t)
v['db' + str(i + 1)] = (beta1 * v['db' + str(i + 1)] + (1 - beta1) * grade['db' + str(i + 1)]) / (1-beta1 ** t)
s['dW' + str(i + 1)] = (beta2 * s['dW' + str(i + 1)] + (1 - beta2) * grade['dw' + str(i + 1)] ** 2) / (1-beta1 ** t)
s['db' + str(i + 1)] = (beta2 * s['db' + str(i + 1)] + (1 - beta2) * grade['db' + str(i + 1)] ** 2) / (1-beta1 ** t)
parameter['W' + str(i + 1)] = parameter['W' + str(i + 1)] - learning_rate*(v['dW' + str(i + 1)]) \
/ (s['dW' + str(i + 1)] + g)
parameter['b' + str(i + 1)] = parameter['b' + str(i + 1)] - learning_rate * (v['db'] + str(i + 1)) \
/ (s['db' + str(i + 1)] + g)
return v, s, parameter, grade
6. Learning rate decay
根据迭代的次数,加快学习率的降低,使得样本参数更容易发生收敛,但是一般情况下不使用
3种更新α的公式
α = 1 / (1 + decay-rate * epoch-num) * α0 α0表示初始学习率, decay-rate 表示衰减层度, epoch-num 表示迭代次数
α = 0.95^epoch_num * α0
α = k / sqrt(epoch_num) * α0
跟我学算法-吴恩达老师(mini-batchsize,指数加权平均,Momentum 梯度下降法,RMS prop, Adam 优化算法, Learning rate decay)的更多相关文章
- 跟我学算法-吴恩达老师(超参数调试, batch归一化, softmax使用,tensorflow框架举例)
1. 在我们学习中,调试超参数是非常重要的. 超参数的调试可以是a学习率,(β1和β2,ε)在Adam梯度下降中使用, layers层数, hidden units 隐藏层的数目, learning_ ...
- 跟我学算法-吴恩达老师的logsitic回归
logistics回归是一种二分类问题,采用的激活函数是sigmoid函数,使得输出值转换为(0,1)之间的概率 A = sigmoid(np.dot(w.T, X) + b ) 表示预测函数 dz ...
- 机器学习爱好者 -- 翻译吴恩达老师的机器学习课程字幕 http://www.ai-start.com/
机器学习爱好者 -- 翻译吴恩达老师的机器学习课程字幕 GNU Octave 开源 MatLab http://www.ai-start.com/ https://zhuanlan.zhihu ...
- 吴恩达《深度学习》-课后测验-第一门课 (Neural Networks and Deep Learning)-Week 3 - Shallow Neural Networks(第三周测验 - 浅层神 经网络)
Week 3 Quiz - Shallow Neural Networks(第三周测验 - 浅层神经网络) \1. Which of the following are true? (Check al ...
- 吴恩达《深度学习》-课后测验-第一门课 (Neural Networks and Deep Learning)-Week 2 - Neural Network Basics(第二周测验 - 神经网络基础)
Week 2 Quiz - Neural Network Basics(第二周测验 - 神经网络基础) 1. What does a neuron compute?(神经元节点计算什么?) [ ] A ...
- 吴恩达《深度学习》-课后测验-第一门课 (Neural Networks and Deep Learning)-Week 4 - Key concepts on Deep Neural Networks(第四周 测验 – 深层神经网络)
Week 4 Quiz - Key concepts on Deep Neural Networks(第四周 测验 – 深层神经网络) \1. What is the "cache" ...
- 神经网络优化算法:Dropout、梯度消失/爆炸、Adam优化算法,一篇就够了!
1. 训练误差和泛化误差 机器学习模型在训练数据集和测试数据集上的表现.如果你改变过实验中的模型结构或者超参数,你也许发现了:当模型在训练数据集上更准确时,它在测试数据集上却不⼀定更准确.这是为什么呢 ...
- 吴恩达讲了干货满满的一节全新AI课,全程手写板书充满诚意非常干货
吴恩达讲了干货满满的一节全新AI课,全程手写板书充满诚意非常干货 摘要: 目前,AI技术做出的经济贡献几乎都来自监督学习,也就是学习从A到B,从输入到输出的映射.现在,监督学习.迁移学习.非监督学习. ...
- 吴恩达最新TensorFlow专项课程开放注册,你离TF Boy只差这一步
不需要 ML/DL 基础,不需要深奥数学背景,初学者和软件开发者也能快速掌握 TensorFlow.掌握人工智能应用的开发秘诀. 以前,吴恩达的机器学习课程和深度学习课程会介绍很多概念与知识,虽然也会 ...
随机推荐
- java中进行四舍五入
在oracle中有一个很好的函数进行四舍五入,round(), select round(111112.23248987,6) from dual; 但是java的Number本身不提供四舍五入的方法 ...
- ES6学习一 JS语言增强篇
一 背景 JavaScript经过二十来年年的发展,由最初简单的交互脚本语言,发展到今天的富客户端交互,后端服务器处理,跨平台(Native),以及小程序等等的应用.JS的角色越来越重要,处理场景越来 ...
- 为什么样本方差自由度(分母)为n-1
一.概念.条件及目的 1.概念 要理解样本方差的自由度为什么是n-1,得先理解自由度的概念: 自由度,是指附加给独立的观测值的约束或限制的个数,即一组数据中可以自由取值的个数. 2.成立条件 所谓自由 ...
- STL标准库-容器-set与multiset
技术在于交流.沟通,转载请注明出处并保持作品的完整性. set与multiset关联容器 结构如下 set是一种关联容器,key即value,value即key.它是自动排序,排序特点依据key se ...
- OMAP4之DSP核(Tesla)软件开发学习(一)
目的: 目前手上正在OMAP4上做东西,由于涉及到大量运算,交给arm A9双核发现运算速度很慢,不能满足需求.故考虑将大量运算任务(比如FIR.FFT.卷积.图像处理.向量运算等)交给O ...
- Nginx 静态资源缓存配置
示例 # Media: images, icons, video, audio, HTC location ~* \.(?:jpg|jpeg|gif|png|ico|cur|gz|svg|svgz|m ...
- Oracle:"ORA-00942: 表或视图不存在"
情景 项目中使用Powerdesigner设计数据结构,在Powerdesigner中数据表和字段都区分了大小写,并生成了Oracle表,在执行Sql脚本时遇到以下问题:“ORA-00942: 表或视 ...
- POI2014题解
POI2014题解 [BZOJ3521][Poi2014]Salad Bar 把p当作\(1\),把j当作\(-1\),然后做一遍前缀和. 一个合法区间\([l,r]\)要满足条件就需要满足所有前缀和 ...
- 便捷的Jenkins jswidgets
很多时候我们在构建完成之后需要查看构建的状态,类似github 中的build Status 插件安装 搜索插件 使用 目前好像只支持自由项目的构建 代码集成 <!DOCTYPE html> ...
- pthread访问调用信号线程的掩码(pthread_sigmask )
掩码: 信号掩码 在POSIX下,每个进程有一个信号掩码(signal mask).简单地说,信号掩码是一个"位图",其中每一位都对应着一种信号.如果位图中的某一位为1,就表示在执 ...