仿scikit-learn模式写的kNN算法
一、什么是kNN算法
k邻近是指每个样本都可以用它最接近的k个邻居来代表。
核心思想:如果一个样本在特征空间中的k个最相邻的样本中大多数属于一个某类别,则该样本也属于这个类别。
二、将kNN封装成kNNClassifier
1、训练样本的特征在二维空间中的表示
、
2、kNN的训练过程如下图
3、完整代码(kNN.py)
import numpy as np
from math import sqrt
from collections import Counter
from metrics import accuracy_score class kNNClassifier():
def __init__(self, k):
"""初始化kNN分类器"""
assert k >= 1, "k must be valid"
self.k = k
self._x_train = None
self._y_train = None def fit(self, x_train, y_train):
"""根据训练集x_train和y_train训练kNN分类器"""
assert x_train.shape[0] == y_train.shape[0], \
"the size of x_train must be equal to the size of y_train"
assert x_train.shape[0] >= self.k, "the size of x_train must be at least k"
self._x_train = x_train
self._y_train = y_train
return self def predict(self, X_predict):
"""给定待预测数据集X_train,返回表示x_train的结果向量"""
assert self._x_train is not None and self._y_train is not None, \
"must fit before predict"
assert X_predict.shape[1] == self._x_train.shape[1] , \
"the feature number of X_predict must be equal to x_train"
y_predict = [self._predict(x) for x in X_predict]
return np.array(y_predict) def _predict(self, x):
"""给定待预测数据x,返回x预测的结果值"""
assert x.shape[0] == self._x_train.shape[1], \
"the feature number of x must be equal tu x_train"
distances = [sqrt(np.sum((x_train-x)**2)) for x_train in self._x_train]
nearest = np.argsort(distances)
topK_y = [self._y_train[i] for i in nearest[:self.k]]
votes = Counter(topK_y)
return votes.most_common(1)[0][0] def score(self, X_test, y_test):
"""根据数据集X_test 和y_test 得到当前模型的准确度"""
y_predict = self.predict(X_test)
return accuracy_score(y_test, y_predict) def __repr__(self):
return "kNN(k=%d)" % self.k if __name__ == "__main__":
x_train = np.array([[0.31864691, 0.99608349],
[0.8609734 , 0.40706129],
[0.86746155, 0.20136923],
[0.4346735 , 0.17677379],
[0.42842348, 0.68055183],
[0.70661963, 0.76155652],
[0.73379517, 0.6123456 ],
[0.68330672, 0.52193524],
[0.11192091, 0.07885633],
[0.99273292, 0.62484263]])
y_train = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
k = 6
x = np.array([0.756789,0.6123456])
knn = kNNClassifier(k)
knn.fit(x_train,y_train)
x_predict = x.reshape(1,-1)
print(knn.predict(x_predict))
三、测试结果
[1]
四、问题
1、如果直接将上面训练得到的模型直接放在真实环境中使用,但是模型没有得到验证,会造成模型很差,会有真实损失。
2、真实环境下很难拿到符合条件的数据去测试
解决办法:
1、将训练数据拿出一部分作为测试数据,通过测试数据直接判断模型好坏。
2、在模型进入真实环境前改进模型
1、train_test_split.py
import numpy as np def train_test_split(X, Y, train_ratio=0.8, seed=None):
"""将数据X和Y按照train_ratio分割成x_train,y_train,x_test,y_test"""
assert X.shape[0] == Y.shape[0], "the size of X must equal to the size of Y"
assert 0.0 <= train_ratio <= 1.0, "train_ratio must be valid" if seed:
np.random.seed(seed) shuffled_indexes = np.random.permutation(len(X))
train_size = int(len(X) * train_ratio)
train_indexes = shuffled_indexes[:train_size]
test_indexes = shuffled_indexes[train_size:] x_train = X[train_indexes]
y_train = Y[train_indexes] x_test = X[test_indexes]
y_test = Y[test_indexes] return x_train,y_train,x_test,y_test
2、实际操作
2、从最终的结果来看,该模型与原始数据的标签的吻合达到100%。
五、scikit-learn中的train_test_split
仿scikit-learn模式写的kNN算法的更多相关文章
- 吴裕雄--天生自然python机器学习实战:K-NN算法约会网站好友喜好预测以及手写数字预测分类实验
实验设备与软件环境 硬件环境:内存ddr3 4G及以上的x86架构主机一部 系统环境:windows 软件环境:Anaconda2(64位),python3.5,jupyter 内核版本:window ...
- Python 手写数字识别-knn算法应用
在上一篇博文中,我们对KNN算法思想及流程有了初步的了解,KNN是采用测量不同特征值之间的距离方法进行分类,也就是说对于每个样本数据,需要和训练集中的所有数据进行欧氏距离计算.这里简述KNN算法的特点 ...
- 机器学习--kNN算法识别手写字母
本文主要是用kNN算法对字母图片进行特征提取,分类识别.内容如下: kNN算法及相关Python模块介绍 对字母图片进行特征提取 kNN算法实现 kNN算法分析 一.kNN算法介绍 K近邻(kNN,k ...
- KNN算法识别手写数字
需求: 利用一个手写数字“先验数据”集,使用knn算法来实现对手写数字的自动识别: 先验数据(训练数据)集: ♦数据维度比较大,样本数比较多. ♦ 数据集包括数字0-9的手写体. ♦每个数字大约有20 ...
- 基于OpenCV的KNN算法实现手写数字识别
基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...
- KNN 算法-实战篇-如何识别手写数字
公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...
- Python实现KNN算法及手写程序识别
1.Python实现KNN算法 输入:inX:与现有数据集(1xN)进行比较的向量 dataSet:已知向量的大小m数据集(NxM) 个标签:数据集标签(1xM矢量) k:用于比较的邻居数 ...
- Scikit Learn: 在python中机器学习
转自:http://my.oschina.net/u/175377/blog/84420#OSC_h2_23 Scikit Learn: 在python中机器学习 Warning 警告:有些没能理解的 ...
- 【机器学习】机器学习入门01 - kNN算法
0. 写在前面 近日加入了一个机器学习的学习小组,每周按照学习计划学习一个机器学习的小专题.笔者恰好近来计划深入学习Python,刚刚熟悉了其基本的语法知识(主要是与C系语言的差别),决定以此作为对P ...
随机推荐
- 解析JSON有俩种方式:JSONObject和GSON
JSONObject: //JSONObject解析JSON文件 private void parseJSONWithJSONObject(String json_data) { try { JSON ...
- multiple users to one ec2 instance setup
http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/managing-users.html usually when use pem file as ...
- PixelShuffle
- Python基础面试题库
Python基础面试题库 Python是一门学习曲线较为容易的编程语言,随着人工智能时代的到来,Python迎来了新一轮的高潮.目前,国内知乎.网易(游戏).腾讯(某些网站).搜狐(邮箱).金山. ...
- linux下面用Mingw编译libx264
linux下面用Mingw编译libx264 首先要先安装好mingw 我用的是Ubuntu 编译ffmpeg的时候 ,官方上面有一个自动化脚本能够把mingw安装好 这里就不说了 新版本的libx2 ...
- leetcode-mid-design-297. Serialize and Deserialize Binary Tree¶-NO -??
mycode 将list转换成树的时候没有思路 参考: deque 是双边队列(double-ended queue),具有队列和栈的性质,在 list 的基础上增加了移动.旋转和增删等 class ...
- iOS堆栈内存区别
堆和栈的区别: · 1> 堆空间的内存是动态分配的,一般存放对象,并且需要手动释放内存. · 2> 栈空间的内存由系统自动分配,一般存放局部变量等,不需要手动管理内存. 接下来我将从以下几 ...
- C语言字符数组详解
字符串的存储方式有字符数组和字符指针,我们先来看看字符数组. 因为字符串是由多个字符组成的序列,所以要想存储一个字符串,可以先把它拆成一个个字符,然后分别对这些字符进行存储,即通过字符数组存储.字符数 ...
- 2、node-webkit运行web应用,node-webkit把web应用打包成桌面应用
下面我通过一个简单的demo来介绍怎么样把一个web应用打包成一个可执行文件(这里只介绍windows环境) 首先新建一个index.html文件,作为我们这个demo的入口页面,我们暂且就把这个页面 ...
- 细数EDM营销中存在的两大盲点
国庆节了,祝大家国庆快乐,转眼博客至今已有三年了.下面博主为大家介绍EDM营销中存在的两大盲点,供大家参考. 一是忽略用户友好.用户友好策略是Email营销成功的关键要素,具体包括内容友好策略.方式友 ...