数据集

iris.data

数据集概览

代码

  1. package org.apache.spark.examples.examplesforml
  2.  
  3. import org.apache.spark.ml.clustering.{KMeans, LDA}
  4. import org.apache.spark.SparkConf
  5. import org.apache.spark.ml.feature.VectorAssembler
  6. import org.apache.spark.sql.SparkSession
  7.  
  8. import scala.util.Random
  9.  
  10. object lLDA {
  11. def main(args: Array[String]): Unit = {
  12. val conf = new SparkConf().setMaster("local").setAppName("iris")
  13. val spark = SparkSession.builder().config(conf).getOrCreate()
  14.  
  15. val file = spark.read.format("csv").load("D:\\9-4LDA算法\\iris.data")
  16. file.show()
  17.  
  18. import spark.implicits._
  19. val random = new Random()
  20. val data = file.map(row => {
  21. val label = row.getString(4) match {
  22. case "Iris-setosa" => 0
  23. case "Iris-versicolor" => 1
  24. case "Iris-virginica" => 2
  25. }
  26.  
  27. (row.getString(0).toDouble,
  28. row.getString(1).toDouble,
  29. row.getString(2).toDouble,
  30. row.getString(3).toDouble,
  31. label,
  32. random.nextDouble())
  33. }).toDF("_c0", "_c1", "_c2", "_c3", "label", "rand").sort("rand")
  34. val assembler = new VectorAssembler()
  35. .setInputCols(Array("_c0", "_c1", "_c2", "_c3"))
  36. .setOutputCol("features")
  37.  
  38. val dataset = assembler.transform(data)
  39. val Array(train, test) = dataset.randomSplit(Array(0.8, 0.2))
  40. train.show()
  41. /*
  42. val kmeans = new KMeans().setFeaturesCol("features").setK(3).setMaxIter(20)
  43. val model = kmeans.fit(train)
  44. model.transform(train).show()
  45. */
  46. val lda = new LDA().setFeaturesCol("features").setK(3).setMaxIter(40)
  47. val model = lda.fit(train)
  48. val prediction = model.transform(train)
  49. //prediction.show()
  50. val ll = model.logLikelihood(train)
  51. val lp = model.logPerplexity(train)
  52. // Describe topics.
  53. val topics = model.describeTopics(3)
  54. prediction.select("label","topicDistribution").show(false)
  55. println("The topics described by their top-weighted terms:")
  56. topics.show(false)
  57. println(s"The lower bound on the log likelihood of the entire corpus: $ll")
  58. println(s"The upper bound on perplexity: $lp")
  59. }
  60. }

输出结果

掌握Spark机器学习库-09.6-LDA算法的更多相关文章

  1. 掌握Spark机器学习库-09.3-kmeans算法实现分类

     数据集 iris.data 数据集概览 代码 package org.apache.spark.examples.hust.hml.examplesforml import org.apache.s ...

  2. 掌握Spark机器学习库-07-线性回归算法概述

    1)简介 自变量,因变量,线性关系,相关系数,一元线性关系,多元线性关系(平面,超平面) 2)使用线性回归算法的前提 3)应用例子 沸点与气压 浮力与表面积

  3. 掌握Spark机器学习库(课程目录)

    第1章 初识机器学习 在本章中将带领大家概要了解什么是机器学习.机器学习在当前有哪些典型应用.机器学习的核心思想.常用的框架有哪些,该如何进行选型等相关问题. 1-1 导学 1-2 机器学习概述 1- ...

  4. UCI机器学习库和一些相关算法(转载)

    UCI机器学习库和一些相关算法 各种机器学习任务的顶级结果(论文)汇总 https://github.com//RedditSota/state-of-the-art-result-for-machi ...

  5. 掌握Spark机器学习库-07.14-保序回归算法实现房价预测

    数据集 house.csv 数据集概览 代码 package org.apache.spark.examples.examplesforml import org.apache.spark.ml.cl ...

  6. 掌握Spark机器学习库-08.2-朴素贝叶斯算法

    数据集 iris.data 数据集概览 代码 import org.apache.spark.SparkConf import org.apache.spark.ml.classification.{ ...

  7. 掌握Spark机器学习库-07-回归算法原理

    1)机器学习模型理解 统计学习,神经网络 2)预测结果的衡量 代价函数(cost function).损失函数(loss function) 3)线性回归是监督学习

  8. 掌握Spark机器学习库-07.6-线性回归实现房价预测

    数据集 house.csv 数据概览 代码 package org.apache.spark.examples.examplesforml import org.apache.spark.ml.fea ...

  9. Spark机器学习(11):协同过滤算法

    协同过滤(Collaborative Filtering,CF)算法是一种常用的推荐算法,它的思想就是找出相似的用户或产品,向用户推荐相似的物品,或者把物品推荐给相似的用户.怎样评价用户对商品的偏好? ...

随机推荐

  1. [Tue, 11 Aug 2015 ~ Mon, 17 Aug 2015] Deep Learning in arxiv

    Image Representations and New Domains inNeural Image Captioning we find that a state-of-theart neura ...

  2. JAVA程序员常用软件整理

    Java类软件:-------------------------------JDK7.0:http://pan.baidu.com/s/1jGFYvAYMyclipse8.5破解版:http://p ...

  3. ZOJ1610 Count the Colors —— 线段树 区间染色

    题目链接:https://vjudge.net/problem/ZOJ-1610 Painting some colored segments on a line, some previously p ...

  4. What's the difference between HEAD, working tree and index, in Git?

    What's the difference between HEAD, working tree and index, in Git?

  5. (转)JFreeChart教程

    JFreeChart教程 一.jFreeChart产生图形的流程 创建一个数据源(dataset)来包含将要在图形中显示的数据>>创建一个 JFreeChart 对象来代表要显示的图形&g ...

  6. 吃CPU的openmp 程序

    g++ -o eat -fopenmp eat.cpp #include "stdio.h" int main(int argc, char *argv[]) { #pragma ...

  7. 并不对劲的bzoj4827:loj2020:p3723:[AHOI/HNOI2017]礼物

    题目大意 有两个长度为\(n\)(\(n\leq5*10^4\))的数列\(x_1,x_2,...,x_n\)和\(y_1,y_2,...,y_n\),两个数列里的数都不超过\(m\)(\(m\leq ...

  8. mysqlnd cannot connect to MySQL 4.1+ using old authentication

    报这个错误主要是因为mysql使用了老的密码格式,而程序要求使用新的格式导致的,解决办法: SET old_passwords = 0; UPDATE mysql.user SET Password ...

  9. 在javascript中,我怎么得到下拉条顶端与当前可视的顶端高度的距离,不是和网页顶端的距离

    "滚动条顶端距离" + document.documentElement.scrollTop)

  10. HDU2604:Queuing(矩阵快速幂+递推)

    传送门 题意 长为len的字符串只由'f','m'构成,有2^len种情况,问在其中不包含'fmf','fff'的字符串有多少个,此处将队列换成字符串 分析 矩阵快速幂写的比较崩,手生了,多练! 用f ...