重写轮子之 GaussionNB
我仿照sk-learn 中 GaussionNB 的结构, 重写了该算法的轮子,命名为 MyGaussionNB, 如下:
# !/usr/bin/python
# -*- coding:utf-8 -*-
"""
Reimplement Gaussion naive Bayes algorithm as a practice
"""
# Author: 相忠良(Zhong-Liang Xiang) <ugoood@163.com>
# Finished at June 3, 2017
import numpy as np
from sklearn import datasets, cross_validation
import math
import matplotlib.pyplot as plt
from sklearn import naive_bayes
def load_data():
iris = datasets.load_iris()
return cross_validation.train_test_split(iris.data, iris.target, test_size=0.50, random_state=0)
class MyGaussianNB:
"""
注意: 使用该分类器前, 必须把标签处理成 0,1,2,3.... 这样的形式
"""
class_prior_dic = {}
class_prior_arr = []
class_count_dic = {}
class_count_arr = []
theta_ = []
sigma_ = []
predict_label = [] # 最终预测值
def __init__(self):
pass
def fit(self, X, y):
"""
Fit Gaussian naive Bayes according to X, y
Parameters
----------
X : array-like, shape(n-samples, m-features)
X part of training data
y : array-like, shape(n-samples,)
labels of training data
Returns:
--------
self : object
Return self.
"""
# calculate class_prior and class_count-------------
dic = {}
for item in y:
if item in dic.keys():
dic[item] += 1
else:
dic[item] = 1
dic_temp = dic.copy()
self.class_count_dic = dic_temp
self.class_count_arr = dic_temp.values()
for item in dic:
dic[item] = float(dic[item]) / y.shape[0]
self.class_prior_dic = dic
self.class_prior_arr = dic.values()
# --------------------------------------------------
# 调用本类 私有方法
self.__cal_theta_sigma_arr(X_train, y_train)
def predict(self, X):
"""
Predict class labels of X
Parameters
----------
X : array-like, shape(n-samples, m-features)
X a set of test data
Returns:
--------
a list of class labels of X
"""
post_arr_matrix = []
for c in self.class_prior_dic.keys():
post_per_sample = []
for sample in X:
i = 0
temp = 0.0
for element in sample:
# 注意: 使用该分类器前, 必须把标签处理成 0,1,2,3.... 这样的形式
# 原因在 下面的 theta_[c][] 处
# 重要: 用 np.log(x)相加的形式, 因为有的概率值特别特别小, 导致后验概率为0
# log sum 越大, post 概率 越大!
# 注意: 我们并未采用 -np.log(x)的形式
temp = temp + np.log(self.__Gaussion_function(element, self.theta_[c][i], self.sigma_[c][i]))
i += 1
# print '在某类, 一个样例结束'
temp = temp + np.log(self.class_prior_dic[c]) # temp - log(p(c))
post_per_sample.append(temp) # 某类下, X 中所有 sample 的 post 概率, shape(n-samples,)
# print '某类, 所有样本概率', post_per_sample
post_arr_matrix.append(post_per_sample) # 各类下, X 中所有 sample 的 post 概率, shape(n-classes, n-samples)
self.predict_label = np.argmax(post_arr_matrix, 0) # 返回 matrix 每列最大值索引. 这里, 索引值恰好是每个 sample 的预测 label.
return self.predict_label
def score(self, X, y):
# 返回正确率
temp_1 = list(X)
temp = list(temp_1 == y)
return 1.0 * temp.count(True) / temp.__len__()
# 私有方法: 计算 每类 各列的 均值theta 和 标准差sigma
def __cal_theta_sigma_arr(self, X, y):
theta_arr = []
sigma_arr = []
xxx = [] # including (X,y)
for item in X:
xxx.append(list(item))
ii = 0
for item in xxx:
item.append(y[ii])
ii += 1
# 担心改了原数据
sss = np.array(xxx).copy()
ssss = np.array(xxx).copy()
for k in self.class_count_dic.keys():
row_mask = np.array(sss[:, -1] == k, dtype = bool) # 行网子
temp = sss[row_mask, :] # 用 行 网 子 !
theta_arr.append(np.mean(temp, axis = 0)) # axis=0 表示列
row_mask_1 = np.array(ssss[:, -1] == k, dtype = bool) # 行网子
temp_1 = ssss[row_mask_1, :] # 用 行 网 子 !
sigma_arr.append(np.std(temp_1, axis = 0))
self.theta_ = theta_arr
self.sigma_ = sigma_arr
return theta_arr, sigma_arr
# Gaussian function
def __Gaussion_function(self, x, theta, sigma): # private method
return np.exp(-(x - theta) ** 2 / (2 * sigma ** 2)) / (np.sqrt(2 * np.pi) * sigma)
X_train, X_test, y_train, y_test = load_data()
MGN = MyGaussianNB()
MGN.fit(X_train, y_train)
a = MGN.predict(X_test)
b = np.array(a)
# print b
# print X_test
print '预测值: ', a
print '实际值: ', y_test
print a == y_test
print 'MyGaussionNB 预测正确率: ', MGN.score(MGN.predict_label, y_test)
# sk-learn 中的 GaussionNB 的性能, 且和我的实现 比较一下, 验证我的 implementation 的正确性.
cls = naive_bayes.GaussianNB()
cls.fit(X_train, y_train)
result = cls.predict(X_test)
print 'sklearn 的 GaussionNB 预测正确率: ', MGN.score(result, y_test)
# 结果几乎完全一致, 但在 test_size=0.95 及 训练集更小时, 我的程序会出现问题 !
'''
下面是编程过程中留下的经验
'''
# 重要1: 判断column value真假,用mask,取想要rows的方法
# row_mask = np.array(a[:, -1] == 0, dtype=bool)
# print a[row_mask, :]
# print np.mean(a[mask, :], axis=0)
# 重要2: 提取字典的keys集合和values集合
# print MGN.class_count_dic.keys()
# print MGN.class_count_dic.values()
# 重要3: 用 np.log(x)相加的形式, 因为有的概率值特别特别小, 导致后验概率为0
# log sum 越大, post 概率 越大!
# 注意: 我们并未采用 -np.log(x)的形式
# 重要4: Numpy中找出array中最大值所对应的行和列
# a = np.array([[.5, 2, 0],
# [5, 3, 6],
# [.5, 1, 0]])
#
# re = np.where(a == np.max(a[:,1])) a中第一列最大元素 在a中的坐标
# print re
# 重要5: 找出 列 or 行 的最大值索引 np.argmax(a,0), a 是矩阵, 0:列, 1:行
# 重要6: 必须要有用于测试的小数据, 来探测每一个func的正确性.
# 下面是我用来测试的小矩阵. 不仅仅测试自己编写的函数,
# 还得对numpy, python 中的函数探测其功能和使用方法.
# a = np.array([[.5, 2, 0],
# [.25, 3, 6],
# [.51, 1, 0]])
#
# b = np.array([11, 22, 33])
# aa = np.array(zip(a, b))
#
# cc = [True, False, True, False, True, True]
# print cc.__len__()
# print cc.count(True)
################################################
# 以下内容是我编程过程中用于测试和探查的各种乱七八糟的代码
#
# 我抛弃了这种做法----->: Gaussion这种东西, 算出的值 极有可能非常小, 得用 -log 相加 处理.
# log sum 越小, post 概率 越大!
# 取而代之的是------->: 直接 log后 相加, 取和的最大值的 为 那个样例 应得的标签.
# def Gaussion_function(x, u, sig):
# return np.exp(-(x - u) ** 2 / (2 * sig ** 2)) / (math.sqrt(2 * math.pi) * sig)
# x1 = Gaussion_function(6., 5.006, 0.34894699)
# x2 = Gaussion_function(2.2, 3.418, 0.37719491)
# x3 = Gaussion_function(4., 1.464, 0.17176728)
# x4 = Gaussion_function(1., 0.244, 0.10613199)
# x1 = Gaussion_function(5., 4.99574468, 0.35247299)
# x2 = Gaussion_function(2.2, 3.418, 0.37719491)
# x3 = Gaussion_function(4., 1.464, 0.17176728)
# x4 = Gaussion_function(1., 0.244, 0.10613199)
# print 'x1 ', x1
# print 'x2 ', x2
# print 'x3 ', x3
# print 'x4 ', x4
#
# x1 = -np.log(x1)
# x2 = -np.log(x2)
# x3 = -np.log(x3)
# x4 = -np.log(x4)
#
# pc = -np.log(0.33098591549295775)
# print 'x1 ', x1
# print 'x2 ', x2
# print 'x3 ', x3
# print 'x4 ', x4
# print 'pc ', pc
# print "x1-x4 log sum:", x1 + x2 + x3 + x4 + pc
#
# print MGN.class_prior_dic
# print MGN.theta_
# [array([ 5.006, 3.418, 1.464, 0.244, 0. ]),
# array([ 5.93469388, 2.78163265, 4.26530612, 1.33265306, 1. ]),
# array([ 6.60408163, 2.97755102, 5.56122449, 2.01836735, 2. ])]
#
# print MGN.sigma_
# [array([ 0.34894699, 0.37719491, 0.17176728, 0.10613199, 0. ]),
# array([ 0.51608851, 0.30282579, 0.4684107 , 0.19207541, 0. ]),
# array([ 0.62562921, 0.32151764, 0.54802664, 0.26929497, 0. ])]
# print 'x=0, 均值为0, 方差为1', Gaussion_function(0, 0, 1)
# x1 = Gaussion_function(5.8, 5.006, 0.34894699)
# x2 = Gaussion_function(2.8, 3.418, 0.37719491)
# x3 = Gaussion_function(5.1, 1.464, 0.17176728)
# x4 = Gaussion_function(2.4, 0.244, 0.10613199)
# print 'x1 ', x1
# print 'x2 ', x2
# print 'x3 ', x3
# print 'x4 ', x4
# print "x1-x4乘积:", x1 * x2 * x3 * x4 * 1.0 # x1-x4乘积: 2.53455055621e-188
# print X_test
# print MGN.class_prior_dic.keys()
# print "标准差", MGN.sigma_
# print "均值", MGN.theta_
# print MGN.class_count_dic
# print MGN.class_prior_dic
# print 'theta: ', MGN.theta_[0][0]
# print 'sigma: ', MGN.sigma_
# print a.__len__()
# print len(X_test)*4
# print MGN.theta_
# print X_train
# print np.mean(X_train, axis=0)
# print MGN.class_count_dic.keys()
# print 'x_train: ', X_train
# print 'y_train: ', y_train
#
# print 'MGN.class_prior_arr: ', MGN.class_prior_arr
# print 'MGN.class_prior_dic: ', MGN.class_prior_dic
# print 'MGN.class_count_arr: ', MGN.class_count_arr
# print 'MGN.class_count_dic: ', MGN.class_count_dic
# print np.argmax(a, 0)
#
# re = np.where(a == np.max(a[:, 0]))
# print re
# print int(re[0])
# print a[[True,False,True],:]
重写轮子之 GaussionNB的更多相关文章
- 重写轮子之 ID3
这是半成品, 已完成了 fit() 部分, 形成了包含一棵完整树的 node 对象. 后续工作是需解析该 node对象, 完成 predict() 工作. # !/usr/bin/python # - ...
- 重写轮子之 kNN
# !/usr/bin/python # -*- coding:utf-8 -*- """ Re-implement kNN algorithm as a practic ...
- 【转】C# 重写WndProc 拦截 发送 系统消息 + windows消息常量值(1)
C# 重写WndProc 拦截 发送 系统消息 + windows消息常量值(1) #region 截获消息 /// 截获消息 处理XP不能关机问题 protected ...
- Asp.net Mvc 请求是如何到达 MvcHandler的——UrlRoutingModule、MvcRouteHandler分析,并造个轮子
这个是转载自:http://www.cnblogs.com/keyindex/archive/2012/08/11/2634005.html(那个比较容易忘记,希望博主不要生气的) 前言 本文假定读者 ...
- 拆解轮子之XRecyclerView
简介 这个轮子是对RecyclerView的封装,主要完成了下拉刷新.上拉加载更多.RecyclerView头部.在我的Material Design学习项目中使用到了项目地址,感觉还不错.趁着毕业答 ...
- 跨平台技术实践案例: 用 reactxp 重写墨刀的移动端
Authors: Gao Cong, Perry Poon Illustrators: Shena Bian April 20, 2019 重新编写,又一次,我们又一次重新编写了移动端应用和移动端 ...
- 星级评分原理 N次重写的分析
使用的是雪碧图,用的软件是CSS Sprite Tools 第一次实现与分析: <!DOCTYPE html> <html> <head> <meta cha ...
- [18/11/29] 继承(extends)和方法的重写(override,不是重载)
一.何为继承?(对原有类的扩充) 继承让我们更加容易实现类的扩展. 比如,我们定义了人类,再定义Boy类就只需要扩展人类即可.实现了代码的重用,不用再重新发明轮子(don’t reinvent w ...
- C# 重写WndProc 拦截 发送 系统消息 + windows消息常量值
接收拦截+发送消息 对于处理所有消息.net 提供了wndproc进行重写 WndProc(ref Message m)protected override void WndProc(ref Mess ...
随机推荐
- 在ABPZERO中,扩展实体的方法。
内容 介绍 扩展的抽象实体 将新属性添加给用户 添加迁移 在界面上显示地址 在用户编辑/添加功能中添加地址 扩展的非抽象类实体 获得版本的派生实体 添加迁移 在界面上添加价格 在创建/编辑版本功能中加 ...
- AOV网络和Kahn算法拓扑排序
1.AOV与DAG 活动网络可以用来描述生产计划.施工过程.生产流程.程序流程等工程中各子工程的安排问题. 一般一个工程可以分成若干个子工程,这些子工程称为活动(Activity).完成了这些活动 ...
- POJ-1258 Agri-Net---MST裸题Prim
题目链接: https://vjudge.net/problem/POJ-1258 题目大意: 求MST 思路: 由于给的是邻接矩阵,直接prim算法 #include<iostream> ...
- POJ-2586 Y2K Accounting Bug贪心,区间盈利
题目链接: https://vjudge.net/problem/POJ-2586 题目大意: MS公司(我猜是微软)遇到了千年虫的问题,导致数据大量数据丢失.比如财务报表.现在知道这个奇特的公司每个 ...
- 编程基础学习JS的入门教程
将JavaScript 插入网页的方法 使用<script>标签在网页中插入Javascript代码. 插入JavaScript 与在网页中插入CSS的方式相似.使用下面的代码可以在网页中 ...
- 关于ES6 的对象解构赋值
之 前写了关于ES6数组的解构 现在 go on ; 解构不仅可以用于数组,还可以用于对象: 对象的解构和数组有一个重要的不同.数组的元素是按次序排列的,变量的取值是由他的位置决定的:而对象的属性没有 ...
- ReactNative Android之原生UI组件动态addView不显示问题解决
ReactNative Android之原生UI组件动态addView不显示问题解决 版权声明:本文为博主原创文章,未经博主允许不得转载. 转载请表明出处:http://www.cnblogs.com ...
- Redis常用命令--SortedSet
SortedSet是一个类似于Set的集合数据类型,里面的每个字符串元素都关联到一个score(整数或浮点数),并且总是通过score来进行排序着. 并且可以取得一定范围内的元素. 在Redis中大概 ...
- [BZOJ 2169]连边
Description 有N个点(编号1到N)组成的无向图,已经为你连了M条边.请你再连K条边,使得所有的点的度数都是偶数.求有多少种连的方法.要求你连的K条边中不能有重边,但和已经连好的边可以重.不 ...
- Prison 监狱
[题目描述]Caima 王国中有一个奇怪的监狱,这个监狱一共有 P 个牢房,这些牢房一字排开,第 i 个仅挨着第 i+1 个(最后一个除外).现在正好牢房是满的.上级下发了一个释放名单,要求每天释放名 ...