原文地址:智能单元

K-Nearest Neighbor分类器

大家可能注意到了,为什么只用最相似的一张图片的标签来作为测试图像的标签呢?这不是很奇怪吗!是的,使用K-Nearest Neighbor分类器就能做得更好。它的思想很简单:与其只找最相近的那1个图片的标签,我们找最相似的k个图片标签,然后让他们针对测试图片进行投票,最后把票数最高的标签作为对测试图片的预测。所以当k=1时候,k-Nearest Neighbor分类器就是Nearest Neighbor分类器。从直观感受上就可以看到,更高的k值可以让分类的效果更平滑,使得分类器对于异常值更具抵抗力。

上面示例展示了Nearest Neighbor分类器和5-Nearest Neighbor分类的区别。例子使用了2维点来表示,分成3类(红、蓝和绿)。不同的颜色区域代表的是使用L2距离的分类器的决策边界。白色的区域是分类模糊的例子(即图像与两个以上的分类标签绑定)。需要注意的是,在NN分类器中,异常的数据点(比如:在蓝色区域的绿色)制造一个不正确预测的孤岛。5-NN分类器将这些不规则都平滑了,使得它针对测试数据的泛化能力更好。注意,5-NN也存在一些灰色区域,这些区域是因为近邻标签的最高票数相同导致的(比如:2个邻居是红色,2个邻居是蓝色,还有1个是绿色)。

在实际中,大多使用k-NN分类器。但是k值如何确定呢?接下来就讨论这个问题。

我们可以选择不同的距离函数,比如L1范数和L2范数,那么选着哪一个更好呢?还有不少选择我们甚至连考虑都没有考虑(比如点积)。所以这些选择,被称为超参数。在基于数据进行学习的机器学习算法设计中,超参数是很正常见的。一般来说,这些超参数具体怎么设置或取值并不是显而易见的。

你可能会建议尝试不同的值,看哪个值表现最好就选哪个。好主意!我们就是这么做的,但这样做的时候要非常细心。特别注意:决不能使用测试集来进行调优。当你在设计机器学习算法的时候,应该把测试集看做非常珍贵的资源,不到最后一步,绝不使用它。如果你使用测试集来调优,而且算法看起来效果不错,那么真正的危险在于:算法实际部署后,性能可能会远低于预期。这种情况,称之为算法对测试集过拟合。从另一个角度来说,如果使用测试集来调优,实际上就是把测试集当做训练集,由测试集训练出来的算法再跑测试集,自然性能看起来会很好。这其实是过于乐观了,实际部署起来效果就会差很多。所以,最终测试的时候再使用测试集,可以很好地近似度量你所设计的分类器的泛化性能(在接下来的课程中会有很多关于泛化性能的讨论)。

测试数据集只使用一次,即在训练完成后评价最终的模型时使用

好在我们有不用测试集调优的方法。其思路是:从训练集中取出一部分数据用来调优,我们称之为验证集(validation set)。以CIFAR-10为例,我们可以用49000个图像作为训练集,用1000个图像作为验证集。验证集其实就是作为假的测试集来调优。下面就是代码:

  1. #coding=utf-8
  2. import numpy as np
  3. import cPickle as pickle
  4. import os
  5. from collections import Counter
  6.  
  7. def load_CIFAR_batch(filename):
  8. """ load single batch of cifar """
  9. #rb二进制读文件
  10. with open(filename, 'rb') as f:
  11. datadict = pickle.load(f)
  12. X = datadict['data']
  13. Y = datadict['labels']
  14. # 生成一个四维数组X,并用transpose对维度进行排序
  15. X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
  16. Y = np.array(Y)
  17. return X, Y
  18. def load_CIFAR10(ROOT):
  19. """ load all of cifar """
  20. xs = []
  21. ys = []
  22. for b in range(1,6):
  23. f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
  24. X, Y = load_CIFAR_batch(f)
  25. xs.append(X)
  26. ys.append(Y)
  27. #用concatenate(array,axis=0)对xs的第一维度(即axis=0)进行合并处理,生成总数组Xtr
  28. Xtr = np.concatenate(xs)
  29. Ytr = np.concatenate(ys)
  30. del X, Y
  31. #处理test文件
  32. Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
  33. return Xtr, Ytr, Xte, Yte
  34.  
  35. class KNearestNeighbor(object):
  36.  
  37. def __init__(self):
  38. pass
  39. def train(self, X, y):
  40. self.X_train = X
  41. self.y_train = y
  42.  
  43. def predict(self, X, k=1, num_loops=1):
  44. if num_loops == 0:
  45. dists = self.compute_distances_no_loops(X)
  46. elif num_loops == 1:
  47. dists = self.compute_distances_one_loop(X)
  48. elif num_loops == 2:
  49. dists = self.compute_distances_two_loops(X)
  50. else:
  51. raise ValueError('Invalid value %d for num_loops' % num_loops)
  52. return self.predict_labels(dists, k=k)
  53.  
  54. def compute_distances_two_loops(self, X):
  55. num_test = X.shape[0]
  56. num_train = self.X_train.shape[0]
  57. dists = np.zeros((num_test, num_train))
  58. for i in xrange(num_test):
  59. for j in xrange(num_train):
  60. dict[i, j] = np.sqrt(np.sum(np.square(self.X_train[j, :] - X[i, :]), axis=1))
  61. return dists
  62.  
  63. def compute_distances_one_loop(self, X):
  64. num_test = X.shape[0]
  65. num_train = self.X_train.shape[0]
  66. dists = np.zeros((num_test, num_train))
  67. for i in xrange(num_test):
  68. dists[i, :] = np.sqrt(np.sum(np.square(self.X_train - X[i, :]), axis=1))
  69. return dists
  70.  
  71. def compute_distances_no_loops(self, X):
  72. num_test = X.shape[0]
  73. num_train = self.X_train.shape[0]
  74. dists = np.zeros((num_test, num_train))
  75. dists = np.multiply(np.dot(X,self.X_train.T),-2)#利用(x1-x2)^2=x1^2-2x1x2+x2^2
  76. distssqx=np.sum(np.square(X),axis=1)
  77. distssqxtr = np.sum(np.square(self.X_train), axis=1)
  78. dists=np.add(dists, distssqx)
  79. dists = np.add(dists, distssqxtr)
  80. return dists
  81.  
  82. def predict_labels(self, dists, k=1):
  83. num_test = dists.shape[0]
  84. y_pred = np.zeros(num_test)
  85. for i in xrange(num_test):
  86. closest_y = []
  87. closest_y = self.y_train[np.argsort(dists[i, :])[:k]].flatten()
  88. #计数器函数
  89. c = Counter(closest_y)
  90. y_pred[i]=c.most_common(1)[0][0]
  91. return y_pred
  92. Xtr, Ytr, Xte, Yte = load_CIFAR10('D:/python_projects/NN/cifar-10-batches-py')
  93. Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3)
  94. Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3)
  95. Xval_rows = Xtr_rows[:1000, :]
  96. Yval = Ytr[:1000]
  97. Xtr_rows = Xtr_rows[1000:, :]
  98. Ytr = Ytr[1000:]
  99. validation_accuracies = []
  100. for k in [1, 5, 20, 100]:
  101. nn = KNearestNeighbor()
  102. nn.train(Xtr_rows, Ytr)
  103. Yval_predict = nn.predict(Xval_rows, k=k)
  104. acc = np.mean(Yval_predict == Yval)
  105. print 'Acc: %f' % (acc,)
  106. validation_accuracies.append((k, acc))
  107. print validation_accuracies

程序结束后,我们会作图分析出哪个k值表现最好,然后用这个k值来跑真正的测试集,并作出对算法的评价。

把训练集分成训练集和验证集。使用验证集来对所有超参数调优。最后只在测试集上跑一次并报告结果。

交叉验证。有时候,训练集数量较小(因此验证集的数量更小),人们会使用一种被称为交叉验证的方法,这种方法更加复杂些。还是用刚才的例子,如果是交叉验证集,我们就不是取1000个图像,而是将训练集平均分成5份,其中4份用来训练,1份用来验证。然后我们循环着取其中4份来训练,其中1份来验证,最后取所有5次验证结果的平均值作为算法验证结果。

这就是5份交叉验证对k值调优的例子。针对每个k值,得到5个准确率结果,取其平均值,然后对不同k值的平均表现画线连接。本例中,当k=7的时算法表现最好(对应图中的准确率峰值)。如果我们将训练集分成更多份数,直线一般会更加平滑(噪音更少)。

在实际情况下,人们不是很喜欢用交叉验证,主要是因为它会消耗较多的计算资源。一般直接吧训练集按照50%-90%的比例分成训练集和验证集。但这也是根据具体情况来定的:如果超参数数量多,你可能就想用更大的验证的数量不够,那么最好还是用交叉验证吧。至于分成几份比较好,一般都是分成3、5和10份。

常用的数据分割模式。给出训练集和测试集后,训练集一般会被均分。这里是分成5份。前面4份用来训练,黄色那份用作验证集调优。如果采取交叉验证,那就各份轮流作为验证集。最后模型训练完毕,超参数都定好了,让模型跑一次(而且只跑一次)测试集,以此测试结果评价算法。

 Nearest Neighbor分类器的优劣

现在对Nearest Neighbor分类器的优缺点进行思考。首先,Nearest Neighbor分类器易于理解,实现简单。其次,算法的训练不需要花时间,因为其训练过程只是将训练集数据存储起来。然而测试要花费大量时间计算,因为每个测试图像需要和所有存储的训练图像进行比较,这显然是一个缺点。在实际应用中,我们关注测试效率远远高于训练效率。其实,我们后续要学习的卷积神经网络在这个权衡上走到了另一个极端:虽然训练花费很多时间,但是一旦训练完成,对新的测试数据进行分类非常快。这样的模式就符合实际使用需求。

Nearest Neighbor分类器的计算复杂度研究是一个活跃的研究领域,若干Approximate Nearest Neighbor (ANN)算法和库的使用可以提升Nearest Neighbor分类器在数据上的计算速度(比如:FLANN)。这些算法可以在准确率和时空复杂度之间进行权衡,并通常依赖一个预处理/索引过程,这个过程中一般包含kd树的创建和k-means算法的运用。

Nearest Neighbor分类器在某些特定情况(比如数据维度较低)下,可能是不错的选择。但是在实际的图像分类工作中,很少使用。因为图像都是高维度数据(他们通常包含很多像素),而高维度向量之间的距离通常是反直觉的。下面的图片展示了基于像素的相似和基于感官的相似是有很大不同的:

在高维度数据上,基于像素的的距离和感官上的非常不同。上图中,右边3张图片和左边第1张原始图片的L2距离是一样的。很显然,基于像素比较的相似和感官上以及语义上的相似是不同的。

这里还有个视觉化证据,可以证明使用像素差异来比较图像是不够的。z这是一个叫做t-SNE的可视化技术,它将CIFAR-10中的图片按照二维方式排布,这样能很好展示图片之间的像素差异值。在这张图片中,排列相邻的图片L2距离就小。

上图使用t-SNE的可视化技术将CIFAR-10的图片进行了二维排列。排列相近的图片L2距离小。可以看出,图片的排列是被背景主导而不是图片语义内容本身主导。

具体说来,这些图片的排布更像是一种颜色分布函数,或者说是基于背景的,而不是图片的语义主体。比如,狗的图片可能和青蛙的图片非常接近,这是因为两张图片都是白色背景。从理想效果上来说,我们肯定是希望同类的图片能够聚集在一起,而不被背景或其他不相关因素干扰。为了达到这个目的,我们不能止步于原始像素比较,得继续前进。

小结

简要说来:

  • 介绍了图像分类问题。在该问题中,给出一个由被标注了分类标签的图像组成的集合,要求算法能预测没有标签的图像的分类标签,并根据算法预测准确率进行评价。
  • 介绍了一个简单的图像分类器:最近邻分类器(Nearest Neighbor classifier)。分类器中存在不同的超参数(比如k值或距离类型的选取),要想选取好的超参数不是一件轻而易举的事。
  • 选取超参数的正确方法是:将原始训练集分为训练集和验证集,我们在验证集上尝试不同的超参数,最后保留表现最好那个。
  • 如果训练数据量不够,使用交叉验证方法,它能帮助我们在选取最优超参数的时候减少噪音。
  • 一旦找到最优的超参数,就让算法以该参数在测试集跑且只跑一次,并根据测试结果评价算法。
  • 最近邻分类器能够在CIFAR-10上得到将近40%的准确率。该算法简单易实现,但需要存储所有训练数据,并且在测试的时候过于耗费计算能力。
  • 最后,我们知道了仅仅使用L1和L2范数来进行像素比较是不够的,图像更多的是按照背景和颜色被分类,而不是语义主体分身。

在接下来的课程中,我们将专注于解决这些问题和挑战,并最终能够得到超过90%准确率的解决方案。该方案能够在完成学习就丢掉训练集,并在一毫秒之内就完成一张图片的分类。

小结:实际应用k-NN

如果你希望将k-NN分类器用到实处(最好别用到图像上,若是仅仅作为练手还可以接受),那么可以按照以下流程:

  1. 预处理你的数据:对你数据中的特征进行归一化(normalize),让其具有零平均值(zero mean)和单位方差(unit variance)。在后面的小节我们会讨论这些细节。本小节不讨论,是因为图像中的像素都是同质的,不会表现出较大的差异分布,也就不需要标准化处理了。
  2. 如果数据是高维数据,考虑使用降维方法,比如PCA(wiki refCS229refblog ref)或随机投影
  3. 将数据随机分入训练集和验证集。按照一般规律,70%-90% 数据作为训练集。这个比例根据算法中有多少超参数,以及这些超参数对于算法的预期影响来决定。如果需要预测的超参数很多,那么就应该使用更大的验证集来有效地估计它们。如果担心验证集数量不够,那么就尝试交叉验证方法。如果计算资源足够,使用交叉验证总是更加安全的(份数越多,效果越好,也更耗费计算资源)。
  4. 在验证集上调优,尝试足够多的k值,尝试L1和L2两种范数计算方式。
  5. 如果分类器跑得太慢,尝试使用Approximate Nearest Neighbor库(比如FLANN)来加速这个过程,其代价是降低一些准确率。
  6. 对最优的超参数做记录。记录最优参数后,是否应该让使用最优参数的算法在完整的训练集上运行并再次训练呢?因为如果把验证集重新放回到训练集中(自然训练集的数据量就又变大了),有可能最优参数又会有所变化。在实践中,不要这样做。千万不要在最终的分类器中使用验证集数据,这样做会破坏对于最优参数的估计。直接使用测试集来测试用最优参数设置好的最优模型,得到测试集数据的分类准确率,并以此作为你的kNN分类器在该数据上的性能表现。

CS231n学习笔记-图像分类笔记(下篇)的更多相关文章

  1. CS231n学习笔记-图像分类笔记(上篇)

    原文地址:智能单元 图像分类:所谓图像分类问题,就是已有固定的分类标签集合,然后对于输入的图像按照标签类别,将其打上标签. 下面先介绍一下一个简单的图像如何利用计算机进行分类: 例子:以下图为例,图像 ...

  2. CS231n课程笔记翻译2:图像分类笔记

    译者注:本文智能单元首发,译自斯坦福CS231n课程笔记image classification notes,由课程教师Andrej Karpathy授权进行翻译.本篇教程由杜客翻译完成.Shiqin ...

  3. hadoop2.5.2学习及实践笔记(二)—— 编译源代码及导入源码至eclipse

    生产环境中hadoop一般会选择64位版本,官方下载的hadoop安装包中的native库是32位的,因此运行64位版本时,需要自己编译64位的native库,并替换掉自带native库. 源码包下的 ...

  4. Python学习的个人笔记(基础语法)

    Python学习的个人笔记 题外话: 我是一个大二的计算机系的学生,这份python学习个人笔记是趁寒假这一周在慕课网,w3cschool,还有借鉴了一些博客,资料整理出来的,用于自己方便的时候查阅, ...

  5. 开始记录学习java的笔记

    今天开始记录学习java的笔记,加油

  6. 菜鸟教程之学习Shell script笔记(上)

    菜鸟教程之学习Shell script笔记 以下内容是,学习菜鸟shell教程整理的笔记 菜鸟教程之shell教程:http://www.runoob.com/linux/linux-shell.ht ...

  7. hadoop2.5.2学习及实践笔记(四)—— namenode启动过程源码概览

    对namenode启动时的相关操作及相关类有一个大体了解,后续深入研究时,再对本文进行补充 >实现类 HDFS启动脚本为$HADOOP_HOME/sbin/start-dfs.sh,查看star ...

  8. 深度学习Keras框架笔记之AutoEncoder类

    深度学习Keras框架笔记之AutoEncoder类使用笔记 keras.layers.core.AutoEncoder(encoder, decoder,output_reconstruction= ...

  9. 深度学习Keras框架笔记之TimeDistributedDense类

    深度学习Keras框架笔记之TimeDistributedDense类使用方法笔记 例: keras.layers.core.TimeDistributedDense(output_dim,init= ...

随机推荐

  1. HDU 2106 decimal system (进制转化求和)

    题意:给你n个r进制数,让你求和. 析:思路就是先转化成十进制,再加和. 代码如下: #include <iostream> #include <cstdio> #includ ...

  2. AirplaceLogger源代码解析

    将源代码添加进Eclipse中,右键-->Import-->Existing Projects into Workspace-->选择AirplaceLogger源代码文件夹即可导入 ...

  3. SSH整合 第五篇 struts2的到来

    struts2的好处,web层的显示,同时Action类相当于MVC模式的C.整合进来的话,是通过与Spring整合,减少重复代码,利用IoC和AOP. 1.struts-2.5.2.jar 以上是s ...

  4. Dalvik虚拟机java方法执行流程和Method结构体分析

    Method结构体是啥? 在Dalvik虚拟机内部,每个Java方法都有一个对应的Method结构体,虚拟机根据此结构体获取方法的所有信息. Method结构体是怎样定义的? 此结构体在不同的andr ...

  5. Informatica增量抽取时间的设置

    使用数据库或者系统变量的当前时间 Informatica中的$$SYSDATE是表示当前系统时间的系统变量. 通过这个变量,我们对每天抽取的数据可以使用以下表达式来实现增量抽取: 时间戳字段>= ...

  6. flume 整合kafka

    背景:系统的数据量越来越大,日志不能再简单的文件的保存,如此日志将会越来越大,也不方便查找与分析,综合考虑下使用了flume来收集日志,收集日志后向kafka传递消息,下面给出具体的配置 # The ...

  7. Git提取两次提交的差异文件

    1. 创建清单文件 获取两次提交之间的文件差异,并将文件清单保存到diff.txt文件中 Git diff --name-only 173d3010 18586360 > diff.txt 2. ...

  8. ASP.NET关于Session_End触发与否的问题

    项目背景: 要求开发一个篆文识别网站,由于之前做好了WinForm的,把系统直接移植到WebForm上就好.工作比较简单,但确实遇到不少问题. 核心问题是: 篆文识别涉及到用户对原始图片的预处理(例如 ...

  9. 关于Java连接SQL Sever数据库

    1.前提条件 需要: 1>本机上装有SQL Sever数据库(2005.2008或者更高版本) 2>eclipse或者myeclipse开发环境 3>jar文件(名为sql_jdbc ...

  10. 一个基于ASP.NET(C#)的ACCESS数据库操作类

    using System; using System.Collections; using System.Collections.Specialized; using System.Data; usi ...