自监督图像论文复现 | BYOL(pytorch)| 2020
继续上一篇的内容,上一篇讲解了Bootstrap Your Onw Latent自监督模型的论文和结构:
https://juejin.cn/post/6922347006144970760
现在我们看看如何用pytorch来实现这个结构,并且在学习的过程中加深对论文的理解。
github:https://github.com/lucidrains/byol-pytorch
【前沿】:这个代码我没有实际跑过,毕竟我只是一个没有GPU的小可怜。
主要模型代码
class BYOL(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_size = 256,
projection_hidden_size = 4096,
augment_fn = None,
augment_fn2 = None,
moving_average_decay = 0.99,
use_momentum = True
):
super().__init__()
self.net = net
# default SimCLR augmentation
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.RandomResizedCrop((image_size, image_size)),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, self.augment1)
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
self.use_momentum = use_momentum
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)
self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
@singleton('target_encoder')
def _get_target_encoder(self):
target_encoder = copy.deepcopy(self.online_encoder)
set_requires_grad(target_encoder, False)
return target_encoder
def reset_moving_average(self):
del self.target_encoder
self.target_encoder = None
def update_moving_average(self):
assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
assert self.target_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
def forward(self, x, return_embedding = False):
if return_embedding:
return self.online_encoder(x)
image_one, image_two = self.augment1(x), self.augment2(x)
online_proj_one, _ = self.online_encoder(image_one)
online_proj_two, _ = self.online_encoder(image_two)
online_pred_one = self.online_predictor(online_proj_one)
online_pred_two = self.online_predictor(online_proj_two)
with torch.no_grad():
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_proj_one, _ = target_encoder(image_one)
target_proj_two, _ = target_encoder(image_two)
target_proj_one.detach_()
target_proj_two.detach_()
loss_one = loss_fn(online_pred_one, target_proj_two.detach())
loss_two = loss_fn(online_pred_two, target_proj_one.detach())
loss = loss_one + loss_two
return loss.mean()
- 先看
forward()
函数,发现输入一个图片给模型,然后返回值是这个图片计算的loss - 如果是推理过程,那么
return_embedding=True
,那么返回的值就是online network中的encoder部分输出的东西,不用在考虑后面的predictor,这里需要注意代码中的encoder其实是论文中的encoder+projector; - 图片经过self.augment1和self.augment2处理成两个不同的图片,在上一篇中,我们称之为view;
- 两个图片都经过online-encoder,这里可能会有疑问:不是应该一个图片经过online network,另外一个经过target network吗?为什么这两个都经过online-encoder,你说的没错,这里只是方便后面计算symmetric loss,因为要计算对称损失,所以两个图片都要经过online network和target network。
- 在target network中推理的内容,都不需要记录梯度,因为target network是根据online network的参数更新的
- 如果
self.use_momentum=False
,那么就不使用论文中的更新target network的方式,而是直接把online network复制给target network,不过我发现!这个github代码虽然有600多stars,但是这里的就算你的self.use_momentum=True,其实也是把online network复制给了target network啊哈哈,那么就不在这里深究了。 - 最后计算通过
loss_fn
计算损失,然后return loss.mean()
所以,目前位置,我们发现这个BYOL的结构其实很简单,目前还有疑点的地方有4个:
- online_encoder如何定义?
- predictor如何定义?
- 图像增强方法如何定义?
- loss_fn损失函数如何定义?
augment
从上面的代码中可以看到这一段:
# default SimCLR augmentation
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.RandomResizedCrop((image_size, image_size)),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, self.augment1)
可以看到:
- 这个就是图像增强的pipeline,而augment1和augment2可以自定义,默认的话就是augment1和augment2都是上面的DEFAULT_AUG;
from torchvision import transforms as T
比较陌生的可能就是torchvision.transforms.ColorJitter()
这个方法了。
从官方API上可以看到,这个方法其实就是随机的修改图片的亮度,对比度,饱和度和色调
encoder+projector
class NetWrapper(nn.Module):
def __init__(self, net, projection_size, projection_hidden_size, layer = -2):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.projection_size = projection_size
self.projection_hidden_size = projection_hidden_size
self.hidden = None
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, __, output):
self.hidden = flatten(output)
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
@singleton('projector')
def _get_projector(self, hidden):
_, dim = hidden.shape
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
return projector.to(hidden)
def get_representation(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
_ = self.net(x)
hidden = self.hidden
self.hidden = None
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
def forward(self, x, return_embedding = False):
representation = self.get_representation(x)
if return_embedding:
return representation
projector = self._get_projector(representation)
projection = projector(representation)
return projection, representation
这个就是基本的encoder+projector,里面包含encoder和projector。
encoder
这个在初始化NetWrapper的时候,需要作为参数传递进来,所以看了训练文件,发现这个模型为:
from torchvision import models, transforms
resnet = models.resnet50(pretrained=True)
所以encoder和论文中说的一样,是一个resnet50。如果我记得没错,这个resnet输出的是一个(batch_size,1000)这样子的tensor。
projector
调用到了MLP这个东西:
class MLP(nn.Module):
def __init__(self, dim, projection_size, hidden_size = 4096):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
)
def forward(self, x):
return self.net(x)
是全连接层+BN+激活层的结构。和论文中说的差不多,并且在最后的全连接层后面没有加上BN+relu。经过这个MLP,返回的是一个(batch_size,projection_size)这样形状的tensor。
predictor
self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
这个predictor,其实就是和projector一模一样的东西,可以看到predictor的输入和输出的特征数量都是projection_size
。
这里因为我对自监督的体系没有完整的阅读论文,只是最先看了这个BYOL,所以我无法说明这个predictor为什么存在。从表现来看,是为了防止online network和target network的结构完全相同,如果完全相同的话可能会让两个模型训练出完全一样的效果,也就是loss=0的情况。假设
loss_fn
def loss_fn(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
这部分和论文中一致。
综上所属,这个BYOL框架是一个简单,又有趣的无监督架构。
自监督图像论文复现 | BYOL(pytorch)| 2020的更多相关文章
- Visualizing and Understanding Convolutional Networks论文复现笔记
目录 Visualizing and Understanding Convolutional Networks 论文复现笔记 Abstract Introduction Approach Visual ...
- Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易
近日,PyTorch 社区发布了一个深度学习工具包 PyTorchHub, 帮助机器学习工作者更快实现重要论文的复现工作.PyTorchHub 由一个预训练模型仓库组成,专门用于提高研究工作的复现性以 ...
- 图像风格迁移(Pytorch)
图像风格迁移 最后要生成的图片是怎样的是难以想象的,所以朴素的监督学习方法可能不会生效, Content Loss 根据输入图片和输出图片的像素差别可以比较损失 \(l_{content} = \fr ...
- 小白经典CNN论文复现系列(一):LeNet1989
小白的经典CNN复现系列(一):LeNet-1989 之前的浙大AI作业的那个系列,因为后面的NLP的东西我最近大概是不会接触到,所以我们先换一个系列开始更新博客,就是现在这个经典的CNN复现啦(。・ ...
- GAN生成图像论文总结
GAN Theory Modifyingthe Optimization of GAN 题目 内容 GAN DCGAN WGAN Least-square GAN Loss Sensi ...
- 化繁为简,弱监督目标定位领域的新SOTA - 伪监督目标定位方法(PSOL) | CVPR 2020
论文提出伪监督目标定位方法(PSOL)来解决目前弱监督目标定位方法的问题,该方法将定位与分类分开成两个独立的网络,然后在训练集上使用Deep descriptor transformation(DDT ...
- 训练一个图像分类器demo in PyTorch【学习笔记】
[学习源]Tutorials > Deep Learning with PyTorch: A 60 Minute Blitz > Training a Classifier 本文相当于 ...
- 库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)
项目地址:https://github.com/bharathgs/Awesome-pytorch-list 列表结构: NLP 与语音处理 计算机视觉 概率/生成库 其他库 教程与示例 论文实现 P ...
- 复现ICCV 2017经典论文—PyraNet
. 过去几年发表于各大 AI 顶会论文提出的 400 多种算法中,公开算法代码的仅占 6%,其中三分之一的论文作者分享了测试数据,约 54% 的分享包含“伪代码”.这是今年 AAAI 会议上一个严峻的 ...
随机推荐
- maven项目修改名称后,打包名称和现在名称不一致
将pom.xm文件中 <artifactId>health</artifactId> 修改成现在项目名称,然后 maven clean ->maven install 如 ...
- Centos7对外开放端口
(1)查看对外开放的端口状态 查询已开放的端口 netstat -anp 查询指定端口是否已开 firewall-cmd --query-port=666/tcp 提示 yes,表示开启:no表示未开 ...
- Linux嵌入式学习-远程过程调用-Binder系统
Binder系统的C程序使用示例IPC : Inter-Process Communication, 进程间通信RPC : Remote Procedure Call, 远程过程调用 这里我们直接只用 ...
- Redis集群搭建采坑总结
背景 先澄清一下,整个过程问题都不是我解决的,我在里面就是起了个打酱油的角色.因为实际上我负责这个项目,整个过程也比较清楚.之前也跟具体负责的同事说过,等过段时间带他做做项目复盘.结果一直忙,之前做的 ...
- File类的特点?如何创建File类对象?Java中如何操作文件内容,什么是Io流Io流如何读取和写入文件?字节缓冲流使用原则?
重难点提示 学习目标 1.能够了解File类的特点(存在的意义,构造方法,常见方法) 2.能够了解什么是IO流以及分类(IO流的概述以及分类) 3.能够掌握字节输出流的使用(继承体系结构介绍以及常见的 ...
- web攻防环境--一句话木马
任务一.基于centos7搭建dvwa web服务靶机 1.在centos7安装LAMP并启动,访问phpinfo页面 也即安装httpd.php.mysql服务. 直接进行yum安装即可,完成后检查 ...
- CRM、DMP、CDP概念解析
CRM.DMP.CDP,都是什么鬼?有什么区别差异?别说你都懂 摘自https://maxket.com/crm-dmp-cdp/ 如果您不想多花人生中宝贵的十分钟,那么不用多考虑了,上CDP吧.如果 ...
- 学习一下 SpringCloud (三)-- 服务调用、负载均衡 Ribbon、OpenFeign
(1) 相关博文地址: 学习一下 SpringCloud (一)-- 从单体架构到微服务架构.代码拆分(maven 聚合): https://www.cnblogs.com/l-y-h/p/14105 ...
- 达梦数据库学习(一、linux操作系统安装及数据库安装)
达梦数据库学习(一.linux操作系统安装及数据库安装) 环境介绍: 使用VM12+中标麒麟V7.0操作系统+达梦8数据库 一.linux系统搭建 本部分没有需要着重介绍,注意安装时基本环境选择&qu ...
- Openstack glance 镜像服务 (五)
Openstack glance 镜像服务 (五) 引用: 官方文档glance安装 https://docs.openstack.org/ocata/zh_CN/install-guide-rdo/ ...