在上次的代码重写中使用了sklearn.LinearRegression 类进行了线性回归之后猜测其使用的是常用的梯度下降+反向传播算法实现,所以今天来学习它的源码实现。但是在看到源码的一瞬间突然有种怀疑人生的感觉,我是谁?我在哪?果然大佬的代码只能让我膜拜。

在一目十行地看完代码之后,我发现了一个问题,梯度的单词是gradient,一般在代码中会使用缩写grad 来表示梯度,而在这个代码中除了Gram 之外竟然没有一个以'g' 开头的单词,更不用说gradient 了。那么代码中包括注释压根没提到过梯度,是不是说明这里根本没有使用梯度下降算法呢,换言之就是是否还有其他方法来实现最小二乘法的线性回归呢?带着这个疑问,我开始仔细阅读LinearRegression.fit() 函数。

首先都是参数处理,这里虽然看不太懂但是也能大概知道他在做什么,所以可以跳过。

然后来到了核心代码,核心代码中使用的几个判断:

 1 if self.positive:
2 if y.ndim < 2:
3 pass
4 else:
5 pass
6 elif sp.issparse(X):
7 if y.ndim < 2:
8 pass
9 else:
10 pass
11 else:
12 pass

self.positive 是在使用密集矩阵的时候设置的参数,y.ndim 表示y 的维度,简单来说就是y 中有几个[],所以大概能知道代码将密集矩阵与稀疏矩阵区分开,并且将一维矩阵与多维矩阵区分开,意味着不同的类别使用不同的方法。

if self.positive 分段解析

1 if self.positive:
2 if y.ndim < 2:
3 self.coef_, self._residues = optimize.nnls(X, y)
4 else:
5 # scipy.optimize.nnls cannot handle y with shape (M, K)
6 outs = Parallel(n_jobs=n_jobs_)(
7 delayed(optimize.nnls)(X, y[:, j])
8 for j in range(y.shape[1]))
9 self.coef_, self._residues = map(np.vstack, zip(*outs))

可以看出y 的维度小于2 的话使用optimize.nnls() 方法,否则进行其他处理,因为“scipy.optimize.nnls cannot handle y with shape (M, K)”,但看到之后也调用了optimize.nnls,所以应该是将矩阵处理成可以使用的样子。

并且值得注意的是这里使用了Parallel(n_jobs=n_jobs_)(delayed(optimize.nnls)(X, y[:, j])for j in range(y.shape[1])) 的调用方式,也就是形如fun(x)(y) 的方式,这意味着函数内定义了另一个函数,第一个括号是fun 的参数,第二个括号是给fun 函数内定义的函数的参数。

elif sp.issparse(X) 分段解析

 1 elif sp.issparse(X):
2 X_offset_scale = X_offset / X_scale
3
4
5 def matvec(b):
6 return X.dot(b) - b.dot(X_offset_scale)
7
8
9 def rmatvec(b):
10 return X.T.dot(b) - X_offset_scale * np.sum(b)
11
12
13 X_centered = sparse.linalg.LinearOperator(shape=X.shape,
14 matvec=matvec,
15 rmatvec=rmatvec)
16
17 if y.ndim < 2:
18 out = sparse_lsqr(X_centered, y)
19 self.coef_ = out[0]
20 self._residues = out[3]
21 else:
22 # sparse_lstsq cannot handle y with shape (M, K)
23 outs = Parallel(n_jobs=n_jobs_)(
24 delayed(sparse_lsqr)(X_centered, y[:, j].ravel())
25 for j in range(y.shape[1]))
26 self.coef_ = np.vstack([out[0] for out in outs])
27 self._residues = np.vstack([out[3] for out in outs])

可以看到先是对数据进行了处理,然后调用了sparse_lsqr() 函数。

剩余分段解析

1 else:
2 self.coef_, self._residues, self.rank_, self.singular_ = \
3 linalg.lstsq(X, y)
4 self.coef_ = self.coef_.T
5
6 if y.ndim == 1:
7 self.coef_ = np.ravel(self.coef_)
8 self._set_intercept(X_offset, y_offset, X_scale)

使用了linalg.lstsq() 函数。

以上我们可以看到代码中一共使用了3 个方法来实现线性回归:optimize.nnls()、sparse_lsqr()、linalg.lstsq()

optimize.nnls() 分析

NNLS 即非负正则化最小二乘法,代码实现由scipy.optimize.nnls提供,这里只是将其封装起来,在源码的注释中提到该算法的FORTRAN 代码在Charles L. Lawson 与Richard J. Hanson 两位教授于1987 年所著的《Solving Least Squares Problems》中发布。“The algorithm is an active set method. It solves the KKT (Karush-Kuhn-Tucker) conditions for the non-negative least squares problem.” 可惜由于本人水平有限,并不能从书中或者此处的代码中学到该算法的精髓,只能先挖一个坑,以后有所提高了再来研究该算法。

sparse_lsqr() 分析

LSQR 即最小二乘QR分解算法,代码实现由scipy.sparse.linalg.lsqr 类提供,这里只是将其封装起来,在文档中可以看到:

LSQR uses an iterative method to approximate the solution.  The number of iterations required to reach a certain accuracy depends strongly on the scaling of the problem.  Poor scaling of the rows or columns of A should therefore be avoided where possible.

同时也给出了参考文献:

[1] C. C. Paige and M. A. Saunders (1982a). "LSQR: An algorithm for sparse linear equations and sparse least squares", ACM TOMS 8(1), 43-71.
[2] C. C. Paige and M. A. Saunders (1982b). "Algorithm 583. LSQR: Sparse linear equations and least squares problems", ACM TOMS 8(2), 195-209.
[3] M. A. Saunders (1995). "Solution of sparse rectangular systems using LSQR and CRAIG", BIT 35, 588-604.

可以看到LSQR 算法是Paige 和Saunders 于1982 年提出的一种方法,但水平有限,暂时并不清楚其中原理。

linalg.lstsq() 分析

LSTSQ 是 LeaST SQuare (最小二乘)的意思,也就是普通的最小二乘法。代码实现由scipy.linalg.lstsq 提供,这里只是将其封装起来。通过官方文档提供的源代码链接,我找到了lstsq 函数的源代码,注释中提到了:

Which LAPACK driver is used to solve the least-squares problem.
Options are ``'gelsd'``, ``'gelsy'``, ``'gelss'``. Default(``'gelsd'``) is a good choice. However, ``'gelsy'`` can be slightly faster on many problems. ``'gelss'`` was used historically. It is generally slow but uses less memory.

也就是说有三个选项:gelsd(默认推荐)、gelsy(可能稍快)、gelss(使用内存少),那么来看看他们分别使用什么方法来解决最小二乘法吧。

 1 if driver in ('gelss', 'gelsd'):
2 if driver == 'gelss':
3 lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
4 v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork,
5 overwrite_a=overwrite_a,
6 overwrite_b=overwrite_b)
7
8 elif driver == 'gelsd':
9 if real_data:
10 lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
11 x, s, rank, info = lapack_func(a1, b1, lwork,
12 iwork, cond, False, False)
13 else: # complex data
14 lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
15 nrhs, cond)
16 x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
17 cond, False, False)
18 elif driver == 'gelsy':
19 lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
20 jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
21 v, x, j, rank, info = lapack_func(a1, b1, jptv, cond,
22 lwork, False, False)

看到调用了lapack_func,但是找了一下spicy 并没有发现这个函数,于是搜索lapack_func 找到了定义:

lapack_func, lapack_lwork = get_lapack_funcs((driver,'%s_lwork' % driver),(a1, b1))

原来它调用了get_lapack_funcs 函数,于是去找该函数的文档,从文档中看到了该方法的用处:“This routine automatically chooses between Fortran/C interfaces. Fortran code is used whenever possible for arrays with column major order. In all other cases, C code is preferred.”,原来这是一个Fortran 语言和C 语言的接口且首选C 语言,并且从源码中我看到了它是调用_get_funcs 实现,那3 个单词是C 语言的函数名,所幸我对C 语言的熟悉程度打过Python,于是去找C 语言API 进行学习。

gelsdComputes the minimum-norm solution to a linear least squares problem using the singular value decomposition of A and a divide and conquer method. 分治法的奇异值分解找出最优解。

gelssComputes the minimum-norm solution to a linear least squares problem using the singular value decomposition of A. 使用奇异值分解找出最优解。

gelsy:Computes the minimum-norm solution to a linear least squares problem using a complete orthogonal factorization of A. 使用完全正交分解的方式找出最优解。

至此对于所有函数的分析基本上是结束了,密集矩阵使用的是NNLS 算法,稀疏矩阵使用的是LSQR 算法,其他的使用的则是最常用的最小二乘法算法。回到开头的问题,为什么没有使用梯度下降呢?难道梯度下降并不是解决最小二乘或者说是线性回归的算法吗?在网上查阅了很多材料之后我发现自己的问题:我陷入了误区。

先来说说最小二乘法,最小二乘法在我初高中的时候学习简单线性回归的时候就接触了,根据百度百科的词条,最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小,也就是说最小二乘法的公式是目标函数 = MIN( ∑ (预测值 – 实际值)² )。这就意味着如果人工计算的话就需要穷举所有的函数,计算他们的损失然后找出最小的那个函数。

再来说说梯度下降,梯度下降是迭代法的一种,通过更新梯度来找到损失最小的那个函数,不知你发现问题了吗?最小二乘法与梯度下降是在做同一件事情,也就是最优化问题,两个是并行的关系,并不存在谁解决谁。百度百科中关于梯度下降的词条中提到:“梯度下降是迭代法的一种,可以用于求解最小二乘问题(线性和非线性都可以)。在求解机器学习算法的模型参数,即无约束优化问题时,梯度下降(Gradient Descent)是最常采用的方法之一,另一种常用的方法是最小二乘法。”这里很清晰的指出了最小二乘法和梯度下降法的关系。

再来说说另一个我之前并不知道的概念:最小二乘准则。百度百科中提到这是一种对于偏差程度的评估准则,与上两者不同,上述的算法都是基于最小二乘准则提出的对于最小二乘法优化问题的解决方案,也就是如果不穷举的话如何找到最小二乘法的最优解。

列出我查阅的资料:

1、知乎-最小二乘法和梯度下降法有哪些区别?

2、百度百科-梯度下降

3、百度百科-最小二乘法

总结

1、虽然找到了sklearn.LinearRegression 类中对于线性回归的算法及实现,但发现并没有使用到梯度下降法,而是使用最小二乘法找到最优解,解开了我对最小二乘法与梯度下降到误解,但由于之前并未从事过算法研究与数学分析,对相应的算法一知半解,所以这里的代码难以看懂,只能就此作罢,学习了相应的算法之后再来学习代码实现。
2、在学习源码的过程中以及写这篇文章的过程中发现对于python的有些概念还是不太清晰,比如函数和方法还有fun(x)(y) 的调用方式,所以能看到上文中有些使用“函数”有些使用“方法”,可能并不对应,但之后熟悉了才能修改。
3、源码中的注释充斥着许多数学词汇,读起来让我异常头疼,几乎都要使用翻译软件才能理解,同时有些平常使用的词汇我也不懂,这个时候英语的作用就十分必要了。
4、总的来说,基础不扎实,水平不高,所以对于其中的精髓难以理解,同时可能文章中错漏百出,但我并未发现,这就是目前的问题或者说困境,勤加学习才能脱离。

机器学习03-sklearn.LinearRegression 源码学习的更多相关文章

  1. 【iScroll源码学习03】iScroll事件机制与滚动条的实现

    前言 想不到又到周末了,周末的时间要抓紧学习才行,前几天我们学习了iScroll几点基础知识: 1. [iScroll源码学习02]分解iScroll三个核心事件点 2. [iScroll源码学习01 ...

  2. Qt Creator 源码学习笔记03,大型项目如何管理工程

    阅读本文大概需要 6 分钟 一个项目随着功能开发越来越多,项目必然越来越大,工程管理成本也越来越高,后期维护成本更高.如何更好的组织管理工程,是非常重要的 今天我们来学习下 Qt Creator 是如 ...

  3. 【iScroll源码学习04】分离IScroll核心

    前言 最近几天我们前前后后基本将iScroll源码学的七七八八了,文章中未涉及的各位就要自己去看了 1. [iScroll源码学习03]iScroll事件机制与滚动条的实现 2. [iScroll源码 ...

  4. (转)干货|这篇TensorFlow实例教程文章告诉你GANs为何引爆机器学习?(附源码)

    干货|这篇TensorFlow实例教程文章告诉你GANs为何引爆机器学习?(附源码) 该博客来源自:https://mp.weixin.qq.com/s?__biz=MzA4NzE1NzYyMw==& ...

  5. Vue源码学习(一):调试环境搭建

    最近开始学习Vue源码,第一步就是要把调试环境搭好,这个过程遇到小坑着实费了点功夫,在这里记下来 一.调试环境搭建过程 1.安装node.js,具体不展开 2.下载vue项目源码,git或svn等均可 ...

  6. JDK1.8源码分析01之学习建议(可以延伸其他源码学习)

    序言:目前有个计划就是准备看一下源码,来提升自己的技术实力.同时现在好多面试官都喜欢问源码,问你是否读过JDK源码等等? 针对如何阅读源码,也请教了我的老师.下面就先来看看老师的回答,也许会有帮助呢. ...

  7. Mybatis源码学习第八天(总结)

    源码学习到这里就要结束了; 来总结一下吧 Mybatis的总体架构 这次源码学习我们,学习了重点的模块,在这里我想说一句,源码的学习不是要所有的都学,一行一行的去学,这是错误的,我们只需要学习核心,专 ...

  8. Mybatis源码学习第六天(核心流程分析)之Executor分析

    今Executor这个类,Mybatis虽然表面是SqlSession做的增删改查,其实底层统一调用的是Executor这个接口 在这里贴一下Mybatis查询体系结构图 Executor组件分析 E ...

  9. [阿里DIN]从论文源码学习 之 embedding_lookup

    [阿里DIN]从论文源码学习 之 embedding_lookup 目录 [阿里DIN]从论文源码学习 之 embedding_lookup 0x00 摘要 0x01 DIN代码 1.1 Embedd ...

随机推荐

  1. Pandas初体验

    目录 Pandas 一.简介 1.安装 2.引用方法 二.series 1.创建方法 2.缺失数据处理 2.1 什么是缺失值 2.2 NaN特性 2.3 填充NaN 2.4 删除NaN 2.5 其他方 ...

  2. 查看浏览器 请求网页 中 header body cookie

    command + alt + i   进入开发者工具 重新刷新页面进行请求URL 进入Network  选中某个url 右侧会展示详细信息

  3. 【图像处理】使用OpenCV+Python进行图像处理入门教程(三)色彩空间

    这篇随笔介绍使用OpenCV进行图像处理的第三章 色彩空间. 3  色彩空间 之前的介绍,大多是基于BGR色彩空间进行的,但针对不同的实际情况,研究人员提出了许多色彩空间,它们都有各自擅长处理的领域. ...

  4. 【转载】UML类图中箭头和线条的含义和用法

    文章转载自 http://blog.csdn.net/hewei0241/article/details/7674450 https://blog.csdn.net/iamherego/article ...

  5. CVE-2018-2628-WLS Core Components 反序列化

    漏洞参考 https://blog.csdn.net/csacs/article/details/87122472 漏洞概述:在 WebLogic 里,攻击者利用其他rmi绕过weblogic黑名单限 ...

  6. 不一样的软件们——GitHub 热点速览 v.21.10

    作者:HelloGitHub-小鱼干 创意,是程序员的一个身份代名词,一样的软件有不一样的玩法.比如,你可以像用 git 一样操作一个 SQL 数据库,dolt 就是这样的数据库.又比如,你可以只写文 ...

  7. JVM线上问题排查

    前言 本文介绍服务器内运行的 Java 应用产生的 OOM 问题 和 CPU 100% 的问题定位 1. 内存 OOM 问题定位 某Java服务(比如进程id pid 为 3320)出现OOM,常见的 ...

  8. 【工具】 memtester内存压力测试工具

    作者:李春港 出处:https://www.cnblogs.com/lcgbk/p/14497838.html 目录 一.简介 二.Memtester安装 三.使用说明 四.测试示例 一.简介 mem ...

  9. webpack4.x 从零开始配置vue 项目(二)基础搭建loader 配置 css、scss

    序 上一篇已经把基本架子搭起来了,现在来增加css.scss.自动生成html.css 提取等方面的打包.webpack 默认只能处理js模块,所以其他文件类型需要做下转换,而loader 恰恰是做这 ...

  10. POJ_1458 Common Subsequence 【LCS】

    一.题目 Common Subsequence 二.分析 比较基础的求最长升序子序列. $DP[i][j]$表示的是字符串$S1[1...i]$与$S2[1...j]$的最长公共子序列长度. 状态转移 ...