问题

在用pytorch生成对抗网络的时候,出现错误Runtime Error: one of the variables needed for gradient computation has been modified by an inplace operation,特记录排坑记录。

环境

windows10 2004

python 3.7.4

pytorch 1.7.0 + cpu

解决过程

  • 尝试一

这段错误代码看上去不难理解,意思为:计算梯度所需的某变量已被一就地操作修改。什么是就地操作呢,举个例子如x += 1就是典型的就地操作,可将其改为y = x + 1。但很遗憾,这样并没有解决我的问题,这种方法的介绍如下。

在网上搜了很多相关博客,大多原因如下:

由于0.4.0把Varible和Tensor融合为一个Tensor,inplace操作,之前对Varible能用,但现在对Tensor,就会出错了。

所以解决方案很简单:将所有inplace操作转换为非inplace操作。如将x += 1换为y = x + 1

仍然有一个问题,即如何找到inplace操作,这里提供一个小trick:分阶段调用y.backward(),若报错,则说明这之前有问题;反之则说明错误在该行之后。

  • 尝试二

在我的代码里根本就没有找到任何inplace操作,因此上面这种方法行不通。自己盯着代码,debug,啥也看不出来,好久......

忽然有了新idea。我的训练阶段的代码如下:

  1. for epoch in range(1, epochs + 1):
  2. for idx, (lr, hr) in enumerate(traindata_loader):
  3. lrs = lr.to(device)
  4. hrs = hr.to(device)
  5. # update the discriminator
  6. netD.zero_grad()
  7. logits_fake = netD(netG(lrs).detach())
  8. logits_real = netD(hrs)
  9. # Label smoothing
  10. real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
  11. fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
  12. d_loss = bce(logits_real, real) + bce(logits_fake, fake)
  13. d_loss.backward(retain_graph=True)
  14. optimizerD.step()
  15. # update the generator
  16. netG.zero_grad()
  17. # !!!问题出错行
  18. g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
  19. g_loss.backward()
  20. optimizerG.step()

判别器loss的backward是正常的,生成器loss的backward有问题。观察到g_loss由两项组成,所以很自然的想法就是删掉其中一项看是否正常。结果为:只保留第一项程序正常运行;g_loss中包含第二项程序就出错。

因此去看了adversarialLoss的代码:

  1. class AdversarialLoss(nn.Module):
  2. def __init__(self):
  3. super(AdversarialLoss, self).__init__()
  4. self.bec_loss = nn.BCELoss()
  5. def forward(self, logits_fake):
  6. # Adversarial Loss
  7. # !!! 问题在这,logits_fake加上detach后就可以正常运行
  8. adversarial_loss = self.bec_loss(logits_fake, torch.ones_like(logits_fake))
  9. return 0.001 * adversarial_loss

看不出来任何问题,只能挨个试。这里只有两个变量:logits_faketorch.ones_like(logits_fake)。后者为常量,所以试着固定logits_fake,不让其参与训练,程序竟能运行了!

  1. class AdversarialLoss(nn.Module):
  2. def __init__(self):
  3. super(AdversarialLoss, self).__init__()
  4. self.bec_loss = nn.BCELoss()
  5. def forward(self, logits_fake):
  6. # Adversarial Loss
  7. # !!! 问题在这,logits_fake加上detach后就可以正常运行
  8. adversarial_loss = self.bec_loss(logits_fake.detach(), torch.ones_like(logits_fake))
  9. return 0.001 * adversarial_loss

由此知道了被修改的变量是logits_fake。尽管程序可以运行了,但这样做不一定合理。类AdversarialLoss中没有对logits_fake进行修改,所以返回刚才的训练程序中。

  1. for epoch in range(1, epochs + 1):
  2. for idx, (lr, hr) in enumerate(traindata_loader):
  3. lrs = lr.to(device)
  4. hrs = hr.to(device)
  5. # update the discriminator
  6. netD.zero_grad()
  7. logits_fake = netD(netG(lrs).detach())
  8. logits_real = netD(hrs)
  9. # Label smoothing
  10. real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
  11. fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
  12. d_loss = bce(logits_real, real) + bce(logits_fake, fake)
  13. d_loss.backward(retain_graph=True)
  14. # 这里进行的更新操作
  15. optimizerD.step()
  16. # update the generator
  17. netG.zero_grad()
  18. # !!!问题出错行
  19. g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
  20. g_loss.backward()
  21. optimizerG.step()

注意到Discriminator在出错行之前进行了更新操作,因此真相呼之欲出————optimizerD.step()logits_fake进行了修改。直接将其挪到倒数第二行即可,修改后代码为:

  1. for epoch in range(1, epochs + 1):
  2. for idx, (lr, hr) in enumerate(traindata_loader):
  3. lrs = lr.to(device)
  4. hrs = hr.to(device)
  5. # update the discriminator
  6. netD.zero_grad()
  7. logits_fake = netD(netG(lrs).detach())
  8. logits_real = netD(hrs)
  9. # Label smoothing
  10. real = (torch.rand(logits_real.size()) * 0.25 + 0.85).clone().detach().to(device)
  11. fake = (torch.rand(logits_fake.size()) * 0.15).clone().detach().to(device)
  12. d_loss = bce(logits_real, real) + bce(logits_fake, fake)
  13. d_loss.backward(retain_graph=True)
  14. # update the generator
  15. netG.zero_grad()
  16. g_loss = contentLoss(netG(lrs), hrs) + adversarialLoss(logits_fake)
  17. g_loss.backward()
  18. optimizerD.step()
  19. optimizerG.step()

程序终于正常运行了,耶( •̀ ω •́ )y!

总结

原因:在计算生成器网络梯度之前先对判别器进行更新,修改了某些值,导致Generator网络的梯度计算失败。

解决方法:将Discriminator的更新步骤放到Generator的梯度计算步骤后面。

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation的更多相关文章

  1. RuntimeError: one of the variables needed for gradient computation has been modified by an inplace

    vgg里面的 ReLU默认的参数inplace=True 当我们调用vgg结构的时候注意 要将inplace改成 False 不然会报错 RuntimeError: one of the variab ...

  2. one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [3, 1280, 28, 28]], which is output 0 of LeakyReluBackward1, is at version 2;

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace o ...

  3. TensorFlow 学习(八)—— 梯度计算(gradient computation)

    maxpooling 的 max 函数关于某变量的偏导也是分段的,关于它就是 1,不关于它就是 0: BP 是反向传播求关于参数的偏导,SGD 则是梯度更新,是优化算法: 1. 一个实例 relu = ...

  4. pytorch .detach() .detach_() 和 .data用于切断反向传播

    参考:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-autograd/#detachsource 当我们再训 ...

  5. PyTorch学习笔记及问题处理

    1.torch.nn.state_dict(): 返回一个字典,保存着module的所有状态(state). parameters和persistent_buffers都会包含在字典中,字典的key就 ...

  6. pytorch的自动求导机制 - 计算图的建立

    一.计算图简介 在pytorch的官网上,可以看到一个简单的计算图示意图, 如下. import torchfrom torch.autograd import Variable x = Variab ...

  7. [源码解析]PyTorch如何实现前向传播(2) --- 基础类(下)

    [源码解析]PyTorch如何实现前向传播(2) --- 基础类(下) 目录 [源码解析]PyTorch如何实现前向传播(2) --- 基础类(下) 0x00 摘要 0x01 前文回顾 0x02 Te ...

  8. Coursera Deep Learning 2 Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization - week1, Assignment(Gradient Checking)

    声明:所有内容来自coursera,作为个人学习笔记记录在这里. Gradient Checking Welcome to the final assignment for this week! In ...

  9. 课程二(Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization),第一周(Practical aspects of Deep Learning) —— 4.Programming assignments:Gradient Checking

    Gradient Checking Welcome to this week's third programming assignment! You will be implementing grad ...

随机推荐

  1. Java面试题之计算字符/字符串出现的次数

    一.计算字符在给定字符串中出现的次数 二.计算字符串在给定字符串中出现的次数 1 import java.util.HashMap; 2 import java.util.Map; 3 4 publi ...

  2. 【Hadoop】伪分布式安装

    创建hadoop用户 创建用户命令: sudo useradd -m hadoop -s /bin/bash 创建好后需要更改hadoop用户的密码,命令如下: sudo passwd hadoop ...

  3. 089 01 Android 零基础入门 02 Java面向对象 02 Java封装 01 封装的实现 03 # 088 01 Android 零基础入门 02 Java面向对象 02 Java封装 01 封装的实现 03 使用包进行类管理(1)——创建包

    089 01 Android 零基础入门 02 Java面向对象 02 Java封装 01 封装的实现 03 # 088 01 Android 零基础入门 02 Java面向对象 02 Java封装 ...

  4. 079 01 Android 零基础入门 02 Java面向对象 01 Java面向对象基础 01 初识面向对象 04 实例化对象

    079 01 Android 零基础入门 02 Java面向对象 01 Java面向对象基础 01 初识面向对象 04 实例化对象 本文知识点:实例化对象 说明:因为时间紧张,本人写博客过程中只是对知 ...

  5. 047 01 Android 零基础入门 01 Java基础语法 05 Java流程控制之循环结构 09 嵌套while循环应用

    047 01 Android 零基础入门 01 Java基础语法 05 Java流程控制之循环结构 09 嵌套while循环应用 本文知识点:嵌套while循环应用 什么是循环嵌套? 什么是循环嵌套? ...

  6. CAD常用知识点

    1.Ctrl+9:打开命令窗口: 2.删除标注或者其他(选择对象过滤器):输入fi后回车会出现对象选择过滤器窗口,以删除标注为例,点击选择过滤器-----标注 按以下顺序点击后回车, 框选要去掉的标注 ...

  7. windev的内部窗口传参方式及其与类的相似性

    最近的应用,需要向一个内部窗口(internal window)传参,因为官方文档的说明较为宽泛,虽然结果只有两小段代码,但也费了很大的劲.把所有关于procedure的文档看一遍,又是重新学习了一遍 ...

  8. Java安全之Commons Collections1分析(一)

    Java安全之Commons Collections1分析(一) 0x00 前言 在CC链中,其实具体执行过程还是比较复杂的.建议调试前先将一些前置知识的基础给看一遍. Java安全之Commons ...

  9. 玩转控件:GDI+动态绘制流程图

       前言 今天,要跟大家一起分享是"GDI+动态生成流程图"的功能.别看名字高大上(也就那样儿--!),其实就是动态生成控件,然后GDI+绘制直线连接控件罢了.实际项目效果图如下 ...

  10. Python基本语法之数据类型(总览)

    Python的八种数据类型 Number,数值类型 String,字符串,主要用于描述文本 List,列表,一个包含元素的序列 Tuple,元组,和列表类似,但其是不可变的 Set,一个包含元素的集合 ...