08.手写KNN算法测试
导入库
import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
导入数据
iris = datasets.load_iris()
数据准备
X = iris.data
y = iris.target
X.shape, y.shape
((150, 4), (150,))
数据分割(28开)
# 因为训练集矩阵和标签向量是分割的,不能单独对某一个进行乱序
# 需要将其合并整体乱序再分割
X_join_y = np.hstack([X, y.reshape(-1,1)])
# 随机,导致每次数据分割结果都会改变
# 如果有debug需求,需要保证每次运行的分割结果一致
# 则需要对random进行seed设置
np.random.seed(1)
np.random.shuffle(X_join_y)
train,test = np.vsplit(X_join_y, [int(0.8*len(X_join_y))])
train.shape,test.shape
((120, 5), (30, 5))
准备data和target
# X_train, y_train, X_test, y_test 成功拿到了训练集(数据+标签)和测试集(数据+标签)
X_train = train[:,0:4]
y_train = train[:,-1]
X_test = test[:,0:4]
y_test = test[:,-1]
KNN手写算法
import numpy as np
from math import sqrt
from collections import Counter
class KNNClassifier: def __init__(self, k):
# 初始化KNN分类器
self.k = k
self._X_train = None
self._y_train = None def fit(self, X_train, y_train):
# 根据训练集X_train, Y_train训练分类器
self._X_train = X_train
self._y_train = y_train
return self def predict(self, X_predict):
# 给定待遇测的数据集X_predict,返回表示X_predict的结果向量
y_predict = [self._predict(x) for x in X_predict]
return np.array(y_predict) def _predict(self, x):
# 给定单个待遇测数据x,返回x的预测结果值
distances = [sqrt(np.sum((x_train - x) ** 2)) for x_train in self._X_train]
nearest = np.argsort(distances)
topK_y = [self._y_train[i] for i in nearest[:self.k]]
votes = Counter(topK_y)
return votes.most_common(1)[0][0] def __repr__(self):
return "KNN=(%d)" % self.k
from sklearn.model_selection import train_test_split result = train_test_split(X, y)
result
[array([[7.2, 3. , 5.8, 1.6],
[5.4, 3.9, 1.3, 0.4],
[6.5, 3.2, 5.1, 2. ],
[6.1, 3. , 4.6, 1.4],
[4.6, 3.2, 1.4, 0.2],
[6.9, 3.2, 5.7, 2.3],
[6.1, 2.8, 4. , 1.3],
[5.7, 3. , 4.2, 1.2],
[5.8, 2.7, 4.1, 1. ],
[5.5, 2.5, 4. , 1.3],
[5.7, 2.5, 5. , 2. ],
[4.6, 3.4, 1.4, 0.3],
[5.9, 3.2, 4.8, 1.8],
[6.3, 2.9, 5.6, 1.8],
[6.8, 3. , 5.5, 2.1],
[6.4, 2.7, 5.3, 1.9],
[6. , 2.9, 4.5, 1.5],
[6. , 2.2, 4. , 1. ],
[4.8, 3. , 1.4, 0.1],
[5.6, 2.5, 3.9, 1.1],
[7.1, 3. , 5.9, 2.1],
[6.7, 3.3, 5.7, 2.1],
[5.5, 2.6, 4.4, 1.2],
[6.3, 3.3, 4.7, 1.6],
[6.7, 3.1, 4.7, 1.5],
[4.3, 3. , 1.1, 0.1],
[4.8, 3.4, 1.9, 0.2],
[6.7, 3.3, 5.7, 2.5],
[6. , 2.7, 5.1, 1.6],
[6.5, 3. , 5.5, 1.8],
[4.9, 2.5, 4.5, 1.7],
[5. , 3.5, 1.3, 0.3],
[5.9, 3. , 4.2, 1.5],
[5.5, 2.4, 3.8, 1.1],
[6.2, 2.2, 4.5, 1.5],
[6.3, 2.7, 4.9, 1.8],
[4.4, 3. , 1.3, 0.2],
[7.7, 3. , 6.1, 2.3],
[7. , 3.2, 4.7, 1.4],
[6.4, 2.8, 5.6, 2.2],
[5.7, 2.8, 4.5, 1.3],
[6.4, 2.9, 4.3, 1.3],
[5.6, 3. , 4.1, 1.3],
[6.3, 2.8, 5.1, 1.5],
[4.9, 3.6, 1.4, 0.1],
[6. , 3.4, 4.5, 1.6],
[5.7, 4.4, 1.5, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.4, 3.7, 1.5, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5. , 2.3, 3.3, 1. ],
[6.9, 3.1, 4.9, 1.5],
[5.1, 3.8, 1.9, 0.4],
[6.4, 2.8, 5.6, 2.1],
[5.1, 3.8, 1.5, 0.3],
[5. , 3.4, 1.5, 0.2],
[5.1, 3.3, 1.7, 0.5],
[5.2, 2.7, 3.9, 1.4],
[6.1, 2.6, 5.6, 1.4],
[7.7, 2.8, 6.7, 2. ],
[5.8, 2.7, 5.1, 1.9],
[6.8, 2.8, 4.8, 1.4],
[4.4, 3.2, 1.3, 0.2],
[5.3, 3.7, 1.5, 0.2],
[6.9, 3.1, 5.4, 2.1],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.4, 3.1, 5.5, 1.8],
[6.2, 3.4, 5.4, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.3, 2.5, 4.9, 1.5],
[5.8, 2.6, 4. , 1.2],
[4.6, 3.1, 1.5, 0.2],
[4.9, 3.1, 1.5, 0.2],
[5.6, 2.9, 3.6, 1.3],
[5.1, 3.7, 1.5, 0.4],
[5. , 3.2, 1.2, 0.2],
[6.5, 3. , 5.8, 2.2],
[7.3, 2.9, 6.3, 1.8],
[5.2, 3.4, 1.4, 0.2],
[4.5, 2.3, 1.3, 0.3],
[5.5, 2.3, 4. , 1.3],
[6.5, 3. , 5.2, 2. ],
[5.5, 2.4, 3.7, 1. ],
[7.6, 3. , 6.6, 2.1],
[5. , 3.6, 1.4, 0.2],
[5.9, 3. , 5.1, 1.8],
[6.3, 2.5, 5. , 1.9],
[6.1, 3. , 4.9, 1.8],
[4.9, 3. , 1.4, 0.2],
[6.7, 3. , 5.2, 2.3],
[5.1, 3.5, 1.4, 0.3],
[6.3, 2.3, 4.4, 1.3],
[4.4, 2.9, 1.4, 0.2],
[6.8, 3.2, 5.9, 2.3],
[5.1, 3.8, 1.6, 0.2],
[7.2, 3.6, 6.1, 2.5],
[5.7, 3.8, 1.7, 0.3],
[5. , 2. , 3.5, 1. ],
[5. , 3. , 1.6, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[6.7, 3.1, 5.6, 2.4],
[5.8, 2.8, 5.1, 2.4],
[5.8, 4. , 1.2, 0.2],
[6.1, 2.8, 4.7, 1.2],
[5.4, 3.9, 1.7, 0.4],
[6.5, 2.8, 4.6, 1.5],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.4, 1.7, 0.2],
[4.9, 2.4, 3.3, 1. ],
[5.1, 3.4, 1.5, 0.2]]),
array([[6.2, 2.9, 4.3, 1.3],
[6.7, 3. , 5. , 1.7],
[5.2, 4.1, 1.5, 0.1],
[5.7, 2.6, 3.5, 1. ],
[7.4, 2.8, 6.1, 1.9],
[5.6, 3. , 4.5, 1.5],
[6.9, 3.1, 5.1, 2.3],
[6. , 2.2, 5. , 1.5],
[5.5, 3.5, 1.3, 0.2],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.2, 6. , 1.8],
[6. , 3. , 4.8, 1.8],
[5.2, 3.5, 1.5, 0.2],
[5.1, 3.5, 1.4, 0.2],
[5. , 3.3, 1.4, 0.2],
[5.6, 2.8, 4.9, 2. ],
[5.6, 2.7, 4.2, 1.3],
[5. , 3.5, 1.6, 0.6],
[7.9, 3.8, 6.4, 2. ],
[6.3, 3.4, 5.6, 2.4],
[5. , 3.4, 1.6, 0.4],
[6.2, 2.8, 4.8, 1.8],
[5.4, 3. , 4.5, 1.5],
[5.5, 4.2, 1.4, 0.2],
[4.6, 3.6, 1. , 0.2],
[6.1, 2.9, 4.7, 1.4],
[6.4, 3.2, 5.3, 2.3],
[5.7, 2.9, 4.2, 1.3],
[7.7, 2.6, 6.9, 2.3],
[7.7, 3.8, 6.7, 2.2],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 3.9, 1.2],
[6.6, 2.9, 4.6, 1.3],
[4.7, 3.2, 1.6, 0.2],
[6.7, 3.1, 4.4, 1.4],
[6.4, 3.2, 4.5, 1.5],
[4.7, 3.2, 1.3, 0.2],
[6.6, 3. , 4.4, 1.4]]),
array([2, 0, 2, 1, 0, 2, 1, 1, 1, 1, 2, 0, 1, 2, 2, 2, 1, 1, 0, 1, 2, 2,
1, 1, 1, 0, 0, 2, 1, 2, 2, 0, 1, 1, 1, 2, 0, 2, 1, 2, 1, 1, 1, 2,
0, 1, 0, 0, 0, 0, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2, 2, 1, 0, 0, 2, 1,
1, 2, 2, 2, 1, 1, 0, 0, 1, 0, 0, 2, 2, 0, 0, 1, 2, 1, 2, 0, 2, 2,
2, 0, 2, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0, 0, 2, 2, 0, 1, 0, 1, 0, 0,
1, 0]),
array([1, 1, 0, 1, 2, 1, 2, 2, 0, 2, 2, 2, 0, 0, 0, 2, 1, 0, 2, 2, 0, 2,
1, 0, 0, 1, 2, 1, 2, 2, 2, 1, 1, 0, 1, 1, 0, 1])]
my_knn_clf = KNNClassifier(k=3)
my_knn_clf.fit(result[0], result[2])
KNN=(3)
y_predict = my_knn_clf.predict(result[1])
sum(y_predict == result[3])
sum(y_predict == result[3])/len(result[3])
08.手写KNN算法测试的更多相关文章
- [纯C#实现]基于BP神经网络的中文手写识别算法
效果展示 这不是OCR,有些人可能会觉得这东西会和OCR一样,直接进行整个字的识别就行,然而并不是. OCR是2维像素矩阵的像素数据.而手写识别不一样,手写可以把用户写字的笔画时间顺序,抽象成一个维度 ...
- 用C实现单隐层神经网络的训练和预测(手写BP算法)
实验要求:•实现10以内的非负双精度浮点数加法,例如输入4.99和5.70,能够预测输出为10.69•使用Gprof测试代码热度 代码框架•随机初始化1000对数值在0~10之间的浮点数,保存在二维数 ...
- 手写KMeans算法
KMeans算法是一种无监督学习,它会将相似的对象归到同一类中. 其基本思想是: 1.随机计算k个类中心作为起始点. 将数据点分配到理其最近的类中心. 3.移动类中心. 4.重复2,3直至类中心不再改 ...
- 手写k-means算法
作为聚类的代表算法,k-means本属于NP难问题,通过迭代优化的方式,可以求解出近似解. 伪代码如下: 1,算法部分 距离采用欧氏距离.参数默认值随意选的. import numpy as np d ...
- Javascript 手写 LRU 算法
LRU 是 Least Recently Used 的缩写,即最近最少使用.作为一种经典的缓存策略,它的基本思想是长期不被使用的数据,在未来被用到的几率也不大,所以当新的数据进来时我们可以优先把这些数 ...
- 手写LRU算法
import java.util.LinkedHashMap; import java.util.Map; public class LRUCache<K, V> extends Link ...
- 手写hashmap算法
/** * 01.自定义一个hashmap * 02.实现put增加键值对,实现key重复时替换key的值 * 03.重写toString方法,方便查看map中的键值对信息 * 04.实现get方法, ...
- 基于kNN的手写字体识别——《机器学习实战》笔记
看完一节<机器学习实战>,算是踏入ML的大门了吧!这里就详细讲一下一个demo:使用kNN算法实现手写字体的简单识别 kNN 先简单介绍一下kNN,就是所谓的K-近邻算法: [作用原理]: ...
- 手写BP(反向传播)算法
BP算法为深度学习中参数更新的重要角色,一般基于loss对参数的偏导进行更新. 一些根据均方误差,每层默认激活函数sigmoid(不同激活函数,则更新公式不一样) 假设网络如图所示: 则更新公式为: ...
随机推荐
- mysql-mysqli_fetch_all(错误)
mysql-mysqli_fetch_all(错误) 问题:使用mysql-mysqli_fetch_all获取返回的数组时失败/报错 原因及解决方法: mysqli_fetch_all函数只存在于m ...
- Java 复习整理day03
变量或者是常量, 只能用来存储一个数据, 例如: 存储一个整数, 小数或者字符串等. 如果需要同时存储多个同类型的数据, 用变量或者常量来实现的话, 非常的繁琐. 针对于 这种情况, 我们就可以通过数 ...
- Luogu T7468 I liked Matrix!
题目链接 题目背景 无 题目描述 在一个n*m 的矩阵A 的所有位置中随机填入0 或1,概率比为x : y.令B[i]=a[i][1]+a[i][2]+......+a[i][m],求min{B[i] ...
- UESTC 1218 Pick The Sticks
Time Limit: 15000/10000MS (Java/Others) Memory Limit: 65535/65535KB (Java/Others) Submit Status ...
- Codeforces Round #274 (Div. 2) C. Exams (贪心)
题意:给\(n\)场考试的时间,每场考试可以提前考,但是记录的是原来的考试时间,问你如何安排考试,使得考试的记录时间递增,并且最后一场考试的时间最早. 题解:因为要满足记录的考试时间递增,所以我们用结 ...
- Codeforces Round #669 (Div. 2) B. Big Vova (枚举)
题意:有一个长度为\(n\)的序列,你需要对其重新排序,构造一个新数组\(c\),\(c_{i}=gcd(a_{1},...,a{i})\)并且使得\(c\)的字典序最小. 题解:直接跑\(n\)次, ...
- Educational Codeforces Round 94 (Rated for Div. 2) B. RPG Protagonist (数学)
题意:你和你的随从去偷剑和战斧,你可以最多可以拿\(p\)重的东西,随从可以拿\(f\)重的东西,总共有\(cnt_{s}\)把剑,\(cnt_{w}\)把战斧,每把剑重\(s\),战斧重\(w\), ...
- 使用Github+jsDelivr搭建图床和存储服务
使用元素 我的博客NLNet 并未搭建自己的博客,使用博客园(cnblogs),自定义了主题NLNet-Theme. 写作工具Typora 优秀的Markdown编辑器.参考NLNet-Theme,我 ...
- Xtrabackup 物理备份
目录 Xtrabackup 安装 Xtrabackup 备份介绍 Xtrabackup全量备份 准备备份目录 全量备份 查看全量备份内容 Xtrabackup 全量备份恢复数据 删除所有数据库 停止数 ...
- c++大整数
这里不是必须用c++的话不推荐用c++大整数,py和java的支持要好得多. 大整数类 (非负) #include <iostream> #include <vector> ...