MLlib之NaiveBayes算法源码学习
package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import org.apache.spark.{SparkException, Logging}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD /**
* Model for Naive Bayes Classifiers.
*
* @param labels list of labels
* @param pi log of class priors, whose dimension is C, number of labels
* @param theta log of class conditional probabilities, whose dimension is C-by-D,
* where D is number of features
*/
class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable { private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM[Double](theta.length, theta(0).length) {
// Need to put an extra pair of braces to prevent Scala treating `i` as a member.
var i = 0
while (i < theta.length) {
var j = 0
while (j < theta(i).length) {
brzTheta(i, j) = theta(i)(j)
j += 1
}
i += 1
}
} override def predict(testData: RDD[Vector]): RDD[Double] = {
val bcModel = testData.context.broadcast(this)
testData.mapPartitions { iter =>
val model = bcModel.value
iter.map(model.predict)
}
} override def predict(testData: Vector): Double = {
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
}
} /**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
*
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
* document classification. By making every vector a 0-1 vector, it can also be used as
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative.
*/
class NaiveBayes private (private var lambda: Double) extends Serializable with Logging { def this() = this(1.0) /** Set the smoothing parameter. Default: 1.0. */
def setLambda(lambda: Double): NaiveBayes = {
this.lambda = lambda
this
} /**
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
*
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
*/
def run(data: RDD[LabeledPoint]) = {
val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
val values = v match {
case sv: SparseVector =>
sv.values
case dv: DenseVector =>
dv.values
}
if (!values.forall(_ >= 0.0)) {
throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")
}
} // Aggregates term frequencies per label.
// TODO: Calling combineByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
createCombiner = (v: Vector) => {
requireNonnegativeValues(v)
(1L, v.toBreeze.toDenseVector)
},
mergeValue = (c: (Long, BDV[Double]), v: Vector) => {
requireNonnegativeValues(v)
(c._1 + 1L, c._2 += v.toBreeze)
},
mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
(c1._1 + c2._1, c1._2 += c2._2)
).collect()
val numLabels = aggregated.length
var numDocuments = 0L
aggregated.foreach { case (_, (n, _)) =>
numDocuments += n
}
val numFeatures = aggregated.head match { case (_, (_, v)) => v.size }
val labels = new Array[Double](numLabels)
val pi = new Array[Double](numLabels)
val theta = Array.fill(numLabels)(new Array[Double](numFeatures))
val piLogDenom = math.log(numDocuments + numLabels * lambda)
var i = 0
aggregated.foreach { case (label, (n, sumTermFreqs)) =>
labels(i) = label
val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
pi(i) = math.log(n + lambda) - piLogDenom
var j = 0
while (j < numFeatures) {
theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
j += 1
}
i += 1
} new NaiveBayesModel(labels, pi, theta)
}
} /**
* Top-level methods for calling naive Bayes.
*/
object NaiveBayes {
/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
*
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
* document classification. By making every vector a 0-1 vector, it can also be used as
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
*
* This version of the method uses a default smoothing parameter of 1.0.
*
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
* vector or a count vector.
*/
def train(input: RDD[LabeledPoint]): NaiveBayesModel = {
new NaiveBayes().run(input)
} /**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
*
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
* document classification. By making every vector a 0-1 vector, it can also be used as
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
*
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
* vector or a count vector.
* @param lambda The smoothing parameter
*/
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
new NaiveBayes(lambda).run(input)
}
}
package org.apache.spark.mllib.classification import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD /**
* :: Experimental ::
* Represents a classification model that predicts to which of a set of categories an example
* belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc.
*/
@Experimental
trait ClassificationModel extends Serializable {
/**
* Predict values for the given data set using the model trained.
*
* @param testData RDD representing data points to be predicted
* @return an RDD[Double] where each entry contains the corresponding prediction
*/
def predict(testData: RDD[Vector]): RDD[Double] /**
* Predict values for a single data point using the model trained.
*
* @param testData array representing a single data point
* @return predicted category from the trained model
*/
def predict(testData: Vector): Double /**
* Predict values for examples stored in a JavaRDD.
* @param testData JavaRDD representing data points to be predicted
* @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
*/
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}
朴素贝叶斯分类算法
何为分类算法?简单来说,就是将具有某些特性的物体归类对应到一个已知的类别集合中的某个类别上。从数学角度来说,可以做如下定义:
已知集合: C={y1,y2,..,yn} 和 I={x1,x2,..,xm,..} ,确定映射规则 y=f(x),使得任意 xi∈I 有且仅有一个 yj∈C 使得 yj=f(xi) 成立。
其中,C为类别集合,I为待分类的物体,f则为分类器,分类算法的主要任务就是构造分类器f。
分类算法的构造通常需要一个已知类别的集合来进行训练,通常来说训练出来的分类算法不可能达到100%的准确率。分类器的质量往往与训练数据、验证数据、训练数据样本大小等因素相关。
举个例子,我们日常生活中看到一个陌生人,要做的第一件事情就是判断其性别,判断性别的过程就是一个分类的过程。根据以往的生活经验,通常经过头发长短、服饰和体型这三个要素就能判断出来一个人的性别。这里的“生活经验”就是一个训练好的关于性别判断的模型,其训练数据是日常生活中遇到的形形色色的人。突然有一天,一个娘炮走到了你面前,长发飘飘,穿着紧身的衣裤,可是体型却很man,于是你就疑惑了,根据以往的经验——也就是已经训练好的模型,无法判断这个人的性别。于是你学会了通过喉结来判断其性别,这样你的模型被训练的质量更高了。但不可否认的是,永远会出现一个让你无法判断性别的人。所以模型永远无法达到100%的准确,只会随着训练数据的不断增多而无限接近100%的准确。
贝叶斯公式
贝叶斯公式,或者叫做贝叶斯定理,是贝叶斯分类的基础。而贝叶斯分类是一类分类算法的统称,这一类算法的基础都是贝叶斯公式。目前研究较多的四种贝叶斯分类算法有:Naive Bayes、TAN、BAN和GBN。
理工科的学生在大学应该都学过概率论,其中最重要的几个公式中就有贝叶斯公式——用来描述两个条件概率之间的关系,比如P(A|B)和P(B|A)。如何在已知事件A和B分别发生的概率,和事件B发生时事件A发生的概率,来求得事件A发生时事件B发生的概率,这就是贝叶斯公式的作用。其表述如下:
P(B|A)=P(A|B)×P(B)P(A)
朴素贝叶斯分类
朴素贝叶斯分类,Naive Bayes,你也可以叫它NB算法。其核心思想非常简单:对于某一预测项,分别计算该预测项为各个分类的概率,然后选择概率最大的分类为其预测分类。就好像你预测一个娘炮是女人的可能性是40%,是男人的可能性是41%,那么就可以判断他是男人。
Naive Bayes的数学定义如下:
- 设 x={a1,a2,..,am} 为一个待分类项,而每个 ai 为 x 的一个特征属性
- 已知类别集合 C={y1,y2,..,yn}
- 计算 x 为各个类别的概率: P(y1|x),P(y2|x),..,P(yn|x)
- 如果 P(yk|x)=max{P(y1|x),P(y2|x),..,P(yn|x)} ,则 x 的类别为 yk
如何获取第四步中的最大值,也就是如何计算第三步中的各个条件概率最为重要。可以采用如下做法:
- 获取训练数据集,即分类已知的数据集
- 统计得到在各类别下各个特征属性的条件概率估计,即:P(a1|y1),P(a2|y1),...,P(am|y1);P(a1|y2),P(a2|y2),...,P(am|y2);...;P(a1|yn),P(a2|yn),...,P(am|yn),其中的数据可以是离散的也可以是连续的
- 如果各个特征属性是条件独立的,则根据贝叶斯定理有如下推导: P(yi|x)=P(x|yi)P(yi)P(x)
对于某x来说,分母是固定的,所以只要找出分子最大的即为条件概率最大的。又因为各特征属性是条件独立的,所以有: P(x|yi)P(yi)=P(a1|yi)P(a2|yi)...P(am|yi)P(yi)=P(yi)∏mj=1P(aj|yi)
Additive smoothing
Additive smoothing,又叫Laplacian smoothing或Lidstone smoothing。
当某个类别下某个特征项划分没有出现时, P(ai|yj)=0 ,这样最后乘出来的结果会让精确度大大的降低,所以引入Additive smoothing来解决这个问题。其思想是对于这样等于0的情况,将其计数值加1,这样如果训练样本集数量充分大时,并不会对结果产生影响,并且解决了上述频率为0的尴尬局面。
MLlib之NaiveBayes算法源码学习的更多相关文章
- MLlib之LR算法源码学习
/** * :: DeveloperApi :: * GeneralizedLinearModel (GLM) represents a model trained using * Generaliz ...
- [算法1-排序](.NET源码学习)& LINQ & Lambda
[算法1-排序](.NET源码学习)& LINQ & Lambda 说起排序算法,在日常实际开发中我们基本不在意这些事情,有API不用不是没事找事嘛.但必要的基础还是需要了解掌握. 排 ...
- [算法2-数组与字符串的查找与匹配] (.NET源码学习)
[算法2-数组与字符串的查找与匹配] (.NET源码学习) 关键词:1. 数组查找(算法) 2. 字符串查找(算法) 3. C#中的String(源码) 4. 特性Attribute 与内 ...
- Java集合专题总结(1):HashMap 和 HashTable 源码学习和面试总结
2017年的秋招彻底结束了,感觉Java上面的最常见的集合相关的问题就是hash--系列和一些常用并发集合和队列,堆等结合算法一起考察,不完全统计,本人经历:先后百度.唯品会.58同城.新浪微博.趣分 ...
- Redis源码学习:字符串
Redis源码学习:字符串 1.初识SDS 1.1 SDS定义 Redis定义了一个叫做sdshdr(SDS or simple dynamic string)的数据结构.SDS不仅用于 保存字符串, ...
- 基于jdk1.8的HashMap源码学习笔记
作为一种最为常用的容器,同时也是效率比较高的容器,HashMap当之无愧.所以自己这次jdk源码学习,就从HashMap开始吧,当然水平有限,有不正确的地方,欢迎指正,促进共同学习进步,就是喜欢程序员 ...
- Vue源码学习1——Vue构造函数
Vue源码学习1--Vue构造函数 这是我第一次正式阅读大型框架源码,刚开始的时候完全不知道该如何入手.Vue源码clone下来之后这么多文件夹,Vue的这么多方法和概念都在哪,完全没有头绪.现在也只 ...
- zookeeper集群搭建及Leader选举算法源码解析
第一章.zookeeper概述 一.zookeeper 简介 zookeeper 是一个开源的分布式应用程序协调服务器,是 Hadoop 的重要组件. zooKeeper 是一个分布式的,开放源码的分 ...
- Vue2.1.7源码学习
原本文章的名字叫做<源码解析>,不过后来想想,还是用“源码学习”来的合适一点,在没有彻底掌握源码中的每一个字母之前,“解析”就有点标题党了.建议在看这篇文章之前,最好打开2.1.7的源码对 ...
随机推荐
- Vue框架H5商城类项目商品详情点击返回弹出推荐商品弹窗的实现方案
需求场景: 非推荐商品详情页返回的时候弹出弹窗推荐商品,点击弹窗按钮可以直接访问推荐商品: 只有直接进入商品详情页返回才会弹出推荐商品弹窗: 每个用户访问只能弹一次(除非清除缓存). 需求分析: 1. ...
- jq修改导航栏样式(选中、使用两张图片替代的是否选中效果)
<footer class="toolbar"> <ul> <li> <a href="{:url('Index/home')} ...
- Spring InitializingBean 接口以及Aware接口实现的原理
关于Spring InitializingBean 接口以及Aware接口实现的其实都在 第11步中: finishBeanFactoryInitialization() 方法中完成了3部分的内容: ...
- freemaker超详细 讲解 配置
一.FreeMarker简介 二.第一个FreeMark示例 2.1.新建一个Maven项目 2.2.添加依赖 2.3.添加存放模板的文件夹 2.4.添加模板 2.5.解析模板 2.6.运行结果 三. ...
- eclipse中集成python开发环境
转载:https://www.cnblogs.com/mywood/p/7272487.html Eclipse简介 Eclipse是java开发最常用的IDE,功能强大,可以在MAC和Windos上 ...
- android app主程序启动前加载图片
android app加载启动图片需要新创建一个activity,在主activity先加载图片activity,展示过程结束后,显示主activity.具体流程如下: 一.创建图片activity的 ...
- javaweb开发3.基于Servlet+JSP+JavaBean开发模式的用户登录注册
转载孤傲苍狼博客http://www.cnblogs.com/xdp-gacl/p/3902537.html 1.层次比较分明的项目结构图
- NGS NGS ngs(hisat,stringtie,ballgown)
NGS ngs(hisat,stringtie,ballgown) #HISAT (hierarchical indexing for spliced alignment of transcripts ...
- 转 node.js和 android中java加密解密一致性问题;
原文地址,请大家去原文博客了解; http://blog.csdn.net/linminqin/article/details/19972751 我保留一份,防止删除: var crypto = re ...
- 20172306 2018-2019《Java程序设计与数据结构课堂测试补充报告》
学号 2017-2018-2 <程序设计与数据结构>课堂测试补充报告 课程:<程序设计与数据结构> 班级: 1723 姓名: 刘辰 学号:20172306 实验教师:王志强 必 ...