0602-nn.Module

pytorch完整教程目录:https://www.cnblogs.com/nickchen121/p/14662511.html

一、nn.Module

torch.nn 的核心数据结构就是 Module,它可以看做是某一层,也可以看做是整个神经网络。最常见的做法就是直接继承 nn.module,然后构建自己的网络模型结构。

1.1 构建一层网络——全连接层

接下来我们通过使用 nn.Module 实现一个全连接层(仿射层),输出 y 和输入 x 满足 \(y=Wx+b\),其中 w 和 b 是可学习参数。

import torch as t
from torch import nn
from torch.autograd import Variable as V
class Linear(nn.Module):
def __init__(self, in_features, out_features): # 输入的数据维度,输出的数据维度
super(Linear,
self).__init__() # 等价于 nn.Module.__init__(self),继承父类的init构造函数
self.w = nn.Parameter(t.randn(in_features, out_features))
self.b = nn.Parameter(t.randn(out_features)) def forward(self, x):
x = x.mm(self.w)
return x + self.b.expand_as(x)
layer = Linear(4, 3)
input = V(t.randn(2, 4))
output = layer(input) # y = Wx + b 的形状是(2,3) = (2,4)*(4*3)+(1,3).expanda_as(x)
output
tensor([[ 1.1407, -0.1323,  0.3659],
[ 2.4265, -1.2330, -0.9984]], grad_fn=<AddBackward0>)
for name, parameter in layer.named_parameters():
print(name, parameter)
w Parameter containing:
tensor([[-1.3990, -1.9669, -0.0430],
[ 0.8150, 0.8829, -1.0932],
[-0.3793, 0.2708, 0.9691],
[-0.9613, -0.3259, 0.5103]], requires_grad=True)
b Parameter containing:
tensor([ 0.9333, -0.7481, -0.6074], requires_grad=True)

从上述代码可以看出实现一个全连接层非常简单,但是需要注意以下几点:

  • 自定义 Linear 必须要继承 nn.Module,并且自定义类的构造函数需要继承 nn.Module 的构造函数
  • 在构造函数中必须自己定义可学习的参数,并且要封装为 Parameter,上述代码则是把 w 和 b 封装成 Parameter,并且可以发现 Parameter 这种数据结构默认 requires_grad=True
  • forward 函数的作用是实现前向传播过程,其输入可以是一个或多个 variable,对 x 的任何操作也必须是 variable 支持的操作
  • 不需要自己写一个反向传播函数,因为它的前向传播都是对 variable 进行操作,nn.Module 能够利用 autograd 自动进行反向传播
  • 调用 layer(input) 时就能得到 input 的结果,其实它的内部是做了 layer.__call__(input) 操作,在 call 函数中,主要调用了 layer.forward(x),另外还对钩子做了一定的处理,因此直接使用 layer(x),而不是使用 layer.forward(x),钩子的具体内容会在接下来讲解。对于 __call__的作用,可以参考这篇文章:详解__call__

1.2 构建多层网络——多层感知机

上述只是实现了一个一层网络结构的模型,下面我们通过更复杂的网络——多层感知机,来感受下 Module 的模块真正强大的地方。多层感知机的网络结构如下图所示:

从多层感知机的网络结构,我们可以看出它由两个全连接层组成,并且它采用 sigimoid 函数作为激活函数。

class Perceptron(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
nn.Module.__init__(self)
self.layer1 = Linear(in_features,
hidden_features) # 此处的 Linear 是前面定义的全连接层
self.layer2 = Linear(hidden_features, out_features) def forward(self, x):
x = self.layer1(x)
x = t.sigimoid(x)
return self.layer2(x) perceptron = Perceptron(3, 4, 1)
for name, param in perceptron.named_parameters():
print(name, param.size())
layer1.w torch.Size([3, 4])
layer1.b torch.Size([4])
layer2.w torch.Size([4, 1])
layer2.b torch.Size([1])

从上述代码中,可以看出多层感知机也非常容易,但是也要注意以下两点:

  • 构造函数中,可以利用前面自定义的 Linear 层作为当前 module 对象的一个子 module,并且它的可学习参数也会称为当前 module 的可学习参数,也就是说主 module 可以递归查找子 module 中的 parameter
  • 在前向传播过程中,我们将输出变量都命名为 x,是为了让 Python 回收一些中间层的输出,从而节省内存,但是有些 variable 虽然名字被覆盖,但是由于它在反向传播过程中仍然需要用到,此时 Python 不会回收这部分数据

对于 parameter的命名有如下规范:

  • 如果没有子模块,parameter 直接命名。例如 self.param_name = nn.Parameter(t.randn(3,4)),则会命名称为 param_name
  • 对于子模块的 parameter,会在它的名字前面加上当前 module 的名字。例如 self.sub_module = SubModel(),SubModel 中也有个名字叫做 param_name 的 parameter,则它的实际名字为 sub_module.param_name

虽然我们自己定义神经网络的层(layer)看起来不是特别费力,但是 torch 为了让用书使用起来更方便,它对绝大多数的 layer 都做了封装,此处不做延伸,有兴趣的可以去参照官方文档,或者参考这一篇文章:0802_转载-nn模块中的网络层介绍

阅读上述介绍的文章时,需要注意下面三点:

  • 构造函数的参数,如 nn.Linear(in_features, out_features, bias),需要关注这三个参数的作用
  • 属性、可学习参数和子 module。例如 nn.Linear 中有 weight 和 bias 两个可学习参数,不包含子 module
  • 输入输出的形状,如 nn.linear 的输入形状是 (N,input_features),输出是 (N, output_features),其中 N 是 batch_size

注:这些自定义的 layer 对输入性状都有一定的假设:输入的不是一个数据,而是一个 batch。如果想要输入一个数据,必须调用 unsqueeze(0) 函数将数据伪装成 batch_size=1 的batch

0602-nn.Module的更多相关文章

  1. pytroch nn.Module源码解析(1)

    今天在写一个分类网络时,要使用nn.Sequential中的一个模块,因为nn.Sequential中模块都没有名字,我一时竟无从下笔.于是决定写这篇博客梳理pytorch的nn.Module类,看完 ...

  2. pytorch 中的重要模块化接口nn.Module

    torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络 nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法 对于自己 ...

  3. 『PyTorch』第十五弹_torch.nn.Module的属性设置&查询

    一.背景知识 python中两个属相相关方法 result = obj.name 会调用builtin函数getattr(obj,'name')查找对应属性,如果没有name属性则调用obj.__ge ...

  4. 『PyTorch』第十四弹_torch.nn.Module类属性

    nn.Module基类的构造函数: def __init__(self): self._parameters = OrderedDict() self._modules = OrderedDict() ...

  5. 『PyTorch x TensorFlow』第八弹_基本nn.Module层函数

    『TensorFlow』网络操作API_上 『TensorFlow』网络操作API_中 『TensorFlow』网络操作API_下 之前也说过,tf 和 t 的层本质区别就是 tf 的是层函数,调用即 ...

  6. 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())

    在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...

  7. 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)

    我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...

  8. [PyTorch 学习笔记] 3.1 模型创建步骤与 nn.Module

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/module_containers.py 这篇文章来看下 ...

  9. pytorch(11)模型创建步骤与nn.Module

    模型创建与nn.Module 网络模型创建步骤 nn.Module graph LR 模型 --> 模型创建 模型创建 --> 构建网络层 构建网络层 --> id[卷积层,池化层, ...

  10. 深度学习--魔法类nn.Module

    深度学习--魔法类nn.Module 作用 pytorch 封装了一些基本的网络类,可以直接调用 好处: 可以直接调用现有的类 容器机制:self.net = nn.Sequential() 参数返回 ...

随机推荐

  1. 【SQL】 牛客网SQL训练Part2 中等难度

    查找当前薪水详情以及部门编号dept_no 查找 1.各个部门当前领导的薪水详情以及其对应部门编号dept_no, 2.输出结果以salaries.emp_no升序排序, 3.并且请注意输出结果里面d ...

  2. B站上教虚幻引擎做游戏的博主 —— 谌嘉诚

    个人主页地址: https://space.bilibili.com/31898841/ 课程地址: https://www.bilibili.com/video/BV164411Y732/

  3. 给大家降降火 —— AI养殖是否夸大功效 —— 深大学生用AI养乌骨鸡增产6万只

    看到一个新闻: 地址: https://export.shobserver.com/baijiahao/html/705726.html 这个新闻里面说的就是这个腾讯的对口培养的大学生搞了一个AI养殖 ...

  4. 强化学习中经典算法 —— reinforce算法 —— (进一步理解, 理论推导出的计算模型和实际应用中的计算模型的区别)

    在奖励折扣率为1的情况下,既没有折扣的情况下,reinforce算法理论上可以写为: 但是在有折扣的情况下,reinforce算法理论上可以写为: 以上均为理论模型. ================ ...

  5. 区块链共识机制 —— PoW共识的Python实现

    原始实现(python2 版本) https://github.com/santisiri/proof-of-work 依据python3特性改进后: #!/usr/bin/env python # ...

  6. SpringBoot Session共享,配置不生效问题排查 → 你竟然在代码里下毒!

    开心一刻 快 8 点了,街边卖油条的还没来,我只能给他打电话 大哥在电话中说到:劳资卖了这么多年油条,从来都是自由自在,自从特么认识了你,居然让我有了上班的感觉! Session 共享 SpringB ...

  7. Singleton bean creation not allowed while singletons of this factory are in destruction

    1.背景 一直都是正常运行的程序,检查日志发现有一条报错如下: org.springframework.beans.factory.BeanCreationNotAllowedException: E ...

  8. 3.2.0 终极预告!云原生支持新增 Spark on k8S 支持

    视频贡献者 | 王维饶 视频制作者 | 聂同学 编辑整理 | Debra Chen Apache DolphinScheduler 3.2.0 版本将发布,为了让大家提前了解到此版本更新的主要内容,我 ...

  9. element-UI tree树形控件 修改小三角图标

    .el-tree /deep/ .el-tree-node__expand-icon.expanded{ -webkit-transform: rotate(0deg); transform: rot ...

  10. MySQL 5.7 DDL 与 GH-OST 对比分析

    作者:来自 vivo 互联网存储研发团队- Xia Qianyong 本文首先介绍MySQL 5.7 DDL以及GH-OST的原理,然后从效率.空间占用.锁阻塞.binlog日志产生量.主备延时等方面 ...