Pytorch:利用torch.nn.Modules.parameters修改模型参数
1. 关于parameters()方法
Pytorch中继承了torch.nn.Module
的模型类具有named_parameters()/parameters()
方法,这两个方法都会返回一个用于迭代模型参数的迭代器(named_parameters
还包括参数名字):
import torch
net = torch.nn.LSTM(input_size=512, hidden_size=64)
print(net.parameters())
print(net.named_parameters())
# <generator object Module.parameters at 0x12a4e9890>
# <generator object Module.named_parameters at 0x12a4e9890>
我们可以将net.parameters()
迭代器和将net.named_parameters()
转化为列表类型,前者列表元素是模型参数,后者是包含参数名和模型参数的元组。
当然,我们更多的是对迭代器直接进行迭代:
for param in net.parameters():
print(param.shape)
# torch.Size([256, 512])
# torch.Size([256, 64])
# torch.Size([256])
# torch.Size([256])
for name, param in net.named_parameters():
print(name, param.shape)
# weight_ih_l0 torch.Size([256, 512])
# weight_hh_l0 torch.Size([256, 64])
# bias_ih_l0 torch.Size([256])
# bias_hh_l0 torch.Size([256])
我们知道,Pytorch在进行优化时需要给优化器传入这个参数迭代器,如:
from torch.optim import RMSprop
optimizer = RMSprop(net.parameters(), lr=0.01)
2. 关于参数修改
那么底层具体是怎么对参数进行修改的呢?
我们在博客《Python对象模型与序列迭代陷阱》中介绍过,Python序列中本身存放的就是对象的引用,而迭代器返回的是序列中的对象的二次引用,如果序列的引用指向基础数据类型,则是不可以通过遍历序列进行修改的,如:
my_list = [1, 2, 3, 4]
for x in my_list:
x += 1
print(my_list) #[1, 2, 3, 4]
而序列中的引用指向复合数据类型,则可以通过遍历序列来完成修改操作,如:
my_list = [[1, 2],[3, 4]]
for sub_list in my_list:
sub_list[0] += 1
print(my_list)
# [1, 2, 3, 4]
# [[2, 2], [4, 4]]
具体原理可参照该篇博客,此处我就不在赘述。这里想提到的是,用net.parameters()/net.named_parameters()
来迭代并修改参数,本质上就是上述第二种对复合数据类型序列的修改。我们可以如下写:
for param in net.parameters():
with torch.no_grad():
param += 1
with torch.no_grad():
表示将将所要修改的张量关闭梯度计算。所增加的1会广播到param
张量的中的每一个元素上。上述操作本质上为:
for param in net.parameters():
with torch.no_grad():
param += torch.ones(param.shape)
但是需要注意,如果我们想让参数全部置为0,切不可像下列这样写:
for param in net.parameters():
with torch.no_grad():
param = torch.zeros(param.shape)
param
是二次引用,param=0
操作再语义上会被解释为让param
这个二次引用去指向新的全0张量对象,但是对参数张量本身并不会产生任何变动。该操作实际上类似下列这种操作:
list_1 = [1, 2]
list_2 = list_1
list_2 = [0, 0]
print(list_1) # [1, 2]
修改二次引用list_2
自然不会影响到list_1
引用的对象。
下面让我们纠正这种错误,采用下列方法直接来将参数张量中的所有数值置0:
for param in net.parameters():
with torch.no_grad():
param[:] = 0 #张量类型自带广播操作,等效于param[:] = torch.zeros(param.shape)
这时语义上就类似
list_1 = [1, 2]
list_2 = list_1
list_2[:] = [0, 0]
print(list_1) # [0, 0]
自然就能完成修改的操作了。
参考
Pytorch:利用torch.nn.Modules.parameters修改模型参数的更多相关文章
- pytorch中torch.nn构建神经网络的不同层的含义
主要是参考这里,写的很好PyTorch 入门实战(四)--利用Torch.nn构建卷积神经网络 卷积层nn.Con2d() 常用参数 in_channels:输入通道数 out_channels:输出 ...
- [pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList
1. torch.nn与torch.nn.functional之间的区别和联系 https://blog.csdn.net/GZHermit/article/details/78730856 nn和n ...
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
- 到底什么是TORCH.NN?
该教程是在notebook上运行的,而不是脚本,下载notebook文件. PyTorch提供了设计优雅的模块和类:torch.nn, torch.optim, Dataset, DataLoader ...
- 从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系
从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系 relu多种实现之间的关系 relu 函数在 pytorch 中总共有 3 次出现: torc ...
- 小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()
模型训练的三要素:数据处理.损失函数.优化算法 数据处理(模块torch.utils.data) 从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始 from torc ...
- pytorch保存模型等相关参数,利用torch.save(),以及读取保存之后的文件
本文分为两部分,第一部分讲如何保存模型参数,优化器参数等等,第二部分则讲如何读取. 假设网络为model = Net(), optimizer = optim.Adam(model.parameter ...
- 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())
学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...
- 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)
我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...
- [深度学习] Pytorch学习(二)—— torch.nn 实践:训练分类器(含多GPU训练CPU加载预测的使用方法)
Learn From: Pytroch 官方Tutorials Pytorch 官方文档 环境:python3.6 CUDA10 pytorch1.3 vscode+jupyter扩展 #%% #%% ...
随机推荐
- 做了5年开源项目,我总结了以下提PR经验!
如何优雅地参与开源贡献,向顶级开源项目提交 PR(Pull Request),如何更好地提交 PR? 针对这些问题和疑惑,我们邀请了 OpenAtom OpenHarmony(以下简称"Op ...
- MogDB学习笔记之 -- 了解pagewriter线程
MogDB 学习笔记之 -- 了解 pagewriter 线程 本文出处:https://www.modb.pro/db/183172 在前面的 MogDB 学习系列中,我们了解了核心的 bgwrit ...
- mysql 必知必会整理—子查询与连接表[八]
前言 简单介绍一下子查询与连接表. 正文 什么是子查询呢? 列出订购物品TNT2的所有客户. select cust_id from orders where order_num IN (SELECT ...
- 给picgo上传的图片加个水印
之前给大家介绍了picgo和免费的图床神器.我们本可以开开心心的进行markdown写作了. 但是总是会有那么一些爬虫网站过来爬你的文章,还把你的文章标明是他们的原著.咋办呢?这里有一个好的办法就是把 ...
- 纯CSS实现带小三角提示框
要实现在页面上点击指定元素时,弹出一个信息提示框.在前面的文章中,我们已经简单介绍了如何使用纯 CSS 创建一个三角形.本文在此基础上,记录如何使用 CSS 创建带三角形的提示框. 实现的原理是创建一 ...
- [Violation] 'click' handler took 429ms
问题 violation 意思为侵权,违背,违反,也就是说明click函数执行违反了某些规则 原因测试 当click事件中执行的程序耗时过长,超过160ms左右的时候就会显示该信息,测试最低155ms ...
- “让专业的人做专业的事”,畅捷通与阿里云的云原生故事 | 云原生 Talk
简介: 如何借助阿里云强大的 IaaS 和 PaaS 能力去构建新一代的 SaaS 企业应用,从而给客户提供更好.更强的服务,这是畅捷通一直在思考和实践的方向.最终,畅捷通选定阿里云企业级分布式应用服 ...
- 项目版本管理的最佳实践:云效飞流Flow篇
简介: 飞流Flow的最佳实践(使用阿里云云效)为了更好地使用飞流Flow,接下来将结合阿里云云效来讲解飞流Flow的最佳实践 目录 一.分支规约 二.版本号规约 2.1 主版本号(首位版本号) 2. ...
- 云上安全保护伞--SLS威胁情报集成实战
简介: 威胁情报是某种基于证据的知识,包括上下文.机制.标示.含义和能够执行的建议. 什么是威胁情报 根据Gartner对威胁情报的定义,威胁情报是某种基于证据的知识,包括上下文.机制.标示.含义和能 ...
- [Blockchain] 前后端完全去中心化的思路, IPFS 与 Ethereum Contract
我们在使用智能合约的时候,一般是把它当成去中心.减少信任依赖的后端存在. 如果没有特殊后端功能要求,一个 DApp 只需要前端驱动 web3js 就可以实现了. 可以看到,现在前端部分依旧是一个中心化 ...