问题

在用pytorch生成对抗网络的时候,出现错误Runtime Error: one of the variables needed for gradient computation has been modified by an inplace operation,特记录排坑记录。

环境

windows10 2004

python 3.7.4

pytorch 1.7.0 + cpu

解决过程

  • 尝试一

这段错误代码看上去不难理解,意思为:计算梯度所需的某变量已被一就地操作修改。什么是就地操作呢,举个例子如x += 1就是典型的就地操作,可将其改为y = x + 1。但很遗憾,这样并没有解决我的问题,这种方法的介绍如下。

在网上搜了很多相关博客,大多原因如下:

由于0.4.0把Varible和Tensor融合为一个Tensor,inplace操作,之前对Varible能用,但现在对Tensor,就会出错了。

所以解决方案很简单:将所有inplace操作转换为非inplace操作。如将x += 1换为y = x + 1

仍然有一个问题,即如何找到inplace操作,这里提供一个小trick:分阶段调用y.backward(),若报错,则说明这之前有问题;反之则说明错误在该行之后。

  • 尝试二

在我的代码里根本就没有找到任何inplace操作,因此上面这种方法行不通。自己盯着代码,debug,啥也看不出来,好久......

忽然有了新idea。我的训练阶段的代码如下:

for epoch in range(1, epochs + 1):
for idx, (lr, hr) in enumerate(traindata_loader):
lrs = lr.to(device)
hrs = hr.to(device) # update the discriminator
netD.zero_grad()
logits_fake = netD(netG(lrs).detach())
logits_real = netD(hrs)
# Label smoothing
real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
d_loss = bce(logits_real, real) + bce(logits_fake, fake)
d_loss.backward(retain_graph=True)
optimizerD.step() # update the generator
netG.zero_grad()
# !!!问题出错行
g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
g_loss.backward()
optimizerG.step()

判别器loss的backward是正常的,生成器loss的backward有问题。观察到g_loss由两项组成,所以很自然的想法就是删掉其中一项看是否正常。结果为:只保留第一项程序正常运行;g_loss中包含第二项程序就出错。

因此去看了adversarialLoss的代码:

class AdversarialLoss(nn.Module):
def __init__(self):
super(AdversarialLoss, self).__init__()
self.bec_loss = nn.BCELoss() def forward(self, logits_fake):
# Adversarial Loss
# !!! 问题在这,logits_fake加上detach后就可以正常运行
adversarial_loss = self.bec_loss(logits_fake, torch.ones_like(logits_fake))
return 0.001 * adversarial_loss

看不出来任何问题,只能挨个试。这里只有两个变量:logits_faketorch.ones_like(logits_fake)。后者为常量,所以试着固定logits_fake,不让其参与训练,程序竟能运行了!

class AdversarialLoss(nn.Module):
def __init__(self):
super(AdversarialLoss, self).__init__()
self.bec_loss = nn.BCELoss() def forward(self, logits_fake):
# Adversarial Loss
# !!! 问题在这,logits_fake加上detach后就可以正常运行
adversarial_loss = self.bec_loss(logits_fake.detach(), torch.ones_like(logits_fake))
return 0.001 * adversarial_loss

由此知道了被修改的变量是logits_fake。尽管程序可以运行了,但这样做不一定合理。类AdversarialLoss中没有对logits_fake进行修改,所以返回刚才的训练程序中。

for epoch in range(1, epochs + 1):
for idx, (lr, hr) in enumerate(traindata_loader):
lrs = lr.to(device)
hrs = hr.to(device) # update the discriminator
netD.zero_grad()
logits_fake = netD(netG(lrs).detach())
logits_real = netD(hrs)
# Label smoothing
real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
d_loss = bce(logits_real, real) + bce(logits_fake, fake)
d_loss.backward(retain_graph=True)
# 这里进行的更新操作
optimizerD.step() # update the generator
netG.zero_grad()
# !!!问题出错行
g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
g_loss.backward()
optimizerG.step()

注意到Discriminator在出错行之前进行了更新操作,因此真相呼之欲出————optimizerD.step()logits_fake进行了修改。直接将其挪到倒数第二行即可,修改后代码为:

for epoch in range(1, epochs + 1):
for idx, (lr, hr) in enumerate(traindata_loader):
lrs = lr.to(device)
hrs = hr.to(device) # update the discriminator
netD.zero_grad()
logits_fake = netD(netG(lrs).detach())
logits_real = netD(hrs)
# Label smoothing
real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
d_loss = bce(logits_real, real) + bce(logits_fake, fake)
d_loss.backward(retain_graph=True) # update the generator
netG.zero_grad()
g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
g_loss.backward()
optimizerD.step()
optimizerG.step()

程序终于正常运行了,耶( •̀ ω •́ )y!

总结

原因:在计算生成器网络梯度之前先对判别器进行更新,修改了某些值,导致Generator网络的梯度计算失败。

解决方法:将Discriminator的更新步骤放到Generator的梯度计算步骤后面。

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation的更多相关文章

  1. RuntimeError: one of the variables needed for gradient computation has been modified by an inplace

    vgg里面的 ReLU默认的参数inplace=True 当我们调用vgg结构的时候注意 要将inplace改成 False 不然会报错 RuntimeError: one of the variab ...

  2. one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [3, 1280, 28, 28]], which is output 0 of LeakyReluBackward1, is at version 2;

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace o ...

  3. TensorFlow 学习(八)—— 梯度计算(gradient computation)

    maxpooling 的 max 函数关于某变量的偏导也是分段的,关于它就是 1,不关于它就是 0: BP 是反向传播求关于参数的偏导,SGD 则是梯度更新,是优化算法: 1. 一个实例 relu = ...

  4. pytorch .detach() .detach_() 和 .data用于切断反向传播

    参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-autograd/#detachsource 当我们再训 ...

  5. PyTorch学习笔记及问题处理

    1.torch.nn.state_dict(): 返回一个字典,保存着module的所有状态(state). parameters和persistent_buffers都会包含在字典中,字典的key就 ...

  6. pytorch的自动求导机制 - 计算图的建立

    一.计算图简介 在pytorch的官网上,可以看到一个简单的计算图示意图, 如下. import torchfrom torch.autograd import Variable x = Variab ...

  7. [源码解析]PyTorch如何实现前向传播(2) --- 基础类(下)

    [源码解析]PyTorch如何实现前向传播(2) --- 基础类(下) 目录 [源码解析]PyTorch如何实现前向传播(2) --- 基础类(下) 0x00 摘要 0x01 前文回顾 0x02 Te ...

  8. Coursera Deep Learning 2 Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization - week1, Assignment(Gradient Checking)

    声明:所有内容来自coursera,作为个人学习笔记记录在这里. Gradient Checking Welcome to the final assignment for this week! In ...

  9. 课程二(Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization),第一周(Practical aspects of Deep Learning) —— 4.Programming assignments:Gradient Checking

    Gradient Checking Welcome to this week's third programming assignment! You will be implementing grad ...

随机推荐

  1. APPCNA 指纹验证登录

    今天在APP中集成了指纹与手势登录功能,本文章分两部分进行记录.一是手势功能的逻辑.二是代码实现.该APP是采用APPCAN开发,直接用其已写好的插件,调用相当接口就要可以了. 1.在APP的个人中心 ...

  2. Metasploit之令牌窃取

    令牌简介及原理 令牌(Token) 就是系统的临时密钥,相当于账户名和密码,用来决定是否允) 许这次请求和判断这次请求是属于哪一个用户的.它允许你在不提供密码或其他凭证的前提下,访问网络和系统资源.这 ...

  3. Go-常量-const

    常量:只能读,不能修改,编译前就是确定的值 关键字: const 常量相关类型:int8,16,32,64 float32,64 bool string 可计算结果数学表达式 常量方法 iota pa ...

  4. js简单数据类型和复杂数据类型

    var timer = null;  //简单数据类型null 返回的是一个空的对象 object console.log(typeof timer); 1.简单数据类型 在内存中存放在栈中,在里面开 ...

  5. 重拾H5小游戏之入门篇(二)

    上一篇,水了近千字,很酸爽,同时表达了"重拾"一项旧本领并不容易,还有点题之效果.其实压缩起来就一句话:经过了一番记忆搜索,以及try..catch的尝试后,终于选定了Phaser ...

  6. localhost与127.0.0.1与0.0.0.0

    localhost localhost其实是域名,一般系统默认将localhost指向127.0.0.1,但是localhost并不等于127.0.0.1,localhost指向的IP地址是可以配置的 ...

  7. 【树】HNOI2014 米特运输

    题目大意 洛谷链接 给出一课点带权的树,修改一些点的权值使该树满足: 同一个父亲的儿子权值必须相同 父亲的取值必须是所有儿子权值之和 输入格式 第一行是一个正整数\(N\),表示节点的数目. 接下来\ ...

  8. Zookeeper(1)---初识

    一.ZK简述 Zookeeper,它是一个分布式程序的协调服务,它主要是用来解决分布式应用中的一些数据管理问题,比如集群管理,分布式应用配置,分布式锁,服务注册/发现等等. 它是一个类似于文件系统的树 ...

  9. js拖拽上传 文件上传之拖拽上传

    由于项目需要上传文件到服务器,于是便在文件上传的基础上增加了拖拽上传.拖拽上传当然属于文件上传的一部分,只不过在文件上传的基础上增加了拖拽的界面,主要在于前台的交互, 从拖拽的文件中获取文件列表然后调 ...

  10. 第三十四章 Linux常规练习题(一)参考答案

    一.练习题一 1.超级用户(管理员用户)提示符是___#_,普通用户提示符是___$_. 2.linux关机重启的命令有哪些 ? 关机命令 重启命令 shutdown -h now shutdown ...