机器学习03-sklearn.LinearRegression 源码学习
在上次的代码重写中使用了sklearn.LinearRegression 类进行了线性回归之后猜测其使用的是常用的梯度下降+反向传播算法实现,所以今天来学习它的源码实现。但是在看到源码的一瞬间突然有种怀疑人生的感觉,我是谁?我在哪?果然大佬的代码只能让我膜拜。
在一目十行地看完代码之后,我发现了一个问题,梯度的单词是gradient,一般在代码中会使用缩写grad 来表示梯度,而在这个代码中除了Gram 之外竟然没有一个以'g' 开头的单词,更不用说gradient 了。那么代码中包括注释压根没提到过梯度,是不是说明这里根本没有使用梯度下降算法呢,换言之就是是否还有其他方法来实现最小二乘法的线性回归呢?带着这个疑问,我开始仔细阅读LinearRegression.fit() 函数。
首先都是参数处理,这里虽然看不太懂但是也能大概知道他在做什么,所以可以跳过。
然后来到了核心代码,核心代码中使用的几个判断:
1 if self.positive:
2 if y.ndim < 2:
3 pass
4 else:
5 pass
6 elif sp.issparse(X):
7 if y.ndim < 2:
8 pass
9 else:
10 pass
11 else:
12 pass
self.positive 是在使用密集矩阵的时候设置的参数,y.ndim 表示y 的维度,简单来说就是y 中有几个[],所以大概能知道代码将密集矩阵与稀疏矩阵区分开,并且将一维矩阵与多维矩阵区分开,意味着不同的类别使用不同的方法。
if self.positive 分段解析
1 if self.positive:
2 if y.ndim < 2:
3 self.coef_, self._residues = optimize.nnls(X, y)
4 else:
5 # scipy.optimize.nnls cannot handle y with shape (M, K)
6 outs = Parallel(n_jobs=n_jobs_)(
7 delayed(optimize.nnls)(X, y[:, j])
8 for j in range(y.shape[1]))
9 self.coef_, self._residues = map(np.vstack, zip(*outs))
可以看出y 的维度小于2 的话使用optimize.nnls() 方法,否则进行其他处理,因为“scipy.optimize.nnls cannot handle y with shape (M, K)”,但看到之后也调用了optimize.nnls,所以应该是将矩阵处理成可以使用的样子。
并且值得注意的是这里使用了Parallel(n_jobs=n_jobs_)(delayed(optimize.nnls)(X, y[:, j])for j in range(y.shape[1])) 的调用方式,也就是形如fun(x)(y) 的方式,这意味着函数内定义了另一个函数,第一个括号是fun 的参数,第二个括号是给fun 函数内定义的函数的参数。
elif sp.issparse(X) 分段解析
1 elif sp.issparse(X):
2 X_offset_scale = X_offset / X_scale
3
4
5 def matvec(b):
6 return X.dot(b) - b.dot(X_offset_scale)
7
8
9 def rmatvec(b):
10 return X.T.dot(b) - X_offset_scale * np.sum(b)
11
12
13 X_centered = sparse.linalg.LinearOperator(shape=X.shape,
14 matvec=matvec,
15 rmatvec=rmatvec)
16
17 if y.ndim < 2:
18 out = sparse_lsqr(X_centered, y)
19 self.coef_ = out[0]
20 self._residues = out[3]
21 else:
22 # sparse_lstsq cannot handle y with shape (M, K)
23 outs = Parallel(n_jobs=n_jobs_)(
24 delayed(sparse_lsqr)(X_centered, y[:, j].ravel())
25 for j in range(y.shape[1]))
26 self.coef_ = np.vstack([out[0] for out in outs])
27 self._residues = np.vstack([out[3] for out in outs])
可以看到先是对数据进行了处理,然后调用了sparse_lsqr() 函数。
剩余分段解析
1 else:
2 self.coef_, self._residues, self.rank_, self.singular_ = \
3 linalg.lstsq(X, y)
4 self.coef_ = self.coef_.T
5
6 if y.ndim == 1:
7 self.coef_ = np.ravel(self.coef_)
8 self._set_intercept(X_offset, y_offset, X_scale)
使用了linalg.lstsq() 函数。
以上我们可以看到代码中一共使用了3 个方法来实现线性回归:optimize.nnls()、sparse_lsqr()、linalg.lstsq()
optimize.nnls() 分析
NNLS 即非负正则化最小二乘法,代码实现由scipy.optimize.nnls提供,这里只是将其封装起来,在源码的注释中提到该算法的FORTRAN 代码在Charles L. Lawson 与Richard J. Hanson 两位教授于1987 年所著的《Solving Least Squares Problems》中发布。“The algorithm is an active set method. It solves the KKT (Karush-Kuhn-Tucker) conditions for the non-negative least squares problem.” 可惜由于本人水平有限,并不能从书中或者此处的代码中学到该算法的精髓,只能先挖一个坑,以后有所提高了再来研究该算法。
sparse_lsqr() 分析
LSQR 即最小二乘QR分解算法,代码实现由scipy.sparse.linalg.lsqr 类提供,这里只是将其封装起来,在文档中可以看到:
LSQR uses an iterative method to approximate the solution. The number of iterations required to reach a certain accuracy depends strongly on the scaling of the problem. Poor scaling of the rows or columns of A should therefore be avoided where possible.
同时也给出了参考文献:
[1] C. C. Paige and M. A. Saunders (1982a). "LSQR: An algorithm for sparse linear equations and sparse least squares", ACM TOMS 8(1), 43-71.
[2] C. C. Paige and M. A. Saunders (1982b). "Algorithm 583. LSQR: Sparse linear equations and least squares problems", ACM TOMS 8(2), 195-209.
[3] M. A. Saunders (1995). "Solution of sparse rectangular systems using LSQR and CRAIG", BIT 35, 588-604.
可以看到LSQR 算法是Paige 和Saunders 于1982 年提出的一种方法,但水平有限,暂时并不清楚其中原理。
linalg.lstsq() 分析
LSTSQ 是 LeaST SQuare (最小二乘)的意思,也就是普通的最小二乘法。代码实现由scipy.linalg.lstsq 提供,这里只是将其封装起来。通过官方文档提供的源代码链接,我找到了lstsq 函数的源代码,注释中提到了:
Which LAPACK driver is used to solve the least-squares problem.
Options are ``'gelsd'``, ``'gelsy'``, ``'gelss'``. Default(``'gelsd'``) is a good choice. However, ``'gelsy'`` can be slightly faster on many problems. ``'gelss'`` was used historically. It is generally slow but uses less memory.
也就是说有三个选项:gelsd(默认推荐)、gelsy(可能稍快)、gelss(使用内存少),那么来看看他们分别使用什么方法来解决最小二乘法吧。
1 if driver in ('gelss', 'gelsd'):
2 if driver == 'gelss':
3 lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
4 v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork,
5 overwrite_a=overwrite_a,
6 overwrite_b=overwrite_b)
7
8 elif driver == 'gelsd':
9 if real_data:
10 lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
11 x, s, rank, info = lapack_func(a1, b1, lwork,
12 iwork, cond, False, False)
13 else: # complex data
14 lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
15 nrhs, cond)
16 x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
17 cond, False, False)
18 elif driver == 'gelsy':
19 lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
20 jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
21 v, x, j, rank, info = lapack_func(a1, b1, jptv, cond,
22 lwork, False, False)
看到调用了lapack_func,但是找了一下spicy 并没有发现这个函数,于是搜索lapack_func 找到了定义:
lapack_func, lapack_lwork = get_lapack_funcs((driver,'%s_lwork' % driver),(a1, b1))
原来它调用了get_lapack_funcs 函数,于是去找该函数的文档,从文档中看到了该方法的用处:“This routine automatically chooses between Fortran/C interfaces. Fortran code is used whenever possible for arrays with column major order. In all other cases, C code is preferred.”,原来这是一个Fortran 语言和C 语言的接口且首选C 语言,并且从源码中我看到了它是调用_get_funcs 实现,那3 个单词是C 语言的函数名,所幸我对C 语言的熟悉程度打过Python,于是去找C 语言API 进行学习。
gelsd:Computes the minimum-norm solution to a linear least squares problem using the singular value decomposition of A and a divide and conquer method. 分治法的奇异值分解找出最优解。
gelss:Computes the minimum-norm solution to a linear least squares problem using the singular value decomposition of A. 使用奇异值分解找出最优解。
gelsy:Computes the minimum-norm solution to a linear least squares problem using a complete orthogonal factorization of A. 使用完全正交分解的方式找出最优解。
至此对于所有函数的分析基本上是结束了,密集矩阵使用的是NNLS 算法,稀疏矩阵使用的是LSQR 算法,其他的使用的则是最常用的最小二乘法算法。回到开头的问题,为什么没有使用梯度下降呢?难道梯度下降并不是解决最小二乘或者说是线性回归的算法吗?在网上查阅了很多材料之后我发现自己的问题:我陷入了误区。
先来说说最小二乘法,最小二乘法在我初高中的时候学习简单线性回归的时候就接触了,根据百度百科的词条,最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小,也就是说最小二乘法的公式是目标函数 = MIN( ∑ (预测值 – 实际值)² )。这就意味着如果人工计算的话就需要穷举所有的函数,计算他们的损失然后找出最小的那个函数。
再来说说梯度下降,梯度下降是迭代法的一种,通过更新梯度来找到损失最小的那个函数,不知你发现问题了吗?最小二乘法与梯度下降是在做同一件事情,也就是最优化问题,两个是并行的关系,并不存在谁解决谁。百度百科中关于梯度下降的词条中提到:“梯度下降是迭代法的一种,可以用于求解最小二乘问题(线性和非线性都可以)。在求解机器学习算法的模型参数,即无约束优化问题时,梯度下降(Gradient Descent)是最常采用的方法之一,另一种常用的方法是最小二乘法。”这里很清晰的指出了最小二乘法和梯度下降法的关系。
再来说说另一个我之前并不知道的概念:最小二乘准则。百度百科中提到这是一种对于偏差程度的评估准则,与上两者不同,上述的算法都是基于最小二乘准则提出的对于最小二乘法优化问题的解决方案,也就是如果不穷举的话如何找到最小二乘法的最优解。
列出我查阅的资料:
总结
机器学习03-sklearn.LinearRegression 源码学习的更多相关文章
- 【iScroll源码学习03】iScroll事件机制与滚动条的实现
前言 想不到又到周末了,周末的时间要抓紧学习才行,前几天我们学习了iScroll几点基础知识: 1. [iScroll源码学习02]分解iScroll三个核心事件点 2. [iScroll源码学习01 ...
- Qt Creator 源码学习笔记03,大型项目如何管理工程
阅读本文大概需要 6 分钟 一个项目随着功能开发越来越多,项目必然越来越大,工程管理成本也越来越高,后期维护成本更高.如何更好的组织管理工程,是非常重要的 今天我们来学习下 Qt Creator 是如 ...
- 【iScroll源码学习04】分离IScroll核心
前言 最近几天我们前前后后基本将iScroll源码学的七七八八了,文章中未涉及的各位就要自己去看了 1. [iScroll源码学习03]iScroll事件机制与滚动条的实现 2. [iScroll源码 ...
- (转)干货|这篇TensorFlow实例教程文章告诉你GANs为何引爆机器学习?(附源码)
干货|这篇TensorFlow实例教程文章告诉你GANs为何引爆机器学习?(附源码) 该博客来源自:https://mp.weixin.qq.com/s?__biz=MzA4NzE1NzYyMw==& ...
- Vue源码学习(一):调试环境搭建
最近开始学习Vue源码,第一步就是要把调试环境搭好,这个过程遇到小坑着实费了点功夫,在这里记下来 一.调试环境搭建过程 1.安装node.js,具体不展开 2.下载vue项目源码,git或svn等均可 ...
- JDK1.8源码分析01之学习建议(可以延伸其他源码学习)
序言:目前有个计划就是准备看一下源码,来提升自己的技术实力.同时现在好多面试官都喜欢问源码,问你是否读过JDK源码等等? 针对如何阅读源码,也请教了我的老师.下面就先来看看老师的回答,也许会有帮助呢. ...
- Mybatis源码学习第八天(总结)
源码学习到这里就要结束了; 来总结一下吧 Mybatis的总体架构 这次源码学习我们,学习了重点的模块,在这里我想说一句,源码的学习不是要所有的都学,一行一行的去学,这是错误的,我们只需要学习核心,专 ...
- Mybatis源码学习第六天(核心流程分析)之Executor分析
今Executor这个类,Mybatis虽然表面是SqlSession做的增删改查,其实底层统一调用的是Executor这个接口 在这里贴一下Mybatis查询体系结构图 Executor组件分析 E ...
- [阿里DIN]从论文源码学习 之 embedding_lookup
[阿里DIN]从论文源码学习 之 embedding_lookup 目录 [阿里DIN]从论文源码学习 之 embedding_lookup 0x00 摘要 0x01 DIN代码 1.1 Embedd ...
随机推荐
- Vue框架简介及简单使用
目录 一.前端框架介绍 二.vue框架简介 三.vue使用初体验 1. vue如何在页面中引入 2. 插值表达式 3. 文本指令 4. 方法指令(事件指令) 5. 属性指令 四.js数据类型补充 1. ...
- Java数组练习(打印杨辉数组)
打印杨辉数组 package com.kangkang.array; import java.util.Scanner; public class demo02 { public static voi ...
- C#连接Excel读取与写入数据库SQL ( 下 )
接上期 dataset简而言之可以理解为 虚拟的 数据库或是Excel文件.而dataset里的datatable 可以理解为数据库中的table活着Excel里的sheet(Excel里面不是可以新 ...
- 测试平台系列(2) 给Pity添加配置
给Pity添加配置 回顾 还记得上篇文章创立的「Flask」实例吗?我们通过这个实例,给根路由 「/」 绑定了一个方法,从而使得用户访问不同路由的时候可以执行不同的方法. 配置 要知道,在一个「Web ...
- C语言之三字棋的简单实现及扩展
C语言之三字棋的简单实现及扩展 在我们学习完数组之后,我们完全可以利用数组相关知识来写一个微小型的游戏,比如说今天所说的--三子棋. 大纲: 文件组成 实现 完整代码展示 扩展 即: 一.文件 ...
- 关于 FreeBSD 老版本如何安装软件
关于 FreeBSD 不被支持版本如何安装软件: ALLOW_UNSUPPORTED_SYSTEM=yes写到/etc/ make.conf 如果提示没有make.conf 请手动新建一个文 ...
- BIMFACE二次开发【C#系列】
本系列文章主要介绍使用 C# .ASP.NET(MVC)技术对 BIMFACE 平台进行二次开发,以满足本公司针对建筑行业施工图审查系统的业务需求,例如图纸模型(PDF 文件.二维 CAD 模型.三维 ...
- 安卓Media相关类测试demo
最近在研究安卓系统给app开发者提供的标准Media相关的工具类,本人做了一些demo来测试这些工具的使用方法. 本demo包含若干apk源码,需要说明以下几点: 1. 构建方式 Makefile使用 ...
- 1.mysql读写
一.数据库读取(mysql) 参数 接受 作用 默认 sql or table_name string 读取的表名,或sql语句 无 con 数据库连接 数据库连接信息 无 index_col Int ...
- CSS-clear属性的作用
1 <!DOCTYPE html> 2 <html lang="en"> 3 <head> 4 <meta charset="U ...