本文是PyTorch使用过程中的的一些总结,有以下内容:

  • 构建网络模型的方法
  • 网络层的遍历
  • 各层参数的遍历
  • 模型的保存与加载
  • 从预训练模型为网络参数赋值

主要涉及到以下函数的使用

  • add_module,ModulesList,Sequential 模型创建
  • modules(),named_modules(),children(),named_children() 访问模型的各个子模块
  • parameters(),named_parameters() 网络参数的遍历
  • save(),load()state_dict() 模型的保存与加载

构建网络

torch.nn.Module是所有网络的基类,在Pytorch实现的Model都要继承该类。而且,Module是可以包含其他的Module的,以树形的结构来表示一个网络结构。

简单的定义一个网络Model

class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.conv1 = nn.Conv2d(3,64,3)
self.conv2 = nn.Conv2d(64,64,3) def forward(self,x):
x = self.conv1(x)
x = self.conv2(x)
return x

Model中两个属性conv1conv2是两个卷积层,在正向传播的过程中,再依次调用这两个卷积层。

除了使用Model的属性来为网络添加层外,还可以使用add_module将网络层添加到网络中。

class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.conv1 = nn.Conv2d(3,64,3)
self.conv2 = nn.Conv2d(64,64,3) self.add_module("maxpool1",nn.MaxPool2d(2,2))
self.add_module("covn3",nn.Conv2d(64,128,3))
self.add_module("conv4",nn.Conv2d(128,128,3)) def forward(self,x):
x = self.conv1(x)
x = self.conv2(x)
x = self.maxpool1(x)
x = self.conv3(x)
x = self.conv4(x)
return x

add_module(name,layer)在正向传播的过程中可以使用添加时的name来访问改layer。

这样一个个的添加layer,在简单的网络中还行,但是对于负责的网络层很多的网络来说就需要敲很多重复的代码了。 这就需要使用到torch.nn.ModuleListtorch.nn.Sequential

使用ModuleListSequential可以方便添加子网络到网络中,但是这两者还是有所不同的。

ModuleList

ModuleList是以list的形式保存sub-modules或者网络层,这样就可以先将网络需要的layer构建好保存到一个list,然后通过ModuleList方法添加到网络中。

class MyModule(nn.Module):
def __init__(self):
super(MyModule,self).__init__() # 构建layer的list
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) def forward(self,x): # 正向传播,使用遍历每个Layer
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x) return x

使用[nn.Linear(10, 10) for i in range(10)]构建要给Layer的list,然后使用ModuleList添加到网络中,在正向传播的过程中,遍历该list

更为方便的是,可以提前配置后,所需要的各个Layer的属性,然后读取配置创建list,然后使用ModuleList将配置好的网络层添加到网络中。 以VGG为例:

vgg_cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',
512, 512, 512, 'M'] def vgg(cfg, i, batch_norm=False):
layers = []
in_channels = i
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
elif v == 'C':
layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return layers class Model1(nn.Module):
def __init__(self):
super(Model1,self).__init__() self.vgg = nn.ModuleList(vgg(vgg_cfg,3)) def forward(self,x): for l in self.vgg:
x = l(x)
m1 = Model1()
print(m1)

读取配置好的网络结构vgg_cfg然后,创建相应的Layer List,使用ModuleList加入到网络中。这样就可以很灵活的创建不同的网络。

这里需要注意的是,ModuleList是将Module加入网络中,需要自己手动的遍历进行每一个Moduleforward

Sequential

一个时序容器。Modules 会以他们传入的顺序被添加到容器中。当然,也可以传入一个OrderedDict一个时序容器。Modules 会以他们传入的顺序被添加到容器中。当然,也可以传入一个OrderedDict

Sequential也是一次加入多个Module到网络中中,和ModuleList不同的是,它接受多个Module依次加入到网络中,还可以接受字典作为参数,例如:

# Example of using Sequential
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
) # Example of using Sequential with OrderedDict
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))

另一个是,Sequential中实现了添加Module的forward,不需要手动的循环调用了。这点相比ModuleList较为方便。

总结

常见的有三种方法来添加子Module到网络中

  • 单独添加一个Module,可以使用属性或者add_module方法。
  • ModuleList可以将一个Module的List加入到网络中,自由度较高,但是需要手动的遍历ModuleList进行forward
  • Sequential按照顺序将将Module加入到网络中,也可以处理字典。 相比于ModuleList不需要自己实现forward

遍历网络结构

可以使用以下2对4个方法来访问网络层所有的Modules

  • modules()named_modules()
  • children()named_children()

modules方法

简单的定义一个如下网络:

class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3)
self.conv2 = nn.Conv2d(64,64,3)
self.maxpool1 = nn.MaxPool2d(2,2) self.features = nn.Sequential(OrderedDict([
('conv3', nn.Conv2d(64,128,3)),
('conv4', nn.Conv2d(128,128,3)),
('relu1', nn.ReLU())
])) def forward(self,x):
x = self.conv1(x)
x = self.conv2(x)
x = self.maxpool1(x)
x = self.features(x) return x

modules()方法,返回一个包含当前模型所有模块的迭代器,这个是递归的返回网络中的所有Module。使用如下语句

    m = Model()
for idx,m in enumerate(m.modules()):
print(idx,"-",m)

其结果为:

0 - Model(
(conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
(maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(features): Sequential(
(conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
(conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(relu1): ReLU()
)
)
1 - Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
2 - Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
3 - MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
4 - Sequential(
(conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
(conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(relu1): ReLU()
)
5 - Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
6 - Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
7 - ReLU()

输出结果解析:

  • 0-Model 整个网络模块
  • 1-2-3-4 为网络的4个子模块,注意4 - Sequential仍然包含有子模块
  • 5-6-7为模块4 - Sequential的子模块

可以看出modules()是递归的返回网络的各个module,从最顶层直到最后的叶子module。

named_modules()的功能和modules()的功能类似,不同的是它返回内容有两部分:module的名称以及module。

children()方法

modules()不同,children()只返回当前模块的子模块,不会递归子模块。

    for idx,m in enumerate(m.children()):
print(idx,"-",m)

其输出为:

0 - Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
1 - Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
2 - MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
3 - Sequential(
(conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
(conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(relu1): ReLU()
)

子模块3-Sequential仍然有子模块,children()没有递归的返回。

named_children()children()的功能类似,不同的是其返回两部分内容:模块的名称以及模块本身。

网络的参数

方法parameters()返回一个包含模型所有参数的迭代器。一般用来当作optimizer的参数。

    for p in m.parameters():
print(type(p.data),p.size())

其输出为:

<class 'torch.Tensor'> torch.Size([128, 64, 3, 3])
<class 'torch.Tensor'> torch.Size([128])
<class 'torch.Tensor'> torch.Size([128, 128, 3, 3])
<class 'torch.Tensor'> torch.Size([128])

包含网络中的所有的权值矩阵参数以及偏置参数。 对网络进行训练时需要将parameters()作为优化器optimizer的参数。

optimizer = torch.optim.SGD(m1.parameters(),lr = args.lr,momentum=args.momentum,weight_decay=args.weight_decay)

parameters()返回网络的所有参数,主要是提供给optimizer用的。而要取得网络某一层的参数或者参数进行一些特殊的处理(如fine-tuning),则使用named_parameters()更为方便些。

named_parameters()返回参数的名称及参数本身,可以按照参数名对一些参数进行处理。

以上面的vgg网络为例:

for k,v in m1.named_parameters():
print(k,v.size())

named_parameters返回的是键值对,k为参数的名称 ,v为参数本身。输出结果为:

vgg.0.weight torch.Size([64, 3, 3, 3])
vgg.0.bias torch.Size([64])
vgg.2.weight torch.Size([64, 64, 3, 3])
vgg.2.bias torch.Size([64])
vgg.5.weight torch.Size([128, 64, 3, 3])
vgg.5.bias torch.Size([128])
vgg.7.weight torch.Size([128, 128, 3, 3])
vgg.7.bias torch.Size([128])
vgg.10.weight torch.Size([256, 128, 3, 3])
vgg.10.bias torch.Size([256])
vgg.12.weight torch.Size([256, 256, 3, 3])
vgg.12.bias torch.Size([256])
vgg.14.weight torch.Size([256, 256, 3, 3])
vgg.14.bias torch.Size([256])
vgg.17.weight torch.Size([512, 256, 3, 3])
vgg.17.bias torch.Size([512])
vgg.19.weight torch.Size([512, 512, 3, 3])
vgg.19.bias torch.Size([512])
vgg.21.weight torch.Size([512, 512, 3, 3])
vgg.21.bias torch.Size([512])
vgg.24.weight torch.Size([512, 512, 3, 3])
vgg.24.bias torch.Size([512])
vgg.26.weight torch.Size([512, 512, 3, 3])
vgg.26.bias torch.Size([512])
vgg.28.weight torch.Size([512, 512, 3, 3])
vgg.28.bias torch.Size([512])

参数名的命名规则属性名称.参数属于的层的编号.weight/bias。 这在fine-tuning的时候,给一些特定的层的参数赋值是非常方便的,这点在后面在加载预训练模型时会看到。

模型的保存与加载

PyTorch使用torch.savetorch.load方法来保存和加载网络,而且网络结构和参数可以分开的保存和加载。

  • 保存网络结构及其参数
torch.save(model,'model.pth') # 保存
model = torch.load("model.pth") # 加载
  • 只加载模型参数,网络结构从代码中创建
torch.save(model.state_dict(),"model.pth") # 保存参数
model = model() # 代码中创建网络结构
params = torch.load("model.pth") # 加载参数
model.load_state_dict(params) # 应用到网络结构中

加载预训练模型

PyTorch中的torchvision里有很多常用网络的预训练模型,例如:vgg,resnet,googlenet等,可以方便的使用这些预训练模型进行微调。

# PyTorch中的torchvision里有很多常用的模型,可以直接调用:
import torchvision.models as models resnet101 = models.resnet18(pretrained=True)
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()

有时候只需要加载预训练模型的部分参数,可以使用参数名作为过滤条件,如下

resnet152 = models.resnet152(pretrained=True)
pretrained_dict = resnet152.state_dict()
"""加载torchvision中的预训练模型和参数后通过state_dict()方法提取参数
也可以直接从官方model_zoo下载:
pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""
model_dict = model.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)

model.state_dict()返回一个python的字典对象,将每一层与它的对应参数建立映射关系(如model的每一层的weights及偏置等等)。注意,只有有参数训练的层才会被保存。

上述的加载方式,是按照参数名类匹配过滤的,但是对于一些参数名称无法完全匹配,或者在预训练模型的基础上新添加的一些层,这些层无法从预训练模型中获取参数,需要初始化。

仍然以上述的vgg为例,在标准的vgg16的特征提取后面,新添加两个卷积层,这两个卷积层的参数需要进行初始化。

vgg = torch.load("vgg.pth") # 加载预训练模型

for k,v in m1.vgg.named_parameters():
k = "features.{}".format(k) # 参数名称
if k in vgg.keys():
v.data = vgg[k].data # 直接加载预训练参数
else:
if k.find("weight") >= 0:
nn.init.xavier_normal_(v.data) # 没有预训练,则使用xavier初始化
else:
nn.init.constant_(v.data,0) # bias 初始化为0

PyTorch-网络的创建,预训练模型的加载的更多相关文章

  1. js动态创建的select2标签样式加载不上解决办法

    js动态创建的select2标签样式加载不上:调用select2的select2()函数来初始化一下: js抛出了Uncaught query function not defined for Sel ...

  2. pytorch中修改后的模型如何加载预训练模型

    问题描述 简单来说,比如你要加载一个vgg16模型,但是你自己需要的网络结构并不是原本的vgg16网络,可能你删掉某些层,可能你改掉某些层,这时你去加载预训练模型,就会报错,错误原因就是你的模型和原本 ...

  3. 【小白学PyTorch】5 torchvision预训练模型与数据集全览

    文章来自:微信公众号[机器学习炼丹术].一个ai专业研究生的个人学习分享公众号 文章目录: 目录 torchvision 1 torchvision.datssets 2 torchvision.mo ...

  4. Swift微博项目--Swift中通过类名字符串创建类以及动态加载控制器的实现

    Swift中用类名字符串创建类(用到了命名空间) OC中可以直接通过类名的字符串转换成对应的类来操作,但是Swift中必须用到命名空间,也就是说Swift中通过字符串获取类的方式为NSClassFro ...

  5. Android Handler 异步消息处理机制的妙用 创建强大的图片加载类(转)

    转载请标明出处:http://blog.csdn.net/lmj623565791/article/details/38476887 ,本文出自[张鸿洋的博客] 最近创建了一个群,方便大家交流,群号: ...

  6. React(九)create-react-app创建项目 + 按需加载Ant Design

    (1)create-react-app如何创建项目我前面第一章介绍过了,这里就不过多写了, (2)我们主要来说说按需加载的问题 1. 引入antd npm install antd --save 2. ...

  7. 从整体上理解进程创建、可执行文件的加载和进程执行进程切换,重点理解分析fork、execve和进程切换

    学号后三位<168> 原创作品转载请注明出处https://github.com/mengning/linuxkernel/ 1.分析fork函数对应的内核处理过程sys_clone,理解 ...

  8. 如何用Swift创建一个复杂的加载动画

    现在在苹果应用商店上有超过140万的App,想让你的app事件非常具有挑战的事情.你有这样一个机会,在你的应用的数据完全加载出来之前,你可以通过一个很小的窗口来捕获用户的关注. 没有比这个更好的地方让 ...

  9. DLL动态库的创建,隐式加载和显式加载

    动态库的创建 打开VS,创建如下控制台工程,工程命名为DllTest: 在弹出的对话框中选择"DLL"后单击"完成"按钮: 在工程中新建DllTest.h和Dl ...

随机推荐

  1. web 开发常用字符串表达式匹配

    记录一下 web 开发中常用的一些字符串模式,这是我遇到的一些,后面如果还有的话,欢迎大神提出,会继续更新. 正则表达式 这个主要用在前端的验证,nginx 路径匹配,shell 脚本文本处理,后端感 ...

  2. pyhton3 之 time模块实例小结

    一.实例1:实现秒表: import time print('按下回车开始计时,按下 Ctrl + C 停止计时.') while True: try: input() # 如果是 python 2. ...

  3. <<代码大全>>阅读笔记之一 使用变量的一般事项

    一.使用变量的一般事项 1.把变量引用局部化 变量应用局部化就是把变量的引用点尽可能集中在一起,这样做的目的是增加代码的可读性 衡量不同引用点靠近程度的一种方法是计算该变量的跨度(span) 示例 a ...

  4. Linux常见的Shell命令

    1.具体的shell命令用法可以通过help或man命令进入手册来查询其具体的用法.2.终端本质上对应着linux上的/dev/tty设备,linux的多用户登录就是通过不同的/dev/tty设备完成 ...

  5. 【tf.keras】实现 F1 score、precision、recall 等 metric

    tf.keras.metric 里面竟然没有实现 F1 score.recall.precision 等指标,一开始觉得真不可思议.但这是有原因的,这些指标在 batch-wise 上计算都没有意义, ...

  6. centos6安装pxc

    Percona XtraDB Cluster是一种高可用性解决方案,可帮助企业避免停机和中断. Percona XtraDB Cluster具有以下MySQL群集优势: • 具有成本效益的HA和MyS ...

  7. MATLAB数值计算——0

    目录 MATLAB数值计算 1.solve() 2.fzero() 3.fsolve() MATLAB数值计算 MATLAB中文论坛基础板块常见问题归纳(出处: MATLAB中文论坛) 登录http: ...

  8. C语言I—2019秋作业03

    这个作业属于那个课程 C语言程序设计II 这个作业要求在哪里 C语言I-2019秋作业03 我在这个课程的目标是 掌握if-else语句,运算关系 这个作业在那个具体方面帮助我实现目标 row 2 c ...

  9. c语言l博客作业09

    问题 答案 这个作业属于那个课程 C语言程序设计II 这个作业要求在哪里 https://edu.cnblogs.com/campus/zswxy/CST2019-2/homework/8655 我在 ...

  10. mysql的两阶段协议(封锁定理,虫洞事务)

    我们都知道数据库的事务具有ACID的四个属性:原子性,一致性,隔离性和持久性.然后在多线程操作的情况下,如果不能保证事务的隔离性,就会造成数据的修改丢失(事务2覆盖了事务1的修改结果).读到脏数据(事 ...