三种梯度下降法的对比(BGD & SGD & MBGD)
常用的梯度下降法分为:
- 批量梯度下降法(Batch Gradient Descent)
- 随机梯度下降法(Stochastic Gradient Descent)
- 小批量梯度下降法(Mini-Batch Gradient Descent)
简单的算法示例
数据
x = np.random.uniform(-3,3,100)
X = x.reshape(-1,1)
y = x * 2 + 5 + np.random.normal(0, 1, 100)
BGD
批量梯度下降法的简单实现:
def gradient_descent(X_b, y, initial_theta, eta, n_iters=1e4, epsilon=1e-8):
def J(theta):
return np.mean((X_b.dot(theta) - y) ** 2)
def dj(theta):
return X_b.T.dot((X_b.dot(theta) - y)) * (2 / len(y))
theta = initial_theta
for i in range(1, int(n_iters)):
gradient = dj(theta) # 获得梯度
last_theta = theta
theta = theta - eta * gradient # 迭代梯度
if np.absolute(J(theta) - J(last_theta)) < epsilon:
break # 满足条件就跳出
return theta
结果是:
X_b = np.hstack([np.ones((len(y), 1)), X])
initial_theta = np.ones(X_b.shape[1])
eta = 0.1
%time s_gradient_descent(X_b, y, initial_theta, eta, n_iters=1)
## array([4.72619109, 3.08239321])
SGD
这里n_iters表示将所有数据迭代的轮数。
def s_gradient_descent(X_b, y, initial_theta, eta, batch_size=10, n_iters=10, epsilon=1e-8):
def J(theta):
return np.mean((X_b.dot(theta) - y) ** 2)
# 这是随机梯度下降的,随机一个样本的梯度
def dj_sgd(X_b_i, y_i, theta):
# return X_b.T.dot((X_b.dot(theta) - y)) * (2 / len(y))
return 2 * X_b_i.T.dot(X_b_i.dot(theta) - y_i)
theta = initial_theta
for i in range(0, int(n_iters)):
for j in range(batch_size, len(y), batch_size):
gradient = dj_sgd(X_b[j,:], y[j], theta)
last_theta = theta
theta = theta - eta * gradient # 迭代梯度
if np.absolute(J(theta) - J(last_theta)) < epsilon:
break # 满足条件就跳出
return theta
结果是:
X_b = np.hstack([np.ones((len(y), 1)), X])
initial_theta = np.ones(X_b.shape[1])
eta = 0.1
%time s_gradient_descent(X_b, y, initial_theta, eta, n_iters=1)
## array([4.72619109, 3.08239321])
MBGD
在随机梯度下降的基础上,对dj做了一点点修改,batch_size指定批量的大小,dj每次计算batch_size个样本的梯度并取平均值。
不得不说,同样是迭代一轮数据,小批量梯度下降法的准确度要比随机梯度下降法高多了。
def b_gradient_descent(X_b, y, initial_theta, eta, batch_size=10, n_iters=10, epsilon=1e-8):
def J(theta):
return np.mean((X_b.dot(theta) - y) ** 2)
# 这是小批量梯度下降的,随机一个样本的梯度
def dj_bgd(X_b_b, y_b, theta):
# return X_b.T.dot((X_b.dot(theta) - y)) * (2 / len(y))
return X_b_b.T.dot(X_b_b.dot(theta) - y_b) * (2 / len(y_b))
theta = initial_theta
for i in range(0, int(n_iters)):
for j in range(batch_size, len(y), batch_size):
gradient = dj_bgd(X_b[j-batch_size:j,:], y[j-batch_size:j], theta)
last_theta = theta
theta = theta - eta * gradient # 迭代梯度
if np.absolute(J(theta) - J(last_theta)) < epsilon:
break # 满足条件就跳出
return theta
结果是:
X_b = np.hstack([np.ones((len(y), 1)), X])
initial_theta = np.ones(X_b.shape[1])
eta = 0.1
%time b_gradient_descent(X_b, y, initial_theta, eta, n_iters=1)
array([4.4649369 , 2.27164876])
三种梯度下降法的对比(BGD & SGD & MBGD)的更多相关文章
- 三种梯度下降算法的区别(BGD, SGD, MBGD)
前言 我们在训练网络的时候经常会设置 batch_size,这个 batch_size 究竟是做什么用的,一万张图的数据集,应该设置为多大呢,设置为 1.10.100 或者是 10000 究竟有什么区 ...
- 各种优化器对比--BGD/SGD/MBGD/MSGD/NAG/Adagrad/Adam
指数加权平均 (exponentially weighted averges) 先说一下指数加权平均, 公式如下: \[v_{t}=\beta v_{t-1}+(1-\beta) \theta_{t} ...
- python笔记-20 django进阶 (model与form、modelform对比,三种ajax方式的对比,随机验证码,kindeditor)
一.model深入 1.model的功能 1.1 创建数据库表 1.2 操作数据库表 1.3 数据库的增删改查操作 2.创建数据库表的单表操作 2.1 定义表对象 class xxx(models.M ...
- iOS- NSThread/NSOperation/GCD 三种多线程技术的对比及实现
1.iOS的三种多线程技术 1.NSThread 每个NSThread对象对应一个线程,量级较轻(真正的多线程) 2.以下两点是苹果专门开发的“并发”技术,使得程序员可以不再去关心线程的具体使用问题 ...
- iOS- NSThread/NSOperation/GCD 三种多线程技术的对比及实现 -- 转
1.iOS的三种多线程技术 1.NSThread 每个NSThread对象对应一个线程,量级较轻(真正的多线程) 2.以下两点是苹果专门开发的“并发”技术,使得程序员可以不再去关心线程的具体使用问题 ...
- 几种梯度下降方法对比(Batch gradient descent、Mini-batch gradient descent 和 stochastic gradient descent)
https://blog.csdn.net/u012328159/article/details/80252012 我们在训练神经网络模型时,最常用的就是梯度下降,这篇博客主要介绍下几种梯度下降的变种 ...
- Dynamics CRM2016 查询数据的三种方式的性能对比
之前写过一个博客,对非声明验证方式下连接组织服务的两种方式的性能进行了对比,但当时只是对比了实例化组织服务的时间,并没有对查询数据的时间进行对比,那有朋友也在我的博客中留言了反映了查询的时间问题,一直 ...
- 两个Map的对比,三种方法,将对比结果写入文件。
三种方法的思维都是遍历一个map的Key,然后2个Map分别取这2个Key值所得到的Value. #第一种用entry private void compareMap(Map<String, S ...
- java对象头信息和三种锁的性能对比
java头的信息分析 首先为什么我要去研究java的对象头呢? 这里截取一张hotspot的源码当中的注释 这张图换成可读的表格如下 |-------------------------------- ...
随机推荐
- Excel 查找函数(一):LOOKUP
序号 员工姓名 部门 职务 1 苏霞 法务部 法律顾问 2 包志林 财务部 财务总监 3 林娥云 安监部 部长 4 石少卿 质检部 质检员 5 于炳福 生产部 生产部 6 蒋琼志 仓储部 保管员 7 ...
- 弹簧高跷题解---双向DP---DD(XYX)的博客
三 . 弹簧高跷 时间限制: 1 Sec 内存限制: 128 MB 题目描述.输入.输出 ----------- 方法 这道题用DP是可以解决的.因为每一次跳跃都与前一次跳跃有关, ...
- MPI学习笔记(三):矩阵相乘的分块并行(行列划分法)
mpi矩阵乘法:C=αAB+βC 一.主从模式的行列划分并行法 1.实现方法 将可用于计算的进程数comm_sz分解为a*b,然后将矩阵A全体行划分为a个部分,将矩阵B全体列划分为b个部分,从而将整个 ...
- CCC3.0 NFC OWNER PAIRING
OWNER PAIRING 本篇只介绍所有操作都成功执行的场景,中间如果出现异常,需要翻看规范决定接下来的操作 一些密钥 公私密钥对(Vehicle.PK&Vehicle.SK) Endpoi ...
- BI如何实现用户身份集成自定义安全程序开发
统一身份认证是整个 IT 架构的最基本的组成部分,而账号则是实现统一身份认证的基础.做好账号的规划和设计直接决定着企业整个信息系统建设的便利与难易程度,决定着系统能否足够敏捷和快速赋能,也决定了在数字 ...
- 常用的SSH,你了解多少?(长文警告)
1.SSH工作原理 从ssh的加密方式说开去,看下文 1.1.对称加密 客户端和服务端采用相同的密钥进行数据的加解密,很难保证密钥不丢失,或者被截获.隐藏着中间人攻击的风险 如果攻击者插在用户与远程主 ...
- 如何使用Postman调试HMS Core推送接口?
HMS Core推送服务支持开发者使用HTTPS协议接入Push服务端.Postman是一款接口测试工具,它可以模拟用户发起的各类HTTP请求,将请求数据发送至服务端,获取对应的响应结果.Postma ...
- ACVF of ARMA(1, 1)
\(ARMA(1, ~ 1)\) process is a time series \(\left\{ X_{t} \right\}\) defined as: \[X_{t} - \phi X_{t ...
- WindowsApps目录占用大量空间
WindowsApps目录占用大量空间今天遇到一个客户端的问题.Windows 10的电脑100G的C盘空间几乎耗尽.但是选取所有文件后总大小只有不到40G.按常规,肯定是有一些没有权限的文件夹的体积 ...
- Springboot配置文件参数使用docker-compose实现动态配置
文章总结; Springboot配置文件中的一些参数可以写成变量的形式,具体变量的值可以从docker-compose.yml文件中设置来获取 在yml文件中,通过${Envirment_variab ...