基于sklearn和keras的数据切分与交叉验证
在训练深度学习模型的时候,通常将数据集切分为训练集和验证集.Keras提供了两种评估模型性能的方法:
- 使用自动切分的验证集
- 使用手动切分的验证集
一.自动切分
在Keras中,可以从数据集中切分出一部分作为验证集,并且在每次迭代(epoch)时在验证集中评估模型的性能.
具体地,调用model.fit()训练模型时,可通过validation_split参数来指定从数据集中切分出验证集的比例.
# MLP with automatic validation set
from keras.models import Sequential
from keras.layers import Dense
import numpy
# fix random seed for reproducibility
numpy.random.seed(7)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10)
validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。
注意,validation_split的划分在shuffle之前,因此如果你的数据本身是有序的,需要先手工打乱再指定validation_split,否则可能会出现验证集样本不均匀。
二.手动切分
Keras允许在训练模型的时候手动指定验证集.
例如,用sklearn库中的train_test_split()函数将数据集进行切分,然后在keras的model.fit()的时候通过validation_data参数指定前面切分出来的验证集.
# MLP with manual validation set
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# split into 67% for train and 33% for test
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=seed)
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(X_train, y_train, validation_data=(X_test,y_test), epochs=150, batch_size=10)
三.K折交叉验证(k-fold cross validation)
将数据集分成k份,每一轮用其中(k-1)份做训练而剩余1份做验证,以这种方式执行k轮,得到k个模型.将k次的性能取平均,作为该算法的整体性能.k一般取值为5或者10.
- 优点:能比较鲁棒性地评估模型在未知数据上的性能.
- 缺点:计算复杂度较大.因此,在数据集较大,模型复杂度较高,或者计算资源不是很充沛的情况下,可能不适用,尤其是在训练深度学习模型的时候.
sklearn.model_selection提供了KFold以及RepeatedKFold, LeaveOneOut, LeavePOut, ShuffleSplit, StratifiedKFold, GroupKFold, TimeSeriesSplit等变体.
下面的例子中用的StratifiedKFold采用的是分层抽样,它保证各类别的样本在切割后每一份小数据集中的比例都与原数据集中的比例相同.
# MLP for Pima Indians Dataset with 10-fold cross validation
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import StratifiedKFold
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# define 10-fold cross validation test harness
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
cvscores = []
for train, test in kfold.split(X, Y):
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(X[train], Y[train], epochs=150, batch_size=10, verbose=0)
# evaluate the model
scores = model.evaluate(X[test], Y[test], verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
cvscores.append(scores[1] * 100)
print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores), numpy.std(cvscores)))
参考:
Evaluate the Performance Of Deep Learning Models in Keras
3.1. Cross-validation: evaluating estimator performance — scikit-learn 0.19.1 documentation
基于sklearn和keras的数据切分与交叉验证的更多相关文章
- 机器学习 - 案例 - 样本不均衡数据分析 - 信用卡诈骗 ( 标准化处理, 数据不均处理, 交叉验证, 评估, Recall值, 混淆矩阵, 阈值 )
案例背景 银行评判用户的信用考量规避信用卡诈骗 ▒ 数据 数据共有 31 个特征, 为了安全起见数据已经向了模糊化处理无法读出真实信息目标 其中数据中的 class 特征标识为是否正常用户 (0 代表 ...
- 莫烦sklearn学习自修第七天【交叉验证】
1. 什么是交叉验证 所谓交叉验证指的是将样本分为两组,一组为训练样本,一组为测试样本:对于哪些数据分为训练样本,哪些数据分为测试样本,进行多次拆分,每次将整个样本进行不同的拆分,对这些不同的拆分每个 ...
- 基于sklearn的分类器实战
已迁移到我新博客,阅读体验更佳基于sklearn的分类器实战 完整代码实现见github:click me 一.实验说明 1.1 任务描述 1.2 数据说明 一共有十个数据集,数据集中的数据属性有全部 ...
- 客户流失?来看看大厂如何基于spark+机器学习构建千万数据规模上的用户留存模型 ⛵
作者:韩信子@ShowMeAI 大数据技术 ◉ 技能提升系列:https://www.showmeai.tech/tutorials/84 行业名企应用系列:https://www.showmeai. ...
- MySQL数据切分的相关概念和原理详解
对于数据切分,我们可能还不是很熟悉,但是它对于MySQL数据库来说也是相当重要的一门技术,本文我们就详细介绍一下MySQL数据库的数据切分的相关知识,接下来就让我们一起来了解一下这部分内容. 什么是数 ...
- MySql(十四):MySql架构设计——可扩展性设计之数据切分
一.前言 通过 MySQL Replication 功能所实现的扩展总是会受到数据库大小的限制,一旦数据库过于庞大,尤其是当写入过于频繁,很难由一台主机支撑的时候,我们还是会面临到扩展瓶颈.这时候,我 ...
- 机器学习入门-交叉验证选择参数(数据切分)train_test_split(under_x, under_y, test_size, random_state), (交叉验证的数据切分)KFold, recall_score(召回率)
1. train_test_split(under_x, under_y, test_size=0.3, random_state=0) # under_x, under_y 表示输入数据, tes ...
- MySQL性能调优与架构设计——第 14 章 可扩展性设计之数据切分
第 14 章 可扩展性设计之数据切分 前言 通过 MySQL Replication 功能所实现的扩展总是会受到数据库大小的限制,一旦数据库过于庞大,尤其是当写入过于频繁,很难由一台主机支撑的时候,我 ...
- 如何基于Go搭建一个大数据平台
如何基于Go搭建一个大数据平台 - Go中国 - CSDN博客 https://blog.csdn.net/ra681t58cjxsgckj31/article/details/78333775 01 ...
随机推荐
- Beaglebone板子修改usb连接时的默认IP192.168.0.2
首先除了有个USB线外,你还需要一个USB转串口的线(目的是防止修改错误,无法使用原来的usb的IP地址登陆,心大的可以跳过这步直接进入重点),串口线连接方法如下图: 将USB以及串口和PC机相连 ...
- Spring笔记 #01# 一个小而生动的IOC例子代码
索引 Spring容器的最小可用依赖 用XML定义元数据 实例化容器&使用容器 例子中仅包含两种类:英雄类Hero和武器类Weapon. 演示DI:给Hero初始化Weapon 演示AOP:法 ...
- Centosphp安装cassandra扩展
一.准备 当前php版本PHP Version 5.5.10,首先去http://pecl.php.net/package/cassandra,找到对应的php版本 二.下载安装 # wget htt ...
- opencv学习之路(25)、轮廓查找与绘制(四)——正外接矩形
一.简介 二.外接矩形的查找绘制 #include "opencv2/opencv.hpp" using namespace cv; void main() { //外接矩形的查找 ...
- 逐步构建循环神经网络 RNN
rnn.utils.py import numpy as np def softmax(x): e_x = np.exp(x - np.max(x)) return e_x / e_x.sum(axi ...
- Qt信号和槽机制
概述 信号和槽机制是QT的核心机制,要精通QT编程就必须对信号和槽有所了解.信号和槽是一种高级接口,应用于对象之间的通信,他是QT的核心特性,也是QT差别于其他工具包的重要地方.信号和槽是QT自行定义 ...
- spring cloud 版本号与 boot版本之间的对应关系(版本不对,会导致pom无法引入)
版本号规则 Spring Cloud并没有熟悉的数字版本号,而是对应一个开发代号. 开发代号看似没有什么规律,但实际上首字母是有顺序的,比如:Dalston版本,我们可以简称 D 版本,对应的 Edg ...
- [easyUI] lazyload 懒加载
1.使用<img>标签将图片都写在网页上. <div style="height:450px;"><h1>请往下看,有图片的吆!</h1& ...
- HDU - 1061-快速幂签到题
快速幂百度百科:快速幂就是快速算底数的n次幂.其时间复杂度为 O(log₂N), 与朴素的O(N)相比效率有了极大的提高. HDU - 1061 代码实现如下: import java.util.Sc ...
- String.format(String format, Object... args)方法详解
很多次见到同事使用这个方法,同时看到https://blog.csdn.net/qq_27298687/article/details/68921934这位仁兄写的非常仔细,我也记录一下,好加深印象. ...