nn.Module基类的构造函数:

  1. def __init__(self):
  2. self._parameters = OrderedDict()
  3. self._modules = OrderedDict()
  4. self._buffers = OrderedDict()
  5. self._backward_hooks = OrderedDict()
  6. self._forward_hooks = OrderedDict()
  7. self.training = True

其中每个属性的解释如下:

  • _parameters:字典,保存用户直接设置的parameter,self.param1 = nn.Parameter(t.randn(3, 3))会被检测到,在字典中加入一个key为'param',value为对应parameter的item。而self.submodule = nn.Linear(3, 4)中的parameter则不会存于此。
  • _modules:子module,通过self.submodel = nn.Linear(3, 4)指定的子module会保存于此。
  • _buffers:缓存。如batchnorm使用momentum机制,每次前向传播需用到上一次前向传播的结果。
  • _backward_hooks_forward_hooks:钩子技术,用来提取中间变量,类似variable的hook。
  • training:BatchNorm与Dropout层在训练阶段和测试阶段中采取的策略不同,通过判断training值来决定前向传播策略。

上述几个属性中,_parameters_modules_buffers这三个字典中的键值,都可以通过self.key方式获得,效果等价于self._parameters['key'].

定义一个Module,这个Module即包含自己的Parameters有包含子Module及其Parameters,

  1. import torch as t
  2. from torch import nn
  3. from torch.autograd import Variable as V
  4.  
  5. class Net(nn.Module):
  6. def __init__(self):
  7. super(Net, self).__init__()
  8. # 等价与self.register_parameter('param1' ,nn.Parameter(t.randn(3, 3)))
  9. self.param1 = nn.Parameter(t.rand(3, 3))
  10. self.submodel1 = nn.Linear(3, 4)
  11. def forward(self, input):
  12. x = self.param1.mm(input)
  13. x = self.submodel11(x)
  14. return x
  15. net = Net()

一、_modules

# 打印网络对象的话会输出子module结构
print(net)

  1. Net(
  2. (submodel1): Linear(in_features=3, out_features=4)
  3. )

# ._modules输出的也是子module结构,不过数据结构和上面的有所不同
print(net.submodel1)
print(net._modules) # 字典子类

  1. Linear(in_features=3, out_features=4)
  2. OrderedDict([('submodel1', Linear(in_features=3, out_features=4))])

for name, submodel in net.named_modules():
    print(name, submodel)

  1. Net(
  2. (submodel1): Linear(in_features=3, out_features=4)
  3. )
  4. submodel1 Linear(in_features=3, out_features=4)

print(list(net.named_modules())) # named_modules其实是包含了本层的module集合

  1. [('', Net(
  2. (submodel1): Linear(in_features=3, out_features=4)
  3. )), ('submodel1', Linear(in_features=3, out_features=4))]

二、_parameters

# ._parameters存储的也是这个结构
print(net.param1)
print(net._parameters) # 字典子类,仅仅包含直接定义的nn.Parameters参数

  1. Parameter containing:
  2. 0.6135 0.8082 0.4519
  3. 0.9052 0.5929 0.2810
  4. 0.6825 0.4437 0.3874
  5. [torch.FloatTensor of size 3x3]
  6. OrderedDict([('param1', Parameter containing:
  7. 0.6135 0.8082 0.4519
  8. 0.9052 0.5929 0.2810
  9. 0.6825 0.4437 0.3874
  10. [torch.FloatTensor of size 3x3]
  11. )])

for name, param in net.named_parameters():
    print(name, param.size())

  1. param1 torch.Size([3, 3])
  2. submodel1.weight torch.Size([4, 3])
  3. submodel1.bias torch.Size([4])

三、_buffers

  1. bn = nn.BatchNorm1d(2)
  2. input = V(t.rand(3, 2), requires_grad=True)
  3. output = bn(input)
  4. bn._buffers
  1. OrderedDict([('running_mean',
  2. 1.00000e-02 *
  3. 9.1559
  4. 1.9914
  5. [torch.FloatTensor of size 2]), ('running_var',
  6. 0.9003
  7. 0.9019
  8. [torch.FloatTensor of size 2])])

四、training

  1. input = V(t.arange(0, 12).view(3, 4))
  2. model = nn.Dropout()
  3. # 在训练阶段,会有一半左右的数被随机置为0
  4. model(input)
  1. Variable containing:
  2. 0 2 4 0
  3. 8 10 0 0
  4. 0 18 0 22
  5. [torch.FloatTensor of size 3x4]
  1. model.training = False
  2. # 在测试阶段,dropout什么都不做
  3. model(input)
  1. Variable containing:
  2. 0 1 2 3
  3. 4 5 6 7
  4. 8 9 10 11
  5. [torch.FloatTensor of size 3x4]

Module.train()、Module.eval() 方法和 Module.training属性的关系

  1. print(net.training, net.submodel1.training)
  2. net.train() # 将本层及子层的training设定为True
  3. net.eval() # 将本层及子层的training设定为False
  4. net.training = True # 注意,对module的设置仅仅影响本层,子module不受影响
  5. net.training, net.submodel1.training
  1. True True
  2. (True, False)

『PyTorch』第十四弹_torch.nn.Module类属性的更多相关文章

  1. 『PyTorch』第十五弹_torch.nn.Module的属性设置&查询

    一.背景知识 python中两个属相相关方法 result = obj.name 会调用builtin函数getattr(obj,'name')查找对应属性,如果没有name属性则调用obj.__ge ...

  2. 『PyTorch』第十二弹_nn.Module和nn.functional

    大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...

  3. 『PyTorch』第十六弹_hook技术

    由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用钩子函数. 钩子函数包括Variable的钩子和nn.Module钩子,用法相似. 一.register_hook impo ...

  4. 『PyTorch』第十弹_循环神经网络

    RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...

  5. 『MXNet』第十二弹_再谈新建计算节点

    上一节我们已经谈到了计算节点,但是即使是官方文档介绍里面相关内容也过于简略,我们使用Faster-RCNN代码中的新建节点为例,重新介绍一下新建节点的调用栈. 1.调用新建节点 参数分为三部分,op_ ...

  6. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  7. 『PyTorch』第九弹_前馈网络简化写法

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...

  8. 『PyTorch』第三弹重置_Variable对象

    『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...

  9. 『PyTorch』第二弹重置_Tensor对象

    『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...

随机推荐

  1. 比特币、莱特币钱包下载和把数据迁移到C盘以外其他盘

    比特币是目前最热门和价格最高的虚拟币,国内外多个平台可以进行交易,有些商家可以用比特币进行支付有些国家可以在ATM取款. Bitcoin-Qt就是最早的比特币客户端,构建了比特币的骨干网络,具有高度的 ...

  2. 2016NOI冬令营day4

    上午:随机算法/近似算法与随机算法的分析方法与应用实例 不懂,完全滑水QAQ :( 下午:计算理论与NP问题 只有讲2-sat和3-sat的时候能听懂,其他的基本都在滑水:( 晚上说是什么中学生学术训 ...

  3. jenkin环境搭建

      Jenkins是一个用Java编写的开源的持续集成(CI)工具,可持续.自动地构建/测试软件项目,监控一些定时执行的任务.具有开源,支持多平台和插件扩展,安装简单,界面化管理等特点. 1.下载并解 ...

  4. 干货:Java并发编程系列之synchronized(一)

    1. 使用方法 synchronized 是 java 中最常用的保证线程安全的方式,synchronized 的作用主要有三方面: 确保线程互斥的访问代码块,同一时刻只有一个方法可以进入到临界区 保 ...

  5. mysql 触发器 trigger用法 four

    实验4 触发器 (1)实验目的 掌握数据库触发器的设计和使用方法 (2)实验内容和要求 定义BEFORE触发器和AFTER触发器.能够理解不同类型触发器的作用和执行原理,验证触发器的有效性. (3)实 ...

  6. P4289 [HAOI2008]移动玩具(bfs)

    P4289 [HAOI2008]移动玩具 双向bfs+状态压缩+记忆化搜索 双向bfs用于对bfs的优化,每次找到可扩展节点少的一边进行一次bfs,找到的第一个互相接触的点即为最短路径 矩阵范围仅4* ...

  7. sublime3 离线安装插件

    直接去:https://packagecontrol.io/installation搜索插件,插件一般会有个git网址(格式化html的插件可以用这个:https://github.com/victo ...

  8. 20145307陈俊达《网络对抗》Exp6 信息搜集与漏洞扫描

    20145307陈俊达<网络对抗>Exp6 信息搜集与漏洞扫描 基础问题回答 哪些组织负责DNS,IP的管理? 全球根服务器均由美国政府授权的ICANN统一管理,负责全球的域名根服务器.D ...

  9. ubuntu14.04禁止触摸板和恢复触摸板

    1.使用xinput list查看与触摸板相关的id,以下是本机的输出,没搞清楚为什么是Mouse!!! jello@jello:~$ xinput list⎡ Virtual core pointe ...

  10. 【第四章】 springboot + swagger

    注:本文参考自 http://www.jianshu.com/p/0465a2b837d2 swagger用于定义API文档. 好处: 前后端分离开发 API文档非常明确 测试的时候不需要再使用URL ...