话不多说,直接上代码

 def stacking_first(train, train_y, test):
savepath = './stack_op{}_dt{}_tfidf{}/'.format(args.option, args.data_type, args.tfidf)
os.makedirs(savepath, exist_ok=True) count_kflod = 0
num_folds = 6
kf = KFold(n_splits=num_folds, shuffle=True, random_state=10)
# 测试集上的预测结果
predict = np.zeros((test.shape[0], config.n_class))
# k折交叉验证集的预测结果
oof_predict = np.zeros((train.shape[0], config.n_class))
scores = []
f1s = [] for train_index, test_index in kf.split(train):
# 训练集划分为6折,每一折都要走一遍。那么第一个是5份的训练集索引,第二个是1份的测试集,此处为验证集是索引 kfold_X_train = {}
kfold_X_valid = {} # 取数据的标签
y_train, y_test = train_y[train_index], train_y[test_index]
# 取数据
kfold_X_train, kfold_X_valid = train[train_index], train[test_index] # 模型的前缀
model_prefix = savepath + 'DNN' + str(count_kflod)
if not os.path.exists(model_prefix):
os.mkdir(model_prefix) M = 4 # number of snapshots
alpha_zero = 1e-3 # initial learning rate
snap_epoch = 16
snapshot = SnapshotCallbackBuilder(snap_epoch, M, alpha_zero) # 使用训练集的size设定维度,fit一个模型出来
res_model = get_model(train)
res_model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# res_model.fit(train_x, train_y, batch_size=BATCH_SIZE, epochs=EPOCH, verbose=1, class_weight=class_weight)
res_model.fit(kfold_X_train, y_train, batch_size=BATCH_SIZE, epochs=snap_epoch, verbose=1,
validation_data=(kfold_X_valid, y_test),
callbacks=snapshot.get_callbacks(model_save_place=model_prefix)) # 找到这个目录下所有已经训练好的深度学习模型,通过".h5"
evaluations = []
for i in os.listdir(model_prefix):
if '.h5' in i:
evaluations.append(i) # 给测试集和当前的验证集开辟空间,就是当前折的数据预测结果构建出这么多的数据集[数据个数,类别]
preds1 = np.zeros((test.shape[0], config.n_class))
preds2 = np.zeros((len(kfold_X_valid), config.n_class))
# 遍历每一个模型,用他们分别预测当前折数的验证集和测试集,N个模型的结果求平均
for run, i in enumerate(evaluations):
res_model.load_weights(os.path.join(model_prefix, i))
preds1 += res_model.predict(test, verbose=1) / len(evaluations)
preds2 += res_model.predict(kfold_X_valid, batch_size=128) / len(evaluations) # 测试集上预测结果的加权平均
predict += preds1 / num_folds
# 每一折的预测结果放到对应折上的测试集中,用来最后构建训练集
oof_predict[test_index] = preds2 # 计算精度和F1
accuracy = mb.cal_acc(oof_predict[test_index], np.argmax(y_test, axis=1))
f1 = mb.cal_f_alpha(oof_predict[test_index], np.argmax(y_test, axis=1), n_out=config.n_class)
print('the kflod cv is : ', str(accuracy))
print('the kflod f1 is : ', str(f1))
count_kflod += 1 # 模型融合的预测结果,存起来,用以以后求平均值
scores.append(accuracy)
f1s.append(f1)
# 指标均值,最为最后的预测结果
print('total scores is ', np.mean(scores))
print('total f1 is ', np.mean(f1s))
return predict

深度学习模型stacking模型融合python代码,看了你就会使的更多相关文章

  1. 时间序列深度学习:seq2seq 模型预测太阳黑子

    目录 时间序列深度学习:seq2seq 模型预测太阳黑子 学习路线 商业中的时间序列深度学习 商业中应用时间序列深度学习 深度学习时间序列预测:使用 keras 预测太阳黑子 递归神经网络 设置.预处 ...

  2. 【转】[caffe]深度学习之图像分类模型AlexNet解读

    [caffe]深度学习之图像分类模型AlexNet解读 原文地址:http://blog.csdn.net/sunbaigui/article/details/39938097   本文章已收录于: ...

  3. [caffe]深度学习之图像分类模型VGG解读

    一.简单介绍 vgg和googlenet是2014年imagenet竞赛的双雄,这两类模型结构有一个共同特点是go deeper.跟googlenet不同的是.vgg继承了lenet以及alexnet ...

  4. 深度学习 vs. 概率图模型 vs. 逻辑学

    深度学习 vs. 概率图模型 vs. 逻辑学 摘要:本文回顾过去50年人工智能(AI)领域形成的三大范式:逻辑学.概率方法和深度学习.文章按时间顺序展开,先回顾逻辑学和概率图方法,然后就人工智能和机器 ...

  5. 深度学习的seq2seq模型——本质是LSTM,训练过程是使得所有样本的p(y1,...,yT‘|x1,...,xT)概率之和最大

    from:https://baijiahao.baidu.com/s?id=1584177164196579663&wfr=spider&for=pc seq2seq模型是以编码(En ...

  6. 推荐系统遇上深度学习(十)--GBDT+LR融合方案实战

    推荐系统遇上深度学习(十)--GBDT+LR融合方案实战 0.8012018.05.19 16:17:18字数 2068阅读 22568 推荐系统遇上深度学习系列:推荐系统遇上深度学习(一)--FM模 ...

  7. 深入浅出深度学习:原理剖析与python实践_黄安埠(著) pdf

    深入浅出深度学习:原理剖析与python实践 目录: 第1 部分 概要 1 1 绪论 2 1.1 人工智能.机器学习与深度学习的关系 3 1.1.1 人工智能——机器推理 4 1.1.2 机器学习—— ...

  8. 一文看懂Stacking!(含Python代码)

    一文看懂Stacking!(含Python代码) https://mp.weixin.qq.com/s/faQNTGgBZdZyyZscdhjwUQ

  9. 风炫安全web安全学习第三十二节课 Python代码执行以及代码防御措施

    风炫安全web安全学习第三十二节课 Python代码执行以及代码防御措施 Python 语言可能发生的命令执行漏洞 内置危险函数 eval和exec函数 eval eval是一个python内置函数, ...

随机推荐

  1. Android 使用开源库StickyGridHeaders来实现带sections和headers的GridView显示本地图片效果

    大家好!过完年回来到现在差不多一个月没写文章了,一是觉得不知道写哪些方面的文章,没有好的题材来写,二是因为自己的一些私事给耽误了,所以过完年的第一篇文章到现在才发表出来,2014年我还是会继续在CSD ...

  2. SNF快速开发平台MVC-表格单元格合并组件

    1.   表格单元格合并组件 1.1.      效果展示 1.1.1.    页面展现表格合并单元格 图 4.1 1.1.2.    导出excel合并单元格 图 4.2 1.2.      调用说 ...

  3. 【R作图】蜜蜂群图beeswarm和jitter的使用

    最近经常要画好看的盒形图,还要在上面加入散点,所以总结了两个方法. 第一种方法是,利用beeswarm函数: library(beeswarm) beeswarm 蜜蜂群图 http://rgm3.l ...

  4. .NET开发微信公众号之创建自定义菜单

    一.简介 微信公众平台服务号以及之前成功申请内测资格的订阅号都具有自定义菜单的功能.开发者可利用该功能为公众账号的会话界面底部增加自定义菜单,用户点击菜单中的选项,可以调出相应的回复信息或网页链接.自 ...

  5. T-Pot平台cowrie蜜罐暴力破解探测及实现自动化邮件告警

    前言:Cowrie是基于kippo更改的中交互ssh蜜罐, 可以对暴力攻击账号密码等记录,并提供伪造的文件系统环境记录黑客操作行为, 并保存通过wget/curl下载的文件以及通过SFTP.SCP上传 ...

  6. 设置全局git忽略文件 gitconfig

    cat ~/.gitconfig [user] email = yuanhuikai@liquidnetwork.com name = yuanhuikai[core] excludesfile = ...

  7. 【css】zSass - 用 sass 编写 css

    zSass 是自己整理的一个 sass 库,参考了 sassCore. 目录结构 variables.scss 默认值设置. reset.scss 重置浏览器样式.(参考:normalize) com ...

  8. An SPI class of type org.apache.lucene.codecs.PostingsFormat with name 'Lucene50' does not exist. You need to add the corresponding JAR file supporting this SPI to your classpath. The current classp

    背景介绍: 当ES中guava库与hive等组件的库冲突时,对Elasticsearch库进行shade,relocate解决库冲突问题. 当使用"org.apache.maven.plug ...

  9. Linux 目录结构_004

    前言 Linux文件系统层次标准,英文全称Filesystem Hierarchy Standard,英文简称FHS. 由于利用Linux来开发产品的团队和个人实在太多了,如果每个人都以自己的想法来配 ...

  10. [Laravel] 02 - Route and MVC

    前言 一.良心资料 英文 Laravel 框架:https://laravel.com/ 教程:https://laracasts.com/series/ laravel-from-scratch-2 ...