本文实现了一个KNN算法,准备用作词频统计改进版本之中,这篇博文是从我另一个刚开的博客中copy过来的。

  KNN算法是一个简单的分类算法,它的动机特别简单:与一个样本点距离近的其他样本点绝大部分属于什么类别,这个样本就属于什么类别,算法的主要步骤如下:

1.计算新样本点与已知类别数据集中样本点的距离。

2.取前K个距离最近的(最相似的)点。

3.统计这K个点所在类别出现的频率。

4.选择出现频率最高的点作为新样本点的类别。

  KNN算法的优点在于一般精度高,对于异常的噪音数据不敏感。KNN一个明显的问题是当属于某个类别c的数据点在已知类别数据集中大量存在时,一个待预测的样本点的前K个最近的点总是存在很多类别c的点,解决这个问题的方法是计算类别的频率时,按照距离进行加权,使得离得近的点比离的远一些点更能影响类别频率排序的结果。

  KNN算法中K值的选定非常影响最后的结果,通常可以使用交叉检验来选取合适的k。下面是仿照sikit-learn的KNeighborsClassifier的调用方式写的KNN:

class KNN_Classifier:
def __init__(self, k):
self.k = k
self.train_data = None
self.train_labels = None def fit(self, train_data, train_labels):
self.train_data = normalize(train_data)
self.train_labels = train_labels def predict(self, test_data):
if (self.train_data is None) | (self.train_labels is None):
print 'fit train data first!'
pre_labels = []
train_data_size = len(self.train_labels)
# for every data point in test set
for x in normalize(test_data):
# calculate distance
sq_diff_mat = (np.tile(x, (train_data_size, 1)) - self.train_data) ** 2
distances = np.sum(sq_diff_mat, axis=1) ** .5
# get lowest k distances
sorted_dis_indices = distances.argsort()[0: self.k]
# count the times class occur
class_counts = {}
for idx in sorted_dis_indices:
label = labels[idx]
class_counts[label] = class_counts.get(label, 0) + 1
# sort class_count dict
sorted_class = sorted(class_counts.items(), key=lambda d: d[1], reverse=True)
# add max voted class to pre_labels
pre_labels.append(sorted_class[0][0])
return pre_labels

测试代码如下所示

    #  load data
data, labels = load_dating_data()
# split data into train set and test set
split_pos = int(len(labels) * 0.9)
train_data = normalize(data[0: split_pos])
train_labels = labels[0: split_pos]
test_data = normalize(data[split_pos: len(labels)])
test_labels = labels[split_pos: len(labels)]
# init classifier
classifier = KNN_Classifier(50)
# fit classifier
classifier.fit(train_data, train_labels)
# predict the class of test data and count error points
error_num = (test_labels != classifier.predict(test_data)).sum()
# calculate error rate and print
print 'error rate is %f' % (error_num * 1.0 / len(test_labels))

  这里使用machine learning in action中的提供的dating data,将90%的数据用作训练数据集,10%的数据用作测试集,选取k=50算法得到的错误率为0.08。

  下面我们来看一下如何使用scikit-learn提供的KNN实现。

scikit-learn中主要提供了2种KNN,KNeighborsClassifier和RadiusNeighborsClassifier。前者使用指定的前K个近邻来预测新样本点的类别,后者则是根据一个指定的半径,使用半径内所有的点来预测。创建一个KNN分类器时有这些重要的参数:

n_neighbors/radius: 使用近邻的个数K或半径

algorithm: 实现KNN的具体算法,如kd树等

metric: 距离的计算方法,默认为'minkowski'表示minkowski距离

p: minkowski距离中的参数p,p=1表示manhattan distance(l1范数),p=2表示euclidean_distance (l2范数)

  这里只列出了几个常用的参数,具体的可以参考链接。使用的方法和上面的测试代码类似,只需要将classifier替换成scikit-learn的实现就可以了。

KNN python实践的更多相关文章

  1. 机器学习算法与Python实践之(二)支持向量机(SVM)初级

    机器学习算法与Python实践之(二)支持向量机(SVM)初级 机器学习算法与Python实践之(二)支持向量机(SVM)初级 zouxy09@qq.com http://blog.csdn.net/ ...

  2. Python实践:开篇

    一.概述 Python实践 是应用Python解决实际问题的案例集合,这些案例中的Python应用通常 功能各异.大小不一. 该系列文章是本人应用Python的实践总结,会不定期更新. 二.目录 Py ...

  3. Python实践之(七)逻辑回归(Logistic Regression)

    机器学习算法与Python实践之(七)逻辑回归(Logistic Regression) zouxy09@qq.com http://blog.csdn.net/zouxy09 机器学习算法与Pyth ...

  4. 机器学习算法与Python实践之(四)支持向量机(SVM)实现

    机器学习算法与Python实践之(四)支持向量机(SVM)实现 机器学习算法与Python实践之(四)支持向量机(SVM)实现 zouxy09@qq.com http://blog.csdn.net/ ...

  5. 机器学习算法与Python实践之(三)支持向量机(SVM)进阶

    机器学习算法与Python实践之(三)支持向量机(SVM)进阶 机器学习算法与Python实践之(三)支持向量机(SVM)进阶 zouxy09@qq.com http://blog.csdn.net/ ...

  6. MapReduce 原理与 Python 实践

    MapReduce 原理与 Python 实践 1. MapReduce 原理 以下是个人在MongoDB和Redis实际应用中总结的Map-Reduce的理解 Hadoop 的 MapReduce ...

  7. 机器学习算法与Python实践之(五)k均值聚类(k-means)

    机器学习算法与Python实践这个系列主要是参考<机器学习实战>这本书.因为自己想学习Python,然后也想对一些机器学习算法加深下了解,所以就想通过Python来实现几个比较常用的机器学 ...

  8. KNN Python实现

    KNN Python实现 ''' k近邻(kNN)算法的工作机制比较简单,根据某种距离测度找出距离给定待测样本距离最小的k个训练样本,根据k个训练样本进行预测. 分类问题:k个点中出现频率最高的类别作 ...

  9. (转) K-Means聚类的Python实践

    本文转自: http://python.jobbole.com/87343/ K-Means聚类的Python实践 2017/02/11 · 实践项目 · K-means, 机器学习 分享到:1 原文 ...

随机推荐

  1. Qt读取文件时中文乱码问题

    在默认情况下,Qt 以 Unicode 格式处理文本字符,因此,字符本身是不会有问题的.之所以出现乱码,原因在于 Qt 不知道将 Unicode 字符以何种方式显示出来.        文本文件含有简 ...

  2. PAT A1076 Forwards on Weibo (30 分)——图的bfs

    Weibo is known as the Chinese version of Twitter. One user on Weibo may have many followers, and may ...

  3. linux公钥和私钥生成

    1. 路径 cd /root/.ssh 2. 输入命令 ssh-keygen -t rsa  按三次回车 3. 会依次生成  id_rsa 私钥和 id_rsa.pub 公钥 [root@insure ...

  4. TP4212 FM9836C

    今天拆了两个充电宝,发现充电宝的电路是由一个集成芯片控制的.芯片型号:TP4212, FM9836C,

  5. android 通讯类资料整理

    https://github.com/koush/AndroidAsync(websocket) https://github.com/loopj/android-async-http http:// ...

  6. java通过反射拷贝两个对象的同名同类型变量

    深拷贝和浅拷贝 首先对象的复制分为深拷贝和浅拷贝,关于这两者的区别,简单来说就是对于对象的引用,在拷贝的时候,是否会新开辟一块内存,还是直接复制引用. 两者的比较也有很多,具体可以看这篇文章: htt ...

  7. LiveCharts文档-3开始-5序列Series

    原文:LiveCharts文档-3开始-5序列Series LiveCharts文档-3开始-5序列Series Strokes和Fills 笔触和填充 所有的Series都有笔触和填充属来处理颜色, ...

  8. css实现按钮固定在底部

    实现类似如下图的功能: 采用如下的样式来控制:

  9. H+ 后台主题UI框架

    十年河东,十年河西,莫欺少年穷 学无止境,精益求精 今天得到了一个非常完美的后端管理系统框架:H+ 后台主题UI框架 H+ 后台主题UI框架 H+是一个完全响应式,基于Bootstrap3.3.6最新 ...

  10. Luogu P1337 [JSOI2004]平衡点 / 吊打XXX

    一道入门模拟退火的经典题,还是很考验RP的 首先我们发现神TM这道题又和物理扯上了关系,其实是一道求广义费马点的题目 首先我们可以根据物理知识得到,当系统处于平衡状态时,系统的总能量最小 又此时系统的 ...