朴素贝叶斯文本分类java实现
- package com.data.ml.classify;
- import java.io.File;
- import java.util.ArrayList;
- import java.util.Collections;
- import java.util.HashMap;
- import java.util.HashSet;
- import java.util.List;
- import java.util.Map;
- import java.util.Map.Entry;
- import java.util.Set;
- import java.util.regex.Matcher;
- import java.util.regex.Pattern;
- import com.data.util.IoUtil;
- public class NativeBayes {
- /**
- * 默认频率
- */
- private double defaultFreq = 0.1;
- /**
- * 训练数据的比例
- */
- private Double trainingPercent = 0.8;
- private Map<String, List<String>> files_all = new HashMap<String, List<String>>();
- private Map<String, List<String>> files_train = new HashMap<String, List<String>>();
- private Map<String, List<String>> files_test = new HashMap<String, List<String>>();
- public NativeBayes() {
- }
- /**
- * 每个分类的频率
- */
- private Map<String, Integer> classFreq = new HashMap<String, Integer>();
- private Map<String, Double> ClassProb = new HashMap<String, Double>();
- /**
- * 特征总数
- */
- private Set<String> WordDict = new HashSet<String>();
- private Map<String, Map<String, Integer>> classFeaFreq = new HashMap<String, Map<String, Integer>>();
- private Map<String, Map<String, Double>> ClassFeaProb = new HashMap<String, Map<String, Double>>();
- private Map<String, Double> ClassDefaultProb = new HashMap<String, Double>();
- /**
- * 计算准确率
- * @param reallist 真实类别
- * @param pridlist 预测类别
- */
- public void Evaluate(List<String> reallist, List<String> pridlist){
- double correctNum = 0.0;
- for (int i = 0; i < reallist.size(); i++) {
- if(reallist.get(i) == pridlist.get(i)){
- correctNum += 1;
- }
- }
- double accuracy = correctNum / reallist.size();
- System.out.println("准确率为:" + accuracy);
- }
- /**
- * 计算精确率和召回率
- * @param reallist
- * @param pridlist
- * @param classname
- */
- public void CalPreRec(List<String> reallist, List<String> pridlist, String classname){
- double correctNum = 0.0;
- double allNum = 0.0;//测试数据中,某个分类的文章总数
- double preNum = 0.0;//测试数据中,预测为该分类的文章总数
- for (int i = 0; i < reallist.size(); i++) {
- if(reallist.get(i) == classname){
- allNum += 1;
- if(reallist.get(i) == pridlist.get(i)){
- correctNum += 1;
- }
- }
- if(pridlist.get(i) == classname){
- preNum += 1;
- }
- }
- System.out.println(classname + " 精确率(跟预测分类比较):" + correctNum / preNum + " 召回率(跟真实分类比较):" + correctNum / allNum);
- }
- /**
- * 用模型进行预测
- */
- public void PredictTestData() {
- List<String> reallist=new ArrayList<String>();
- List<String> pridlist=new ArrayList<String>();
- for (Entry<String, List<String>> entry : files_test.entrySet()) {
- String realclassname = entry.getKey();
- List<String> files = entry.getValue();
- for (String file : files) {
- reallist.add(realclassname);
- List<String> classnamelist=new ArrayList<String>();
- List<Double> scorelist=new ArrayList<Double>();
- for (Entry<String, Double> entry_1 : ClassProb.entrySet()) {
- String classname = entry_1.getKey();
- //先验概率
- Double score = Math.log(entry_1.getValue());
- String[] words = IoUtil.readFromFile(new File(file)).split(" ");
- for (String word : words) {
- if(!WordDict.contains(word)){
- continue;
- }
- if(ClassFeaProb.get(classname).containsKey(word)){
- score += Math.log(ClassFeaProb.get(classname).get(word));
- }else{
- score += Math.log(ClassDefaultProb.get(classname));
- }
- }
- classnamelist.add(classname);
- scorelist.add(score);
- }
- Double maxProb = Collections.max(scorelist);
- int idx = scorelist.indexOf(maxProb);
- pridlist.add(classnamelist.get(idx));
- }
- }
- Evaluate(reallist, pridlist);
- for (String cname : files_test.keySet()) {
- CalPreRec(reallist, pridlist, cname);
- }
- }
- /**
- * 模型训练
- */
- public void createModel() {
- double sum = 0.0;
- for (Entry<String, Integer> entry : classFreq.entrySet()) {
- sum+=entry.getValue();
- }
- for (Entry<String, Integer> entry : classFreq.entrySet()) {
- ClassProb.put(entry.getKey(), entry.getValue()/sum);
- }
- for (Entry<String, Map<String, Integer>> entry : classFeaFreq.entrySet()) {
- sum = 0.0;
- String classname = entry.getKey();
- for (Entry<String, Integer> entry_1 : entry.getValue().entrySet()){
- sum += entry_1.getValue();
- }
- double newsum = sum + WordDict.size()*defaultFreq;
- Map<String, Double> feaProb = new HashMap<String, Double>();
- ClassFeaProb.put(classname, feaProb);
- for (Entry<String, Integer> entry_1 : entry.getValue().entrySet()){
- String word = entry_1.getKey();
- feaProb.put(word, (entry_1.getValue() +defaultFreq) /newsum);
- }
- ClassDefaultProb.put(classname, defaultFreq/newsum);
- }
- }
- /**
- * 加载训练数据
- */
- public void loadTrainData(){
- for (Entry<String, List<String>> entry : files_train.entrySet()) {
- String classname = entry.getKey();
- List<String> docs = entry.getValue();
- classFreq.put(classname, docs.size());
- Map<String, Integer> feaFreq = new HashMap<String, Integer>();
- classFeaFreq.put(classname, feaFreq);
- for (String doc : docs) {
- String[] words = IoUtil.readFromFile(new File(doc)).split(" ");
- for (String word : words) {
- WordDict.add(word);
- if(feaFreq.containsKey(word)){
- int num = feaFreq.get(word) + 1;
- feaFreq.put(word, num);
- }else{
- feaFreq.put(word, 1);
- }
- }
- }
- }
- System.out.println(classFreq.size()+" 分类, " + WordDict.size()+" 特征词");
- }
- /**
- * 将数据分为训练数据和测试数据
- *
- * @param dataDir
- */
- public void splitData(String dataDir) {
- // 用文件名区分类别
- Pattern pat = Pattern.compile("\\d+([a-z]+?)\\.");
- dataDir = "testdata/allfiles";
- File f = new File(dataDir);
- File[] files = f.listFiles();
- for (File file : files) {
- String fname = file.getName();
- Matcher m = pat.matcher(fname);
- if (m.find()) {
- String cname = m.group(1);
- if (files_all.containsKey(cname)) {
- files_all.get(cname).add(file.toString());
- } else {
- List<String> tmp = new ArrayList<String>();
- tmp.add(file.toString());
- files_all.put(cname, tmp);
- }
- } else {
- System.out.println("err: " + file);
- }
- }
- System.out.println("统计数据:");
- for (Entry<String, List<String>> entry : files_all.entrySet()) {
- String cname = entry.getKey();
- List<String> value = entry.getValue();
- // System.out.println(cname + " : " + value.size());
- List<String> train = new ArrayList<String>();
- List<String> test = new ArrayList<String>();
- for (String str : value) {
- if (Math.random() <= trainingPercent) {// 80%用来训练 , 20%测试
- train.add(str);
- } else {
- test.add(str);
- }
- }
- files_train.put(cname, train);
- files_test.put(cname, test);
- }
- System.out.println("所有文件数:");
- printStatistics(files_all);
- System.out.println("训练文件数:");
- printStatistics(files_train);
- System.out.println("测试文件数:");
- printStatistics(files_test);
- }
- /**
- * 打印统计信息
- *
- * @param m
- */
- public void printStatistics(Map<String, List<String>> m) {
- for (Entry<String, List<String>> entry : m.entrySet()) {
- String cname = entry.getKey();
- List<String> value = entry.getValue();
- System.out.println(cname + " : " + value.size());
- }
- System.out.println("--------------------------------");
- }
- public static void main(String[] args) {
- NativeBayes bayes = new NativeBayes();
- bayes.splitData(null);
- bayes.loadTrainData();
- bayes.createModel();
- bayes.PredictTestData();
- }
- }
- 所有文件数:
sports : 1018
auto : 1020
business : 1028
--------------------------------
训练文件数:
sports : 791
auto : 812
business : 808
--------------------------------
测试文件数:
sports : 227
auto : 208
business : 220
--------------------------------
3 分类, 39613 特征词
准确率为:0.9801526717557252
sports 精确率(跟预测分类比较):0.9956140350877193 召回率(跟真实分类比较):1.0
auto 精确率(跟预测分类比较):0.9579439252336449 召回率(跟真实分类比较):0.9855769230769231
business 精确率(跟预测分类比较):0.9859154929577465 召回率(跟真实分类比较):0.9545454545454546- 统计数据:
所有文件数:
sports : 1018
auto : 1020
business : 1028
--------------------------------
训练文件数:
sports : 827
auto : 833
business : 825
--------------------------------
测试文件数:
sports : 191
auto : 187
business : 203
--------------------------------
3 分类, 39907 特征词
准确率为:0.9759036144578314
sports 精确率(跟预测分类比较):0.9894736842105263 召回率(跟真实分类比较):0.9842931937172775
auto 精确率(跟预测分类比较):0.9836956521739131 召回率(跟真实分类比较):0.9679144385026738
business 精确率(跟预测分类比较):0.9565217391304348 召回率(跟真实分类比较):0.9753694581280788
朴素贝叶斯文本分类java实现的更多相关文章
- Mahout朴素贝叶斯文本分类
Mahout朴素贝叶斯文本分类算法 Mahout贝叶斯分类器按照官方的说法,是按照<Tackling the PoorAssumptions of Naive Bayes Text Classi ...
- 朴素贝叶斯文本分类-在《红楼梦》作者鉴别的应用上(python实现)
朴素贝叶斯算法简单.高效.接下来我们来介绍其如何应用在<红楼梦>作者的鉴别上. 第一步,当然是先得有文本数据,我在网上随便下载了一个txt(当时急着交初稿...).分类肯定是要一个回合一个 ...
- 朴素贝叶斯文本分类(python代码实现)
朴素贝叶斯(naive bayes)法是基于贝叶斯定理与特征条件独立假设的分类方法. 优点:在数据较少的情况下仍然有效,可以处理多分类问题. 缺点:对入输入数据的准备方式较为敏感. 使用数据类型:标称 ...
- 朴素贝叶斯文本分类实现 python cherry分类器
贝叶斯模型在机器学习以及人工智能中都有出现,cherry分类器使用了朴素贝叶斯模型算法,经过简单的优化,使用1000个训练数据就能得到97.5%的准确率.虽然现在主流的框架都带有朴素贝叶斯模型算法,大 ...
- 详解使用EM算法的半监督学习方法应用于朴素贝叶斯文本分类
1.前言 对大量需要分类的文本数据进行标记是一项繁琐.耗时的任务,而真实世界中,如互联网上存在大量的未标注的数据,获取这些是容易和廉价的.在下面的内容中,我们介绍使用半监督学习和EM算法,充分结合大量 ...
- 利用朴素贝叶斯算法进行分类-Java代码实现
http://www.crocro.cn/post/286.html 利用朴素贝叶斯算法进行分类-Java代码实现 鳄鱼 3个月前 (12-14) 分类:机器学习 阅读(44) 评论(0) ...
- 朴素贝叶斯算法分析及java 实现
1. 先引入一个简单的例子 出处:http://www.ruanyifeng.com/blog/2013/12/naive_bayes_classifier.html 一.病人分类的例子 让我从一个例 ...
- Naive Bayes(朴素贝叶斯算法)[分类算法]
Naïve Bayes(朴素贝叶斯)分类算法的实现 (1) 简介: (2) 算法描述: (3) <?php /* *Naive Bayes朴素贝叶斯算法(分类算法的实现) */ /* *把. ...
- 芝麻HTTP:记scikit-learn贝叶斯文本分类的坑
基本步骤: 1.训练素材分类: 我是参考官方的目录结构: 每个目录中放对应的文本,一个txt文件一篇对应的文章:就像下面这样 需要注意的是所有素材比例请保持在相同的比例(根据训练结果酌情调整.不可比例 ...
随机推荐
- 深入理解计算机系统第二版习题解答CSAPP 2.19
在2.17的基础上完成下表: x 十六进制 T2U(x) -8 0x8 -3 0xD -2 0xE -1 0xF 0 0x0 5 0x5
- 我的第一篇——nginx+naxsi总结篇1
今天是我正式在Linux下安装nginx的第一天吧,搜索,查看,安装,这之间肯定是或多或少的遇到了很多的问题,不管是大的还是小的,都应该记录下来,或许以后还会用到,或许会帮到其他人. 首先,先说一下, ...
- Java栈实现
栈数组实现一:优点:入栈和出栈速度快,缺点:长度有限(有时候这也不能算是个缺点) public class Stack { private int top = -1; private Object[] ...
- gVim多标签页
我们一般使用的文本编辑器,如:editplus.ultraEdit等都是支持多标签页的,可以同时打开多个文件,方便切换,以前gVim只能打开多个窗口,或者一个窗口切出多个窗口来编辑,自从7.0以后Vi ...
- [MSDN]使用 REST 处理文件夹和文件
msdn: http://msdn.microsoft.com/zh-cn/library/dn292553.aspx 了解如何使用 SharePoint 2013 REST 界面对文件夹和文件执行基 ...
- RabbitMQ 原文译1.2--"Hello Word"
本系列文章均来自官网原文,属于个人翻译,如有雷同,权当个人归档,忽喷. .NET/C# RabbitMQ 客户端下载地址:https://github.com/rabbitmq/rabbitmq-do ...
- C# String 前面不足位数补零的方法 PadLeft
PadLeft(int totalWidth, char paddingChar) //在字符串左边用 paddingChar 补足 totalWidth 长度PadLeft(int totalWid ...
- 利用openssl进行RSA加密解密
openssl是一个功能强大的工具包,它集成了众多密码算法及实用工具.我们即可以利用它提供的命令台工具生成密钥.证书来加密解密文件,也可以在利用其提供的API接口在代码中对传输信息进行加密. RSA是 ...
- xadmin学习笔记(一)——编程准备
前言 xadmin是GitHub上的开源项目,它是Django admin的超强升级版,提供了强大的插件系统,丰富的内置功能,以及无与伦比的UI主题,使管理系统的实现变得异常简单.详情请参见官方网址. ...
- iOS 的一点理解(一) 代理delegate
做了一年的iOS,想记录自己对知识点的一点理解. 第一篇,想记录一下iOS中delegate(委托,也有人称作代理)的理解吧. 故名思议,delegate就是代理的含义, 一件事情自己不方便做,然后交 ...