spark 线性回归算法(scala)
构建Maven项目,托管jar包
数据格式
//0.fp_nid,1.nsr_id,2.gf_id,2.hydm,3.djzclx_dm,4.kydjrq,5.xgrq,6.je,7.se,8.jshj,9.kpyf,10.kprq,11.zfbz,12.date_key,13.hwmc,14.ggxh,15.dw,16.sl,17.dj,18.je je1,19.se1,20.spbm,21.label
(fpid_10000201 115717 (2239 173 2011-07-12 00:00:00.0 2016-08-31 15:40:37.0 4123.08 700.92 4824.0 201704 2017-04-25 N) 201706 可视回油单向阀 HYS-1Φ1.5A 只 3.0 35.8974358974359 107.69 18.31 1090120040000000000) 0)
(fpid_10000324 253389 (7310 173 2016-01-04 00:00:00.0 2017-07-24 10:01:02.0 36609.76 6223.64 42833.4 201709 2017-09-08 N) 201711 电视机 三星743寸 台 1.0 2991.4529914529912 2991.45 508.55 1090522010000000000) 0)
(fpid_10000416 126378 (5175 173 1999-01-14 00:00:00.0 2016-05-27 14:50:49.0 25337.81 4307.39 29645.2 201612 2016-12-21 N) 201706 防水涂料 null 公斤 105.0 5.225885225885226 548.72 93.28 1070101060000000000) 0)
package Test.tett1 import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.regression.LinearRegressionModel
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.regression.LinearRegression object MLDemo3 { def main(args: Array[String]): Unit = {
val sess = SparkSession.builder().appName("ml").master("local[4]").getOrCreate();
val sc = sess.sparkContext;
val dataDir = "hdfs://weekend110:9000/user/hive/warehouse/nsr2_xfp"
//定义样例类(要分析数据的类属性)
case class FP(fp_nid:String,nsr_id:String,gf_id:String,hydm:String,djzclx_dm:String,kydjrq:String,xgrq:String,
je:String,se:String,jshj:String,kpyf:String,kprq:String,zfbz:String,
label:String) //变换()
//0.fp_nid,1.nsr_id,2.gf_id,2.hydm,3.djzclx_dm,4.kydjrq,5.xgrq,6.je,7.se,8.jshj,9.kpyf,10.kprq,11.zfbz,12.date_key,13.hwmc,14.ggxh,15.dw,16.sl,17.dj,18.je je1,19.se1,20.spbm,21.label
val fpDataRDD = sc.textFile(dataDir).map(_.split("\001")).map(f => FP(f(0).toString,
f(1).toString,f(2).toString,f(3).toString,f(4).toString,f(5).toString,f(6).toString,
f(7).toString, f(8).toString,f(9).toString,f(10).toString,f(11).toString,f(12).toString,
f(13).toString)) import sess.implicits._ def strToDouble(str: String): Double = {
val regex = """([0-9]+)""".r
val res = str match{
case regex(num) => num
case _ => "1"
}
val resDouble = res.toDouble
resDouble
} //转换RDD成DataFrame
//1.fp_nid 2.nsr_id 3.gf_id 4.zfbz 5.hydm 6.djzclx_dm 7.je 8.se 9.jshj 10.kpyf 11.date_key 12.sl 13.dj 14.je1 15.se1 16.spbm
val trainingDF = fpDataRDD.map(f => (f.label.replaceAll("[)]","").toDouble,
Vectors.dense(
if(f.zfbz.equals("N)")) 1 else 0,
f.hydm.replaceAll("[(]","").toDouble,
f.djzclx_dm.toDouble,
f.kpyf.toDouble,
strToDouble(f.je),
strToDouble(f.se),
strToDouble(f.jshj)
))).toDF("label", "features") //显式数据
trainingDF.show()
println("======================") //创建线性回归对象
val lr = new LinearRegression()
//设置最大迭代次数
lr.setMaxIter(50)
//通过线性回归拟合训练数据,生成模型
val model = lr.fit(trainingDF) //创建内存测试数据数据框
val testDF = sess.createDataFrame(Seq(
(0,Vectors.dense(3812,171,9401.71,1598.29,11000.0,201612,1)),
(0,Vectors.dense(4190,173,72200.0,12274.0,84474.0,201710,1)),
(0,Vectors.dense(7519,173,99999.99,3000.0,102999.99,201709,1)), (1,Vectors.dense(1951,173,19743.59,3356.41,23100.0,201612,1)),
(1,Vectors.dense(5219,173,41880.35,7119.65,49000.0,201705,1)),
(1,Vectors.dense(5189,173,1320.93,224.56,1545.49,201611,1)),
(1,Vectors.dense(1779,173,21911.4,3724.94,25636.34,201611,0))
)).toDF("label", "features") testDF.show() //创建临时视图
testDF.createOrReplaceTempView("test")
println("======================") //利用model对测试数据进行变化,得到新数据框,查询features", "label", "prediction方面值。
val tested = model.transform(trainingDF).select("features", "label", "prediction");
tested.show(); //将分析的数据导入数据库
import java.sql.DriverManager
tested.rdd.foreachPartition(
it =>{
var url = "jdbc:mysql://localhost:3306/data?useUnicode=true&characterEncoding=utf8"
val conn= DriverManager.getConnection(url,"root","123456")
val pstat = conn.prepareStatement ("INSERT INTO `test` (`label`, `pre`,`zfbz`,`hydm`, `djzclx_dm`, "
+"`kpyf`,`je`,`se`,`jshj`) "
+"VALUES (?,?,?,?,?,?,?,?,?)")
for (obj <-it){
pstat.setString(1,obj.get(1).toString())
pstat.setString(2,obj.get(2).toString())
pstat.setString(3,obj.get(0).toString().split(",")(0).replaceAll("[\\[]", ""))
pstat.setString(4,obj.get(0).toString().split(",")(1))
pstat.setString(5,obj.get(0).toString().split(",")(2))
pstat.setString(6,obj.get(0).toString().split(",")(3))
pstat.setString(7,obj.get(0).toString().split(",")(4))
pstat.setString(8,obj.get(0).toString().split(",")(5))
pstat.setString(9,obj.get(0).toString().split(",")(6) .replaceAll("[\\]]", ""))
pstat.addBatch
}
try{
pstat.executeBatch
}finally{
pstat.close
conn.close
}
}
)
}
}
maven的pom.xml配置文件
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>Test</groupId>
<artifactId>tett1</artifactId>
<version>0.0.1-SNAPSHOT</version>
<inceptionYear>2008</inceptionYear>
<properties>
<scala.version>2.7.0</scala.version>
</properties> <repositories>
<repository>
<id>scala-tools.org</id>
<name>Scala-Tools Maven2 Repository</name>
<url>http://scala-tools.org/repo-releases</url>
</repository>
</repositories> <pluginRepositories>
<pluginRepository>
<id>scala-tools.org</id>
<name>Scala-Tools Maven2 Repository</name>
<url>http://scala-tools.org/repo-releases</url>
</pluginRepository>
</pluginRepositories> <dependencies>
<!-- <dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency> -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>2.1.0</version>
</dependency>
</dependencies> <build>
<sourceDirectory>src/main/scala</sourceDirectory>
<testSourceDirectory>src/test/scala</testSourceDirectory>
<pluginManagement>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<skip>true</skip>
</configuration>
</plugin> <plugin>
<groupId>org.scala-tools</groupId>
<artifactId>maven-scala-plugin</artifactId>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
</goals>
</execution>
</executions>
<configuration>
<scalaVersion>${scala.version}</scalaVersion>
<args>
<arg>-target:jvm-1.5</arg>
</args>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-eclipse-plugin</artifactId>
<configuration>
<downloadSources>true</downloadSources>
<buildcommands>
<buildcommand>ch.epfl.lamp.sdt.core.scalabuilder</buildcommand>
</buildcommands>
<additionalProjectnatures>
<projectnature>ch.epfl.lamp.sdt.core.scalanature</projectnature>
</additionalProjectnatures>
<classpathContainers>
<classpathContainer>org.eclipse.jdt.launching.JRE_CONTAINER</classpathContainer>
<classpathContainer>ch.epfl.lamp.sdt.launching.SCALA_CONTAINER</classpathContainer>
</classpathContainers>
</configuration>
</plugin>
</plugins>
</pluginManagement>
</build>
<reporting>
<plugins>
<plugin>
<groupId>org.scala-tools</groupId>
<artifactId>maven-scala-plugin</artifactId>
<configuration>
<scalaVersion>${scala.version}</scalaVersion>
</configuration>
</plugin>
</plugins>
</reporting>
</project>
spark 线性回归算法(scala)的更多相关文章
- Spark机器学习(1):线性回归算法
线性回归算法,是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法. 1. 梯度下降法 线性回归可以使用最小二乘法,但是速度比较慢,因此一般使用梯度下降法(Grad ...
- 在Spark上用Scala实验梯度下降算法
首先参考的是这篇文章:http://blog.csdn.net/sadfasdgaaaasdfa/article/details/45970185 但是其中的函数太老了.所以要改.另外出发点是我自己的 ...
- scikit-learn 线性回归算法库小结
scikit-learn对于线性回归提供了比较多的类库,这些类库都可以用来做线性回归分析,本文就对这些类库的使用做一个总结,重点讲述这些线性回归算法库的不同和各自的使用场景. 线性回归的目的是要得到输 ...
- 《BI那点儿事》Microsoft 线性回归算法
Microsoft 线性回归算法是 Microsoft 决策树算法的一种变体,有助于计算依赖变量和独立变量之间的线性关系,然后使用该关系进行预测.该关系采用的表示形式是最能代表数据序列的线的公式.例如 ...
- [机器学习Lesson 2]代价函数之线性回归算法
本章内容主要是介绍:单变量线性回归算法(Linear regression with one variable) 1. 线性回归算法(linear regression) 1.1 预测房屋价格 该问题 ...
- 通过机器学习的线性回归算法预测股票走势(用Python实现)
在本人的新书里,将通过股票案例讲述Python知识点,让大家在学习Python的同时还能掌握相关的股票知识,所谓一举两得.这里给出以线性回归算法预测股票的案例,以此讲述通过Python的sklearn ...
- 机器学习---用python实现最小二乘线性回归算法并用随机梯度下降法求解 (Machine Learning Least Squares Linear Regression Application SGD)
在<机器学习---线性回归(Machine Learning Linear Regression)>一文中,我们主要介绍了最小二乘线性回归算法以及简单地介绍了梯度下降法.现在,让我们来实践 ...
- 机器学习-线性回归算法(单变量)Linear Regression with One Variable
1 线性回归算法 http://www.cnblogs.com/wangxin37/p/8297988.html 回归一词指的是,我们根据之前的数据预测出一个准确的输出值,对于这个例子就是价格,回归= ...
- 梯度下降算法&线性回归算法
**机器学习的过程说白了就是让我们编写一个函数使得costfunction最小,并且此时的参数值就是最佳参数值. 定义 假设存在一个代价函数 fun:\(J\left(\theta_{0}, \the ...
随机推荐
- vs关于“当前不会命中断点 还没有为该文档加载任何符号”的解决方法
首先调式的时候确定在debug模式下, 解决方法:工具-选项-调试 -(启用“仅我的代码”)勾去掉.
- Python全栈开发记录_第十篇(反射及选课系统练习)
反射机制:反射就是通过字符串的形式,导入模块:通过字符串的形式,去模块中寻找指定函数,对其进行操作.也就是利用字符串的形式去对象(模块)中操作(查找or获取or删除or添加)成员,一种基于字符串的事件 ...
- 自定义Windows右击菜单调用Winform程序
U9_Git中ignore文件处理 背景 U9代码中有许多自动生成的文件,不需要上传Git必须BE Entity中的.target文件 .bak 文件 Enum.cs结尾的文件,还有许多 extand ...
- springboot学习一:快速搭建springboot项目
1.idea创建springboot工程 JDK选择1.8以上的版本 选择springboot的版本和添加配置项 新建一个HelloController,测试 访问 http://localhost: ...
- expdp/impdp数据泵分区表导入太慢了。添加不检查元数据参数提高效率:ACCESS_METHOD=DIRECT_PATH
分区表数据泵导入太慢,达不到客户的迁移要求导出语句如下:(10G单节点)userid='/ as sysdba'directory=milk_dirdumpfile=mon_%U.dmplogfile ...
- Linux on window初体验
参照来源: https://www.cnblogs.com/enet01/p/7458767.html 1:liunx on window 的配置不多说(百度网上很多)启动开发这模式,在应用和程序中勾 ...
- POI导入具有合并了单元格的Excel
POI进行单行单行地导入的数据在网上有许多的文章,但是要导入一个具有合并单元格的excel貌似比较难找.刚好最近完成了这样的一个需求,要求导入具有合并单元格的excel: /** * 读取excel数 ...
- 关于SQLserver2008索引超出了数据
由于公司只支持了2008.不支持2012的数据库.所以安装的2008.但在对表进行操作的时候出现如下异常: 这个问题是由于本地装的2008,但IT那边的测试机上面确装的2012.所以2008连接了20 ...
- 深度学习原理与框架- tf.nn.conv2d_transpose(反卷积操作) tf.nn.conv2d_transpose(进行反卷积操作) 对于stride的理解存在问题?
反卷积操作: 首先对需要进行维度扩张的feature_map 进行补零操作,然后使用3*3的卷积核,进行卷积操作,使得其维度进行扩张,图中可以看出,2*2的feature经过卷积变成了4*4. ...
- Spring Security 理解小记
JWT 框架图如下, 来自博客https://blog.csdn.net/shehun1/article/details/45394405 个人觉得还不错.. 在开发中Spring boot 启用 加 ...