PLSA.py

 # coding:utf8
from pyspark import SparkContext
from pyspark import RDD
import numpy as np
from numpy.random import RandomState import sys
reload(sys)
#设置默认编码为utf8,从spark rdd中取出中文词汇时需要编码为中文编码,否则不能保存成功
sys.setdefaultencoding('utf8') """
总结:
broadcast变量和需要用到broadcast变量的方法需要位于同一作用域 broadcast变量的unpersist会将存储broadcast变量的文件立即删除,
而此时rdd并未被触发执行,当rdd执行时会发现没有broadcast变量,所以会报错,
建议只在程序运行完成后,将broadcast变量 unpersist
""" class PLSA: def __init__(self, data, sc, k, is_test=False, max_itr=1000, eta=1e-6): """
init the algorithm :type data RDD
:param data: 输入文章rdd,每条记录为一系列用空格分隔的词,如"我 爱 蓝天 我 爱 白云"
:type max_itr int
:param max_itr: 最大EM迭代次数
:type is_test bool
:param is_test: 是否为测试,是则rd = RandomState(1),否则 rd = RandomState()
:type sc SparkContext
:param sc: spark context
:type k int
:param k : 主题个数
:type eta float
:param : 阈值,当log likelyhood的变化小于eta时,停止迭代
:return : PLSA object
"""
self.max_itr = max_itr
self.k = sc.broadcast(k)
self.ori_data = data.map(lambda x: x.split(' '))
self.sc = sc
self.eta = eta self.rd = sc.broadcast(RandomState(1) if is_test else RandomState()) def train(self): #获取词汇字典 ,如{"我":1}
self.word_dict_b = self._init_dict_()
#将文本中词汇,转成词典中的index
self._convert_docs_to_word_index()
#初始化,每个主题下的单词分布
self._init_probility_word_topic_() pre_l= self._log_likelyhood_() print "L(%d)=%.5f" %(0,pre_l) for i in range(self.max_itr):
#更新每个单词主题的后验分布
self._E_step_()
#最大化下界
self._M_step_()
now_l = self._log_likelyhood_() improve = np.abs((pre_l-now_l)/pre_l)
pre_l = now_l print "L(%d)=%.5f with %.6f%% improvement" %(i+1,now_l,improve*100)
if improve <self.eta:
break def _M_step_(self):
"""
更新参数 p(z=k|d),p(w|z=k)
:return: None
"""
k = self.k
v = self.v def update_probility_of_doc_topic(doc):
"""
更新文章的主题分布
"""
doc['topic'] = doc['topic'] - doc['topic'] topic_doc = doc['topic']
words = doc['words']
for (word_index,word) in words.items():
topic_doc += word['count']*word['topic_word']
topic_doc /= np.sum(topic_doc) return {'words':words,'topic':topic_doc} self.data = self.data.map(update_probility_of_doc_topic)
"""
rdd相当于一系列操作过程的结合,且前面的操作过程嵌套在后面的操作过程里,当这个嵌套超过大约60,spark会报错,
这里每次M step都通过cache将前面的操作执行掉
"""
self.data.cache() def update_probility_word_given_topic(doc):
"""
更新每个主题下的单词分布
"""
probility_word_given_topic = np.matrix(np.zeros((k.value,v.value))) words = doc['words']
for (word_index,word) in words.items():
probility_word_given_topic[:,word_index] += np.matrix(word['count']*word['topic_word']).T return probility_word_given_topic probility_word_given_topic = self.data.map(update_probility_word_given_topic).sum()
probility_word_given_topic_row_sum = np.matrix(np.sum(probility_word_given_topic,axis=1)) #使每个主题下单词概率和为1
probility_word_given_topic = np.divide(probility_word_given_topic,probility_word_given_topic_row_sum) self.probility_word_given_topic = self.sc.broadcast(probility_word_given_topic) def _E_step_(self):
"""
更新隐变量 p(z|w,d)-给定文章,和单词后,该单词的主题分布
:return: None
"""
probility_word_given_topic = self.probility_word_given_topic
k = self.k def update_probility_of_word_topic_given_word(doc):
topic_doc = doc['topic']
words = doc['words'] for (word_index,word) in words.items():
topic_word = word['topic_word']
for i in range(k.value):
topic_word[i] = probility_word_given_topic.value[i,word_index]*topic_doc[i]
#使该单词各主题分布概率和为1
topic_word /= np.sum(topic_word)
return {'words':words,'topic':topic_doc} self.data = self.data.map(update_probility_of_word_topic_given_word) def _init_probility_word_topic_(self):
"""
init p(w|z=k)
:return: None
"""
#dict length(words in dict)
m = self.v.value probility_word_given_topic = self.rd.value.uniform(0,1,(self.k.value,m))
probility_word_given_topic_row_sum = np.matrix(np.sum(probility_word_given_topic,axis=1)).T #使每个主题下单词概率和为1
probility_word_given_topic = np.divide(probility_word_given_topic,probility_word_given_topic_row_sum) self.probility_word_given_topic = self.sc.broadcast(probility_word_given_topic) def _convert_docs_to_word_index(self): word_dict_b = self.word_dict_b
k = self.k
rd = self.rd
'''
I wonder is there a better way to execute function with broadcast varible
'''
def _word_count_doc_(doc):
wordcount ={}
word_dict = word_dict_b.value
for word in doc:
if wordcount.has_key(word_dict[word]):
wordcount[word_dict[word]]['count'] += 1
else:
#first one is the number of word occurance
#second one is p(z=k|w,d)
wordcount[word_dict[word]] = {'count':1,'topic_word': rd.value.uniform(0,1,k.value)} topics = rd.value.uniform(0, 1, k.value)
topics = topics/np.sum(topics)
return {'words':wordcount,'topic':topics} self.data = self.ori_data.map(_word_count_doc_) def _init_dict_(self):
"""
init word dict of the documents,
and broadcast it
:return: None
"""
words = self.ori_data.flatMap(lambda d: d).distinct().collect()
word_dict = {w: i for w, i in zip(words, range(len(words)))}
self.v = self.sc.broadcast(len(word_dict))
return self.sc.broadcast(word_dict) def _log_likelyhood_(self):
probility_word_given_topic = self.probility_word_given_topic
k = self.k def likelyhood(doc):
l = 0.0
topic_doc = doc['topic']
words = doc['words'] for (word_index,word) in words.items():
l += word['count']*np.log(np.matrix(topic_doc)*probility_word_given_topic.value[:,word_index])
return l
return self.data.map(likelyhood).sum() def save(self,f_word_given_topic,f_doc_topic):
"""
保存模型结果 TODO 添加分布式保存结果
:param f_word_given_topic: 文件路径,用于给定主题下词汇分布
:param f_doc_topic: 文件路径,用于保存文档的主题分布
:return:
"""
doc_topic = self.data.map(lambda x:' '.join([str(q) for q in x['topic'].tolist()])).collect()
probility_word_given_topic = self.probility_word_given_topic.value word_dict = self.word_dict_b.value
word_given_topic = [] for w,i in word_dict.items():
word_given_topic.append('%s %s' %(w,' '.join([str(q[0]) for q in probility_word_given_topic[:,i].tolist()]))) f1 = open (f_word_given_topic, 'w') for line in word_given_topic:
f1.write(line)
f1.write('\n')
f1.close() f2 = open (f_doc_topic, 'w') for line in doc_topic:
f2.write(line)
f2.write('\n')
f2.close()

调用

 from PLSA import PLSA
from pyspark import SparkContext if __name__=="__main__":
sc = SparkContext('local')
data = sc.textFile("E:/github/FGYML4/data/news_seg/news_seg.txt")
plsa = PLSA(data,sc,3,max_itr=1)
plsa.train()
plsa.save('D:/topic_word','D:/doc_topic')

基于spark的plsa实现的更多相关文章

  1. 基于Spark ALS构建商品推荐引擎

    基于Spark ALS构建商品推荐引擎   一般来讲,推荐引擎试图对用户与某类物品之间的联系建模,其想法是预测人们可能喜好的物品并通过探索物品之间的联系来辅助这个过程,让用户能更快速.更准确的获得所需 ...

  2. 【基于spark IM 的二次开发笔记】第一天 各种配置

    [基于spark IM 的二次开发笔记]第一天 各种配置 http://juforg.iteye.com/blog/1870487 http://www.igniterealtime.org/down ...

  3. 大数据实时处理-基于Spark的大数据实时处理及应用技术培训

    随着互联网.移动互联网和物联网的发展,我们已经切实地迎来了一个大数据 的时代.大数据是指无法在一定时间内用常规软件工具对其内容进行抓取.管理和处理的数据集合,对大数据的分析已经成为一个非常重要且紧迫的 ...

  4. 基于Spark和SparkSQL的NetFlow流量的初步分析——scala语言

    基于Spark和SparkSQL的NetFlow流量的初步分析--scala语言 标签: NetFlow Spark SparkSQL 本文主要是介绍如何使用Spark做一些简单的NetFlow数据的 ...

  5. UserView--第二种方式(避免第一种方式Set饱和),基于Spark算子的java代码实现

      UserView--第二种方式(避免第一种方式Set饱和),基于Spark算子的java代码实现   测试数据 java代码 package com.hzf.spark.study; import ...

  6. UserView--第一种方式set去重,基于Spark算子的java代码实现

    UserView--第一种方式set去重,基于Spark算子的java代码实现 测试数据 java代码 package com.hzf.spark.study; import java.util.Ha ...

  7. 基于Spark自动扩展scikit-learn (spark-sklearn)(转载)

    转载自:https://blog.csdn.net/sunbow0/article/details/50848719 1.基于Spark自动扩展scikit-learn(spark-sklearn)1 ...

  8. 苏宁基于Spark Streaming的实时日志分析系统实践 Spark Streaming 在数据平台日志解析功能的应用

    https://mp.weixin.qq.com/s/KPTM02-ICt72_7ZdRZIHBA 苏宁基于Spark Streaming的实时日志分析系统实践 原创: AI+落地实践 AI前线 20 ...

  9. 基于Spark Mllib的文本分类

    基于Spark Mllib的文本分类 文本分类是一个典型的机器学习问题,其主要目标是通过对已有语料库文本数据训练得到分类模型,进而对新文本进行类别标签的预测.这在很多领域都有现实的应用场景,如新闻网站 ...

随机推荐

  1. qt反走样(简选)

    # -*- coding: utf-8 -*- # python:2.x __author__ = 'Administrator' #qt反走样(简选) #概念 """ ...

  2. IOS 用drawRect 画表格

    自定义一个View DrawLine DrawLine.h #import <UIKit/UIKit.h> @protocol gridTouchDelete <NSObject&g ...

  3. JavaScript ----------------- 原型式继承

    思想:借助原型可以基于已有的对象创建新对象,同时还不必因此创建自定义类型.为了达到这个目的,看看下面的实现方式 function object(o){ function F(){ } F.protot ...

  4. CSU 1808 地铁

    题意: ICPCCamp 有 n 个地铁站,用 1,2,-,n 编号. m 段双向的地铁线路连接 n 个地铁站,其中第 i 段地铁属于 ci 号线,位于站 ai,bi 之间,往返均需要花费 ti 分钟 ...

  5. Oracle - 使用序列+触发器实现主键自增长

    Oracle中的自增,不如Sql server那般方便. --.创建序列 CREATE SEQUENCE "TABLE_NAME"."SQ_NAME" MINV ...

  6. Windows Phone 学习教程(一)

    http://www.cnblogs.com/webabcd/category/385852.html Windows Phone 7 教程 Windows Phone 8.1 Windows Pho ...

  7. codeforces 340E Iahub and Permutations(错排or容斥)

    转载请注明出处: http://www.cnblogs.com/fraud/          ——by fraud Iahub and Permutations Iahub is so happy ...

  8. css 实现进度条

    <select id="progress" onchange="changeProgress(this)"> <option value=&q ...

  9. css3实现手机qq空间菜单按钮

    工作之余写的一个类似于QQzone的菜单效果 先上截图: 图一为点击按钮前界面: 图二为点击按钮后的界面 下面上代码: <!--css部分--> <style type=" ...

  10. Scrapy学习系列(一):网页元素查询CSS Selector和XPath Selector

    这篇文章主要介绍创建一个简单的spider,顺便介绍一下对网页元素的选取方式(css selector, xpath selector). 第一步:创建spider工程 打开命令行运行以下命令: sc ...