Spark 多项式逻辑回归__二分类
- package Spark_MLlib
- import org.apache.spark.ml.Pipeline
- import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
- import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
- import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
- import org.apache.spark.ml.linalg.Vectors
- import org.apache.spark.sql.SparkSession
- object 多项式逻辑回归__二分类 {
- val spark=SparkSession.builder().master("local").getOrCreate()
- import spark.implicits._ //支持把一个RDD隐式转换为一个DataFrame
- def main(args: Array[String]): Unit = {
- val df =spark.sparkContext.textFile("file:///home/soyo/桌面/spark编程测试数据/soyo.txt")
- .map(_.split(",")).map(x=>data_schema(Vectors.dense(x().toDouble,x().toDouble,x().toDouble,x().toDouble),x())).toDF()
- df.show()
- df.createOrReplaceTempView("data_schema")
- val df_data=spark.sql("select * from data_schema where label !='soyo2'") //这里soyo2需要加单引号,不然报错
- // df_data.map(x=>x(1)+":"+x(0)).collect().foreach(println)
- df_data.show()
- val labelIndexer=new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df_data)
- val featureIndexer=new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(df_data) //目的在特征向量中建类别索引
- val Array(trainData,testData)=df_data.randomSplit(Array(0.7,0.3))
- val lr=new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter().setRegParam(0.3).setElasticNetParam(0.8).setFamily("multinomial")//设置elasticnet混合参数为0.8,setFamily("multinomial"):设置为多项逻辑回归,不设置setFamily为二项逻辑回归
- val labelConverter=new IndexToString().setInputCol("prediction").setOutputCol("predictionLabel").setLabels(labelIndexer.labels)
- val lrPipeline=new Pipeline().setStages(Array(labelIndexer,featureIndexer,lr,labelConverter))
- val lrPipeline_Model=lrPipeline.fit(trainData)
- val lrPrediction=lrPipeline_Model.transform(testData)
- lrPrediction.show(false)
- // lrPrediction.take(100).foreach(println)
- //模型评估
- val evaluator=new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
- val lrAccuracy=evaluator.evaluate(lrPrediction)
- println("准确率为: "+lrAccuracy)
- val lrError=-lrAccuracy
- println("错误率为: "+lrError)
- val LRmodel=lrPipeline_Model.stages().asInstanceOf[LogisticRegressionModel]
- println("二项逻辑回归模型系数矩阵: "+LRmodel.coefficientMatrix)
- println("二项逻辑回归模型的截距向量: "+LRmodel.interceptVector)
- println("类的数量(标签可以使用的值): "+LRmodel.numClasses)
- println("模型所接受的特征的数量: "+LRmodel.numFeatures)
- }
- }
结果:
+-----------------+-----+
| features|label|
+-----------------+-----+
|[5.1,3.5,1.4,0.2]|soyo1|
|[4.9,3.0,1.4,0.2]|soyo1|
|[4.7,3.2,1.3,0.2]|soyo1|
|[4.6,3.1,1.5,0.2]|soyo1|
|[5.0,3.6,1.4,0.2]|soyo1|
|[5.4,3.9,1.7,0.4]|soyo1|
|[4.6,3.4,1.4,0.3]|soyo1|
|[5.0,3.4,1.5,0.2]|soyo1|
|[4.4,2.9,1.4,0.2]|soyo1|
|[4.9,3.1,1.5,0.1]|soyo1|
|[5.4,3.7,1.5,0.2]|soyo1|
|[4.8,3.4,1.6,0.2]|soyo1|
|[4.8,3.0,1.4,0.1]|soyo1|
|[4.3,3.0,1.1,0.1]|soyo1|
|[5.8,4.0,1.2,0.2]|soyo1|
|[5.7,4.4,1.5,0.4]|soyo1|
|[5.4,3.9,1.3,0.4]|soyo1|
|[5.1,3.5,1.4,0.3]|soyo1|
|[5.7,3.8,1.7,0.3]|soyo1|
|[5.1,3.8,1.5,0.3]|soyo1|
+-----------------+-----+
only showing top 20 rows
+-----------------+-----+------------+------------------+------------------------------------------+----------------------------------------+----------+---------------+
|features |label|indexedLabel|indexedFeatures |rawPrediction |probability |prediction|predictionLabel|
+-----------------+-----+------------+------------------+------------------------------------------+----------------------------------------+----------+---------------+
|[4.6,3.1,1.5,0.2]|soyo1|0.0 |[4.6,3.1,1.5,1.0] |[0.3841092104753886,-0.384109210475388] |[0.6831353764654857,0.3168646235345142] |0.0 |soyo1 |
|[4.6,3.2,1.4,0.2]|soyo1|0.0 |[4.6,3.2,1.4,1.0] |[0.4118074545189242,-0.41180745451892353] |[0.6950031457169539,0.3049968542830461] |0.0 |soyo1 |
|[4.6,3.4,1.4,0.3]|soyo1|0.0 |[4.6,3.4,1.4,2.0] |[0.41345332780578103,-0.41345332780578037]|[0.6957004614212158,0.30429953857878417]|0.0 |soyo1 |
|[4.7,3.2,1.6,0.2]|soyo1|0.0 |[4.7,3.2,1.6,1.0] |[0.39085103161962165,-0.390851031619621] |[0.6860468315498303,0.31395316845016974]|0.0 |soyo1 |
|[4.9,3.0,1.4,0.2]|soyo1|0.0 |[4.9,3.0,1.4,1.0] |[0.37736738933115554,-0.377367389331155] |[0.6802095073085258,0.3197904926914742] |0.0 |soyo1 |
|[4.9,3.1,1.5,0.1]|soyo1|0.0 |[4.9,3.1,1.5,0.0] |[0.4169034023763003,-0.4169034023762997] |[0.697159256477463,0.302840743522537] |0.0 |soyo1 |
|[5.0,3.0,1.6,0.2]|soyo1|0.0 |[5.0,3.0,1.6,1.0] |[0.356410966431853,-0.35641096643185244] |[0.6710244037082002,0.32897559629179984]|0.0 |soyo1 |
|[5.0,3.4,1.5,0.2]|soyo1|0.0 |[5.0,3.4,1.5,1.0] |[0.4357693082570414,-0.4357693082570408] |[0.705065751202206,0.2949342487977939] |0.0 |soyo1 |
|[5.0,3.4,1.6,0.4]|soyo1|0.0 |[5.0,3.4,1.6,3.0] |[0.35970271300556683,-0.35970271300556617]|[0.6724760743873281,0.3275239256126718] |0.0 |soyo1 |
|[5.1,3.4,1.5,0.2]|soyo1|0.0 |[5.1,3.4,1.5,1.0] |[0.4357693082570414,-0.4357693082570408] |[0.705065751202206,0.2949342487977939] |0.0 |soyo1 |
|[5.4,3.4,1.7,0.2]|soyo1|0.0 |[5.4,3.4,1.7,1.0] |[0.4148128853577389,-0.41481288535773825] |[0.6962757951954652,0.3037242048045349] |0.0 |soyo1 |
|[5.6,2.8,4.9,2.0]|soyo3|1.0 |[5.6,2.8,4.9,12.0]|[-0.3845461875044362,0.38454618750443703] |[0.3166754764713344,0.6833245235286656] |1.0 |soyo3 |
|[5.7,3.8,1.7,0.3]|soyo1|0.0 |[5.7,3.8,1.7,2.0] |[0.45089882383236457,-0.4508988238323638] |[0.7113187796385543,0.2886812203614457] |0.0 |soyo1 |
|[5.7,4.4,1.5,0.4]|soyo1|0.0 |[5.7,4.4,1.5,3.0] |[0.5423812503940613,-0.5423812503940606] |[0.7473941839256351,0.25260581607436505]|0.0 |soyo1 |
|[5.8,2.8,5.1,2.4]|soyo3|1.0 |[5.8,2.8,5.1,16.0]|[-0.5366793780073855,0.5366793780073863] |[0.2547648665744027,0.7452351334255972] |1.0 |soyo3 |
|[6.0,2.2,5.0,1.5]|soyo3|1.0 |[6.0,2.2,5.0,7.0] |[-0.3343736350128348,0.33437363501283546] |[0.3387774047228901,0.6612225952771099] |1.0 |soyo3 |
|[6.2,2.8,4.8,1.8]|soyo3|1.0 |[6.2,2.8,4.8,10.0]|[-0.3084795922529615,0.30847959225296234] |[0.3504733529544735,0.6495266470455265] |1.0 |soyo3 |
|[6.3,2.9,5.6,1.8]|soyo3|1.0 |[6.3,2.9,5.6,10.0]|[-0.3750852512562874,0.3750852512562882] |[0.3207841503157466,0.6792158496842534] |1.0 |soyo3 |
|[6.3,3.3,6.0,2.5]|soyo3|1.0 |[6.3,3.3,6.0,17.0]|[-0.5776773099857371,0.577677309985738] |[0.23951239936093965,0.7604876006390604]|1.0 |soyo3 |
|[6.3,3.4,5.6,2.4]|soyo3|1.0 |[6.3,3.4,5.6,16.0]|[-0.485750239692336,0.4857502396923369] |[0.2745815258875292,0.7254184741124707] |1.0 |soyo3 |
+-----------------+-----+------------+------------------+------------------------------------------+----------------------------------------+----------+---------------+
only showing top 20 rows
准确率为: 1.0
错误率为: 0.0
二项逻辑回归模型系数矩阵: 0.0 0.17220032593884316 -0.1047821144965127 -0.03279419190091169
0.0 -0.172200325938843 0.10478211449651276 0.03279419190091169
二项逻辑回归模型的截距向量: [0.04025556371065551,-0.04025556371065551]
类的数量(标签可以使用的值): 2
模型所接受的特征的数量: 4
Spark 多项式逻辑回归__二分类的更多相关文章
- Spark 多项式逻辑回归__多分类
package Spark_MLlib import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.{B ...
- Spark 二项逻辑回归__二分类
package Spark_MLlib import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.{B ...
- scikit-learn机器学习(二)逻辑回归进行二分类(垃圾邮件分类),二分类性能指标,画ROC曲线,计算acc,recall,presicion,f1
数据来自UCI机器学习仓库中的垃圾信息数据集 数据可从http://archive.ics.uci.edu/ml/datasets/sms+spam+collection下载 转成csv载入数据 im ...
- 机器学习---逻辑回归(二)(Machine Learning Logistic Regression II)
在<机器学习---逻辑回归(一)(Machine Learning Logistic Regression I)>一文中,我们讨论了如何用逻辑回归解决二分类问题以及逻辑回归算法的本质.现在 ...
- stanford coursera 机器学习编程作业 exercise 3(逻辑回归实现多分类问题)
本作业使用逻辑回归(logistic regression)和神经网络(neural networks)识别手写的阿拉伯数字(0-9) 关于逻辑回归的一个编程练习,可参考:http://www.cnb ...
- Logistic Regression(逻辑回归)(二)—深入理解
(整理自AndrewNG的课件,转载请注明.整理者:华科小涛@http://www.cnblogs.com/hust-ghtao/) 上一篇讲解了Logistic Regression的基础知识,感觉 ...
- 【原】Spark之机器学习(Python版)(二)——分类
写这个系列是因为最近公司在搞技术分享,学习Spark,我的任务是讲PySpark的应用,因为我主要用Python,结合Spark,就讲PySpark了.然而我在学习的过程中发现,PySpark很鸡肋( ...
- Spark Mllib逻辑回归算法分析
原创文章,转载请注明: 转载自http://www.cnblogs.com/tovin/p/3816289.html 本文以spark 1.0.0版本MLlib算法为准进行分析 一.代码结构 逻辑回归 ...
- Spark LogisticRegression 逻辑回归之建模
导入包 import org.apache.spark.sql.SparkSession import org.apache.spark.sql.Dataset import org.apache.s ...
随机推荐
- IDEA的Maven Projects无法显示
记一个小坑: 前两天重装了一下电脑系统,下载了个最新的IDEA2018.3.5,把Maven.JDK.TomCat都设置好了 今天打开IDEA创建一个新的Maven项目,项目没有显示让我导入Maven ...
- 【转】SQLServer连接字符串配置:MultipleActiveResultSets
ADO.NET 1.x 利用SqlDataReader读取数据,针对每个结果集需要一个独立的连接.当然,你还必须管理这些连接并且要付出相应的内存和潜在的应用程序中的高度拥挤的瓶颈代价-特别是在数据集中 ...
- c语音 dll断点调试方法
转自:https://blog.csdn.net/qingzai_/article/details/45348613 dll调试方法: 1.把最新生成的dll和pdb放到 启动这个dll 的进程目录下 ...
- * SPOJ PGCD Primes in GCD Table (需要自己推线性筛函数,好题)
题目大意: 给定n,m,求有多少组(a,b) 0<a<=n , 0<b<=m , 使得gcd(a,b)= p , p是一个素数 这里本来利用枚举一个个素数,然后利用莫比乌斯反演 ...
- HDU 2147 找规律博弈
题目大意: 从右上角出发一直到左下角,每次左移,下移或者左下移,到达左下角的人获胜 到达左下角为必胜态,那么到达它的所有点都为必败态,每个点的局势都跟左,下,左下三个点有关 开始写了一个把所有情况都计 ...
- Django用法补充
1. 自定义Admin from django.contrib import admin from xx import models # 自定义操作 class CustomerAdmin(admin ...
- ie下php session不能用(域名的合法定义)
今天遇到了一个奇怪的问题.应用程序的后台ie下居然无法登陆,老是提示验证码不正确,明明输入是正确的.于是抓包.测试.调试,最终发现罪魁祸首phpsessionid在ie下没有办法写入.研究了一下,发现 ...
- Linux下汇编语言学习笔记0 --- 前期准备工作
这是17年暑假学习Linux汇编语言的笔记记录,参考书目为清华大学出版社 Jeff Duntemann著 梁晓辉译<汇编语言基于Linux环境>的书,喜欢看原版书的同学可以看<Ass ...
- sql将日期按照年月分组并统计数量
SELECT DATE_FORMAT(releaseDate,"%Y年%m月") AS dates,COUNT(*) FROM t_diary GROUP BY DATE_FORM ...
- 搬砖--杭电校赛(dfs)
搬砖 Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65535/65535 K (Java/Others)Total Submissi ...