文章转载自微信公众号:机器学习炼丹术。欢迎大家关注,这是我的学习分享公众号,100+原创干货。

文章目录:

本文是对一些函数的学习。函数主要包括下面四个方便:

  • 模型构建的函数:add_module,add_module,add_module
  • 访问子模块:add_module,add_module,add_moduleadd_module
  • 网络遍历:

    add_module,add_module
  • 模型的保存与加载:add_module,add_module,add_module

1 模型构建函数

torch.nn.Module是所有网络的基类,在PyTorch实现模型的类中都要继承这个类(这个在之前的课程中已经提到)。在构建Module中,Module是一个包含其他的Module的,类似于,你可以先定义一个小的网络模块,然后把这个小模块作为另外一个网络的组件。因此网络结构是呈现树状结构

我们先简单定义一个网络:

  1. import torch.nn as nn
  2. import torch
  3. class MyNet(nn.Module):
  4. def __init__(self):
  5. super(MyNet,self).__init__()
  6. self.conv1 = nn.Conv2d(3,64,3)
  7. self.conv2 = nn.Conv2d(64,64,3)
  8. def forward(self,x):
  9. x = self.conv1(x)
  10. x = self.conv2(x)
  11. return x
  12. net = MyNet()
  13. print(net)

输出结果:



MyNet中有两个属性conv1conv2是两个卷积层,在正向传播forward的过程中,依次调用这两个卷积层实现网络的功能。

1.1 add_module

这种是最常见的定义网络的功能,在有些项目中,会看到这样的方法add_module。我们用这个方法来重写上面的网络:

  1. class MyNet(nn.Module):
  2. def __init__(self):
  3. super(MyNet,self).__init__()
  4. self.add_module('conv1',nn.Conv2d(3,64,3))
  5. self.add_module('conv2',nn.Conv2d(64,64,3))
  6. def forward(self,x):
  7. x = self.conv1(x)
  8. x = self.conv2(x)
  9. return x

其实add_module(name,layer)self.name=layer实现了相同的功能,个人感觉也许是因为add_module可以使用字符串来定义变量名字,所以可以放在循环中?反正这个先了解熟悉熟悉

上面的两种方法都是一层一层的添加layer,如果网络复杂的话,那就需要写很多重复的代码了。因此接下来来讲解一下网络模块的构建,torch.nn.ModuleListtorch.nn.Sequential

1.2 ModuleList

ModuleList按照字面意思是用list的形式保存网络层的。这样就可以先将网络需要的layer构建好,保存到一个list,然后通过ModuleList方法添加到网络中.

  1. class MyNet(nn.Module):
  2. def __init__(self):
  3. super(MyNet,self).__init__()
  4. self.linears = nn.ModuleList(
  5. [nn.Linear(10,10) for i in range(5)]
  6. )
  7. def forward(self,x):
  8. for l in self.linears:
  9. x = l(x)
  10. return x
  11. net = MyNet()
  12. print(net)

输出结果是:

这个ModuleList主要是用在读取config文件来构建网络模型中的,下面用VGG模型的构建为例子:

  1. vgg_cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',
  2. 512, 512, 512, 'M']
  3. def vgg(cfg, i, batch_norm=False):
  4. layers = []
  5. in_channels = i
  6. for v in cfg:
  7. if v == 'M':
  8. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
  9. elif v == 'C':
  10. layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
  11. else:
  12. conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
  13. if batch_norm:
  14. layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
  15. else:
  16. layers += [conv2d, nn.ReLU(inplace=True)]
  17. in_channels = v
  18. return layers
  19. class Model1(nn.Module):
  20. def __init__(self):
  21. super(Model1,self).__init__()
  22. self.vgg = nn.ModuleList(vgg(vgg_cfg,3))
  23. def forward(self,x):
  24. for l in self.vgg:
  25. x = l(x)
  26. m1 = Model1()
  27. print(m1)

先读取网络结构的配置文件vgg_cfg然后根据这个文件创建对应的Layer list,然后使用ModuleList添加到网络中,这样可以快速创建不同的网络(用上面为例子的话,可以通过修改配置文件,然后快速修改网络结构

1.3 Sequential

在一些自己做的小项目中,Sequential其实用的更为频繁。

依然重写最初最简单的例子:

  1. class MyNet(nn.Module):
  2. def __init__(self):
  3. super(MyNet,self).__init__()
  4. self.conv = nn.Sequential(
  5. nn.Conv2d(3,64,3),
  6. nn.Conv2d(64,64,3)
  7. )
  8. def forward(self,x):
  9. x = self.conv(x)
  10. return x
  11. net = MyNet()
  12. print(net)

运行结果:

观察细致的朋友可以发现这个问题,Seqential内的网络层是默认用数字进行标号的,而一开始我们使用self.conv1self.conv2的时候,使用conv1和conv2作为标号的。

我们如何修改Sequential中网络层的名称呢?这里需要使用到collections.OrderedDict有序字典。Sequential是支持有序字典构建的。

  1. from collections import OrderedDict
  2. class MyNet(nn.Module):
  3. def __init__(self):
  4. super(MyNet,self).__init__()
  5. self.conv = nn.Sequential(OrderedDict([
  6. ('conv1',nn.Conv2d(3,64,3)),
  7. ('conv2',nn.Conv2d(64,64,3))
  8. ]))
  9. def forward(self,x):
  10. x = self.conv(x)
  11. return x
  12. net = MyNet()
  13. print(net)

输出结果:

1.4 小总结

  • 单独增加一个网络层或者子模块,可以用add_module或者直接赋予属性;
  • ModuleList可以将一个Module的List增加到网络中,自由度较高。
  • Sequential按照顺序产生一个Module模块。这里推荐习惯使用OrderedDict的方法进行构建。对网络层加上规范的名称,这样有助于后续查找与遍历

2 遍历模型结构

本章节使用下面的方法进行遍历之前提到的Module。(个人理解,Module是多个layer的合并,但是一个layer可以说成Module。

先定义一个网络吧,随便写一个:

  1. import torch.nn as nn
  2. import torch
  3. from collections import OrderedDict
  4. class MyNet(nn.Module):
  5. def __init__(self):
  6. super(MyNet,self).__init__()
  7. self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3)
  8. self.conv2 = nn.Conv2d(64,64,3)
  9. self.maxpool1 = nn.MaxPool2d(2,2)
  10. self.features = nn.Sequential(OrderedDict([
  11. ('conv3', nn.Conv2d(64,128,3)),
  12. ('conv4', nn.Conv2d(128,128,3)),
  13. ('relu1', nn.ReLU())
  14. ]))
  15. def forward(self,x):
  16. x = self.conv1(x)
  17. x = self.conv2(x)
  18. x = self.maxpool1(x)
  19. x = self.features(x)
  20. return x
  21. net = MyNet()
  22. print(net)

输出结果是:

2.1 modules()

在第四课中初始化模型各个层的参数的时候,用到了这个方法,现在我们再来理解一下:

  1. for idx,m in enumerate(net.modules()):
  2. print(idx,"-",m)

运行结果:

上面那个网络构建的时候用到了Sequential,所以网络中其实是嵌套了一个小的Module,这就是之前提到的树状结构,然后上面便利的时候也是树状结构的便利过程,可以看出来应该是一个深度遍历的过程。

  • 首先第一个输出的是最大的那个Module,也就是整个网络,0-Model整个网络模块;
  • 1-2-3-4是网络的四个子模块,4-Sequential中间仍然包含子模块
  • 5-6-7是模块4-Sequential的子模块。

【总结】

modules()是递归的返回网络的各个module(深度遍历),从最顶层直到最后的叶子的module。

2.2 named_modules()

named_modules()module()类似,只是同时返回name和module。

  1. for idx,(name,m) in enumerate(net.named_modules()):
  2. print(idx,"-",name)

输出结果:

2.3 parameters()

  1. for p in net.parameters():
  2. print(type(p.data),p.size())

运行结果:

输出的是四个卷积层的权重矩阵参数和偏置参数。值得一提的是,对网络进行训练时需要将parameters()作为优化器optimizer的参数。

  1. optimizer = torch.optim.SGD(net.parameters(),
  2. lr = 0.001,
  3. momentum=0.9)

总之呢,这个parameters()是返回网络所有的参数,主要用在给optimizer优化器用的。而要对网络的某一层的参数做处理的时候,一般还是使用named_parameters()方便一些。

  1. for idx,(name,m) in enumerate(net.named_parameters()):
  2. print(idx,"-",name,m.size())

输出结果:

【小扩展】

我个人有时会使用下面的方法来获取参数:

  1. for idx,(name,m) in enumerate(net.named_modules()):
  2. if isinstance(m,nn.Conv2d):
  3. print(m.weight.shape)
  4. print(m.bias.shape)

先判断是否是卷积层,然后获取其参数,输出结果:

3 保存与载入

PyTorch使用torch.savetorch.load方法来保存和加载网络,而且网络结构和参数可以分开的保存和加载。

  1. torch.save(model,'model.pth') # 保存
  2. model = torch.load("model.pth") # 加载

pytorch中网络结构和模型参数是可以分开保存的。上面的方法是两者同时保存到了.pth文件中,当然,你也可以仅仅保存网络的参数来减小存储文件的大小。注意:如果你仅仅保存模型参数,那么在载入的时候,是需要通过运行代码来初始化模型的结构的。

  1. torch.save(model.state_dict(),"model.pth") # 保存参数
  2. model = MyNet() # 代码中创建网络结构
  3. params = torch.load("model.pth") # 加载参数
  4. model.load_state_dict(params) # 应用到网络结构中

至此,我们今天已经学习了不少的内容,大家对PyTorch的掌握更近一步了呢~

【小白学PyTorch】6 模型的构建访问遍历存储(附代码)的更多相关文章

  1. 【小白学PyTorch】18 TF2构建自定义模型

    [机器学习炼丹术]的炼丹总群已经快满了,要加入的快联系炼丹兄WX:cyx645016617 参考目录: 目录 1 创建自定义网络层 2 创建一个完整的CNN 2.1 keras.Model vs ke ...

  2. 【小白学PyTorch】20 TF2的eager模式与求导

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  3. 【小白学PyTorch】19 TF2模型的存储与载入

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  4. 【小白学PyTorch】15 TF2实现一个简单的服装分类任务

    [新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx64501661 ...

  5. 从头学pytorch(九):模型构造

    模型构造 nn.Module nn.Module是pytorch中提供的一个类,是所有神经网络模块的基类.我们自定义的模块要继承这个基类. import torch from torch import ...

  6. 【小白学PyTorch】12 SENet详解及PyTorch实现

    文章来自微信公众号[机器学习炼丹术].我是炼丹兄,有什么问题都可以来找我交流,近期建立了微信交流群,也在朋友圈抽奖赠书十多本了.我的微信是cyx645016617,欢迎各位朋友. 参考目录: @ 目录 ...

  7. 【小白学PyTorch】4 构建模型三要素与权重初始化

    文章目录: 目录 1 模型三要素 2 参数初始化 3 完整运行代码 4 尺寸计算与参数计算 1 模型三要素 三要素其实很简单 必须要继承nn.Module这个类,要让PyTorch知道这个类是一个Mo ...

  8. 小白学PyTorch 动态图与静态图的浅显理解

    文章来自公众号[机器学习炼丹术],回复"炼丹"即可获得海量学习资料哦! 目录 1 动态图的初步推导 2 动态图的叶子节点 3. grad_fn 4 静态图 本章节缕一缕PyTorc ...

  9. 【小白学PyTorch】5 torchvision预训练模型与数据集全览

    文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...

随机推荐

  1. Spring纯注解配置

    待改造的问题 我们发现,之所以我们现在离不开 xml 配置文件,是因为我们有一句很关键的配置: <!-- 告知spring框架在,读取配置文件,创建容器时,扫描注解,依据注解创建对象,并存入容器 ...

  2. 041_go语言中的panic

    代码演示: package main import "os" func main() { // panic("a problem") _, err := os. ...

  3. python操作Excel,你觉得哪个库更好呢?

    对比学习python,更高效~ Excel数据的类型及组织方式 很多人学习python,不知道从何学起.很多人学习python,掌握了基本语法过后,不知道在哪里寻找案例上手.很多已经做案例的人,却不知 ...

  4. Dubbo系列之 (一)SPI扩展

    一.基础铺垫 1.@SPI .@Activate. @Adaptive a.对于 @SPI,Dubbo默认的特性扩展接口,都必须打上这个@SPI,标识这是个Dubbo扩展点.如果自己需要新增dubbo ...

  5. OpenCV开发笔记(六十九):红胖子8分钟带你使用传统方法识别已知物体(图文并茂+浅显易懂+程序源码)

    若该文为原创文章,未经允许不得转载原博主博客地址:https://blog.csdn.net/qq21497936原博主博客导航:https://blog.csdn.net/qq21497936/ar ...

  6. git使用-远程仓库(github为例)

    1.登录github(没有先注册账号) 2.settings>SSH and GPG keys>New SSH key Title(自己填写即可) key需要git命令生成 ssh-key ...

  7. java Struts 多种表单写法

    1.html:form(struts标签) 缺点:必须指定一个有效的action属性. 优点:可以使用struts token机制. 调用方法通过submit的name属性. <table al ...

  8. python3.x与2.x中print输出不换行

    python3.x: print(i,end=' ') 循环输出: ... ------------------------- print(i,end='!') 循环输出:!!!... end=单引号 ...

  9. C#LeetCode刷题之#746-使用最小花费爬楼梯( Min Cost Climbing Stairs)

    问题 该文章的最新版本已迁移至个人博客[比特飞],单击链接 https://www.byteflying.com/archives/4016 访问. 数组的每个索引做为一个阶梯,第 i个阶梯对应着一个 ...

  10. 文件上传控件bootstrap-fileinput中文设置没有效果的情况

    1.引入zh.js顺序错误 zh.js需放到fileinput.js下面 2. 组件创建语法错误 (class=“file”) 如果你使用js初始化fileinput组件,那么在html元素中应删除 ...