继续上一篇的内容,上一篇讲解了Bootstrap Your Onw Latent自监督模型的论文和结构:

https://juejin.cn/post/6922347006144970760

现在我们看看如何用pytorch来实现这个结构,并且在学习的过程中加深对论文的理解。

github:https://github.com/lucidrains/byol-pytorch

【前沿】:这个代码我没有实际跑过,毕竟我只是一个没有GPU的小可怜。

主要模型代码

class BYOL(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_size = 256,
projection_hidden_size = 4096,
augment_fn = None,
augment_fn2 = None,
moving_average_decay = 0.99,
use_momentum = True
):
super().__init__()
self.net = net # default SimCLR augmentation DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.RandomResizedCrop((image_size, image_size)),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
) self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, self.augment1) self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer) self.use_momentum = use_momentum
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay) self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) # get device of network and make wrapper same device
device = get_module_device(net)
self.to(device) # send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device)) @singleton('target_encoder')
def _get_target_encoder(self):
target_encoder = copy.deepcopy(self.online_encoder)
set_requires_grad(target_encoder, False)
return target_encoder def reset_moving_average(self):
del self.target_encoder
self.target_encoder = None def update_moving_average(self):
assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
assert self.target_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) def forward(self, x, return_embedding = False):
if return_embedding:
return self.online_encoder(x) image_one, image_two = self.augment1(x), self.augment2(x) online_proj_one, _ = self.online_encoder(image_one)
online_proj_two, _ = self.online_encoder(image_two) online_pred_one = self.online_predictor(online_proj_one)
online_pred_two = self.online_predictor(online_proj_two) with torch.no_grad():
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_proj_one, _ = target_encoder(image_one)
target_proj_two, _ = target_encoder(image_two)
target_proj_one.detach_()
target_proj_two.detach_() loss_one = loss_fn(online_pred_one, target_proj_two.detach())
loss_two = loss_fn(online_pred_two, target_proj_one.detach()) loss = loss_one + loss_two
return loss.mean()
  • 先看forward()函数,发现输入一个图片给模型,然后返回值是这个图片计算的loss
  • 如果是推理过程,那么return_embedding=True,那么返回的值就是online network中的encoder部分输出的东西,不用在考虑后面的predictor,这里需要注意代码中的encoder其实是论文中的encoder+projector
  • 图片经过self.augment1和self.augment2处理成两个不同的图片,在上一篇中,我们称之为view;
  • 两个图片都经过online-encoder,这里可能会有疑问:不是应该一个图片经过online network,另外一个经过target network吗?为什么这两个都经过online-encoder,你说的没错,这里只是方便后面计算symmetric loss,因为要计算对称损失,所以两个图片都要经过online network和target network。
  • 在target network中推理的内容,都不需要记录梯度,因为target network是根据online network的参数更新的
  • 如果self.use_momentum=False,那么就不使用论文中的更新target network的方式,而是直接把online network复制给target network,不过我发现!这个github代码虽然有600多stars,但是这里的就算你的self.use_momentum=True,其实也是把online network复制给了target network啊哈哈,那么就不在这里深究了。
  • 最后计算通过loss_fn计算损失,然后return loss.mean()

所以,目前位置,我们发现这个BYOL的结构其实很简单,目前还有疑点的地方有4个:

  • online_encoder如何定义?
  • predictor如何定义?
  • 图像增强方法如何定义?
  • loss_fn损失函数如何定义?

augment

从上面的代码中可以看到这一段:

# default SimCLR augmentation

        DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.RandomResizedCrop((image_size, image_size)),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
) self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, self.augment1)

可以看到:

  • 这个就是图像增强的pipeline,而augment1和augment2可以自定义,默认的话就是augment1和augment2都是上面的DEFAULT_AUG;
  • from torchvision import transforms as T

比较陌生的可能就是torchvision.transforms.ColorJitter()这个方法了。

从官方API上可以看到,这个方法其实就是随机的修改图片的亮度,对比度,饱和度和色调

encoder+projector

class NetWrapper(nn.Module):
def __init__(self, net, projection_size, projection_hidden_size, layer = -2):
super().__init__()
self.net = net
self.layer = layer self.projector = None
self.projection_size = projection_size
self.projection_hidden_size = projection_hidden_size self.hidden = None
self.hook_registered = False def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None def _hook(self, _, __, output):
self.hidden = flatten(output) def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True @singleton('projector')
def _get_projector(self, hidden):
_, dim = hidden.shape
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
return projector.to(hidden) def get_representation(self, x):
if self.layer == -1:
return self.net(x) if not self.hook_registered:
self._register_hook() _ = self.net(x)
hidden = self.hidden
self.hidden = None
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden def forward(self, x, return_embedding = False):
representation = self.get_representation(x) if return_embedding:
return representation projector = self._get_projector(representation)
projection = projector(representation)
return projection, representation

这个就是基本的encoder+projector,里面包含encoder和projector。

encoder

这个在初始化NetWrapper的时候,需要作为参数传递进来,所以看了训练文件,发现这个模型为:

from torchvision import models, transforms
resnet = models.resnet50(pretrained=True)

所以encoder和论文中说的一样,是一个resnet50。如果我记得没错,这个resnet输出的是一个(batch_size,1000)这样子的tensor。

projector

调用到了MLP这个东西:

class MLP(nn.Module):
def __init__(self, dim, projection_size, hidden_size = 4096):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
) def forward(self, x):
return self.net(x)

是全连接层+BN+激活层的结构。和论文中说的差不多,并且在最后的全连接层后面没有加上BN+relu。经过这个MLP,返回的是一个(batch_size,projection_size)这样形状的tensor。

predictor

self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)

这个predictor,其实就是和projector一模一样的东西,可以看到predictor的输入和输出的特征数量都是projection_size

这里因为我对自监督的体系没有完整的阅读论文,只是最先看了这个BYOL,所以我无法说明这个predictor为什么存在。从表现来看,是为了防止online network和target network的结构完全相同,如果完全相同的话可能会让两个模型训练出完全一样的效果,也就是loss=0的情况。假设

loss_fn

def loss_fn(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)

这部分和论文中一致。

综上所属,这个BYOL框架是一个简单,又有趣的无监督架构。

自监督图像论文复现 | BYOL(pytorch)| 2020的更多相关文章

  1. Visualizing and Understanding Convolutional Networks论文复现笔记

    目录 Visualizing and Understanding Convolutional Networks 论文复现笔记 Abstract Introduction Approach Visual ...

  2. Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易

    近日,PyTorch 社区发布了一个深度学习工具包 PyTorchHub, 帮助机器学习工作者更快实现重要论文的复现工作.PyTorchHub 由一个预训练模型仓库组成,专门用于提高研究工作的复现性以 ...

  3. 图像风格迁移(Pytorch)

    图像风格迁移 最后要生成的图片是怎样的是难以想象的,所以朴素的监督学习方法可能不会生效, Content Loss 根据输入图片和输出图片的像素差别可以比较损失 \(l_{content} = \fr ...

  4. 小白经典CNN论文复现系列(一):LeNet1989

    小白的经典CNN复现系列(一):LeNet-1989 之前的浙大AI作业的那个系列,因为后面的NLP的东西我最近大概是不会接触到,所以我们先换一个系列开始更新博客,就是现在这个经典的CNN复现啦(。・ ...

  5. GAN生成图像论文总结

    GAN Theory Modifyingthe Optimization of GAN 题目 内容 GAN   DCGAN   WGAN   Least-square GAN   Loss Sensi ...

  6. 化繁为简,弱监督目标定位领域的新SOTA - 伪监督目标定位方法(PSOL) | CVPR 2020

    论文提出伪监督目标定位方法(PSOL)来解决目前弱监督目标定位方法的问题,该方法将定位与分类分开成两个独立的网络,然后在训练集上使用Deep descriptor transformation(DDT ...

  7. 训练一个图像分类器demo in PyTorch【学习笔记】

    [学习源]Tutorials > Deep Learning with PyTorch: A 60 Minute Blitz > Training a Classifier   本文相当于 ...

  8. 库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)

    项目地址:https://github.com/bharathgs/Awesome-pytorch-list 列表结构: NLP 与语音处理 计算机视觉 概率/生成库 其他库 教程与示例 论文实现 P ...

  9. 复现ICCV 2017经典论文—PyraNet

    . 过去几年发表于各大 AI 顶会论文提出的 400 多种算法中,公开算法代码的仅占 6%,其中三分之一的论文作者分享了测试数据,约 54% 的分享包含“伪代码”.这是今年 AAAI 会议上一个严峻的 ...

随机推荐

  1. Kotlin 简单使用手册

    在昨天和做android的前辈一番交谈后,觉得很惭愧,许多东西还只是知其然而不知其所以然,也深感自己的技术还太浅薄.以后要更加努力地学习,要着重学习原理.方法论,不能只停留在会用的阶段. 今天又要献丑 ...

  2. 白嫖JetBrains正版全家桶!

    使用自己的开源项目,是可以白嫖JetBrains正版全家桶的! 前言 之前在学Go的时候,想着要用什么编辑器,网上的大佬都讲,想省事直接用Goland,用VsCode配置会存在一些未知的使用体验问题, ...

  3. 深入浅出Dotnet Core的项目结构变化

    有时候,越是基础的东西,越是有人不明白.   前几天Review一个项目的代码,发现非常基础的内容,也会有人理解出错. 今天,就着这个点,写一下Dotnet Core的主要类型的项目结构,以及之间的转 ...

  4. python的默认参数的一个坑

    前言 pass 正文 在 https://docs.python.org/3/tutorial/controlflow.html#default-argument-values 中,有这样一段话 Im ...

  5. Python 中的行向量、列向量和矩阵

    1.一维数组 一维数组既不是行向量,也不是列向量. import numpy as npa=np.array([1,2,3])print(np.shape(a))>>>(3,) 2. ...

  6. 【MySQL 高级】索引优化分析

    MySQL高级 索引优化分析 SQL 的效率问题 出现性能下降,SQL 执行慢,执行时间长,等待时间长等情况,可能的原因有: 查询语句写的不好 索引失效 单值索引:在 user 表中给 name 属性 ...

  7. LeetCode747 至少是其他数字两倍的最大数

    在一个给定的数组nums中,总是存在一个最大元素 . 查找数组中的最大元素是否至少是数组中每个其他数字的两倍. 如果是,则返回最大元素的索引,否则返回-1. 示例 1: 输入: nums = [3, ...

  8. 【C++】《C++ Primer 》第十三章

    第十三章 拷贝控制 定义一个类时,需要显式或隐式地指定在此类型地对象拷贝.移动.赋值和销毁时做什么. 一个类通过定义五种特殊的成员函数来控制这些操作.即拷贝构造函数(copy constructor) ...

  9. Linux学习笔记 | 常见错误之VMware启动linux后一直黑屏

    方法1: 宿主机(windows)管理员模式运行cmd 输入netsh winsock reset 然后重启电脑 netsh winsock reset命令,作用是重置 Winsock 目录.如果一台 ...

  10. 一次snapshot迁移引发的Hbase RIT(hbase2.1.0-cdh6.3.0)

    1. 问题起因 通过snapshot做跨集群数据同步时,在执行拷贝脚本里没有指定所有者及所有组,导致clone时没有权限,客户端卡死.master一直报错,经过一系列操作后,导致RIT异常. 2. 异 ...