keras中自定义Layer
最近在学习SSD的源码,其中有两个自定的层,特此学习一下并记录。
import keras.backend as K
from keras.engine.topology import InputSpec
from keras.engine.topology import Layer
import numpy as np class L2Normalization(Layer):
'''
Performs L2 normalization on the input tensor with a learnable scaling parameter
as described in the paper "Parsenet: Looking Wider to See Better" (see references)
and as used in the original SSD model. Arguments:
gamma_init (int): The initial scaling parameter. Defaults to 20 following the
SSD paper. Input shape:
4D tensor of shape `(batch, channels, height, width)` if `dim_ordering = 'th'`
or `(batch, height, width, channels)` if `dim_ordering = 'tf'`. Returns:
The scaled tensor. Same shape as the input tensor.
''' def __init__(self, gamma_init=20, **kwargs):
if K.image_dim_ordering() == 'tf':
self.axis = 3
else:
self.axis = 1
self.gamma_init = gamma_init
super(L2Normalization, self).__init__(**kwargs) def build(self, input_shape):
self.input_spec = [InputSpec(shape=input_shape)]
gamma = self.gamma_init * np.ones((input_shape[self.axis],))
self.gamma = K.variable(gamma, name='{}_gamma'.format(self.name))
self.trainable_weights = [self.gamma]
super(L2Normalization, self).build(input_shape) def call(self, x, mask=None):
output = K.l2_normalize(x, self.axis)
output *= self.gamma
return output
首先说一下这个层是用来做什么的。就是对于每一个通道进行归一化,不过通道使用的是不同的归一化参数,也就是说这个参数是需要进行学习的,因此需要通过 自定义层来完成。
在keras中,每个层都是对象,真的,可以通过dir(Layer对象)来查看具有哪些属性。
具体说来:
__init__():用来进行初始化的(这不是废话么),gamma就是要学习的参数。
bulid():是用来创建这层的权重向量的,也就是要学习的参数“壳”。
33:设置该层的input_spec,这个是通过InputSpec函数来实现。
34:分配权重“壳”的实际空间大小
35,:由于底层使用的Tensorflow来进行实现的,因此这里使用Tensorflow中的variable来保存变量。
36:根据keras官网的要求,可训练的权重是要添加至trainable_weights列表中的
37:我不想说了,官网给的实例都是这么做的。
call():用来进行具体实现操作的。
40:沿着指定的轴对输入数据进行L2正则化
41:使用学习的gamma来对正则化后的数据进行加权
42:将最后的数据最为该层的返回值,这里由于是和输入形式相同的,因此就没有了compute_output_shape函数,如果输入和输出的形式不同,就需要进行输入的调整。
就这样子吧。
keras中自定义Layer的更多相关文章
- keras中保存自定义层和loss
在keras中保存模型有几种方式: (1):使用callbacks,可以保存训练中任意的模型,或选择最好的模型 logdir = './callbacks' if not os.path.exists ...
- keras中的loss、optimizer、metrics
用keras搭好模型架构之后的下一步,就是执行编译操作.在编译时,经常需要指定三个参数 loss optimizer metrics 这三个参数有两类选择: 使用字符串 使用标识符,如keras.lo ...
- keras中的mask操作
使用背景 最常见的一种情况, 在NLP问题的句子补全方法中, 按照一定的长度, 对句子进行填补和截取操作. 一般使用keras.preprocessing.sequence包中的pad_sequenc ...
- iOS开发UI篇—CAlayer(自定义layer)
iOS开发UI篇—CAlayer(自定义layer) 一.第一种方式 1.简单说明 以前想要在view中画东西,需要自定义view,创建一个类与之关联,让这个类继承自UIView,然后重写它的Draw ...
- iOS 自定义layer的两种方式
在iOS中,你能看得见摸得着的东西基本都是UIView,比如一个按钮,一个标签,一个文本输入框,这些都是UIView: 其实UIView之所以能显示在屏幕上,完全是因为它内部的一个图层 在创建UIVi ...
- Keras中RNN不定长输入的处理--padding and masking
在使用RNN based model处理序列的应用中,如果使用并行运算batch sample,我们几乎一定会遇到变长序列的问题. 通常解决变长的方法主要是将过长的序列截断,将过短序列用0补齐到一个固 ...
- Keras网络层之“关于Keras的层(Layer)”
关于Keras的“层”(Layer) 所有的Keras层对象都有如下方法: layer.get_weights():返回层的权重(numpy array) layer.set_weights(weig ...
- IOS 自定义Layer(图层)
方式1: @interface NJViewController () @end @implementation NJViewController - (void)viewDidLoad { [sup ...
- iOS开发UI篇—自定义layer
一.第一种方式 1.简单说明 以前想要在view中画东西,需要自定义view,创建一个类与之关联,让这个类继承自UIView,然后重写它的DrawRect:方法,然后在该方法中画图. 绘制图形的步骤: ...
随机推荐
- [笔记]Go语言实现同一结构体适配多种消息源
问题: 提供天气信息的网站有很多,每家的数据及格式都不同,为了适配各种不同的天气接口,写了如下程序. 代码如下: package main import ( "encoding/json&q ...
- LeetCode:组合总数III【216】
LeetCode:组合总数III[216] 题目描述 找出所有相加之和为 n 的 k 个数的组合.组合中只允许含有 1 - 9 的正整数,并且每种组合中不存在重复的数字. 说明: 所有数字都是正整数. ...
- LeetCode:逆波兰表达式求值【150】
LeetCode:逆波兰表达式求值[150] 题目描述 根据逆波兰表示法,求表达式的值. 有效的运算符包括 +, -, *, / .每个运算对象可以是整数,也可以是另一个逆波兰表达式. 说明: 整数除 ...
- Python 中全局变量的实现
一.概述 Python 中全局变量的使用场景不多,但偶尔也有用武之处. 如在函数中的初始化,有时需要从外部传入一个全局变量加以控制.或者在函数中,使用连接池时,也可能有使用全局变量的需要. 广义上的全 ...
- Linux 搭建 SVN
一.yum 安装 subversion yum -y install subversion 二.创建svn版本库所在路径(建议放在opt.usr.home下) mkdir -p /usr/local/ ...
- Ubuntu16.04安装Appium
准备工作 1.安装Node 下载地址:https://nodejs.org/en/download/ 下载完后解压,设置环境变量 配置Node环境变量$sudo vim /etc/profile 在文 ...
- 基于jQuery和Bootstrap的手风琴垂直菜单
在线演示 本地下载
- Oracle大总结
maven的常见两个指令说明 mvn install 是将你打好的jar包安装到你的本地库中,一般没有设置过是在 用户目录下的 .m2\下面.mvn package 只是将你的代码打包到输出目录,一般 ...
- 解决org.apache.hadoop.io.nativeio.NativeIO$Windows.access0(Ljava/lang/String;I)Z
这个问题来的有点莫名奇妙,之前我的hadoop运行一直是正常的,某一天开始运行Mapreduce就报这个错. 试过很多种方法都没有用,比如 1.path环境变量2.Hadoop bin目录下hadoo ...
- UVA 1640 The Counting Problem(按位dp)
题意:给你整数a.b,问你[a,b]间每个数字分解成单个数字后,0.1.2.3.4.5.6.7.8.9,分别有多少个 题解:首先找到[0,b]与[0,a-1]进行区间减法,接着就只是求[0,x] 对于 ...