系列博客,原文在笔者所维护的github上:https://aka.ms/beginnerAI
点击star加星不要吝啬,星越多笔者越努力。

4.2 梯度下降法

有了上一节的最小二乘法做基准,我们这次用梯度下降法求解w和b,从而可以比较二者的结果。

4.2.1 数学原理

在下面的公式中,我们规定x是样本特征值(单特征),y是样本标签值,z是预测值,下标 \(i\) 表示其中一个样本。

预设函数(Hypothesis Function)

为一个线性函数:

\[z_i = x_i \cdot w + b \tag{1}\]

损失函数(Loss Function)

为均方差函数:

\[loss(w,b) = \frac{1}{2} (z_i-y_i)^2 \tag{2}\]

与最小二乘法比较可以看到,梯度下降法和最小二乘法的模型及损失函数是相同的,都是一个线性模型加均方差损失函数,模型用于拟合,损失函数用于评估效果。

区别在于,最小二乘法从损失函数求导,直接求得数学解析解,而梯度下降以及后面的神经网络,都是利用导数传递误差,再通过迭代方式一步一步逼近近似解。

4.2.2 梯度计算

计算z的梯度

根据公式2:
\[
{\partial loss \over \partial z_i}=z_i - y_i \tag{3}
\]

计算w的梯度

我们用loss的值作为误差衡量标准,通过求w对它的影响,也就是loss对w的偏导数,来得到w的梯度。由于loss是通过公式2->公式1间接地联系到w的,所以我们使用链式求导法则,通过单个样本来求导。

根据公式1和公式3:

\[
{\partial{loss} \over \partial{w}} = \frac{\partial{loss}}{\partial{z_i}}\frac{\partial{z_i}}{\partial{w}}=(z_i-y_i)x_i \tag{4}
\]

计算b的梯度

\[
\frac{\partial{loss}}{\partial{b}} = \frac{\partial{loss}}{\partial{z_i}}\frac{\partial{z_i}}{\partial{b}}=z_i-y_i \tag{5}
\]

4.2.3 代码实现

if __name__ == '__main__':

    reader = SimpleDataReader()
    reader.ReadData()
    X,Y = reader.GetWholeTrainSamples()

    eta = 0.1
    w, b = 0.0, 0.0
    for i in range(reader.num_train):
        # get x and y value for one sample
        xi = X[i]
        yi = Y[i]
        # 公式1
        zi = xi * w + b
        # 公式3
        dz = zi - yi
        # 公式4
        dw = dz * xi
        # 公式5
        db = dz
        # update w,b
        w = w - eta * dw
        b = b - eta * db

    print("w=", w)
    print("b=", b)

大家可以看到,在代码中,我们完全按照公式推导实现了代码,所以,大名鼎鼎的梯度下降,其实就是把推导的结果转化为数学公式和代码,直接放在迭代过程里!另外,我们并没有直接计算损失函数值,而只是把它融入在公式推导中。

4.2.4 运行结果

w= [1.71629006]
b= [3.19684087]

读者可能会注意到,上面的结果和最小二乘法的结果(w1=2.056827, b1=2.965434)相差比较多,这个问题我们留在本章稍后的地方解决。

代码位置

ch04, Level2

[ch04-02] 用梯度下降法解决线性回归问题的更多相关文章

  1. C / C ++ 基于梯度下降法的线性回归法(适用于机器学习)

    写在前面的话: 在第一学期做项目的时候用到过相应的知识,觉得挺有趣的,就记录整理了下来,基于C/C++语言 原贴地址:https://helloacm.com/cc-linear-regression ...

  2. tensorflow实现svm多分类 iris 3分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)

    # Multi-class (Nonlinear) SVM Example # # This function wll illustrate how to # implement the gaussi ...

  3. tensorflow实现svm iris二分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)

    iris二分类 # Linear Support Vector Machine: Soft Margin # ---------------------------------- # # This f ...

  4. 机器学习中梯度下降法原理及用其解决线性回归问题的C语言实现

    本文讲梯度下降(Gradient Descent)前先看看利用梯度下降法进行监督学习(例如分类.回归等)的一般步骤: 1, 定义损失函数(Loss Function) 2, 信息流forward pr ...

  5. 梯度下降法及一元线性回归的python实现

    梯度下降法及一元线性回归的python实现 一.梯度下降法形象解释 设想我们处在一座山的半山腰的位置,现在我们需要找到一条最快的下山路径,请问应该怎么走?根据生活经验,我们会用一种十分贪心的策略,即在 ...

  6. 最小二乘法 及 梯度下降法 分别对存在多重共线性数据集 进行线性回归 (Python版)

    网上对于线性回归的讲解已经很多,这里不再对此概念进行重复,本博客是作者在听吴恩达ML课程时候偶然突发想法,做了两个小实验,第一个实验是采用最小二乘法对数据进行拟合, 第二个实验是采用梯度下降方法对数据 ...

  7. 梯度下降法实现最简单线性回归问题python实现

    梯度下降法是非常常见的优化方法,在神经网络的深度学习中更是必会方法,但是直接从深度学习去实现,会比较复杂.本文试图使用梯度下降来优化最简单的LSR线性回归问题,作为进一步学习的基础. import n ...

  8. 机器学习---用python实现最小二乘线性回归算法并用随机梯度下降法求解 (Machine Learning Least Squares Linear Regression Application SGD)

    在<机器学习---线性回归(Machine Learning Linear Regression)>一文中,我们主要介绍了最小二乘线性回归算法以及简单地介绍了梯度下降法.现在,让我们来实践 ...

  9. 简单线性回归(梯度下降法) python实现

    grad_desc .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { bord ...

随机推荐

  1. HDU5950 矩阵快速幂(巧妙的递推)

    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5950 题意:f[n] = 2*f[n-2] + f[n-1] + n^4 思路:对于递推题而言,如果递 ...

  2. 前端技术之:如何在控制台将JS class实例输出为JSON格式

    有一个类: class Point { constructor(x, y) { this.x = x; this.y = y; } } 如果我们在控制台中输出其实例: console.log(new ...

  3. 图片瀑布流,so easy!

    什么是图片瀑布流 用一张花瓣网页的图片布局可以很清楚看出图片瀑布流的样子: 简单来说,就是有很多图片平铺在页面上,每张图片的宽度相同,但是高度不同,这样错落有致的排列出 n 列的样子很像瀑布,于是就有 ...

  4. [Neo4j]Conda虚拟环境中安装python-igraph

    neo4j算法需要用到python-igraph包,但试过很多方法,都失败了 pip install python-igraph 安装失败, 提示C core of igraph 没有安装. 在con ...

  5. 机器学习环境搭建安装TensorFlow1.13.1+Anaconda3.5.3+Python3.7.1+Win10

    安装Python3.7.1 此处不再赘述安装过程,作为记录 安装Anaconda3.5.3 Anaconda3-5.3.0-Windows-x86_64.exe 方案1. 可以直接从官网https:/ ...

  6. Spring Boot 2.X(十七):应用监控之 Spring Boot Admin 使用及配置

    Admin 简介 Spring Boot Admin 是 Spring Boot 应用程序运行状态监控和管理的后台界面.最新UI使用vue.js重写里. Spring Boot Admin 为已注册的 ...

  7. RocketMQ ACL使用指南

    目录 1.什么是ACL? 2.ACL基本流程图 3.如何配置ACL 3.1 acl配置文件 3.2 RocketMQ ACL权限可选值 3.3.权限验证流程 4.使用示例 4.1 Broker端安装 ...

  8. css3 svg 物体跟随路径动画教程

    css3 svg 物体跟随路径动画教程https://www.jianshu.com/p/992488f3f3fc

  9. 深入理解计算机系统 第三章 程序的机器级表示 part2

    这周由于时间和精力有限,只读一小节:3.4.4  压入和弹出栈数据 栈是一种特殊的数据结构,遵循“后进先出”的原则,可以用数组实现,总是从数组的一端插入和删除元素,这一端被称为栈顶. 栈有两个常用指令 ...

  10. man 与 help

    man帮助文档被划分为节 序号 节号 说明 1 1 命令帮助信息 2 2 系统调用函数的帮助信息(内核提供的接口函数) 3 3 库函数帮助信息 4 4 设备文件帮助信息 5 5 配置文档帮助说明 6 ...