下文都将torch.nn简写成nn

  • Module: 就是我们常用的torch.nn.Module类,你定义的所有网络结构都必须继承这个类。
  • Buffer: buffer和parameter相对,就是指那些不需要参与反向传播的参数

    示例如下:
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.my_tensor = torch.randn(1) # 参数直接作为模型类成员变量
self.register_buffer('my_buffer', torch.randn(1)) # 参数注册为 buffer
self.my_param = nn.Parameter(torch.randn(1))
def forward(self, x):
return x model = MyModel()
print(model.state_dict())
>>>OrderedDict([('my_param', tensor([1.2357])), ('my_buffer', tensor([-0.9982]))])
  • Parameter: 是nn.parameter.Paramter,也就是组成Module的参数。例如一个nn.Linear通常由weightbias参数组成。它的特点是默认requires_grad=True,也就是说训练过程中需要反向传播的,就需要使用这个
import torch.nn as nn
fc = nn.Linear(2,2) # 读取参数的方式一
fc._parameters
>>> OrderedDict([('weight', Parameter containing:
tensor([[0.4142, 0.0424],
[0.3940, 0.0796]], requires_grad=True)),
('bias', Parameter containing:
tensor([-0.2885, 0.5825], requires_grad=True))]) # 读取参数的方式二(推荐这种)
for n, p in fc.named_parameters():
print(n,p)
>>>weight Parameter containing:
tensor([[0.4142, 0.0424],
[0.3940, 0.0796]], requires_grad=True)
bias Parameter containing:
tensor([-0.2885, 0.5825], requires_grad=True) # 读取参数的方式三
for p in fc.parameters():
print(p)
>>>Parameter containing:
tensor([[0.4142, 0.0424],
[0.3940, 0.0796]], requires_grad=True)
Parameter containing:
tensor([-0.2885, 0.5825], requires_grad=True)

通过上面的例子可以看到,nn.parameter.Paramterrequires_grad属性值默认为True。另外上面例子给出了三种读取parameter的方法,推荐使用后面两种(这两种的区别可参阅Pytorch: parameters(),children(),modules(),named_*区别),因为是以迭代生成器的方式来读取,第一种方式是一股脑的把参数全丢给你,要是模型很大,估计你的电脑会吃不消。

另外需要介绍的是_parametersnn.Module__init__()函数中就定义了的一个OrderDict类,这个可以通过看下面给出的部分源码看到,可以看到还初始化了很多其他东西,其实原理都大同小异,你理解了这个之后,其他的也是同样的道理。

class Module(object):
...
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True

每当我们给一个成员变量定义一个nn.parameter.Paramter的时候,都会自动注册到_parameters,具体的步骤如下:

import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# 下面两种定义方式均可
self.p1 = nn.paramter.Paramter(torch.tensor(1.0))
print(self._parameters)
self.p2 = nn.Paramter(torch.tensor(2.0))
print(self._parameters)
  • 首先运行super(MyModel, self).__init__(),这样MyModel就初始化了_paramters等一系列的OrderDict,此时所有变量还都是空的。
  • self.p1 = nn.paramter.Paramter(torch.tensor(1.0)): 这行代码会触发nn.Module预定义好的__setattr__函数,该函数部分源码如下,:
def __setattr__(self, name, value):
...
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules)
self.register_parameter(name, value)
...

__setattr__函数作用简单理解就是判断你定义的参数是否正确,如果正确就继续调用register_parameter函数进行注册,这个函数简单概括就是做了下面这件事

def register_parameter(self,name,param):
...
self._parameters[name]=param

下面我们实例化这个模型看结果怎样

model = MyModel()
>>>OrderedDict([('p1', Parameter containing:
tensor(1., requires_grad=True))])
OrderedDict([('p1', Parameter containing:
tensor(1., requires_grad=True)), ('p2', Parameter containing:
tensor(2., requires_grad=True))])

结果和上面分析的一致。

MARSGGBO♥原创

如有意合作,欢迎私戳

邮箱:marsggbo@foxmail.com



2019-12-20 21:11:02

Pytorch中Module,Parameter和Buffer的区别的更多相关文章

  1. node.js中module.export与export的区别。

    对module.exports和exports的一些理解 可能是有史以来最简单通俗易懂的有关Module.exports和exports区别的文章了. exports = module.exports ...

  2. nodejs 中 module.exports 和 exports 的区别

    1. module应该是require方法中,上下文中的对象 2. exports对象应该是上下文中引用module.exports的新对象 3. exports.a = xxx 会将修改更新到mod ...

  3. module中module.exports与exports的区别(转)

    转https://cnodejs.org/topic/55ccace5b25bd72150842c0a require 用来加载代码,而 exports 和 module.exports 则用来导出代 ...

  4. pytorch 中的重要模块化接口nn.Module

    torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络 nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法 对于自己 ...

  5. [转载]Pytorch中nn.Linear module的理解

    [转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...

  6. 前端后台以及游戏中使用Google Protocol Buffer详解

    前端后台以及游戏中使用Google Protocol Buffer详解 0.什么是protoBuf protoBuf是一种灵活高效的独立于语言平台的结构化数据表示方法,与XML相比,protoBuf更 ...

  7. 转pytorch中训练深度神经网络模型的关键知识点

    版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_42279044/articl ...

  8. 详解Pytorch中的网络构造,模型save和load,.pth权重文件解析

    转载:https://zhuanlan.zhihu.com/p/53927068 https://blog.csdn.net/wangdongwei0/article/details/88956527 ...

  9. PyTorch 中 weight decay 的设置

    先介绍一下 Caffe 和 TensorFlow 中 weight decay 的设置: 在 Caffe 中, SolverParameter.weight_decay 可以作用于所有的可训练参数, ...

随机推荐

  1. anaconda换源及创建虚拟环境

    0x01:换源,依次输入一下两条命令 conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/fr ...

  2. Ubuntu环境下打开Firefox报错: Firefox is already running, but is not responding.

    在ubuntu下启动firefox可能会报错 Firefox is already running, but is not responding. To open a new window, you ...

  3. 新建全色或者resize(毫无价值,只是做记录)

    import glob import os,sys import shutil import numpy as np import cv2 import matplotlib.pyplot as pl ...

  4. 洛谷 P1381 单词背诵

    洛谷 P1381 单词背诵 洛谷传送门 题目描述 灵梦有n个单词想要背,但她想通过一篇文章中的一段来记住这些单词. 文章由m个单词构成,她想在文章中找出连续的一段,其中包含最多的她想要背的单词(重复的 ...

  5. appium--python启动appium服务

    前戏 前面我们都是在cmd下通过输入appium加端口号来启动服务的,在我们做自动化的时候,我们当然不希望我们手动启动appium服务,而是希望通过脚本自动启动appium服务. 我们可以使用subp ...

  6. QFramework 学习

    github地址: https://github.com/liangxiegame/QFramework 框架官网: http://qf.liangxiegame.com/ 视频教程: http:// ...

  7. 微服务 SpringCloud + docker

    最近看到微服务很火,也是未来的趋势,所以就去学习下 好,接下来我们来认识下spring cloud.一.什么是spring cloud?它的中文官网这样说: 微服务架构集大成者,云计算最佳业务实践. ...

  8. 十一、Spring之事件监听

    Spring之事件监听 ApplicationListener ApplicationListener是Spring事件机制的一部分,与抽象类ApplicationEvent类配合来完成Applica ...

  9. pandas使用大全--数据与处理

    1.首先导入pandas库,一般都会用到numpy库,所以我们先导入备用: import numpy as np import pandas as pd 导入CSV或者xlsx文件: df = pd. ...

  10. 镭神激光雷达对于Autoware的适配

    1. 前言 我们的自动驾驶采用镭神激光雷达,在使用Autoware的时候,需要对镭神激光雷达的底层驱动,进行一些改变以适配Autoware. 2. 修改 (1)首先修改lslidar_c32.laun ...