梯度下降法(gradient descent),又名最速下降法(steepest descent)是求解无约束最优化问题最常用的方法,它是一种迭代方法,每一步主要的操作是求解目标函数的梯度向量,将当前位置的负梯度方向作为搜索方向(因为在该方向上目标函数下降最快,这也是最速下降法名称的由来)。
梯度下降法特点:越接近目标值,步长越小,下降速度越慢。
直观上来看如下图所示:

这里每一个圈代表一个函数梯度,最中心表示函数极值点,每次迭代根据当前位置求得的梯度(用于确定搜索方向以及与步长共同决定前进速度)和步长找到一个新的位置,这样不断迭代最终到达目标函数局部最优点(如果目标函数是凸函数,则到达全局最优点)。

下面我们将通过公式来具体说明梯度下降法
下面这个h(θ)是我们的拟合函数

也可以用向量的形式进行表示:

下面函数是我们需要进行最优化的风险函数,其中的每一项都表示在已有的训练集上我们的拟合函数与y之间的残差,计算其平方损失函数作为我们构建的风险函数(参见最小二乘法及其Python实现)

这里我们乘上1/2是为了方便后面求偏导数时结果更加简洁,之所以能乘上1/2是因为乘上这个系数后对求解风险函数最优值没有影响。
我们的目标就是要最小化风险函数,使得我们的拟合函数能够最大程度的对目标函数y进行拟合,即:

后面的具体梯度求解都是围绕这个目标来进行。

批量梯度下降BGD
按照传统的思想,我们需要对上述风险函数中的每个求其偏导数,得到每个对应的梯度

这里表示第i个样本点的第j分量,即h(θ)中的

接下来由于我们要最小化风险函数,故按照每个参数的负梯度方向来更新每一个

这里的α表示每一步的步长

从上面公式可以注意到,它得到的是一个全局最优解,但是每迭代一步,都要用到训练集所有的数据,如果m很大,那么可想而知这种方法的迭代速度!!所以,这就引入了另外一种方法,随机梯度下降。

随机梯度下降SGD
因为批量梯度下降在训练集很大的情况下迭代速度非常之慢,所以在这种情况下再使用批量梯度下降来求解风险函数的最优化问题是不具有可行性的,在此情况下,提出了——随机梯度下降
我们将上述的风险函数改写成以下形式:

其中,

称为样本点的损失函数

接下来我们对每个样本的损失函数,对每个求其偏导数,得到每个对应的梯度

然后根据每个参数的负梯度方向来更新每一个

与批量梯度下降相比,随机梯度下降每次迭代只用到了一个样本,在样本量很大的情况下,常见的情况是只用到了其中一部分样本数据即可将θ迭代到最优解。因此随机梯度下降比批量梯度下降在计算量上会大大减少。

SGD有一个缺点是,其噪音较BGD要多,使得SGD并不是每次迭代都向着整体最优化方向。而且SGD因为每次都是使用一个样本进行迭代,因此最终求得的最优解往往不是全局最优解,而只是局部最优解。但是大的整体的方向是向全局最优解的,最终的结果往往是在全局最优解附近。

下面是两种方法的图形展示:

从上述图形可以看出,SGD因为每次都是用一个样本点进行梯度搜索,因此其最优化路径看上去比较盲目(这也是随机梯度下降名字的由来)。

对比其优劣点如下:
批量梯度下降:
优点:全局最优解;易于并行实现;总体迭代次数不多
缺点:当样本数目很多时,训练过程会很慢,每次迭代需要耗费大量的时间。

随机梯度下降:
优点:训练速度快,每次迭代计算量不大
缺点:准确度下降,并不是全局最优;不易于并行实现;总体迭代次数比较多。

============ 分割分割 =============
上面我们讲解了什么是梯度下降法,以及如何求解梯度下降,下面我们将通过Python来实现梯度下降法。

  1. # _*_ coding: utf-8 _*_
  2. # 作者: yhao
  3. # 博客: http://blog.csdn.net/yhao2014
  4. # 邮箱: yanhao07@sina.com
  5. # 训练集
  6. # 每个样本点有3个分量 (x0,x1,x2)
  7. x = [(1, 0., 3), (1, 1., 3), (1, 2., 3), (1, 3., 2), (1, 4., 4)]
  8. # y[i] 样本点对应的输出
  9. y = [95.364, 97.217205, 75.195834, 60.105519, 49.342380]
  10. # 迭代阀值,当两次迭代损失函数之差小于该阀值时停止迭代
  11. epsilon = 0.0001
  12. # 学习率
  13. alpha = 0.01
  14. diff = [0, 0]
  15. max_itor = 1000
  16. error1 = 0
  17. error0 = 0
  18. cnt = 0
  19. m = len(x)
  20. # 初始化参数
  21. theta0 = 0
  22. theta1 = 0
  23. theta2 = 0
  24. while True:
  25. cnt += 1
  26. # 参数迭代计算
  27. for i in range(m):
  28. # 拟合函数为 y = theta0 * x[0] + theta1 * x[1] +theta2 * x[2]
  29. # 计算残差
  30. diff[0] = (theta0 + theta1 * x[i][1] + theta2 * x[i][2]) - y[i]
  31. # 梯度 = diff[0] * x[i][j]
  32. theta0 -= alpha * diff[0] * x[i][0]
  33. theta1 -= alpha * diff[0] * x[i][1]
  34. theta2 -= alpha * diff[0] * x[i][2]
  35. # 计算损失函数
  36. error1 = 0
  37. for lp in range(len(x)):
  38. error1 += (y[lp]-(theta0 + theta1 * x[lp][1] + theta2 * x[lp][2]))**2/2
  39. if abs(error1-error0) < epsilon:
  40. break
  41. else:
  42. error0 = error1
  43. print ' theta0 : %f, theta1 : %f, theta2 : %f, error1 : %f' % (theta0, theta1, theta2, error1)
  44. print 'Done: theta0 : %f, theta1 : %f, theta2 : %f' % (theta0, theta1, theta2)
  45. print '迭代次数: %d' % cnt

结果(截取部分):

  1. theta0 : 2.782632, theta1 : 3.207850, theta2 : 7.998823, error1 : 7.508687
  2. theta0 : 4.254302, theta1 : 3.809652, theta2 : 11.972218, error1 : 813.550287
  3. theta0 : 5.154766, theta1 : 3.351648, theta2 : 14.188535, error1 : 1686.507256
  4. theta0 : 5.800348, theta1 : 2.489862, theta2 : 15.617995, error1 : 2086.492788
  5. theta0 : 6.326710, theta1 : 1.500854, theta2 : 16.676947, error1 : 2204.562407
  6. theta0 : 6.792409, theta1 : 0.499552, theta2 : 17.545335, error1 : 2194.779569
  7. theta0 : 74.892395, theta1 : -13.494257, theta2 : 8.587471, error1 : 87.700881
  8. theta0 : 74.942294, theta1 : -13.493667, theta2 : 8.571632, error1 : 87.372640
  9. theta0 : 74.992087, theta1 : -13.493079, theta2 : 8.555828, error1 : 87.045719
  10. theta0 : 75.041771, theta1 : -13.492491, theta2 : 8.540057, error1 : 86.720115
  11. theta0 : 75.091349, theta1 : -13.491905, theta2 : 8.524321, error1 : 86.395820
  12. theta0 : 75.140820, theta1 : -13.491320, theta2 : 8.508618, error1 : 86.072830
  13. theta0 : 75.190184, theta1 : -13.490736, theta2 : 8.492950, error1 : 85.751139
  14. theta0 : 75.239442, theta1 : -13.490154, theta2 : 8.477315, error1 : 85.430741
  15. theta0 : 97.986390, theta1 : -13.221172, theta2 : 1.257259, error1 : 1.553781
  16. theta0 : 97.986505, theta1 : -13.221170, theta2 : 1.257223, error1 : 1.553680
  17. theta0 : 97.986620, theta1 : -13.221169, theta2 : 1.257186, error1 : 1.553579
  18. theta0 : 97.986735, theta1 : -13.221167, theta2 : 1.257150, error1 : 1.553479
  19. theta0 : 97.986849, theta1 : -13.221166, theta2 : 1.257113, error1 : 1.553379
  20. theta0 : 97.986963, theta1 : -13.221165, theta2 : 1.257077, error1 : 1.553278
  21. Done: theta0 : 97.987078, theta1 : -13.221163, theta2 : 1.257041
  22. 迭代次数: 3443

可以看到最后收敛到稳定的参数值。

注意:这里在选取alpha和epsilon时需要谨慎选择,可能不适的值会导致最后无法收敛。

参考文档:

python实现梯度下降算法

(转)梯度下降法及其Python实现的更多相关文章

  1. 梯度下降法的python代码实现(多元线性回归)

    梯度下降法的python代码实现(多元线性回归最小化损失函数) 1.梯度下降法主要用来最小化损失函数,是一种比较常用的最优化方法,其具体包含了以下两种不同的方式:批量梯度下降法(沿着梯度变化最快的方向 ...

  2. 梯度下降法实现-python[转载]

    转自:https://www.jianshu.com/p/c7e642877b0e 梯度下降法,思想及代码解读. import numpy as np # Size of the points dat ...

  3. 固定学习率梯度下降法的Python实现方案

    应用场景 优化算法经常被使用在各种组合优化问题中.我们可以假定待优化的函数对象\(f(x)\)是一个黑盒,我们可以给这个黑盒输入一些参数\(x_0, x_1, ...\),然后这个黑盒会给我们返回其计 ...

  4. paper 166:梯度下降法及其Python实现

    参考来源:https://blog.csdn.net/yhao2014/article/details/51554910 梯度下降法(gradient descent),又名最速下降法(steepes ...

  5. 简单线性回归(梯度下降法) python实现

    grad_desc .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { bord ...

  6. 梯度下降法VS随机梯度下降法 (Python的实现)

    # -*- coding: cp936 -*- import numpy as np from scipy import stats import matplotlib.pyplot as plt # ...

  7. 梯度下降法原理与python实现

    梯度下降法(Gradient descent)是一个一阶最优化算法,通常也称为最速下降法. 要使用梯度下降法找到一个函数的局部极小值,必须向函数上当前点对应梯度(或者是近似梯度)的反方向的规定步长距离 ...

  8. 最小二乘法 及 梯度下降法 运行结果对比(Python版)

    上周在实验室里师姐说了这么一个问题,对于线性回归问题,最小二乘法和梯度下降方法所求得的权重值是一致的,对此我颇有不同观点.如果说这两个解决问题的方法的等价性的确可以根据数学公式来证明,但是很明显的这个 ...

  9. 最小二乘法 及 梯度下降法 分别对存在多重共线性数据集 进行线性回归 (Python版)

    网上对于线性回归的讲解已经很多,这里不再对此概念进行重复,本博客是作者在听吴恩达ML课程时候偶然突发想法,做了两个小实验,第一个实验是采用最小二乘法对数据进行拟合, 第二个实验是采用梯度下降方法对数据 ...

随机推荐

  1. SQL Server 查询分析器键盘快捷方式

    下表列出 SQL Server 查询分析器提供的所有键盘快捷方式. 活动 快捷方式 书签:清除所有书签. CTRL-SHIFT-F2 书签:插入或删除书签(切换). CTRL+F2 书签:移动到下一个 ...

  2. Qt信号槽的一些事

    注:此文是站在Qt5的角度说的,对于Qt4部分是不适用的. 1.先说Qt信号槽的几种连接方式和执行方式. 1)Qt信号槽给出了五种连接方式: Qt::AutoConnection 0 自动连接:默认的 ...

  3. ubuntu下android源码的下载(最新)

    在ubuntu下下载android源码我断断续续搞了好几个月,希望大家不要向我学习啊!一次性搞定! 这里给大家一些建议啊,如果是看书的话看下书的出版日期,超过一年的基本上失效,网上的也是,特别是在国内 ...

  4. ZABBIX API简介及使用

    API简介 Zabbix API开始扮演着越来越重要的角色,尤其是在集成第三方软件和自动化日常任务时.很难想象管理数千台服务器而没有自动化是多么的困难.Zabbix API为批量操作.第三方软件集成以 ...

  5. Unity3D-光照贴图技术

    概念 Lightmapping光照贴图技术是一种增强静态场景光照效果的技术,其优点是可以通过较少的性能消耗使静态场景看上去更加真实,丰富,更加具有立体感:缺点是不能用来实时地处理动态光照.当游戏场景包 ...

  6. struts2将数据通过Json格式显示于EasyUI-datagrid数据表格

    1.搭建ssh开发环境 2.写好Dao.service等方法 3.建立DTO数据传输对象: package com.beichende.sshwork.user.web.dto; import jav ...

  7. Linux配置防火墙,开启80port、3306port 可能会遇到的小问题

     vi /etc/sysconfig/iptables -A INPUT -m state –state NEW -m tcp -p tcp –dport 80 -j ACCEPT(同意80端口通 ...

  8. 左侧固定宽度,右侧自适应宽度的CSS布局

    BI上有高手专门讨论了这种布局方法,但他用了较多的hack,还回避了IE6的dtd.我在实际使用中,发现回避掉IE6的dtd定义后,会导致ajax模态框无法居中(VS的一个控件,自动生成的代码,很难修 ...

  9. linux中,通过crontab -e编辑生成的定时任务,写在哪个文件中

    环境描述: 操作系统:Red Hat Enterprise Linux Server release 6.6 (Santiago) 内核版本:2.6.32-504.el6.x86_64 需求描述: 一 ...

  10. 安装并配置ROS环境1

    ros学习之路(原创博文,转载请标明出处-周学伟http://www.cnblogs.com/zxouxuewei/) 一.ros核心教程    1.安装并配置ROS环境: 注意: 学习这节课之前请按 ...