我最近在学使用Pytorch写GAN代码,发现有些代码在训练部分细节有略微不同,其中有的人用到了detach()函数截断梯度流,有的人没用detch(),取而代之的是在损失函数在反向传播过程中将backward(retain_graph=True),本文通过两个 gan 的代码,介绍它们的作用,并分析,不同的更新策略对程序效率的影响。

  这两个 GAN 的实现中,有两种不同的训练策略:

  • 先训练判别器(discriminator),再训练生成器(generator),这是原始论文Generative Adversarial Networks 中的算法
  • 先训练generator,再训练discriminator

  为了减少网络垃圾,GAN的原理网上一大堆,我这里就不重复赘述了,想要详细了解GAN原理的朋友,可以参考我专题文章:神经网络结构:生成式对抗网络(GAN)

需要了解的知识:

  detach():截断node反向传播的梯度流,将某个node变成不需要梯度的Varibale,因此当反向传播经过这个node时,梯度就不会从这个node往前面传播

更新策略

  我们直接下面进入本文正题,即,在 pytorch 中,detach 和 retain_graph 是干什么用的?本文将借助三段 GAN 的实现代码,来举例介绍它们的作用。

先训练判别器,再训练生成器

策略一

我们分析循环中一个 step 的代码:

valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device)  # 真实标签,都是1
fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # 假标签,都是0 # ########################
# 训练判别器 #
# ########################
real_imgs = imgs.to(device) # 真实图片
z = torch.randn((imgs.shape[0], 100)).to(device) # 噪声 gen_imgs = generator(z) # 从噪声中生成假数据
pred_gen = discriminator(gen_imgs) # 判别器对假数据的输出
pred_real = discriminator(real_imgs) # 判别器对真数据的输出 optimizer_D.zero_grad() # 把判别器中所有参数的梯度归零
real_loss = adversarial_loss(pred_real, valid) # 判别器对真实样本的损失
fake_loss = adversarial_loss(pred_gen, fake) # 判别器对假样本的损失
d_loss = (real_loss + fake_loss) / 2 # 两项损失相加取平均 # 下面这行代码十分重要,将在正文着重讲解
d_loss.backward(retain_graph=True) # retain_graph=True 十分重要,否则计算图内存将会被释放
optimizer_D.step() # 判别器参数更新 # ########################
# 训练生成器 #
# ########################
g_loss = adversarial_loss(pred_gen, valid) # 生成器的损失函数
optimizer_G.zero_grad() # 生成器参数梯度归零
g_loss.backward() # 生成器的损失函数梯度反向传播
optimizer_G.step() # 生成器参数更新

代码讲解

  鉴别器的损失函数d_loss是由real_loss和fake_loss组成的,而fake_loss又是noise经过generator来的。这样一来我们对d_loss进行反向传播,不仅会计算discriminator 的梯度还会计算generator 的梯度(虽然这一步optimizer_D.step()只更新 discriminator 的参数),因此下面在更新generator参数时,要先将generator参数的梯度清零,避免受到discriminator loss 回传过来的梯度影响。

  generator 的 损失在回传时,同样要经过 discriminator 网络才能传递回自身(系统从输入噪声到 Discriminator 输出,从头到尾只有一次前向传播,而有两次反向传播,故在第一次反向传播时,鉴别器要设置 backward(retain graph=True),保持计算图不被释放。因为 pytorch 默认一个计算图只计算一次反向传播,反向传播后,这个计算图的内存就会被释放,所以用这个参数控制计算图不被释放。因此,在回传梯度时,同样也计算了一遍 discriminator 的参数梯度,只不过这次 discriminator 的参数不更新,只更新 generator 的参数,即 optimizer_G.step()。同时,我们看到,下一个 step 首先将 discriminator 的梯度重置为 0,就是为了防止 generator loss 反向传播时顺带计算的梯度对其造成影响(还有上一步 discriminator loss 回传时累积的梯度)。

  综上,我们看到,为了完成一步参数更新,我们进行了两次反向传播,第一次反向传播为了更新 discriminator 的参数,但多余计算了 generator 的梯度。第二次反向传播为了更新 generator 的参数,但是计算了 discriminator 的梯度,因此在写一个step,需要立即清零discriminator梯度。

  如果你实在看不懂,就照着这个形式写代码就行了,反正形式都帮你们写好了

策略二

  这种策略我遇到的比较多,也是先训练鉴别器,再训练生成器

  鉴别器训练阶段,noise 从 generator 输入,输出 fake data,然后 detach 一下,随着 true data 一起输入 discriminator,计算 discriminator 损失,并更新 discriminator 参数。生成器训练阶段,把没经过 detach 的 fake data 输入到discriminator 中,计算 generator loss,再反向传播梯度,更新 generator 的参数。这种策略,计算了两次 discriminator 梯度,一次 generator 梯度。感觉这种比较符合先更新 discriminator 的习惯。缺点是,之前的 generator 生成的计算图得保留着,直到 discriminator 更新完,再释放。

valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device)  # 真实标签,都是1
fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # 假标签,都是0 # ########################
# 训练判别器 #
# ########################
real_imgs = imgs.to(device) # 真实图片
z = torch.randn((imgs.shape[0], 100)).to(device) # 噪声 gen_imgs = generator(z) # 从噪声中生成假数据
pred_gen = discriminator(gen_imgs.detach()) # 假数据detach(),判别器对假数据的输出
pred_real = discriminator(real_imgs) # 判别器对真数据的输出 optimizer_D.zero_grad() # 把判别器中所有参数的梯度归零
real_loss = adversarial_loss(pred_real, valid) # 判别器对真实样本的损失
fake_loss = adversarial_loss(pred_gen, fake) # 判别器对假样本的损失
d_loss = (real_loss + fake_loss) / 2 # 两项损失相加取平均 # 下面这行代码十分重要,将在正文着重讲解
d_loss.backward() # retain_graph=True 十分重要,否则计算图内存将会被释放
optimizer_D.step() # 判别器参数更新 # ########################
# 训练生成器 #
# ########################
g_loss = adversarial_loss(pred_gen, valid) # 生成器的损失函数
optimizer_G.zero_grad() # 生成器参数梯度归零
g_loss.backward() # 生成器的损失函数梯度反向传播
optimizer_G.step() # 生成器参数更新

先训练生成器,再训练判别器

我们分析循环中一个 step 的代码:

valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)  # 真实样本的标签,都是 1
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 生成样本的标签,都是 0
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 噪声
real_imgs = Variable(imgs.type(Tensor)) # 真实图片 # ########################
# 训练生成器 #
# ########################
optimizer_G.zero_grad() # 生成器参数梯度归零
gen_imgs = generator(z) # 根据噪声生成虚假样本
g_loss = adversarial_loss(discriminator(gen_imgs), valid) # 用真实的标签+假样本,计算生成器损失
g_loss.backward() # 生成器梯度反向传播,反向传播经过了判别器,故此时判别器参数也有梯度
optimizer_G.step() # 生成器参数更新,判别器参数虽然有梯度,但是这一步不能更新判别器 # ########################
# 训练判别器 #
# ########################
optimizer_D.zero_grad() # 把生成器损失函数梯度反向传播时,顺带计算的判别器参数梯度清空
real_loss = adversarial_loss(discriminator(real_imgs), valid) # 真样本+真标签:判别器损失
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 假样本+假标签:判别器损失
d_loss = (real_loss + fake_loss) / 2 # 判别器总的损失函数
d_loss.backward() # 判别器损失回传
optimizer_D.step() # 判别器参数更新

  为了更新生成器参数,用生成器的损失函数计算梯度,然后反向传播,传播图中经过了判别器,根据链式法则,不得不顺带计算一下判别器的参数梯度,虽然在这一步不会更新判别器参数。反向传播过后,noise 到 fake image 再到 discriminator 的输出这个前向传播的计算图就被释放掉了,后面也不会再用到。

  接着更新判别器参数,此时注意到,我们输入判别器的是两部分,一部分是真实数据,另一部分是生成器的输出,也就是假数据。注意观察细节,在判别器前向传播过程,输入的假数据被 detach 了,detach 的意思是,这个数据和生成它的计算图“脱钩”了,即梯度传到它那个地方就停了,不再继续往前传播(实际上也不会再往前传播了,因为 generator 的计算图在第一次反向传播过后就被释放了)。因此,判别器梯度反向传播,就到它自己身上为止。

  因此,比起第一种策略,这种策略要少计算一次 generator 的所有参数的梯度,同时,也不必刻意保存一次计算图,占用不必要的内存。

  但需要注意的是,在第一种策略中,noise 从 generator 输入,到 discriminator 输出,只经历了一次前向传播,discriminator 端的输出,被用了两次,一次是计算 discriminator 的损失函数,另一次是计算 generator 的损失函数。

  而在第这种策略中,noise 从 generator 输入,到discriminator 输出,计算 generator 损失,回传,这一步更新了 generator 的参数,并释放了计算图。下一步更新 discriminator 的参数时,generator 的输出经过 detach 后,又通过了一遍 discriminator,相当于,generator 的输出前后两次通过了 discriminator ,得到相同的输出。显然,这也是冗余的。

总结

综上,这两段代码各有利弊:

  第一段代码,好处是 noise 只进行了一次前向传播,缺点是,更新 discriminator 参数时,多计算了一次 generator 的梯度,同时,第一次更新 discriminator 需要保留计算图,保证算 generator loss 时计算图不被销毁。

  第三段代码,好处是通过先更新 generator ,使更新后的前向传播计算图可以放心被销毁,因此不用保留计算图占用内存。同时,在更新 discriminator 的时候,也不会像上面的那段代码,计算冗余的 generator 的梯度。缺点是,在 discriminator 上,对 generator 的输出算了两次前向传播,第二次又产生了新的计算图(但比第一次的小)。

一个多计算了一次 generator 梯度,一个多计算一次 discriminator 前向传播。因此,两者差别不大。如果 discriminator 比generator 复杂,那么应该采取第一种策略,如果 discriminator 比 generator 简单,那么应该采取第三种策略,通常情况下,discriminator 要比 generator 简单,故如果效果差不多尽量采取第三种策略。

  但是第三种先更新generator,再更新 discriminator 总是给人感觉怪怪得,因为 generator 的更新需要 discriminator 提供准确的 loss 和 gradient,否则岂不是在瞎更新?

  但是策略三,马上用完马上释放。综合来说,还是策略三最好,策略二其次,策略一最差(差在多计算一次 generator gradient 上,而通常多计算一次 generator gradient 的运算量比多计算一次 discriminator 前向传播的运算量大),因此,detach 还是很有必要的。

参考

Pytorch: detach 和 retain_graph

使用PyTorch进行GAN训练时对于梯度截断的思考.detach()

pytorch训练GAN时的detach()的更多相关文章

  1. Pytorch训练时显存分配过程探究

    对于显存不充足的炼丹研究者来说,弄清楚Pytorch显存的分配机制是很有必要的.下面直接通过实验来推出Pytorch显存的分配过程. 实验实验代码如下: import torch from torch ...

  2. 深度学习识别CIFAR10:pytorch训练LeNet、AlexNet、VGG19实现及比较(三)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com VGGNet在2014年ImageNet图像分类任务竞赛中有出色的表现.网络结构如下图所示: 同样的, ...

  3. 深度学习识别CIFAR10:pytorch训练LeNet、AlexNet、VGG19实现及比较(二)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com AlexNet在2012年ImageNet图像分类任务竞赛中获得冠军.网络结构如下图所示: 对CIFA ...

  4. darknet训练yolov3时的一些注意事项

    训练需要用到的文件: 1)       .data文件.该文件包含一些配置信息,具体为训练的总类别数,训练数据和验证数据的路径,类别名称,模型存放路径等. 例如coco.data classes= 8 ...

  5. 一个简洁、好用的Pytorch训练模板

    一个简洁.好用的Pytorch训练模板 代码地址:https://github.com/KinglittleQ/Pytorch-Template 怎么使用 1) 更改template.py 替换 __ ...

  6. 怎么选取训练神经网络时的Batch size?

    怎么选取训练神经网络时的Batch size? - 知乎 https://www.zhihu.com/question/61607442 深度学习中的batch的大小对学习效果有何影响? - 知乎 h ...

  7. visdom可视化pytorch训练过程

    一.前言 在深度学习模型训练的过程中,常常需要实时监听并可视化一些数据,如损失值loss,正确率acc等.在Tensorflow中,最常使用的工具非Tensorboard莫属:在Pytorch中,也有 ...

  8. 一套兼容win和Linux的PyTorch训练MNIST的算法代码(CNN)

    第一次,调了很久.它本来已经很OK了,同时适用CPU和GPU,且可正常运行的. 为了用于性能测试,主要改了三点: 一,每一批次显示处理时间. 二,本地加载测试数据. 三,兼容LINUX和WIN 本地加 ...

  9. 深度学习识别CIFAR10:pytorch训练LeNet、AlexNet、VGG19实现及比较(一)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 前面几篇文章介绍了MINIST,对这种简单图片的识别,LeNet-5可以达到99%的识别率. CIFA ...

随机推荐

  1. Linux系统如何在离线环境或内网环境安装部署Docker服务和其他服务

    如何在离线环境或纯内网环境的Linux机器上安装部署Docker服务或其他服务.本次我们以Docker服务和Ansible服务为例. 获取指定服务的所有rpm包 保证要获取rpm包的机器能够上网. 本 ...

  2. 实时,异步网页使用jTable, SignalR和ASP。NET MVC

    下载source code - 984.21 KB 图:不同客户端的实时同步表. 点击这里观看现场演示. 文章概述 介绍使用的工具演示实现 模型视图控制器 遗言和感谢参考历史 介绍 HTTP(即web ...

  3. Consul 学习笔记-服务注册

    Consul简介: Consul是一种服务网格解决方案,提供具有服务发现,配置和分段功能的全功能控制平面.这些功能中的每一个都可以根据需要单独使用,也可以一起使用以构建完整的服务网格.Consul需要 ...

  4. 在Linux下如何根据域名自签发OpenSSL证书与常用证书转换

    在Linux下如何根据域名自签发各种SSL证书,这里我们以Apache.Tomcat.Nginx为例. openssl自签发泛域名(通配符)证书 首先要有openssl工具,如果没有那么使用如下命令安 ...

  5. Linux就该这么学28期——Day05 vim编辑器与Shell命令脚本 (yum配置 网卡配置)

    vim 三种模式: 命令模式 按行操作 dd 剪切.删除 5dd dG   全删 yy 复制光标所在行 p 粘贴 u 撤销操作 / 搜索 /ab n  下一个 N   上一个 输入模式 a 当前光标处 ...

  6. Monolog - Logging for PHP

    github地址:https://github.com/Seldaek/monolog 使用 Monolog 安装 核心概念 日志级别 配置一个日志服务 为记录添加额外的数据 使用通道 自定义日志格式 ...

  7. zabbix:以主动模式添加一台受监控主机 (zabbix5.0)

    一,zabbix被动模式和主动模式的区别? zabbix-agent默认的模式是被动模式, zabbix agent被动地接受zabbix server发来的指令, 获取数据后再返回给zabbix s ...

  8. centos8上配置openssh的安全

    一,openssh服务版本号的查看 1,查看当前sshd的版本号 : [root@yjweb ~]# sshd --help unknown option -- - OpenSSH_7.8p1, Op ...

  9. centos8平台redis cluster集群添加/删除node节点(redis5.0.7)

    一,当前redis cluster的node情况: 我们的添加删除等操作都是以这个cluster作为demo cluster采用六台redis,3主3从 redis1 : ip: 172.17.0.2 ...

  10. java Error opening registry key 'Software\JavaSoft\Java Runtime Environment'安装jdk1.7遇到的问题

    最近开发项目要求jdk在1.7以上,我先卸载了jdk1.6,下载1.7下来安装好,配置下环境变量,可以是在输入java -version的时候发现: java Error opening regist ...