在Spark上用Scala实验梯度下降算法
首先参考的是这篇文章:http://blog.csdn.net/sadfasdgaaaasdfa/article/details/45970185
但是其中的函数太老了。所以要改。另外出发点是我自己的这篇文章 http://www.cnblogs.com/charlesblc/p/6206198.html 里面关于梯度下降的那幅图片。
改来改去,在随机化向量上耗费了很多时间,最后还是做好了。代码如下:
package com.spark.my import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}
import breeze.linalg.DenseVector
import breeze.numerics.exp /**
* Created by baidu on 16/11/28.
*/ object GradientDemo{
case class DataPoint(x: DenseVector[Double], y: Double) // case class见下文
def parsePoint(x: Array[Double]): DataPoint = {
//DataPoint(Vectors.dense(x.slice(0, x.size-2)), x(x.size-1))
DataPoint(DenseVector(x.slice(0, x.size-2)), x(x.size-1))
} def main(args: Array[String]) { Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
val conf = new SparkConf()
val sc = new SparkContext(conf) println("Begin load gradient file")
// 装载数据集
val text = sc.textFile("hdfs://master.Hadoop:8390/gradient_data/spam.data.txt")
val lines = text.map {
line =>
line.split(" ").map(_.toDouble)
} val points = lines.map(parsePoint(_)) // (parsePoint(_))看起来是一样的
var w = DenseVector.rand(lines.first().size - 2) val iterations = 100
for (i <- 1 to iterations) {
val gradient = points.map(p =>
(1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x)
.reduce(_ + _)
w -= gradient
} println("Finish data loading, w num: " + w.length + "; w: " + w) }
}
然后在m42n05机器上,先用的是把 http://www-stat.stanford.edu/~tibs/ElemStatLearn/datasets/spam.data 这个文件拷贝到Hadoop上:
$hadoop fs -mkdir /gradient_data $ hadoop fs -put spam.data.txt /gradient_data/ $ hadoop fs -ls /gradient_data/
Found 1 items
-rw-r--r-- 3 work supergroup 698341 2016-12-21 17:59 /gradient_data/spam.data.txt
然后把jar包也拷贝过来,运行命令:
$ ./bin/spark-submit --class com.spark.my.GradientDemo --master spark://10.117.146.12:7077 myjars/scala-demo.jar 得到输出:
16/12/21 18:17:57 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
16/12/21 18:17:58 INFO util.log: Logging initialized @1689ms
16/12/21 18:17:58 INFO server.Server: jetty-9.2.z-SNAPSHOT
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@107ed6fc{/jobs,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@1643d68f{/jobs/json,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@186978a6{/jobs/job,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@2e029d61{/jobs/job/json,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@482d776b{/stages,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@4052274f{/stages/json,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@132ddbab{/stages/stage,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@297ea53a{/stages/stage/json,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@acb0951{/stages/pool,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@5bf22f18{/stages/pool/json,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@267f474e{/storage,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@7a7471ce{/storage/json,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@28276e50{/storage/rdd,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@62e70ea3{/storage/rdd/json,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@3efe7086{/environment,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@675d8c96{/environment/json,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@741b3bc3{/executors,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@2ed3b1f5{/executors/json,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@63648ee9{/executors/threadDump,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@68d6972f{/executors/threadDump/json,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@45be7cd5{/static,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@7651218e{/,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@3185fa6b{/api,null,AVAILABLE}
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@6d366c9b{/stages/stage/kill,null,AVAILABLE}
16/12/21 18:17:58 INFO server.ServerConnector: Started ServerConnector@53e211ee{HTTP/1.1}{0.0.0.0:4040}
16/12/21 18:17:58 INFO server.Server: Started @1811ms
16/12/21 18:17:58 INFO handler.ContextHandler: Started o.s.j.s.ServletContextHandler@6e0d4a8{/metrics/json,null,AVAILABLE}
Begin load gradient file
16/12/21 18:18:00 INFO mapred.FileInputFormat: Total input paths to process : 1
16/12/21 18:18:02 WARN netlib.BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
16/12/21 18:18:02 WARN netlib.BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
Finish data loading, w num: 56; w: DenseVector(0.5742670447735152, 0.3793477463119241, 0.9681722093411653, 0.5967720119758925, 1.513648869152009, 0.8246263930800145, 0.8513296345703405, 0.5016541916805365, 0.10371045067354999, 1.0622529560536655, 0.7333760424194737, 2.1149483032187897, 0.9299367625800867, 0.7255747859512406, 0.13008556580706143, 1.4831202765138185, 0.7729907277492736, 0.9723309264036033, 13.394753146641808, 0.5531526429090097, 2.7444722115693665, 0.11325813324181622, 0.5096129116641023, 0.7201439311127137, 0.44719912156747926, 0.8273500952621051, 0.6736417633922696, 0.046531684571481415, 0.017895929000231802, 0.4726397794671698, 0.394438566392741, 0.8438784726078483, 0.4144073806784945, 0.18873920886297268, 0.4760240368798872, 0.31604719205329873, 0.694745503752298, 0.721380820951884, 0.988535475648986, 0.13515871744899247, 0.15694652560543523, 0.6939378895510522, 0.9279201378471407, 0.3336083293555714, 0.38938263676999685, 0.17159756568171308, 0.18897754115255144, 0.7281027812135723, 0.7233165381530381, 1.1093715737790655, 0.15675561193336351, 2.059622965151493, 0.6839713282339183, 0.11528695729374866, 7.413534050555067, 23.13404922028611)
16/12/21 18:18:07 INFO server.ServerConnector: Stopped ServerConnector@53e211ee{HTTP/1.1}{0.0.0.0:4040}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@6d366c9b{/stages/stage/kill,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@3185fa6b{/api,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@7651218e{/,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@45be7cd5{/static,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@68d6972f{/executors/threadDump/json,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@63648ee9{/executors/threadDump,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@2ed3b1f5{/executors/json,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@741b3bc3{/executors,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@675d8c96{/environment/json,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@3efe7086{/environment,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@62e70ea3{/storage/rdd/json,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@28276e50{/storage/rdd,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@7a7471ce{/storage/json,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@267f474e{/storage,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@5bf22f18{/stages/pool/json,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@acb0951{/stages/pool,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@297ea53a{/stages/stage/json,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@132ddbab{/stages/stage,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@4052274f{/stages/json,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@482d776b{/stages,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@2e029d61{/jobs/job/json,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@186978a6{/jobs/job,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@1643d68f{/jobs/json,null,UNAVAILABLE}
16/12/21 18:18:07 INFO handler.ContextHandler: Stopped o.s.j.s.ServletContextHandler@107ed6fc{/jobs,null,UNAVAILABLE}
可以看到数据正常进行了处理。
在代码的迭代循环里面再加上这么一句,看看过程:
println("In data loading, w num: " + w.length + "; w: " + w)
然后重新拷贝jar包,然后运行。发现增加了很多中间数据,但是每次改动不大,有的只是最后几个数字改动:
In data loading, w num: 56; w: DenseVector(0.8387794911469437, 0.041931950643148204, 0.610593576873822, 0.775693127624059, 0.9595814255406686, 0.8346753461732199, 1.3049939469403333, 0.7056665962054256, 0.4607139317388798, 0.7272237992038442, 0.658182563650663, 0.733627042229442, 0.49543528179048996, 0.43928474305383947, 0.7784540121519834, 3.3618947233533456, 0.8863247999385253, 0.4007587753541083, 2.0631977325748334, 0.8211289850510815, 1.2076387347473903, 0.43209585536401196, 0.8361371667999544, 0.3902040623717107, 0.9249800607229486, 0.9684655358995048, 0.7122113545634148, 0.7564214721597596, 0.9295754044438086, 0.0667831407627083, 0.8262226990678785, 0.9866253536733688, 0.7214690647928418, 0.5992067836236182, 0.801215365214358, 1.0206941788488395, 0.8887684894893382, 0.39696145592511084, 0.7994301499483707, 0.39766237687949973, 0.3213782652296576, 0.3959330364022269, 0.6573698429264838, 0.5725594506918451, 0.932872703406284, 0.4276515117478306, 0.8908902872993782, 0.6281143587881469, 0.5136752276267151, 1.0933173640821512, 0.10820509511118362, 1.9426418431339785, 0.2017114624971559, 0.9827542778431644, 5.224634203803431, 16.694903977208174)
In data loading, w num: 56; w: DenseVector(0.8387794911469437, 0.041931950643148204, 0.6105935768739001, 0.775693127624059, 0.9595814255414439, 0.8346753461732199, 1.3049939469403333, 0.7056665962054256, 0.4607139317388798, 0.7272237992038442, 0.658182563650663, 0.733627042229442, 0.49543528179048996, 0.43928474305383947, 0.7784540121519834, 3.3618947233534118, 0.8863247999385373, 0.4007587753541083, 2.0631977325749897, 0.8211289850510815, 1.2076387347474142, 0.43209585536401196, 0.8361371667999544, 0.3902040623717107, 0.9249800607229486, 0.9684655358995048, 0.7122113545634148, 0.7564214721597596, 0.9295754044438086, 0.0667831407627083, 0.8262226990678785, 0.9866253536733688, 0.7214690647928418, 0.5992067836236182, 0.801215365214358, 1.0206941788488395, 0.8887684894893382, 0.39696145592511084, 0.7994301499483707, 0.3976623768795117, 0.3213782652296576, 0.3959330364022269, 0.6573698429264838, 0.5725594506918451, 0.932872703406296, 0.4276515117478306, 0.8908902872993782, 0.6281143587881469, 0.5136752276267151, 1.093317364082217, 0.10820509511118362, 1.942641843152015, 0.2017114624971559, 0.982754277843168, 5.22463420411604, 16.694903977520784)
梯度下降原理
梯度下降原理讲的比较好的,可以看这里:
http://blog.csdn.net/woxincd/article/details/7040944
还有这篇:
http://www.cnblogs.com/maybe2030/p/5089753.html?utm_source=tuicool&utm_medium=referral
仔细看了一下,发现上面的公式,和代码里面的公式好像不太一样。应该是代码里面用到了Sigmoid函数。
还需要好好领悟一下。
上面代码里面用到的公式主要是:
(1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x)
上面p.x是一个n维的vector,p.y是一个数值。 然后 reduce(_+_)是说把没一行的都加起来。也就是最后是一个n维的vector. 然后 w -= gradient
然后迭代N次,得到一个新的w.
case class
case class和class的区别可以看:http://www.tuicool.com/articles/yEZr6ve
在Scala中存在case class,它其实就是一个普通的class。但是它又和普通的class略有区别,如下:
1、初始化的时候可以不用new,当然你也可以加上,普通类一定需要加new;
2、toString的实现更漂亮;
3、默认实现了equals 和hashCode;
4、默认是可以序列化的,也就是实现了Serializable ;
5、自动从scala.Product中继承一些函数;
6、case class构造函数的参数是public级别的,我们可以直接访问;
7、支持模式匹配。
Breeze
另外,上面的DenseVector其实都是用的Breeze里面的类
LinearRegressionWithSGD
另外,这是Spark里面实现的线性回归,是基于随机梯度下降的。相似的函数还有:
MLlib中可用的线性回归算法有:LinearRegressionWithSGD,RidgeRegressionWithSGD,LassoWithSGD;MLlib回归分析中涉及到的主要类有,GeneralizedLinearAlgorithm,GradientDescent。
Scala用Java
上文最后用的是DenseVector,所以没有用下面这段。但是下面这段说明了Scala里面可以用Java的:
import java.util.Random
val rand = new Random(53)
在Spark上用Scala实验梯度下降算法的更多相关文章
- 在Spark上通过BulkLoad快速将海量数据导入到Hbase
我们在<通过BulkLoad快速将海量数据导入到Hbase[Hadoop篇]>文中介绍了一种快速将海量数据导入Hbase的一种方法,而本文将介绍如何在Spark上使用Scala编写快速导入 ...
- Spark MLib:梯度下降算法实现
声明:本文参考< 大数据:Spark mlib(三) GradientDescent梯度下降算法之Spark实现> 1. 什么是梯度下降? 梯度下降法(英语:Gradient descen ...
- 解惑:在Ubuntu18.04.2的idea上运行Scala支持的spark程序遇到的问题
解惑:在Ubuntu18.04.2的idea上运行Scala支持的spark程序遇到的问题 一.前言 最近在做一点小的实验,用到了Scala,spark这些东西,于是在Linux平台上来完成,结果一个 ...
- flink 批量梯度下降算法线性回归参数求解(Linear Regression with BGD(batch gradient descent) )
1.线性回归 假设线性函数如下: 假设我们有10个样本x1,y1),(x2,y2).....(x10,y10),求解目标就是根据多个样本求解theta0和theta1的最优值. 什么样的θ最好的呢?最 ...
- 协同过滤 CF & ALS 及在Spark上的实现
使用Spark进行ALS编程的例子可以看:http://www.cnblogs.com/charlesblc/p/6165201.html ALS:alternating least squares ...
- ng机器学习视频笔记(二) ——梯度下降算法解释以及求解θ
ng机器学习视频笔记(二) --梯度下降算法解释以及求解θ (转载请附上本文链接--linhxx) 一.解释梯度算法 梯度算法公式以及简化的代价函数图,如上图所示. 1)偏导数 由上图可知,在a点 ...
- 监督学习:随机梯度下降算法(sgd)和批梯度下降算法(bgd)
线性回归 首先要明白什么是回归.回归的目的是通过几个已知数据来预测另一个数值型数据的目标值. 假设特征和结果满足线性关系,即满足一个计算公式h(x),这个公式的自变量就是已知的数据x,函数值h(x)就 ...
- [机器学习Lesson3] 梯度下降算法
1. Gradient Descent(梯度下降) 梯度下降算法是很常用的算法,可以将代价函数J最小化.它不仅被用在线性回归上,也被广泛应用于机器学习领域中的众多领域. 1.1 线性回归问题应用 我们 ...
- AI-2.梯度下降算法
上节定义了神经网络中几个重要的常见的函数,最后提到的损失函数的目的就是求得一组合适的w.b 先看下损失函数的曲线图,如下 即目的就是求得最低点对应的一组w.b,而本节要讲的梯度下降算法就是会一步一步地 ...
随机推荐
- StringComparison枚举
public enum StringComparison { CurrentCulture, CurrentCultureIgnoreCase, InvariantCulture, Invariant ...
- 字符编码的过滤器Filter(即输入的汉字,能在页面上正常显示,不会出现乱码)
自定义抽象的 HttpFilter类, 实现自 Filter 接口 package com.lanqiao.javaweb; import java.io.IOException; import ja ...
- Y2K Accounting Bug 分类: POJ 2015-06-16 16:55 14人阅读 评论(0) 收藏
Y2K Accounting Bug Time Limit: 1000MS Memory Limit: 65536K Total Submissions: 11222 Accepted: 56 ...
- SHA1加密C#
//SHA1 static public string SHA1_Hash(string str_sha1_in) { SHA1 sha1 = new SHA1CryptoServiceProvide ...
- hiho 第1周 最长回文子串
题目链接:http://hihocoder.com/problemset/problem/1032 #include <bits/stdc++.h> using namespace std ...
- phpcms 04
首页index.html 首页头条推荐 <div class="col-left"> <div class="news-hot"> &l ...
- 2016年10月12日 星期三 --出埃及记 Exodus 18:23
2016年10月12日 星期三 --出埃及记 Exodus 18:23 If you do this and God so commands, you will be able to stand th ...
- python学习笔记五 模块上(基础篇)
模块学习 模块,用一砣代码实现了某个功能的代码集合. 类似于函数式编程和面向过程编程,函数式编程则完成一个功能,其他代码用来调用即可,提供了代码的重用性和代码间的耦合.而对于一个复杂的功能来,可能需要 ...
- MUI 版本更新
MUI版本更新,一些js,css 就不写了. 一.app 端 1.APP html 代码 <li class="mui-table-view-cell"> <a ...
- Apache Commons CLI 开发命令行工具示例
概念说明Apache Commons CLI 简介 虽然各种人机交互技术飞速发展,但最传统的命令行模式依然被广泛应用于各个领域:从编译代码到系统管理,命令行因其简洁高效而备受宠爱.各种工具和系统都 提 ...