http://www.cnblogs.com/wzm-xu/p/4062266.html

多元线性回归----Java简单实现

 

学习Andrew N.g的机器学习课程之后的简单实现.

课程地址:https://class.coursera.org/ml-007

不大会编辑公式,所以略去具体的推导,有疑惑的同学去看看Andrew 的课程吧,顺带一句,Andrew的课程实在是很赞。

如果还有疑问,feel free to contact me via emails or QQ.

LinearRegression.java

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException; public class LinearRegression {
/*
* 训练数据示例:
* x0 x1 x2 y
1.0 1.0 2.0 7.2
1.0 2.0 1.0 4.9
1.0 3.0 0.0 2.6
1.0 4.0 1.0 6.3
1.0 5.0 -1.0 1.0
1.0 6.0 0.0 4.7
1.0 7.0 -2.0 -0.6
注意!!!!x1,x2,y三列是用户实际输入的数据,x0是为了推导出来的公式统一,特地补上的一列。
x0,x1,x2是“特征”,y是结果 h(x) = theta0 * x0 + theta1* x1 + theta2 * x2 theta0,theta1,theta2 是想要训练出来的参数 此程序采用“梯度下降法” *
*/ private double [][] trainData;//训练数据,一行一个数据,每一行最后一个数据为 y
private int row;//训练数据 行数
private int column;//训练数据 列数 private double [] theta;//参数theta private double alpha;//训练步长
private int iteration;//迭代次数 public LinearRegression(String fileName)
{
int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的 行数
int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的 列数 trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1
this.row=rowoffile;
this.column=columnoffile+1; this.alpha = 0.001;//步长默认为0.001
this.iteration=100000;//迭代次数默认为 100000 theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
initialize_theta(); loadTrainDataFromFile(fileName,rowoffile,columnoffile);
}
public LinearRegression(String fileName,double alpha,int iteration)
{
int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的 行数
int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的 列数 trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1
this.row=rowoffile;
this.column=columnoffile+1; this.alpha = alpha;
this.iteration=iteration; theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
initialize_theta(); loadTrainDataFromFile(fileName,rowoffile,columnoffile);
} private int getRowNumber(String fileName)
{
int count =0;
File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
while ( reader.readLine() != null)
count++;
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
return count; } private int getColumnNumber(String fileName)
{
int count =0;
File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
String tempString = reader.readLine();
if(tempString!=null)
count = tempString.split(" ").length;
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
return count;
} private void initialize_theta()//将theta各个参数全部初始化为1.0
{
for(int i=0;i<theta.length;i++)
theta[i]=1.0;
} public void trainTheta()
{
int iteration = this.iteration;
while( (iteration--)>0 )
{
//对每个theta i 求 偏导数
double [] partial_derivative = compute_partial_derivative();//偏导数
//更新每个theta
for(int i =0; i< theta.length;i++)
theta[i]-= alpha * partial_derivative[i];
}
} private double [] compute_partial_derivative()
{
double [] partial_derivative = new double[theta.length];
for(int j =0;j<theta.length;j++)//遍历,对每个theta求偏导数
{
partial_derivative[j]= compute_partial_derivative_for_theta(j);//对 theta i 求 偏导
}
return partial_derivative;
}
private double compute_partial_derivative_for_theta(int j)
{
double sum=0.0;
for(int i=0;i<row;i++)//遍历 每一行数据
{
sum+=h_theta_x_i_minus_y_i_times_x_j_i(i,j);
}
return sum/row;
}
private double h_theta_x_i_minus_y_i_times_x_j_i(int i,int j)
{
double[] oneRow = getRow(i);//取一行数据,前面是feature,最后一个是y
double result = 0.0; for(int k=0;k< (oneRow.length-1);k++)
result+=theta[k]*oneRow[k];
result-=oneRow[oneRow.length-1];
result*=oneRow[j];
return result;
}
private double [] getRow(int i)//从训练数据中取出第i行,i=0,1,2,。。。,(row-1)
{
return trainData[i];
} private void loadTrainDataFromFile(String fileName,int row, int column)
{
for(int i=0;i< row;i++)//trainData的第一列全部置为1.0(feature x0)
trainData[i][0]=1.0; File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
String tempString = null;
int counter = 0;
while ( (counter<row) && (tempString = reader.readLine()) != null) {
String [] tempData = tempString.split(" ");
for(int i=0;i<column;i++)
trainData[counter][i+1]=Double.parseDouble(tempData[i]);
counter++;
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
} public void printTrainData()
{
System.out.println("Train Data:\n");
for(int i=0;i<column-1;i++)
System.out.printf("%10s","x"+i+" ");
System.out.printf("%10s","y"+" \n");
for(int i=0;i<row;i++)
{
for(int j=0;j<column;j++)
{
System.out.printf("%10s",trainData[i][j]+" ");
}
System.out.println();
}
System.out.println();
} public void printTheta()
{
for(double a:theta)
System.out.print(a+" ");
} }

TestLinearRegression.java

public class TestLinearRegression {

    public static void main(String[] args) {
// TODO Auto-generated method stub
LinearRegression m = new LinearRegression("trainData",0.001,1000000);
m.printTrainData();
m.trainTheta();
m.printTheta();
} }

trainData文件中是训练数据,默认最后一列是y,比如:

1.0       2.0       7.2 
             2.0       1.0       4.9 
             3.0       0.0       2.6 
             4.0       1.0       6.3 
             5.0      -1.0       1.0 
            6.0       0.0       4.7 
            7.0      -2.0      -0.6

前两列是“feature”,最后一列,也就是第三列是y

Email: wuzimian2006@163.com

QQ:    726590906

多元线性回归----Java简单实现的更多相关文章

  1. 多元线性回归(Multivariate Linear Regression)简单应用

    警告:本文为小白入门学习笔记 数据集: http://openclassroom.stanford.edu/MainFolder/DocumentPage.php?course=DeepLearnin ...

  2. day-12 python实现简单线性回归和多元线性回归算法

    1.问题引入  在统计学中,线性回归是利用称为线性回归方程的最小二乘函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析.这种函数是一个或多个称为回归系数的模型参数的线性组合.一个带有一个自变 ...

  3. 【TensorFlow篇】--Tensorflow框架初始,实现机器学习中多元线性回归

    一.前述 TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理.Tensor(张量)意味着N维数组,Flow(流)意味着基于数据流图的计算,T ...

  4. coursera机器学习笔记-多元线性回归,normal equation

    #对coursera上Andrew Ng老师开的机器学习课程的笔记和心得: #注:此笔记是我自己认为本节课里比较重要.难理解或容易忘记的内容并做了些补充,并非是课堂详细笔记和要点: #标记为<补 ...

  5. 多元线性回归 ——模型、估计、检验与预测

    一.模型假设 传统多元线性回归模型 最重要的假设的原理为: 1. 自变量和因变量之间存在多元线性关系,因变量y能够被x1,x2-.x{k}完全地线性解释:2.不能被解释的部分则为纯粹的无法观测到的误差 ...

  6. 多元线性回归模型的特征压缩:岭回归和Lasso回归

    多元线性回归模型中,如果所有特征一起上,容易造成过拟合使测试数据误差方差过大:因此减少不必要的特征,简化模型是减小方差的一个重要步骤.除了直接对特征筛选,来也可以进行特征压缩,减少某些不重要的特征系数 ...

  7. 多元线性回归公式推导及R语言实现

    多元线性回归 多元线性回归模型 实际中有很多问题是一个因变量与多个自变量成线性相关,我们可以用一个多元线性回归方程来表示. 为了方便计算,我们将上式写成矩阵形式: Y = XW 假设自变量维度为N W ...

  8. 【R】多元线性回归

    R中的线性回归函数比较简单,就是lm(),比较复杂的是对线性模型的诊断和调整.这里结合Statistical Learning和杜克大学的Data Analysis and Statistical I ...

  9. 斯坦福机器学习视频笔记 Week2 多元线性回归 Linear Regression with Multiple Variables

    相比于week1中讨论的单变量的线性回归,多元线性回归更具有一般性,应用范围也更大,更贴近实际. Multiple Features 上面就是接上次的例子,将房价预测问题进行扩充,添加多个特征(fea ...

随机推荐

  1. Autofac 组件、服务、自动装配(2)

    一.组件 创建出来的对象需要从组件中来获取,组件的创建有如下4种(延续第一篇的Demo,仅仅变动所贴出的代码)方式: 1.类型创建RegisterType AutoFac能够通过反射检查一个类型,选择 ...

  2. Linq第二讲

    这一讲,来说说集合.因为linq主要用于对数据源进行查询,集合是最常见的数据源. 集合 形式: 数组,列表List<T> Arraylist等. 特点: 可通过索引或键访问.可进行fore ...

  3. sublime Text 常用操作

    原文出处:http://www.php100.com/html/it/focus/2014/1030/7666.html 1. 多光标操作:只要按下Cmd(Windows系统下Ctrl)键,再用鼠标选 ...

  4. Java良葛格 学习笔记

    学习一个新的事物时,如果遇到一些概念无法很快理解,这可能是因为要理解概念会需要其它概念先建立起来,所以先暂时放下这个疑问也是一个学习方法,称之为“存疑” ,在以后的学习过程中待必要的概念学会后,目前的 ...

  5. AIR使用文件对象操作文件和目录

    文件对象是啥?文件对象(File对象)是在文件系统中指向文件或目录的指针.由于安全原因,只在AIR中可用. 文件对象能做啥? 获取特定目录,包括用户目录.用户文档目录.该应用程序启动的目录和程序目录 ...

  6. pulseaudio的交叉编译

    在/etc/profile里导入 export PATH==$PATH:/home/jack/arm-linux-gcc/x-tools/arm-unknown-linux-gnueabi/bin 配 ...

  7. 报表Reporting S而vice是 错误的解决

    Reporting Services 错误 报表服务器无法打开与报表服务器数据库的连接.所有请求和处理都要求与数据库建立连接. (rsReportServerDatabaseUnavailable)获 ...

  8. ccw-ide

    有bug,会把working set弄乱,整理后能重启一次正常,再次重启又乱了.

  9. setTranslucent

    在ios7中 如果setTranslucent=yes 默认的   则状态栏及导航栏底部为透明的,界面上的组件应该从屏幕顶部开始显示,因为是半透明的,可以看到,所以为了不和状态栏及导航栏重叠,第一个组 ...

  10. pch文件的配置与路径修改