【第三十一周】ViT 论文阅读笔记

发布于:2025-04-15 ⋅ 阅读:(19) ⋅ 点赞:(0)

摘要

本篇博客介绍了Vision Transformer(ViT),这是一种突破性的图像分类模型,其核心思想是将图像分割为固定大小的块(如16×16像素),并将这些块序列化后输入标准的Transformer架构,从而替代传统卷积神经网络(CNN)对视觉特征的局部归纳偏置依赖 。针对图像数据难以直接适配序列模型的问题,ViT提出图像块嵌入(Patch Embedding)技术,通过线性投影将每个块展平为向量,并引入可学习的位置编码(Position Embedding)保留空间信息,同时添加分类标识符(Class Token)以聚合全局特征。ViT采用多层Transformer Encoder堆叠,通过自注意力机制捕捉跨区域的全局依赖,最终由MLP Head输出分类结果。实验表明,当在大规模数据集(如JFT-300M)预训练后,ViT在ImageNet等任务上超越同期CNN模型,且训练资源需求更低。其优势在于全局建模能力与模型扩展性,但依赖大量预训练数据且计算复杂度随分辨率呈平方增长。未来改进方向包括轻量化设计、动态位置编码优化,以及结合局部-全局注意力机制以提升实际场景的实用性。

Abstract

This blog introduces Vision Transformer (ViT), a groundbreaking image classification model that replaces the local inductive bias of traditional convolutional neural networks (CNNs) by segmenting images into fixed-size patches (e.g., 16×16 pixels) and processing them as sequential inputs through a standard Transformer architecture. To address the challenge of adapting grid-structured images to sequence-based modeling, ViT employs patch embedding to linearly project flattened patches into vectors, learnable positional embeddings to encode spatial relationships, and a class token to aggregate global features. By stacking multiple Transformer encoder layers with self-attention mechanisms, ViT captures long-range dependencies across image regions, culminating in classification predictions via an MLP head. Experiments demonstrate that when pretrained on large-scale datasets like JFT-300M, ViT outperforms contemporary CNNs on tasks such as ImageNet while requiring fewer computational resources. Despite its advantages in global feature modeling and scalability, ViT heavily relies on extensive pretraining data and suffers from quadratic computational complexity relative to input resolution. Future research may focus on lightweight architectures, dynamic positional encoding, and hybrid local-global attention mechanisms to enhance its practicality in real-world applications.


文章信息

Title:AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE
Author:Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
Source:https://arxiv.org/abs/2010.11929


引言

在ViT的提出之前,基于自注意力的架构,特别是 Transformers ,已成为自然语言处理领域的首选模型,主要的方法是在大型文本语料上进行预训练,然后在较小的任务特定数据集上进行微调。Transformer的计算效率高(并行计算)、可扩展性强,且随着模型和数据集的增长,性能仍然没有饱和的迹象。

然而,在计算机视觉领域,卷积架构仍占主导地位。受 NLP 成功的启发,一些作品尝试将类似 CNN 的架构与自注意力相结合,也有一些作品尝试完全取代卷积,但后一种模型虽然理论上有效,但由于使用了专门的注意力模式,因此尚未在现代硬件加速器上有效扩展。

CNN依赖两种先验知识——局部性(Locality)和平移不变性(Translation Equivariance),这使得CNN在小数据集上表现优异,但也限制了其全局建模能力。CNN的逐层卷积操作难以直接建模图像中远距离像素或区域之间的关系。

ViT摒弃了卷积操作,消除了视觉任务中的归纳偏置,将图像分割为固定大小的块(Patch),通过线性投影转化为序列输入,完全依赖自注意力机制建模全局关系。

方法

在 ViT 的设计中,尽可能与原始的 Transformer 结构保持一致,这种简单的设计的一个好处是可以使用现有的 Transformer 架构的高效实现。
ViT 的网络架构如下图所示:
在这里插入图片描述
ViT 的架构简单来说有三部分组成:

  • Linear Projection of Flattened Patches(Embedding层,将二维图像转换为适合Transformer处理的序列数据)
  • Transformer Encoder(通过自注意力机制和前馈网络提取全局特征)
  • MLP Head(最终用于分类的层结构)

ViT总体流程:
对于输入图像,先按照预先指定的尺寸进行分割,分割后的小图像是 patch ,然后每个patch经过 Linear Projection of Flattened Patches层进行embeding,得到 patch embedding。然后对每个 patch embedding 分别加上对应 patch 的位置编码 position embedding,并拼接(concat)上一个 cls token (cls token 来自BERT,用于最终的分类层的输入)的信息。上述得到的信息输入到 Transformer encoder (堆叠的多层,每层的结构一致,参数独立)中提取全局信息。经过堆叠的Transformer encoder后,cls token 可以得到整图(所有patch)的信息,所以其输出作为分类层 MLP head 的输入,进行类别概率计算。
ViT的流程动图如下:
在这里插入图片描述

Patch Embedding

对于图像数据而言,其数据格式为[H, W, C]是三维矩阵明显不是Transformer想要的。所以需要先通过一个Embedding层来对数据做个变换。为方便描述网络中的数据流,本博客在具体的数据上都以ViT-B/16为例(patch大小为16*16)。
首先将输入的图片按照预先指定的patch大小分割为若干patch,如输入图片大小为 224 ∗ 224 ∗ 3 224*224*3 2242243,划分后有 ( 224 ∗ 224 ) / ( 16 ∗ 16 ) = 196 (224*224)/(16*16)= 196 224224/1616=196 个 patch。每个patch大小为 16 × 16 × 3 16\times16\times3 16×16×3,通过映射得到一个长度为768的向量,数据形状变化:[16, 16, 3] -> [768]。

在实际的实现中,是通过卷积来完成上述操作的。
在这里插入图片描述
使用卷积核为 16 ∗ 16 16*16 1616,stride=16,padding=0,卷积核个数为 768 的卷积操作,可将原图的 224 ∗ 224 ∗ 3 224*224*3 2242243 转化为 14 ∗ 14 ∗ 768 14*14*768 1414768,然后将 H 个 W 两个维度展平得到二维矩阵,形状为 196*768,其中 196 是patch token 的个数,768 是 token 的维度。

Q:为何要处理成Patch
A:主要有以下两个原因:
第一,减少计算量,在Transformer中,假设输入的序列长度为N,则经过attention时计算复杂度为 O ( N 2 ) O(N^2) O(N2),因为注意力机制下,每个token都要和包括自己在内的所有token做一次attention score计算。在ViT中,分割的每个Patch作为一个token输入到Transformer encoder,序列长度 N = ( H × W ) / P 2 N=(H\times W)/{P^2} N=(H×W)/P2,其中P是patch的大小,patch越大,序列越短,计算量越小。
第二,和语言数据中蕴含的丰富语义不同,像素本身含有大量的冗余信息。比如,相邻的两个像素格子间的取值往往是相似的。因此并不需要特别精准的计算粒度(比如把P设为1)。

Patch + Position Embedding

与BERT一样,ViT 中 Transformer 的输入需要有位置信息(position embeding)和 class token([class]token是一个可训练的参数,数据格式和之前计算的patch token一样)。
得到的 patch token 先与 class token 进行 concat 拼接: Cat([1, 768], [196, 768]) -> [197, 768]。然后加上可训练Position Embedding,是直接在token上进行sum运算,前后数据格式不变。
对于Position Embedding,在源码中默认使用的是1D Pos. Emb.,对比不使用Position Embedding准确率提升了大概3个点,和2D Pos. Emb.比起来没太大差别。
在这里插入图片描述

Transformer Encoder

Transformer Encoder其实就是重复堆叠Encoder Block L次,Encoder Block其具体结构如下图左侧所示。
在这里插入图片描述
Encoder Block主要由以下几部分组成:

  • Layer Norm,这种Normalization方法主要是针对NLP领域提出的,这里是对每个token进行Norm处理。
  • 多头自注意力(Multi-Head Self-Attention),捕捉不同位置patch间的依赖关系。
  • Dropout/DropPath,在原论文的代码中是直接使用的Dropout层,在但rwightman实现的代码中使用的是DropPath(stochastic depth)。
  • MLP Block,如图右侧所示,就是全连接+GELU激活函数+Dropout组成,需要注意的是第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072],第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768]。

MLP Head

Transformer Encoder前后的数据格式不变,在其后还有一个layer norm。对于MLP,其输入为提取出的[class]token生成的对应结果,即[197, 768]中抽取出[class]token对应的[1, 768],MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。

整体架构

下面是ViT-B/16的详细网络结构
在这里插入图片描述
论文中计算的数学表达如下:
在这里插入图片描述
其中, x p i x_p^i xpi表示第i个patch, E E E E p o s E_{pos} Epos分别表示Token Embedding和Positional Embedding, Z 0 Z_0 Z0是Transformer Encoder的输入,公式(2)是计算multi-head attention的过程,公式(3)是计算MLP的过程,公式(4)是最终分类任务,LN表示是一个简单的线性分类模型, Z L 0 Z_L^0 ZL0是得到的 cls token 对应的输出结果。

CNN的归纳偏置

归纳偏置就是一种假设,或者说一种先验知识。有了这种先验,就能知道哪一种方法更适合解决哪一类任务。所以归纳偏置是一种统称,不同的任务其归纳偏置下包含的具体内容不一样。

对图像任务来说,它的归纳偏置有以下两点:

  • 空间局部性(locality) :假设一张图片中,相邻的区域是有相关特征的。比如太阳和天空就经常一起出现。
  • 平移等边性(translation equivariance):无论是先做卷积还是先做平移,其结果都是一样的,即 f ( g ( x ) ) = g ( f ( x ) ) f(g(x))=g(f(x)) f(g(x))=g(f(x))

基于这两种先验知识,CNN成为了图像任务最佳的方案之一。卷积核能最大程度保持空间局部性(保存相关物体的位置信息)和平移等边性,使得在训练过程中,最大限度学习和保留原始图片信息。
而本文介绍的ViT没有使用卷积(除了在patch embedding时),完全丢弃了图像的归纳偏置。

代码实现

下面的代码来自rwightman的实现,这也是被官方认可的实现。
patch embedding的实现:

class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        """
        初始化 PatchEmbed 模块

        Args:
            img_size (int or tuple): 输入图像的尺寸,默认为 224
            patch_size (int or tuple): 图像块的尺寸,默认为 16
            in_c (int): 输入图像的通道数,默认为 3
            embed_dim (int): 嵌入维度,默认为 768
            norm_layer (nn.Module): 归一化层,默认为 None
        """
        super().__init__()
        # 将图像尺寸转换为元组形式
        img_size = (img_size, img_size)
        # 将图像块尺寸转换为元组形式
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        # 计算网格尺寸
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        # 计算图像块的数量
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        # 定义卷积层,用于将图像分割成图像块并进行嵌入
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        # 定义归一化层,如果提供了则使用,否则使用恒等映射
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        """
        前向传播

        Args:
            x (torch.Tensor): 输入图像张量,形状为 [B, C, H, W]

        Returns:
            torch.Tensor: 处理后的张量,形状为 [B, num_patches, embed_dim]
        """
        # 获取输入图像的形状
        B, C, H, W = x.shape
        # 检查输入图像的尺寸是否与模型设置的尺寸一致
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # 通过卷积层将图像分割成图像块并进行嵌入
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1, 2)
        # 进行归一化处理
        x = self.norm(x)
        return x

使用卷积层将图像分割成多个图像块,并将每个图像块映射到一个固定维度的嵌入向量。

Attention模块:实现多头自注意力机制(Multi-Head Self-Attention),用于捕捉输入序列中不同位置之间的依赖关系。

class Attention(nn.Module):
    def __init__(self,
                 dim,   # 输入token的dim
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        """
        初始化 Attention 模块

        Args:
            dim (int): 输入 token 的维度
            num_heads (int): 注意力头的数量,默认为 8
            qkv_bias (bool): 是否使用偏置项,默认为 False
            qk_scale (float): 缩放因子,默认为 None
            attn_drop_ratio (float): 注意力矩阵的丢弃概率,默认为 0.
            proj_drop_ratio (float): 投影层的丢弃概率,默认为 0.
        """
        super(Attention, self).__init__()
        self.num_heads = num_heads
        # 计算每个注意力头的维度
        head_dim = dim // num_heads
        # 计算缩放因子,如果未提供则使用默认值
        self.scale = qk_scale or head_dim ** -0.5
        # 定义线性层,用于生成查询(Q)、键(K)和值(V)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        # 定义注意力矩阵的丢弃层
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        # 定义投影层
        self.proj = nn.Linear(dim, dim)
        # 定义投影层的丢弃层
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):
        """
        前向传播

        Args:
            x (torch.Tensor): 输入张量,形状为 [batch_size, num_patches + 1, total_embed_dim]

        Returns:
            torch.Tensor: 处理后的张量,形状为 [batch_size, num_patches + 1, total_embed_dim]
        """
        # 获取输入张量的形状
        B, N, C = x.shape

        # 通过线性层生成查询(Q)、键(K)和值(V)
        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # 分离查询(Q)、键(K)和值(V)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # 计算注意力分数
        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # 对注意力分数进行 softmax 操作,得到注意力矩阵
        attn = attn.softmax(dim=-1)
        # 对注意力矩阵进行丢弃操作
        attn = self.attn_drop(attn)

        # 根据注意力矩阵对值(V)进行加权求和
        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        # 通过投影层进行线性变换
        x = self.proj(x)
        # 对投影层的输出进行丢弃操作
        x = self.proj_drop(x)
        return x

MLP模块:由两个全连接层和一个激活函数组成,通过对输入进行线性变换和非线性激活,得到输出。

class Mlp(nn.Module):
    """
    MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        """
        初始化 MLP 模块

        Args:
            in_features (int): 输入特征的维度
            hidden_features (int): 隐藏层特征的维度,默认为 None
            out_features (int): 输出特征的维度,默认为 None
            act_layer (nn.Module): 激活函数层,默认为 nn.GELU
            drop (float): 丢弃概率,默认为 0.
        """
        super().__init__()
        # 如果未提供输出特征的维度,则使用输入特征的维度
        out_features = out_features or in_features
        # 如果未提供隐藏层特征的维度,则使用输入特征的维度
        hidden_features = hidden_features or in_features
        # 定义第一个全连接层
        self.fc1 = nn.Linear(in_features, hidden_features)
        # 定义激活函数层
        self.act = act_layer()
        # 定义第二个全连接层
        self.fc2 = nn.Linear(hidden_features, out_features)
        # 定义丢弃层
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        """
        前向传播

        Args:
            x (torch.Tensor): 输入张量

        Returns:
            torch.Tensor: 处理后的张量
        """
        # 通过第一个全连接层
        x = self.fc1(x)
        # 通过激活函数层
        x = self.act(x)
        # 进行丢弃操作
        x = self.drop(x)
        # 通过第二个全连接层
        x = self.fc2(x)
        # 进行丢弃操作
        x = self.drop(x)
        return x

Block模块:实现 Transformer 编码器中的一个块,包含多头自注意力机制和多层感知机。对输入进行归一化处理,然后依次通过多头自注意力机制和多层感知机,最后使用残差连接将输入和输出相加。

class Block(nn.Module):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        """
        初始化 Block 模块

        Args:
            dim (int): 输入特征的维度
            num_heads (int): 注意力头的数量
            mlp_ratio (float): MLP 隐藏层维度与输入维度的比例,默认为 4.
            qkv_bias (bool): 是否使用偏置项,默认为 False
            qk_scale (float): 缩放因子,默认为 None
            drop_ratio (float): 丢弃概率,默认为 0.
            attn_drop_ratio (float): 注意力矩阵的丢弃概率,默认为 0.
            drop_path_ratio (float): 随机深度丢弃概率,默认为 0.
            act_layer (nn.Module): 激活函数层,默认为 nn.GELU
            norm_layer (nn.Module): 归一化层,默认为 nn.LayerNorm
        """
        super(Block, self).__init__()
        # 定义第一个归一化层
        self.norm1 = norm_layer(dim)
        # 定义注意力模块
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        # 定义随机深度丢弃层,如果丢弃概率大于 0 则使用,否则使用恒等映射
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        # 定义第二个归一化层
        self.norm2 = norm_layer(dim)
        # 计算 MLP 隐藏层的维度
        mlp_hidden_dim = int(dim * mlp_ratio)
        # 定义 MLP 模块
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)

    def forward(self, x):
        """
        前向传播

        Args:
            x (torch.Tensor): 输入张量

        Returns:
            torch.Tensor: 处理后的张量
        """
        # 先进行归一化,再通过注意力模块,最后加上随机深度丢弃和残差连接
        x = x + self.drop_path(self.attn(self.norm1(x)))
        # 先进行归一化,再通过 MLP 模块,最后加上随机深度丢弃和残差连接
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

VisionTransformer模块:实现完整的 Vision Transformer 模型,包括图像块嵌入、位置编码、Transformer 编码器和分类头。具体做法为:将输入图像通过PatchEmbed模块转换为嵌入向量,添加位置编码后,通过多个Block模块进行特征提取,最后通过分类头进行分类。

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
        """
        Args:
            img_size (int, tuple): 输入图像的尺寸,如果是整数则表示正方形图像的边长
            patch_size (int, tuple): 图像分块的尺寸,如果是整数则表示正方形分块的边长
            in_c (int): 输入图像的通道数,通常彩色图像为3
            num_classes (int): 分类任务的类别数
            embed_dim (int): 嵌入向量的维度
            depth (int): 变压器(Transformer)的层数
            num_heads (int): 多头注意力机制中的头数
            mlp_ratio (int): 多层感知机(MLP)隐藏层维度与嵌入维度的比例
            qkv_bias (bool): 是否在查询(Q)、键(K)、值(V)的线性变换中使用偏置
            qk_scale (float): 自定义的查询和键的缩放因子,如果未设置则使用默认值
            representation_size (Optional[int]): 如果设置,则启用并将表示层(预对数层)的维度设置为该值
            distilled (bool): 模型是否包含蒸馏令牌和头,如DeiT模型
            drop_ratio (float): 随机失活(Dropout)的概率
            attn_drop_ratio (float): 注意力机制中的随机失活概率
            drop_path_ratio (float): 随机深度(Stochastic Depth)的概率
            embed_layer (nn.Module): 用于图像分块嵌入的层
            norm_layer: (nn.Module): 归一化层
        """
        super(VisionTransformer, self).__init__()
        # 分类任务的类别数
        self.num_classes = num_classes
        # 特征维度,与嵌入维度保持一致,便于与其他模型统一接口
        self.num_features = self.embed_dim = embed_dim
        # 令牌数量,如果使用蒸馏则为2(分类令牌和蒸馏令牌),否则为1(分类令牌)
        self.num_tokens = 2 if distilled else 1
        # 如果未提供归一化层,则使用默认的LayerNorm层,设置eps为1e-6
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        # 如果未提供激活函数层,则使用默认的GELU激活函数
        act_layer = act_layer or nn.GELU

        # 图像分块嵌入层,将输入图像分割成多个分块并进行嵌入
        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
        # 分块的数量
        num_patches = self.patch_embed.num_patches

        # 分类令牌,可学习的参数,形状为 [1, 1, embed_dim]
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # 蒸馏令牌,如果使用蒸馏则为可学习的参数,形状为 [1, 1, embed_dim],否则为None
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
        # 位置嵌入,可学习的参数,形状为 [1, num_patches + num_tokens, embed_dim]
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        # 位置嵌入后的随机失活层
        self.pos_drop = nn.Dropout(p=drop_ratio)

        # 随机深度衰减规则,从0到drop_path_ratio线性插值生成depth个值
        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]
        # 变压器块序列,包含多个Block层
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
        # 归一化层,用于对变压器块的输出进行归一化
        self.norm = norm_layer(embed_dim)

        # 表示层(预对数层)
        if representation_size and not distilled:
            # 如果设置了表示层维度且不使用蒸馏,则启用表示层
            self.has_logits = True
            # 更新特征维度为表示层维度
            self.num_features = representation_size
            # 表示层,包含一个线性层和一个Tanh激活函数
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
        else:
            # 否则不启用表示层,使用恒等映射
            self.has_logits = False
            self.pre_logits = nn.Identity()

        # 分类头
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        # 蒸馏头,如果使用蒸馏则为线性层,否则为None
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

        # 权重初始化
        # 位置嵌入的权重使用截断正态分布初始化,标准差为0.02
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.dist_token is not None:
            # 蒸馏令牌的权重使用截断正态分布初始化,标准差为0.02
            nn.init.trunc_normal_(self.dist_token, std=0.02)
        # 分类令牌的权重使用截断正态分布初始化,标准差为0.02
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        # 应用自定义的权重初始化函数
        self.apply(_init_vit_weights)

    def forward_features(self, x):
        """
        前向传播特征提取部分
        Args:
            x (torch.Tensor): 输入图像,形状为 [B, C, H, W]

        Returns:
            torch.Tensor: 特征向量,如果使用蒸馏则返回分类令牌和蒸馏令牌的特征向量
        """
        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 196, 768]
        # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            # 如果不使用蒸馏,将分类令牌和分块嵌入拼接
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
        else:
            # 如果使用蒸馏,将分类令牌、蒸馏令牌和分块嵌入拼接
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)

        # 位置嵌入并进行随机失活
        x = self.pos_drop(x + self.pos_embed)
        # 通过变压器块序列
        x = self.blocks(x)
        # 归一化
        x = self.norm(x)
        if self.dist_token is None:
            # 如果不使用蒸馏,返回分类令牌的特征向量
            return self.pre_logits(x[:, 0])
        else:
            # 如果使用蒸馏,返回分类令牌和蒸馏令牌的特征向量
            return x[:, 0], x[:, 1]

    def forward(self, x):
        """
        前向传播函数
        Args:
            x (torch.Tensor): 输入图像,形状为 [B, C, H, W]

        Returns:
            torch.Tensor: 分类结果,如果使用蒸馏则返回分类结果和蒸馏结果的平均值
        """
        # 提取特征
        x = self.forward_features(x)
        if self.head_dist is not None:
            # 如果使用蒸馏,分别通过分类头和蒸馏头
            x, x_dist = self.head(x[0]), self.head_dist(x[1])
            if self.training and not torch.jit.is_scripting():
                # 训练时返回分类结果和蒸馏结果
                return x, x_dist
            else:
                # 推理时返回分类结果和蒸馏结果的平均值
                return (x + x_dist) / 2
        else:
            # 如果不使用蒸馏,通过分类头得到分类结果
            x = self.head(x)
        return x

实验结果

论文中训练了三种模型,如下表:
在这里插入图片描述
ViT与传统的CNN网络的比较:
在这里插入图片描述
VIT和卷积神经网络相比,表现基本一致,ViT的训练成本更低。
采用了不同数量的数据集,对VIT进行训练,效果如下:
在这里插入图片描述
当数据集较小时,ViT的效果不如CNN,但随着训练数据的增多,ViT的效果会逐步超越ViT。

总结

Vision Transformer(ViT)通过将图像分割为固定大小的图像块并转化为序列数据,结合可学习的位置编码与全局分类标识符(Class Token),利用多层Transformer Encoder的自注意力机制实现图像特征的全局建模,最终通过MLP Head输出分类结果。其核心工作流程以图像块嵌入为起点,通过位置编码赋予空间信息,在Transformer中逐层融合跨区域的语义关联,最终由Class Token汇聚全局信息完成分类。ViT的优势在于突破卷积的局部限制、实现长距离依赖建模,且模型扩展性强,但依赖大规模预训练数据且计算复杂度随图像分辨率陡增。未来研究可探索轻量化设计、局部-全局注意力结合、动态位置编码优化,以及跨模态与多任务的高效适配,进一步推动视觉Transformer在复杂场景下的实用化进程。ViT的成功不仅验证了Transformer在视觉任务中的普适性,更为跨模态统一模型开辟了新路径。


网站公告

今日签到

点亮在社区的每一天
去签到