从零搭建Pytorch模型教程(三)搭建Transformer网络
前言 本文介绍了Transformer的基本流程,分块的两种实现方式,Position Emebdding的几种实现方式,Encoder的实现方式,最后分类的两种方式,以及最重要的数据格式的介绍。
本文来自公众号CV技术指南的技术总结系列
欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。
在讲如何搭建之前,先回顾一下Transformer在计算机视觉中的结构是怎样的。这里以最典型的ViT为例。
如图所示,对于一张图像,先将其分割成NxN个patches,把patches进行Flatten,再通过一个全连接层映射成tokens,对每一个tokens加入位置编码(position embedding),会随机初始化一个tokens,concate到通过图像生成的tokens后,再经过transformer的Encoder模块,经过多层Encoder后,取出最后的tokens(即随机初始化的tokens),再通过全连接层作为分类网络进行分类。
下面我们就根据这个流程来一步一步介绍如何搭建一个Transformer模型。、
分块
目前有两种方式实现分块,一种是直接分割,一种是通过卷积核和步长都为patch大小的卷积来分割。
直接分割
直接分割即把图像直接分成多块。在代码实现上需要使用einops这个库,完成的操作是将(B,C,H,W)的shape调整为(B,(H/P *W/P),P*P*C)。
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
这里简单介绍一下Rearrange。
Rearrange用于对张量的维度进行重新变换排序,可用于替换pytorch中的reshape,view,transpose和permute等操作。举几个例子
#假设images的shape为[32,200,400,3]
#实现view和reshape的功能
Rearrange(images,'b h w c -> (b h) w c')#shape变为(32*200, 400, 3)
#实现permute的功能
Rearrange(images, 'b h w c -> b c h w')#shape变为(32, 3, 200, 400)
#实现这几个都很难实现的功能
Rearrange(images, 'b h w c -> (b c w) h')#shape变为(32*3*400, 200)
从这几个例子看可以看出,Rearrange非常简单好用,这里的b, c, h, w都可以理解为表示符号,用来表示操作变化。通过这几个例子似乎也能理解下面这行代码是如何将图像分割的。
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width)
这里需要解释的是,一个括号内的两个变量相乘表示的是该维度的长度,因此不要把"h"和"w"理解成图像的宽和高。这里实际上h = H/p1, w = W/p2,代表的是高度上有几块,宽度上有几块。h和w都不需要赋值,代码会自动根据这个表达式计算,b和c也会自动对应到输入数据的B和C。
后面的"b (h w) (p1 p2 c)"表示了图像分块后的shape: (B,(H/P *W/P),P*P*C)
这种方式在分块后还需要通过一层全连接层将分块的向量映射为tokens。
在ViT中使用的就是这种直接分块方式。
卷积分割
卷积分割比较容易理解,使用卷积核和步长都为patch大小的卷积对图像卷积一次就可以了。
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
在swin transformer中即使用的是这种卷积分块方式。在swin transformer中卷积后没有再加全连接层。
Position Embedding
Position Embedding可以分为absolute position embedding和relative position embedding。
在学习最初的transformer时,可能会注意到用的是正余弦编码的方式,但这只适用于语音、文字等1维数据,图像是高度结构化的数据,用正余弦不合适。
在ViT和swin transformer中都是直接随机初始化一组与tokens同shape的可学习参数,与tokens相加,即完成了absolute position embedding。
在ViT中实现方式:
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
x += self.pos_embedding[:, :(n + 1)]
#之所以是n+1,是因为ViT中选择随机初始化一个class token,与分块得到的tokens拼接。所以patches的数量为num_patches+1。
在swin transformer中的实现方式:
from timm.models.layers import trunc_normal_
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
在TimeSformer中的实现方式:
self.pos_emb = torch.nn.Embedding(num_positions + 1, dim)
以上就是简单的使用方法,这种方法属于absolute position embedding。
还有更复杂一点的方法,以后有机会单独搞一篇文章来介绍。
感兴趣的读者可以先去看看这篇论文《ICCV2021 | Vision Transformer中相对位置编码的反思与改进》。
Encoder
Encoder由Multi-head Self-attention和FeedForward组成。
Multi-head Self-attention
Multi-head Self-attention主要是先把tokens分成q、k、v,再计算q和k的点积,经过softmax后获得加权值,给v加权,再经过全连接层。
用公式表示如下:
所谓Multi-head是指把q、k、v再dim维度上分成head份,公式里的dk为每个head的维度。
具体代码如下:
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
这里没有太多可以解释的地方,介绍一下q、k、v的来源,由于这是self-attention,因此q=k=v(即tokens),若是普通attention,则k= v,而q是其它的东西,例如可以是另一个尺度的tokens,或视频领域中的其它帧的tokens。
FeedForward
这里不用多介绍。
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
把上面两者组合起来就是Encoder了。
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
depth指的是Encoder的数量。PreNorm指的是层归一化。
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
分类方法
数据通过Encoder后获得最后的预测向量的方法有两种典型。在ViT中是随机初始化一个cls_token,concate到分块后的token后,经过Encoder后取出cls_token,最后将cls_token通过全连接层映射到最后的预测维度。
#生成cls_token部分
from einops import repeat
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
################################
#分类部分
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
在swin transformer中,没有选择cls_token。而是直接在经过Encoder后将所有数据取了个平均池化,再通过全连接层。
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
x = self.head(x)
组合以上这些就成了一个完整的模型
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
数据的变换
以上的代码都是比较简单的,整体上最麻烦的地方在于理解数据的变换。
首先输入的数据为(B, C, H, W),在经过分块后,变成了(B, n, d)。
在CNN模型中,很好理解(H,W)就是feature map,C是指feature map的数量,那这里的n,d哪个是通道,哪个是图像特征?
回顾一下分块的部分
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width)
根据这个可以知道n为分块的数量,d为每一块的内容。因此,这里的n相当于CNN模型中的C,而d相当于features。
一般情况下,在Encoder中,我们都是以(B, n, d)的形式。
在swin transformer中这种以卷积的形式分块,获得的形式为(B, C, L),然后做了一个transpose得到(B, L, C),这与ViT通过直接分块方式获得的形式实际上完全一样,在Swin transformer中的L即为ViT中的n,而C为ViT中的d。
因此,要注意的是在Multi-head self-attention中,数据的形式是(Batchsize, Channel, Features),分成多个head的是Features。
前面提到,在ViT中会concate一个随机生成的cls_token,该cls_token的维度即为(B, 1, d)。可以理解为通道数多了个1。
以上就是Transformer的模型搭建细节了,整体上比较简单,大家看完这篇文章后可以找几篇Transformer的代码来理解理解。如ViT, swin transformer, TimeSformer等。
ViT:https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
swin: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
TimeSformer:https://github.com/lucidrains/TimeSformer-pytorch/blob/main/timesformer_pytorch/timesformer_pytorch.py
下一篇我们将介绍如何写train函数,以及包括设置优化方式,设置学习率,不同层设置不同学习率,解析参数等。
欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。
CV技术指南创建了一个交流氛围很不错的群,除了太偏僻的问题,几乎有问必答。关注公众号添加编辑的微信号可邀请加交流群。
其它文章
StyleGAN大汇总 | 全面了解SOTA方法、架构新进展
目标检测、实例分割、多目标跟踪的Anchor-free应用方法总结
从零搭建Pytorch模型教程(三)搭建Transformer网络的更多相关文章
- 从零搭建Pytorch模型教程(四)编写训练过程--参数解析
前言 训练过程主要是指编写train.py文件,其中包括参数的解析.训练日志的配置.设置随机数种子.classdataset的初始化.网络的初始化.学习率的设置.损失函数的设置.优化方式的设置. ...
- 从零搭建Pytorch模型教程(一)数据读取
前言 本文介绍了classdataset的几个要点,由哪些部分组成,每个部分需要完成哪些事情,如何进行数据增强,如何实现自己设计的数据增强.然后,介绍了分布式训练的数据加载方式,数据读取的整个 ...
- 手把手教从零开始在GitHub上使用Hexo搭建博客教程(三)-使用Travis自动部署Hexo(1)
前言 前面两篇文章介绍了在github上使用hexo搭建博客的基本环境和hexo相关参数设置等. 基于目前,博客基本上是可以完美运行了. 但是,有一点是不太好,就是源码同步问题,如果在不同的电脑上写文 ...
- PyTorch 系列教程之空间变换器网络
在本教程中,您将学习如何使用称为空间变换器网络的视觉注意机制来扩充您的网络.你可以在DeepMind paper 阅读更多有关空间变换器网络的内容. 空间变换器网络是对任何空间变换的差异化关注的概括. ...
- keras入门(三)搭建CNN模型破解网站验证码
项目介绍 在文章CNN大战验证码中,我们利用TensorFlow搭建了简单的CNN模型来破解某个网站的验证码.验证码如下: 在本文中,我们将会用Keras来搭建一个稍微复杂的CNN模型来破解以上的 ...
- 物理引擎Havok教程(一)搭建开发环境
物理引擎Havok教程(一)搭建开发环境 网上关于Havok的教程实在不多,并且Havok学习起来还是有一定难度的,所以这里写了一个系列教程,希望可以帮到读者.这是第一期. 一.Havok物理引擎简单 ...
- 【入门教程】kafka环境搭建以及基础教程
问题导读 1.Kafka独特设计在什么地方?2.Kafka如何搭建及创建topic.发送消息.消费消息?3.如何书写Kafka程序?4.数据传输的事务定义有哪三种?5.Kafka判断一个节点是否活着有 ...
- DWR搭建以及使用教程
DWR搭建以及使用教程 DWR(Direct Web Remoting)是一个Ajax的开源框架,用于改善web页面与Java类交互的远程服务器端的交互体验,可以帮助开发人员开发包含AJAX技术的 ...
- hexo搭建博客系列(三)美化主题
文章目录 其他搭建 1. 添加博客图标 2. 鼠标点击特效(二选一) 2.1 红心特效 2.2 爆炸烟花 3. 设置头像 4. 侧边栏社交小图标设置 5. 文章末尾的标签图标修改 6. 访问量统计 7 ...
随机推荐
- 6月20日 Django中ORM介绍和字段、字段参数、相关操作
一.Django中ORM介绍和字段及字段参数 二.Django ORM 常用字段和参数 三.Django ORM执行原生SQL.在Python脚本中调用Django环境.Django终端打印SQL语句 ...
- RabbitMQ Go客户端教程6——RPC
本文翻译自RabbitMQ官网的Go语言客户端系列教程,本文首发于我的个人博客:liwenzhou.com,教程共分为六篇,本文是第六篇--RPC. 这些教程涵盖了使用RabbitMQ创建消息传递应用 ...
- java线程池之newFixedThreadPool定长线程池
newFixedThreadPool 创建一个定长线程池,可控制线程最大并发数,超出的线程会在队列中等待. 线程池的作用: 线程池作用就是限制系统中执行线程的数量. 根 据系统的环境情况,可以 ...
- Java并发编程虚假唤醒问题(生产者和消费者关系)
何为虚假唤醒: 当一个条件满足时,很多线程都被唤醒了,但是只有其中部分是有用的唤醒,其它的唤醒都是无用功:比如买货:如果商品本来没有货物,突然进了一件商品,这是所有的线程都被唤醒了,但是只能一个人买, ...
- MySQL JDBC常用知识,封装工具类,时区问题配置,SQL注入问题
JDBC JDBC介绍 Sun公司为了简化开发人员的(对数据库的统一)操作,提供了(Java操作数据库的)规范,俗称JDBC,这些规范的由具体由具体的厂商去做 对于开发人员来说,我们只需要掌握JDBC ...
- 关于Gradle 6.x及以上版本发布到仓库有很多CheckSum文件,想去除?
写在前边 今天写的这个博客和平时的博客有点区别,大多数都是告诉你解决问题的,而本文不完全是. 经常使用Gradle的都知道 Gradle有两个发布制品到Maven仓库的插件:maven 与 maven ...
- Lua协程的一个例子
很久没记录笔记了,还是养成不了记录的习惯 下面是来自 programming in lua的一个协程的例(生产者与用户的例子) 帖代码,慢慢理解 -- Programming in Lua Corou ...
- 数据结构-二叉树、B树、B+树、B*树(整理版)
1. 二叉树 二叉树的特点: ① 所有非叶子节点至多拥有两个儿子(Left和Right): ② 所有节点存储一个关键字: ③ 非叶子节点的左指针指向小于其关键字的子树,右指针指向大于其关键字的子树: ...
- Lock 与 Synchronized 的区别?
首先两者都保持了并发场景下的原子性和可见性,区别则是synchronized的释放锁机制是交由其自身控制,且互斥性在某些场景下不符合逻辑,无法进行干预,不可人为中断等.而lock常用的则有Reentr ...
- Java并发机制(8)--concurrent包下辅助类的使用
Java并发编程:concurrent包下辅助类的使用 整理自:博客园-海子-http://www.cnblogs.com/dolphin0520/p/3920397.html 1.CountDown ...