线性回归

  • 需求:从文件读取数据对,计算回归函数及系数
  • 实现1:commons.math的SimpleRegression,定义函数getData从文件读取数据返回SimpleRegression类

 1 import java.io.File;
2 import java.io.FileNotFoundException;
3 import java.util.Scanner;
4 import org.apache.commons.math3.stat.regression.SimpleRegression;
5
6 public class Example1 {
7 public static void main(String[] args) {
8 SimpleRegression sr = getData("data/Data1.dat");
9 double m = sr.getSlope();
10 double b = sr.getIntercept();
11 double r = sr.getR(); // correlation coefficient
12 double r2 = sr.getRSquare();
13 double sse = sr.getSumSquaredErrors();
14 double tss = sr.getTotalSumSquares();
15
16 System.out.printf("y = %.6fx + %.4f%n", m, b);
17 System.out.printf("r = %.6f%n", r);
18 System.out.printf("r2 = %.6f%n", r2);
19 System.out.printf("EV = %.5f%n", tss - sse);
20 System.out.printf("UV = %.4f%n", sse);
21 System.out.printf("TV = %.3f%n", tss);
22 }
23
24 public static SimpleRegression getData(String data) {
25 SimpleRegression sr = new SimpleRegression();
26 try {
27 Scanner fileScanner = new Scanner(new File(data));
28 fileScanner.nextLine(); // read past title line
29 int n = fileScanner.nextInt();
30 fileScanner.nextLine(); // read past line of labels
31 fileScanner.nextLine(); // read past line of labels
32 for (int i = 0; i < n; i++) {
33 String line = fileScanner.nextLine();
34 Scanner lineScanner = new Scanner(line).useDelimiter("\\t");
35 double x = lineScanner.nextDouble();
36 double y = lineScanner.nextDouble();
37 sr.addData(x, y);
38 }
39 } catch (FileNotFoundException e) {
40 System.err.println(e);
41 }
42 return sr;
43 }
44 }
  • 实现2:直接计算统计量

 1 import java.io.File;
2 import java.io.FileNotFoundException;
3 import java.util.Scanner;
4
5 public class Example2 {
6 private static double sX=0, sXX=0, sY=0, sYY=0, sXY=0;
7 private static int n=0;
8
9 public static void main(String[] args) {
10 getData("data/Data1.dat");
11 double m = (n*sXY - sX*sY)/(n*sXX - sX*sX);
12 double b = sY/n - m*sX/n;
13 double r2 = m*m*(n*sXX - sX*sX)/(n*sYY - sY*sY);
14 double r = Math.sqrt(r2);
15 double tv = sYY - sY*sY/n;
16 double mX = sX/n; // mean value of x
17 double ev = (sXX - 2*mX*sX + n*mX*mX)*m*m;
18 double uv = tv - ev;
19
20 System.out.printf("y = %.6fx + %.4f%n", m, b);
21 System.out.printf("r = %.6f%n", r);
22 System.out.printf("r2 = %.6f%n", r2);
23 System.out.printf("EV = %.5f%n", ev);
24 System.out.printf("UV = %.4f%n", uv);
25 System.out.printf("TV = %.3f%n", tv);
26 }
27
28 public static void getData(String data) {
29 try {
30 Scanner fileScanner = new Scanner(new File(data));
31 fileScanner.nextLine(); // read past title line
32 n = fileScanner.nextInt();
33 fileScanner.nextLine(); // read past line of labels
34 fileScanner.nextLine(); // read past line of labels
35 for (int i = 0; i < n; i++) {
36 String line = fileScanner.nextLine();
37 Scanner lineScanner = new Scanner(line).useDelimiter("\\t");
38 double x = lineScanner.nextDouble();
39 double y = lineScanner.nextDouble();
40 sX += x;
41 sXX += x*x;
42 sY += y;
43 sYY += y*y;
44 sXY += x*y;
45 }
46 } catch (FileNotFoundException e) {
47 System.err.println(e);
48 }
49 }
50 }

y = 0.882279x + 18.8739
r = 0.935222
r2 = 0.874641
EV = 1423.35676
UV = 204.0042
TV = 1627.361

  • 实现3:对辅助类进行实例化,并绘图

Example3.java

 1 import java.io.File;
2 import javax.swing.JFrame;
3
4 public class Example3 {
5 public static void main(String[] args) {
6 Data data = new Data(new File("data/Data1.dat"));
7 JFrame frame = new JFrame(data.getTitle());
8 frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
9 RegressionPanel panel = new RegressionPanel(data);
10 frame.add(panel);
11 frame.pack();
12 frame.setSize(500, 422);
13 frame.setResizable(false);
14 frame.setLocationRelativeTo(null); // center frame on screen
15 frame.setVisible(true);
16 }
17 }

Data.java

  1 import java.io.File;
2 import java.io.FileNotFoundException;
3 import java.util.Scanner;
4
5 public class Data {
6 private String title,xName, yName;
7 private int n;
8 private double[] x, y;
9 private double sX, sXX, sY, sYY, sXY, minX, minY, maxX, maxY;
10 private double meanX, meanY, slope, intercept, corrCoef;
11
12 public Data(File inputFile) {
13 try {
14 Scanner input = new Scanner(inputFile);
15 title = input.nextLine();
16 n = input.nextInt();
17 xName = input.next();
18 yName = input.next();
19 input.nextLine();
20 x = new double[n];
21 y = new double[n];
22 minX = minY = Double.POSITIVE_INFINITY;
23 maxX = maxY = Double.NEGATIVE_INFINITY;
24 for (int i = 0; i < n; i++) {
25 double xi = x[i] = input.nextDouble();
26 double yi = y[i] = input.nextDouble();
27 sX += xi;
28 sXX += xi*xi;
29 sY += yi;
30 sYY += yi*yi;
31 sXY += xi*yi;
32 minX = (xi < minX? xi: minX);
33 minY = (yi < minY? yi: minY);
34 maxX = (xi > maxX? xi: maxX);
35 maxY = (yi > maxY? yi: maxY);
36 }
37 meanX = sX/n;
38 meanY = sY/n;
39 slope = (n*sXY - sX*sY)/(n*sXX - sX*sX);
40 intercept = meanY - slope*meanX;
41 corrCoef = slope*Math.sqrt((n*sXX - sX*sX)/(n*sYY - sY*sY));
42 } catch (FileNotFoundException e) {
43 System.err.println(e);
44 }
45 }
46
47 public String getTitle() {
48 return title;
49 }
50
51 public String getXName() {
52 return xName;
53 }
54
55 public String getYName() {
56 return yName;
57 }
58
59 public int getN() {
60 return n;
61 }
62
63 public double[] getX() {
64 return x;
65 }
66
67 public double[] getY() {
68 return y;
69 }
70
71 public double getMeanX() {
72 return meanX;
73 }
74
75 public double getMeanY() {
76 return meanY;
77 }
78
79 public double getSlope() {
80 return slope;
81 }
82
83 public double getIntercept() {
84 return intercept;
85 }
86
87 public double getCorrCoef() {
88 return corrCoef;
89 }
90
91 public double[][] getTable() {
92 double[][] table = new double[n][2];
93 for (int i = 0; i < n; i++) {
94 table[i][0] = x[i];
95 table[i][1] = y[i];
96 }
97 return table;
98 }
99
100 public double getMinX() {
101 return minX;
102 }
103
104 public double getMinY() {
105 return minY;
106 }
107
108 public double getMaxX() {
109 return maxX;
110 }
111
112 public double getMaxY() {
113 return maxY;
114 }
115 }

RegressionPanal.java

import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Graphics;
import java.awt.Graphics2D;
import javax.swing.JPanel; public class RegressionPanel extends JPanel {
private static final int WIDTH=500, HEIGHT=400, BUFFER=28, MARGIN=40;
private final Data data;
private double xMin, xMax, yMin, yMax, xRange, yRange, gWidth, gHeight;
private double slope, intercept; public RegressionPanel(Data data) {
this.data = data;
this.setSize(WIDTH, HEIGHT);
this.xMin = data.getMinX();
this.xMax = data.getMaxX();
this.yMin = data.getMinY();
this.yMax = data.getMaxY();
this.slope = data.getSlope();
this.intercept = data.getIntercept();
this.xRange = xMax - xMin;
this.yRange = yMax - yMin;
this.gWidth = WIDTH - 2*MARGIN - BUFFER;
this.gHeight = HEIGHT - 2*MARGIN - BUFFER;
setBackground(Color.WHITE);
} @Override
public void paintComponent(Graphics g) {
super.paintComponent(g);
Graphics2D g2 = (Graphics2D)g;
g2.setStroke(new BasicStroke(1));
drawGrid(g2);
drawPoints(g2, data.getX(), data.getY());
drawLine(g2);
} private void drawGrid(Graphics2D g2) {
g2.setStroke(new BasicStroke(1));
double xGd = Math.pow(10, Math.floor(Math.log10(xRange)));
int xd = dToI(xGd);
int x0 = dToI(xGd*Math.floor(xMin/xGd));
int xn = dToI(xGd*Math.ceil(xMax/xGd));
for (int xi = x0; xi <= xn; xi += xd) {
g2.setColor(Color.LIGHT_GRAY);
int p = f(xi);
g2.drawLine(p, 0, p, HEIGHT-18); // vertical lines
g2.setColor(Color.BLACK);
g2.drawString(""+xi, p-8, HEIGHT-4);
}
double yGd = Math.pow(10, Math.floor(Math.log10(yRange)));
int yd = dToI(yGd);
int y0 = dToI(xGd*Math.floor(xMin/yGd));
int yn = dToI(xGd*Math.ceil(yMax/yGd));
for (int yi = y0; yi <= yn; yi += yd) {
g2.setColor(Color.LIGHT_GRAY);
int q = g(yi);
g2.drawLine(BUFFER, q, WIDTH, q); // horizontal lines
g2.setColor(Color.LIGHT_GRAY);
g2.setColor(Color.BLACK);
g2.drawString((yi<100?" ":"")+yi, 2, q+5);
}
} private void drawPoints(Graphics2D g2, double[] x, double[] y) {
g2.setColor(Color.BLACK);
for (int i = 0; i < x.length; i++) {
int u = f(x[i]);
int v = g(y[i]);
g2.fillOval(u-3, v-3, 6, 6); // coordinates are at NW corners
}
} private void drawLine(Graphics2D g2) {
g2.setColor(Color.BLUE);
g2.setStroke(new BasicStroke(2));
int p0 = BUFFER;
int q0 = g(yLine(fInv(p0)));
int p1 = WIDTH;
int q1 = g(yLine(fInv(p1)));
g2.drawLine(p0, q0, p1, q1);
} private double yLine(double x) {
return slope*x + intercept;
} private int dToI(double x) {
return (int)Math.round(x);
} private int f(double x) {
return dToI((x - xMin)*gWidth/xRange) + BUFFER + MARGIN;
} private int g(double y) {
return dToI(gHeight - (y - yMin)*gHeight/yRange) + MARGIN;
} private double fInv(int p) {
return (p - BUFFER - MARGIN)*xRange/gWidth + xMin;
} private double gInv(int q) {
return yMin + (gHeight + MARGIN - q)*yRange/gHeight;
}
}

多项式回归

  • 需求:已知刹车速度和距离的数据,求解
  • 实现:最小二乘法,解方程组,LU分解

 1 import org.apache.commons.math3.linear.*;
2
3 public class Example4 {
4 static double[] x = {20, 30, 40, 50, 60, 70};
5 static double[] y = {52, 87, 136, 203, 290, 394};
6 static int n = y.length; // 6
7
8 public static void main(String[] args) {
9 double[][] a = new double[3][3];
10 double[] w = new double[3];
11 deriveNormalEquations(a, w);
12 printNormalEquations(a, w);
13 double[] b = solveNormalEquations(a, w);
14 printResults(b);
15 }
16
17 public static void deriveNormalEquations(double[][] a, double[] w) {
18 for (int i = 0; i < n; i++) {
19 double xi = x[i];
20 double yi = y[i];
21 a[0][0] = n;
22 a[0][1] = a[1][0] += xi;
23 a[0][2] = a[1][1] = a[2][0] += xi*xi;
24 a[1][2] = a[2][1] += xi*xi*xi;
25 a[2][2] += xi*xi*xi*xi;
26 w[0] += yi;
27 w[1] += xi*yi;
28 w[2] += xi*xi*yi;
29 }
30 }
31
32 public static void printNormalEquations(double[][] a, double[] w) {
33 for (int i = 0; i < 3; i++) {
34 System.out.printf("%8.0fb0 + %6.0fb1 + %8.0fb2 = %7.0f%n",
35 a[i][0], a[i][1], a[i][2], w[i]);
36 }
37 }
38
39 /* Solves the matrix equation a*b = w for b[], representing a[]
40 as RealMatrix m and b[] as RealVector v:
41 */
42 private static double[] solveNormalEquations(double[][] a, double[] w) {
43 RealMatrix m = new Array2DRowRealMatrix(a, false);
44 LUDecomposition lud = new LUDecomposition(m);
45 DecompositionSolver solver = lud.getSolver();
46 RealVector v = new ArrayRealVector(w, false);
47 return solver.solve(v).toArray();
48 }
49
50 private static void printResults(double[] b) {
51 System.out.printf("f(t) = %.2f + %.3ft + %.5ft^2%n", b[0], b[1], b[2]);
52 System.out.printf("f(55) = %.1f%n", f(55, b));
53 }
54
55 private static double f(double t, double[] b) {
56 return b[0] + b[1]*t + b[2]*t*t;
57 }
58 }

6b0 + 270b1 + 13900b2 = 1162
270b0 + 13900b1 + 783000b2 = 64220
13900b0 + 783000b1 + 46750000b2 = 3798800
f(t) = 40.73 + -1.170t + 0.08875t^2
f(55) = 244.8

多元线性回归

  • 需求:变量y依赖于多个变量
  • 实现:直接求解或通过Apache Commons

Example5.java

 1 import org.apache.commons.math3.linear.*;
2
3 public class Example5 {
4 static double[] x = {10, 9, 12, 10, 9, 10, 8, 11};
5 static double[] y = {59, 57, 61, 52, 48, 55, 51, 62};
6 static double[] z = {71, 68, 76, 56, 57, 77, 55, 67};
7 static int n = z.length; // 8
8
9 public static void main(String[] args) {
10 double[][] a = new double[3][3];
11 double[] w = new double[3];
12 deriveNormalEquations(a, w);
13 printNormalEquations(a, w);
14 double[] b = solveNormalEquations(a, w);
15 printResults(b);
16 }
17
18 public static void deriveNormalEquations(double[][] a, double[] w) {
19 for (int i = 0; i < n; i++) {
20 double xi = x[i];
21 double yi = y[i];
22 double zi = z[i];
23 a[0][0] = n;
24 a[0][1] = a[1][0] += xi;
25 a[0][2] = a[2][0] += yi;
26 a[1][1] += xi*xi;
27 a[1][2] = a[2][1] += xi*yi;
28 a[2][2] += yi*yi;
29 w[0] += zi;
30 w[1] += xi*zi;
31 w[2] += yi*zi;
32 }
33 }
34
35 public static void printNormalEquations(double[][] a, double[] w) {
36 for (int i = 0; i < 3; i++) {
37 System.out.printf("%6.0fx0 + %4.0fx1 + %5.0fx2 = %5.0f%n",
38 a[i][0], a[i][1], a[i][2], w[i]);
39 }
40 }
41
42 private static double[] solveNormalEquations(double[][] a, double[] w) {
43 RealMatrix m = new Array2DRowRealMatrix(a, false);
44 LUDecomposition lud = new LUDecomposition(m);
45 DecompositionSolver solver = lud.getSolver();
46 RealVector v = new ArrayRealVector(w, false);
47 return solver.solve(v).toArray();
48 }
49
50 private static void printResults(double[] b) {
51 System.out.printf("f(s, t) = %.2f + %.2fs + %.2ft%n", b[0], b[1], b[2]);
52 System.out.printf("f(10, 59) = %.1f%n", f(10, 59, b));
53 System.out.printf("f(9, 57) = %.1f%n", f(9, 57, b));
54 System.out.printf("f(11, 64) = %.1f%n", f(11, 64, b));
55 }
56
57 private static double f(double s, double t, double[] b) {
58 return b[0] + b[1]*s + b[2]*t;
59 }
60 }

Example6.java

 1 import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
2
3 public class Example6 {
4 static double[][] x = { {10, 59}, {9, 57}, {12, 61}, {10, 52}, {9, 48},
5 {10, 55}, {8, 51}, {11, 62} };
6 static double[] y = {71, 68, 76, 56, 57, 77, 55, 67};
7
8 public static void main(String[] args) {
9 OLSMultipleLinearRegression mlr = new OLSMultipleLinearRegression();
10 mlr.newSampleData(y, x);
11 double[] b = mlr.estimateRegressionParameters();
12 printResults(b);
13 }
14
15 private static void printResults(double[] b) {
16 System.out.printf("f(s, t) = %.2f + %.2fs + %.2ft%n", b[0], b[1], b[2]);
17 System.out.printf("f(10, 59) = %.1f%n", f(10, 59, b));
18 System.out.printf("f(9, 57) = %.1f%n", f(9, 57, b));
19 System.out.printf("f(11, 64) = %.1f%n", f(11, 64, b));
20 }
21
22 private static double f(double s, double t, double[] b) {
23 return b[0] + b[1]*s + b[2]*t;
24 }
25 }

8x0 + 79x1 + 445x2 = 527
79x0 + 791x1 + 4427x2 = 5254
445x0 + 4427x1 + 24929x2 = 29543
f(s, t) = -5.75 + 1.55s + 1.01t
f(10, 59) = 69.5
f(9, 57) = 65.9
f(11, 64) = 76.1

[Java] 数据分析 -- 回归分析的更多相关文章

  1. [Java] 数据分析 -- 大数据

    单词计数 需求:输入小说文本,输出每个单词出现的次数 实现:分map.combine.reduce三个阶段实现 1 /* Data Analysis with Java 2 * John R. Hub ...

  2. [Java] 数据分析 -- NoSQL数据库

    MongoDB概念:与关系型数据库对应 database(数据库):数据库 collection(集合):表 document(文档):行 field(域):列/字段 注意事项 文档是一组键值(key ...

  3. [Java]数据分析--聚类

    距离度量 需求:计算两点间的欧几里得距离.曼哈顿距离.切比雪夫距离.堪培拉距离 实现:利用commons.math3库相应函数 1 import org.apache.commons.math3.ml ...

  4. [Java] 数据分析--分类

    ID3算法 思路:分类算法的输入为训练集,输出为对数据进行分类的函数.ID3算法为分类函数生成分类树 需求:对水果训练集的一个维度(是否甜)进行预测 实现:决策树,熵函数,ID3,weka库 J48类 ...

  5. [Java] 数据分析--统计

    二项分布 需求:5个四面体筛子,筛子三面绿色,一面红色,模拟1000000次,统计每次试验红色落地筛子个数的分布 实现:用循环实现5个筛子和1000000次试验,定义函数numRedDown模拟5个筛 ...

  6. [Java]数据分析--数据可视化

    时间序列 需求:将一组字符顺序添加到时间序列中 实现:定义时间序列类TimeSeries,包含静态类Entry表示序列类中的各项,以及add,get,iterator,entry方法 TimeSeri ...

  7. [Java] 数据分析--数据预处理

    数据结构 键-值对:HashMap 1 import java.io.File; 2 import java.io.FileNotFoundException; 3 import java.util. ...

  8. Spark案例分析

    一.需求:计算网页访问量前三名 import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} /* ...

  9. 一元线性回归分析及java实现

    http://blog.csdn.net/hwwn2009/article/details/38414911 一元线性回归分析及java实现 2014-08-07 11:02 1072人阅读 评论(0 ...

随机推荐

  1. Python数据分析入门(十):数据清洗和准备

    数据清洗是数据分析关键的一步,直接影响之后的处理工作 数据需要修改吗?有什么需要修改的吗?数据应该怎么调整才能适用于接下来的分析和挖掘? 是一个迭代的过程,实际项目中可能需要不止一次地执行这些清洗操作 ...

  2. 【面试技巧】老生常谈之 n 种使用 CSS 实现三角形的技巧

    在一些面经中,经常能看到有关 CSS 的题目都会有一道如何使用 CSS 绘制三角形,而常见的回答通常也只有使用 border 进行绘制一种方法. 而 CSS 发展到今天,其实有很多有意思的仅仅使用 C ...

  3. HTML5和CSS3 PC端静态网页琐碎知识点

    1.PC端为了兼容IE9以及IE9以下,尽量要使用float进行布局,兼容性好,一般不要用flex进行布局. 2.问起CSS选择器的分类,先说id选择器,类选择器,属性选择器,伪类选择器,伪元素选择器 ...

  4. python基础(五):列表的使用(上)

    什么是列表 列表是一系列元素,按特定顺序排列组成.列表总的元素之间没有任何关系,既可以时字符串,也可以是数字,还可以是布尔值. 由此可以看出,列表通常包含多个元素,因此再给列表命名的时候,最好使用复数 ...

  5. 轻松理解 Spring AOP

    目录 Spring AOP 简介 Spring AOP 的基本概念 面向切面编程 AOP 的目的 AOP 术语和流程 术语 流程 五大通知执行顺序 例子 图例 实际的代码 使用 Spring AOP ...

  6. BUAA_OO_第二单元

    BUAA_OO_2020_UNIT2 一.程序结构分析 第五次作业 UML & Mertrics ​ 电梯的调度问题,实质上就是任务的请求与分配问题,笔者在第五次作业中采用简单的"生 ...

  7. 解决JDK9以上的非法反射访问警告

    1 问题描述 JDK9以上很多库都有这种非法反射访问的警告,比如protostuff: 解决方法两个: JDK降级 添加JVM参数 2 原因 降到JDK8能解决以上问题. 但是这不是本文的重点. 先说 ...

  8. poj_1700 题解

    题目描述:在漆黑的夜里,四位旅行者来到了一座狭窄而且没有护栏的桥边. 如果不借助手电筒的话,大家是无论如何也不敢过桥去的. 不幸的是,四个人一共只带了一只手电筒,而桥窄得只够让两个人同时过. 如果各自 ...

  9. 869. Reordered Power of 2

    Starting with a positive integer N, we reorder the digits in any order (including the original order ...

  10. Python输入与输出

    输出 print函数 语法: print(self, *args, sep=' ', end='\n', file=None) print函数是python中最常见的一个函数.用于将内容打印输出. p ...