背景与原理:

线性回归是机器学习建模中最为简单的模型,也是计算起来最为直观的模型

所谓线性回归,我们要建立的是这样的模型:对一组数据,每组数据形如$(x_{1},...,x_{n},y)$,我们希望构造一个线性函数$h_{\theta}(X)=\sum_{i=0}^{n}\theta_{i}x_{i}$,使得$h_{\theta}(X)$与$y$的差距最小

对于上述式子,我们给一个记号:$X=(x_{0},...,x_{n})$,为了体现线性函数中常数项的存在,我们令$x_{0}=1$,因此$X=(1,x_{1},...,x_{n})$,$\theta=(\theta_{0},...,\theta_{n})$

那么我们把二者都看做列向量(常用的约定),上式实际上在说:$h_{\theta}(X)=\theta^{T}X$

我们约定损失函数为$L^{2}$损失函数(这里可以用最大似然估计证明选取该损失函数的合理性,之后有空填坑吧...),即对一组参数$\theta$,损失函数$J(\theta)=\dfrac{1}{2}\sum_{i=1}^{m}(y_{i}-h_{\theta}(X_{i}))^{2}$

这表示一共有$m$组形如$(X_{i},y_{i})$的数据的总损失

那么我们令$Y=(y_{1},...,y_{m})$(视作一个列向量),$X=(X_{1},...,X_{m})$($X_{i}$为一个$(n+1)$维列向量,$X$是一个$(n+1)*m$的矩阵)上式即为:$J(\theta)=\dfrac{1}{2}(X^{T}\theta-Y)^{T}(X^{T}\theta-Y)$(目前前面出现的$X_{i},Y,\theta$三个向量都视作列向量,始终不要忘记这一点)

那么我们现在想求出$J(\theta)$最小时$\theta$的取值,那么我们对$J(\theta)$求偏导,得到:

$\dfrac{\partial J(\theta)}{\partial \theta}=\dfrac{1}{2}\dfrac{\partial(\theta^{T}XX^{T}\theta -Y^{T}X^{T}\theta-\theta^{T} XY+Y^{T}Y)}{\partial \theta}$

那么也就是:(梯度向量如果按列向量解释):

$\dfrac{\partial J(\theta)}{\partial \theta}=\dfrac{1}{2}(2XX^{T}\theta -XY- XY)=XX^{T}\theta-XY$

对这步推导感兴趣的话可以参考多元函数微分学,这里我们只应用了一些非常简单的结论,如:

$\dfrac{\partial \alpha^{T} \beta}{\partial \alpha}=\dfrac{\partial \beta^{T}\alpha}{\partial \alpha}=\beta$(当然,由于我们先天规定了梯度向量是行向量还是列向量,所以结果可能是$\beta$或$\beta^{T}$)

欲使$J(\theta)$取得极小值,我们要求各个偏导数为0,因此我们有:

$XX^{T}\theta-XY=0$,即:$\theta=(XX^{T})^{-1}XY$

但是当我们的$X$很大的时候,无论是矩阵乘法还是矩阵求逆的计算量都太大了(仅考虑一个含有$10000$组$1000$维数据的数据集,其矩阵乘法的运算量就很巨大,更遑论矩阵求逆的计算量了)

梯度下降:

考虑一个一元函数,其在$x_{0}$点处的导数是$f^{'}(x_{0})$,那么这个导数值决定了这个函数的走势:如果导数值小于零,那么我们沿正向前进一小段函数值就会下降,而反之我们沿负向前进一小段函数值就会下降(这一点可以由一元微分公式决定)

那么对于多元函数也是同理的,但是对于多元函数而言,我们有很多个可以前进的方向,那么我们选取一个能使函数值下降最快的方向,由多元函数微分学告诉我们,这实际上就是函数的在这一点梯度的方向,所以我们使用的公式是:

$\hat{\theta}=\theta-\alpha \dfrac{\partial J(\theta)}{\partial \theta}$

这里的$\alpha$被称为学习率,表示我们具体会在梯度方向上前进多远

当然,我们也可以具体到每个参数:

$\hat{\theta_{j}}=\theta_{j}-\alpha \dfrac{\partial J(\theta)}{\partial \theta_{j}}$

而我们总是有:

$\dfrac{\partial J(\theta)}{\partial \theta_{j}}=\dfrac{\partial \sum_{i=1}^{n}(h_{\theta}(X_{i})-Y_{I})^{2})}{2\partial \theta_{j}}=\dfrac{\sum_{i=1}^{m}(h_{\theta}(X_{i})-Y_{i}) \partial (h_{\theta}(X_{i})-Y_{i})}{\partial  \theta_{j} }$

同时我们最开始就有:

$h_{\theta}(X_{i})=\sum_{k=0}^{n}\theta_{k}x_{ik}$,其中$X_{i}=(1,x_{i1},...,x_{in})$

于是我们就得到了表达式:

$\dfrac{\partial (h_{\theta}(X_{i})-Y_{i})}{\partial  \theta_{j}}=x_{ij}$

这样我们就有:

$\dfrac{\partial J(\theta)}{\partial \theta_{j}}=\sum_{i=1}^{m}(h_{\theta}(X_{i})-Y_{i})x_{ij}$

于是最后的迭代式即为:

$\hat{\theta_{j}}=\theta_{j}-\alpha \sum_{i=1}^{m}(h_{\theta}(X_{i})-Y_{i})x_{ij}$

(当然你也可以写成$\hat{\theta_{j}}=\theta_{j}+\alpha \sum_{i=1}^{m}(Y_{i}-h_{\theta}(X_{i}))x_{ij}$,但这就是个人喜好的形式问题了)

在迭代过程中,一般我们设置一个边界,当相邻两次迭代的差距小于这个边界时就认为收敛。

在梯度下降的过程中,由于初值和学习率$\alpha$都是人工设定的,而梯度下降法对这些东西又是敏感的,因为存在全局最小值/局部极小值/鞍点(梯度为0但甚至不为局部极小值的点,在一元函数中可以考察$y=x^{3}$在$x=0$的情况)的影响,同时如果学习率过大会在最小值周围跳跃而无法收敛,如果过小则学习速度太慢。

代码实现:

import numpy as np
from scipy import stats
import matplotlib.pyplot as plt def my_linear(X,Y,theta,alpha,siz,eps,dep):
theta=(np.matrix(X@X.T).I)@X@Y.T
return np.array(theta.T) delta=1.0 while delta>eps:
#temp=Y-(theta @ X)
#new_theta=theta.copy() #for i in range(0,siz):
# new_theta+=alpha*temp[0,i]*X[:,i] new_theta=(theta.T-alpha*(X@X.T@theta.T-X@Y.T)).T delt=new_theta-theta
delta=(delt.T @ delt)[0,0] theta=new_theta return theta x=np.arange(0.,10.,0.2)
siz=len(x)
y=2*x+0.5+np.random.randn(siz)
y=np.vstack([y])
x0=np.full(siz,1.)
X=np.vstack([x0,x])
theta=np.full(2,1.)
theta=np.vstack([theta]) my_theta= my_linear(X,y,theta,1e-3,siz,1e-3,0) predict_y=(my_theta@X) plt.plot(x,y[0,:],'g*')
plt.plot(x,predict_y[0,:],'r')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

用numpy来实现这样的功能,给出了三种实现方式,一种是直接使用矩阵求逆(在numpy中,对于矩阵A和B,矩阵转置是A.T,矩阵的逆是A.I,矩阵乘法是A@B)来实现,可以看到这样的效果是最好的;

另两种则是两种不同的迭代方式,一种是通过矩阵计算直接整体迭代,另一种则是逐个迭代,这两种方式其实是等价的

需要注意的是,这里线性回归的表现好坏和初值与学习率的选取有直接关系,可以看到如果学习率选择0.1那么整个过程无法收敛,而这里取的初值会让迭代法最后得到的结果处于一个局部极小值或鞍点而无法得到全局最小值(初值换成0.1的话能取得明显更好的效果)

当然,这样的常用方法也是有库函数的,我们也可以使用不同的包

import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression x=np.arange(0.,10.,0.2)
siz=len(x)
y=2*x+0.5+np.random.randn(siz)
Y=np.vstack([y]).T
X=np.vstack([x]).T l=LinearRegression()
l.fit(X,Y) plt.scatter(x,y,c='g')
plt.plot(X,l.predict(X),c='r')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

如果我们使用sklearn的话,我们可以用LinearRegression构造一个线性模型,然后直接把数据fit进去(这里数据的解释方法是这样的:每个$X_{i}$视作行向量,对应于一个$y_{i}$,这样拼起来之后得到的$Y$应当是一个列向量,而$X$则是一个矩阵,行数为数据个数,列数为维数,这是常用记法,和前面推导中使用的记法有一定区别)

实际上,我们生成一个LinearRegression对象之后,可以调用fit方法按最小化均方误差作为损失函数来训练一个线性模型,而训练结果被储存在成员变量coef__中,即如果想查看训练出的参数,我们应该调用l.coef__,而如果我们想对一组数据集进行预测,我们要是用的是l.predict。

import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression x=np.arange(0.,10.,0.2)
siz=len(x)
y=2*x+0.5+np.random.randn(siz)
Y=np.vstack([y]).T
X=np.vstack([x]).T slope,intercept,r_value,p_value,slope_std_error=stats.linregress(x,y) plt.scatter(x,y,c='g')
plt.plot(x,slope*x+intercept,c='r')
plt.show()

另一方面,我们也可以用scipy里的stats来实现线性回归,问题是这里只支持一维线性回归,但是可以同时算出r,p,标准差,得到的slope是斜率,intercept是截距。

小结与改进:

利用普通最小二乘法进行多元线性回归时,在出现数据共线性等问题时会有较大误差,因此在改进时可以选用偏最小二乘法——对一组数据$(x_{i1},...,x_{in},y_{i})$,令$\hat{x_{ij}}=\dfrac{x_{ij}-\overline{x_{j}}}{s_{j}}$,$\hat{y_{i}}=\dfrac{y_{i}-\overline{y}}{s_{y}}$(即每一维减去自己的均值后再除以自己的标准差,这样可以把一组数据变成均值为$0$,方差为$1$的数据),然后再进行最小二乘法,这样得到的结果效果一般更优

python机器学习——线性回归方法的更多相关文章

  1. python机器学习---线性回归案例和KNN机器学习案例

    散点图和KNN预测 一丶案例引入 # 城市气候与海洋的关系研究 # 导包 import numpy as np import pandas as pd from pandas import Serie ...

  2. 吴裕雄 python 机器学习——线性回归模型

    import numpy as np from sklearn import datasets,linear_model from sklearn.model_selection import tra ...

  3. 机器学习经典算法具体解释及Python实现--线性回归(Linear Regression)算法

    (一)认识回归 回归是统计学中最有力的工具之中的一个. 机器学习监督学习算法分为分类算法和回归算法两种,事实上就是依据类别标签分布类型为离散型.连续性而定义的. 顾名思义.分类算法用于离散型分布预測, ...

  4. 机器学习|线性回归算法详解 (Python 语言描述)

    原文地址 ? 传送门 线性回归 线性回归是一种较为简单,但十分重要的机器学习方法.掌握线性的原理及求解方法,是深入了解线性回归的基本要求.除此之外,线性回归也是监督学习回归部分的基石. 线性回归介绍 ...

  5. Python机器学习/LinearRegression(线性回归模型)(附源码)

    LinearRegression(线性回归) 2019-02-20  20:25:47 1.线性回归简介 线性回归定义: 百科中解释 我个人的理解就是:线性回归算法就是一个使用线性函数作为模型框架($ ...

  6. Python机器学习笔记:常用评估指标的用法

    在机器学习中,性能指标(Metrics)是衡量一个模型好坏的关键,通过衡量模型输出y_predict和y_true之间的某种“距离”得出的. 对学习器的泛化性能进行评估,不仅需要有效可行的试验估计方法 ...

  7. 只需十四步:从零开始掌握 Python 机器学习(附资源)

    分享一篇来自机器之心的文章.关于机器学习的起步,讲的还是很清楚的.原文链接在:只需十四步:从零开始掌握Python机器学习(附资源) Python 可以说是现在最流行的机器学习语言,而且你也能在网上找 ...

  8. Python机器学习笔记:sklearn库的学习

    网上有很多关于sklearn的学习教程,大部分都是简单的讲清楚某一方面,其实最好的教程就是官方文档. 官方文档地址:https://scikit-learn.org/stable/ (可是官方文档非常 ...

  9. Python机器学习笔记:不得不了解的机器学习面试知识点(1)

    机器学习岗位的面试中通常会对一些常见的机器学习算法和思想进行提问,在平时的学习过程中可能对算法的理论,注意点,区别会有一定的认识,但是这些知识可能不系统,在回答的时候未必能在短时间内答出自己的认识,因 ...

  10. Python & 机器学习之项目实践

    机器学习是一项经验技能,经验越多越好.在项目建立的过程中,实践是掌握机器学习的最佳手段.在实践过程中,通过实际操作加深对分类和回归问题的每一个步骤的理解,达到学习机器学习的目的. 预测模型项目模板不能 ...

随机推荐

  1. Task :app:lintVitalRelease FAILED

    错误信息:Task :app:lintVitalRelease FAILED 问题原因:dl.google.com 无法连接 解决办法: 修改hosts(推荐)通过在线查询ip网站,找到dl.goog ...

  2. 【博客】如何在Github上创建博客

    [博客]如何在Github上创建博客 1. 安装nodejs windows安装npm教程--nodejs 2. 安装hexo npm install -g hexo-cli 3. 搭建博客 $ he ...

  3. oracle中 null 和 '' 和' '的比对

    SELECT LENGTH(''),LENGTH(NULL),LENGTH(' '),LENGTH(TRIM(' ')) FROM dual; 返回结果为 null,null,1,null 也就是在o ...

  4. 【服务器数据恢复】Linux服务器分区不能挂载的数据恢复案例

    服务器数据恢复环境:某品牌PowerEdge系列服务器,磁盘阵列存储型号为该品牌MD3200系列存储,分配lun:linux centos 7操作系统,EXT4文件系统. 服务器故障:服务器在工作中由 ...

  5. 使用ADB拷贝Android设备的文件夹

    在当前目录下执行,拷贝到当前目录.   拷贝照片 adb pull sdcard/DCIM   删除照片 adb shell rm -rvf sdcard/DCIM   拷贝图片 adb pull s ...

  6. top单核与32C--CPU爆表

    linux的cpu使用频率是根据cpu个数和核数决定的 top, 然后你按一下键盘的1,这就是单个核心的负载,不然是所有核心的负载相加,自然会超过100 单核为100%,服务器是32核的,下面基本用了 ...

  7. ROS2

    ROS2核心概念 节点 创建节点流程 编程接口初始化 创建节点并初始化 实现节点功能 销毁节点并关闭接口 #!/usr/bin/env python3 import rclpy # ROS2 Pyth ...

  8. OTP: gen_server的简单应用

    1.确定回调模块名 2.编写接口函数 3.在回调模块里编写六个必需的回调函数 1.确定回调模块名 my_bank 2.编写接口方法 start() 打开银行 stop() 关闭银行 new_accou ...

  9. iOS MacOS 系统时间(时间戳)格式化

    #pragma mark -原始数据是20220608155116,加工成2022/06/08 15:51:16 -(NSString *)timeString:(NSString *)toIndex ...

  10. 记录解决方案(sqlserver篇)

    一个月的补卡次数不超过三次(即统计一个月内某人的补卡次数) 表结构是某人一天内的四次打卡状态,这样是统计当月补卡的天数了(错误) select count(*) from [Proc_HR_Punch ...