下文都将torch.nn简写成nn

  • Module: 就是我们常用的torch.nn.Module类,你定义的所有网络结构都必须继承这个类。
  • Buffer: buffer和parameter相对,就是指那些不需要参与反向传播的参数

    示例如下:
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.my_tensor = torch.randn(1) # 参数直接作为模型类成员变量
self.register_buffer('my_buffer', torch.randn(1)) # 参数注册为 buffer
self.my_param = nn.Parameter(torch.randn(1))
def forward(self, x):
return x model = MyModel()
print(model.state_dict())
>>>OrderedDict([('my_param', tensor([1.2357])), ('my_buffer', tensor([-0.9982]))])
  • Parameter: 是nn.parameter.Paramter,也就是组成Module的参数。例如一个nn.Linear通常由weightbias参数组成。它的特点是默认requires_grad=True,也就是说训练过程中需要反向传播的,就需要使用这个
import torch.nn as nn
fc = nn.Linear(2,2) # 读取参数的方式一
fc._parameters
>>> OrderedDict([('weight', Parameter containing:
tensor([[0.4142, 0.0424],
[0.3940, 0.0796]], requires_grad=True)),
('bias', Parameter containing:
tensor([-0.2885, 0.5825], requires_grad=True))]) # 读取参数的方式二(推荐这种)
for n, p in fc.named_parameters():
print(n,p)
>>>weight Parameter containing:
tensor([[0.4142, 0.0424],
[0.3940, 0.0796]], requires_grad=True)
bias Parameter containing:
tensor([-0.2885, 0.5825], requires_grad=True) # 读取参数的方式三
for p in fc.parameters():
print(p)
>>>Parameter containing:
tensor([[0.4142, 0.0424],
[0.3940, 0.0796]], requires_grad=True)
Parameter containing:
tensor([-0.2885, 0.5825], requires_grad=True)

通过上面的例子可以看到,nn.parameter.Paramterrequires_grad属性值默认为True。另外上面例子给出了三种读取parameter的方法,推荐使用后面两种(这两种的区别可参阅Pytorch: parameters(),children(),modules(),named_*区别),因为是以迭代生成器的方式来读取,第一种方式是一股脑的把参数全丢给你,要是模型很大,估计你的电脑会吃不消。

另外需要介绍的是_parametersnn.Module__init__()函数中就定义了的一个OrderDict类,这个可以通过看下面给出的部分源码看到,可以看到还初始化了很多其他东西,其实原理都大同小异,你理解了这个之后,其他的也是同样的道理。

class Module(object):
...
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True

每当我们给一个成员变量定义一个nn.parameter.Paramter的时候,都会自动注册到_parameters,具体的步骤如下:

import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# 下面两种定义方式均可
self.p1 = nn.paramter.Paramter(torch.tensor(1.0))
print(self._parameters)
self.p2 = nn.Paramter(torch.tensor(2.0))
print(self._parameters)
  • 首先运行super(MyModel, self).__init__(),这样MyModel就初始化了_paramters等一系列的OrderDict,此时所有变量还都是空的。
  • self.p1 = nn.paramter.Paramter(torch.tensor(1.0)): 这行代码会触发nn.Module预定义好的__setattr__函数,该函数部分源码如下,:
def __setattr__(self, name, value):
...
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules)
self.register_parameter(name, value)
...

__setattr__函数作用简单理解就是判断你定义的参数是否正确,如果正确就继续调用register_parameter函数进行注册,这个函数简单概括就是做了下面这件事

def register_parameter(self,name,param):
...
self._parameters[name]=param

下面我们实例化这个模型看结果怎样

model = MyModel()
>>>OrderedDict([('p1', Parameter containing:
tensor(1., requires_grad=True))])
OrderedDict([('p1', Parameter containing:
tensor(1., requires_grad=True)), ('p2', Parameter containing:
tensor(2., requires_grad=True))])

结果和上面分析的一致。

MARSGGBO♥原创

如有意合作,欢迎私戳

邮箱:marsggbo@foxmail.com



2019-12-20 21:11:02

Pytorch中Module,Parameter和Buffer的区别的更多相关文章

  1. node.js中module.export与export的区别。

    对module.exports和exports的一些理解 可能是有史以来最简单通俗易懂的有关Module.exports和exports区别的文章了. exports = module.exports ...

  2. nodejs 中 module.exports 和 exports 的区别

    1. module应该是require方法中,上下文中的对象 2. exports对象应该是上下文中引用module.exports的新对象 3. exports.a = xxx 会将修改更新到mod ...

  3. module中module.exports与exports的区别(转)

    转https://cnodejs.org/topic/55ccace5b25bd72150842c0a require 用来加载代码,而 exports 和 module.exports 则用来导出代 ...

  4. pytorch 中的重要模块化接口nn.Module

    torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络 nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法 对于自己 ...

  5. [转载]Pytorch中nn.Linear module的理解

    [转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...

  6. 前端后台以及游戏中使用Google Protocol Buffer详解

    前端后台以及游戏中使用Google Protocol Buffer详解 0.什么是protoBuf protoBuf是一种灵活高效的独立于语言平台的结构化数据表示方法,与XML相比,protoBuf更 ...

  7. 转pytorch中训练深度神经网络模型的关键知识点

    版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_42279044/articl ...

  8. 详解Pytorch中的网络构造,模型save和load,.pth权重文件解析

    转载:https://zhuanlan.zhihu.com/p/53927068 https://blog.csdn.net/wangdongwei0/article/details/88956527 ...

  9. PyTorch 中 weight decay 的设置

    先介绍一下 Caffe 和 TensorFlow 中 weight decay 的设置: 在 Caffe 中, SolverParameter.weight_decay 可以作用于所有的可训练参数, ...

随机推荐

  1. 检索式chatbot:

    小夕从7月份开始收到第一场面试邀请,到9月初基本结束了校招(面够了面够了T_T),深深的意识到今年的对话系统/chatbot方向是真的超级火呀.从微软主打情感计算的小冰,到百度主打智能家庭(与车联网? ...

  2. 【CodeForces】CodeForcesRound576 Div1 解题报告

    点此进入比赛 \(A\):MP3(点此看题面) 大致题意: 让你选择一个值域区间\([L,R]\),使得序列中满足\(L\le a_i\le R\)的数的种类数不超过\(2^{\lfloor\frac ...

  3. Linux上error while loading shared libraries问题解决方法

    在Linux环境执行程序时经常会遇到提示程序依赖动态库.so文件不存在的情况,出现报错"error while loading shared libraries: XXXX.so.XX: c ...

  4. [CEOI2019]Cubeword(暴力)

    没错,标签就是暴力. 首先发现棱上的所有词长度都相等,枚举长度 \(len\). 然后发现这些词中只有第一个字符和最后一个字符比较重要(只有这两个位置会与别的串衔接,中间的是啥无所谓). 令 \(cn ...

  5. Python爬取拉勾网招聘信息并写入Excel

    这个是我想爬取的链接:http://www.lagou.com/zhaopin/Python/?labelWords=label 页面显示如下: 在Chrome浏览器中审查元素,找到对应的链接: 然后 ...

  6. jvm 性能调优工具之 jmap

    概述 命令jmap是一个多功能的命令.它可以生成 java 程序的 dump 文件, 也可以查看堆内对象示例的统计信息.查看 ClassLoader 的信息以及 finalizer 队列. jmap ...

  7. 小米笔记本pro 黑苹果系统无法进入系统,频繁重启故障解决记录

    问题1:频繁重启,然后clover丢失 表现情况:开机没有选择macos 或windos的界面 解决办法:进入windows使用工具easyefi,直接添加一个clover start boot,选择 ...

  8. mongodb 导出制定的查询结果

    1.mongo查询语句: db.quarkContext.find({"submitTime":{"$gt":ISODate("2019-07-13T ...

  9. 【shell脚本】打印九九乘法表

    打印九九乘法表 一.seq介绍 seq命令用于以指定增量从首数开始打印数字到尾数,即产生从某个数到另外一个数之间的所有整数,并且可以对整数的格式.宽度.分割符号进行控制 语法: [1] seq [选项 ...

  10. mysql8 安装

    准备工作: 首先安装这些依赖 yum install -y flex yum install gcc gcc-c++ cmake  ncurses ncurses-devel bison libaio ...