Transformer在NLP任务中表现很好,但是在CV任务中应用还很有限,基本都是作为CNN的一个辅助,Vit尝试使用纯Transformer结构解决CV的任务,并成功将其应用到了CV的基本任务--图像分类中。

因此,简单而言,这篇论文的主旨就是,用Transformer结构完成图像分类任务。

结构概述

基本结构如下:

核心要点:

  • 图像切patch
  • Patch0
  • Position Embedding
  • Multi-Head Attention

图像切patch

在NLP任务中,将自然语言使用Word2Vec转为向量(Embedding)送入模型进行处理,在CV中没有对应的序列化token,因此作者采用将原始图像切分为多个小块,然后将每个小块儿内的信息展平的方式。

假设输入的shape为:(1, 3, 288, 288)

切分为9个小块,则每个小块的shape为:(1, 3, 32, 32)

然后将每个小块展平,则每个小块为(1, 3072),有9个小块,所以Linear Projection of Flattened Patched的shape为:(1, 9, 3072)输出shape为(1, 9, 1024),再加上Position Embedding,Transformer Encoder的输入shape为(1, 10, 1024),也就是图中Embedded Patches的shape。

Patch0

为什么需要有Patch0?

这是因为需要对1-9个patches信息的整合,最后送入MLP Head的只有Patch0。

Position Embedding

图像被切分和展开后,丢失了位置信息,对于图像处理任务来说,这是很怪异的,因此,作者这里采用在每个Patch上增加一个位置信息的方式,将位置信息纳入考虑。

Multi-Head Attention

参考Attention的基本结构。[Todo, Link]

代码[Pytorch]

使用repo pytorch_vit

import torch
from vit_pytorch import ViT v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
) img = torch.randn(1, 3, 256, 256) preds = v(img)
print(preds.shape) # 1000,与ViT定义的num_classes一致

ViT类参数解析:

  • dim:Linear Projection的输出维度:1024
  • depth:有多少个Transformer Blocks
  • heads:Multi-Head的Head数
  • mlp_dim:Transformer Encoder内部的MLP的维度
  • dropout
  • ......

ViT的forward函数:

def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x) x = self.transformer(x) x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] x = self.to_latent(x)
return self.mlp_head(x)

输入端的切分主要由下面这句话完成:

x = self.to_patch_embedding(img)

==>

self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
#由传入参数: image_size = 256, patch_size = 32
# Rearrange完成的shape变换为(b, c, 256, 256) -> (b, 64, 1024*c)
# nn.LayerNorm
# nn.Linear: (b, 64, 1024*c) --> (b, 64, 1024)

Rearrange用更加可理解的方式实现transpose的功能:

We don't write:

y = x.transpose(0, 2, 3, 1)

We write comprehensible code:

y = rearrange(x, 'b c h w -> b h w c')

ViT简述【Transformer】的更多相关文章

  1. VIT Vision Transformer | 先从PyTorch代码了解

    文章原创自:微信公众号「机器学习炼丹术」 作者:炼丹兄 联系方式:微信cyx645016617 代码来自github [前言]:看代码的时候,也许会不理解VIT中各种组件的含义,但是这个文章的目的是了 ...

  2. ICCV2021 | Tokens-to-Token ViT:在ImageNet上从零训练Vision Transformer

    ​  前言  本文介绍一种新的tokens-to-token Vision Transformer(T2T-ViT),T2T-ViT将原始ViT的参数数量和MAC减少了一半,同时在ImageNet上从 ...

  3. Transformer详解

    0 简述 Transformer改进了RNN最被人诟病的训练慢的缺点,利用self-attention机制实现快速并行. 并且Transformer可以增加到非常深的深度,充分发掘DNN模型的特性,提 ...

  4. 从零搭建Pytorch模型教程(三)搭建Transformer网络

    ​ 前言 本文介绍了Transformer的基本流程,分块的两种实现方式,Position Emebdding的几种实现方式,Encoder的实现方式,最后分类的两种方式,以及最重要的数据格式的介绍. ...

  5. 论文阅读 | Transformer-XL: Attentive Language Models beyond a Fixed-Length Context

    0 简述 Transformer最大的问题:在语言建模时的设置受到固定长度上下文的限制. 本文提出的Transformer-XL,使学习不再仅仅依赖于定长,且不破坏时间的相关性. Transforme ...

  6. attention、self-attention、transformer和bert模型基本原理简述笔记

    attention 以google神经机器翻译(NMT)为例 无attention: encoder-decoder在无attention机制时,由encoder将输入序列转化为最后一层输出state ...

  7. ICCV2021 | TransFER:使用Transformer学习关系感知的面部表情表征

    ​  前言  人脸表情识别(FER)在计算机视觉领域受到越来越多的关注.本文介绍了一篇在人脸表情识别方向上使用Transformer来学习关系感知的ICCV2021论文,论文提出了一个TransFER ...

  8. ICCV2021 | PnP-DETR:用Transformer进行高效的视觉分析

    ​  前言  DETR首创了使用transformer解决视觉任务的方法,它直接将图像特征图转化为目标检测结果.尽管很有效,但由于在某些区域(如背景)上进行冗余计算,输入完整的feature maps ...

  9. pycaffe︱caffe中fine-tuning模型三重天(函数详解、框架简述)

    本文主要参考caffe官方文档[<Fine-tuning a Pretrained Network for Style Recognition>](http://nbviewer.jupy ...

  10. 带你读Paper丨分析ViT尚存问题和相对应的解决方案

    摘要:针对ViT现状,分析ViT尚存问题和相对应的解决方案,和相关论文idea汇总. 本文分享自华为云社区<[ViT]目前Vision Transformer遇到的问题和克服方法的相关论文汇总& ...

随机推荐

  1. 【Java SE进阶】Day02 Collection、Iterator、泛型

    一.Collection集合 1.概述 数组存元素,集合存对象(类型可以不一样) 2.框架分类 单列:Collection List ArrayList LinkedList Set HashSet ...

  2. 【每日一题】【字符串与数字互转】【去除空格】【大数处理】2021年12月12日-8. 字符串转换整数 (atoi)

    请你来实现一个 myAtoi(string s) 函数,使其能将字符串转换成一个 32 位有符号整数(类似 C/C++ 中的 atoi 函数). 函数 myAtoi(string s) 的算法如下: ...

  3. 把盏言欢,款款而谈,ChatGPT结合钉钉机器人(outgoing回调)打造人工智能群聊/单聊场景,基于Python3.10

    就像黑火药时代里突然诞生的核弹一样,OpenAI的ChatGPT语言模型的横空出世,是人工智能技术发展史上的一个重要里程碑.这是一款无与伦比.超凡绝伦的模型,能够进行自然语言推理和对话,并且具有出色的 ...

  4. AStar寻路算法示例

    概述 AStar算法是一种图形搜索算法,常用于寻路.他是以广度优先搜索为基础,集Dijkstra算法和最佳优先(best fit)于一身的一种算法. 示例1:4向 示例2:8向 思路 递归的通过估值函 ...

  5. Jmeter 逻辑控制器之吞吐量控制器(Throughput Controller)

    吞吐量控制器(Throughput Controller)用来控制其下元件的执行次数,并无控制吞吐量的功能,想要控制吞吐量可以用Constant Throughput Timer,吞吐量控制器有两种模 ...

  6. STL set容器常用API

    set容器,容器内部将数据自动排序(平衡二叉树),不能插入重复元素.multiset可以插入重复元素.不能修改容器中的值,通过删除值,在插入. #define _CRT_SECURE_NO_WARNI ...

  7. [Unity]限制一个值的大小(Clamp以及Mathf)

    如何限制一个物体的运动范围? 代码实例 public float xMin, xMax, zMin, zMax; rigidbody.position = new Vector3( Mathf.Cla ...

  8. MySQL 不四舍五入取整、取小数、四舍五入取整、取小数、向下、向上取整

    总结了MySQL中取整和取小数中遇到的问题和解决的几个方法:不四舍五入取整.取小数.四舍五入取整.取小数.向下.向上取整. 其中: 不四舍五入取整(截取整数部分)就是'向下取整': 除了用trunca ...

  9. 今天遇到的报错Babel noteThe code generator has deoptimised the styling of ...as it exceeds the max of 500KB.

    解决办法如下: { test: /.js$/, exclude: /node_modules/, use: 'babel-loader' } 然并卵,我已经设置了这个东西了,突然发现我的文件并不在no ...

  10. PHP转Go实践:xjson解析神器「开源工具集」

    前言 近期会更新一系列开源项目的文章,新的一年会和大家做更多的开源项目,也欢迎大家加入进来. xutil 今天分享的文章源自于开源项目jinzaigo/xutil的封装. 在封装过程中,劲仔将实现原理 ...