RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
问题
在用pytorch跑生成对抗网络的时候,出现错误Runtime Error: one of the variables needed for gradient computation has been modified by an inplace operation
,特记录排坑记录。
环境
windows10 2004
python 3.7.4
pytorch 1.7.0 + cpu
解决过程
- 尝试一
这段错误代码看上去不难理解,意思为:计算梯度所需的某变量已被一就地操作修改。什么是就地操作呢,举个例子如x += 1
就是典型的就地操作,可将其改为y = x + 1
。但很遗憾,这样并没有解决我的问题,这种方法的介绍如下。
在网上搜了很多相关博客,大多原因如下:
由于0.4.0把Varible和Tensor融合为一个Tensor,inplace操作,之前对Varible能用,但现在对Tensor,就会出错了。
所以解决方案很简单:将所有inplace操作转换为非inplace操作。如将x += 1
换为y = x + 1
。
仍然有一个问题,即如何找到inplace操作,这里提供一个小trick:分阶段调用y.backward()
,若报错,则说明这之前有问题;反之则说明错误在该行之后。
- 尝试二
在我的代码里根本就没有找到任何inplace操作,因此上面这种方法行不通。自己盯着代码,debug,啥也看不出来,好久......
忽然有了新idea。我的训练阶段的代码如下:
for epoch in range(1, epochs + 1):
for idx, (lr, hr) in enumerate(traindata_loader):
lrs = lr.to(device)
hrs = hr.to(device)
# update the discriminator
netD.zero_grad()
logits_fake = netD(netG(lrs).detach())
logits_real = netD(hrs)
# Label smoothing
real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
d_loss = bce(logits_real, real) + bce(logits_fake, fake)
d_loss.backward(retain_graph=True)
optimizerD.step()
# update the generator
netG.zero_grad()
# !!!问题出错行
g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
g_loss.backward()
optimizerG.step()
判别器loss的backward是正常的,生成器loss的backward有问题。观察到g_loss由两项组成,所以很自然的想法就是删掉其中一项看是否正常。结果为:只保留第一项程序正常运行;g_loss中包含第二项程序就出错。
因此去看了adversarialLoss
的代码:
class AdversarialLoss(nn.Module):
def __init__(self):
super(AdversarialLoss, self).__init__()
self.bec_loss = nn.BCELoss()
def forward(self, logits_fake):
# Adversarial Loss
# !!! 问题在这,logits_fake加上detach后就可以正常运行
adversarial_loss = self.bec_loss(logits_fake, torch.ones_like(logits_fake))
return 0.001 * adversarial_loss
看不出来任何问题,只能挨个试。这里只有两个变量:logits_fake
和torch.ones_like(logits_fake)
。后者为常量,所以试着固定logits_fake
,不让其参与训练,程序竟能运行了!
class AdversarialLoss(nn.Module):
def __init__(self):
super(AdversarialLoss, self).__init__()
self.bec_loss = nn.BCELoss()
def forward(self, logits_fake):
# Adversarial Loss
# !!! 问题在这,logits_fake加上detach后就可以正常运行
adversarial_loss = self.bec_loss(logits_fake.detach(), torch.ones_like(logits_fake))
return 0.001 * adversarial_loss
由此知道了被修改的变量是logits_fake。尽管程序可以运行了,但这样做不一定合理。类AdversarialLoss
中没有对logits_fake
进行修改,所以返回刚才的训练程序中。
for epoch in range(1, epochs + 1):
for idx, (lr, hr) in enumerate(traindata_loader):
lrs = lr.to(device)
hrs = hr.to(device)
# update the discriminator
netD.zero_grad()
logits_fake = netD(netG(lrs).detach())
logits_real = netD(hrs)
# Label smoothing
real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
d_loss = bce(logits_real, real) + bce(logits_fake, fake)
d_loss.backward(retain_graph=True)
# 这里进行的更新操作
optimizerD.step()
# update the generator
netG.zero_grad()
# !!!问题出错行
g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
g_loss.backward()
optimizerG.step()
注意到Discriminator在出错行之前进行了更新操作,因此真相呼之欲出————optimizerD.step()
对logits_fake
进行了修改。直接将其挪到倒数第二行即可,修改后代码为:
for epoch in range(1, epochs + 1):
for idx, (lr, hr) in enumerate(traindata_loader):
lrs = lr.to(device)
hrs = hr.to(device)
# update the discriminator
netD.zero_grad()
logits_fake = netD(netG(lrs).detach())
logits_real = netD(hrs)
# Label smoothing
real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
d_loss = bce(logits_real, real) + bce(logits_fake, fake)
d_loss.backward(retain_graph=True)
# update the generator
netG.zero_grad()
g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
g_loss.backward()
optimizerD.step()
optimizerG.step()
程序终于正常运行了,耶( •̀ ω •́ )y!
总结
原因:在计算生成器网络梯度之前先对判别器进行更新,修改了某些值,导致Generator网络的梯度计算失败。
解决方法:将Discriminator的更新步骤放到Generator的梯度计算步骤后面。
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation的更多相关文章
- RuntimeError: one of the variables needed for gradient computation has been modified by an inplace
vgg里面的 ReLU默认的参数inplace=True 当我们调用vgg结构的时候注意 要将inplace改成 False 不然会报错 RuntimeError: one of the variab ...
- one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [3, 1280, 28, 28]], which is output 0 of LeakyReluBackward1, is at version 2;
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace o ...
- TensorFlow 学习(八)—— 梯度计算(gradient computation)
maxpooling 的 max 函数关于某变量的偏导也是分段的,关于它就是 1,不关于它就是 0: BP 是反向传播求关于参数的偏导,SGD 则是梯度更新,是优化算法: 1. 一个实例 relu = ...
- pytorch .detach() .detach_() 和 .data用于切断反向传播
参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-autograd/#detachsource 当我们再训 ...
- PyTorch学习笔记及问题处理
1.torch.nn.state_dict(): 返回一个字典,保存着module的所有状态(state). parameters和persistent_buffers都会包含在字典中,字典的key就 ...
- pytorch的自动求导机制 - 计算图的建立
一.计算图简介 在pytorch的官网上,可以看到一个简单的计算图示意图, 如下. import torchfrom torch.autograd import Variable x = Variab ...
- [源码解析]PyTorch如何实现前向传播(2) --- 基础类(下)
[源码解析]PyTorch如何实现前向传播(2) --- 基础类(下) 目录 [源码解析]PyTorch如何实现前向传播(2) --- 基础类(下) 0x00 摘要 0x01 前文回顾 0x02 Te ...
- Coursera Deep Learning 2 Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization - week1, Assignment(Gradient Checking)
声明:所有内容来自coursera,作为个人学习笔记记录在这里. Gradient Checking Welcome to the final assignment for this week! In ...
- 课程二(Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization),第一周(Practical aspects of Deep Learning) —— 4.Programming assignments:Gradient Checking
Gradient Checking Welcome to this week's third programming assignment! You will be implementing grad ...
随机推荐
- Java面试题之计算字符/字符串出现的次数
一.计算字符在给定字符串中出现的次数 二.计算字符串在给定字符串中出现的次数 1 import java.util.HashMap; 2 import java.util.Map; 3 4 publi ...
- 【Hadoop】伪分布式安装
创建hadoop用户 创建用户命令: sudo useradd -m hadoop -s /bin/bash 创建好后需要更改hadoop用户的密码,命令如下: sudo passwd hadoop ...
- 089 01 Android 零基础入门 02 Java面向对象 02 Java封装 01 封装的实现 03 # 088 01 Android 零基础入门 02 Java面向对象 02 Java封装 01 封装的实现 03 使用包进行类管理(1)——创建包
089 01 Android 零基础入门 02 Java面向对象 02 Java封装 01 封装的实现 03 # 088 01 Android 零基础入门 02 Java面向对象 02 Java封装 ...
- 079 01 Android 零基础入门 02 Java面向对象 01 Java面向对象基础 01 初识面向对象 04 实例化对象
079 01 Android 零基础入门 02 Java面向对象 01 Java面向对象基础 01 初识面向对象 04 实例化对象 本文知识点:实例化对象 说明:因为时间紧张,本人写博客过程中只是对知 ...
- 047 01 Android 零基础入门 01 Java基础语法 05 Java流程控制之循环结构 09 嵌套while循环应用
047 01 Android 零基础入门 01 Java基础语法 05 Java流程控制之循环结构 09 嵌套while循环应用 本文知识点:嵌套while循环应用 什么是循环嵌套? 什么是循环嵌套? ...
- CAD常用知识点
1.Ctrl+9:打开命令窗口: 2.删除标注或者其他(选择对象过滤器):输入fi后回车会出现对象选择过滤器窗口,以删除标注为例,点击选择过滤器-----标注 按以下顺序点击后回车, 框选要去掉的标注 ...
- windev的内部窗口传参方式及其与类的相似性
最近的应用,需要向一个内部窗口(internal window)传参,因为官方文档的说明较为宽泛,虽然结果只有两小段代码,但也费了很大的劲.把所有关于procedure的文档看一遍,又是重新学习了一遍 ...
- Java安全之Commons Collections1分析(一)
Java安全之Commons Collections1分析(一) 0x00 前言 在CC链中,其实具体执行过程还是比较复杂的.建议调试前先将一些前置知识的基础给看一遍. Java安全之Commons ...
- 玩转控件:GDI+动态绘制流程图
前言 今天,要跟大家一起分享是"GDI+动态生成流程图"的功能.别看名字高大上(也就那样儿--!),其实就是动态生成控件,然后GDI+绘制直线连接控件罢了.实际项目效果图如下 ...
- Python基本语法之数据类型(总览)
Python的八种数据类型 Number,数值类型 String,字符串,主要用于描述文本 List,列表,一个包含元素的序列 Tuple,元组,和列表类似,但其是不可变的 Set,一个包含元素的集合 ...