【源码解读】cycleGAN(一):网络
源码地址:https://github.com/aitorzip/PyTorch-CycleGAN

如图所示,cycleGAN的网络结构包括两个生成器G(X->Y)和F(Y->X),两个判别器Dx和Dy
生成器部分:网络整体上经过一个降采样然后上采样的过程,中间是一系列残差块,数目由实际情况确定,根据论文中所说,当输入分辨率为128x128,采用6个残差块,当输入分辨率为256x256甚至更高时,采用9个残差块,其源代码如下,
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9):
super(Generator, self).__init__() # Initial convolution block
model = [ nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True) ] # Downsampling
in_features = 64
out_features = in_features*2
for _ in range(2):
model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features*2 # Residual blocks
for _ in range(n_residual_blocks):
model += [ResidualBlock(in_features)] # Upsampling
out_features = in_features//2
for _ in range(2):
model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features//2 # Output layer
model += [ nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7),
nn.Tanh() ] self.model = nn.Sequential(*model) def forward(self, x):
return self.model(x)
其中,值得注意的网络层是nn.ReflectionPad2d和nn.InstanceNorm2d,前者搭配7x7卷积,先在特征图周围以反射的方式补长度,使得卷积后特征图尺寸不变,示例如下,输出结果就是以特征图边界为反射边,向外补充

nn.InstanceNorm2d是相比于batchNorm更加适合图像生成,风格迁移的归一化方法,相比于batchNorm跨样本,单通道统计,InstanceNorm采用单样本,单通道统计,括号中的参数代表通道数。
判别器部分:结构比生成器更加简单,经过5层卷积,通道数缩减为1,最后池化平均,尺寸也缩减为1x1,最最后reshape一下,变为(batchsize,1)
class Discriminator(nn.Module):
def __init__(self, input_nc):
super(Discriminator, self).__init__() # A bunch of convolutions one after another
model = [ nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True) ] model += [ nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2, inplace=True) ] model += [ nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2, inplace=True) ] model += [ nn.Conv2d(256, 512, 4, padding=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, inplace=True) ] # FCN classification layer
model += [nn.Conv2d(512, 1, 4, padding=1)] self.model = nn.Sequential(*model) def forward(self, x):
x = self.model(x)
# Average pooling and flatten
return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0])
【源码解读】cycleGAN(一):网络的更多相关文章
- Alamofire源码解读系列(七)之网络监控(NetworkReachabilityManager)
Alamofire源码解读系列(七)之网络监控(NetworkReachabilityManager) 本篇主要讲解iOS开发中的网络监控 前言 在开发中,有时候我们需要获取这些信息: 手机是否联网 ...
- SDWebImage源码解读 之 UIImage+GIF
第二篇 前言 本篇是和GIF相关的一个UIImage的分类.主要提供了三个方法: + (UIImage *)sd_animatedGIFNamed:(NSString *)name ----- 根据名 ...
- AFNetworking 3.0 源码解读 总结(干货)(下)
承接上一篇AFNetworking 3.0 源码解读 总结(干货)(上) 21.网络服务类型NSURLRequestNetworkServiceType 示例代码: typedef NS_ENUM(N ...
- AFNetworking 3.0 源码解读 总结(干货)(上)
养成记笔记的习惯,对于一个软件工程师来说,我觉得很重要.记得在知乎上看到过一个问题,说是人类最大的缺点是什么?我个人觉得记忆算是一个缺点.它就像时间一样,会自己消散. 前言 终于写完了 AFNetwo ...
- AFNetworking 3.0 源码解读(十)之 UIActivityIndicatorView/UIRefreshControl/UIImageView + AFNetworking
我们应该看到过很多类似这样的例子:某个控件拥有加载网络图片的能力.但这究竟是怎么做到的呢?看完这篇文章就明白了. 前言 这篇我们会介绍 AFNetworking 中的3个UIKit中的分类.UIAct ...
- AFNetworking 3.0 源码解读(九)之 AFNetworkActivityIndicatorManager
让我们的APP像艺术品一样优雅,开发工程师更像是一名匠人,不仅需要精湛的技艺,而且要有一颗匠心. 前言 AFNetworkActivityIndicatorManager 是对状态栏中网络激活那个小控 ...
- AFNetworking 3.0 源码解读(八)之 AFImageDownloader
AFImageDownloader 这个类对写DownloadManager有很大的借鉴意义.在平时的开发中,当我们使用UIImageView加载一个网络上的图片时,其原理就是把图片下载下来,然后再赋 ...
- AFNetworking 3.0 源码解读(七)之 AFAutoPurgingImageCache
这篇我们就要介绍AFAutoPurgingImageCache这个类了.这个类给了我们临时管理图片内存的能力. 前言 假如说我们要写一个通用的网络框架,除了必备的请求数据的方法外,必须提供一个下载器来 ...
- AFNetworking 3.0 源码解读(六)之 AFHTTPSessionManager
AFHTTPSessionManager相对来说比较好理解,代码也比较短.但却是我们平时可能使用最多的类. AFNetworking 3.0 源码解读(一)之 AFNetworkReachabilit ...
- AFNetworking 3.0 源码解读(一)之 AFNetworkReachabilityManager
做ios开发,AFNetworking 这个网络框架肯定都非常熟悉,也许我们平时只使用了它的部分功能,而且我们对它的实现原理并不是很清楚,就好像总是有一团迷雾在眼前一样. 接下来我们就非常详细的来读一 ...
随机推荐
- PCL智能指针疑云 <二> 使用同一智能指针作为PCL预处理API的输入和输出
问题介绍: slam构建地图,先进行降采样,再进行可视化或存储.然而经过降采样后,代码没有报错的情况下,点云数据散成一团.将代码和点云数据展示如下, pcl::VoxelGrid<Lidar:: ...
- Springboot入门:
Springboot入门: 1.springboot是基于spring的全新框架,设计目的:简化spring应用配置和开发过程. 该框架遵循“约定大于配置”原则,采用特定的方式进行配置,从而事开发者无 ...
- java跨越请求实例
使用Access-Control-Allow-Origin解决跨域 什么是跨域 当两个域具有相同的协议(如http), 相同的端口(如80),相同的host(如www.google.com),那么我们 ...
- Found duplicate classes/resources
很可能是多个三方依赖重复了,依赖个插件,这个插件能查找出依赖关系, duplicate-finder-maven-plugin 使用命令显示 mvn dependency:tree [INFO] \- ...
- c++ 指针类型转换
1.数据类型转换(static_cast) //数据类型转换printf("%d\n", static_cast<int>(10.2)); 2.指针类型转换(reint ...
- leetcode 234 回文链表 Palindrome Linked List
要求用O(n)时间,和O(1)空间,因此思路是用本身链表进行判断,既然考虑回文,本方法思想是先遍历一次求链表长度,然后翻转前半部分链表:然后同时对前半部分链表和后半部分链表遍历,来判断对应节点的值是否 ...
- ThreadLocal 源码分析
线程局部变量 ThreadLocal 用于实现线程隔离和类间变量共享. 创建实例 /** * 当前 ThreadLocal 实例的哈希值 */ private final int threadLoca ...
- leetcode 56. Merge Intervals 、57. Insert Interval
56. Merge Intervals是一个无序的,需要将整体合并:57. Insert Interval是一个本身有序的且已经合并好的,需要将新的插入进这个已经合并好的然后合并成新的. 56. Me ...
- 4.2.k8s.Ingress-Nginx
Ingress-Nginx ingress-nginx为7层代理,通过配置域名访问后端服务 ingress-nginx容器和kubernetes api交互,动态生成nginx配置 ingress服务 ...
- DedeCMS调取其他织梦CMS站点数据库数据方法
第1步:打开网站include\taglib文件夹中找到sql.lib.php文件,并直接复制一些此文件出来,并把复制出来的这个文件重命名为mysql.lib.php.注:mysql.lib.php, ...