Transformer在NLP任务中表现很好,但是在CV任务中应用还很有限,基本都是作为CNN的一个辅助,Vit尝试使用纯Transformer结构解决CV的任务,并成功将其应用到了CV的基本任务--图像分类中。

因此,简单而言,这篇论文的主旨就是,用Transformer结构完成图像分类任务。

结构概述

基本结构如下:

核心要点:

  • 图像切patch
  • Patch0
  • Position Embedding
  • Multi-Head Attention

图像切patch

在NLP任务中,将自然语言使用Word2Vec转为向量(Embedding)送入模型进行处理,在CV中没有对应的序列化token,因此作者采用将原始图像切分为多个小块,然后将每个小块儿内的信息展平的方式。

假设输入的shape为:(1, 3, 288, 288)

切分为9个小块,则每个小块的shape为:(1, 3, 32, 32)

然后将每个小块展平,则每个小块为(1, 3072),有9个小块,所以Linear Projection of Flattened Patched的shape为:(1, 9, 3072)输出shape为(1, 9, 1024),再加上Position Embedding,Transformer Encoder的输入shape为(1, 10, 1024),也就是图中Embedded Patches的shape。

Patch0

为什么需要有Patch0?

这是因为需要对1-9个patches信息的整合,最后送入MLP Head的只有Patch0。

Position Embedding

图像被切分和展开后,丢失了位置信息,对于图像处理任务来说,这是很怪异的,因此,作者这里采用在每个Patch上增加一个位置信息的方式,将位置信息纳入考虑。

Multi-Head Attention

参考Attention的基本结构。[Todo, Link]

代码[Pytorch]

使用repo pytorch_vit

import torch
from vit_pytorch import ViT v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
) img = torch.randn(1, 3, 256, 256) preds = v(img)
print(preds.shape) # 1000,与ViT定义的num_classes一致

ViT类参数解析:

  • dim:Linear Projection的输出维度:1024
  • depth:有多少个Transformer Blocks
  • heads:Multi-Head的Head数
  • mlp_dim:Transformer Encoder内部的MLP的维度
  • dropout
  • ......

ViT的forward函数:

def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x) x = self.transformer(x) x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] x = self.to_latent(x)
return self.mlp_head(x)

输入端的切分主要由下面这句话完成:

x = self.to_patch_embedding(img)

==>

self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
#由传入参数: image_size = 256, patch_size = 32
# Rearrange完成的shape变换为(b, c, 256, 256) -> (b, 64, 1024*c)
# nn.LayerNorm
# nn.Linear: (b, 64, 1024*c) --> (b, 64, 1024)

Rearrange用更加可理解的方式实现transpose的功能:

We don't write:

y = x.transpose(0, 2, 3, 1)

We write comprehensible code:

y = rearrange(x, 'b c h w -> b h w c')

ViT简述【Transformer】的更多相关文章

  1. VIT Vision Transformer | 先从PyTorch代码了解

    文章原创自:微信公众号「机器学习炼丹术」 作者:炼丹兄 联系方式:微信cyx645016617 代码来自github [前言]:看代码的时候,也许会不理解VIT中各种组件的含义,但是这个文章的目的是了 ...

  2. ICCV2021 | Tokens-to-Token ViT:在ImageNet上从零训练Vision Transformer

    ​  前言  本文介绍一种新的tokens-to-token Vision Transformer(T2T-ViT),T2T-ViT将原始ViT的参数数量和MAC减少了一半,同时在ImageNet上从 ...

  3. Transformer详解

    0 简述 Transformer改进了RNN最被人诟病的训练慢的缺点,利用self-attention机制实现快速并行. 并且Transformer可以增加到非常深的深度,充分发掘DNN模型的特性,提 ...

  4. 从零搭建Pytorch模型教程(三)搭建Transformer网络

    ​ 前言 本文介绍了Transformer的基本流程,分块的两种实现方式,Position Emebdding的几种实现方式,Encoder的实现方式,最后分类的两种方式,以及最重要的数据格式的介绍. ...

  5. 论文阅读 | Transformer-XL: Attentive Language Models beyond a Fixed-Length Context

    0 简述 Transformer最大的问题:在语言建模时的设置受到固定长度上下文的限制. 本文提出的Transformer-XL,使学习不再仅仅依赖于定长,且不破坏时间的相关性. Transforme ...

  6. attention、self-attention、transformer和bert模型基本原理简述笔记

    attention 以google神经机器翻译(NMT)为例 无attention: encoder-decoder在无attention机制时,由encoder将输入序列转化为最后一层输出state ...

  7. ICCV2021 | TransFER:使用Transformer学习关系感知的面部表情表征

    ​  前言  人脸表情识别(FER)在计算机视觉领域受到越来越多的关注.本文介绍了一篇在人脸表情识别方向上使用Transformer来学习关系感知的ICCV2021论文,论文提出了一个TransFER ...

  8. ICCV2021 | PnP-DETR:用Transformer进行高效的视觉分析

    ​  前言  DETR首创了使用transformer解决视觉任务的方法,它直接将图像特征图转化为目标检测结果.尽管很有效,但由于在某些区域(如背景)上进行冗余计算,输入完整的feature maps ...

  9. pycaffe︱caffe中fine-tuning模型三重天(函数详解、框架简述)

    本文主要参考caffe官方文档[<Fine-tuning a Pretrained Network for Style Recognition>](http://nbviewer.jupy ...

  10. 带你读Paper丨分析ViT尚存问题和相对应的解决方案

    摘要:针对ViT现状,分析ViT尚存问题和相对应的解决方案,和相关论文idea汇总. 本文分享自华为云社区<[ViT]目前Vision Transformer遇到的问题和克服方法的相关论文汇总& ...

随机推荐

  1. 如何使用 IdGen 生成 UID

    在分布式系统中,雪花 ID 是一种常用的唯一 ID 生成算法.它通过结合时间戳.机器码和自增序列来生成 64 位整数 ID,可以保证 ID 的唯一性和顺序性. 在.Net 项目中,我们可以使用 IdG ...

  2. Pytorch框架详解之一

    Pytorch基础操作 numpy基础操作 定义数组(一维与多维) 寻找最大值 维度上升与维度下降 数组计算 矩阵reshape 矩阵维度转换 代码实现 import numpy as np a = ...

  3. vue 组件之间传值(父传子,子传父)

    1.父传子 基本就用一个方式,props Father.vue(用v-bind(简写 : )  将父组件传的值绑定到子组件上) <template> <div> 我是爸爸:{{ ...

  4. JavaScript:操作符:赋值运算符和空赋值(??=)

    =号是赋值运算,即返回符号右边的结果,同时将结果赋值给符号左边的变量,考虑下面代码的运行结果: 赋值运算b = 1 + 1,做了两件事,先返回符号右边的结果,即2,这个2将参与a = 1 + 2的计算 ...

  5. STL list容器API

    list容器:链表容器,不支持随机遍历.不能用通用的sort算法(要有随机访问迭代器),容器自己有排序算法 #define _CRT_SECURE_NO_WARNINGS #include<io ...

  6. 初识argparse 模块

    # 1引入模块 import argparse # 2建立解析对象 parser = argparse.ArgumentParser() # 3增加属性:给xx实例增加一个aa属性 # xx.add_ ...

  7. 终于弄明白了 RocketMQ 的存储模型

    RocketMQ 优异的性能表现,必然绕不开其优秀的存储模型 . 这篇文章,笔者按照自己的理解 , 尝试分析 RocketMQ 的存储模型,希望对大家有所启发. 1 整体概览 首先温习下 Rocket ...

  8. [深度学习] CNN的基础结构与核心思想

    1. 概述 卷积神经网络是一种特殊的深层的神经网络模型,它的特殊性体现在两个方面,一方面它的神经元间的连接是非全连接的, 另一方面同一层中某些神经元之间的连接的权重是共享的(即相同的).它的非全连接和 ...

  9. MQ系列9:高可用架构分析

    MQ系列1:消息中间件执行原理 MQ系列2:消息中间件的技术选型 MQ系列3:RocketMQ 架构分析 MQ系列4:NameServer 原理解析 MQ系列5:RocketMQ消息的发送模式 MQ系 ...

  10. python之路54 forms组件 渲染 展示 参数补充 modelform组件 django中间件

    forms组件渲染标签 <p>forms组件渲染标签的方式1(封装程度高 扩展性差 主要用于本地测试):</p> {# {{ form_obj.as_p }}#} {# {{ ...