在上一个博文中,我们用learning_curve函数来确定应该拥有多少的训练集能够达到效果,就像一个人进行学习时需要做多少题目就能拥有较好的考试成绩了。

本次我们来看下如何调整学习中的参数,类似一个人是在早上7点钟开始读书好还是晚上8点钟读书好。

加载数据

数据仍然利用手写数字识别作为训练数据:

  1. from sklearn.datasets import load_digits
  2. # 加载数据
  3. digits = load_digits()
  4. X = digits.data
  5. y = digits.target

调整参数

我们想要调整·SVC(gamma=0.001)·SVC中的gamma参数,看到底把gamma参数设置成哪个值是最优的。

因此需要定义测试的参数范围,这里设置了参数值的范围为从10的-6次方到10的-2.3次方,总共5个值:

  1. import numpy as np
  2. # 定义gamma参数的可能取值范围,从10**-6, 到10**-2.3,总共5个参数值
  3. param_range = np.logspace(-6, -2.3, 5)

validation_curve不停尝试在不同参数值下的损失函数值:

  1. from sklearn.model_selection import validation_curve
  2. from sklearn.svm import SVC
  3. # param_name中指定了修改SVC中的哪个参数值,这里修改的是gamma参数值;param_range参数指定了具体参数值的可选范围
  4. train_loss, test_loss = validation_curve(SVC(), X, y, param_name="gamma", param_range=param_range, cv=10, scoring='neg_mean_squared_error')
  5. train_loss_mean = -np.mean(train_loss, axis=1)
  6. test_loss_mean = -np.mean(test_loss, axis=1)

可视化图形

可视化图形,横坐标为参数可选值的范围,纵坐标为在各参数下的损失函数值

  1. # 可视化图形,横坐标为参数可选值的范围,纵坐标为在各参数下的损失函数值
  2. import matplotlib.pyplot as plt
  3. plt.plot(param_range, train_loss_mean, label="Train")
  4. plt.plot(param_range, test_loss_mean, label="Test")
  5. plt.legend()
  6. plt.show()

图形显示为:

在这个图形中,我们发现gamma值有一个转折点,当其在0.001之后,测试集的误差值就开始扩大了,因此,从图形上看,一个比较好的学习参数值是gamma=0.001或者再往前一点点,大概在0.0007左右。

完整代码

完整的代码如下:

  1. from sklearn.datasets import load_digits
  2. # 加载数据
  3. digits = load_digits()
  4. X = digits.data
  5. y = digits.target
  6. import numpy as np
  7. # 定义gamma参数的可能取值范围,从10**-6, 到10**-2.3,总共5个参数值
  8. param_range = np.logspace(-6, -2.3, 5)
  9. from sklearn.model_selection import validation_curve
  10. from sklearn.svm import SVC
  11. # param_name中指定了修改SVC中的哪个参数值,这里修改的是gamma参数值;param_range参数指定了具体参数值的可选范围
  12. train_loss, test_loss = validation_curve(SVC(), X, y, param_name="gamma", param_range=param_range, cv=10, scoring='neg_mean_squared_error')
  13. train_loss_mean = -np.mean(train_loss, axis=1)
  14. test_loss_mean = -np.mean(test_loss, axis=1)
  15. # 可视化图形,横坐标为参数可选值的范围,纵坐标为在各参数下的损失函数值
  16. import matplotlib.pyplot as plt
  17. plt.plot(param_range, train_loss_mean, label="Train")
  18. plt.plot(param_range, test_loss_mean, label="Test")
  19. plt.legend()
  20. plt.show()

sklearn交叉验证3-【老鱼学sklearn】的更多相关文章

  1. sklearn交叉验证-【老鱼学sklearn】

    交叉验证(Cross validation),有时亦称循环估计, 是一种统计学上将数据样本切割成较小子集的实用方法.于是可以先在一个子集上做分析, 而其它子集则用来做后续对此分析的确认及验证. 一开始 ...

  2. sklearn交叉验证2-【老鱼学sklearn】

    过拟合 过拟合相当于一个人只会读书,却不知如何利用知识进行变通. 相当于他把考试题目背得滚瓜烂熟,但一旦环境稍微有些变化,就死得很惨. 从图形上看,类似下图的最右图: 从数学公式上来看,这个曲线应该是 ...

  3. sklearn保存模型-【老鱼学sklearn】

    训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...

  4. sklearn数据库-【老鱼学sklearn】

    在做机器学习时需要有数据进行训练,幸好sklearn提供了很多已经标注好的数据集供我们进行训练. 本节就来看看sklearn提供了哪些可供训练的数据集. 这些数据位于datasets中,网址为:htt ...

  5. sklearn模型的属性与功能-【老鱼学sklearn】

    本节主要讲述模型中的各种属性及其含义. 例如上个博文中,我们有用线性回归模型来拟合房价. # 创建线性回归模型 model = LinearRegression() # 训练模型 model.fit( ...

  6. sklearn标准化-【老鱼学sklearn】

    在前面的一篇博文中关于计算房价中我们也大致提到了标准化的概念,也就是比如对于影响房价的参数中有面积和户型,面积的取值范围可以很广,它可以从0-500平米,而户型一般也就1-5. 标准化就是要把这两种参 ...

  7. 二分类问题续 - 【老鱼学tensorflow2】

    前面我们针对电影评论编写了二分类问题的解决方案. 这里对前面的这个方案进行一些改进. 分批训练 model.fit(x_train, y_train, epochs=20, batch_size=51 ...

  8. tensorflow卷积神经网络-【老鱼学tensorflow】

    前面我们曾有篇文章中提到过关于用tensorflow训练手写2828像素点的数字的识别,在那篇文章中我们把手写数字图像直接碾压成了一个784列的数据进行识别,但实际上,这个图像是2828长宽结构的,我 ...

  9. 机器学习- Sklearn (交叉验证和Pipeline)

    前面一节咱们已经介绍了决策树的原理已经在sklearn中的应用.那么这里还有两个数据处理和sklearn应用中的小知识点咱们还没有讲,但是在实践中却会经常要用到的,那就是交叉验证cross_valid ...

随机推荐

  1. linux文件目录权限和系统基础优化命令(yum源配置)

    一.用户 1.介绍 我们都知道linux中有root用户和普通用户,但是同样是普通用户,为什么有些用户的权限却不一样呢?其实这就类似于我们的QQ群,root用户就是QQ群主,他拥有最高的权利,想干什么 ...

  2. python之旅九【第九篇】socket

    什么是socket 建立网络通信连接至少要一对端口号(socket).socket本质是编程接口(API),对TCP/IP的封装,TCP/IP也要提供可供程序员做网络开发所用的接口,这就是Socket ...

  3. 初识并发编程 MPI

    MPI是一个跨语言的通讯协议,用于并发编程.MPI标准定义了一组具有可移植性的编程接口. 安装环境 MPICH 是开源的消息传递接口(MPI)标准的实现. 下载地址 # 解压文件 tar -xzvf ...

  4. FastDFS分布式文件系统客户端安装

    软件安装前提:服务器已配置好LNMP环境安装libfastcommon见FastDFS服务器安装文档(http://www.cnblogs.com/Mrhuangrui/p/8316481.html) ...

  5. [FJOI2016]建筑师

    题目描述 小 Z 是一个很有名的建筑师,有一天他接到了一个很奇怪的任务:在数轴上建 n 个建筑,每个建筑的高度是 1 到 n 之间的一个整数. 小 Z 有很严重的强迫症,他不喜欢有两个建筑的高度相同. ...

  6. 为什么会有这么多python?其实python并不是编程语言!

    Python是出类拔萃的 然而,这是一句非常模棱两可的话.这里的"Python"到底指的是什么? 是Python的抽象接口吗?是Python的通用实现CPython吗(不要把CPy ...

  7. django系列7:修改404页面展示,优化模板,降低urlconf和模板之间的耦合,命名app将模板和app绑定

    为了增加程序的友好和健壮性,修改view代码,处理以下如果出现404,页面的UI展示. 修改view代码 from django.http import Http404 from django.sho ...

  8. 金融量化分析【day112】:因子选股

    一.因子选股基础 二.因子选股策略实现代码 # 导入函数库 import jqdata import psutil #初始化函数,设定基准等等 def initialize(context): set ...

  9. BFC块级格式化上下文

    BFC块级格式化上下文 触发条件 overflow 值不为 visible 的块元素 根元素 html 元素 浮动元素(元素的 float 不是 none) 绝对定位元素(元素的 position 为 ...

  10. CAS实现单点登录

    1.简介 SSO单点登录 在多个相互信任的系统中,用户只需要登录一次就可以访问其他受信任的系统. 新浪微博与新浪博客是相互信任的应用系统. *当用户首次访问新浪微博时,新浪微博识别到用户未登录,将请求 ...