从头学pytorch(十一):自定义层
自定义layer
https://www.cnblogs.com/sdu20112013/p/12132786.html一文里说了怎么写自定义的模型.本篇说怎么自定义层.
分两种:
- 不含模型参数的layer
- 含模型参数的layer
核心都一样,自定义一个继承自nn.Module的类
,在类的forward函数里实现该layer的计算,不同的是,带参数的layer需要用到nn.Parameter
不含模型参数的layer
直接继承nn.Module
import torch
from torch import nn
class CenteredLayer(nn.Module):
def __init__(self, **kwargs):
super(CenteredLayer, self).__init__(**kwargs)
def forward(self, x):
return x - x.mean()
layer = CenteredLayer()
layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
y = net(torch.rand(4, 8))
y.mean().item()
含模型参数的layer
- Parameter
- ParameterList
- ParameterDict
Parameter
类其实是Tensor
的子类,如果一个Tensor
是Parameter
,那么它会自动被添加到模型的参数列表里。所以在自定义含模型参数的层时,我们应该将参数定义成Parameter
,除了直接定义成Parameter
类外,还可以使用ParameterList
和ParameterDict
分别定义参数的列表和字典。
ParameterList用法和list类似
class MyDense(nn.Module):
def __init__(self):
super(MyDense,self).__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(4,4)) for i in range(4)])
self.params.append(nn.Parameter(torch.randn(4,1)))
def forward(self,x):
for i in range(len(self.params)):
x = torch.mm(x,self.params[i])
return x
net = MyDense()
print(net)
输出
MyDense(
(params): ParameterList(
(0): Parameter containing: [torch.FloatTensor of size 4x4]
(1): Parameter containing: [torch.FloatTensor of size 4x4]
(2): Parameter containing: [torch.FloatTensor of size 4x4]
(3): Parameter containing: [torch.FloatTensor of size 4x4]
(4): Parameter containing: [torch.FloatTensor of size 4x1]
)
)
ParameterDict用法和python dict类似.也可以用.keys(),.items()
class MyDictDense(nn.Module):
def __init__(self):
super(MyDictDense, self).__init__()
self.params = nn.ParameterDict({
'linear1': nn.Parameter(torch.randn(4, 4)),
'linear2': nn.Parameter(torch.randn(4, 1))
})
self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增
def forward(self, x, choice='linear1'):
return torch.mm(x, self.params[choice])
net = MyDictDense()
print(net)
print(net.params.keys(),net.params.items())
x = torch.ones(1, 4)
net(x, 'linear1')
输出
MyDictDense(
(params): ParameterDict(
(linear1): Parameter containing: [torch.FloatTensor of size 4x4]
(linear2): Parameter containing: [torch.FloatTensor of size 4x1]
(linear3): Parameter containing: [torch.FloatTensor of size 4x2]
)
)
odict_keys(['linear1', 'linear2', 'linear3']) odict_items([('linear1', Parameter containing:
tensor([[-0.2275, -1.0434, -1.6733, -1.8101],
[ 1.7530, 0.0729, -0.2314, -1.9430],
[-0.1399, 0.7093, -0.4628, -0.2244],
[-1.6363, 1.2004, 1.4415, -0.1364]], requires_grad=True)), ('linear2', Parameter containing:
tensor([[ 0.5035],
[-0.0171],
[-0.8580],
[-1.1064]], requires_grad=True)), ('linear3', Parameter containing:
tensor([[-1.2078, 0.4364],
[-0.8203, 1.7443],
[-1.7759, 2.1744],
[-0.8799, -0.1479]], requires_grad=True))])
使用自定义的layer构造模型
layer1 = MyDense()
layer2 = MyDictDense()
net = nn.Sequential(layer2,layer1)
print(net)
print(net(x))
输出
Sequential(
(0): MyDictDense(
(params): ParameterDict(
(linear1): Parameter containing: [torch.FloatTensor of size 4x4]
(linear2): Parameter containing: [torch.FloatTensor of size 4x1]
(linear3): Parameter containing: [torch.FloatTensor of size 4x2]
)
)
(1): MyDense(
(params): ParameterList(
(0): Parameter containing: [torch.FloatTensor of size 4x4]
(1): Parameter containing: [torch.FloatTensor of size 4x4]
(2): Parameter containing: [torch.FloatTensor of size 4x4]
(3): Parameter containing: [torch.FloatTensor of size 4x4]
(4): Parameter containing: [torch.FloatTensor of size 4x1]
)
)
)
tensor([[-4.7566]], grad_fn=<MmBackward>)
从头学pytorch(十一):自定义层的更多相关文章
- 从头学pytorch(一):数据操作
跟着Dive-into-DL-PyTorch.pdf从头开始学pytorch,夯实基础. Tensor创建 创建未初始化的tensor import torch x = torch.empty(5,3 ...
- 从头学pytorch(三) 线性回归
关于什么是线性回归,不多做介绍了.可以参考我以前的博客https://www.cnblogs.com/sdu20112013/p/10186516.html 实现线性回归 分为以下几个部分: 生成数据 ...
- 从头学pytorch(九):模型构造
模型构造 nn.Module nn.Module是pytorch中提供的一个类,是所有神经网络模块的基类.我们自定义的模块要继承这个基类. import torch from torch import ...
- 从头学pytorch(六):权重衰减
深度学习中常常会存在过拟合现象,比如当训练数据过少时,训练得到的模型很可能在训练集上表现非常好,但是在测试集上表现不好. 应对过拟合,可以通过数据增强,增大训练集数量.我们这里先不介绍数据增强,先从模 ...
- 从头学pytorch(七):dropout防止过拟合
上一篇讲了防止过拟合的一种方式,权重衰减,也即在loss上加上一部分\(\frac{\lambda}{2n} \|\boldsymbol{w}\|^2\),从而使得w不至于过大,即不过分偏向某个特征. ...
- 从头学pytorch(十二):模型保存和加载
模型读取和存储 总结下来,就是几个函数 torch.load()/torch.save() 通过python的pickle完成序列化与反序列化.完成内存<-->磁盘转换. Module.s ...
- 从头学pytorch(十五):AlexNet
AlexNet AlexNet是2012年提出的一个模型,并且赢得了ImageNet图像识别挑战赛的冠军.首次证明了由计算机自动学习到的特征可以超越手工设计的特征,对计算机视觉的研究有着极其重要的意义 ...
- 从头学pytorch(十九):批量归一化batch normalization
批量归一化 论文地址:https://arxiv.org/abs/1502.03167 批量归一化基本上是现在模型的标配了. 说实在的,到今天我也没搞明白batch normalize能够使得模型训练 ...
- 从头学pytorch(二十):残差网络resnet
残差网络ResNet resnet是何凯明大神在2015年提出的.并且获得了当年的ImageNet比赛的冠军. 残差网络具有里程碑的意义,为以后的网络设计提出了一个新的思路. googlenet的思路 ...
随机推荐
- tensorflow学习笔记(三十四):Saver(保存与加载模型)
Savertensorflow 中的 Saver 对象是用于 参数保存和恢复的.如何使用呢? 这里介绍了一些基本的用法. 官网中给出了这么一个例子: v1 = tf.Variable(..., nam ...
- oracle中的闪回
项目中运用: 首先说明:闪回方法有一个前提,就是需要尽早的发现问题,果断的采取行动.若误操作的记录已经在UNDO表空间中被清除,则此方法就不可行了,需要另寻他法. 例如: SELECT * FROM ...
- 解锁当前XXX用户
pam_tally2 查看当前锁账户 pam_tally2 --user=XXX用户 --reset 解锁当前XXX用户
- day5_python之协程函数
一.yield 1:把函数的执行结果封装好__iter__和__next__,即得到一个迭代器2:与return功能类似,都可以返回值,但不同的是,return只能返回一次值,而yield可以返回多次 ...
- behavior planning——15.cost function design weightTweaking
Designing cost functions is difficult and getting them all to cooperate to produce reasionable vehic ...
- Android Animation动画实战(一): 从布局动画引入ListView滑动时,每一Item项的显示动画
前言: 之前,我已经写了两篇博文,给大家介绍了Android的基础动画是如何实现的,如果还不清楚的,可以点击查看:Android Animation动画详解(一): 补间动画 及 Android An ...
- BERT大火却不懂Transformer?读这一篇就够了
https://zhuanlan.zhihu.com/p/54356280 大数据文摘与百度NLP联合出品 编译:张驰.毅航.Conrad.龙心尘 来源:https://jalammar.github ...
- Java多线程遍历文件夹,广度遍历加多线程加深度遍历结合
复习IO操作,突然想写一个小工具,统计一下电脑里面的Java代码量还有注释率,最开始随手写了一个递归算法,遍历文件夹,比较简单,而且代码层次清晰,相对易于理解,代码如下:(完整代码贴在最后面,前面是功 ...
- SuperSocket通过 SessionID 获取 Session
前面提到过,如果你获取了连接的 Session 实例,你就可以通过 "Send()" 方法向客户端发送数据.但是在某些情况下,你无法直接获取 Session 实例. SuperSo ...
- [转][ASP.NET Core 3框架揭秘] 跨平台开发体验: Windows [下篇]
由于ASP.NET Core框架在本质上就是由服务器和中间件构建的消息处理管道,所以在它上面构建的应用开发框架都是建立在某种类型的中间件上,整个ASP.NET Core MVC开发框架就是建立在用来实 ...