一、不含参数层

通过继承Block自定义了一个将输入减掉均值的层:CenteredLayer类,并将层的计算放在forward函数里,

from mxnet import nd, gluon
from mxnet.gluon import nn class CenteredLayer(nn.Block):
def __init__(self, **kwargs):
super(CenteredLayer, self).__init__(**kwargs) def forward(self, x):
return x - x.mean() # 直接使用这个层
layer = CenteredLayer()
# layer(nd.array([1, 2, 3, 4, 5])) # 构建更复杂模型
net = nn.Sequential()
net.add(nn.Dense(128))
net.add(nn.Dense(10))
net.add(CenteredLayer()) # 初始化、运行……
net.initialize()
y = net(nd.random.uniform(shape=(4, 8)))

二、含参数层

注意,本节实现的自定义层不能自动推断输入尺寸,需要手动指定

见上节『MXNet』第三弹_Gluon模型参数在自定义层的时候我们常使用Block自带的ParameterDict类添加成员变量params,如下,

from mxnet import gluon
from mxnet.gluon import nn class MyDense(nn.Block):
def __init__(self, units, in_units, **kwargs):
super(MyDense, self).__init__(**kwargs)
self.weight = self.params.get('weight', shape=(in_units, units))
self.bias = self.params.get('bias', shape=(units,)) def forward(self, x):
linear = nd.dot(x, self.weight.data()) + self.bias.data()
return nd.relu(linear) # 实际运行
dense = MyDense(5, in_units=10)

如果不想使用ParameterDict类则需要一下操作

# self.weight = self.params.get('weight', shape=(in_units, units))
self.weight = gluon.Parameter('weight', shape=(in_units, units))
self.params.update({'weight':self.weight})

否则在net.initialize()初始化时是初始化不到ParameterDict外变量的。

有关这一点详见下面:

    def __init__(self, conv_arch, dropout_keep_prob, **kwargs):
super(SSD, self).__init__(**kwargs)
self.vgg_conv = nn.Sequential()
self.vgg_conv.add(repeat(*conv_arch[0], pool=False))
[self.vgg_conv.add(repeat(*conv_arch[i])) for i in range(1, len(conv_arch))]
# 迭代器对象只能进行单次迭代,所以将之转化为tuple,否则识别参数处迭代后forward再次迭代直接跳出循环
# self.vgg_conv = tuple([repeat(*conv_arch[i])
# for i in range(len(conv_arch))])
# 只能识别实例属性直接为mx层函数或者mx序列对象的参数,如果使用其他容器,需要将参数收集进参数字典
# _ = [self.params.update(block.collect_params()) for block in self.vgg_conv] def forward(self, x, feat_layers):
end_points = {'block0': x}
for (index, block) in enumerate(self.vgg_conv):
end_points.update({'block{:d}'.format(index+1): block(end_points['block{:d}'.format(index)])})
return end_points

属性对象是mxnet的对象时才能默认识别层中的参数,否则需要显式收集进self.params中。

测试代码:

if __name__ == '__main__':

    ssd = SSD(conv_arch=((2, 64), (2, 128), (3, 256), (3, 512), (3, 512)),
dropout_keep_prob=0.5)
ssd.initialize()
X = mx.ndarray.random.uniform(shape=(1, 1, 304, 304))
import pprint as pp
pp.pprint([x[1].shape for x in ssd(X).items()])

自行验证即可。

『MXNet』第四弹_Gluon自定义层的更多相关文章

  1. 『MXNet』第三弹_Gluon模型参数

    MXNet中含有init包,它包含了多种模型初始化方法. from mxnet import init, nd from mxnet.gluon import nn net = nn.Sequenti ...

  2. 『MXNet』第六弹_Gluon性能提升

    一.符号式编程 1.命令式编程和符号式编程 命令式: def add(a, b): return a + b def fancy_func(a, b, c, d): e = add(a, b) f = ...

  3. 『MXNet』第六弹_Gluon性能提升 静态图 动态图 符号式编程 命令式编程

    https://www.cnblogs.com/hellcat/p/9084894.html 目录 一.符号式编程 1.命令式编程和符号式编程 2.MXNet的符号式编程 二.惰性计算 用同步函数实际 ...

  4. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  5. 『MXNet』第十弹_物体检测SSD

    全流程地址 一.辅助API介绍 mxnet.image.ImageDetIter 图像检测迭代器, from mxnet import image from mxnet import nd data_ ...

  6. 『MXNet』第八弹_数据处理API_下_Image IO专题

    想学习MXNet的同学建议看一看这位博主的博客,受益良多. 在本节中,我们将学习如何在MXNet中预处理和加载图像数据. 在MXNet中加载图像数据有4种方式. 使用 mx.image.imdecod ...

  7. 『MXNet』第八弹_数据处理API_上

    一.Gluon数据加载 下面的两个dataset处理类一般会成对出现,两个都可做预处理,但是由于后面还可能用到原始图片,.ImageFolderDataset不加预处理的话可以满足,所以建议在.Dat ...

  8. 『MXNet』第十一弹_符号式编程初探

    一.符号分类 符号对我们想要进行的计算进行了描述, 下图展示了符号如何对计算进行描述. 我们定义了符号变量A, 符号变量B, 生成了符号变量C, 其中, A, B为参数节点, C为内部节点! mxne ...

  9. 『MXNet』第七弹_多GPU并行程序设计

    资料原文 一.概述思路 假设一台机器上有个GPU.给定需要训练的模型,每个GPU将分别独立维护一份完整的模型参数. 在模型训练的任意一次迭代中,给定一个小批量,我们将该批量中的样本划分成份并分给每个G ...

随机推荐

  1. (转)开源项目miaosha(下)

    石墨文档:https://shimo.im/docs/2XlwliBQAYsKCHbq/ (二期)20.开源秒杀项目miaosha解读(下) [课程20]jmeter.xmind81.5KB [课程2 ...

  2. (转载)Rime输入法—鼠须管(Squirrel)词库添加及配置

    为什么用Rime 13年底的时候,日本爆出百度的日本版本输入法的问题,要求政府人员停用,没当回事,反正我没用,当然了,有关搜狗和用户隐私有关的问题就一直没有中断过,也没太在意.但,前几天McAfee爆 ...

  3. UVa 11488 超级前缀集合(Trie的应用)

    https://vjudge.net/problem/UVA-11488 题意: 给定一个字符串集合S,定义P(s)为所有字符串的公共前缀长度与S中字符串个数的乘积.比如P( {000, 001, 0 ...

  4. HDU 5727 Necklace(全排列+二分图匹配)

    http://acm.split.hdu.edu.cn/showproblem.php?pid=5727 题意:现在有n个阳珠子和n个阴珠子,现在要把它们串成项链,要求是阴阳珠子间隔串,但是有些阴阳珠 ...

  5. Codeforces 729E Subordinates

    题目链接:http://codeforces.com/problemset/problem/729/E 既然每一个人都有一个顶头上司,考虑一个问题: 如果这些人中具有上司数目最多的人有$x$个上司,那 ...

  6. javascript 创建video元素

    <!DOCTYPE html> <html> <body> <h3>演示如何创建 VIDEO 元素</h3> <p>请点击按钮来 ...

  7. Java+selenium 爬Boss直聘中职位信息,薪资水平和职位描述

      需要下载合适的selenium webdirver jar包和对应浏览器的驱动jar包 import org.openqa.selenium.By; import org.openqa.selen ...

  8. NoSQL(not only struts query language)的简单介绍

    为什么需要NoSQL? 互联网自扩大规模来一直面临3个问题 1.High performance高并发 一个网站开发实时生成动态页面可能会存在高并发请求的需求,硬盘IO已经无法接受 2.Huge St ...

  9. Bootstrap 4正式发布还有意义吗?

    历经三年开发,前端框架Bootstrap 4正式发布了.然而今天的Web世界已经和当初Mark Otto发布Bootstrap时的情况大为不同,一些开发者由此质疑它的更新是否还有意义. V4版本的主要 ...

  10. 力扣(LeetCode)219. 存在重复元素 II

    给定一个整数数组和一个整数 k,判断数组中是否存在两个不同的索引 i 和 j,使得 nums [i] = nums [j],并且 i 和 j 的差的绝对值最大为 k. 示例 1: 输入: nums = ...