MXNET:深度学习计算-自定义层
虽然 Gluon 提供了大量常用的层,但有时候我们依然希望自定义层。本节将介绍如何使用 NDArray 来自定义一个 Gluon 的层,从而以后可以被重复调用。
不含模型参数的自定义层
我们先介绍如何定义一个不含模型参数的自定义层。事实上,这和 “模型构造” 中介绍的使用 Block 构造模型类似。
通过继承 Block 自定义了一个将输入减掉均值的层:CenteredLayer 类,并将层的计算放在 forward 函数里。
class CenteredLayer(nn.Block):
def __init__(self, **kwargs):
super(CenteredLayer, self).__init__(**kwargs)
def forward(self, x):
return x - x.mean()
layer = CenteredLayer()
layer(nd.array([1, 2, 3, 4, 5]))
# output
[-2. -1. 0. 1. 2.]
<NDArray 5 @cpu(0)>
我们也可以用它来构造更复杂的模型
net = nn.Sequential()
net.add(nn.Dense(128))
net.add(nn.Dense(10))
net.add(CenteredLayer())
net.initialize()
y = net(nd.random.uniform(shape=(4, 8)))
y.mean()
含模型参数的自定义层
在自定义层的时候我们还可以使用 Block 自带的 ParameterDict 类型的成员变量 params。顾名思义,这是一个由字符串类型的参数名字映射到 Parameter 类型的模型参数的字典。我们可以通过 get 函数从 ParameterDict 创建 Parameter。
params = gluon.ParameterDict()
params.get('param2', shape=(2, 3))
params
# ouput
(
Parameter param2 (shape=(2, 3), dtype=<class 'numpy.float32'>)
)
现在我们看下如何实现一个含权重参数和偏差参数的全连接层。它使用 ReLU 作为激活函数。其中 in_units 和 units 分别是输入单元个数和输出单元个数。
class MyDense(nn.Block):
def __init__(self, units, in_units, **kwargs):
super(MyDense, self).__init__(**kwargs)
self.weight = self.params.get('weight', shape=(in_units, units))
self.bias = self.params.get('bias', shape=(units,))
def forward(self, x):
linear = nd.dot(x, self.weight.data()) + self.bias.data()
return nd.relu(linear)
我们实例化 MyDense 类来看下它的模型参数。
# units:该层的输出个数;in_units:该层的输入个数。
dense = MyDense(units=5, in_units=10)
dense.params
# output
mydense0_ (
Parameter mydense0_weight (shape=(10, 5), dtype=<class 'numpy.float32'>)
Parameter mydense0_bias (shape=(5,), dtype=<class 'numpy.float32'>)
)
我们也可以使用自定义层构造模型。它用起来和 Gluon 的其他层很类似。
net = nn.Sequential()
net.add(MyDense(32, in_units=64))
net.add(MyDense(2, in_units=32))
net.initialize()
net(nd.random.uniform(shape=(2, 64)))
MXNET:深度学习计算-自定义层的更多相关文章
- MXNet深度学习库简介
MXNet深度学习库简介 摘要: MXNet是一个深度学习库, 支持C++, Python, R, Scala, Julia, Matlab以及JavaScript等语言; 支持命令和符号编程; 可以 ...
- MXNET:深度学习计算-模型构建
进入更深的层次:模型构造.参数访问.自定义层和使用 GPU. 模型构建 在多层感知机的实现中,我们首先构造 Sequential 实例,然后依次添加两个全连接层.其中第一层的输出大小为 256,即隐藏 ...
- 深度学习中 batchnorm 层是咋回事?
作者:Double_V_ 来源:CSDN 原文:https://blog.csdn.net/qq_25737169/article/details/79048516 版权声明:本文为博主原创文章,转载 ...
- Caffe深度学习计算框架
Caffe | Deep Learning Framework是一个清晰而高效的深度学习框架,其作者是博士毕业于UC Berkeley的 Yangqing Jia,目前在Google工作.Caffe是 ...
- MXNET:深度学习计算-模型参数
我们将深入讲解模型参数的访问和初始化,以及如何在多个层之间共享同一份参数. 之前我们一直在使用默认的初始函数,net.initialize(). from mxnet import init, nd ...
- MXNET:深度学习计算-GPU
mxnet的设备管理 MXNet 使用 context 来指定用来存储和计算的设备,例如可以是 CPU 或者 GPU.默认情况下,MXNet 会将数据创建在主内存,然后利用 CPU 来计算.在 MXN ...
- mxnet深度学习实战学习笔记-9-目标检测
1.介绍 目标检测是指任意给定一张图像,判断图像中是否存在指定类别的目标,如果存在,则返回目标的位置和类别置信度 如下图检测人和自行车这两个目标,检测结果包括目标的位置.目标的类别和置信度 因为目标检 ...
- DeepLearning.ai学习笔记(一)神经网络和深度学习--Week3浅层神经网络
介绍 DeepLearning课程总共五大章节,该系列笔记将按照课程安排进行记录. 另外第一章的前两周的课程在之前的Andrew Ng机器学习课程笔记(博客园)&Andrew Ng机器学习课程 ...
- deeplearning.ai 神经网络和深度学习 week3 浅层神经网络 听课笔记
1. 第i层网络 Z[i] = W[i]A[i-1] + B[i],A[i] = f[i](Z[i]). 其中, W[i]形状是n[i]*n[i-1],n[i]是第i层神经元的数量: A[i-1]是第 ...
随机推荐
- POJ 2437 Muddy roads【贪心】(区间覆盖)
题目链接:https://vjudge.net/contest/194475#problem/C 题目大意: 有n滩泥 木板长度为L 求至少需要多少木板才能覆盖这些泥 解题思路: 把泥块的起点升序排序 ...
- Python中 各种数字类型的判别(numerica, digital, decimal)
一. 全角和半角 全角:是指一个全角字符占用两个标准字符(或两个半角字符)的位置. 全角占两个字节. 汉字字符和规定了全角的英文字符及国标GB2312-80中的图形符号和特殊字符都是全角字符.在全角中 ...
- 树莓派(Raspbian系统)中使用pyinstaller封装Python代码为可执行程序
一.前言 将做好的Python软件运行在树莓派上时,不想公开源码,就需要对文件进行封装(或称打包),本文主要介绍使用pyinstaller封装Python代码为可执行程序. Python是一个脚本语言 ...
- xadmin2.0(for Django2.0) 基本设置
一.下载xadmin 1.使用安装工具安装: pip install git+git://github.com/sshwsfc/xadmin.git@django2 2.下载源码: git clone ...
- 算法初级面试题05——哈希函数/表、生成多个哈希函数、哈希扩容、利用哈希分流找出大文件的重复内容、设计RandomPool结构、布隆过滤器、一致性哈希、并查集、岛问题
今天主要讨论:哈希函数.哈希表.布隆过滤器.一致性哈希.并查集的介绍和应用. 题目一 认识哈希函数和哈希表 1.输入无限大 2.输出有限的S集合 3.输入什么就输出什么 4.会发生哈希碰撞 5.会均匀 ...
- vue-particles粒子动画效果
1.安装vue-particles依赖包 npm install vue-particles --save-dev 2.在main.js文件中引入并使用 import Vue from 'vue' i ...
- SQL 查询 技巧
一.使用SELECT检索数据 数据查询是SQL语言的中心内容,SELECT 语句的作用是让数据库服务器根据客户要求检索出所需要的信息资料,并按照规定的格式进行整理,返回给客户端. SELECT 语句的 ...
- [TC14126]BagAndCards
[TC14126]BagAndCards 题目大意: 有\(n(n\le500)\)个袋子,第\(i\)个袋子里有\(count[i][j]\)张值为\(j(j\le m\le500)\)的牌.给一个 ...
- Windows 文件名的书写规范
下面是不能用于文件名的: / \ : * " < > | windows系统下文件名长度为: 255个英文字符(DOS下8.3格式),包括文件名和扩展名在内,或者是255/2=1 ...
- GPSCamera隐私声明
GPSCamera获取了以下敏感隐私权限用于照片水印展示: 1:修改或删除您的SD卡中的内容 2:访问确切位置信息(使用 GPS 和网络进行定位) 3:访问大致位置信息(使用网络进行定位) 4:拍摄照 ...