https://blog.csdn.net/a350203223/article/details/77449630

在使用深度学习平台时,光会使用其中已定义好的操作有时候是满足不了实际使用的,一般需要我们自己定义新的操作。但是,绝大多数深度平台都是编译好的,很难再次编写。本文以Mxnet为例,官方给出四种定义新操作的方法,

分别调用:

1、mx.operator.CustomOp

2、mx.operator.NDArrayOp

3、mx.operator.NumpyOp

4、使用 C++ 定义底层

并且给出了重新定义softmax层的例子。但是sofetmax操作只有前向操作,也没有参数,与我们通常需要需要使用的情况不符,官方文档也没有一个有参数的中间层例子。在此博主给出了一个重新定义全连接操作的例子,希望能够给大家带来帮助。

# pylint: skip-file
import os
from data import mnist_iterator
import mxnet as mx
import numpy as np
import logging
from numpy import *

class Dense(mx.operator.CustomOp):

def __init__(self, num_hidden):
self.num_hidden = num_hidden

def forward(self, is_train, req, in_data, out_data, aux):
x = in_data[0]
w = in_data[1]
b = in_data[2]
y = out_data[0]
y[:] = mx.nd.add(mx.nd.dot(x, w.T), b)
self.assign(out_data[0], req[0], mx.nd.array(yy))

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
dx = in_grad[0]
dw = in_grad[1]
db = in_grad[2]
dy = out_grad[0]
x = in_data[0]
w = in_data[1]
dw[:] = mx.nd.dot(dy.T, x)
dx[:] = mx.nd.dot(dy, w)
db[:] = mx.nd.sum(dy, axis=0)
self.assign(in_grad[0], req[0], dx)
self.assign(in_grad[1], req[0], dw)
self.assign(in_grad[2], req[0], db)

@mx.operator.register("dense")
class DenseProp(mx.operator.CustomOpProp):
def __init__(self, num_hidden):
super(DenseProp, self).__init__(True)
# we use constant bias here to illustrate how to pass arguments
# to operators. All arguments are in string format so you need
# to convert them back to the type you want.
self.num_hidden = long(num_hidden)

def list_arguments(self):
return ['data', 'weight', 'bias']

def list_outputs(self):
# this can be omitted if you only have 1 output.
return ['output']

def infer_shape(self, in_shapes):
data_shape = in_shapes[0]
weight_shape = (self.num_hidden, in_shapes[0][1])
bias_shape = (self.num_hidden,)
output_shape = (data_shape[0], self.num_hidden)
return [data_shape, weight_shape, bias_shape], [output_shape], []

def infer_type(self, in_type):
dtype = in_type[0]
return [dtype, dtype, dtype], [dtype], []

def create_operator(self, ctx, in_shapes, in_dtypes):
# create and return the CustomOp class.
return Dense(self.num_hidden)

# define mlp
data = mx.symbol.Variable('data')
##This is the new defined layer
fc1 = mx.symbol.Custom(data, name='fc1', op_type='dense', num_hidden=128)
act1 = mx.symbol.Activation(data=fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data=act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
mlp = mx.symbol.Softmax(data = fc3, name = 'softmax')
train, val = mnist_iterator(batch_size=100, input_shape = (784,))
logging.basicConfig(level=logging.DEBUG)
model = mx.model.FeedForward(
ctx = mx.gpu(1), symbol = mlp, num_epoch = 20,
learning_rate = 0.1, momentum = 0.9, wd = 0.00001)
model.fit(X=train, eval_data=val,
batch_end_callback=mx.callback.Speedometer(100,100))
---------------------
作者:启功
来源:CSDN
原文:https://blog.csdn.net/a350203223/article/details/77449630
版权声明:本文为博主原创文章,转载请附上博文链接!

Mxnet:以全连接层为例子自定义新的操作(层)的更多相关文章

  1. 基于MNIST数据集使用TensorFlow训练一个包含一个隐含层的全连接神经网络

    包含一个隐含层的全连接神经网络结构如下: 包含一个隐含层的神经网络结构图 以MNIST数据集为例,以上结构的神经网络训练如下: #coding=utf-8 from tensorflow.exampl ...

  2. Tensorflow 多层全连接神经网络

    本节涉及: 身份证问题 单层网络的模型 多层全连接神经网络 激活函数 tanh 身份证问题新模型的代码实现 模型的优化 一.身份证问题 身份证号码是18位的数字[此处暂不考虑字母的情况],身份证倒数第 ...

  3. caffe之(四)全连接层

    在caffe中,网络的结构由prototxt文件中给出,由一些列的Layer(层)组成,常用的层如:数据加载层.卷积操作层.pooling层.非线性变换层.内积运算层.归一化层.损失计算层等:本篇主要 ...

  4. caffe中全卷积层和全连接层训练参数如何确定

    今天来仔细讲一下卷基层和全连接层训练参数个数如何确定的问题.我们以Mnist为例,首先贴出网络配置文件: name: "LeNet" layer { name: "mni ...

  5. 基于深度学习和迁移学习的识花实践——利用 VGG16 的深度网络结构中的五轮卷积网络层和池化层,对每张图片得到一个 4096 维的特征向量,然后我们直接用这个特征向量替代原来的图片,再加若干层全连接的神经网络,对花朵数据集进行训练(属于模型迁移)

    基于深度学习和迁移学习的识花实践(转)   深度学习是人工智能领域近年来最火热的话题之一,但是对于个人来说,以往想要玩转深度学习除了要具备高超的编程技巧,还需要有海量的数据和强劲的硬件.不过 Tens ...

  6. resnet18全连接层改成卷积层

    想要尝试一下将resnet18最后一层的全连接层改成卷积层看会不会对网络效果和网络大小有什么影响 1.首先先对train.py中的更改是: train.py代码可见:pytorch实现性别检测 # m ...

  7. tensorflow 1.0 学习:池化层(pooling)和全连接层(dense)

    池化层定义在 tensorflow/python/layers/pooling.py. 有最大值池化和均值池化. 1.tf.layers.max_pooling2d max_pooling2d( in ...

  8. mnist全连接层网络权值可视化

    一.数据准备 网络结构:lenet_lr.prototxt 训练好的模型:lenet_lr_iter_10000.caffemodel 下载地址:链接:https://pan.baidu.com/s/ ...

  9. Caffe源码阅读(1) 全连接层

    Caffe源码阅读(1) 全连接层 发表于 2014-09-15   |   今天看全连接层的实现.主要看的是https://github.com/BVLC/caffe/blob/master/src ...

随机推荐

  1. Jmeter计数器的使用-转载

    说一下jmeter中,配置元件-计数器的使用. 如果需要引用的数据量较大,且要求不能重复或者需要自增,那么可以使用计数器来实现. 1.启动jmeter,添加线程组,右键添加配置元件——计数器,如下图: ...

  2. ubuntu16.04+GTX2080Ti+torch7安装记录

    环境说明 ubuntu16.04 cuda10.0 2080Ti显卡 拉取代码和修改编译脚本 拉取代码 用户先clone代码: git clone https://github.com/torch/d ...

  3. golang静态编译

    golang 的编译(不涉及 cgo 编译的前提下)默认使用了静态编译,不依赖任何动态链接库. 这样可以任意部署到各种运行环境,不用担心依赖库的版本问题.只是体积大一点而已,存储时占用了一点磁盘,运行 ...

  4. golang错误处理

    1. 错误 错误用内建的error类型来表示. type error interface { Error() string } error 有了一个签名为 Error() string 的方法.所有实 ...

  5. k8s node节点部署(v1.13.10)

    系统环境: node节点 操作系统: CentOS-7-x86_64-DVD-1908.iso node节点 IP地址: 192.168.1.204 node节点 hostname(主机名, 请和保持 ...

  6. Linux之RHEL7root密码破解(一)

    很多时候我们都会有这样的经历,各种密码,各种复杂,忘记了怎么办???Windows的有关密码忘记了是可以通过相关的邮箱啊手机号等等是可以 找回的,那么Linux的root密码忘记了,该怎么办呢?那么接 ...

  7. beta版本——第七次冲刺

    第七次冲刺 (1)SCRUM部分☁️ 成员描述: 姓名 李星晨 完成了哪个任务 编写个人信息修改界面的js 花了多少时间 3h 还剩余多少时间 0h 遇到什么困难 密码验证部分出现问题 这两天解决的进 ...

  8. Django REST framework —— 权限组件源码分析

    在上一篇文章中我们已经分析了认证组件源码,我们再来看看权限组件的源码,权限组件相对容易,因为只需要返回True 和False即可 代码 class ShoppingCarView(ViewSetMix ...

  9. 项目Alpha冲刺(团队)-测试篇

    格式描述 课程名称:软件工程1916|W(福州大学) 作业要求:项目Alpha冲刺(团队)-代码规范.冲刺任务与计划 团队名称:为了交项目干杯 测试用例:测试用例文档.zip 作业目标:描述项目的测试 ...

  10. 如何使用project制定项目计划?(附详细步骤截图)

    使用project制定项目计划可以分为六个步骤,如下图(1): 图(1)-project制定项目计划步骤 下面我们就以project2010为例,按上图所示步骤对如何制定项目计划进行详细说明: 一.创 ...