一、背景知识

python中两个属相相关方法

result = obj.name 会调用builtin函数getattr(obj,'name')查找对应属性,如果没有name属性则调用obj.__getattr__('name')方法,再无则报错

obj.name = value 会调用builtin函数setattr(obj,'name',value)设置对应属性,如果设置了__setattr__('name',value)方法则优先调用此方法,而非直接将值存入__dict__并新建属性

二、nn.Module的__setattr__()方法逻辑

nn.Module中实现了__setattr__()方法,当再class的初始化__init__()中执行module.name=value时,会在其中判断value是否属于Parameters或者nn.Module对象,是则将之存储进入__dict__._parameters和__dict__._modules两个字典中;如果是其他对象诸如Variable、List、dict等等,则调用默认操作,将值直接存入__dict__中。

示例

nn.Module的新建Parameter属性,在._parameters中可以查询到,在.__dict__中没有,属于.__dict__._parameters中

import torch as t
import torch.nn as nn module = nn.Module()
module.param = nn.Parameter(t.ones(2,2)) print(module._parameters) """
OrderedDict([('param', Parameter containing:
1 1
1 1
[torch.FloatTensor of size 2x2])])
""" print(module.__dict__)
"""
{'_backend': <torch.nn.backends.thnn.THNNFunctionBackend at 0x7f5dbcf8c160>,
'_backward_hooks': OrderedDict(),
'_buffers': OrderedDict(),
'_forward_hooks': OrderedDict(),
'_forward_pre_hooks': OrderedDict(),
'_modules': OrderedDict(),
'_parameters': OrderedDict([('param', Parameter containing:
1 1
1 1
[torch.FloatTensor of size 2x2])]),
'training': True}
"""

以通常List的格式传入的子Module直接从属于属于.__dict__,并未被_modules识别

submodule1 = nn.Linear(2,2)
submodule2 = nn.Linear(2,2)
module_list = [submodule1,submodule2]
module.submodules = module_list print('_modules:',module_list)
# _modules: [Linear (2 -> 2), Linear (2 -> 2)]
print('__dict__[submodules]:',module.__dict__.get('submodules'))
# __dict__[submodules]: [Linear (2 -> 2), Linear (2 -> 2)]
print('__dict__[submodules]:',module.__dict__['submodules'])
# __dict__[submodules]: [Linear (2 -> 2), Linear (2 -> 2)]

以ModuleList格式传入的子Module可被._modules识别,而不直接从属于.__dict__

module_list = nn.ModuleList(module_list)
module.submodules = module_list print(isinstance(module_list,nn.Module))
# True print(module._modules)
"""
OrderedDict([('submodules', ModuleList (
(0): Linear (2 -> 2)
(1): Linear (2 -> 2)
))])
"""
print(module.__dict__.get('submodules'))
# None
print(module.__dict__['submodules'])
"""
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-19-d4344afabcbf> in <module>()
----> 1 print(module.__dict__['submodules']) KeyError: 'submodules'
"""

三、属性查询函数__getattr__相关特性

nn.Module的.__getattr__()方法会对__dict__._module、__dict__._parameters和__dict__._buffers这三个字典中的key进行查询。当nn.Module进行属性查询时,会先在__dict__进行查询(仅查询本级),查询不到对应属性值时,就会调用.__getattr__()方法,再无结果就报错。

示例

对于__dict__中的属性.training,可以看到.__getattr__('training')查询时就没有结果,

print(module.__dict__.get('submodules'))
# None getattr(module,'training')
# True module.training
# True module.__getattr__('training')
"""
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
……
AttributeError: 'Module' object has no attribute 'training'
"""

另外,我们可以看到.__getattr__可以查询到的结果如下,都是nn.Module自建的属性,

module.__getattr__
"""
<bound method Module.__getattr__ of Module (
(submodules): ModuleList (
(0): Linear (2 -> 2)
(1): Linear (2 -> 2)
)
)>
"""

对于普通的新建属性,其实和nn.Module自建的没什么不同,不同查询方式输出相似,

module.attr1 = 2
getattr(module,'attr1')
# 2 module.__getattr__('attr1')
"""
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
……
AttributeError: 'Module' object has no attribute 'attr1'
"""

对于nn.Module的特殊属性,可以看到,getattr和.__getattr__均可查到,这也是由于getattr一次查找无果后,调用.__getattr__的结果,

getattr(module,'param')
"""
Parameter containing:
1 1
1 1
[torch.FloatTensor of size 2x2]
""" module.__getattr__('param')
"""
Parameter containing:
1 1
1 1
[torch.FloatTensor of size 2x2]
"""

『PyTorch』第十五弹_torch.nn.Module的属性设置&查询的更多相关文章

  1. 『PyTorch』第十四弹_torch.nn.Module类属性

    nn.Module基类的构造函数: def __init__(self): self._parameters = OrderedDict() self._modules = OrderedDict() ...

  2. 『PyTorch』第十二弹_nn.Module和nn.functional

    大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...

  3. 『PyTorch』第十六弹_hook技术

    由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用钩子函数. 钩子函数包括Variable的钩子和nn.Module钩子,用法相似. 一.register_hook impo ...

  4. 『PyTorch』第十弹_循环神经网络

    RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...

  5. 『MXNet』第十二弹_再谈新建计算节点

    上一节我们已经谈到了计算节点,但是即使是官方文档介绍里面相关内容也过于简略,我们使用Faster-RCNN代码中的新建节点为例,重新介绍一下新建节点的调用栈. 1.调用新建节点 参数分为三部分,op_ ...

  6. 『PyTorch』第九弹_前馈网络简化写法

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...

  7. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  8. 『PyTorch』第三弹重置_Variable对象

    『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...

  9. 『PyTorch』第二弹重置_Tensor对象

    『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...

随机推荐

  1. SpringMVC Maven项目 java.lang.ClassNotFoundException: org.springframework.web.servlet.DispatcherServle

    今天在搭建SpringMVC开发框架时,遇到了一个问题,尽管在maven的POM.xml文件中配置了项目所依赖的jar包,但在启动项目是已然报错如下: 信息: Starting Servlet Eng ...

  2. 认识拨号计划-dialplan

    拨号计划是 FreeSWITCH 中至关重要的一部分.它的主要作用就是对电话进行路由(从这一点上来说,相当于一个路由表).说的简明一点,就是当一个用户拨号时,对用户所拨的号码进行分析,进而决定下一步该 ...

  3. js 获取时区

    js的时区函数: 设datename为创建的一个Date对象 ====================datename.getTimezoneOffset()--取得当地时间和GMT时间(格林威治时间 ...

  4. 家庭记账本之微信小程序(七)

    最后成果 在经过对微信小程序的简单学习后,对于微信小程序也稍有理解,在浏览学习过别人的东西后自己也制作了一个,觉得就是有点low,在今后的学习中会继续完善这个微信小程序 //index.js //获取 ...

  5. IP-v4&IP-v6

    IPv6与IPv4区别: 1:IPv6的地址空间更大.IPv4中规定IP地址长度为32,即有2^32-1个地址: 而IPv6中IP地址的长度为128,即有2^128-1个地址. 2.IPv6的路由表更 ...

  6. Nginx技术研究系列7-Azure环境中Nginx高可用性和部署架构设计

    前几篇文章介绍了Nginx的应用.动态路由.配置.在实际生产环境部署时,我们需要同时考虑Nginx的高可用性和部署架构. Nginx自身不支持集群以保证自身的高可用性,商业版本的Nginx+推荐: T ...

  7. PHP XAMPP windows环境安装扩展redis 致命错误: Class 'Redis' not found解决方法

    PHP XAMPP windows环境安装扩展redis 致命错误: Class 'Redis' not found解决方法 1.电脑需要先安装redis服务端环境,并在安装目录下打开客户端redis ...

  8. Python基础(三)文件操作

    [对文件进行循环操作] fw = open('nhy','w') for line in fw: print('line:',line)   #直接循环文件对象,每次循环的时候就是取每一行的数据 fw ...

  9. vue-i18n安装配置,运行

    需求:根据浏览器语言自动切换语言 1.安装vue-i18n, yarn安装 $ yarn add vue-i18n npm安装 $ npm install vue-i18n 2.导入语言包 src下创 ...

  10. [C++ Primer Plus] 第5章、循环和关系表达式(二)课后习题

    1.编写一个要求用户输入两个整数的程序,将程序将计算并输出这两个整数之间(包括这两个整数)所有的整数的和.这里假设先输入较小的整数,例如如果用户输入的是2和9,则程序将指出2-9之间所有整数的和为44 ...