【源码解读】pix2pix(一):训练
源码地址:https://github.com/mrzhu-cool/pix2pix-pytorch
相比于朱俊彦的版本,这一版更加简单易读
训练的代码在train.py,开头依然是很多代码的共同三板斧,加载参数,加载数据,加载模型
命令行参数
# Training settings
parser = argparse.ArgumentParser(description='pix2pix-pytorch-implementation')
parser.add_argument('--dataset', required=True, help='facades')
parser.add_argument('--batch_size', type=int, default=1, help='training batch size')
parser.add_argument('--test_batch_size', type=int, default=1, help='testing batch size')
parser.add_argument('--direction', type=str, default='b2a', help='a2b or b2a')
parser.add_argument('--input_nc', type=int, default=3, help='input image channels')
parser.add_argument('--output_nc', type=int, default=3, help='output image channels')
parser.add_argument('--ngf', type=int, default=64, help='generator filters in first conv layer')
parser.add_argument('--ndf', type=int, default=64, help='discriminator filters in first conv layer')
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count')
parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine')
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', help='use cuda?')
parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use')
parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
parser.add_argument('--lamb', type=int, default=10, help='weight on L1 term in objective')
opt = parser.parse_args()
数据
print('===> Loading datasets')
root_path = "dataset/"
train_set = get_training_set(root_path + opt.dataset, opt.direction)
test_set = get_test_set(root_path + opt.dataset, opt.direction)
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.test_batch_size, shuffle=False)
模型
print('===> Building models')
net_g = define_G(opt.input_nc, opt.output_nc, opt.ngf, 'batch', False, 'normal', 0.02, gpu_id=device)
net_d = define_D(opt.input_nc + opt.output_nc, opt.ndf, 'basic', gpu_id=device)
优化器,损失函数
criterionGAN = GANLoss().to(device)
criterionL1 = nn.L1Loss().to(device)
criterionMSE = nn.MSELoss().to(device) # setup optimizer
optimizer_g = optim.Adam(net_g.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizer_d = optim.Adam(net_d.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
net_g_scheduler = get_scheduler(optimizer_g, opt)
net_d_scheduler = get_scheduler(optimizer_d, opt)
接着按批次读取数据,首先更新判别器,判别器的输入是图像对(真,真)(真,假)
######################
# (1) Update D network
###################### optimizer_d.zero_grad() # train with fake
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False) # train with real
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True) # Combined D loss
loss_d = (loss_d_fake + loss_d_real) * 0.5 loss_d.backward() optimizer_d.step()
然后更新生成器,生成器的损失由判别器产生的损失函数和真假图像之间的L1约束组成
######################
# (2) Update G network
###################### optimizer_g.zero_grad() # First, G(A) should fake the discriminator
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab)
loss_g_gan = criterionGAN(pred_fake, True) # Second, G(A) = B
loss_g_l1 = criterionL1(fake_b, real_b) * opt.lamb loss_g = loss_g_gan + loss_g_l1 loss_g.backward() optimizer_g.step()
最后更新学习率
update_learning_rate(net_g_scheduler, optimizer_g)
update_learning_rate(net_d_scheduler, optimizer_d)
比较核心的代码是网络构造,以及一些工具函数,放在后面写
【源码解读】pix2pix(一):训练的更多相关文章
- Bert系列(二)——源码解读之模型主体
本篇文章主要是解读模型主体代码modeling.py.在阅读这篇文章之前希望读者们对bert的相关理论有一定的了解,尤其是transformer的结构原理,网上的资料很多,本文内容对原理部分就不做过多 ...
- Bert系列(三)——源码解读之Pre-train
https://www.jianshu.com/p/22e462f01d8c pre-train是迁移学习的基础,虽然Google已经发布了各种预训练好的模型,而且因为资源消耗巨大,自己再预训练也不现 ...
- [源码分析] Facebook如何训练超大模型 --- (2)
[源码分析] Facebook如何训练超大模型 --- (2) 目录 [源码分析] Facebook如何训练超大模型 --- (2) 0x00 摘要 0x01 回顾 1.1 ZeRO 1.1.1 Ze ...
- SDWebImage源码解读之SDWebImageDownloaderOperation
第七篇 前言 本篇文章主要讲解下载操作的相关知识,SDWebImageDownloaderOperation的主要任务是把一张图片从服务器下载到内存中.下载数据并不难,如何对下载这一系列的任务进行设计 ...
- SDWebImage源码解读 之 NSData+ImageContentType
第一篇 前言 从今天开始,我将开启一段源码解读的旅途了.在这里先暂时不透露具体解读的源码到底是哪些?因为也可能随着解读的进行会更改计划.但能够肯定的是,这一系列之中肯定会有Swift版本的代码. 说说 ...
- SDWebImage源码解读 之 UIImage+GIF
第二篇 前言 本篇是和GIF相关的一个UIImage的分类.主要提供了三个方法: + (UIImage *)sd_animatedGIFNamed:(NSString *)name ----- 根据名 ...
- SDWebImage源码解读 之 SDWebImageCompat
第三篇 前言 本篇主要解读SDWebImage的配置文件.正如compat的定义,该配置文件主要是兼容Apple的其他设备.也许我们真实的开发平台只有一个,但考虑各个平台的兼容性,对于框架有着很重要的 ...
- SDWebImage源码解读_之SDWebImageDecoder
第四篇 前言 首先,我们要弄明白一个问题? 为什么要对UIImage进行解码呢?难道不能直接使用吗? 其实不解码也是可以使用的,假如说我们通过imageNamed:来加载image,系统默认会在主线程 ...
- SDWebImage源码解读之SDWebImageCache(上)
第五篇 前言 本篇主要讲解图片缓存类的知识,虽然只涉及了图片方面的缓存的设计,但思想同样适用于别的方面的设计.在架构上来说,缓存算是存储设计的一部分.我们把各种不同的存储内容按照功能进行切割后,图片缓 ...
- SDWebImage源码解读之SDWebImageCache(下)
第六篇 前言 我们在SDWebImageCache(上)中了解了这个缓存类大概的功能是什么?那么接下来就要看看这些功能是如何实现的? 再次强调,不管是图片的缓存还是其他各种不同形式的缓存,在原理上都极 ...
随机推荐
- 11.Linux date命令的用法
date命令常的日常应用 修改时间 date -s “2008/05/23 19:20″ 打包文件 tar zcvf log-$(date +$F).gz /home/admin/logs 同步阿 ...
- js for循环中i++与++i有什么区别
平时都是这样写的for循环, 1 2 3 for(var i = 0; i < 20 ; i++){ .... } 但我看有的人这样写 for (var i = 0; ...
- GIT的工作原理和基本命令
1.GIT的工作原理 工作区:我们写代码的地方. 暂存区:临时存储用的. 历史区:生成历史版本的地方. 提交流程:工作区->暂存区->历史区 图示: 2.GIT的全局配置 3.创建仓库完成 ...
- 获取免费的https证书
可以通过网站获取免费的https证书 首先到https://freessl.org注册一个账号 然后就可以开始创建免费证书了 获取的证书里面通常只有pem后缀文件 nodejs使用的时候需要crt文件 ...
- Python的datetime与Decimal数据进行json序列化的简单说明
我们在Python的json.JSONEncoder类中可以查看Python数据序列化为JSON格式的数据时数据类型的对应关系: class JSONEncoder(object): "&q ...
- 洛谷P1190 接水问题
题目名称:接水问题 题目来源 [洛谷P1190] (https://www.luogu.org/problemnew/show/P1190) 题目描述 学校里有一个水房,水房里一共有\(m\)个龙头 ...
- 转 实例具体解释DJANGO的 SELECT_RELATED 和 PREFETCH_RELATED 函数对 QUERYSET 查询的优化(二)
https://blog.csdn.net/cugbabybear/article/details/38342793 这是本系列的第二篇,内容是 prefetch_related() 函数的用途.实现 ...
- ubuntu通过windows下的ccproxy代理上网
网上教程很多,需要注意的是将ubuntu的ip和windows的Ip设置到同一个网段,即子网掩码是1的对应的部分要相同.由于没有配置到同一个网段,折腾了我好久.
- GitHub - 解决 GitHub Page 404
带有下划线的文件报 404 解决:在仓库文件夹根目录添加.nojekyll文件 参见: Bypassing Jekyll on GitHub Pages - The GitHub Blog How t ...
- Openstack 实现技术分解 (2) 虚拟机初始化工具 — Cloud-Init & metadata & userdata
目录 目录 前文列表 扩展阅读 系统环境 前言 Cloud-init Cloud-init 的配置文件 metadata userdata metadata 和 userdata 的区别 metada ...