交叉验证(Cross validation),有时亦称循环估计, 是一种统计学上将数据样本切割成较小子集的实用方法。于是可以先在一个子集上做分析, 而其它子集则用来做后续对此分析的确认及验证。 一开始的子集被称为训练集。而其它的子集则被称为验证集或测试集。交叉验证是一种评估统计分析、机器学习算法对独立于训练数据的数据集的泛化能力(generalize)。

我们以分类花的例子来看下:

  1. # 加载iris数据集
  2. from sklearn.datasets import load_iris
  3. from sklearn.model_selection import train_test_split
  4. from sklearn.neighbors import KNeighborsClassifier
  5. iris = load_iris()
  6. X = iris.data
  7. y = iris.target
  8. # 分割训练集和测试集
  9. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
  10. # 建立模型
  11. model = KNeighborsClassifier()
  12. # 训练模型
  13. model.fit(X_train, y_train)
  14. # 将准确率打印出
  15. print(model.score(X_test, y_test))

这样这个模型的得分为:

0.911111111111

但是如果我再运行一下,这个得分又会变成:

0.955555555556

如果再进行多次运行,这个得分的结果就又会不一样。

为了能够得出一个相对比较准确的得分,一般是进行多次试验,并且是用不同的训练集和测试集进行。

这个叫做交叉验证,一般有留一法,也就是把原始数据分成十份,其中一份作为测试,其它的作为训练集,并且可以循环来选取其中的一份作为测试集,剩下的作为训练集。

当然,这里只是提供一个基本思想,具体你要分成几份可以自己来定义。

比如,下面的代码我们定义了5份并做了5次实验:

  1. # 加载iris数据集
  2. from sklearn.datasets import load_iris
  3. from sklearn.model_selection import train_test_split, cross_val_score
  4. from sklearn.neighbors import KNeighborsClassifier
  5. iris = load_iris()
  6. X = iris.data
  7. y = iris.target
  8. # 建立模型
  9. model = KNeighborsClassifier()
  10. # 使用K折交叉验证模块
  11. scores = cross_val_score(model, X, y, cv=5)
  12. # 将5次的预测准确率打印出
  13. print(scores)

输出为:

  1. [ 0.96666667 1. 0.93333333 0.96666667 1. ]

对这几次实验结果进行一下平均作为本次实验的最终得分:

  1. # 将5次的预测准确平均率打印出
  2. print(scores.mean())

结果为:

0.973333333333

在KNN算法中,其中有个neighbors参数,我们可以修改此参数的值:

  1. model = KNeighborsClassifier(n_neighbors=5)

但这个参数值选择哪个数字为最佳呢?

我们可以通过程序来不停选择这个值并看在不同数值下其对应的得分情况,最终可以选择得分较好对应的参数值:

  1. # 加载iris数据集
  2. from sklearn.datasets import load_iris
  3. from sklearn.model_selection import cross_val_score
  4. from sklearn.neighbors import KNeighborsClassifier
  5. # 可视化模块
  6. import matplotlib.pyplot as plt
  7. iris = load_iris()
  8. X = iris.data
  9. y = iris.target
  10. # 建立测试参数集
  11. k_range = range(1, 31)
  12. k_scores = []
  13. for k in k_range:
  14. # 建立模型
  15. model = KNeighborsClassifier(n_neighbors=k)
  16. # 使用K折交叉验证模块
  17. scores = cross_val_score(model, X, y, cv=10)
  18. # 计算10次的预测准确平均率
  19. k_scores.append(scores.mean())
  20. # 可视化数据
  21. plt.plot(k_range, k_scores)
  22. plt.show()

显示的图形为:

从这个结果图上看,n_neighbors太小或太大其精确度都会下降,因此比较好的取值是5-20之间。

另外对于回归算法,需要用损失函数来进行评估:

  1. loss = -cross_val_score(model, data_X, data_y, cv=10, scoring='neg_mean_squared_error')

sklearn交叉验证-【老鱼学sklearn】的更多相关文章

  1. sklearn标准化-【老鱼学sklearn】

    在前面的一篇博文中关于计算房价中我们也大致提到了标准化的概念,也就是比如对于影响房价的参数中有面积和户型,面积的取值范围可以很广,它可以从0-500平米,而户型一般也就1-5. 标准化就是要把这两种参 ...

  2. sklearn数据库-【老鱼学sklearn】

    在做机器学习时需要有数据进行训练,幸好sklearn提供了很多已经标注好的数据集供我们进行训练. 本节就来看看sklearn提供了哪些可供训练的数据集. 这些数据位于datasets中,网址为:htt ...

  3. sklearn交叉验证2-【老鱼学sklearn】

    过拟合 过拟合相当于一个人只会读书,却不知如何利用知识进行变通. 相当于他把考试题目背得滚瓜烂熟,但一旦环境稍微有些变化,就死得很惨. 从图形上看,类似下图的最右图: 从数学公式上来看,这个曲线应该是 ...

  4. sklearn交叉验证3-【老鱼学sklearn】

    在上一个博文中,我们用learning_curve函数来确定应该拥有多少的训练集能够达到效果,就像一个人进行学习时需要做多少题目就能拥有较好的考试成绩了. 本次我们来看下如何调整学习中的参数,类似一个 ...

  5. sklearn模型的属性与功能-【老鱼学sklearn】

    本节主要讲述模型中的各种属性及其含义. 例如上个博文中,我们有用线性回归模型来拟合房价. # 创建线性回归模型 model = LinearRegression() # 训练模型 model.fit( ...

  6. sklearn保存模型-【老鱼学sklearn】

    训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...

  7. 二分类问题续 - 【老鱼学tensorflow2】

    前面我们针对电影评论编写了二分类问题的解决方案. 这里对前面的这个方案进行一些改进. 分批训练 model.fit(x_train, y_train, epochs=20, batch_size=51 ...

  8. tensorflow卷积神经网络-【老鱼学tensorflow】

    前面我们曾有篇文章中提到过关于用tensorflow训练手写2828像素点的数字的识别,在那篇文章中我们把手写数字图像直接碾压成了一个784列的数据进行识别,但实际上,这个图像是2828长宽结构的,我 ...

  9. 机器学习- Sklearn (交叉验证和Pipeline)

    前面一节咱们已经介绍了决策树的原理已经在sklearn中的应用.那么这里还有两个数据处理和sklearn应用中的小知识点咱们还没有讲,但是在实践中却会经常要用到的,那就是交叉验证cross_valid ...

随机推荐

  1. Ubuntu shutdown

    gsettings set com.canonical.indicator.session suppress-logout-restart-shutdown true

  2. POJChallengeRound2 Tree 【数学期望】

    题目分析: 我们令$G(x)$表示前$x$个点的平均深度,$F(x)$表示第$x$个点的期望深度. 有$F(x) = G(x-1)+1$,$G(x) = G(x-1)+\frac{1}{x}$ 所以答 ...

  3. Android学习第8天

    进程的概念 a)        四大组件都运行在主线程中 b)        服务是没有界面的,可理解为没有界面的Activity c)         进程的优先级 i.              ...

  4. Codeforces 1037C Equalize

    原题 题目大意: 给你两个长度都为\(n\)的的\(01\)串\(a,b\),现在你可以对\(a\)串进行如下两种操作: 1.交换位置\(i\)和位置\(j\),代价为\(|i-j|\) 2.反转位置 ...

  5. 使用Eclipse创建动态的web工程

    使用Eclipse创建动态的web工程 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.修改工作区的编码 1>.点击Window选择Preferences 2>.将默 ...

  6. java Ajax跨域请求COOKIE无法带上的解决办法

    1.web.xml加入以下节点,,一定放在第一个filter <!--目录下所有文件可以跨域Begin--> <filter> <filter-name>CorsF ...

  7. LFYZ-OJ ID: 1021 邮票问题

    邮票问题 Problem Description 设有已知面额的邮票m种,每种有n张,用总数不超过n张的邮票,能从面额1开始,最多连续组成多少面额.(1≤m≤100,1≤n≤100,1≤邮票面额≤25 ...

  8. Contest2158 - 2019-3-14 高一noip基础知识点 测试3 题解版

    传送门 预计得分:0 实际得分:90 还行 T1 数学卡精 二分double卡精 反正就是卡精 怎么办?卡回去!! 将double*=1e4,变成一个long long 注意四舍五入的奇技淫巧 代码 ...

  9. MySQL学习4 - 数据类型一

    介绍 一.数值类型 二.浮点型 验证三种类型建表 验证三种类型的精度 三.日期类型 综合练习: 介绍 存储引擎决定了表的类型,而表内存放的数据也要有不同的类型,每种数据类型都有自己的宽度,但宽度是可选 ...

  10. Coursera, Big Data 3, Integration and Processing (week 4)

    Week 4 Big Data Precessing Pipeline 上图可以generalize 成下图,也就是Big data pipeline some high level processi ...