一、不含参数层

通过继承Block自定义了一个将输入减掉均值的层:CenteredLayer类,并将层的计算放在forward函数里,

from mxnet import nd, gluon
from mxnet.gluon import nn class CenteredLayer(nn.Block):
def __init__(self, **kwargs):
super(CenteredLayer, self).__init__(**kwargs) def forward(self, x):
return x - x.mean() # 直接使用这个层
layer = CenteredLayer()
# layer(nd.array([1, 2, 3, 4, 5])) # 构建更复杂模型
net = nn.Sequential()
net.add(nn.Dense(128))
net.add(nn.Dense(10))
net.add(CenteredLayer()) # 初始化、运行……
net.initialize()
y = net(nd.random.uniform(shape=(4, 8)))

二、含参数层

注意,本节实现的自定义层不能自动推断输入尺寸,需要手动指定

见上节『MXNet』第三弹_Gluon模型参数在自定义层的时候我们常使用Block自带的ParameterDict类添加成员变量params,如下,

from mxnet import gluon
from mxnet.gluon import nn class MyDense(nn.Block):
def __init__(self, units, in_units, **kwargs):
super(MyDense, self).__init__(**kwargs)
self.weight = self.params.get('weight', shape=(in_units, units))
self.bias = self.params.get('bias', shape=(units,)) def forward(self, x):
linear = nd.dot(x, self.weight.data()) + self.bias.data()
return nd.relu(linear) # 实际运行
dense = MyDense(5, in_units=10)

如果不想使用ParameterDict类则需要一下操作

# self.weight = self.params.get('weight', shape=(in_units, units))
self.weight = gluon.Parameter('weight', shape=(in_units, units))
self.params.update({'weight':self.weight})

否则在net.initialize()初始化时是初始化不到ParameterDict外变量的。

有关这一点详见下面:

    def __init__(self, conv_arch, dropout_keep_prob, **kwargs):
super(SSD, self).__init__(**kwargs)
self.vgg_conv = nn.Sequential()
self.vgg_conv.add(repeat(*conv_arch[0], pool=False))
[self.vgg_conv.add(repeat(*conv_arch[i])) for i in range(1, len(conv_arch))]
# 迭代器对象只能进行单次迭代,所以将之转化为tuple,否则识别参数处迭代后forward再次迭代直接跳出循环
# self.vgg_conv = tuple([repeat(*conv_arch[i])
# for i in range(len(conv_arch))])
# 只能识别实例属性直接为mx层函数或者mx序列对象的参数,如果使用其他容器,需要将参数收集进参数字典
# _ = [self.params.update(block.collect_params()) for block in self.vgg_conv] def forward(self, x, feat_layers):
end_points = {'block0': x}
for (index, block) in enumerate(self.vgg_conv):
end_points.update({'block{:d}'.format(index+1): block(end_points['block{:d}'.format(index)])})
return end_points

属性对象是mxnet的对象时才能默认识别层中的参数,否则需要显式收集进self.params中。

测试代码:

if __name__ == '__main__':

    ssd = SSD(conv_arch=((2, 64), (2, 128), (3, 256), (3, 512), (3, 512)),
dropout_keep_prob=0.5)
ssd.initialize()
X = mx.ndarray.random.uniform(shape=(1, 1, 304, 304))
import pprint as pp
pp.pprint([x[1].shape for x in ssd(X).items()])

自行验证即可。

『MXNet』第四弹_Gluon自定义层的更多相关文章

  1. 『MXNet』第三弹_Gluon模型参数

    MXNet中含有init包,它包含了多种模型初始化方法. from mxnet import init, nd from mxnet.gluon import nn net = nn.Sequenti ...

  2. 『MXNet』第六弹_Gluon性能提升

    一.符号式编程 1.命令式编程和符号式编程 命令式: def add(a, b): return a + b def fancy_func(a, b, c, d): e = add(a, b) f = ...

  3. 『MXNet』第六弹_Gluon性能提升 静态图 动态图 符号式编程 命令式编程

    https://www.cnblogs.com/hellcat/p/9084894.html 目录 一.符号式编程 1.命令式编程和符号式编程 2.MXNet的符号式编程 二.惰性计算 用同步函数实际 ...

  4. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  5. 『MXNet』第十弹_物体检测SSD

    全流程地址 一.辅助API介绍 mxnet.image.ImageDetIter 图像检测迭代器, from mxnet import image from mxnet import nd data_ ...

  6. 『MXNet』第八弹_数据处理API_下_Image IO专题

    想学习MXNet的同学建议看一看这位博主的博客,受益良多. 在本节中,我们将学习如何在MXNet中预处理和加载图像数据. 在MXNet中加载图像数据有4种方式. 使用 mx.image.imdecod ...

  7. 『MXNet』第八弹_数据处理API_上

    一.Gluon数据加载 下面的两个dataset处理类一般会成对出现,两个都可做预处理,但是由于后面还可能用到原始图片,.ImageFolderDataset不加预处理的话可以满足,所以建议在.Dat ...

  8. 『MXNet』第十一弹_符号式编程初探

    一.符号分类 符号对我们想要进行的计算进行了描述, 下图展示了符号如何对计算进行描述. 我们定义了符号变量A, 符号变量B, 生成了符号变量C, 其中, A, B为参数节点, C为内部节点! mxne ...

  9. 『MXNet』第七弹_多GPU并行程序设计

    资料原文 一.概述思路 假设一台机器上有个GPU.给定需要训练的模型,每个GPU将分别独立维护一份完整的模型参数. 在模型训练的任意一次迭代中,给定一个小批量,我们将该批量中的样本划分成份并分给每个G ...

随机推荐

  1. git删除远程分支文件,不改变本地文件

    git提交项目时候踩的Git的坑 特别 由于准备春招,所以希望各位看客方便的话,能去github上面帮我Star一下项目 https://github.com/Draymonders/Campus-S ...

  2. P3455 [POI2007]ZAP-Queries(莫比乌斯反演)

    思路 和YY的GCD类似但是更加简单了 类似的推一波公式即可 \[ F(n)=\sum_{n|d}f(d) \] \[ f(n)=\sum_{n|d}\mu(\frac{d}{n})F(d) \] \ ...

  3. Linux 解决 firefox 中文页面乱码问题

    1.由于 firefox 默认是允许网页自己选择字体,在 Linux 上便会出现部分网站的乱码情况.因此可以取消允许页面自己选择字体这个选项便能解决部分乱码情况.

  4. js字符串与十六进制相互转换

    1.字符串(汉字)转换为十六进制 主要使用字符串.charCodeAt()方法,此方法返回一个字符的Unicode值,再用toString(16)方法,该方法是先将数字对象转换为二进制,再把二进制转化 ...

  5. WebStorm破解方法

    http://www.jianshu.com/p/85266fa16639 http://idea.lanyus.com/ webstorm 入门指南 破解方法 1. 下载的WebStorm http ...

  6. DataGrip激活码

    引言: 网上有有很多datagirp的激活码,但是经过尝试很多都失效了,找了半天终于 找到了一个可用的激活码! 1. 激活码 适用版本: DataGrip2018.2.3,2018.1.1,其他版本没 ...

  7. win10常用命令和设置总结

    1.常用命令 exit:退出cmd面板; cls:清除cmd面板; 2.常用设置 2.1 services.msc 禁用:以后怎样都不会运行;手动:是打开某些用到它的程序要用到该服务时才会运行; 自动 ...

  8. 前端阶段_html部分

    HTML 1.html5的第一行一定是<!DOCTYPE html>,h4太长,而且一般ide中会自动加载,了解即可. 2.h5的整个页面被<html></html> ...

  9. [原][osg][osgearth]倾斜摄影2.文件格式分析:OSGB

    倾斜摄影三维模型格式包含:*.osgb,*.dae等 文件格式包含:*.xml, *.desc, *.lfp等 例如:LocaSpace Viewer软件把osgb分块模型文件建立索引生成一个lfp文 ...

  10. 如何编写一个d.ts文件

    这篇文章主要讲怎么写一个typescript的描述文件(以d.ts结尾的文件名,比如xxx.d.ts). 2018.12.18更新说明: 1.增加了全局声明的原理说明. 2.增加了es6的import ...