k-NN 没有特别的训练过程,给定训练集,标签,k,计算待预测特征到训练集的所有距离,选取前k个距离最小的训练集,k个中标签最多的为预测标签

约会类型分类、手写数字识别分类

  1. 计算输入数据到每一个训练数据的距离
  2. 选择前k个,判断其中类别最多的类作为预测类
  1. import numpy as np
  2. import operator
  3. import matplotlib
  4. import matplotlib.pyplot as plt
  5. # inX: test data, N features (1xN)
  6. # dataSet: M samples, N features (MxN)
  7. # label: for M samples (1xM)
  8. # k: k-Nearest Neighbor
  9. def classify0(inX, dataSet, labels, k):
  10. dataSetSize = dataSet.shape[0]
  11. diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
  12. distances = np.sum(diffMat**2, axis=1)**0.5
  13. sortDistances = distances.argsort() # 计算距离
  14. classCount = {}
  15. for i in range(k):
  16. voteLable = labels[sortDistances[i]]
  17. classCount[voteLable] = classCount.get(voteLable, 0) + 1
  18. sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 找出最多投票的类
  19. result = sortedClassCount[0][0]
  20. # print("Predict: ", result)
  21. return result
  22. # 将一个文件写入矩阵,文件有4列,最后一列为labels,以\t间隔
  23. def file2matrix(filename):
  24. with open(filename) as f:
  25. arrayLines = f.readlines()
  26. # print(arrayLines) # 有\n
  27. numberOfLines = len(arrayLines) # 将txt文件按行读入为一个list,一行为一个元素
  28. returnMat = np.zeros((numberOfLines, 3))
  29. classLabelVector = []
  30. index = 0
  31. for line in arrayLines:
  32. line = line.strip()
  33. listFromLine = line.split('\t')
  34. returnMat[index,:] = listFromLine[0:3]
  35. classLabelVector.append(int(listFromLine[-1]))
  36. index += 1
  37. return returnMat, classLabelVector
  38. # 画一些图
  39. def ex3():
  40. datingDateMat, datingLables = file2matrix("datingTestSet2.txt")
  41. fig = plt.figure()
  42. ax = fig.add_subplot(1,2,1)
  43. ax.scatter(datingDateMat[:,1], datingDateMat[:,2], s=15.0*np.array(datingLables), c=15.0*np.array(datingLables))
  44. ax2 = fig.add_subplot(1,2,2)
  45. ax2.scatter(datingDateMat[:,0], datingDateMat[:,1], s=15.0*np.array(datingLables), c=15.0*np.array(datingLables))
  46. plt.show()
  47. # 将数据集归一化[0 1]之间 (value - min)/(max - min)
  48. def autoNorm(dataSet):
  49. minVals = dataSet.min(axis=0)
  50. maxVals = dataSet.max(axis=0)
  51. ranges = maxVals - minVals
  52. m = dataSet.shape[0]
  53. normDataSet = dataSet - np.tile(minVals, (m,1))
  54. normDataSet = normDataSet/np.tile(ranges, (m,1))
  55. return normDataSet, ranges, minVals
  56. # 分类器,输入数据集,归一化参数,labels,70%作为训练集,30%测试集
  57. def datingClassTest(normDataSet, ranges, minVals, labels):
  58. m = normDataSet.shape[0]
  59. numOfTrain = int(m*0.7)
  60. trainIndex = np.arange(m)
  61. np.random.shuffle(trainIndex)
  62. dataSet = normDataSet[trainIndex[0:numOfTrain],:]
  63. testSet = normDataSet[trainIndex[numOfTrain:],:]
  64. labels = np.array(labels)
  65. dataSetLabels = labels[trainIndex[0:numOfTrain]]
  66. testSetLabels = labels[trainIndex[numOfTrain:]]
  67. k = int(input("Input k: "))
  68. results = []
  69. for inX in testSet:
  70. result = classify0(inX, dataSet, dataSetLabels, k)
  71. results.append(result)
  72. compResultsAndLable = np.argwhere(results==testSetLabels)
  73. acc = len(compResultsAndLable)/len(testSetLabels)
  74. print("Accuracy: {:.2f}".format(acc))
  75. print("Error: {:.2f}".format(1-acc))
  76. classList = ['not at all', 'in small doses', 'in large doses']
  77. inX1 = float(input("1: percentage of time spent playing video games? "))
  78. inX2 = float(input("2: frequent flier miles earned per year? "))
  79. inX3 = float(input("3: liters of ice cream consumed per year? "))
  80. inXUser = [inX1,inX2,inX3]
  81. inXUser = (inXUser - minVals)/ranges
  82. result = classify0(inXUser, dataSet, dataSetLabels, k)
  83. print("Predict: ", classList[result])
  84. if __name__ == '__main__':
  85. # # -- ex1 --
  86. # inX = [1, 1]
  87. # dataSet = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
  88. # labels = ['A', 'A', 'B', 'B']
  89. # k = 3
  90. # classify0(inX, dataSet, labels, k)
  91. # # -- ex2 --
  92. datingDateMat, datingLables = file2matrix("datingTestSet2.txt")
  93. # # -- ex3 --
  94. # ex3()
  95. # #-- ex4 --
  96. # normDataSet, ranges, minVals = autoNorm(datingDateMat)
  97. # # -- ex5 --
  98. # datingClassTest(normDataSet, ranges, minVals, datingLables)
  1. import numpy as np
  2. import matplotlib
  3. import matplotlib.pyplot as plt
  4. import os
  5. import operator
  6. def img2vector(filename):
  7. with open(filename) as f:
  8. lines = f.readlines()
  9. return_vector = []
  10. for line in lines:
  11. line = line.strip()
  12. for j in line:
  13. return_vector.append(int(j))
  14. return return_vector
  15. # inX: test data, N features (1xN)
  16. # dataSet: M samples, N features (MxN)
  17. # label: for M samples (1xM)
  18. # k: k-Nearest Neighbor
  19. def classify0(inX, dataSet, labels, k):
  20. dataSetSize = dataSet.shape[0]
  21. diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
  22. distances = np.sum(diffMat**2, axis=1)**0.5
  23. sortDistances = distances.argsort() # 计算距离
  24. classCount = {}
  25. for i in range(k):
  26. voteLable = labels[sortDistances[i]]
  27. classCount[voteLable] = classCount.get(voteLable, 0) + 1
  28. sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 找出最多投票的类
  29. result = sortedClassCount[0][0]
  30. # print("Predict: ", result)
  31. return result
  32. def handwriting_class_test(data_set, training_labels, test_set, test_labels, k):
  33. results = []
  34. for i in range(len(test_set)):
  35. result = classify0(test_set[i], data_set, training_labels, k)
  36. results.append(result)
  37. # print('predict: ', result, 'answer: ', test_labels[i])
  38. compare_results = np.argwhere(results==test_labels)
  39. acc = len(compare_results)/len(test_labels)
  40. print("Accuracy: {:.5f}".format(acc))
  41. print("Error: {:.5f}".format(1-acc))
  42. if __name__ == '__main__':
  43. dir_path = r'H:\ML\MachineLearninginAction\02kNN\digits'
  44. training_path = os.path.join(dir_path, r'trainingDigits')
  45. test_path = os.path.join(dir_path, r'testDigits')
  46. training_files_list = os.listdir(training_path)
  47. test_files_list = os.listdir(test_path)
  48. # 计算训练集矩阵与labels
  49. m = len(training_files_list)
  50. # m = 5
  51. data_set = np.zeros((m, 1024))
  52. training_labels = np.zeros(m)
  53. for i in range(m):
  54. data_set[i] = img2vector(os.path.join(training_path, training_files_list[i]))
  55. training_labels[i] = training_files_list[i].split('_')[0]
  56. # 测试集矩阵与labels
  57. mt = len(test_files_list)
  58. test_set = np.zeros((mt,1024))
  59. test_labels = np.zeros(mt)
  60. for i in range(mt):
  61. test_set[i] = img2vector(os.path.join(test_path, test_files_list[i]))
  62. test_labels[i] = test_files_list[i].split('_')[0]
  63. k = 3
  64. handwriting_class_test(data_set, training_labels, test_set, test_labels, k)

k-NN——算法实现的更多相关文章

  1. kaggle赛题Digit Recognizer:利用TensorFlow搭建神经网络(附上K邻近算法模型预测)

    一.前言 kaggle上有传统的手写数字识别mnist的赛题,通过分类算法,将图片数据进行识别.mnist数据集里面,包含了42000张手写数字0到9的图片,每张图片为28*28=784的像素,所以整 ...

  2. 机器学习实战笔记--k近邻算法

    #encoding:utf-8 from numpy import * import operator import matplotlib import matplotlib.pyplot as pl ...

  3. 《机器学习实战》学习笔记一K邻近算法

     一. K邻近算法思想:存在一个样本数据集合,称为训练样本集,并且每个数据都存在标签,即我们知道样本集中每一数据(这里的数据是一组数据,可以是n维向量)与所属分类的对应关系.输入没有标签的新数据后,将 ...

  4. [Machine-Learning] K临近算法-简单例子

    k-临近算法 算法步骤 k 临近算法的伪代码,对位置类别属性的数据集中的每个点依次执行以下操作: 计算已知类别数据集中的每个点与当前点之间的距离: 按照距离递增次序排序: 选取与当前点距离最小的k个点 ...

  5. k近邻算法的Java实现

    k近邻算法是机器学习算法中最简单的算法之一,工作原理是:存在一个样本数据集合,即训练样本集,并且样本集中的每个数据都存在标签,即我们知道样本集中每一数据和所属分类的对应关系.输入没有标签的新数据之后, ...

  6. 基本分类方法——KNN(K近邻)算法

    在这篇文章 http://www.cnblogs.com/charlesblc/p/6193867.html 讲SVM的过程中,提到了KNN算法.有点熟悉,上网一查,居然就是K近邻算法,机器学习的入门 ...

  7. 聚类算法:K-means 算法(k均值算法)

    k-means算法:      第一步:选$K$个初始聚类中心,$z_1(1),z_2(1),\cdots,z_k(1)$,其中括号内的序号为寻找聚类中心的迭代运算的次序号. 聚类中心的向量值可任意设 ...

  8. 从K近邻算法谈到KD树、SIFT+BBF算法

    转自 http://blog.csdn.net/v_july_v/article/details/8203674 ,感谢july的辛勤劳动 前言 前两日,在微博上说:“到今天为止,我至少亏欠了3篇文章 ...

  9. Python实现kNN(k邻近算法)

    Python实现kNN(k邻近算法) 运行环境 Pyhton3 numpy科学计算模块 计算过程 st=>start: 开始 op1=>operation: 读入数据 op2=>op ...

  10. 机器学习之K近邻算法(KNN)

    机器学习之K近邻算法(KNN) 标签: python 算法 KNN 机械学习 苛求真理的欲望让我想要了解算法的本质,于是我开始了机械学习的算法之旅 from numpy import * import ...

随机推荐

  1. cocos2dx 入口函数分析

    以下是main函数最开始的两段,也是cocos2d一开始执行的地方: AppDelegate app; return Application::getInstance()->run(); 接下来 ...

  2. java继承基础内容

    1 /* 2 * 继承的好处: 3 * 1,提高了代码的复用性. 4 * 2,让类与类之间产生了关系,给第三个特征多态提供了前提. 5 * 6 * 7 * java中支持单继承.不直接支持多继承,但对 ...

  3. nginx配置支持websocket

    前两天折腾了下socketio,部署完发现通过nginx代理之后前端的socket无法和后端通信了,于是暴查一通,最后解决问题: location / { proxy_pass http://127. ...

  4. 羽夏看Win系统内核——句柄表篇

    写在前面   此系列是本人一个字一个字码出来的,包括示例和实验截图.由于系统内核的复杂性,故可能有错误或者不全面的地方,如有错误,欢迎批评指正,本教程将会长期更新. 如有好的建议,欢迎反馈.码字不易, ...

  5. NumPy 秘籍中文第二版·翻译完成

    原文:NumPy Cookbook - Second Edition 协议:CC BY-NC-SA 4.0 欢迎任何人参与和完善:一个人可以走的很快,但是一群人却可以走的更远. 在线阅读 Apache ...

  6. python继承关系中,类属性的修改

    class Grandfather(object): mylist = [] def __init__(self): pass class Father(Grandfather): pass Gran ...

  7. oracle修改密码、添加用户及授权

    解锁某个用户 sqlplus/as sysdba; alter user scott account unlock; 忘记密码处理 登录:sqlplus/as sysdba;修改:alter user ...

  8. Java枚举使用笔记

    原创:转载需注明原创地址 https://www.cnblogs.com/fanerwei222/p/11833790.html Java枚举简单使用示例: package com.shineyue. ...

  9. Java反射使用方法

    //简单的例子public class ReflextionMain { public static void main(String[] args) throws ClassNotFoundExce ...

  10. 北京太速科技-6U VPX 6槽 Full Mesh结构背板机箱

    一.产品概述 Ori6UVPX6SlotFullMesh264背板机箱 为 6U VPX 6槽 Full Mesh结构,用于实现PCIE.RapidIO IO 4X Full Mesh 高速数据互联, ...