k-NN——算法实现
k-NN 没有特别的训练过程,给定训练集,标签,k,计算待预测特征到训练集的所有距离,选取前k个距离最小的训练集,k个中标签最多的为预测标签
约会类型分类、手写数字识别分类
- 计算输入数据到每一个训练数据的距离
- 选择前k个,判断其中类别最多的类作为预测类
import numpy as np
import operator
import matplotlib
import matplotlib.pyplot as plt
# inX: test data, N features (1xN)
# dataSet: M samples, N features (MxN)
# label: for M samples (1xM)
# k: k-Nearest Neighbor
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
distances = np.sum(diffMat**2, axis=1)**0.5
sortDistances = distances.argsort() # 计算距离
classCount = {}
for i in range(k):
voteLable = labels[sortDistances[i]]
classCount[voteLable] = classCount.get(voteLable, 0) + 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 找出最多投票的类
result = sortedClassCount[0][0]
# print("Predict: ", result)
return result
# 将一个文件写入矩阵,文件有4列,最后一列为labels,以\t间隔
def file2matrix(filename):
with open(filename) as f:
arrayLines = f.readlines()
# print(arrayLines) # 有\n
numberOfLines = len(arrayLines) # 将txt文件按行读入为一个list,一行为一个元素
returnMat = np.zeros((numberOfLines, 3))
classLabelVector = []
index = 0
for line in arrayLines:
line = line.strip()
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat, classLabelVector
# 画一些图
def ex3():
datingDateMat, datingLables = file2matrix("datingTestSet2.txt")
fig = plt.figure()
ax = fig.add_subplot(1,2,1)
ax.scatter(datingDateMat[:,1], datingDateMat[:,2], s=15.0*np.array(datingLables), c=15.0*np.array(datingLables))
ax2 = fig.add_subplot(1,2,2)
ax2.scatter(datingDateMat[:,0], datingDateMat[:,1], s=15.0*np.array(datingLables), c=15.0*np.array(datingLables))
plt.show()
# 将数据集归一化[0 1]之间 (value - min)/(max - min)
def autoNorm(dataSet):
minVals = dataSet.min(axis=0)
maxVals = dataSet.max(axis=0)
ranges = maxVals - minVals
m = dataSet.shape[0]
normDataSet = dataSet - np.tile(minVals, (m,1))
normDataSet = normDataSet/np.tile(ranges, (m,1))
return normDataSet, ranges, minVals
# 分类器,输入数据集,归一化参数,labels,70%作为训练集,30%测试集
def datingClassTest(normDataSet, ranges, minVals, labels):
m = normDataSet.shape[0]
numOfTrain = int(m*0.7)
trainIndex = np.arange(m)
np.random.shuffle(trainIndex)
dataSet = normDataSet[trainIndex[0:numOfTrain],:]
testSet = normDataSet[trainIndex[numOfTrain:],:]
labels = np.array(labels)
dataSetLabels = labels[trainIndex[0:numOfTrain]]
testSetLabels = labels[trainIndex[numOfTrain:]]
k = int(input("Input k: "))
results = []
for inX in testSet:
result = classify0(inX, dataSet, dataSetLabels, k)
results.append(result)
compResultsAndLable = np.argwhere(results==testSetLabels)
acc = len(compResultsAndLable)/len(testSetLabels)
print("Accuracy: {:.2f}".format(acc))
print("Error: {:.2f}".format(1-acc))
classList = ['not at all', 'in small doses', 'in large doses']
inX1 = float(input("1: percentage of time spent playing video games? "))
inX2 = float(input("2: frequent flier miles earned per year? "))
inX3 = float(input("3: liters of ice cream consumed per year? "))
inXUser = [inX1,inX2,inX3]
inXUser = (inXUser - minVals)/ranges
result = classify0(inXUser, dataSet, dataSetLabels, k)
print("Predict: ", classList[result])
if __name__ == '__main__':
# # -- ex1 --
# inX = [1, 1]
# dataSet = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
# labels = ['A', 'A', 'B', 'B']
# k = 3
# classify0(inX, dataSet, labels, k)
# # -- ex2 --
datingDateMat, datingLables = file2matrix("datingTestSet2.txt")
# # -- ex3 --
# ex3()
# #-- ex4 --
# normDataSet, ranges, minVals = autoNorm(datingDateMat)
# # -- ex5 --
# datingClassTest(normDataSet, ranges, minVals, datingLables)
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import os
import operator
def img2vector(filename):
with open(filename) as f:
lines = f.readlines()
return_vector = []
for line in lines:
line = line.strip()
for j in line:
return_vector.append(int(j))
return return_vector
# inX: test data, N features (1xN)
# dataSet: M samples, N features (MxN)
# label: for M samples (1xM)
# k: k-Nearest Neighbor
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
distances = np.sum(diffMat**2, axis=1)**0.5
sortDistances = distances.argsort() # 计算距离
classCount = {}
for i in range(k):
voteLable = labels[sortDistances[i]]
classCount[voteLable] = classCount.get(voteLable, 0) + 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 找出最多投票的类
result = sortedClassCount[0][0]
# print("Predict: ", result)
return result
def handwriting_class_test(data_set, training_labels, test_set, test_labels, k):
results = []
for i in range(len(test_set)):
result = classify0(test_set[i], data_set, training_labels, k)
results.append(result)
# print('predict: ', result, 'answer: ', test_labels[i])
compare_results = np.argwhere(results==test_labels)
acc = len(compare_results)/len(test_labels)
print("Accuracy: {:.5f}".format(acc))
print("Error: {:.5f}".format(1-acc))
if __name__ == '__main__':
dir_path = r'H:\ML\MachineLearninginAction\02kNN\digits'
training_path = os.path.join(dir_path, r'trainingDigits')
test_path = os.path.join(dir_path, r'testDigits')
training_files_list = os.listdir(training_path)
test_files_list = os.listdir(test_path)
# 计算训练集矩阵与labels
m = len(training_files_list)
# m = 5
data_set = np.zeros((m, 1024))
training_labels = np.zeros(m)
for i in range(m):
data_set[i] = img2vector(os.path.join(training_path, training_files_list[i]))
training_labels[i] = training_files_list[i].split('_')[0]
# 测试集矩阵与labels
mt = len(test_files_list)
test_set = np.zeros((mt,1024))
test_labels = np.zeros(mt)
for i in range(mt):
test_set[i] = img2vector(os.path.join(test_path, test_files_list[i]))
test_labels[i] = test_files_list[i].split('_')[0]
k = 3
handwriting_class_test(data_set, training_labels, test_set, test_labels, k)
k-NN——算法实现的更多相关文章
- kaggle赛题Digit Recognizer:利用TensorFlow搭建神经网络(附上K邻近算法模型预测)
一.前言 kaggle上有传统的手写数字识别mnist的赛题,通过分类算法,将图片数据进行识别.mnist数据集里面,包含了42000张手写数字0到9的图片,每张图片为28*28=784的像素,所以整 ...
- 机器学习实战笔记--k近邻算法
#encoding:utf-8 from numpy import * import operator import matplotlib import matplotlib.pyplot as pl ...
- 《机器学习实战》学习笔记一K邻近算法
一. K邻近算法思想:存在一个样本数据集合,称为训练样本集,并且每个数据都存在标签,即我们知道样本集中每一数据(这里的数据是一组数据,可以是n维向量)与所属分类的对应关系.输入没有标签的新数据后,将 ...
- [Machine-Learning] K临近算法-简单例子
k-临近算法 算法步骤 k 临近算法的伪代码,对位置类别属性的数据集中的每个点依次执行以下操作: 计算已知类别数据集中的每个点与当前点之间的距离: 按照距离递增次序排序: 选取与当前点距离最小的k个点 ...
- k近邻算法的Java实现
k近邻算法是机器学习算法中最简单的算法之一,工作原理是:存在一个样本数据集合,即训练样本集,并且样本集中的每个数据都存在标签,即我们知道样本集中每一数据和所属分类的对应关系.输入没有标签的新数据之后, ...
- 基本分类方法——KNN(K近邻)算法
在这篇文章 http://www.cnblogs.com/charlesblc/p/6193867.html 讲SVM的过程中,提到了KNN算法.有点熟悉,上网一查,居然就是K近邻算法,机器学习的入门 ...
- 聚类算法:K-means 算法(k均值算法)
k-means算法: 第一步:选$K$个初始聚类中心,$z_1(1),z_2(1),\cdots,z_k(1)$,其中括号内的序号为寻找聚类中心的迭代运算的次序号. 聚类中心的向量值可任意设 ...
- 从K近邻算法谈到KD树、SIFT+BBF算法
转自 http://blog.csdn.net/v_july_v/article/details/8203674 ,感谢july的辛勤劳动 前言 前两日,在微博上说:“到今天为止,我至少亏欠了3篇文章 ...
- Python实现kNN(k邻近算法)
Python实现kNN(k邻近算法) 运行环境 Pyhton3 numpy科学计算模块 计算过程 st=>start: 开始 op1=>operation: 读入数据 op2=>op ...
- 机器学习之K近邻算法(KNN)
机器学习之K近邻算法(KNN) 标签: python 算法 KNN 机械学习 苛求真理的欲望让我想要了解算法的本质,于是我开始了机械学习的算法之旅 from numpy import * import ...
随机推荐
- Servlet-ServletConfig类使用介绍
ServletConfig类(Servlet程序的配置信息类) Servlet 程序和 ServletConfig对象都是由 Tomcat负责创建,我们负责使用. Servlet 程序默认是第一次访问 ...
- Typora基础快捷键使用流程
Typora简介 Typora是一个所见即所得的Markdown格式文本编辑器,支持windows.macOS和GNU\Linux操作系统,包括对GitHub Flavored Markdown扩展格 ...
- linux正则sed实战案例详解
目录 1. 将nginx.conf中的注释全部去掉 2. 将nginx.com中每一行之前增加注释 3.要求一键修改本机的ip 4.将/etc/passwd中的root修改成ROOT 1. 将ngin ...
- linux文件权限全面解析
目录 linux文件权限全面解析 一:linux文件的权限有哪些? 1,权限分为3个部分 2,权限位 3,每一个权限拥有一个数字编号 4,在添加权限的时候,可以将权限加起来 5,linux添加权限命令 ...
- 学习Java第4天
今天所作的工作: 1.类 2.类的构造方法 3.静态变量 4.类的主方法 5.对象 今天没有完成昨天的工作安排,因为发现进入类之后的编程思想发生的变化,相对与c++的逻辑既有较大的相似性又有不同的性质 ...
- Redis性能分析思路
Redis性能分析有几个大的方向.分别是 (1)基准对比 (2)配置优化 (3)数据持久化 (4)键值优化 (5)缓存淘汰 (6)Redis集群 基准对比 在没有业务实例运行的情况下,在服务器上通过测 ...
- AT2645 [ARC076D] Exhausted?
解法一 引理:令一个二分图两部分别为 \(X, Y(|X| \le |Y|)\),若其存在完美匹配当且仅当 \(\forall S \subseteq X, f(S) \ge |S|\)(其中 \(f ...
- 如何修改TOMCAT的默认主页为你自己项目的主页
感谢作者:xxs673076773 原文链接:https://www.iteye.com/blog/xxs673076773-1134805 (最合适的) 最直接的办法是,删掉tomcat下原有Roo ...
- js判断变量是否为空字符串、null、undefined
let _isEmpty = (input) => { return input + '' === 'null' || input + '' === 'undefined' || input.t ...
- 增删改查简单的sql语句
insert INSERT INTO t_stu (name,age) VALUES ('wang',12) INSERT INTO t_stu VALUES(NULL,' ...