1. package com.data.ml.classify;
  2.  
  3. import java.io.File;
  4. import java.util.ArrayList;
  5. import java.util.Collections;
  6. import java.util.HashMap;
  7. import java.util.HashSet;
  8. import java.util.List;
  9. import java.util.Map;
  10. import java.util.Map.Entry;
  11. import java.util.Set;
  12. import java.util.regex.Matcher;
  13. import java.util.regex.Pattern;
  14.  
  15. import com.data.util.IoUtil;
  16.  
  17. public class NativeBayes {
  18. /**
  19. * 默认频率
  20. */
  21. private double defaultFreq = 0.1;
  22.  
  23. /**
  24. * 训练数据的比例
  25. */
  26. private Double trainingPercent = 0.8;
  27.  
  28. private Map<String, List<String>> files_all = new HashMap<String, List<String>>();
  29.  
  30. private Map<String, List<String>> files_train = new HashMap<String, List<String>>();
  31.  
  32. private Map<String, List<String>> files_test = new HashMap<String, List<String>>();
  33.  
  34. public NativeBayes() {
  35.  
  36. }
  37.  
  38. /**
  39. * 每个分类的频率
  40. */
  41. private Map<String, Integer> classFreq = new HashMap<String, Integer>();
  42.  
  43. private Map<String, Double> ClassProb = new HashMap<String, Double>();
  44.  
  45. /**
  46. * 特征总数
  47. */
  48. private Set<String> WordDict = new HashSet<String>();
  49.  
  50. private Map<String, Map<String, Integer>> classFeaFreq = new HashMap<String, Map<String, Integer>>();
  51.  
  52. private Map<String, Map<String, Double>> ClassFeaProb = new HashMap<String, Map<String, Double>>();
  53.  
  54. private Map<String, Double> ClassDefaultProb = new HashMap<String, Double>();
  55.  
  56. /**
  57. * 计算准确率
  58. * @param reallist 真实类别
  59. * @param pridlist 预测类别
  60. */
  61. public void Evaluate(List<String> reallist, List<String> pridlist){
  62. double correctNum = 0.0;
  63. for (int i = 0; i < reallist.size(); i++) {
  64. if(reallist.get(i) == pridlist.get(i)){
  65. correctNum += 1;
  66. }
  67. }
  68. double accuracy = correctNum / reallist.size();
  69. System.out.println("准确率为:" + accuracy);
  70. }
  71.  
  72. /**
  73. * 计算精确率和召回率
  74. * @param reallist
  75. * @param pridlist
  76. * @param classname
  77. */
  78. public void CalPreRec(List<String> reallist, List<String> pridlist, String classname){
  79. double correctNum = 0.0;
  80. double allNum = 0.0;//测试数据中,某个分类的文章总数
  81. double preNum = 0.0;//测试数据中,预测为该分类的文章总数
  82.  
  83. for (int i = 0; i < reallist.size(); i++) {
  84. if(reallist.get(i) == classname){
  85. allNum += 1;
  86. if(reallist.get(i) == pridlist.get(i)){
  87. correctNum += 1;
  88. }
  89. }
  90. if(pridlist.get(i) == classname){
  91. preNum += 1;
  92. }
  93. }
  94. System.out.println(classname + " 精确率(跟预测分类比较):" + correctNum / preNum + " 召回率(跟真实分类比较):" + correctNum / allNum);
  95. }
  96.  
  97. /**
  98. * 用模型进行预测
  99. */
  100. public void PredictTestData() {
  101. List<String> reallist=new ArrayList<String>();
  102. List<String> pridlist=new ArrayList<String>();
  103.  
  104. for (Entry<String, List<String>> entry : files_test.entrySet()) {
  105. String realclassname = entry.getKey();
  106. List<String> files = entry.getValue();
  107.  
  108. for (String file : files) {
  109. reallist.add(realclassname);
  110.  
  111. List<String> classnamelist=new ArrayList<String>();
  112. List<Double> scorelist=new ArrayList<Double>();
  113. for (Entry<String, Double> entry_1 : ClassProb.entrySet()) {
  114. String classname = entry_1.getKey();
  115. //先验概率
  116. Double score = Math.log(entry_1.getValue());
  117.  
  118. String[] words = IoUtil.readFromFile(new File(file)).split(" ");
  119. for (String word : words) {
  120. if(!WordDict.contains(word)){
  121. continue;
  122. }
  123.  
  124. if(ClassFeaProb.get(classname).containsKey(word)){
  125. score += Math.log(ClassFeaProb.get(classname).get(word));
  126. }else{
  127. score += Math.log(ClassDefaultProb.get(classname));
  128. }
  129. }
  130.  
  131. classnamelist.add(classname);
  132. scorelist.add(score);
  133. }
  134.  
  135. Double maxProb = Collections.max(scorelist);
  136. int idx = scorelist.indexOf(maxProb);
  137. pridlist.add(classnamelist.get(idx));
  138. }
  139. }
  140.  
  141. Evaluate(reallist, pridlist);
  142.  
  143. for (String cname : files_test.keySet()) {
  144. CalPreRec(reallist, pridlist, cname);
  145. }
  146.  
  147. }
  148.  
  149. /**
  150. * 模型训练
  151. */
  152. public void createModel() {
  153. double sum = 0.0;
  154. for (Entry<String, Integer> entry : classFreq.entrySet()) {
  155. sum+=entry.getValue();
  156. }
  157. for (Entry<String, Integer> entry : classFreq.entrySet()) {
  158. ClassProb.put(entry.getKey(), entry.getValue()/sum);
  159. }
  160.  
  161. for (Entry<String, Map<String, Integer>> entry : classFeaFreq.entrySet()) {
  162. sum = 0.0;
  163. String classname = entry.getKey();
  164. for (Entry<String, Integer> entry_1 : entry.getValue().entrySet()){
  165. sum += entry_1.getValue();
  166. }
  167. double newsum = sum + WordDict.size()*defaultFreq;
  168.  
  169. Map<String, Double> feaProb = new HashMap<String, Double>();
  170. ClassFeaProb.put(classname, feaProb);
  171.  
  172. for (Entry<String, Integer> entry_1 : entry.getValue().entrySet()){
  173. String word = entry_1.getKey();
  174. feaProb.put(word, (entry_1.getValue() +defaultFreq) /newsum);
  175. }
  176. ClassDefaultProb.put(classname, defaultFreq/newsum);
  177. }
  178. }
  179.  
  180. /**
  181. * 加载训练数据
  182. */
  183. public void loadTrainData(){
  184. for (Entry<String, List<String>> entry : files_train.entrySet()) {
  185. String classname = entry.getKey();
  186. List<String> docs = entry.getValue();
  187.  
  188. classFreq.put(classname, docs.size());
  189.  
  190. Map<String, Integer> feaFreq = new HashMap<String, Integer>();
  191. classFeaFreq.put(classname, feaFreq);
  192.  
  193. for (String doc : docs) {
  194. String[] words = IoUtil.readFromFile(new File(doc)).split(" ");
  195. for (String word : words) {
  196.  
  197. WordDict.add(word);
  198.  
  199. if(feaFreq.containsKey(word)){
  200. int num = feaFreq.get(word) + 1;
  201. feaFreq.put(word, num);
  202. }else{
  203. feaFreq.put(word, 1);
  204. }
  205. }
  206. }
  207.  
  208. }
  209. System.out.println(classFreq.size()+" 分类, " + WordDict.size()+" 特征词");
  210. }
  211.  
  212. /**
  213. * 将数据分为训练数据和测试数据
  214. *
  215. * @param dataDir
  216. */
  217. public void splitData(String dataDir) {
  218. // 用文件名区分类别
  219. Pattern pat = Pattern.compile("\\d+([a-z]+?)\\.");
  220. dataDir = "testdata/allfiles";
  221. File f = new File(dataDir);
  222. File[] files = f.listFiles();
  223. for (File file : files) {
  224. String fname = file.getName();
  225. Matcher m = pat.matcher(fname);
  226. if (m.find()) {
  227. String cname = m.group(1);
  228. if (files_all.containsKey(cname)) {
  229. files_all.get(cname).add(file.toString());
  230. } else {
  231. List<String> tmp = new ArrayList<String>();
  232. tmp.add(file.toString());
  233. files_all.put(cname, tmp);
  234. }
  235. } else {
  236. System.out.println("err: " + file);
  237. }
  238. }
  239.  
  240. System.out.println("统计数据:");
  241. for (Entry<String, List<String>> entry : files_all.entrySet()) {
  242. String cname = entry.getKey();
  243. List<String> value = entry.getValue();
  244. // System.out.println(cname + " : " + value.size());
  245.  
  246. List<String> train = new ArrayList<String>();
  247. List<String> test = new ArrayList<String>();
  248.  
  249. for (String str : value) {
  250. if (Math.random() <= trainingPercent) {// 80%用来训练 , 20%测试
  251. train.add(str);
  252. } else {
  253. test.add(str);
  254. }
  255. }
  256.  
  257. files_train.put(cname, train);
  258. files_test.put(cname, test);
  259. }
  260.  
  261. System.out.println("所有文件数:");
  262. printStatistics(files_all);
  263. System.out.println("训练文件数:");
  264. printStatistics(files_train);
  265. System.out.println("测试文件数:");
  266. printStatistics(files_test);
  267.  
  268. }
  269.  
  270. /**
  271. * 打印统计信息
  272. *
  273. * @param m
  274. */
  275. public void printStatistics(Map<String, List<String>> m) {
  276. for (Entry<String, List<String>> entry : m.entrySet()) {
  277. String cname = entry.getKey();
  278. List<String> value = entry.getValue();
  279. System.out.println(cname + " : " + value.size());
  280. }
  281. System.out.println("--------------------------------");
  282. }
  283.  
  284. public static void main(String[] args) {
  285. NativeBayes bayes = new NativeBayes();
  286. bayes.splitData(null);
  287. bayes.loadTrainData();
  288. bayes.createModel();
  289. bayes.PredictTestData();
  290.  
  291. }
  292.  
  293. }
  294.  
  295. 所有文件数:
    sports : 1018
    auto : 1020
    business : 1028
    --------------------------------
    训练文件数:
    sports : 791
    auto : 812
    business : 808
    --------------------------------
    测试文件数:
    sports : 227
    auto : 208
    business : 220
    --------------------------------
    3 分类, 39613 特征词
    准确率为:0.9801526717557252
    sports 精确率(跟预测分类比较):0.9956140350877193 召回率(跟真实分类比较):1.0
    auto 精确率(跟预测分类比较):0.9579439252336449 召回率(跟真实分类比较):0.9855769230769231
    business 精确率(跟预测分类比较):0.9859154929577465 召回率(跟真实分类比较):0.9545454545454546
  296.  
  297. 统计数据:
    所有文件数:
    sports : 1018
    auto : 1020
    business : 1028
    --------------------------------
    训练文件数:
    sports : 827
    auto : 833
    business : 825
    --------------------------------
    测试文件数:
    sports : 191
    auto : 187
    business : 203
    --------------------------------
    3 分类, 39907 特征词
    准确率为:0.9759036144578314
    sports 精确率(跟预测分类比较):0.9894736842105263 召回率(跟真实分类比较):0.9842931937172775
    auto 精确率(跟预测分类比较):0.9836956521739131 召回率(跟真实分类比较):0.9679144385026738
    business 精确率(跟预测分类比较):0.9565217391304348 召回率(跟真实分类比较):0.9753694581280788
  298.  

朴素贝叶斯文本分类java实现的更多相关文章

  1. Mahout朴素贝叶斯文本分类

    Mahout朴素贝叶斯文本分类算法 Mahout贝叶斯分类器按照官方的说法,是按照<Tackling the PoorAssumptions of Naive Bayes Text Classi ...

  2. 朴素贝叶斯文本分类-在《红楼梦》作者鉴别的应用上(python实现)

    朴素贝叶斯算法简单.高效.接下来我们来介绍其如何应用在<红楼梦>作者的鉴别上. 第一步,当然是先得有文本数据,我在网上随便下载了一个txt(当时急着交初稿...).分类肯定是要一个回合一个 ...

  3. 朴素贝叶斯文本分类(python代码实现)

    朴素贝叶斯(naive bayes)法是基于贝叶斯定理与特征条件独立假设的分类方法. 优点:在数据较少的情况下仍然有效,可以处理多分类问题. 缺点:对入输入数据的准备方式较为敏感. 使用数据类型:标称 ...

  4. 朴素贝叶斯文本分类实现 python cherry分类器

    贝叶斯模型在机器学习以及人工智能中都有出现,cherry分类器使用了朴素贝叶斯模型算法,经过简单的优化,使用1000个训练数据就能得到97.5%的准确率.虽然现在主流的框架都带有朴素贝叶斯模型算法,大 ...

  5. 详解使用EM算法的半监督学习方法应用于朴素贝叶斯文本分类

    1.前言 对大量需要分类的文本数据进行标记是一项繁琐.耗时的任务,而真实世界中,如互联网上存在大量的未标注的数据,获取这些是容易和廉价的.在下面的内容中,我们介绍使用半监督学习和EM算法,充分结合大量 ...

  6. 利用朴素贝叶斯算法进行分类-Java代码实现

    http://www.crocro.cn/post/286.html 利用朴素贝叶斯算法进行分类-Java代码实现  鳄鱼  3个月前 (12-14)  分类:机器学习  阅读(44)  评论(0) ...

  7. 朴素贝叶斯算法分析及java 实现

    1. 先引入一个简单的例子 出处:http://www.ruanyifeng.com/blog/2013/12/naive_bayes_classifier.html 一.病人分类的例子 让我从一个例 ...

  8. Naive Bayes(朴素贝叶斯算法)[分类算法]

    Naïve Bayes(朴素贝叶斯)分类算法的实现 (1) 简介: (2)   算法描述: (3) <?php /* *Naive Bayes朴素贝叶斯算法(分类算法的实现) */ /* *把. ...

  9. 芝麻HTTP:记scikit-learn贝叶斯文本分类的坑

    基本步骤: 1.训练素材分类: 我是参考官方的目录结构: 每个目录中放对应的文本,一个txt文件一篇对应的文章:就像下面这样 需要注意的是所有素材比例请保持在相同的比例(根据训练结果酌情调整.不可比例 ...

随机推荐

  1. 深入理解计算机系统第二版习题解答CSAPP 2.19

    在2.17的基础上完成下表: x 十六进制 T2U(x) -8 0x8 -3 0xD -2 0xE -1 0xF 0 0x0 5 0x5

  2. 我的第一篇——nginx+naxsi总结篇1

    今天是我正式在Linux下安装nginx的第一天吧,搜索,查看,安装,这之间肯定是或多或少的遇到了很多的问题,不管是大的还是小的,都应该记录下来,或许以后还会用到,或许会帮到其他人. 首先,先说一下, ...

  3. Java栈实现

    栈数组实现一:优点:入栈和出栈速度快,缺点:长度有限(有时候这也不能算是个缺点) public class Stack { private int top = -1; private Object[] ...

  4. gVim多标签页

    我们一般使用的文本编辑器,如:editplus.ultraEdit等都是支持多标签页的,可以同时打开多个文件,方便切换,以前gVim只能打开多个窗口,或者一个窗口切出多个窗口来编辑,自从7.0以后Vi ...

  5. [MSDN]使用 REST 处理文件夹和文件

    msdn: http://msdn.microsoft.com/zh-cn/library/dn292553.aspx 了解如何使用 SharePoint 2013 REST 界面对文件夹和文件执行基 ...

  6. RabbitMQ 原文译1.2--"Hello Word"

    本系列文章均来自官网原文,属于个人翻译,如有雷同,权当个人归档,忽喷. .NET/C# RabbitMQ 客户端下载地址:https://github.com/rabbitmq/rabbitmq-do ...

  7. C# String 前面不足位数补零的方法 PadLeft

    PadLeft(int totalWidth, char paddingChar) //在字符串左边用 paddingChar 补足 totalWidth 长度PadLeft(int totalWid ...

  8. 利用openssl进行RSA加密解密

    openssl是一个功能强大的工具包,它集成了众多密码算法及实用工具.我们即可以利用它提供的命令台工具生成密钥.证书来加密解密文件,也可以在利用其提供的API接口在代码中对传输信息进行加密. RSA是 ...

  9. xadmin学习笔记(一)——编程准备

    前言 xadmin是GitHub上的开源项目,它是Django admin的超强升级版,提供了强大的插件系统,丰富的内置功能,以及无与伦比的UI主题,使管理系统的实现变得异常简单.详情请参见官方网址. ...

  10. iOS 的一点理解(一) 代理delegate

    做了一年的iOS,想记录自己对知识点的一点理解. 第一篇,想记录一下iOS中delegate(委托,也有人称作代理)的理解吧. 故名思议,delegate就是代理的含义, 一件事情自己不方便做,然后交 ...