《机学五》KNN算法及实例
一、概述
【定义】如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。
二、距离计算公式
两个样本的距离可以通过如下公式计算,又叫【欧式距离】
设有特征,a(a1,a2,a3),b(b1,b2,b3),那么:
$$\sqrt{(a1-b1){2}+(a2-b2){2}+(a3-b3)^{2}}$$
三、sklearn k-近邻算法API
sklearn.neighbors.KNeighborsClassifier(n_neighbors=5,algorithm='auto')
- n_neighbors:int,可选(默认= 5),k_neighbors查询默认使用的邻居数
- algorithm:{‘auto’,‘ball_tree’,‘kd_tree’,‘brute’},可选用于计算最近邻居的算法:
‘ball_tree’将会使用 BallTree
‘kd_tree’将使用 KDTree
‘auto’将尝试根据传递给fit方法的值来决定最合适的算法 (不同实现方式影响效率)
四、实战
数据位置:https://www.kaggle.com/c/facebook-v-predicting-check-ins/data
五、数据的处理
1、缩小数据集范围
DataFrame.query()
2、处理日期数据
pd.to_datetime
pd.DatetimeIndex
3、增加分割的日期数据
4、删除没用的日期数据
pd.drop
5、将签到位置少于3个用户的删除
#按place_id列进行分组,并数出每个地方有多少次签到
place_count =data.groupby('place_id').aggregate(np.count_nonzero)
#入住次数>3个的拿出来,place_id放到最后做为一个特征,前面加一个新索引列0-n
tf = place_count[place_count.row_id > 3].reset_index()
#原表的place_id,存在上一步新建的表里的,单独拿出来做为新数据
data = data[data['place_id'].isin(tf.place_id)]
六、Knn算法实战
数据集示例(trian.csv):
6.1.初步knn,暂时不用特征工程(标准化)
此处分为两步:
- 因为数据集太大选择一部分数据进行处理choisecsv()
- 进行knn算法knndemo()
- 可以看到此时准确率仅为2.7%
import pandas as pd
from sklearn.model_selection import train_test_split #数据集分割成:测试集、训练集模块
from sklearn.neighbors import KNeighborsClassifier #knn近邻算法模块
def choisecsv():
"""
选取一部分数据保存为新表格
:return:
"""
# 【1】读取数据
path = "D:\\a\\data\\facebook-v-predicting-check-ins\\train.csv"
data = pd.read_csv(path)
# print(data.head(10))
# 【2】因为数据集太大,只选取一部分数据,保存一个新表格
data = data.query("x > 1.0 & x < 1.25 & y > 2.5 & y < 2.75")
print(data.head(10))
data.to_csv('train_s.csv',index=False) #第2参数,不要在新csv里最前列自动加索引
def demoknn():
"""
根据输入坐标,预测入住地点
:return:
"""
data=pd.read_csv("train_s.csv")
#【1】时间戳(785470):转化时间列的时间戳为时间类似:1970-1-1 12:00:00
time_value=pd.to_datetime(data['time'],unit='s')
#print(time_value)
#(1.1)把日期转化为字典格式
time_value=pd.DatetimeIndex(time_value)
# (1.2)根据时间字典构造一些新特征:月,天,时,星期几
data['mouth']=time_value.month
data['day'] = time_value.day
data['hour'] = time_value.hour
data['weekday'] = time_value.weekday
#(1.3)删除时间戳列。pd里0是行,1是列
data=data.drop(['time'],axis=1)
#print(data.head(10))
#【2】删除入驻数量少于3次的地方
place_count=data.groupby('place_id').count() #按place_id列进行分组,并数出每个地方有多少次签到
tf = place_count[place_count.row_id > 3].reset_index() #入住次数>3个的拿出来,place_id放到最后做为一个特征,前面加一个新索引列0-n
data = data[data['place_id'].isin(tf.place_id)] #原表的place_id,存在上一步新建的表里的,单独拿出来做为新数据
#print(data.sample(10))
#【3】取得特征值x、目标值y
y=data['place_id']
x=data.drop('place_id',axis=1)
#【4】分割数据集
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.25)
#【5】特征工程(标准化)(暂时跳过)
#【6】执行Knn算法
knn=KNeighborsClassifier(n_neighbors=5)
#【7】fit(处理),执行算法(到此模型已训练完成)
knn.fit(x_train,y_train)
#【8】预测结果predict(输入测试集进行预测)
y_predict=knn.predict(x_test)
print('预测的位置为:',y_predict)
#【9】算出准确率scroe(对模型进行测试集的预测,的准确率计算):
print('准确率为:',knn.score(x_test,y_test))
if __name__=="__main__":
#choisecsv() #选取部分数据创建一个新表并保存
demoknn() #knn实例
'''结果(可以看到此次准确率很低):
预测的位置为: [1097200869 4423196276 4932578245 ... 3312463746 1479000473 1097200869]
准确率为: 0.027659574468085105
'''
6.2提升预测准确率1:进行特征工程处理(标准化)
接上例代码,在第5步加上(特征工程——标准化)处理后再预测,可以看到准确率一下由上例的2.7%,一下提升到41.3%
#【5】特征工程(标准化)对测试集、训练集的特征值进行,标准化
std=StandardScaler()
x_train=std.fit_transform(x_train)
x_test=std.fit_transform(x_test)
'''结果:
预测的位置为: [3312463746 7803770431 2327054745 ... 1602053545 5270522918 1267801529]
准确率为: 0.4134751773049645
'''
6.3 提升预测准确率2:删除无关特征
排查发现row_id这列和预测数据结果无关,会对结果造成影响,因此删除
此时发现准确率由:41%变47%
关键代码:
# (2.1)此处打印发现row_id这列特征值与预测没有相关性,会影响模型预测结果;因此把它删除:训练特征集、测试特征集都要删除
data=data.drop(['row_id'],axis=1)
源码如下:
import pandas as pd
from sklearn.model_selection import train_test_split #数据集分割成:测试集、训练集模块
from sklearn.neighbors import KNeighborsClassifier #knn近邻算法模块
from sklearn.preprocessing import StandardScaler #特征工程 标准化 模块
def choisecsv():
"""
选取一部分数据保存为新表格
:return:
"""
# 【1】读取数据
path = "D:\\a\\data\\facebook-v-predicting-check-ins\\train.csv"
data = pd.read_csv(path)
# print(data.head(10))
# 【2】因为数据集太大,只选取一部分数据,保存一个新表格
data = data.query("x > 1.0 & x < 1.25 & y > 2.5 & y < 2.75")
print(data.head(10))
data.to_csv('train_s.csv',index=False) #第2参数,不要在新csv里最前列自动加索引
def demoknn():
"""
根据输入坐标,预测入住地点
:return:
"""
data=pd.read_csv("train_s.csv")
#【1】时间戳处理:转化时间列的时间戳为时间
time_value=pd.to_datetime(data['time'],unit='s')
#print(time_value)
#(1.1)把日期转化为字典格式
time_value=pd.DatetimeIndex(time_value)
# (1.2)根据时间字典构造一些新特征:月,天,时,星期几
data['mouth']=time_value.month
data['day'] = time_value.day
data['hour'] = time_value.hour
data['weekday'] = time_value.weekday
#(1.3)删除时间戳列。pd里0是行,1是列
data=data.drop(['time'],axis=1)
#print(data.head(10))
#【2】删除入驻数量少于3次的地方
place_count=data.groupby('place_id').count() #数出一共有多少个地方
tf = place_count[place_count.row_id > 3].reset_index() #入住次数>3个的拿出来,place_id放到最后做为一个特征,前面加一个新索引列0-n
data = data[data['place_id'].isin(tf.place_id)] #原表的place_id,存在上一步新建的表里的,单独拿出来做为新数据
#print(data.sample(10))
# (2.1)此处打印发现row_id这列特征值与预测没有相关性,会影响模型预测结果;因此把它删除:训练特征集、测试特征集都要删除
data=data.drop(['row_id'],axis=1)
#【3】取得特征值x、目标值y
y=data['place_id']
x=data.drop('place_id',axis=1)
#【4】分割数据集
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.25)
# print(x_train)
#【5】特征工程(标准化)对测试集、训练集的特征值进行,标准化
std=StandardScaler()
x_train=std.fit_transform(x_train)
x_test=std.fit_transform(x_test)
#【6】执行Knn算法
knn=KNeighborsClassifier(n_neighbors=5)
#【7】fit(处理),执行算法(到此模型已训练完成)
knn.fit(x_train,y_train)
#【8】预测结果predict(输入测试集进行预测)
y_predict=knn.predict(x_test)
print('预测的位置为:',y_predict)
#【9】算出准确率scroe(对模型进行测试集的预测,的准确率计算):
print('准确率为:',knn.score(x_test,y_test))
if __name__=="__main__":
#choisecsv() #选取部分数据创建一个新表并保存
demoknn() #knn实例
'''结果
预测的位置为: [7803770431 1097200869 3533177779 ... 4932578245 5606572086 1097200869]
准确率为: 0.4761229314420804
'''
6.4提升预测准确率3:K值调整等
1.调整k值关键代码:
1、k值取多大?有什么影响?
- k值取很小:容易受异常点影响
- k值取很大:容易受最近数据太多导致比例变化
2、性能影响:k近邻每个都要进行数值运算,当数据非常多时,非常耗时
#【6】执行Knn算法
knn=KNeighborsClassifier(n_neighbors=6)
'''结果:经多次调试发现
k值调整成6预测准确率最高变为:
48%
'''
七、K近邻总结
k-近邻算法步骤
- 导入相关模块
- 对数据进行初步处理(处理时间戳成新特征、删除无关特征)
- 对数据进行划分:分成训练集、测试集
- 进行特征工程处理:一般用标准化
- 进行knn算法
import pandas as pd
from sklearn.model_selection import train_test_split #数据集分割成:测试集、训练集模块
from sklearn.preprocessing import StandardScaler #特征工程 标准化 模块
from sklearn.neighbors import KNeighborsClassifier #knn近邻算法模块
详情见上一节代码……
k-近邻算法优缺点
优点:
- 简单,易于理解,易于实现,无需估计参数,无需训练
缺点:
- 懒惰算法,对测试样本分类时的计算量大,内存开销大
- 必须指定K值,K值选择不当则分类精度不能保证
- 使用场景:小数据场景,几千~几万样本,具体场景具体业务去测试
《机学五》KNN算法及实例的更多相关文章
- KNN算法基本实例
KNN算法是机器学习领域中一个最基本的经典算法.它属于无监督学习领域的算法并且在模式识别,数据挖掘和特征提取领域有着广泛的应用. 给定一些预处理数据,通过一个属性把这些分类坐标分成不同的组.这就是KN ...
- 机器学习经典算法具体解释及Python实现--K近邻(KNN)算法
(一)KNN依旧是一种监督学习算法 KNN(K Nearest Neighbors,K近邻 )算法是机器学习全部算法中理论最简单.最好理解的.KNN是一种基于实例的学习,通过计算新数据与训练数据特征值 ...
- KNN算法 - 数据挖掘算法(3)
(2017-04-10 银河统计) KNN算法即K Nearest Neighbor算法.这个算法是机器学习里面一个比较经典的.相对比较容易理解的算法.其中的K表示最接近自己的K个数据样本.KNN算法 ...
- 机器学*——K*邻算法(KNN)
1 前言 Kjin邻法(k-nearest neighbors,KNN)是一种基本的机器学*方法,采用类似"物以类聚,人以群分"的思想.比如,判断一个人的人品,只需观察他来往最密切 ...
- 【机器学*】k-*邻算法(kNN) 学*笔记
[机器学*]k-*邻算法(kNN) 学*笔记 标签(空格分隔): 机器学* kNN简介 kNN算法是做分类问题的.思想如下: KNN算法的思想总结一下:就是在训练集中数据和标签已知的情况下,输入测试数 ...
- 【十大算法实现之KNN】KNN算法实例(含测试数据和源码)
KNN算法基本的思路是比较好理解的,今天根据它的特点写了一个实例,我会把所有的数据和代码都写在下面供大家参考,不足之处,请指正.谢谢! update:工程代码全部在本页面中,测试数据已丢失,建议去UC ...
- 一步步教你轻松学关联规则Apriori算法
一步步教你轻松学关联规则Apriori算法 (白宁超 2018年10月22日09:51:05) 摘要:先验算法(Apriori Algorithm)是关联规则学习的经典算法之一,常常应用在商业等诸多领 ...
- 机器学习笔记--KNN算法1
前言 Hello ,everyone. 我是小花.大四毕业,留在学校有点事情,就在这里和大家吹吹我们的狐朋狗友算法---KNN算法,为什么叫狐朋狗友算法呢,在这里我先卖个关子,且听我慢慢道来. 一 K ...
- 机器学习之二:K-近邻(KNN)算法
一.概述 K最近邻(k-Nearest Neighbor,KNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一.该方法的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中 ...
随机推荐
- Java--对象与类(二)
final 实例域 可以将实例域定义为final.构建对象时必须初始化这样的域.也就是说在一个构造器执行之后,这个域被设置,并且之后无法对其修改 final 修饰符大多应用于基本(primitive) ...
- python爬虫(九) requests库之post请求
1.方法: response=requests.post("https://www.baidu.com/s",data=data) 2.拉勾网职位信息获取 因为拉勾网设置了反爬虫机 ...
- Spring Cloud Hystrix 请求熔断与服务降级
在Java中,每一个HTTP请求都会开启一个新线程.而下游服务挂了或者网络不可达,通常线程会阻塞住,直到Timeout.你想想看,如果并发量多一点,这些阻塞的线程就会占用大量的资源,很有可能把自己本身 ...
- Linux命令:vi | vim命令
vim - vi 增强版.文本编辑器 格式:vim [options] [file ..] 说明:如果file存在,文件被打开并显示内容,如果文件不存在,当编辑后第一次存盘时创建它 [options] ...
- nodeJS - 定义全局变量
定义 : global.变量名=‘xxxx’; 取出 : global.变量名
- Linux centos7 linux任务计划cron、chkconfig工具、systemd管理服务、unit介绍、 target介绍
一.linux任务计划cron crontab -u -e -l -r 格式;分 时 日 月 周 user command 文件/var/spool/corn/username 分范围0-59,时范 ...
- ES6 && ECMAScript2015 新特性
ECMAScript 6(以下简称ES6)是JavaScript语言的下一代标准.因为当前版本的ES6是在2015年发布的,所以又称ECMAScript 2015. 也就是说,ES6就是ES201 ...
- 用Navicat连接阿里云ECS服务器上的MySQL数据库,连接不上,并且报10060错误
设置远程访问(使用root密码): grant all privileges on . to 'root' @'%' identified by '123456'; flush privileges; ...
- SpringBoot--⼯具表达式对象
⼯具表达式对象除了这些基本的对象之外,Thymeleaf将为我们提供⼀组⼯具对象,这些对象将帮助我们在表达式中执⾏常⻅任务.#execInfo:有关正在处理的模板的信息.#messages:⽤于在变量 ...
- 142、Java内部类之在普通方法里面定义内部类
01.代码如下: package TIANPAN; class Outer { // 外部类 private String msg = "Hello World !"; publi ...