本文转载自:https://juejin.im/post/5a924df16fb9a0634514d6e1

机器学习之线性回归(纯python实现)

线性回归是机器学习中最基本的一个算法,大部分算法都是由基本的算法演变而来。本文着重用很简单的语言说一下线性回归。

线性回归

包括一元线性回归和多元线性回归,一元指的是只有一个x和一个y。通过一元对于线性回归有个基本的理解。

一元线性回归就是在数据中找到一条直线,以最小的误差来(Loss)来拟和数据。

上面提到的误差可以这样表示,假设那条直线如下图:

理想情况是所有点都落在直线上。退一步,希望所有点离直线的距离最近。简单起见,将距离求平方,误差可以表示为:

上面的i表示第i个数据。一般情况下对Loss求平均,来当作最终的损失。

最小化误差

找到最能拟合数据的直线,也就是最小化误差。

最小二乘法

上述公式只有m, b未知,因此可以看最一个m, b的二次方程,求Loss的问题就转变成了求极值问题。 这里不做详细说明。

另每个变量的偏导数为0, 求方程组的解。

求出m,b即可得到所要的直线。

梯度下降法

没有梯度下降就没有现在的深度学习。 最小二乘法可以一步到位,直接求出m,b。在大部分公式中是无法简单的直接计算的。而梯度下降通过一步一步的迭代,慢慢的去靠近那条最优的直线,因此需要不断的优化。 Loss的函数图像可以类比成一个碗。

要求的最小值就在碗底,随意给出一点往下走,即沿着下降最快的方向(梯度)往下走,定义每一步移动的步长,移动的次数来逼近最优值。 下面用算法来实现:

初始化:

def init_data():
data = np.loadtxt('data.csv', delimiter=',')
return data def linear_regression():
learning_rate = 0.01 #步长
initial_b = 0
initial_m = 0
num_iter = 1000 #迭代次数 data = init_data()
[b, m] = optimizer(data, initial_b, initial_m, learning_rate, num_iter)
plot_data(data,b,m)
print(b, m)
return b, m
复制代码

优化器去做梯度下降:

def optimizer(data, initial_b, initial_m, learning_rate, num_iter):
b = initial_b
m = initial_m for i in range(num_iter):
b, m = compute_gradient(b, m, data, learning_rate)
# after = computer_error(b, m, data)
if i % 100 == 0:
print(i, computer_error(b, m, data)) # 损失函数,即误差
return [b, m]
复制代码

每次迭代计算梯度做参数更新:

def compute_gradient(b_cur, m_cur, data, learning_rate):
b_gradient = 0
m_gradient = 0 N = float(len(data))
#
# 偏导数, 梯度
for i in range(0, len(data)):
x = data[i, 0]
y = data[i, 1] b_gradient += -(2 / N) * (y - ((m_cur * x) + b_cur))
m_gradient += -(2 / N) * x * (y - ((m_cur * x) + b_cur)) #偏导数 new_b = b_cur - (learning_rate * b_gradient)
new_m = m_cur - (learning_rate * m_gradient)
return [new_b, new_m]
复制代码

Loss值的计算:

def computer_error(b, m, data):
totalError = 0
x = data[:, 0]
y = data[:, 1]
totalError = (y - m * x - b) ** 2
totalError = np.sum(totalError, axis=0)
return totalError / len(data)
复制代码

执行函数计算结果:

if __name__ == '__main__':
linear_regression()
复制代码

运算结果如下:

0 3.26543633854
100 1.41872132865
200 1.36529867423
300 1.34376973304
400 1.33509372632
500 1.33159735872
600 1.330188348
700 1.32962052693
800 1.32939169917
900 1.32929948325
1.23930380135 1.86724196887
复制代码

可以看到,随着迭代次数的增加,Loss函数越来越逼近最小值,而m,b也越来越逼近最优解。

注意:

在上面的方法中,还是通过计算Loss的偏导数来最小化误差。上述方法在梯度已知的情况下,即肯定按照下降最快的方法到达碗底。那么在公式非常难以计算的情况下怎么去求最优解。此时求偏导数可以使用导数的定义,看另一个函数。

def optimizer_two(data, initial_b, initial_m, learning_rate, num_iter):
b = initial_b
m = initial_m while True:
before = computer_error(b, m, data)
b, m = compute_gradient(b, m, data, learning_rate)
after = computer_error(b, m, data)
if abs(after - before) < 0.0000001: #不断减小精度
break
return [b, m] def compute_gradient_two(b_cur, m_cur, data, learning_rate):
b_gradient = 0
m_gradient = 0 N = float(len(data)) delta = 0.0000001 for i in range(len(data)):
x = data[i, 0]
y = data[i, 1]
# 利用导数的定义来计算梯度
b_gradient = (error(x, y, b_cur + delta, m_cur) - error(x, y, b_cur - delta, m_cur)) / (2*delta)
m_gradient = (error(x, y, b_cur, m_cur + delta) - error(x, y, b_cur, m_cur - delta)) / (2*delta) b_gradient = b_gradient / N
m_gradient = m_gradient / N
#
new_b = b_cur - (learning_rate * b_gradient)
new_m = m_cur - (learning_rate * m_gradient)
return [new_b, new_m] def error(x, y, b, m):
return (y - (m * x) - b) ** 2
复制代码

上述两种中,迭代次数足够多都可以逼近最优解。 分别求得的最优解为: 1: 1.23930380135 1.86724196887 2: 1.24291450769 1.86676417482

简单比较

sklearn中有相应的方法求线性回归,其直接使用最小二乘法求最优解。简单实现以做个比较。

def scikit_learn():
data = init_data()
y = data[:, 1]
x = data[:, 0]
x = (x.reshape(-1, 1))
linreg = LinearRegression()
linreg.fit(x, y)
print(linreg.coef_)
print(linreg.intercept_) if __name__ == '__main__':
# linear_regression()
scikit_learn()
复制代码

此时求的解为: 1.24977978176 1.86585571。 可以说明上述计算结果比较满意,通过后期调整参数,可以达到比较好的效果。

源码和数据参考: github.com/yunshuipiao…

感谢并参考博文: 线性回归理解(附纯python实现)
Gradient Descent 梯度下降法
梯度下降(Gradient Descent)小结

作者:swensun
链接:https://juejin.im/post/5a924df16fb9a0634514d6e1
来源:掘金
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

机器学习之线性回归(纯python实现)][转]的更多相关文章

  1. 机器学习之线性回归使用Python和tensorflow实现

    导入依赖包 import tensorflow as tf import numpy as np import matplotlib.pylab as plt from pylab import mp ...

  2. 机器学习3- 一元线性回归+Python实现

    目录 1. 线性模型 2. 线性回归 2.1 一元线性回归 3. 一元线性回归的Python实现 3.1 使用 stikit-learn 3.1.1 导入必要模块 3.1.2 使用 Pandas 加载 ...

  3. 机器学习4- 多元线性回归+Python实现

    目录 1 多元线性回归 2 多元线性回归的Python实现 2.1 手动实现 2.1.1 导入必要模块 2.1.2 加载数据 2.1.3 计算系数 2.1.4 预测 2.2 使用 sklearn 1 ...

  4. JavaScript机器学习之线性回归

    译者按: AI时代,不会机器学习的JavaScript开发者不是好的前端工程师. 原文: Machine Learning with JavaScript : Part 1 译者: Fundebug ...

  5. 干货 | 请收下这份2018学习清单:150个最好的机器学习,NLP和Python教程

    机器学习的发展可以追溯到1959年,有着丰富的历史.这个领域也正在以前所未有的速度进化.在之前的一篇文章中,我们讨论过为什么通用人工智能领域即将要爆发.有兴趣入坑ML的小伙伴不要拖延了,时不我待! 在 ...

  6. 机器学习1—简介及Python机器学习环境搭建

    简介 前置声明:本专栏的所有文章皆为本人学习时所做笔记而整理成篇,转载需授权且需注明文章来源,禁止商业用途,仅供学习交流.(欢迎大家提供宝贵的意见,共同进步) 正文: 机器学习,顾名思义,就是研究计算 ...

  7. 是AI就躲个飞机-纯Python实现人工智能

    你要的答案或许都在这里:小鹏的博客目录 代码下载:Here. 很久以前微信流行过一个小游戏:打飞机,这个游戏简单又无聊.在2017年来临之际,我就实现一个超级弱智的人工智能(AI),这货可以躲避从屏幕 ...

  8. 深入理解Python中协程的应用机制: 使用纯Python来实现一个操作系统吧!!

    本文参考:http://www.dabeaz.com/coroutines/   作者:David Beazley 缘起: 本人最近在学习python的协程.偶然发现了David Beazley的co ...

  9. 机器学习、NLP、Python和Math最好的150余个教程(建议收藏)

    编辑 | MingMing 尽管机器学习的历史可以追溯到1959年,但目前,这个领域正以前所未有的速度发展.最近,我一直在网上寻找关于机器学习和NLP各方面的好资源,为了帮助到和我有相同需求的人,我整 ...

随机推荐

  1. 如何给MFC的菜单项添加快捷键

    我们一起分享,如何给MFC的菜单项添加快捷键.[程序在VC6.0编译环境下编译通过.(VS2010的编译环境大同小异)] 1.程序演示环境 1.1新建一个[对话框(Dialog)]的程序.然后,New ...

  2. HDU_5538_House Building

    House Building Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 262144/262144 K (Java/Others) ...

  3. Codeforces Round #427 (Div. 2)—A,B,C,D题

    A. Key races 题目链接:http://codeforces.com/contest/835/problem/A 题目意思:两个比赛打字,每个人有两个参数v和t,v秒表示他打每个字需要多久时 ...

  4. UNION DISTINCT

    w同结构表读写合并. DROP PROCEDURE IF EXISTS w_ww_amzasin; DELIMITER /w/ CREATE PROCEDURE w_ww_amzasin() BEGI ...

  5. 利用wget批量下载http目录下文件

    原理:下载你需要down的目录页面的index.html,可能名字不是如此!!!之后用wget下载该文件里包含的所有链接! 例如:wget -vE -rLnp -nH --tries=20 --tim ...

  6. birt 日志打印

    在birt初始initialize 方法里,定义日志输出方法 importPackage(Packages.java.util.logging); importPackage(Packages.log ...

  7. jQuery中on()方法用法实例

    这篇文章主要介绍了jQuery中on()方法用法,实例分析了on()方法的功能.定义及在匹配元素上绑定一个或者多个事件处理函数的使用技巧,需要的朋友可以参考下 本文实例讲述了jQuery中on()方法 ...

  8. ALV tree DUMP 问题处理-20180328

    Category ABAP Programming Error Runtime Errors MESSAGE_TYPE_X ABAP Program SAPLOLEA Application Comp ...

  9. 使用Spring注解获取配置文件信息

    需要加载的配置文件内容(resource.properties): #FTP相关配置 #FTP的IP地址 FTP_ADDRESS=192.168.1.121 FTP_PORT=21 FTP_USERN ...

  10. HMM、MEMM、CRF模型比较和标注偏置问题(Label Bias Problem)

    本文转自:http://www.cnblogs.com/syx-1987/p/4077325.html 路径1-1-1-1的概率:0.4*0.45*0.5=0.09 路径2-2-2-2的概率:0.01 ...