Vision Transformer 笔记

本文主要从代码角度记录使用transformer实现图像分类的流程. 代码vit-pytorch/

总体结构

结合上图与代码展开:

前向传播过程代码如下:

def forward(self, img, mask=None):
    
    x = self.to_patch_embedding(img)
   
    b, n, _ = x.shape

    
    cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
   
    x = torch.cat((cls_tokens, x), dim=1)
  
    x += self.pos_embedding[:, :(n + 1)]
  
    x = self.dropout(x)

  
    x = self.transformer(x, mask)

  
    x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

  
    x = self.to_latent(x)
  
    return self.mlp_head(x)

主要由5步构成:

  • 将图片变成图片块并进行embedding
  • 生成用于分类的token
  • 编码位置信息 pos_embedding
  • transformer 编码器模块进行attention计算
  • mlp 分类头进行分类

patch_embedding

num_patches = (image_size // patch_size) ** 2

patch_dim = channels * patch_size ** 2

assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'


self.to_patch_embedding = nn.Sequential(
range('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
    
    nn.Linear(patch_dim, dim),
)

假如输入图片大小为 H * W * C,将图片切分为 9 块, 每块子图的宽高为(p, p), 将每个子图展开为一维向量大小为 p * p * c, 总输入变为 9 * (p * p * c), 每张图片变成了9个一维向量。然后连接一个全连接层对每个向量进行embedding, 输出维度为 dim

分类token

在vit中,如上图,进过编码器后得到9个图片块对应的向量,通俗说这9个图像块已经互相建立了联系,看到了彼此,但是该用哪一个去进行分类呢?因此增加了一个专用于分类的向量,该向量时可学习的。大白话理解就是:这9个也不要争了,增加一个来统领全局,由它来学的整个图片的信息。

self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# 维度和patch embedding的维度一样,前向传播时会concat 到embedding 前面

pos_embedding

与transformer一样,对图像进行分块后失去了空间位置信息。原始的Transformer引入了一个 Positional encoding来加入序列的位置信息,transformer中的 Positional encoding 是写死的,在这里也引入的pos_embedding,是用一个可训练的变量替代。

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

num_patches + 1 , 是因为增加了一个分类token

transformer 编码器模块进行attention计算

原版transformer的encoder一样进行编码

mlp 分类头进行分类

self.mlp_head = nn.Sequential(
nn.layerNorm(dim),
nn.Linear(dim, num_classes))

即一个全连接层进行映射到世界类别进行分类, 在vit前向传播过程中,通常只取第一个分类token $x[:0]$ 输入到 mlp 中.

x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

vs sota

REF