【cs229-Lecture2】Linear Regression with One Variable (Week 1)(含测试数据和源码)
从Ⅱ到Ⅳ都在讲的是线性回归,其中第Ⅱ章讲得是简单线性回归(simple linear regression, SLR)(单变量),第Ⅲ章讲的是线代基础,第Ⅳ章讲的是多元回归(大于一个自变量)。
本文的目的主要是对Ⅱ章中出现的一些算法进行实现,适合的人群为已经看完本章节Stanford课程的学者。本人只是一名初学者,尽可能以白话的方式来说明问题。不足之处,还请指正。
在开始讨论具体步骤之前,首先给出简要的思维路线:
1.拥有一个点集,为了得到一条最佳拟合的直线;
2.通过“最小二乘法”来衡量拟合程度,得到代价方程;
3.利用“梯度下降算法”使得代价方程取得极小值点;
首先,介绍几个概念:
回归在数学上来说是给定一个点集,能够用一条曲线去拟合之。如果这个曲线是一条直线,那就被称为线性回归;如果曲线是一条二次曲线,就被称为二次回归,回归还有很多的变种,如locally weighted回归,logistic回归等等。
课程中得到的h就是线性回归方程:
下面,首先来介绍一下单变量的线性回归:
问题是这样的:给定一个点集,找出一条直线去拟合,要求拟合的效果达到最佳(最佳拟合)。
既然是直线,我们先假设直线的方程为:
如图:
点集有了,直线方程有了,接下来,我们要做的就是计算出
和
,使得拟合效果达到最佳(最佳拟合)。
那么,拟合效果的评判标准是什么呢?换句话说,我们需要知道一种对拟合效果的度量。
在这里,我们提出“最小二乘法”:(以下摘自wiki)
最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。
利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小。
对于“最小二乘法”就不再展开讨论,只要知道他是一个度量标准,我们可以用它来评判计算出的直线方程是否达到了最佳拟合就够了。
那么,回到问题上来,在单变量的线性回归中,这个拟合效果的表达式是利用最小二乘法将未知量残差平方和最小化:
结合课程,定义了一个成本函数:
其实,到这里,要是把点集的具体数值代入到成本函数中,就已经完全抽象出了一个高等数学问题(解一个二元函数的最小值问题)。
其中,a,b,c,d,e,f均为已知。
课程中介绍了一种叫“Gradient descent”的方法——梯度下降算法
两张图说明算法的基本思想:
所谓梯度下降算法(一种求局部最优解的方法),举个例子就好比你现在在一座山上,你想要尽快地到达山底(极小值点),这是一个下降的过程,这里就涉及到了两个问题:1)你下山的时候,跨多大的步子(当然,肯定不是越大越好,因为有一种可能就是你一步跨地太大,正好错过了极小的位置);2)你朝哪个方向跨步(注意,这个方向是不断变化的,你每到一个新的位置,要判断一下下一步朝那个方向走才是最好的,但是有一点可以肯定的是,要想尽快到达最低点,应从最陡的地方下山)。
那么,什么时候算是你到了一个极小点呢,显然,当你所处的位置发生的变化不断减小,直至收敛于某一位置,就说明那个位置就是一个极小值点。
so,我们来看
的变化,则我们需要让
对
求偏导,倒数代表变化率。也就是要朝着对陡的地方下山(因为沿着最陡显然比较快),就得到了
的变化情况:
简化之后:
步长不宜过大或过小
梯度下降法是按下面的流程进行的:(转自:http://blog.sina.com.cn/s/blog_62339a2401015jyq.html)
1)首先对θ赋值,这个值可以是随机的,也可以让θ是一个全零的向量。
2)改变θ的值,使得J(θ)按梯度下降的方向进行减少。
{
为了方便大家的理解,首先给出单变量的例子:
eg:求
的最小值。(注:
)
java代码如下:
·
package OneVariable; public class OneVariable{
public static void main(String[] args){
double e=0.00001;//定义迭代精度
double alpha=0.5;//定义迭代步长
double x=0; //初始化x
double y0=2*x*x+3*x+1;//与初始化x对应的y值
double y1=0;//定义变量,用于保存当前值
while (true)
{
x=x-alpha*(4.0*x+3.0);
y1=2*x*x+3*x+1;
if (Math.abs(y1-y0)<e)//如果2次迭代的结果变化很小,结束迭代
{
break;
}
y0=y1;//更新迭代的结果
}
System.out.println("Min(f(x))="+y0);
System.out.println("minx="+x);
}
} //输出
Min(f(x))=1.0
minx=-1.5}
两个变量的时候,为了更清楚,给出下面的图:
这是一个表示参数θ与误差函数J(θ)的关系图,红色的部分是表示J(θ)有着比较高的取值,我们需要的是,能够让J(θ)的值尽量的低。也就是深蓝色的部分。θ0,θ1表示θ向量的两个维度。
在上面提到梯度下降法的第一步是给θ给一个初值,假设随机给的初值是在图上的十字点。
然后我们将θ按照梯度下降的方向进行调整,就会使得J(θ)往更低的方向进行变化,如图所示,算法的结束将是在θ下降到无法继续下降为止。
当然,可能梯度下降的最终点并非是全局最小点,可能是一个局部最小点,可能是下面的情况:
上面这张图就是描述的一个局部最小点,这是我们重新选择了一个初始点得到的,看来我们这个算法将会在很大的程度上被初始点的选择影响而陷入局部最小点
一个很重要的地方值得注意的是,梯度是有方向的,对于一个向量θ,每一维分量θi都可以求出一个梯度的方向,我们就可以找到一个整体的方向,在变化的时候,我们就朝着下降最多的方向进行变化就可以达到一个最小点,不管它是局部的还是全局的。
理论的知识就讲到这,下面,我们就用java去实现这个算法:
梯度下降有两种:批量梯度下降和随机梯度下降。详见:http://blog.csdn.net/lilyth_lilyth/article/details/8973972
测试数据就用课后题中的数据(ex1data1.txt),用matlab打开作图得到:
首先说明:以下源码是不正确的,具体为什么不正确我还没搞清楚!非常希望各位高手能够指正!
测试数据及源码下载:http://pan.baidu.com/s/1mgiIVm4
OneVariable.javapackage OneVariableVersion; import java.io.IOException;
import java.util.List; /**
* Linear Regression with One Variable
* @author XBW
* @date 2014年8月17日
*/ public class OneVariable{
public static final Double e=0.00001;
public static List<Data> DS;
public static Double step;
public static Double m; /**
* 计算当前参数是否符合
* @param ans
* @param datalist
* @return
*/
public static Ans calc(Ans ans){
Double costfun;
do{
costfun=calcAccuracy(ans);
ans=update(ans);
step*=0.3;
}while(Math.abs(costfun-calcAccuracy(ans))>e);
ans.ifConvergence=true;
return ans;
} /**
* 判断当前ans是否满足精度,y=t0+t1*x
* @param ans
* @return
*/
public static Double calcAccuracy(Ans ans){
Double cost=0.0;
Double tmp;
for(int i=0;i<m;i++){
tmp=DS.get(i).y-(ans.theta[0]*DS.get(i).x[0]+ans.theta[1]*DS.get(i).x[1]);
cost+=tmp*tmp;
}
cost/=(2*m);
return cost;
} /**
* 更新ans
* @param ans,学习速率为step,m为数据量
* @return
*/
public static Ans update(Ans ans){
Double[] tmp=new Double[100] ;
for(int i=0;i<2;i++){
tmp[i]=ans.theta[i]-step*fun(ans,i);
}
for(int i=0;i<2;i++){
ans.theta[i]=tmp[i];
}
return ans;
} /**
* 计算偏导
* @return
*/
public static Double fun(Ans ans,int xi){
Double ret = 0.0;
for(int i=0;i<m;i++){
ret+=(ans.theta[0]*DS.get(i).x[0]+ans.theta[1]*DS.get(i).x[1]-DS.get(i).y)*DS.get(i).x[xi];
}
ret/=m;
return ret;
} public static void main(String[] args) throws IOException{
DS=new DataSet().ds;
step=1.0;
m=(double)DS.size(); Double[] theta={0.0,0.0}; //初始设定theta0=0,theta1=0
Ans ans=new Ans(theta,false);
Ans answer;
answer=calc(ans);
System.out.println("theta1= "+answer.theta[0]+" theta2="+answer.theta[1]);
}
}DataSet.java
package OneVariableVersion; import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List; /**
* 数据处理
* @author XBW
* @date 2014年8月17日
*/ public class DataSet{
String defaultpath="D:\\MachineLearning\\StanfordbyAndrewNg\\II.LinearRegressionwithOneVariable(Week1)\\homework\\ex1data1.txt"; List<Data> ds=new ArrayList<Data>(); public DataSet() throws IOException{
File dataset=new File(defaultpath);
BufferedReader br = new BufferedReader(new FileReader(dataset));
String tsing;
while((tsing=br.readLine())!=null){
String[] dlist=tsing.split(",");
Data dtmp=new Data(Double.parseDouble(dlist[0]),Double.parseDouble(dlist[1]));
this.ds.add(dtmp);
}
br.close();
}
}Ans.java
package OneVariableVersion; /**
* 保存结果,y=t0+t1*x
* @author XBW
* @date 2014年8月17日
*/ public class Ans {
Double[] theta;
boolean ifConvergence; public Ans(Double[] tmp,boolean ifCon){
this.theta=tmp;
this.ifConvergence=ifCon;
}
}Data.java
package OneVariableVersion; /**
* 一条数据
* @author XBW
* @date 2014年8月17日
*/
public class Data {
Double[] x=new Double[2];
Double y; public Data(Double xtmp,Double ytmp){
this.x[0]=1.0;
this.x[1]=xtmp;
this.y=ytmp;
}
}总结:写代码的时候有几个讲究:
- 步长是否需要动态变化,按照coursera公开课上讲的是不必要动态改变的,因为偏导数会越来越小,但在实际情况下,按照一定的比值缩小或者自己定义一种缩小的方式可能是有必要的,所以具体情况具体分析;
- 初始步长的设定也是很重要的,大了就不会得到结果,因为发散了;步长越大,下降速率越快,但是也会导致震荡,所以,还是哪句话:具体问题具体分析;
【cs229-Lecture2】Linear Regression with One Variable (Week 1)(含测试数据和源码)的更多相关文章
- Stanford机器学习---第二讲. 多变量线性回归 Linear Regression with multiple variable
原文:http://blog.csdn.net/abcjennifer/article/details/7700772 本栏目(Machine learning)包括单参数的线性回归.多参数的线性回归 ...
- Stanford机器学习---第一讲. Linear Regression with one variable
原文:http://blog.csdn.net/abcjennifer/article/details/7691571 本栏目(Machine learning)包括单参数的线性回归.多参数的线性回归 ...
- 机器学习笔记1——Linear Regression with One Variable
Linear Regression with One Variable Model Representation Recall that in *regression problems*, we ar ...
- Machine Learning 学习笔记2 - linear regression with one variable(单变量线性回归)
一.Model representation(模型表示) 1.1 训练集 由训练样例(training example)组成的集合就是训练集(training set), 如下图所示, 其中(x,y) ...
- Ng第二课:单变量线性回归(Linear Regression with One Variable)
二.单变量线性回归(Linear Regression with One Variable) 2.1 模型表示 2.2 代价函数 2.3 代价函数的直观理解 2.4 梯度下降 2.5 梯度下 ...
- MachineLearning ---- lesson 2 Linear Regression with One Variable
Linear Regression with One Variable model Representation 以上篇博文中的房价预测为例,从图中依次来看,m表示训练集的大小,此处即房价样本数量:x ...
- 斯坦福第二课:单变量线性回归(Linear Regression with One Variable)
二.单变量线性回归(Linear Regression with One Variable) 2.1 模型表示 2.2 代价函数 2.3 代价函数的直观理解 I 2.4 代价函数的直观理解 I ...
- 机器学习 (一) 单变量线性回归 Linear Regression with One Variable
文章内容均来自斯坦福大学的Andrew Ng教授讲解的Machine Learning课程,本文是针对该课程的个人学习笔记,如有疏漏,请以原课程所讲述内容为准.感谢博主Rachel Zhang的个人笔 ...
- Lecture0 -- Introduction&&Linear Regression with One Variable
Introduction What is machine learning? Tom Mitchell provides a more modern definition: "A compu ...
随机推荐
- Spring JDBC多批次操作
以下示例将演示如何使用spring jdbc在单个调用中进行多批次更新. 我们将在批量大小为1的多批次操作中更新student表中的记录. student表的结果如下 - CREATE TABLE s ...
- 在Ubuntu14.04上编译Android4.0.1出现的几个问题
一. 工具 sudo apt-get install git-core gnupg flex bison gperf build-essential \ zip curl libc6-de ...
- js中找string中重复项最多的字符个数
// split():字符串中的方法,把字符串转成数组. // sort():数组中的排序方法,按照ACALL码进行排序. // join():数组中的方法,把数组转换为字符串 function de ...
- Linux环境下Redis安装配置步骤[转]
在LInux下安装Redis的步骤如下: 1.首先下载一个Redis安装包,官网下载地址为:https://redis.io/ 2.在Linux下解压redis: tar -zxvf redis-2. ...
- python pip 更换国内安装源(windows)
1.点击此电脑,在最上面的的文件夹窗口输入 : %APPDATA% 2.按回车跳转到以下目录,新建pip文件夹 3.创建pip.ini文件 4.打开文件夹,输入以下内容,关闭即可(注意:源镜像可替换) ...
- Kafka配置说明
Broker Configs Property Default Description broker.id 每个broker都可以用一个唯一的非负整数id进行标识:这个id可以作为broker的 ...
- git push 后 链接总是灰色点击没有反应
情况描述: mymon是openfalcon的监控mysql插件,从GitHub拉下来后,改动源码后,提交到公司内部的gitlab上,发现提交上去的图标总是灰色的,点击进不去,如下图所示.怎么解决? ...
- ubuntu-16.04.2-desktop-amd64.iso:安装Oracle11gR2
特点: 使用ubuntu-16.04.2-desktop-amd64.iso 不降级默认的gcc版本,(liveCD 自带默认为 gcc 5.4):仅需要建立“gcc -Wl,--no-as-need ...
- BarTender复合条形码中的分隔符模式详解
在BarTender 10.1中,支持使用BarTender分隔符模式的复合条形码符号体系包括GS1 Composite和GS1 DataBar (RSS).本文小编给大家详细讲解BarTender分 ...
- C++ 使用vector时遇到的一个问题
我在测试程序中定义一个存储三维点的结构体,并定义该结构体的vector,当我在向vector插入元素时,编译一直提示错误: 代码片段如下: C++ Code 1234567891011121314 ...