pytorch中多个loss回传的参数影响示例
写了一段代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.fc1 = nn.Linear(5, 4)
self.fc2 = nn.Linear(4, 3)
self.fc3 = nn.Linear(4, 3) def forward(self, x):
mid = self.fc1(x)
out1 = self.fc2(mid)
out2 = self.fc3(mid)
return out1, out2 x = torch.randn((3, 5))
y = torch.torch.randint(3, (3,), dtype=torch.int64)
model = Test()
model.train()
optim = torch.optim.RMSprop(model.parameters(), lr=0.001) print(model.fc2.weight)
print(model.fc3.weight)
for i in range(5):
out1, out2 = model(x)
loss1 = F.cross_entropy(out1, y)
loss2 = F.cross_entropy(out2, y)
loss = loss1 + loss2
optim.zero_grad()
loss.backward()
optim.step()
print("-------------after-----------")
print(model.fc2.weight)
print(model.fc3.weight)
在loss.backward()处分别更换为loss1.backward()和loss2.backward(),观察fc2和fc3层的参数变化。
得出的结论为:loss2只影响fc3的参数,loss1只影响fc2的参数。
(粗略分析,抛砖引玉)
pytorch中多个loss回传的参数影响示例的更多相关文章
- 关于Pytorch中accuracy和loss的计算
这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚. 给出实例 def train(train_loader, model, criteon, optimizer, epoc ...
- 在ASP.NET MVC中以post方式传递数组参数的示例
最近在工作中用到了在ASP.NET MVC中以post方式传递数组参数的情况,记录下来,以供参考. 一.准备参数对象 在本例中,我会传递两个数组参数:一个字符串数组,一个自定义对象数组.这个自定义对象 ...
- 在ASP.NET MVC中以post方式传递数组参数的示例【转】
最近在工作中用到了在ASP.NET MVC中以post方式传递数组参数的情况,记录下来,以供参考. 一.准备参数对象 在本例中,我会传递两个数组参数:一个字符串数组,一个自定义对象数组.这个自定义对象 ...
- PyTorch中view的用法
相当于numpy中resize()的功能,但是用法可能不太一样. 我的理解是: 把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其 ...
- Pytorch中的自动求导函数backward()所需参数含义
摘要:一个神经网络有N个样本,经过这个网络把N个样本分为M类,那么此时backward参数的维度应该是[N X M] 正常来说backward()函数是要传入参数的,一直没弄明白backward需要传 ...
- ARTS-S pytorch中backward函数的gradient参数作用
导数偏导数的数学定义 参考资料1和2中对导数偏导数的定义都非常明确.导数和偏导数都是函数对自变量而言.从数学定义上讲,求导或者求偏导只有函数对自变量,其余任何情况都是错的.但是很多机器学习的资料和开源 ...
- Pytorch中torch.autograd ---backward函数的使用方法详细解析,具体例子分析
backward函数 官方定义: torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph ...
- 【PyTorch】PyTorch中的梯度累加
PyTorch中的梯度累加 使用PyTorch实现梯度累加变相扩大batch PyTorch中在反向传播前为什么要手动将梯度清零? - Pascal的回答 - 知乎 https://www.zhihu ...
- pytorch中tensorboardX的用法
在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...
随机推荐
- Codeforces Round #626 (Div. 2) D. Present(位运算)
题意: 求n个数中两两和的异或. 思路: 逐位考虑,第k位只需考虑0~k-1位,可通过&(2k+1-1)得到一组新数. 将新数排序,当两数和在[2k,2k+1)和[2k+1+2k,2k+2)之 ...
- SCZ 20170812 T1 HKJ
因为题面实在是太过暴力,就不贴链接了--我自己重新写一下题面吧-- 题目描述 给定一张带权有向图,设起点为1,终点为n,每个点除编号外还有一个序号,要求输出从起点至终点的最短路经过的点的序号和最短距离 ...
- Codeforces Round #648 (Div. 2) E. Maximum Subsequence Value(鸽巢原理)
题目链接:https://codeforces.com/problemset/problem/1365/E 题意 有 $n$ 个元素,定义大小为 $k$ 的集合值为 $\sum2^i$,其中,若集合内 ...
- 【hdu 4859】海岸线(图论--网络流最小割)
题意:有一个区域,有'.'的陆地,'D'的深海域,'E'的浅海域.其中浅海域可以填充为陆地.这里的陆地区域不联通,并且整个地图都处在海洋之中.问填充一定浅海域之后所有岛屿的最长的海岸线之和. 解法:最 ...
- P4074 [WC2013]糖果公园 树上莫队带修改
题目链接 Candyland 有一座糖果公园,公园里不仅有美丽的风景.好玩的游乐项目,还有许多免费糖果的发放点,这引来了许多贪吃的小朋友来糖果公园游玩. 糖果公园的结构十分奇特,它由 nn 个游览点构 ...
- Logstash 日志收集(补)
收集 Tomcat 日志 安装 Tomcat # 安装 jdk [root@web01 ~]# rpm -ivh jdk-8u181-linux-x64.rpm # 下载 [root@web01 ~] ...
- LVS-DR 模式
SNAT(Source Network Address Translation)源地址转换,类似家里路由器设置,内网地址向外访问时,发起访问的内网ip地址转换为指定的 IP 地址 DNAT(Desti ...
- Shell 函数 & 数组
Shell 函数 函数介绍 # 什么是函数? 具备某一功能的工具 => 函数 事先准备工具的过程 => 函数的定义 遇到应用场景拿来就用 => 函数的调用 # 为何要用函数? 没有引 ...
- HTTP 请求过程以及报文结构
目录 HTTP 请求流程 HTTP 请求报文 请求行 方法字段(Request Method) URL字段(Uniform Resource Locator) HTTP 协议版本字段(略) 请求/响应 ...
- 忘记Mysql的root用户密码处理方法(以mysql 5.5.33为例)
1.修改mysql服务器的脚本 ~]#vi /etc/rc.d/init.d/mysqld #找到$bindir/mysqld_safe --datadir="$datadir" ...