Mxnet学习笔记(3)--自定义Op
https://blog.csdn.net/u011765306/article/details/54562282
前言
今天因为要用到tile操作(类似np.tile,将数据沿axises进行数据扩充),结果发现mxnet中没有,而且很多操作都没实现,详细完成
度可以参看issue,还在完成中,不过这并不影响我们要用的操作,这里我们
需要实现自己的Op。当然,在官方的example/numpy-ops中已经给出部分例子。这里具体的记录一下。
自定义Op
自定义op都是去继承operator.py中的类,其中提供如下几类:
operator.py
CustomOp(object)
CustomOpProp(object)
NDArrayOp(PythonOp)
NumpyOp(PythonOp)
PythonOp(object)
这里很清晰的可以看出,operator分为两条路线,一条路线为CustomOp, 另外一条路线为继承PythonOp,这里我们就分为两部分分别介绍这两条路线。
CustomOp类
这条路线是有三步组成,第一步继承CustomOp,重写方法forward()和backward(),然后继承CustomOpProp,重写成员方法,并在方法create_operator中
调用之前写好的Op,第三步调用operator.register()对操作进行注册。具体我们结合官方代码example/numpy-ops/custom_softmax.py来解释,代码如下:
class Softmax(mx.operator.CustomOp):
def forward(self, is_train, req, in_data, out_data, aux):
x = in_data[0].asnumpy()
y = np.exp(x - x.max(axis=1).reshape((x.shape[0], 1)))
y /= y.sum(axis=1).reshape((x.shape[0], 1))
self.assign(out_data[0], req[0], mx.nd.array(y))
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
l = in_data[1].asnumpy().ravel().astype(np.int)
y = out_data[0].asnumpy()
y[np.arange(l.shape[0]), l] -= 1.0
self.assign(in_grad[0], req[0], mx.nd.array(y))
@mx.operator.register("softmax")
class SoftmaxProp(mx.operator.CustomOpProp):
def __init__(self):
super(SoftmaxProp, self).__init__(need_top_grad=False)
def list_arguments(self):
return ['data', 'label']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
data_shape = in_shape[0]
label_shape = (in_shape[0][0],)
output_shape = in_shape[0]
return [data_shape, label_shape], [output_shape], []
def create_operator(self, ctx, shapes, dtypes):
return Softmax()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
上述代码是对softmax的自定义,在类Softmax中重写forward()和backward(),这里与caffe中定义层操作类似,forward()中定义层的前向操作,backward()中
定义反向传播的梯度计算。在完成定义之后,在类SoftmaxProp中create_operator()调用并返回Softmax()实例。那么第三步register如何实现,可以看到,
在SoftmaxProp中带有装饰器mx.operator.register(),等价于操作register("custom_op")(CustomOpProp),这里即在代码运行前即完成了该Op的
实例化,与optimazer的装饰器类似。
PythonOp类
这条路线,PythonOp类为基类,而我们大多定义Op时不会去继承它,而是使用他的subclass: NDarrayOp、NumpyOp。这条路线不会像继承CustomOp那样需要三步,这里我们也是只讨论如何继承并定义操作,不去探究
这两个类的实现细节。还是拿官网例子来讲。上代码:
class NDArraySoftmax(mx.operator.NDArrayOp):
def __init__(self):
super(NDArraySoftmax, self).__init__(False)
self.fwd_kernel = None
self.bwd_kernel = None
def list_arguments(self):
return ['data', 'label']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
data_shape = in_shape[0]
label_shape = (in_shape[0][0],)
output_shape = in_shape[0]
return [data_shape, label_shape], [output_shape]
def forward(self, in_data, out_data):
x = in_data[0]
y = out_data[0]
if self.fwd_kernel is None:
self.fwd_kernel = mx.rtc('softmax', [('x', x)], [('y', y)])
self.fwd_kernel.push([x], [y], (1, 1, 1), (x.shape[0], 1, 1))
def backward(self, out_grad, in_data, out_data, in_grad):
l = in_data[1]
y = out_data[0]
dx = in_grad[0]
if self.bwd_kernel is None:
self.bwd_kernel = mx.rtc('softmax_grad', [('y', y), ('l', l)], [('dx', dx)])
self.bwd_kernel.push([y,l], [dx], (y.shape[0],1,1), (y.shape[1], 1, 1))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
继承NDArrayOp其实和NumpyOp类似,不同之处在于forward()和backward()重写方式使用函数不同,NDArrayOp中需要使用mx.nd中的操作,而
NumpyOp可以使用numpy中的操作。总之重点在forward()和backward()。当然,如此的自定义方法在使用时需要先定义类对象才可以使用。即与CunstomOp
的定义时间不同。
成员方法list_arguments,list_outpus,infer_shape
虽然继承方法不同,但是效果是一样的,forward()和backward()是对Op操作的定义,剩余三个成员方法则是对Op接口的描述。
list_arguments
该方法主要是对该Op定义时形参的命名,如上述多为['data', 'label'],那么该Op在使用时形参必须为data和label。这里也可以看出mxnet是用过名字
寻找变量的,DataIter,optimazer也是如此。
list_outputs
同样的,该方法定义了输出变量的名字,一般为opname+’_output’。
infer_shape
该方法用于在给定输入时,获取该Op的输出shape。当然,在我们自定义时,需要自己设计Op的输入和输出shape。
以上就是自定义Op时需要做的事情,重点还是forward()和backward(),有时候无头绪的时候可以参考caffe的写法获得灵感。接下来我用例子来说描述一下上述方法。
import mxnet as mx
import numpy as np
class TileLayer(mx.operator.NumpyOp):
def __init__(self, tiles, axis):
super(TileLayer, self).__init__(False)
# tiles可以为list或者一个数
self.tiles = tiles
self.axis = axis
def list_arguments(self):
return ['input']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
data_shape = in_shape[0]
output_shape = in_shape[0] + [self.tiles]
return [data_shape], [output_shape]
def forward(self, in_data, out_data):
x = in_data[0]
y = out_data[0]
y = np.tile(x, reps=self.tiles)
def backward(self, out_grad, in_data, out_data, in_grad):
bottom_diff = in_grad[0]
top_diff = np.sum(out_grad[0], axis=self.axis)
bottom_diff = top_diff
if __name__ == '__main__':
import logging
from collections import namedtuple
Batch = namedtuple('Batch', ['input'])
logging.basicConfig(level=logging.INFO)
a = mx.sym.Variable('data')
custie = TileLayer(tiles=10, axis=2)
tiles_a = custie(input=a, name='tileop')
arg_shapes, out_shape, aux_shape = tiles_a.infer_shape(data=(2, 3))
logging.info('arg_shape:{}\n, out_shape:{}\n, aux_shape:{}\n, output_blob:{}'.format(arg_shapes, out_shape, aux_shape, tiles_a.list_outputs()))
exe = mx.module.Module(symbol=tiles_a, logger=logging)
exe.bind(data_shapes=[('data', (1, 10, 10))], inputs_need_grad=True)
# exe.init_params()
# exe.init_optimizer()
# data1 = [mx.nd.ones((1, 10, 10))]
# exe.forward(Batch(data1))
# print exe.get_outputs()[0].asnumpy().shape
# top_grads =np.random.random(size=(1, 10, 10, 10))
# exe.backward(out_grads=top_grads)
# print exe.get_input_grads()[0].asnumpy()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
以上为定义的tile操作,这里没有做完全的tile操作,只是可以在最后的axis进行数据的tile操作。forward中用numpy.tile实现,backward中参考caffe
中的TileLayer实现,这里代码运行结果:
INFO:root:arg_shape:[(2L, 3L)]
out_shape:[(2L, 3L, 10L)]
aux_shape:[]
output_blob:['tileop_output']
1
2
3
4
上述代码因为在list_arguments中定义了形参名字为input,因此在使用是形参必须为input,结果中也可以看到,infer_shape以及list_output的结果,基本细节就是上述。
在我们定义好Op后,我们需要通过mx.mod.Moudle()将Op进行整合,并通过bind()来申请内存,在此之后,我们可以通过以下两种方法训练它:
分别调用init_params()初始化参数(当然这里没有参数需要初始化),init_optimazer()初始化optimazer,接下来就可以通过forward()和backward()进行前向反向传播训练模块。
或者直接调用fit()方法进行训练,因为fit()中包含初始化操作。
关于Moudle可以参看mx.mod.Module
---------------------
作者:我只是空气
来源:CSDN
原文:https://blog.csdn.net/u011765306/article/details/54562282
版权声明:本文为博主原创文章,转载请附上博文链接!
Mxnet学习笔记(3)--自定义Op的更多相关文章
- [转载]SharePoint 2013搜索学习笔记之自定义结果源
搜索中心新建好之后在搜索结果页上会默认有所有内容,人员,对话,视频这四个结果分类,每个分类会返回指定范围的搜索结果,这里我再添加了部门日志结果分类,搜索这个分类只会返回部门日志内容类型的搜索结果,要实 ...
- Hadoop学习笔记—5.自定义类型处理手机上网日志
转载自http://www.cnblogs.com/edisonchou/p/4288737.html Hadoop学习笔记—5.自定义类型处理手机上网日志 一.测试数据:手机上网日志 1.1 关于这 ...
- shiro学习笔记_0600_自定义realm实现授权
博客shiro学习笔记_0400_自定义Realm实现身份认证 介绍了认证,这里介绍授权. 1,仅仅通过配置文件来指定权限不够灵活且不方便.在实际的应用中大多数情况下都是将用户信息,角色信息,权限信息 ...
- ASP.NET MVC 学习笔记-7.自定义配置信息 ASP.NET MVC 学习笔记-6.异步控制器 ASP.NET MVC 学习笔记-5.Controller与View的数据传递 ASP.NET MVC 学习笔记-4.ASP.NET MVC中Ajax的应用 ASP.NET MVC 学习笔记-3.面向对象设计原则
ASP.NET MVC 学习笔记-7.自定义配置信息 ASP.NET程序中的web.config文件中,在appSettings这个配置节中能够保存一些配置,比如, 1 <appSettin ...
- SpringBoot学习笔记:自定义拦截器
SpringBoot学习笔记:自定义拦截器 快速开始 拦截器类似于过滤器,但是拦截器提供更精细的的控制能力,它可以在一个请求过程中的两个节点进行拦截: 在请求发送到Controller之前 在响应发送 ...
- Netty学习笔记(三) 自定义编码器
编写一个网络应用程序需要实现某种编解码器,编解码器的作用就是讲原始字节数据与自定义的消息对象进行互转.网络中都是以字节码的数据形式来传输数据的,服务器编码数据后发送到客户端,客户端需要对数据进行解码, ...
- ASP.NET MVC 学习笔记-7.自定义配置信息(后续)
自定义配置信息的高级应用 通过上篇博文对简单的自定义配置信息的学习,使得更加灵活的控制系统配置信息.实际项目中,这种配置的灵活度往往无法满足项目的灵活度和扩展性. 比如,一个配置信息有三部分组成,而每 ...
- swift学习笔记之—自定义函数的规则说明
原文出自:www.hangge.com 转载请保留原文链接:http://www.hangge.com/blog/cache/detail_517.html 1,无返回值的函数 func test( ...
- shiro学习笔记_0400_自定义realm实现身份认证
自定义Realm实现身份认证 先来看下Realm的类继承关系: Realm接口有三个方法,最重要的是第三个方法: a) String getName():返回此realm的名字 b) boolean ...
随机推荐
- linux /bin/bash^M: bad interpreter的解决办法
linux下执行shell脚本时报错:-bash: ./a.sh: /bin/bash^M: bad interpreter: No such file or directory. 原因是window ...
- 二进制搭建Kubernetes集群(最新v1.16.0版本)
目录 1.生产环境k8s平台架构 2.官方提供三种部署方式 3.服务器规划 4.系统初始化 5.Etcd集群部署 5.1.安装cfssl工具 5.2.生成etcd证书 5.2.1 创建用来生成 CA ...
- Tomcat+Nginx+Memcached综合案例
Tomcat+Nginx+Memcached综合案例 说明 通过Nginx解析静态页面并将动态负载均衡调度给后面的多个Tomcat,Tomcat解析java动态程序. 由于http是无状态的协议,你访 ...
- 关闭firefox火狐浏览器下载完成时自动扫描(49.0.2以后版本)
本人自己找到的方法,亲测有效,如下:1.在火狐浏览器地址里输入about:config回车,可能会提示“这可能使质量保证失效”,点击[我了解此风险!]2.在搜索框里输入browser.safebrow ...
- php中函数的类型提示和文件读取功能
这个没有深入. <?php function addNumbers(int $a, int $b, bool $printSum): int { $sum = $a + $b; if ($pri ...
- 小程序页面收录 sitemap
微信现已开放小程序内搜索,你的小程序页面将可能展示在微信搜索等多个公开场景中.当开发者允许微信索引时,微信会通过爬虫的形式,为小程序的页面内容建立索引. 若小程序中存在不适合展示信息如用户个人信息.商 ...
- [转] C++ explicit关键字详解
本文转自tiankong19999 首先, C++中的explicit关键字只能用于修饰只有一个参数的类构造函数, 它的作用是表明该构造函数是显示的, 而非隐式的, 跟它相对应的另一个关键字是impl ...
- 《逆袭团队》第九次团队作业【Beta】Scrum Meeting 1
项目 内容 软件工程 任课教师博客主页链接 作业链接地址 团队作业9:Beta冲刺与团队项目验收 团队名称 逆袭团队 具体目标 (1)掌握软件黑盒测试技术:(2)学会编制软件项目总结PPT.项目验收报 ...
- 用PHP的fopen函数读写robots.txt文件
以前介绍了用PHP读写文本文档制作最简单的访问计数器不需要数据库,仅仅用文本文档就可以实现网页访问计数功能.同样我们可以拓展一下这个思路,robots.txt文件对于我们网站来说非常重要,有时候我们需 ...
- HDU - 4059: The Boss on Mars (容斥 拉格朗日 小小的优化搜索)
pro: T次询问,每次给出N(N<1e8),求所有Σi^4 (i<=N,且gcd(i,N)==1) ; sol: 因为N比较小,我们可以求出素因子,然后容斥. 主要问题就是求1到P的 ...