1. 比较一般的自定义函数:

需要注意的是,不能像sklearn那样直接定义,因为这里的y_true和y_pred是张量,不是numpy数组。示例如下:

from keras import backend

def rmse(y_true, y_pred):
return backend.sqrt(backend.mean(backend.square(y_pred - y_true), axis=-1))

用的时候直接:

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[rmse])

2. 比较复杂的如AUC函数:

AUC的计算需要整体数据,如果直接在batch里算,误差就比较大,不能合理反映整体情况。这里采用回调函数写法,每个epoch计算一次:

from sklearn.metrics import roc_auc_score

class roc_callback(keras.callbacks.Callback):
def __init__(self,training_data, validation_data): self.x = training_data[0]
self.y = training_data[1]
self.x_val = validation_data[0]
self.y_val = validation_data[1] def on_train_begin(self, logs={}):
return def on_train_end(self, logs={}):
return def on_epoch_begin(self, epoch, logs={}):
return def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.x)
roc = roc_auc_score(self.y, y_pred) y_pred_val = self.model.predict(self.x_val)
roc_val = roc_auc_score(self.y_val, y_pred_val) print('\rroc-auc: %s - roc-auc_val: %s' % (str(round(roc,4)),str(round(roc_val,4))),end=100*' '+'\n')
return def on_batch_begin(self, batch, logs={}):
return def on_batch_end(self, batch, logs={}):
return

调用回调函数示例:

model.fit(X_train, y_train, epochs=10, batch_size=4,
callbacks = [roc_callback(training_data=[X_train, y_train], validation_data=[X_test, y_test])] )

整体示例:

from tensorflow import keras
from sklearn import datasets
from sklearn import model_selection
from sklearn.metrics import roc_auc_score def rmse(y_true, y_pred):
return keras.backend.sqrt(keras.backend.mean(keras.backend.square(y_pred - y_true), axis=-1)) class roc_callback(keras.callbacks.Callback):
def __init__(self,training_data, validation_data): self.x = training_data[0]
self.y = training_data[1]
self.x_val = validation_data[0]
self.y_val = validation_data[1] def on_train_begin(self, logs={}):
return def on_train_end(self, logs={}):
return def on_epoch_begin(self, epoch, logs={}):
return def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.x)
roc = roc_auc_score(self.y, y_pred) y_pred_val = self.model.predict(self.x_val)
roc_val = roc_auc_score(self.y_val, y_pred_val) print('\rroc-auc: %s - roc-auc_val: %s' % (str(round(roc,4)),str(round(roc_val,4))),end=100*' '+'\n')
return def on_batch_begin(self, batch, logs={}):
return def on_batch_end(self, batch, logs={}):
return X, y = datasets.make_classification(n_samples=100, n_features=4, n_classes=2, random_state=2018)
X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2, random_state=2018)
print("TrainSet", X_train.shape, "TestSet", X_test.shape) model = keras.models.Sequential()
model.add(keras.layers.Dense(20, input_shape=(4,), activation='relu'))
model.add(keras.layers.Dense(1, activation='sigmoid'))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[rmse]) model.fit(X_train, y_train, epochs=10, batch_size=4,
callbacks = [roc_callback(training_data=[X_train, y_train], validation_data=[X_test, y_test])] )

运行结果:

TrainSet (80, 4) TestSet (20, 4)
Epoch 1/10
roc-auc: 0.1604 - roc-auc_val: 0.2738
80/80 [==============================] - 0s - loss: 0.8132 - rmse: 0.5298
Epoch 2/10
roc-auc: 0.4874 - roc-auc_val: 0.619
80/80 [==============================] - 0s - loss: 0.7432 - rmse: 0.5049
Epoch 3/10
roc-auc: 0.7715 - roc-auc_val: 0.9643
80/80 [==============================] - 0s - loss: 0.6821 - rmse: 0.4807
Epoch 4/10
roc-auc: 0.9602 - roc-auc_val: 1.0
80/80 [==============================] - 0s - loss: 0.6268 - rmse: 0.4560
Epoch 5/10
roc-auc: 0.9842 - roc-auc_val: 1.0
80/80 [==============================] - 0s - loss: 0.5747 - rmse: 0.4301
Epoch 6/10
roc-auc: 0.9956 - roc-auc_val: 1.0
80/80 [==============================] - 0s - loss: 0.5230 - rmse: 0.4025
Epoch 7/10
roc-auc: 0.9975 - roc-auc_val: 1.0
80/80 [==============================] - 0s - loss: 0.4743 - rmse: 0.3739
Epoch 8/10
roc-auc: 0.9987 - roc-auc_val: 1.0
80/80 [==============================] - 0s - loss: 0.4289 - rmse: 0.3454
Epoch 9/10
roc-auc: 0.9987 - roc-auc_val: 1.0...] - ETA: 0s - loss: 0.4019 - rmse: 0.3301
80/80 [==============================] - 0s - loss: 0.3830 - rmse: 0.3149
Epoch 10/10
roc-auc: 0.9987 - roc-auc_val: 1.0
80/80 [==============================] - 0s - loss: 0.3424 - rmse: 0.2865

Keras自定义评估函数的更多相关文章

  1. keras 自定义 custom 函数

    转自: https://kexue.fm/archives/4493/,感谢分享! Keras是一个搭积木式的深度学习框架,用它可以很方便且直观地搭建一些常见的深度学习模型.在tensorflow出来 ...

  2. xgboost 自定义目标函数和评估函数

    https://zhpmatrix.github.io/2017/06/29/custom-xgboost/ https://www.cnblogs.com/silence-gtx/p/5812012 ...

  3. TensorFlow自定义训练函数

    本文记录了在TensorFlow框架中自定义训练函数的模板并简述了使用自定义训练函数的优势与劣势. 首先需要说明的是,本文中所记录的训练函数模板参考自https://stackoverflow.com ...

  4. 关于jqGrig如何写自定义格式化函数将JSON数据的字符串转换为表格各个列的值

    首先介绍一下jqGrid是一个jQuery的一个表格框架,现在有一个需求就是将数据库表的数据拿出来显示出来,分别有id,name,details三个字段,其中难点就是details字段,它的数据是这样 ...

  5. 自定义el函数

    1.1.1 自定义EL函数(EL调用Java的函数) 第一步:创建一个Java类.方法必须是静态方法. public static String sayHello(String name){ retu ...

  6. ORACLE 自定义聚合函数

    用户可以自定义聚合函数  ODCIAggregate,定义了四个聚集函数:初始化.迭代.合并和终止. Initialization is accomplished by the ODCIAggrega ...

  7. SQL Server 自定义聚合函数

    说明:本文依据网络转载整理而成,因为时间关系,其中原理暂时并未深入研究,只是整理备份留个记录而已. 目标:在SQL Server中自定义聚合函数,在Group BY语句中 ,不是单纯的SUM和MAX等 ...

  8. Matlab中如何将(自定义)函数作为参数传递给另一个函数

    假如我们编写了一个积分通用程序,想使它更具有通用性,那么可以把被积函数也作为一个参数.在c/c++中,可以使用函数指针来实现上边的功能,在matlab中如何实现呢?使用函数句柄--这时类似于函数指针的 ...

  9. python 自定义排序函数

    自定义排序函数 Python内置的 sorted()函数可对list进行排序: >>>sorted([36, 5, 12, 9, 21]) [5, 9, 12, 21, 36] 但 ...

随机推荐

  1. HDU1232——畅通工程

    #include<stdio.h> ]; int find(int x) //查找根节点 { int r=x; while (pre[r]!=r) //返回根节点 r r=pre[r]; ...

  2. C++解析(12):初始化列表与对象构造顺序、析构顺序

    0.目录 1.类成员的初始化 2.类中的const成员 3.对象的构造顺序 3.1 局部对象的构造顺序 3.2 堆对象的构造顺序 3.3 全局对象的构造顺序 4.对象的析构顺序 5.小结 1.类成员的 ...

  3. C++解析(3):布尔类型与三目运算符

    0.目录 1.布尔类型 2.三目运算符 3.小结 1.布尔类型 C++中的布尔类型: C++在C语言的基本类型系统之上增加了bool C++中的bool可取的值只有true和false 理论上bool ...

  4. [CQOI2011]动态逆序对 CDQ分治

    洛谷上有2道相同的题目(基本是完全相同的,输入输出格式略有不同) ---题面--- ---题面--- CDQ分治 首先由于删除是很不好处理的,所以我们把删除改为插入,然后输出的时候倒着输出即可 首先这 ...

  5. 【BZOJ4184】shallot(线段树分治,线性基)

    [BZOJ4184]shallot(线段树分治,线性基) 题面 权限题啊.....好烦.. Description 小苗去市场上买了一捆小葱苗,她突然一时兴起,于是她在每颗小葱苗上写上一个数字,然后把 ...

  6. 【BZOJ4828】【HNOI2017】大佬(动态规划)

    [BZOJ4828][HNOI2017]大佬(动态规划) 题面 BZOJ 洛谷 LOJ 人们总是难免会碰到大佬.他们趾高气昂地谈论凡人不能理解的算法和数据结构,走到任何一个地方,大佬的气场 就能让周围 ...

  7. linux服务之NTP时间服务器

    1. NTP简介 NTP(Network Time Protocol,网络时间协议)是用来使网络中的各个计算机时间同步的一种协议.它的用途是把计算机的时钟同步到世界协调时UTC,其精度在局域网内可达0 ...

  8. redis的Pub/Sub功能

    Pub/Sub功能(即Publish,Subscribe)意思是发布及订阅功能.简单的理解就像我们订阅blog一样,不同的是,这里的客户端与server端采用长连接建立推送机制,一个客户端发布消息,可 ...

  9. Spring MVC 向前台页面传值-ModelAndView

    ModelAndView 该对象中包含了一个model属性和一个view属性 model:其实是一个ModelMap类型.其实ModelMap是一个LinkedHashMap的子类 view:包含了一 ...

  10. Hibernate学习(3)- *.hbm.xml详解

    <?xml version="1.0" encoding="UTF-8"?> <!DOCTYPE hibernate-mapping PUBL ...