matlab实现梯度下降法(Gradient Descent)的一个例子
在此记录使用matlab作梯度下降法(GD)求函数极值的一个例子:
问题设定:
1. 我们有一个$n$个数据点,每个数据点是一个$d$维的向量,向量组成一个data矩阵$\mathbf{X}\in \mathbb{R}^{n\times d}$,这是我们的输入特征矩阵。
2. 我们有一个响应的响应向量$\mathbf{y}\in \mathbb{R}^n$。
3. 我们将使用线性模型来fit上述数据。因此我们将优化问题形式化成如下形式:$$\arg\min_{\mathbf{w}}f(\mathbf{w})=\frac{1}{n}\|\mathbf{y}-\mathbf{\overline{X}}\mathbf{w}\|_2^2$$
其中$\mathbf{\overline{X}}=(\mathbf{1,X})\in \mathbb{R}^{n\times (d+1)}$ and $\mathbf{w}=(w_0,w_1,...,w_d)^\top\in \mathbb{R}^{d+1}$
显然这是一个回归问题,我们的目标从通俗意义上讲就是寻找合适的权重向量$\mathbf{w}$,使得线性模型能够拟合的更好。
预处理:
1. 按列对数据矩阵进行最大最小归一化,该操作能够加快梯度下降的速度,同时保证了输入的数值都在0和1之间。$\mathbf{x}_i$为$\mathbf{X}$的第i列。 $$z_{ij}\leftarrow \frac{x_{ij}-\min(\mathbf{x}_i)}{\max(\mathbf{x}_i)-\min(\mathbf{x}_i)}$$
这样我们的优化问题得到了转化:$$\arg\min_{\mathbf{u}}g(\mathbf{w})=\frac{1}{n}\|\mathbf{y}-\mathbf{\overline{Z}}\mathbf{u}\|_2^2$$
2. 考虑对目标函数的Lipschitz constants进行估计。因为我们使用线性回归模型,Lipschitz constants可以方便求得,这样便于我们在梯度下降法是选择合适的步长。假如非线性模型,可能要用其他方法进行估计(可选)。
问题解决:
使用梯度下降法进行问题解决,算法如下:
我们可以看到,这里涉及到求目标函数$f$对$\mathbf{x}_k$的梯度。显然在这里,因为是线性模型,梯度的求解十分的简单:$$\nabla f(\mathbf{x}_k)=-\frac{2}{n}\mathbf{\overline{X}}^\top(\mathbf{y}-\mathbf{\overline{X}}\mathbf{u}_k)$$
进行思考,还有没有其他办法可以把这个梯度给弄出来?假如使用Tensorflow,Pytorch这样可以自动保存计算图的东东,那么梯度是可以由机器自动求出来的。当然在这里我是用matlab实现,暂时没有发现这样的利器,所以我认为假如在这里想求出梯度,那么我们必须要把梯度的闭式解搞出来,不然没法继续进行。
下面是一段matlab的代码:
function [g_result,u_result] = GD(N_Z,y,alpha,u0)
%GD 梯度下降法
% Detailed explanation goes here
[n,~] = size(N_Z);
u = u0;
k = 0;
t = y-N_Z*u;
disp("g(u):");
while(合理的终止条件)
k = k + 1;
u = u - alpha * (-2/n)*N_Z'*t;
t = y-N_Z*u;
if(mod(k,10)==0)
disp(t'*t/n);
end
end
g_result = (y-N_Z * u)' * (y-N_Z * u)/n;
u_result = u;
end
当然假如初始化的时候$u_0$选择不当,而且因为没有正则项,以上的算法将会有很大的问题:梯度消失,导致优化到最后的时候非常慢。我花了好多个小时才将loss讲到0.19左右,而闭式解法能够使得loss为0.06几,运行时间也不会难以忍受。
问题推广:
在这里,我们的问题是线性模型,回归问题。能否有更广的应用?思考后认为,只要需要优化的目标是标量,且该目标函数对输入向量的梯度容易求得即可。只是因为该算法简单朴素,可能在实际应用的时候会碰见恼人的梯度消失问题。
matlab实现梯度下降法(Gradient Descent)的一个例子的更多相关文章
- (3)梯度下降法Gradient Descent
梯度下降法 不是一个机器学习算法 是一种基于搜索的最优化方法 作用:最小化一个损失函数 梯度上升法:最大化一个效用函数 举个栗子 直线方程:导数代表斜率 曲线方程:导数代表切线斜率 导数可以代表方向, ...
- <反向传播(backprop)>梯度下降法gradient descent的发展历史与各版本
梯度下降法作为一种反向传播算法最早在上世纪由geoffrey hinton等人提出并被广泛接受.最早GD由很多研究团队各自发表,可他们大多无人问津,而hinton做的研究完整表述了GD方法,同时hin ...
- 梯度下降法Gradient descent(最速下降法Steepest Descent)
最陡下降法(steepest descent method)又称梯度下降法(英语:Gradient descent)是一个一阶最优化算法. 函数值下降最快的方向是什么?沿负梯度方向 d=−gk
- 梯度下降(gradient descent)算法简介
梯度下降法是一个最优化算法,通常也称为最速下降法.最速下降法是求解无约束优化问题最简单和最古老的方法之一,虽然现在已经不具有实用性,但是许多有效算法都是以它为基础进行改进和修正而得到的.最速下降法是用 ...
- 机器学习(1)之梯度下降(gradient descent)
机器学习(1)之梯度下降(gradient descent) 题记:最近零碎的时间都在学习Andrew Ng的machine learning,因此就有了这些笔记. 梯度下降是线性回归的一种(Line ...
- 梯度下降(Gradient Descent)小结 -2017.7.20
在求解算法的模型函数时,常用到梯度下降(Gradient Descent)和最小二乘法,下面讨论梯度下降的线性模型(linear model). 1.问题引入 给定一组训练集合(training se ...
- 理解梯度下降法(Gradient Decent)
1. 什么是梯度下降法? 梯度下降法(Gradient Decent)是一种常用的最优化方法,是求解无约束问题最古老也是最常用的方法之一.也被称之为最速下降法.梯度下降法在机器学习中十分常见,多用 ...
- (二)深入梯度下降(Gradient Descent)算法
一直以来都以为自己对一些算法已经理解了,直到最近才发现,梯度下降都理解的不好. 1 问题的引出 对于上篇中讲到的线性回归,先化一个为一个特征θ1,θ0为偏置项,最后列出的误差函数如下图所示: 手动求解 ...
- CS229 2.深入梯度下降(Gradient Descent)算法
1 问题的引出 对于上篇中讲到的线性回归,先化一个为一个特征θ1,θ0为偏置项,最后列出的误差函数如下图所示: 手动求解 目标是优化J(θ1),得到其最小化,下图中的×为y(i),下面给出TrainS ...
随机推荐
- Android常用五大布局
一.说明 1.每个应用程序都默认包含一个主界面布局文件(.xml). 2.位于项目的app/src/main/res/layout目录. 3.宽度和高度的属性 match_parent:强制性的使使徒 ...
- ATT&CK如何落地到安全产品
科普:ATT&CK是什么 ATT&CK的提出是为了解决业界对黑客行为.事件的描述不一致.不直观的问题,换句话说它解决了描述黑客行为 (TTP) 的语言和词库,将描述黑客攻击的语言统一化 ...
- php序列化和反序列化学习
1.什么是序列化 序列化说通俗点就是把一个对象变成可以传输的字符串. 1.举个例子,不知道大家知不知道json格式,这就是一种序列化,有可能就是通过array序列化而来的.而反序列化就是把那串可以传输 ...
- 非阻塞赋值(Non-blocking Assignment)是个伪需求(2)
https://mp.weixin.qq.com/s/5NWvdK3T2X4dtyRqtNrBbg 13hope: 个人理解,Verilog本身只是“建模”语言.具体到阻塞/非阻塞,只规定了两种赋 ...
- Java实现 LeetCode 835 图像重叠(暴力)
835. 图像重叠 给出两个图像 A 和 B ,A 和 B 为大小相同的二维正方形矩阵.(并且为二进制矩阵,只包含0和1). 我们转换其中一个图像,向左,右,上,或下滑动任何数量的单位,并把它放在另一 ...
- Java实现 LeetCode 554 砖墙(缝隙可以放在数组?)
554. 砖墙 你的面前有一堵方形的.由多行砖块组成的砖墙. 这些砖块高度相同但是宽度不同.你现在要画一条自顶向下的.穿过最少砖块的垂线. 砖墙由行的列表表示. 每一行都是一个代表从左至右每块砖的宽度 ...
- Java实现 LeetCode 108 将有序数组转换为二叉搜索树
108. 将有序数组转换为二叉搜索树 将一个按照升序排列的有序数组,转换为一棵高度平衡二叉搜索树. 本题中,一个高度平衡二叉树是指一个二叉树每个节点 的左右两个子树的高度差的绝对值不超过 1. 示例: ...
- Java实现 蓝桥杯VIP 算法训练 矩阵乘方
算法提高 矩阵乘方 时间限制:1.0s 内存限制:512.0MB 问题描述 给定一个矩阵A,一个非负整数b和一个正整数m,求A的b次方除m的余数. 其中一个nxn的矩阵除m的余数得到的仍是一个nxn的 ...
- 第七届蓝桥杯JavaB组国(决)赛部分真题
解题代码部分来自网友,如果有不对的地方,欢迎各位大佬评论 题目1.愤怒小鸟 题目描述 X星球愤怒的小鸟喜欢撞火车! 一根平直的铁轨上两火车间相距 1000 米 两火车 (不妨称A和B) 以时速 10米 ...
- Java实现8枚硬币问题(减治法)
1 问题描述 在8枚外观相同的硬币中,有一枚是假币,并且已知假币与真币的重量不同,但不知道假币与真币相比较轻还是较重.可以通过一架天平来任意比较两组硬币,设计一个高效的算法来检测这枚假币. 2.1 减 ...