机器学习 K-近邻算法(KNN)

关注公众号“轻松学编程”了解更多。

以下命令都是在浏览器中输入。

cmd命令窗口输入:jupyter notebook

后打开浏览器输入网址http://localhost:8888/

导引

如何进行电影分类

众所周知,电影可以按照题材分类,然而题材本身是如何定义的?由谁来判定某部电影属于哪 个题材?也就是说同一题材的电影具有哪些公共特征?这些都是在进行电影分类时必须要考虑的问 题。没有哪个电影人会说自己制作的电影和以前的某部电影类似,但我们确实知道每部电影在风格 上的确有可能会和同题材的电影相近。那么动作片具有哪些共有特征,使得动作片之间非常类似, 而与爱情片存在着明显的差别呢?动作片中也会存在接吻镜头,爱情片中也会存在打斗场景,我们 不能单纯依靠是否存在打斗或者亲吻来判断影片的类型。但是爱情片中的亲吻镜头更多,动作片中 的打斗场景也更频繁,基于此类场景在某部电影中出现的次数可以用来进行电影分类。

一个机器学习算法:K-近邻算法,它非常有效而且易于掌握。

一、k-近邻算法原理

简单地说,K-近邻算法采用测量不同特征值之间的距离方法进行分类。

  • 优点:精度高、对异常值不敏感、无数据输入假定。
  • 缺点:时间复杂度高、空间复杂度高。
  • 适用数据范围:数值型和标称型。

工作原理

存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输人没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。

一般来说,我们只选择样本数据集中前K个最相似的数据,这就是K-近邻算法中K的出处,

通常K是不大于20的整数。
最后 ,选择K个最相似数据中出现次数最多的分类,作为新数据的分类。

回到前面电影分类的例子,使用K-近邻算法分类爱情片和动作片。有人曾经统计过很多电影的打斗镜头和接吻镜头,下图显示了6部电影的打斗和接吻次数。假如有一部未看过的电影,如何确定它是爱情片还是动作片呢?我们可以使用K-近邻算法来解决这个问题。

首先我们需要知道这个未知电影存在多少个打斗镜头和接吻镜头,上图中问号位置是该未知电影出现的镜头数图形化展示,具体数字参见下表。

电影名称 打斗镜头 接吻镜头 电影类型
California Man 3 104 爱情片
He’s Not Really into Dudes 2 100 爱情片
Beautiful Woman 1 81 爱情片
Kevin Longblade 101 10 动作片
Robo Slayer 3000 99 5 动作片
Amped II 98 2 动作片
? 18 90 未知

即使不知道未知电影属于哪种类型,我们也可以通过某种方法计算出来。首先计算未知电影与样本集中其他电影的距离,如下表所示。

电影名称 与未知电影的距离
California Man 20.5
He’s Not Really into Dudes 18.7
Beautiful Woman 19.2
Kevin Longblade 115.3
Robo Slayer 3000 117.4
Amped II 118.9

现在我们得到了样本集中所有电影与未知电影的距离,按照距离递增排序,可以找到K个距离最近的电影。假定k=3,则三个最靠近的电影依次是California Man、He’s Not Really into Dudes、Beautiful Woman。K-近邻算法按照距离最近的三部电影的类型,决定未知电影的类型,而这三部电影全是爱情片,因此我们判定未知电影是爱情片。

欧几里得距离(Euclidean Distance)

欧氏距离是最常见的距离度量,衡量的是多维空间中各个点之间的绝对距离。公式如下:

二、在scikit-learn库中使用k-近邻算法

  • 分类问题:from sklearn.neighbors import KNeighborsClassifier
  • 回归问题:from sklearn.neighbors import KNeighborsRegressor

1)用于分类

范例:动作爱情电影分类分析

1、导包:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pandas import DataFrame,Series
#导入分类算法模型包
from sklearn.neighbors import KNeighborsClassifier %matplotlib inline

2、获取样本集

#获取样本数据
data=pd.read_excel('./my_films.xlsx')
data

3、提取特征数据和目标数据

#样本集中的特征数据
features=data[['Action Lens','Love Lens']]
features

#样本集中的目标数据
target=data['target']
target

4、创建算法模型对象

#创建分类算法模型对象
knn=KNeighborsClassifier(n_neighbors=5)

5、使用样本数据训练模型

#使用样本数据训练模型
knn.fit(features,target)

6、预测

#上铺的兄弟(动作:40 爱情:20)
knn.predict([40,20])

n_neighbors一般取奇数值,取值不同,预测结果有可能不同,可以使用score函数对模型进行精度的评分。

7、使用score函数对模型进行精度的评分

必须先对样本数据进行拆分:训练样本数据和测试样本数据

7.1、导包
#对样本数据进行拆分:使用该函数train_test_split对样本数据进行拆分
from sklearn.model_selection import train_test_split
7.2、拆分样本数据
#参数:样本数据的特征数据 目标数据  拆分比例
#返回值:训练集的特征 测试集的特征 训练集的目标 测试集的目标
x_train,x_test,y_train,y_test=train_test_split(features,target,
test_size=0.2)
7.3、创建算法模型对象
knn_=KNeighborsClassifier(n_neighbors=5)
7.4、使用样本数据训练模型
#将训练集数据带入到fit中
knn_.fit(x_train,y_train)

7.5、对模型进行评分
#给分类算法模型进行精度评分:参数为测试集的特征数据和目标数据
knn_.score(x_test,y_test)

分数范围:0~1.

7.6、调整模型
knn_=KNeighborsClassifier(n_neighbors=3)
knn_.fit(x_train,y_train)
knn_.score(x_test,y_test)

7.7、重新预测
#上铺的兄弟(动作:40 爱情:39)
knn.predict([40,39])

2)用于回归

回归用于对趋势的预测 。

1、生成样本数据

# 设置随机种子,确保每次运行得到相同结果
np.random.seed(0) #生成40个随机数据,这是训练数据
X_test = np.sort(5*np.random.rand(40,1),axis = 0)
#这是要进行预测的数据
T = np.linspace(0,5,100)[:,np.newaxis]
y = np.sin(X_test).ravel()
#添加噪声
y[::5] += (0.5 - np.random.rand(8))

2、定义一个KNN回归模型

# 定义一个KNN回归模型
knn = KNeighborsRegressor(n_neighbors=5)

3、训练模型

#训练数据
knn.fit(X_test,y)

4、预测数据

# 预测数据
y_ = knn.predict(T)

5、画图

# 画图
plt.scatter(X_test,y,c='k',label = 'data')
plt.plot(T,y_,c='g',label = 'prediction')
plt.axis('tight')
plt.legend()

三、绘制分类边界图

使用sklearn中自带的鸢尾花数据。

1、导包

#导入自带的数据集中的鸢尾花数据集
from sklearn.datasets import load_iris

2、获取数据样本集

iris=load_iris()
iris_df=DataFrame(data=iris.data,columns=iris.feature_names)
iris_df['target']=Series(data=iris.target)
iris_df.head()

3、获取特征数据和目标数据

3.1 特征数据

根据计算原始样本数据中的四种特征数据的方差,摘选出方差最大的两种特征作为新样本集的特征数据。

feature_std=iris_df[['sepal length (cm)','sepal width (cm)',
'petal length (cm)','petal width (cm)']].std(axis=0)
feature_std

features=iris_df[[0,2]]
features.head()

3.2 目标数据
#目标数据
target=iris_df['target']

4、拆分数据集

#对样本集数据进行一个拆分(为了后续实现模型算法的评分)
from sklearn.model_selection import train_test_split #x_train:训练集中的特征数据
#x_test:测试集中的特征数据
#y_train:训练集中的目标数据
#y_test:测试集中的目标数据 #test_size:一般把样本集按二八比例分 #训练集数据需要传递给fit函数
#测试集数据需要传递给score函数
x_train,x_test,y_train,y_test=train_test_split(features,target,
test_size=0.2)

5、创建分类模型对象

knn=KNeighborsClassifier(n_neighbors=5)

####6、训练模型

knn.fit(x_train,y_train)

7、对模型进行评分,调整k取值参数

knn.score(x_test,y_test)

n_neighbors=5对应的分数达到0.93分,比较高。

如果分数较低,需要调整n_neighbors(k)取值。

8、绘图显示分类情况

8.1原始数据的分类情况

#显示训练样本集中原始数据的分类情况
plt.scatter(x_train.iloc[:,0],x_train.iloc[:,1],c=y_train)

8.2 模型预测后的分类情况
#显示预测后的分类情况,和上图比对
yy=knn.predict(x_train)
plt.scatter(x_train.iloc[:,0],x_train.iloc[:,1],c=yy)

两图差不多,说明模型预测结果较为精准。

9 绘制边界图

9.1 获取坐标系边界
x_train.head()

#获取x_train第1列最小、最大值
xmin,xmax=x_train.iloc[:,0].min(),x_train.iloc[:,0].max()
#获取x_train第2列最小、最大值
ymin,ymax=x_train.iloc[:,1].min(),x_train.iloc[:,1].max()
#在最小、最大值中间分成300个点
x=np.linspace(xmin,xmax,300)
y=np.linspace(ymin,ymax,300)
9.2 网格交叉(获取坐标系所有的散点)
9.2.1 网格交叉
#进行网格交叉(获取整个坐标系中所有的散点)
xx,yy=np.meshgrid(x,y)

######9.2.2 合并坐标点

#把xx,yy合并成两列的二维数组
grid_test=np.c_[xx.ravel(),yy.ravel()]
grid_test

9.2.3 显示坐标点
plt.scatter(grid_test[:,0],grid_test[:,1])

从上图中可以看出进行网格交叉后的散点已经布满了整个坐标系,然后用训练后的模型对这些散点做一个预测就可以显示分类后的边界图了。

9.2.4 对坐标点进行预测
grid_y=knn_.predict(grid_test)

#####9.3 绘制分类后的边界图

#使用固定颜色
from matplotlib.colors import ListedColormap
cmap=ListedColormap(['red','blue','green'])
plt.scatter(grid_test[:,0],grid_test[:,1],c=grid_y,cmap=cmap)
plt.scatter(x_train.iloc[:,0],x_train.iloc[:,1],c=yy)

后记

【后记】为了让大家能够轻松学编程,我创建了一个公众号【轻松学编程】,里面有让你快速学会编程的文章,当然也有一些干货提高你的编程水平,也有一些编程项目适合做一些课程设计等课题。

也可加我微信【1257309054】,拉你进群,大家一起交流学习。
如果文章对您有帮助,请我喝杯咖啡吧!

公众号

关注我,我们一起成长~~

python机器学习实现K-近邻算法(KNN)的更多相关文章

  1. 机器学习之K近邻算法(KNN)

    机器学习之K近邻算法(KNN) 标签: python 算法 KNN 机械学习 苛求真理的欲望让我想要了解算法的本质,于是我开始了机械学习的算法之旅 from numpy import * import ...

  2. k近邻算法(KNN)

    k近邻算法(KNN) 定义:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别. from sklearn.model_selection ...

  3. 【机器学习】k近邻算法(kNN)

    一.写在前面 本系列是对之前机器学习笔记的一个总结,这里只针对最基础的经典机器学习算法,对其本身的要点进行笔记总结,具体到算法的详细过程可以参见其他参考资料和书籍,这里顺便推荐一下Machine Le ...

  4. 机器学习(四) 分类算法--K近邻算法 KNN (上)

    一.K近邻算法基础 KNN------- K近邻算法--------K-Nearest Neighbors 思想极度简单 应用数学知识少 (近乎为零) 效果好(缺点?) 可以解释机器学习算法使用过程中 ...

  5. 机器学习(四) 机器学习(四) 分类算法--K近邻算法 KNN (下)

    六.网格搜索与 K 邻近算法中更多的超参数 七.数据归一化 Feature Scaling 解决方案:将所有的数据映射到同一尺度 八.scikit-learn 中的 Scaler preprocess ...

  6. 用Python从零开始实现K近邻算法

    KNN算法的定义: KNN通过测量不同样本的特征值之间的距离进行分类.它的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别.K通 ...

  7. 一看就懂的K近邻算法(KNN),K-D树,并实现手写数字识别!

    1. 什么是KNN 1.1 KNN的通俗解释 何谓K近邻算法,即K-Nearest Neighbor algorithm,简称KNN算法,单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1 ...

  8. 机器学习中 K近邻法(knn)与k-means的区别

    简介 K近邻法(knn)是一种基本的分类与回归方法.k-means是一种简单而有效的聚类方法.虽然两者用途不同.解决的问题不同,但是在算法上有很多相似性,于是将二者放在一起,这样能够更好地对比二者的异 ...

  9. 《机器学习实战》---第二章 k近邻算法 kNN

    下面的代码是在python3中运行, # -*- coding: utf-8 -*- """ Created on Tue Jul 3 17:29:27 2018 @au ...

  10. 机器学习之K近邻算法

    K 近邻 (K-nearest neighbor, KNN) 算法直接作用于带标记的样本,属于有监督的算法.它的核心思想基本上就是 近朱者赤,近墨者黑. 它与其他分类算法最大的不同是,它是一种&quo ...

随机推荐

  1. 如何自动填充SQL语句中的公共字段

    1. 前言 我们在设计数据库的时候一定会带上新增.更新的时间.操作者等审计信息. 之所以带这些信息是因为假如有一天公司的数据库被人为删了,尽管可能有数据库备份可以恢复数据.但是我们仍然需要追踪到这个事 ...

  2. Go-归档文件-tar

    文件归档 tar 1. 创建一个tar头部并自动填充tar头部信息 tar.FileInfoHeader() 联合 os.Stat() 方法 2. 手动填写 tar头部信息 tar.Header{} ...

  3. const pointers

    1 指针 p对应的地址是常量,但是里面存放的data不是常量 2 地址里存放的data是常量,但是地址不是常量 3 地址和指针都是常量

  4. 【题解】[ZJOI2009]假期的宿舍

    \(\color{red}{Link}\) \(\text{Solution:}\) 把人和床看成点,问题转化为二分图. 于是,对于每一个在校生,我们建立出他的床点:然后对于每一个在校生,他们自己可以 ...

  5. 【题解】[APIO2010]特别行动队

    Link 题目大意:一段区间的贡献是\(ax^2+bx+c,x=\sum v\),求一个划分让总区间的价值最大.分段必须连续. \(\text{Solution:}\) 设计\(dp[i]\)表示前\ ...

  6. Android设备上的逐像素碰撞检测

    介绍 我正在我的Android设备上开发一款游戏,不用说,因为我想要接触到尽可能多的用户,我做到了 省略了硬件加速.因此,我需要编写能够在大多数设备上运行的最快的代码.我从一个简单的表面视图开始 并使 ...

  7. intelliJ 软件项目打开运行

    1.导入项目 2.首先更改数据库,找到application-dev.yml文件,更改数据源 3.配置tomcat端口  找到application.yml 文件 然后打开pom.xml 更改版本号 ...

  8. Linux就该这么学28期——Day05 vim编辑器与Shell命令脚本 (yum配置 网卡配置)

    vim 三种模式: 命令模式 按行操作 dd 剪切.删除 5dd dG   全删 yy 复制光标所在行 p 粘贴 u 撤销操作 / 搜索 /ab n  下一个 N   上一个 输入模式 a 当前光标处 ...

  9. golang xpath解析网页

    https://github.com/antchfx/htmlquery package main import ( "fmt" "github.com/antchfx/ ...

  10. gin+gorm 用户服务

    package main import ( "fmt" "github.com/gin-gonic/gin" "github.com/jinzhu/g ...