『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上

『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

在前面的例子中,基本上都是将每一层的输出直接作为下一层的输入,这种网络称为前馈传播网络(feedforward neural network)。对于此类网络如果每次都写复杂的forward函数会有些麻烦,在此就有两种简化方式,ModuleList和Sequential。其中Sequential是一个特殊的module,它包含几个子Module,前向传播时会将输入一层接一层的传递下去。ModuleList也是一个特殊的module,可以包含几个子module,可以像用list一样使用它,但不能直接把输入传给ModuleList。下面举例说明。

一、nn.Sequential()对象

nn.Sequential()对象是类似keras的前馈模型的对象,可以为之添加层实现前馈神经网络。

1、模型建立方式

第一种写法:

nn.Sequential()对象.add_module(层名,层class的实例)

  1. net1 = nn.Sequential()
  2. net1.add_module('conv', nn.Conv2d(3, 3, 3))
  3. net1.add_module('batchnorm', nn.BatchNorm2d(3))
  4. net1.add_module('activation_layer', nn.ReLU())

第二种写法:

nn.Sequential(*多个层class的实例)

  1. net2 = nn.Sequential(
  2. nn.Conv2d(3, 3, 3),
  3. nn.BatchNorm2d(3),
  4. nn.ReLU()
  5. )

第三种写法:

nn.Sequential(OrderedDict([*多个(层名,层class的实例)]))

  1. from collections import OrderedDict
  2. net3= nn.Sequential(OrderedDict([
  3. ('conv', nn.Conv2d(3, 3, 3)),
  4. ('batchnorm', nn.BatchNorm2d(3)),
  5. ('activation_layer', nn.ReLU())
  6. ]))

2、检查以及调用模型

查看模型

print对象即可

  1. print('net1:', net1)
  2. print('net2:', net2)
  3. print('net3:', net3)
  1. net1: Sequential(
  2. (conv): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
  3. (batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
  4. (activation_layer): ReLU()
  5. )
  6. net2: Sequential(
  7. (0): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
  8. (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
  9. (2): ReLU()
  10. )
  11. net3: Sequential(
  12. (conv): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
  13. (batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
  14. (activation_layer): ReLU()
  15. )

提取子Module对象

  1. # 可根据名字或序号取出子module
  2. net1.conv, net2[0], net3.conv
  1. (Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)),
  2. Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)),
  3. Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1)))

调用模型

可以直接网络对象(输入数据),也可以使用上面的Module子对象分别传入(input)。

  1. input = V(t.rand(1, 3, 4, 4))
  2. output = net1(input)
  3. output = net2(input)
  4. output = net3(input)
  5. output = net3.activation_layer(net1.batchnorm(net1.conv(input)))

二、nn.ModuleList()对象

ModuleListModule的子类,当在Module中使用它的时候,就能自动识别为子module。

建立以及使用方法如下,

  1. modellist = nn.ModuleList([nn.Linear(3,4), nn.ReLU(), nn.Linear(4,2)])
  2. input = V(t.randn(1, 3))
  3. for model in modellist:
  4. input = model(input)
  5. # 下面会报错,因为modellist没有实现forward方法
  6. # output = modelist(input)

和普通list不一样,它和torch的其他机制结合紧密,继承了nn.Module的网络模型class可以使用nn.ModuleList并识别其中的parameters,当然这只是个list,不会自动实现forward方法,

  1. class MyModule(nn.Module):
  2. def __init__(self):
  3. super(MyModule, self).__init__()
  4. self.list = [nn.Linear(3, 4), nn.ReLU()]
  5. self.module_list = nn.ModuleList([nn.Conv2d(3, 3, 3), nn.ReLU()])
  6. def forward(self):
  7. pass
  8. model = MyModule()
  9. print(model)
  1. MyModule(
  2. (module_list): ModuleList(
  3. (0): Conv2d (3, 3, kernel_size=(3, 3), stride=(1, 1))
  4. (1): ReLU()
  5. )
  6. )
  1. for name, param in model.named_parameters():
  2. print(name, param.size())
  1. ('module_list.0.weight', torch.Size([3, 3, 3, 3]))
  2. ('module_list.0.bias', torch.Size([3]))

可见,list中的子module并不能被主module所识别,而ModuleList中的子module能够被主module所识别。这意味着如果用list保存子module,将无法调整其参数,因其未加入到主module的参数中。

除ModuleList之外还有ParameterList,其是一个可以包含多个parameter的类list对象。在实际应用中,使用方式与ModuleList类似。如果在构造函数__init__中用到list、tuple、dict等对象时,一定要思考是否应该用ModuleList或ParameterList代替。

『PyTorch』第九弹_前馈网络简化写法的更多相关文章

  1. 『MXNet』第九弹_分类器以及迁移学习DEMO

    解压文件命令: with zipfile.ZipFile('../data/kaggle_cifar10/' + fin, 'r') as zin: zin.extractall('../data/k ...

  2. 『TensorFlow』第九弹_图像预处理_不爱红妆爱武装

    部分代码单独测试: 这里实践了图像大小调整的代码,值得注意的是格式问题: 输入输出图像时一定要使用uint8编码, 但是数据处理过程中TF会自动把编码方式调整为float32,所以输入时没问题,输出时 ...

  3. 『PyTorch』第二弹_张量

    参考:http://www.jianshu.com/p/5ae644748f21# 几个数学概念: 标量(Scalar)是只有大小,没有方向的量,如1,2,3等 向量(Vector)是有大小和方向的量 ...

  4. 『PyTorch』第一弹_静动态图构建if逻辑对比

    对比TensorFlow和Pytorch的动静态图构建上的差异 静态图框架设计好了不能够修改,且定义静态图时需要使用新的特殊语法,这也意味着图设定时无法使用if.while.for-loop等结构,而 ...

  5. 『PyTorch』第二弹重置_Tensor对象

    『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...

  6. 『MXNet』第一弹_基础架构及API

    MXNet是基础,Gluon是封装,两者犹如TensorFlow和Keras,不过得益于动态图机制,两者交互比TensorFlow和Keras要方便得多,其基础操作和pytorch极为相似,但是方便不 ...

  7. 『TensorFlow』第二弹_线性拟合&神经网络拟合_恰是故人归

    Step1: 目标: 使用线性模拟器模拟指定的直线:y = 0.1*x + 0.3 代码: import tensorflow as tf import numpy as np import matp ...

  8. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  9. 『PyTorch』第十二弹_nn.Module和nn.functional

    大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...

随机推荐

  1. python2.7运行selenium webdriver api报错Unable to find a matching set of capabilities

    在火狐浏览器33版本,python2.7运行selenium webdriver api报错:SessionNotCreatedException: Message: Unable to find a ...

  2. Java eclipse下 Ant build.xml实例详解 附完整项目源码

    在有eclipse集成环境下ant其实不是很重要,但有些项目需要用到,另外通过eclipse来学习和理解ant是个很好的途径,所以写他demo总结下要点,希望能够帮到大家. 一.本人测试环境eclip ...

  3. 01: 安装zabbix server

    目录:Django其他篇 01: 安装zabbix server 02:zabbix-agent安装配置 及 web界面管理 03: zabbix API接口 对 主机.主机组.模板.应用集.监控项. ...

  4. Android实践项目汇报总结(上)修改

    微博客户端的设计与实现(上) 第一章 绪论 1.1课题背景 微博可以说是时下最受人们所喜爱的一种社交方式,它是一种通过关注机制分享简短实时信息的广播式的社交网络平台.通过微博我们可以了解最新的时事新闻 ...

  5. svn的下载链接

    想要下载svn结果网上出来都是tortoisesvn 正确的链接是 源代码 http://subversion.apache.org/ 安装包 http://www.collab.net/downlo ...

  6. C# 用Linq查询DataGridView行中的数据是否包含(各种操作)

    http://blog.csdn.net/xht555/article/details/38685845 https://www.cnblogs.com/wuchao/archive/2012/12/ ...

  7. 基于大规模语料的新词发现算法【转自matix67】

    最近需要对商品中的特有的词识别,因此需新词发现算法,matrix的这篇算法很好. 对中文资料进行自然语言处理时,我们会遇到很多其他语言不会有的困难,例如分词——汉语的词与词之间没有空格,那计算机怎么才 ...

  8. HDU 3549 Flow Problem(最大流模板)

    http://acm.hdu.edu.cn/showproblem.php?pid=3549 刚接触网络流,感觉有点难啊,只好先拿几道基础的模板题来练练手. 最大流的模板题. #include< ...

  9. Linux——用户管理简单学习笔记(三)

    用户组管理命令: groupadd -g 888 webadmin 创建用户组webadmin,其GID为888 删除用户组: groupdel 组名 修改用户组信息 groupmod groupmo ...

  10. JavaScript页面跳转的一些实现方法

    第一种 <script language=”javascript” type=”text/javascript”> window.location.href=”login.jsp?back ...