一、什么是kNN算法

k邻近是指每个样本都可以用它最接近的k个邻居来代表。

核心思想:如果一个样本在特征空间中的k个最相邻的样本中大多数属于一个某类别,则该样本也属于这个类别。

二、将kNN封装成kNNClassifier

1、训练样本的特征在二维空间中的表示

  

2、kNN的训练过程如下图

  

3、完整代码(kNN.py)

import numpy as np
from math import sqrt
from collections import Counter
from metrics import accuracy_score class kNNClassifier():
def __init__(self, k):
"""初始化kNN分类器"""
assert k >= 1, "k must be valid"
self.k = k
self._x_train = None
self._y_train = None def fit(self, x_train, y_train):
"""根据训练集x_train和y_train训练kNN分类器"""
assert x_train.shape[0] == y_train.shape[0], \
"the size of x_train must be equal to the size of y_train"
assert x_train.shape[0] >= self.k, "the size of x_train must be at least k"
self._x_train = x_train
self._y_train = y_train
return self def predict(self, X_predict):
"""给定待预测数据集X_train,返回表示x_train的结果向量"""
assert self._x_train is not None and self._y_train is not None, \
"must fit before predict"
assert X_predict.shape[1] == self._x_train.shape[1] , \
"the feature number of X_predict must be equal to x_train"
y_predict = [self._predict(x) for x in X_predict]
return np.array(y_predict) def _predict(self, x):
"""给定待预测数据x,返回x预测的结果值"""
assert x.shape[0] == self._x_train.shape[1], \
"the feature number of x must be equal tu x_train"
distances = [sqrt(np.sum((x_train-x)**2)) for x_train in self._x_train]
nearest = np.argsort(distances)
topK_y = [self._y_train[i] for i in nearest[:self.k]]
votes = Counter(topK_y)
return votes.most_common(1)[0][0] def score(self, X_test, y_test):
"""根据数据集X_test 和y_test 得到当前模型的准确度"""
y_predict = self.predict(X_test)
return accuracy_score(y_test, y_predict) def __repr__(self):
return "kNN(k=%d)" % self.k if __name__ == "__main__":
x_train = np.array([[0.31864691, 0.99608349],
[0.8609734 , 0.40706129],
[0.86746155, 0.20136923],
[0.4346735 , 0.17677379],
[0.42842348, 0.68055183],
[0.70661963, 0.76155652],
[0.73379517, 0.6123456 ],
[0.68330672, 0.52193524],
[0.11192091, 0.07885633],
[0.99273292, 0.62484263]])
y_train = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
k = 6
x = np.array([0.756789,0.6123456])
knn = kNNClassifier(k)
knn.fit(x_train,y_train)
x_predict = x.reshape(1,-1)
print(knn.predict(x_predict))

三、测试结果

[1]

四、问题

1、如果直接将上面训练得到的模型直接放在真实环境中使用,但是模型没有得到验证,会造成模型很差,会有真实损失。

2、真实环境下很难拿到符合条件的数据去测试

解决办法:

1、将训练数据拿出一部分作为测试数据,通过测试数据直接判断模型好坏。

2、在模型进入真实环境前改进模型

1、train_test_split.py

import numpy as np

def train_test_split(X, Y, train_ratio=0.8, seed=None):
"""将数据X和Y按照train_ratio分割成x_train,y_train,x_test,y_test"""
assert X.shape[0] == Y.shape[0], "the size of X must equal to the size of Y"
assert 0.0 <= train_ratio <= 1.0, "train_ratio must be valid" if seed:
np.random.seed(seed) shuffled_indexes = np.random.permutation(len(X))
train_size = int(len(X) * train_ratio)
train_indexes = shuffled_indexes[:train_size]
test_indexes = shuffled_indexes[train_size:] x_train = X[train_indexes]
y_train = Y[train_indexes] x_test = X[test_indexes]
y_test = Y[test_indexes] return x_train,y_train,x_test,y_test

2、实际操作

2、从最终的结果来看,该模型与原始数据的标签的吻合达到100%。

五、scikit-learn中的train_test_split

 

仿scikit-learn模式写的kNN算法的更多相关文章

  1. 吴裕雄--天生自然python机器学习实战:K-NN算法约会网站好友喜好预测以及手写数字预测分类实验

    实验设备与软件环境 硬件环境:内存ddr3 4G及以上的x86架构主机一部 系统环境:windows 软件环境:Anaconda2(64位),python3.5,jupyter 内核版本:window ...

  2. Python 手写数字识别-knn算法应用

    在上一篇博文中,我们对KNN算法思想及流程有了初步的了解,KNN是采用测量不同特征值之间的距离方法进行分类,也就是说对于每个样本数据,需要和训练集中的所有数据进行欧氏距离计算.这里简述KNN算法的特点 ...

  3. 机器学习--kNN算法识别手写字母

    本文主要是用kNN算法对字母图片进行特征提取,分类识别.内容如下: kNN算法及相关Python模块介绍 对字母图片进行特征提取 kNN算法实现 kNN算法分析 一.kNN算法介绍 K近邻(kNN,k ...

  4. KNN算法识别手写数字

    需求: 利用一个手写数字“先验数据”集,使用knn算法来实现对手写数字的自动识别: 先验数据(训练数据)集: ♦数据维度比较大,样本数比较多. ♦ 数据集包括数字0-9的手写体. ♦每个数字大约有20 ...

  5. 基于OpenCV的KNN算法实现手写数字识别

    基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...

  6. KNN 算法-实战篇-如何识别手写数字

    公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...

  7. Python实现KNN算法及手写程序识别

    1.Python实现KNN算法 输入:inX:与现有数据集(1xN)进行比较的向量   dataSet:已知向量的大小m数据集(NxM)   个标签:数据集标签(1xM矢量)   k:用于比较的邻居数 ...

  8. Scikit Learn: 在python中机器学习

    转自:http://my.oschina.net/u/175377/blog/84420#OSC_h2_23 Scikit Learn: 在python中机器学习 Warning 警告:有些没能理解的 ...

  9. 【机器学习】机器学习入门01 - kNN算法

    0. 写在前面 近日加入了一个机器学习的学习小组,每周按照学习计划学习一个机器学习的小专题.笔者恰好近来计划深入学习Python,刚刚熟悉了其基本的语法知识(主要是与C系语言的差别),决定以此作为对P ...

随机推荐

  1. LeetCode_509.斐波那契数

    LeetCode-cn_509 509.斐波那契数 斐波那契数,通常用 F(n) 表示,形成的序列称为斐波那契数列.该数列由 0 和 1 开始,后面的每一项数字都是前面两项数字的和.也就是: F(0) ...

  2. centos下面配置key登录

    centos下需要配置使用key登录,并且要禁止root登录 下面的操作都是用root来设置的 1.添加新用户 例如用户名leisiyuan useradd leisiyuan 2.设置密码 pass ...

  3. 转 HTTP请求报文格式 GET和POST

    https://blog.csdn.net/h517604180/article/details/79802914 最近在做安卓客户端图片上传插件功能,供后台调用.其中涉及到了拼接HTTP请求报文,所 ...

  4. JS去重-删除连续重复的值

    function removeRepetition(str) { var result = "", unStr; for(var i=0,len=str.length;i<l ...

  5. struts2 2.5.16 通配符方式调用action中的方法报404

    1.问题描述 在struts.xml中配置用通配符方式调用action中的add()方法,访问 http://localhost:8080/Struts2Demo/helloworld_add.act ...

  6. Sklearn评估器选择

  7. 爬虫三之beautifulsoup

    基本使用 from bs4 import BeautifulSoup soup = BeautifulSoup(html#,'lxml','xml','html5lib') soup.prettify ...

  8. ActiveMQ学习教程/1.简要介绍与安装

    ActiveMQ学习教程(一)——简要介绍与安装 一.名词: 1.JMS:即Java消息服务(Java Message Service)应用程序接口,是一个Java平台中关于面向消息中间件(MOM)的 ...

  9. 堆排序 && Kth Largest Element in an Array

    堆排序 堆节点的访问 通常堆是通过一维数组来实现的.在数组起始位置为0的情形中: 父节点i的左子节点在位置(2*i+1); 父节点i的右子节点在位置(2*i+2); 子节点i的父节点在位置floor( ...

  10. Java本周总结1

    这两周我上认真的课应该就是李老师的课了/ 第一周主要跟我们讲述了java的发展史何java开发环境的搭建,带领我们走进了java,李老师的精彩讲述让我们对Java有了深刻的认识/. jdk下载安装包我 ...