自监督图像论文复现 | BYOL(pytorch)| 2020
继续上一篇的内容,上一篇讲解了Bootstrap Your Onw Latent自监督模型的论文和结构:
https://juejin.cn/post/6922347006144970760
现在我们看看如何用pytorch来实现这个结构,并且在学习的过程中加深对论文的理解。
github:https://github.com/lucidrains/byol-pytorch
【前沿】:这个代码我没有实际跑过,毕竟我只是一个没有GPU的小可怜。
主要模型代码
class BYOL(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_size = 256,
projection_hidden_size = 4096,
augment_fn = None,
augment_fn2 = None,
moving_average_decay = 0.99,
use_momentum = True
):
super().__init__()
self.net = net
# default SimCLR augmentation
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.RandomResizedCrop((image_size, image_size)),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, self.augment1)
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
self.use_momentum = use_momentum
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)
self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
@singleton('target_encoder')
def _get_target_encoder(self):
target_encoder = copy.deepcopy(self.online_encoder)
set_requires_grad(target_encoder, False)
return target_encoder
def reset_moving_average(self):
del self.target_encoder
self.target_encoder = None
def update_moving_average(self):
assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
assert self.target_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
def forward(self, x, return_embedding = False):
if return_embedding:
return self.online_encoder(x)
image_one, image_two = self.augment1(x), self.augment2(x)
online_proj_one, _ = self.online_encoder(image_one)
online_proj_two, _ = self.online_encoder(image_two)
online_pred_one = self.online_predictor(online_proj_one)
online_pred_two = self.online_predictor(online_proj_two)
with torch.no_grad():
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_proj_one, _ = target_encoder(image_one)
target_proj_two, _ = target_encoder(image_two)
target_proj_one.detach_()
target_proj_two.detach_()
loss_one = loss_fn(online_pred_one, target_proj_two.detach())
loss_two = loss_fn(online_pred_two, target_proj_one.detach())
loss = loss_one + loss_two
return loss.mean()
- 先看
forward()
函数,发现输入一个图片给模型,然后返回值是这个图片计算的loss - 如果是推理过程,那么
return_embedding=True
,那么返回的值就是online network中的encoder部分输出的东西,不用在考虑后面的predictor,这里需要注意代码中的encoder其实是论文中的encoder+projector; - 图片经过self.augment1和self.augment2处理成两个不同的图片,在上一篇中,我们称之为view;
- 两个图片都经过online-encoder,这里可能会有疑问:不是应该一个图片经过online network,另外一个经过target network吗?为什么这两个都经过online-encoder,你说的没错,这里只是方便后面计算symmetric loss,因为要计算对称损失,所以两个图片都要经过online network和target network。
- 在target network中推理的内容,都不需要记录梯度,因为target network是根据online network的参数更新的
- 如果
self.use_momentum=False
,那么就不使用论文中的更新target network的方式,而是直接把online network复制给target network,不过我发现!这个github代码虽然有600多stars,但是这里的就算你的self.use_momentum=True,其实也是把online network复制给了target network啊哈哈,那么就不在这里深究了。 - 最后计算通过
loss_fn
计算损失,然后return loss.mean()
所以,目前位置,我们发现这个BYOL的结构其实很简单,目前还有疑点的地方有4个:
- online_encoder如何定义?
- predictor如何定义?
- 图像增强方法如何定义?
- loss_fn损失函数如何定义?
augment
从上面的代码中可以看到这一段:
# default SimCLR augmentation
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.RandomResizedCrop((image_size, image_size)),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, self.augment1)
可以看到:
- 这个就是图像增强的pipeline,而augment1和augment2可以自定义,默认的话就是augment1和augment2都是上面的DEFAULT_AUG;
from torchvision import transforms as T
比较陌生的可能就是torchvision.transforms.ColorJitter()
这个方法了。
从官方API上可以看到,这个方法其实就是随机的修改图片的亮度,对比度,饱和度和色调
encoder+projector
class NetWrapper(nn.Module):
def __init__(self, net, projection_size, projection_hidden_size, layer = -2):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.projection_size = projection_size
self.projection_hidden_size = projection_hidden_size
self.hidden = None
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, __, output):
self.hidden = flatten(output)
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
@singleton('projector')
def _get_projector(self, hidden):
_, dim = hidden.shape
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
return projector.to(hidden)
def get_representation(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
_ = self.net(x)
hidden = self.hidden
self.hidden = None
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
def forward(self, x, return_embedding = False):
representation = self.get_representation(x)
if return_embedding:
return representation
projector = self._get_projector(representation)
projection = projector(representation)
return projection, representation
这个就是基本的encoder+projector,里面包含encoder和projector。
encoder
这个在初始化NetWrapper的时候,需要作为参数传递进来,所以看了训练文件,发现这个模型为:
from torchvision import models, transforms
resnet = models.resnet50(pretrained=True)
所以encoder和论文中说的一样,是一个resnet50。如果我记得没错,这个resnet输出的是一个(batch_size,1000)这样子的tensor。
projector
调用到了MLP这个东西:
class MLP(nn.Module):
def __init__(self, dim, projection_size, hidden_size = 4096):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
)
def forward(self, x):
return self.net(x)
是全连接层+BN+激活层的结构。和论文中说的差不多,并且在最后的全连接层后面没有加上BN+relu。经过这个MLP,返回的是一个(batch_size,projection_size)这样形状的tensor。
predictor
self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
这个predictor,其实就是和projector一模一样的东西,可以看到predictor的输入和输出的特征数量都是projection_size
。
这里因为我对自监督的体系没有完整的阅读论文,只是最先看了这个BYOL,所以我无法说明这个predictor为什么存在。从表现来看,是为了防止online network和target network的结构完全相同,如果完全相同的话可能会让两个模型训练出完全一样的效果,也就是loss=0的情况。假设
loss_fn
def loss_fn(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
这部分和论文中一致。
综上所属,这个BYOL框架是一个简单,又有趣的无监督架构。
自监督图像论文复现 | BYOL(pytorch)| 2020的更多相关文章
- Visualizing and Understanding Convolutional Networks论文复现笔记
目录 Visualizing and Understanding Convolutional Networks 论文复现笔记 Abstract Introduction Approach Visual ...
- Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易
近日,PyTorch 社区发布了一个深度学习工具包 PyTorchHub, 帮助机器学习工作者更快实现重要论文的复现工作.PyTorchHub 由一个预训练模型仓库组成,专门用于提高研究工作的复现性以 ...
- 图像风格迁移(Pytorch)
图像风格迁移 最后要生成的图片是怎样的是难以想象的,所以朴素的监督学习方法可能不会生效, Content Loss 根据输入图片和输出图片的像素差别可以比较损失 \(l_{content} = \fr ...
- 小白经典CNN论文复现系列(一):LeNet1989
小白的经典CNN复现系列(一):LeNet-1989 之前的浙大AI作业的那个系列,因为后面的NLP的东西我最近大概是不会接触到,所以我们先换一个系列开始更新博客,就是现在这个经典的CNN复现啦(。・ ...
- GAN生成图像论文总结
GAN Theory Modifyingthe Optimization of GAN 题目 内容 GAN DCGAN WGAN Least-square GAN Loss Sensi ...
- 化繁为简,弱监督目标定位领域的新SOTA - 伪监督目标定位方法(PSOL) | CVPR 2020
论文提出伪监督目标定位方法(PSOL)来解决目前弱监督目标定位方法的问题,该方法将定位与分类分开成两个独立的网络,然后在训练集上使用Deep descriptor transformation(DDT ...
- 训练一个图像分类器demo in PyTorch【学习笔记】
[学习源]Tutorials > Deep Learning with PyTorch: A 60 Minute Blitz > Training a Classifier 本文相当于 ...
- 库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)
项目地址:https://github.com/bharathgs/Awesome-pytorch-list 列表结构: NLP 与语音处理 计算机视觉 概率/生成库 其他库 教程与示例 论文实现 P ...
- 复现ICCV 2017经典论文—PyraNet
. 过去几年发表于各大 AI 顶会论文提出的 400 多种算法中,公开算法代码的仅占 6%,其中三分之一的论文作者分享了测试数据,约 54% 的分享包含“伪代码”.这是今年 AAAI 会议上一个严峻的 ...
随机推荐
- mysql 提示 vcruntime140_1.dll丢失
百度网盘:https://pan.baidu.com/s/1vbVexHs1eRfGlnTbr8U53Q 提取码:59tm 将两个文件同时放到路径:C:\Windows\System32 下,运行ba ...
- [LeetCode]231. Power of Two判断是不是2\3\4的幂
/* 用位操作,乘2相当于左移1位,所以2的幂只有最高位是1 所以问题就是判断你是不是只有最高位是1,怎判断呢 这些数-1后形成的数,除了最高位,后边都是1,如果n&n-1就可以判断了 如果是 ...
- 如何理解SQL的可重复读和幻读之间的区别?
从本源来理解比较容易理解,如果只是描述概念和定义,容易让人云里雾里找不到方向.正好这两天在浏览mysql的文档,我可以简单在这里总结一下,帮助其他还没有理解的朋友,如果有错误也麻烦帮忙指正. 先讲一点 ...
- 冒泡排序算法JAVA实现版
/***关于冒泡排序,从性能最低版本实现到性能最优版本实现*/public class BubbleSortDemo { public static void sort(int array[]) { ...
- 实战Git命令(界面操作+命令行)
先说明下公司的发版步骤,当需要开发一个新的功能,先从master分支中拉出一个自己的分支a(假设分支为a),在a分支开发功能完后,需要切换到dev分支,然后把自己的分支a合到dev分支,部署测试环境让 ...
- SpringBoot入门及深入
一:SpringBoot简介 当前互联网后端开发中,JavaEE占据了主导地位.对JavaEE开发,首选框架是Spring框架.在传统的Spring开发中,需要使用大量的与业务无关的XML配置才能使S ...
- vmstat参数详解
vmstat 5 可以使用ctrl+c停止vmstat,可以看到输出依赖于所用的操作系统,因此可能需要阅读一下手册来解读报告 第一行的值是显示子系统启动以来的平均值,第二行开始展示现在正在发生的情况, ...
- 利用sql_tuning_Advisor调优sql
1.赋权给调优用户 grant ADVISOR to xxxxxx; 2.创建调优任务 使用sql_text创建 DECLARE my_task_name VARCHAR2 (30); my_sqlt ...
- allator 对springBoot进行加密
1.对springboot项目添加jar包和xml文件 allatori.xml: <config> <input> <jar in="target/sprin ...
- Java自学笔记1206
字符串比较string1.equals(string2) 代码如下: 1 package Demo_1206; 2 3 import java.util.Scanner; 4 5 public cla ...