nn.Module 函数详解

nn.Module是所有网络模型结构的基类,无论是pytorch自带的模型,还是要自定义模型,都需要继承这个类。这个模块包含了很多子模块,如下所示,_parameters存放的是模型的参数,_buffers也存放的是模型的参数,但是是那些不需要更新的参数。带hook的都是钩子函数,详见钩子函数部分。

  1. self._parameters = OrderedDict()
  2. self._buffers = OrderedDict()
  3. self._non_persistent_buffers_set = set()
  4. self._backward_hooks = OrderedDict()
  5. self._is_full_backward_hook = None
  6. self._forward_hooks = OrderedDict()
  7. self._forward_pre_hooks = OrderedDict()
  8. self._state_dict_hooks = OrderedDict()
  9. self._load_state_dict_pre_hooks = OrderedDict()
  10. self._modules = OrderedDict()

此外,每一个模块还内置了一些常用的方法来帮助访问和操作网络。

  1. load_state_dict() #加载模型权重参数
  2. parameters() #读取所有参数
  3. named_parameters() #读取参数名称和参数
  4. buffers() #读取self.named_buffers中的参数
  5. named_buffers() #读取self.named_buffers中的参数名称和参数
  6. children() #读取模型中,所有的子模型
  7. named_children() #读取子模型名称和子模型
  8. requires_grad_() #设置模型是否开启梯度反向传播

Parameter类

Parameter是Tensor子类,所以继承了Tensor类的属性。例如data和grad属性,可以根据data来访问参数数值,用grad来访问参数梯度。

  1. weight_0 = nn.Parameters(torch.randn(10,10))
  2. print(weight_0.data)
  3. print(weight_0.grad)

定义变量的时候,nn.Parameter会被自动加入到参数列表中去

  1. class MyModel(nn.Module):
  2. def __init__(self):
  3. super(MyModel,self).__init__()
  4. self.weight1 = nn.Parameter(torch.randn(10,10))
  5. self.weight2 = torch.randn(10,10)
  6. def forward(self,x):
  7. pass
  8. model = MyModel()
  9. for name,param in model.named_parameters():
  10. print(name)
  11. output: weight1

ParameterList

接定义成Parameter类外,还可以使用ParameterList和ParameterDict分别定义参数的列表和字典。ParameterList接收一个Parameter实例的列表作为输入然后得到一个参数列表,使用的时候可以用索引来访问某个参数,另外也可以使用append和extend在列表后面新增参数。

  1. params = nn.ParameterList(
  2. [nn.Parameter(torch.randn(10,10)) for i in range(5)]
  3. )
  4. params.append(nn.Parameter(torch.randn(3,3)))

ParameterDict

可以像添加字典数据那样添加参数

  1. params = nn.ParameterDict({
  2. 'linear1':nn.Parameter(torch.randn(10,5)),
  3. 'linear2':nn.Parameter(torch.randn(5,2))
  4. })

模型构建

使用Sequential构建模型

  1. # 写法一
  2. net = nn.Sequential(
  3. nn.Linear(num_inputs, 1)
  4. # 此处还可以传入其他层
  5. )
  6. # 写法二
  7. net = nn.Sequential()
  8. net.add_module('linear', nn.Linear(num_inputs, 1))
  9. # net.add_module ......
  10. # 写法三
  11. from collections import OrderedDict
  12. net = nn.Sequential(OrderedDict([
  13. ('linear', nn.Linear(num_inputs, 1))
  14. # ......
  15. ]))
  16. print(net)

自定义模型

  1. 无参数模型

下面是一个展开操作,比如将2维图像展开成一维

  1. class Flatten(nn.Module):
  2. def __init__(self):
  3. super(Flatten,self).__init__()
  4. def forward(self,input):
  5. return input.view(input.size(0),-1)
  1. 有参数模型

自定义一个Linear层

  1. class MLinear(nn.Module):
  2. def __init__(self,input,output):
  3. super(MyLinear,self).__init__()
  4. self.w = nn.Parameter(torch.randn(input,output))
  5. self.b = nn.Parameter(torch.randn(output))
  6. def foward(self,x):
  7. x = self.w @ x + self.b
  8. return x
  1. 组合模型
  1. class Model(nn.Module):
  2. def __init__(self):
  3. super(Model,self).__init__()
  4. self.l1 = nn.Linear(10,20)
  5. self.l2 = nn.Linear(20,5)
  6. def forward(self,x):
  7. x = self.l1(x)
  8. x = self.l2(x)
  9. return x

ModuleList & ModuleDict

ModuleList 和 ModuleDict都是继承与nn.Module, 与Seuqential不同的是,ModuleList 和 ModuleDict没有自带forward方法,所以只能作为一个模块和其他自定义方法进行组合。下面是使用示例:

  1. class MyModuleList(nn.Module):
  2. def __init__(self):
  3. super(MyModuleList, self).__init__()
  4. self.linears = nn.ModuleList(
  5. [nn.Linear(10, 10) for i in range(3)]
  6. )
  7. def forward(self, x):
  8. for linear in self.linears:
  9. x = linear(x)
  10. return x
  11. class MyModuleDict(nn.Module):
  12. def __init__(self):
  13. super(MyModuleDict, self).__init__()
  14. self.linears = nn.ModuleDict({
  15. "linear1":nn.Linear(10,10),
  16. "linear2":nn.Linear(10,10)
  17. })
  18. def forward(self, x):
  19. x = self.linears["linear1"](x)
  20. x = self.linears["linear2"](x)
  21. return x

Pytorch系列:(三)模型构建的更多相关文章

  1. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  2. pytorch入门2.1构建回归模型初体验(模型构建)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  3. pytorch入门2.2构建回归模型初体验(开始训练)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  4. pytorch入门2.0构建回归模型初体验(数据生成)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  5. 前端构建大法 Gulp 系列 (三):gulp的4个API 让你成为gulp专家

    系列目录 前端构建大法 Gulp 系列 (一):为什么需要前端构建 前端构建大法 Gulp 系列 (二):为什么选择gulp 前端构建大法 Gulp 系列 (三):gulp的4个API 让你成为gul ...

  6. [深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题

    [深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存 ...

  7. 【小白学PyTorch】6 模型的构建访问遍历存储(附代码)

    文章转载自微信公众号:机器学习炼丹术.欢迎大家关注,这是我的学习分享公众号,100+原创干货. 文章目录: 目录 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 ...

  8. Web 开发人员和设计师必读文章推荐【系列三十】

    <Web 前端开发精华文章推荐>2014年第9期(总第30期)和大家见面了.梦想天空博客关注 前端开发 技术,分享各类能够提升网站用户体验的优秀 jQuery 插件,展示前沿的 HTML5 ...

  9. CSS3之简易的3D模型构建[原创开源]

    CSS3之简易的3D模型构建[开源分享] 先上一张图(成果图):这个是使用 3D建模空间[源码之一] 制作出来的模型之一 当然这是一部分模型特写, 之前还制作过枪的3D模型等等. 感兴趣的朋友可以自己 ...

随机推荐

  1. PHP中间件

    定义 首先什么是php的中间件? 根据zend-framework中的定义: 所谓中间件是指提供在请求和响应之间的,能够截获请求,并在其基础上进行逻辑处理,与此同时能够完成请求的响应或传递到下一个中间 ...

  2. Mysql训练:两个表中使用 Select 语句会导致产生 笛卡尔乘积 ,两个表的前后顺序决定查询之后的表顺序

    力扣:超过经理收入的员工 Employee 表包含所有员工,他们的经理也属于员工.每个员工都有一个 Id,此外还有一列对应员工的经理的 Id. +----+-------+--------+----- ...

  3. Python爬虫系统化学习(3)

    一般来说当我们爬取网页的整个源代码后,是需要对网页进行解析的. 正常的解析方法有三种 ①:正则匹配解析 ②:BeatuifulSoup解析 ③:lxml解析 正则匹配解析: 在之前的学习中,我们学习过 ...

  4. Linux-两种磁盘分区方式

    Linux文件设备 要理解Linux,首先要理解Linux文件结构 在Linux操作系统中,几乎所有的设备都位于/dev目录中 名称 作用 位置 SATA接口 电脑硬盘接口 /dev/sd[a-p] ...

  5. Java基本概念:类

    一.描述 类是一种抽象的数据类型,它是对某一类事物整体的描述或定义,但是并不能代表某一个具体的事物. 例如,我们生活中所说的词语:动物.植物.手机.电脑等等.这些也都是抽象的概念,而不是指的某一个 具 ...

  6. 第31天学习打卡(File类。字符流读写文件)

    File类 概念 文件,文件夹,一个file对象代表磁盘上的某个文件或者文件夹 构造方法  File(String pathname) File(String parent,String child) ...

  7. docker封装Spring Cloud(单机版)

    一.概述 微服务统一在一个git项目里面,项目的大致结构如下: ./ ├── auth-server │ ├── pom.xml │ └── src ├── common │ ├── pom.xml ...

  8. jdk 集合大家族之Map

    jdk 集合大家族之Map 前言: 之前章节复习了Collection接口相关,此次我们来一起回顾一下Map相关 .本文基于jdk1.8. 1. HashMap 1.1 概述 HashMap相对于Li ...

  9. FreeBSD 日常应用

    freebsd日常应用 办公libreoffice或者apache openoffice 设计 图像编辑:gimp 矢量图设计:lnkscape 视频剪辑:openshot 视频特效:natron 编 ...

  10. C# 应用 - 多线程 5) 死锁

    两个线程中的每一个线程都尝试锁定另外一个线程已锁定的资源时,就会发生死锁. 两个线程都不能继续执行. 托管线程处理类的许多方法都提供了超时设定,有助于检测死锁. 例如,下面的代码尝试在 lockObj ...