如何在scikit-learn模型中使用Keras

通过用 KerasClassifier 或 KerasRegressor 类包装Keras模型,可将其用于scikit-learn。

要使用这些包装,必须定义一个函数,以便按顺序模式创建并返回Keras,然后当构建 KerasClassifier 类时,把该函数传递给 build_fn 参数。

例如:

def create_model():
...
return model model = KerasClassifier(build_fn=create_model)

KerasClassifier类 的构建器为可以采取默认参数,并将其被传递给 model.fit() 的调用函数,比如 epochs数目和批尺寸(batch size)。

例如:

def create_model():
...
return model model = KerasClassifier(build_fn=create_model, nb_epoch=10)

KerasClassifier类的构造也可以使用新的参数,使之能够传递给自定义的create_model()函数。这些新的参数,也必须由使用默认参数的 create_model() 函数的签名定义。

例如:

def create_model(dropout_rate=0.0):
...
return model model = KerasClassifier(build_fn=create_model, dropout_rate=0.2)

pred = estimator.predict(X_test)#返回给定测试数据的类预测。
pred1=estimator.predict_proba(X_test)#返回给定测试数据的类概率估计。
# pred3=estimator.score(X_test,Y_test)#返回给定测试数据和标签的平均精度。
print(X_test)#
print(Y_test)#实际类别
print(pred)#预测类别

print(pred1)

[[0. 1. 0. ... 1. 0. 0.]
[0. 0. 1. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 1. 1. ... 0. 0. 0.]]
[[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
...
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 1. 0. 0. 0.]]
[5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5]
[[0.02377683 0.0266185 0.04945414 0.08426233 0.04495123 0.77093697]
[0.02115186 0.01721832 0.03360457 0.05283894 0.05303674 0.82214963]
[0.00838055 0.01647644 0.02293482 0.05378568 0.057558 0.8408645 ]
...
[0.01674003 0.01713392 0.03502046 0.03685626 0.03512193 0.85912746]
[0.0494712 0.0336375 0.05689533 0.03956604 0.04415505 0.77627486]
[0.04764625 0.04542363 0.08352048 0.15077472 0.10701337 0.5656215 ]]

estimator = KerasClassifier的更多相关文章

  1. 【Python与机器学习】:利用Keras进行多类分类

    多类分类问题本质上可以分解为多个二分类问题,而解决二分类问题的方法有很多.这里我们利用Keras机器学习框架中的ANN(artificial neural network)来解决多分类问题.这里我们采 ...

  2. Python机器学习笔记:利用Keras进行分类预测

    Keras是一个用于深度学习的Python库,它包含高效的数值库Theano和TensorFlow. 本文的目的是学习如何从csv中加载数据并使其可供Keras使用,如何用神经网络建立多类分类的数据进 ...

  3. Keras人工神经网络多分类(SGD)

    import numpy as np import pandas as pd from keras.models import Sequential from keras.layers import ...

  4. python多标签分类模版

    from sklearn.multioutput import MultiOutputClassifier from sklearn.ensemble import RandomForestClass ...

  5. np_utils.to_categorical

    https://blog.csdn.net/zlrai5895/article/details/79560353 多类分类问题本质上可以分解为多个二分类问题,而解决二分类问题的方法有很多.这里我们利用 ...

  6. 3.2. Grid Search: Searching for estimator parameters

    3.2. Grid Search: Searching for estimator parameters Parameters that are not directly learnt within ...

  7. 机器学习笔记5-Tensorflow高级API之tf.estimator

    前言 本文接着上一篇继续来聊Tensorflow的接口,上一篇中用较低层的接口实现了线性模型,本篇中将用更高级的API--tf.estimator来改写线性模型. 还记得之前的文章<机器学习笔记 ...

  8. [sklearn]官方例程-Imputing missing values before building an estimator 随机填充缺失值

    官方链接:http://scikit-learn.org/dev/auto_examples/plot_missing_values.html#sphx-glr-auto-examples-plot- ...

  9. tensorflow estimator API小栗子

    TensorFlow的高级机器学习API(tf.estimator)可以轻松配置,训练和评估各种机器学习模型. 在本教程中,您将使用tf.estimator构建一个神经网络分类器,并在Iris数据集上 ...

随机推荐

  1. day 06 列表去重, 数据类型的补充,编码,深浅copy

    因为重要,所以放前面 列表去重 l1 = [1, 2, 3, 4, 5] l2 = [3, 4, 5, 6, 7] set = list(set(l1 + l2)) # set自动去重,然后变成lis ...

  2. linux 常见基础知识(此文章将会在整个linux学习过程中,不断添加)

    1,linux 文件类型 普通文件 目录文件 链接文件 块设备 字符设备 Socket 管道文件 - d l b c s p 2,linux 文件属性 蓝色 绿色 浅蓝色 红色 灰色 目录 可执行文件 ...

  3. 半吊子的STM32 — IIC通信

    半双工通信模式:以字节模式发送(8位): 两线式串行总线,SDA(数据信号)和SCL(时钟信号)两条信号线都为高电平时,总线为空闲状态:起始时,SCL稳定为高电平,SDA电平由高向低跳变:停止时,SC ...

  4. Json解析数据

    Json数据解析(重点网址推荐:www.json.org   code.google.com/   https://www.json.com/) 1:什么是Json? 2:Json数据格式的特点? 3 ...

  5. git 本地仓库与远程仓库的连接

    在远程如github新建一个项目名称为blog, 本地项目为store,是一个laravel框架项目,首先用 git init初始化本目,然后用git remote add origin git@gi ...

  6. 2D情况下,复数的意义代表旋转

    4 x i x i = - 4 就是"4"在数轴上旋转了180度. 那么4 x i = 4i 就旋转了90度. 复数的意义就表示旋转 乘以-1,表示x正半轴的数,围绕原点,逆时针偏 ...

  7. ObjC.primitive-methods

    Primitive Method "When it comes to subclassing, knowing which methods are ‘primitive’ methods i ...

  8. svn conflict问题解决办法

    转自:http://www.cnblogs.com/aaronLinux/p/5521844.html 目录: 1. 同一处修改文件冲突 1.1. 解决方式一 1.2. 解决方式二 1.3. 解决总结 ...

  9. Windows 修改的hosts记录没有效果

    windows修改的hosts记录没有效果,新添加的也没有效果. 检查DNS设置相关的均正常, <Dns client为此计算机解析和缓冲域名系统 (DNS) 名称.> 为此计算机注册并更 ...

  10. HDU 4455.Substrings

    Substrings Time Limit: 10000/5000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others)Total ...