『开发技巧』Keras自定义对象(层、评价函数与损失)
1.自定义层
对于简单、无状态的自定义操作,你也许可以通过 layers.core.Lambda
层来实现。但是对于那些包含了可训练权重的自定义层,你应该自己实现这种层。
这是一个 Keras2.0 中,Keras 层的骨架(如果你用的是旧的版本,请更新到新版)。你只需要实现三个方法即可:
build(input_shape)
: 这是你定义权重的地方。这个方法必须设self.built = True
,可以通过调用super([Layer], self).build()
完成。call(x)
: 这里是编写层的功能逻辑的地方。你只需要关注传入call
的第一个参数:输入张量,除非你希望你的层支持masking。compute_output_shape(input_shape)
: 如果你的层更改了输入张量的形状,你应该在这里定义形状变化的逻辑,这让Keras能够自动推断各层的形状。
from keras import backend as K
from keras.engine.topology import Layer
class MyLayer(Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MyLayer, self).__init__(**kwargs)
def build(self, input_shape):
# 为该层创建一个可训练的权重
self.kernel = self.add_weight(name='kernel',
shape=(input_shape[1], self.output_dim),
initializer='uniform',
trainable=True)
super(MyLayer, self).build(input_shape) # 一定要在最后调用它
def call(self, x):
return K.dot(x, self.kernel)
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
还可以定义具有多个输入张量和多个输出张量的 Keras 层。 为此,你应该假设方法 build(input_shape)
,call(x)
和 compute_output_shape(input_shape)
的输入输出都是列表。 这里是一个例子,与上面那个相似:
from keras import backend as K
from keras.engine.topology import Layer
class MyLayer(Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MyLayer, self).__init__(**kwargs)
def build(self, input_shape):
assert isinstance(input_shape, list)
# 为该层创建一个可训练的权重
self.kernel = self.add_weight(name='kernel',
shape=(input_shape[0][1], self.output_dim),
initializer='uniform',
trainable=True)
super(MyLayer, self).build(input_shape) # 一定要在最后调用它
def call(self, x):
assert isinstance(x, list)
a, b = x
return [K.dot(a, self.kernel) + b, K.mean(b, axis=-1)]
def compute_output_shape(self, input_shape):
assert isinstance(input_shape, list)
shape_a, shape_b = input_shape
return [(shape_a[0], self.output_dim), shape_b[:-1]]
已有的 Keras 层就是实现任何层的很好例子。不要犹豫阅读源码!
2.自定义评价函数
自定义评价函数应该在编译的时候(compile)传递进去。该函数需要以 (y_true, y_pred)
作为输入参数,并返回一个张量作为输出结果。
import keras.backend as K
def mean_pred(y_true, y_pred):
return K.mean(y_pred)
model.compile(optimizer='rmsprop',
loss='binary_crossentropy',
metrics=['accuracy', mean_pred])
3.自定义损失函数
自定义损失函数也应该在编译的时候(compile)传递进去。该函数需要以 (y_true, y_pred)
作为输入参数,并返回一个张量作为输出结果。
import keras.backend as K
def my_loss(y_true, y_pred):
return K.mean(K.squre(y_pred-y_true))#以平方差举例
model.compile(optimizer='rmsprop',
loss=my_loss,
metrics=['accuracy'])
4.处理已保存模型中的自定义层(或其他自定义对象)
如果要加载的模型包含自定义层或其他自定义类或函数,则可以通过 custom_objects
参数将它们传递给加载机制:
from keras.models import load_model
# 假设你的模型包含一个 AttentionLayer 类的实例
model = load_model('my_model.h5', custom_objects={'AttentionLayer': AttentionLayer})
或者,你可以使用 自定义对象作用域:
from keras.utils import CustomObjectScope
with CustomObjectScope({'AttentionLayer': AttentionLayer}):
model = load_model('my_model.h5')
『开发技巧』Keras自定义对象(层、评价函数与损失)的更多相关文章
- 『开发技巧』Python音频操作工具PyAudio上手教程
『开发技巧』Python音频操作工具PyAudio上手教程 0.引子 当需要使用Python处理音频数据时,使用python读取与播放声音必不可少,下面介绍一个好用的处理音频PyAudio工具包. ...
- [开发技巧]·TensorFlow&Keras GPU使用技巧
[开发技巧]·TensorFlow&Keras GPU使用技巧 1.问题描述 在使用TensorFlow&Keras通过GPU进行加速训练时,有时在训练一个任务的时候需要去测试结果 ...
- 『开发技术』GPU训练加速原理(附KerasGPU训练技巧)
0.深入理解GPU训练加速原理 我们都知道用GPU可以加速神经神经网络训练(相较于CPU),具体的速度对比可以参看我之前写的速度对比博文: [深度应用]·主流深度学习硬件速度对比(CPU,GPU,TP ...
- 『电脑技巧』破解Win7/Win8登录密码
Pic via baidu 0x 00 破解思路 用户的明文密码经过单向Hash加密生成Hash散列,Hash散列又被加密存放在系统盘\Windiws\System32\config文件下 要获得明文 ...
- 『电脑技巧』浅谈Win7的文件共享设置
随着移动存储设备的普及,很少有小伙伴喜欢使用局域网“文件共享”这一捷径了 而且自从XP之后,Windows系列主机共享设置貌似比较麻烦 虽然事实并不是看上去那样(Win7也很Easy的说 = =) 现 ...
- 『开发技术』Docker开发教程(一)安装与测试(Windows 家庭版)
0.前言 针对其他系统和版本,Docker都很容易安装,可以参考官方教程:https://docs.docker.com/docker-hub/ 由于Windows10家庭版无法安装docker,因此 ...
- 『开发技术』Windows极简安装使用face_recognition
face_recognition是一个强大.简单.易上手的人脸识别开源项目,并且配备了完整的开发文档和应用案例,特别是兼容树莓派系统.此项目是世界上最简洁的人脸识别库,你可以使用Python和命令行工 ...
- 『开发技术』Ubuntu与Windows如何查看CPU&GPU&内存占用量
0 序·简介 在使用Ubuntu或者Windows执行一些复杂数据运算时,需要关注下CPU.GPU以及内存占用量,如果数据运算超出了负荷,会产生难以预测的错误.本文将演示如何用简单地方式,实时监控Ub ...
- Keras处理已保存模型中的自定义层(或其他自定义对象)
如果要加载的模型包含自定义层或其他自定义类或函数,则可以通过 custom_objects 参数将它们传递给加载机制: from keras.models import load_model # 假设 ...
随机推荐
- JNDI(Java Naming and Directory Interface)
# 前言 内容基本拷贝,整理出来,方便以后回忆. # What The Java Naming and Directory Interface™ (JNDI) is an application pr ...
- WPF读取和显示word
引言 在项目开发中,word的读取和显示会经常出现在客户的需求中.特别是一些有关法律规章制度.通知.红头文件等,都是用word发布的. 在WPF中,对显示WORD没有特定的控件,这对开发显示WORD的 ...
- Qt中使用Boost
编译BOOST库 bjam stage --toolset=qcc --without-graph --without-graph_parallel --without-math --without- ...
- 图像滤镜艺术---Swirl滤镜
原文:图像滤镜艺术---Swirl滤镜 Swirl Filter Swirl 滤镜是实现图像围绕中心点(cenX,cenY)扭曲旋转的效果,效果图如下: 原图 效果图 代码如下: // ...
- Win10《芒果TV - Preview》更新至v3.1.57.0:热门节目和电视台直播回归
Win10<芒果TV - Preview>是Win10<芒果TV>官方唯一指定内测预览版,最新的改进和功能更新将会在此版本优先体验. 为了想让大家能在12月31日看到<湖 ...
- win32内存调用图
https://msdn.microsoft.com/en-us/library/ms810603.aspxhttps://www.codeproject.com/Articles/14525/Hea ...
- 有效地查找SAP增强点
找SAP增强点一直都是SAP开发的重点难点,增强开发的代码一般不会很多,但是需要花费比较多的时间在查找增强点上 网上也流传了很多查找SAP增强的方法: 1.利用TCODE寻找增强 2.利用系统函数寻找 ...
- 跨平台网络通信与服务器框架 acl 3.2.0 发布,acl_cpp 是基于 acl 库的 C++ 库
acl 3.2.0 版本发布了,acl 是 one advanced C/C++ library 的简称,主要包括网络通信库以及服务器框架库等功能,支持 Linux/Windows/Solaris/F ...
- ABP开发框架前后端开发系列---(8)ABP框架之Winform界面的开发过程
在前面随笔介绍的<ABP开发框架前后端开发系列---(7)系统审计日志和登录日志的管理>里面,介绍了如何改进和完善审计日志和登录日志的应用服务端和Winform客户端,由于篇幅限制,没有进 ...
- 读书笔记——《谁说菜鸟不会数据分析—Python篇》
最近刚读完一本新书,关注的公众号作者出的“谁说菜鸟不会数据分析—Python篇”,话说现在很多微信公众号大牛都在出书,这貌似是一个趋势.. 说说这本书吧,我之前看过一些网文,对于数据分析这一块也有过一 ...