Vision Transformer (ViT):将Transformer带入计算机视觉的革命性尝试(代码实现)

发布于:2025-03-13 ⋅ 阅读:(21) ⋅ 点赞:(0)

Vision Transformer (ViT):将Transformer带入计算机视觉的革命性尝试

作为一名深度学习研究者,如果你对自然语言处理(NLP)领域的Transformer架构了如指掌,那么你一定不会对它在序列建模中的强大能力感到陌生。然而,2021年由Google Research团队在ICLR上发表的论文《AN IMAGE IS WORTH 16x16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》却将这一熟悉的架构带入了一个全新的领域——计算机视觉,提出了Vision Transformer(ViT)。这篇博客将为你详细解析ViT的原理,结合你对Transformer的深厚理解,带你走进这一开创性的模型。

背景:从NLP到视觉的跨界思考

在NLP领域,Transformer(Vaswani et al., 2017)凭借其自注意力机制(Self-Attention)彻底改变了序列建模的范式。通过预训练大规模语言模型(如BERT、GPT),Transformer展现了惊艳的泛化能力和计算效率。然而,在计算机视觉领域,卷积神经网络(CNN)一直是无可争议的主宰,凭借其局部性、平移不变性等归纳偏置(Inductive Bias),在图像分类、目标检测等任务中占据主导地位。

ViT的核心思想大胆而简单:如果Transformer在NLP中能处理单词序列(Token Sequence),为什么不能将图像也看作一种序列呢?作者提出,通过将图像分割成固定大小的Patch,并将这些Patch作为输入序列直接交给标准Transformer处理,可以完全抛弃CNN的架构。这一尝试不仅挑战了CNN的统治地位,还揭示了大规模数据预训练对模型性能的深远影响。

ViT的架构:从图像到序列的转变

在这里插入图片描述
来源:https://arxiv.org/pdf/2010.11929

ViT的架构设计几乎是对NLP Transformer的“照搬”,但在输入处理上做了一些关键调整。以下是ViT的核心步骤,相信你会发现它与NLP中的处理流程有惊人的相似之处:

1. 图像分块与嵌入(Patch Embedding)

在NLP中,输入是一个单词序列,每个单词通过词嵌入(Word Embedding)映射为固定维度的向量。ViT将这一思想移植到图像上:

  • 图像分割:给定一张输入图像 ( x ∈ R H × W × C x \in \mathbb{R}^{H \times W \times C} xRH×W×C )(H、W为高宽,C为通道数,通常是RGB的3),ViT将其分割为固定大小的Patch,例如 ( P × P P \times P P×P )(论文中常用 ( 16 × 16 16 \times 16 16×16 ))。这会生成 ( N = H W / P 2 N = HW / P^2 N=HW/P2 ) 个Patch,每个Patch是一个 ( P 2 ⋅ C P^2 \cdot C P2C ) 维的向量。
  • 线性投影:这些Patch被展平后,通过一个可训练的线性层映射到一个固定维度 ( D D D ) 的嵌入空间,形成Patch Embedding。这与NLP中的词嵌入过程几乎一模一样,只是这里的“词”是图像Patch。
  • 位置编码(Position Embedding):与NLP类似,ViT为每个Patch添加位置编码,以保留空间信息。默认使用一维可学习位置编码(1D Positional Embedding),尽管论文也尝试了二维编码,但效果差别不大。

最终,输入序列为:
z 0 = [ x class ; x p 1 E ; x p 2 E ; ⋯   ; x p N E ] + E pos \mathbf{z}_0 = [\mathbf{x}_{\text{class}}; \mathbf{x}_p^1 \mathbf{E}; \mathbf{x}_p^2 \mathbf{E}; \cdots; \mathbf{x}_p^N \mathbf{E}] + \mathbf{E}_{\text{pos}} z0=[xclass;xp1E;xp2E;;xpNE]+Epos
其中,( E ∈ R ( P 2 ⋅ C ) × D \mathbf{E} \in \mathbb{R}^{(P^2 \cdot C) \times D} ER(P2C)×D ) 是投影矩阵,( E pos ∈ R ( N + 1 ) × D \mathbf{E}_{\text{pos}} \in \mathbb{R}^{(N+1) \times D} EposR(N+1)×D ) 是位置编码。

2. 分类标记(Class Token)

如果你熟悉BERT(可以参考笔者的另一篇博客:BERT模型详解:双向Transformer的语言理解革命(预训练和微调代码实现)),你一定知道它的 [CLS] Token。ViT也借鉴了这一设计,在Patch序列前添加一个可学习的分类标记(Class Token),记为 ( x class \mathbf{x}_{\text{class}} xclass )。这个Token的作用是在Transformer编码后,作为整个图像的表示,用于后续分类任务。

3. Transformer编码器

接下来,ViT将序列 ( z 0 \mathbf{z}_0 z0 ) 输入标准Transformer编码器,与NLP中的架构完全一致:

  • 多头自注意力(Multi-Head Self-Attention, MSA):通过自注意力机制,ViT在全局范围内整合Patch之间的信息,而不像CNN那样局限于局部感受野。
  • MLP块:每个Transformer层包含一个前馈网络(MLP),带有GELU激活函数。
  • 层归一化与残差连接:LayerNorm(LN)和残差连接确保了训练的稳定性。

经过 ( L L L ) 层Transformer编码后,输出序列为 ( z L \mathbf{z}_L zL )。其中,( z L 0 \mathbf{z}_L^0 zL0 )(即Class Token的输出)被用作图像表示:
y = LN ⁡ ( z L 0 ) \mathbf{y} = \operatorname{LN}(\mathbf{z}_L^0) y=LN(zL0)

4. 分类头

在预训练阶段,( y \mathbf{y} y ) 被送入一个带有单隐藏层的MLP进行分类;在微调阶段,则简化为一个线性层,输出类别数 ( K K K ) 的预测。

关键特性:极简与归纳偏置的取舍

ViT的设计极简,几乎没有引入图像特有的归纳偏置:

  • 与CNN的对比:CNN通过卷积操作天然具有局部性、平移不变性等特性,而ViT仅在Patch分割和微调时的分辨率调整中引入了少量二维结构信息。其余部分完全依赖自注意力从数据中学习空间关系。
  • 全局性:自注意力使ViT从第一层起就能关注整个图像,而CNN的感受野需要通过深层堆叠逐步扩大。

这种“无偏置”设计带来了一个重要问题:ViT是否能在数据量不足时泛化良好?答案是否定的。论文指出,当在中小规模数据集(如ImageNet,1.3M图像)上从头训练时,ViT的表现不如同等规模的ResNet。然而,当预训练数据规模扩大到14M(ImageNet-21k)或300M(JFT-300M)时,ViT开始展现出超越CNN的潜力。这表明,大规模数据可以弥补归纳偏置的缺失。

性能表现:数据驱动的胜利

ViT在多个基准测试中取得了令人瞩目的成绩:

  • ImageNet:ViT-H/14(Huge模型,14×14 Patch)达到88.55% Top-1精度,接近Noisy Student(EfficientNet-L2)的88.5%。
  • CIFAR-100:94.55%,超越BiT-L的93.51%。
  • VTAB(19任务):77.63%,显著优于BiT-L的76.29%。

更重要的是,ViT的预训练计算成本远低于CNN。例如,ViT-H/14在JFT-300M上预训练耗时2500 TPUv3-core-days,而BiT-L需要9900,Noisy Student更是高达12300。这种效率得益于Transformer的并行性和可扩展性。

深入分析:ViT如何“看”图像?

为了理解ViT的内部机制,论文提供了一些可视化分析:

  • 注意力距离:在较低层,部分注意力头关注局部区域,类似CNN的早期卷积层;随着层数增加,注意力范围扩展至全局。
  • 位置编码:ViT学习到的位置编码反映了图像的二维拓扑结构,邻近Patch的编码更相似。
  • 注意力图:通过Attention Rollout方法,ViT能聚焦于与分类任务语义相关的区域,展现出强大的解释性。

自监督预训练的初步探索

如果你对BERT的掩码语言建模(Masked Language Modeling)情有独钟,那么ViT的初步自监督实验可能会让你兴奋。作者尝试了掩码Patch预测(Masked Patch Prediction),类似BERT的策略,将50%的Patch替换为掩码,并预测其均值颜色。在JFT-300M上预训练后,ViT-B/16的ImageNet精度从头训练的77.9%提升至79.9%,尽管仍落后于监督预训练的83.97%。这表明自监督ViT有潜力,但仍需进一步优化。

对研究者的启示

对于熟悉NLP的你,ViT不仅是一个视觉模型,更是一个跨领域思想的桥梁:

  • 架构复用:ViT证明了Transformer的通用性,提示我们可以在更多模态上尝试类似的序列化建模。
  • 数据依赖性:大规模预训练对ViT至关重要,这与NLP中的经验一致。你可以思考如何设计更高效的自监督任务来减少数据需求。
  • 扩展方向:论文提出将ViT应用于检测、分割等任务(后续研究如DETR已验证其可行性),这可能是你未来研究的一个切入点。

结语

Vision Transformer以其简洁而大胆的设计,打破了CNN在计算机视觉中的垄断地位。它告诉我们,当数据和算力足够时,模型可以从头学习复杂的空间关系,而无需依赖传统归纳偏置。作为一名NLP领域的深度学习研究者,你是否也从中看到了Transformer无限可能的未来?欢迎留言分享你的看法!


参考文献
Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021.

ViT代码实现

以下是一个基于 PyTorch 的 Vision Transformer (ViT) 的完整、可运行的代码实现。这个实现参考了原始论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》的核心思想,包含了 Patch Embedding、Multi-Head Self-Attention 和 Transformer Encoder 的主要组件,并以 MNIST 数据集为例进行训练和测试。为了确保代码可运行,尽量保持简洁并提供注释。

环境要求

  • Python 3.8+
  • PyTorch 2.0+
  • Torchvision

完整代码

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

# 超参数设置
image_size = 28  # MNIST 图像大小为 28x28
patch_size = 7   # Patch 大小为 7x7
num_patches = (image_size // patch_size) ** 2  # 16 个 Patch
patch_dim = patch_size * patch_size * 1  # 输入通道为 1 (灰度图)
dim = 64         # 嵌入维度
depth = 6        # Transformer 层数
heads = 8        # 注意力头数
mlp_dim = 128    # MLP 隐藏层维度
num_classes = 10 # MNIST 类别数
dropout = 0.1    # Dropout 率

# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Patch Embedding 模块
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, patch_dim, dim, dropout):
        super().__init__()
        self.num_patches = (image_size // patch_size) ** 2
        # 线性投影:将 Patch 展平并映射到 dim 维度
        self.proj = nn.Linear(patch_dim, dim)
        # 位置编码
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, dim))
        # CLS Token
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B = x.shape[0]  # Batch Size
        # 将图像分割为 Patch 并展平
        x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)  # (B, C, H/p, W/p, p, p)
        x = x.permute(0, 2, 3, 1, 4, 5).contiguous()  # (B, H/p, W/p, C, p, p)
        x = x.view(B, self.num_patches, -1)  # (B, num_patches, patch_dim)
        # 线性投影
        x = self.proj(x)  # (B, num_patches, dim)
        # 添加 CLS Token
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, num_patches + 1, dim)
        # 添加位置编码
        x = x + self.pos_embedding
        x = self.dropout(x)
        return x

# 多头自注意力模块
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, heads, dropout):
        super().__init__()
        self.heads = heads
        self.scale = (dim // heads) ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=False)  # 查询、键、值投影
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(dim, dim)  # 输出投影

    def forward(self, x):
        B, N, C = x.shape  # (Batch, num_patches + 1, dim)
        # 生成 Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # (B, heads, N, dim/heads)
        # 注意力计算
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, heads, N, N)
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        # 加权求和
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)  # (B, N, dim)
        x = self.proj(x)
        x = self.dropout(x)
        return x

# Transformer Encoder 层
class TransformerEncoderLayer(nn.Module):
    def __init__(self, dim, heads, mlp_dim, dropout):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, heads, dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))  # 残差连接
        x = x + self.mlp(self.norm2(x))  # 残差连接
        return x

# Vision Transformer 模型
class ViT(nn.Module):
    def __init__(self, image_size, patch_size, patch_dim, dim, depth, heads, mlp_dim, num_classes, dropout):
        super().__init__()
        self.patch_embed = PatchEmbedding(image_size, patch_size, patch_dim, dim, dropout)
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(dim, heads, mlp_dim, dropout) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        cls_token = x[:, 0]  # 提取 CLS Token
        x = self.head(cls_token)
        return x

# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST 均值和标准差
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 初始化模型、损失函数和优化器
model = ViT(
    image_size=image_size,
    patch_size=patch_size,
    patch_dim=patch_dim,
    dim=dim,
    depth=depth,
    heads=heads,
    mlp_dim=mlp_dim,
    num_classes=num_classes,
    dropout=dropout
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练函数
def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 100 == 99:
            print(f'[Epoch {epoch+1}, Batch {i+1}] Loss: {running_loss / 100:.3f}')
            running_loss = 0.0

# 测试函数
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

# 主训练循环
num_epochs = 10
for epoch in range(num_epochs):
    train(model, train_loader, criterion, optimizer, epoch)
    test(model, test_loader)

代码说明

  1. Patch Embedding:

  2. Multi-Head Self-Attention:

    • 实现多头自注意力机制,Q、K、V 通过一个线性层生成。
    • 使用缩放点积注意力(Scaled Dot-Product Attention),并添加 Dropout。
  3. Transformer Encoder:

    • 包含 6 层 Transformer,每层有 MSA 和 MLP 块,带有残差连接和 LayerNorm。
    • MLP 使用 GELU 激活函数。
  4. 分类头:

    • 从最后一层提取 CLS Token,经过 LayerNorm 和线性层输出 10 个类别。
  5. 训练与测试:

    • 在 MNIST 数据集上训练 10 个 epoch,使用 Adam 优化器。
    • 每 100 个批次打印损失,并在每个 epoch 后测试准确率。

运行结果

在 CPU 或 GPU 上运行此代码,将下载 MNIST 数据集并开始训练。预期结果:

  • 训练损失逐渐下降。
  • 测试准确率在 10 个 epoch 后可能达到 95% 以上(具体取决于随机性和硬件)。

注意事项

  • 计算资源:如果在 CPU 上运行,可能较慢;建议使用 GPU 加速。
  • 超参数调整:当前设置适合 MNIST,处理更高分辨率图像(如 ImageNet)可能需要调整 patch_sizedimdepth
  • 扩展性:此代码是基础实现,未包含高级优化(如混合精度训练或预训练权重)。

参考

希望这个实现对你理解 ViT 的工作原理有所帮助!如果需要更复杂的版本(例如支持 ImageNet 数据集或预训练),可以进一步扩展。欢迎反馈或提问!


ViT 层的行为分析

引言

作为深度学习研究者,你对 Transformer 在 NLP 中的层级行为可能已非常熟悉:早期层关注语法和局部依赖,深层捕捉语义和长距离关系。那么,在计算机视觉的 Vision Transformer(ViT)中,层的行为是否类似?本文将深入探讨 ViT 的层级特征提取,特别关注其与 CNN 的对比,并分析第 31、32 层等深层的特性,结合现有研究提供全面见解。

CNN 的层级特征提取:从低级到高级

CNN (具体可以参考笔者的另一篇博客:卷积神经网络(CNN):深度解析其原理与特性)的强大之处在于其深层结构:

  • 浅层(如第 1、2 层):通过小卷积核提取低级特征,如边缘、纹理,感受野小,专注于局部信息。
  • 深层(如第 10 层或更深):通过堆叠卷积和池化层,感受野扩展,逐步学习高级语义特征,如对象部件(例如猫的耳朵)或整体形状(例如整只猫)。数学上,感受野扩展遵循公式:
    R F l = R F l − 1 + ( k − 1 ) ⋅ ∏ i = 1 l − 1 s i RF_l = RF_{l-1} + (k-1) \cdot \prod_{i=1}^{l-1} s_i RFl=RFl1+(k1)i=1l1si
    其中 ( R F l RF_l RFl ) 是第 ( l l l ) 层的感受野大小,( k k k ) 是卷积核大小,( s i s_i si ) 是之前各层的步幅。这使得深层 CNN 能捕捉全局上下文。

这种层次结构是 CNN 的归纳偏置(inductive bias),使其在数据量有限时表现良好。

ViT 的层级行为:从 Patch 到语义

ViT 的输入是将图像分割为固定大小的 Patch(如 16x16),每个 Patch 线性嵌入后添加位置编码,输入 Transformer 编码器。编码器由多头自注意力(MSA)和多层感知机(MLP)块交替组成。以下是层的行为分析:

早期层(如第 1、2 层)
  • 自注意力(MSA):从第一层起,MSA 允许每个 Patch 嵌入关注整个序列(所有 Patch),这与 CNN 的局部卷积不同。研究表明 [1],早期层的某些注意力头表现出全局行为,关注整个图像,而其他头则聚焦于局部区域,类似于 CNN 的早期卷积层。
  • MLP 块:MLP 是局部的,平移等变的,类似于卷积层,但作用于 Patch 嵌入。它通过 GELU 非线性添加非线性变换,初步精炼特征。
  • 位置编码:位置嵌入从一开始就编码 2D 空间结构,研究显示 [1],邻近 Patch 的位置编码更相似,反映行-列结构。
中间层(如第 15、16 层)
  • 注意力距离:研究 [2] 使用平均注意力距离分析,显示随着层数的增加,注意力范围扩大。中间层开始更多地整合跨 Patch 的信息,形成更复杂的空间关系。
  • 特征整合:MSA 层继续全局整合信息,MLP 层进一步非线性变换,逐步从 Patch 级别的原始信息向更高层次的表示过渡。
深层(如第 31、32 层)
  • 深层行为:对于深层(如第 31、32 层),需要注意 ViT 的层数通常较少(如 ViT-B/16 有 12 层),因此第 31、32 层可能超出了标准模型的深度。但假设模型有 32 层,研究 [3] 表明:
    • 深层注意力头几乎全部关注全局,平均注意力距离最大,专注于与任务相关的语义区域(如对象的关键部分)。
    • CLS Token(分类标记)的输出在深层更能代表整个图像的语义信息,适合分类任务。
  • 任务依赖:深层的具体特性高度依赖训练数据和任务。例如,在广义零样本学习(GZSL)中,研究 [4] 发现第 11 层(12 层模型)CLS 特征表现最佳,表明深层更适合提取属性相关信息。
与 CNN 的对比:层次结构的差异
  • CNN 的层次结构:CNN 从边缘到纹理,再到对象部件和整体,层次明确,归纳偏置强(如局部性、平移不变性)。深层逐步扩展感受野,构建明确的高级特征。
  • ViT 的灵活性:ViT 缺乏这种固有层次结构,早期层已能全局整合信息,深层更多是精炼注意力,聚焦语义相关区域。这种数据驱动的特性使其在大数据集上表现优异,但小数据集时可能不如 CNN。
研究论文与结论

以下是关键研究:

结论:

  • ViT 的层从浅到深确实有从局部到全局的转变,但不像 CNN 那样有严格的低级到高级特征层次。
  • 早期层(如第 1、2 层)关注局部和全局信息,深层(如第 31、32 层,假设模型足够深)更聚焦语义,具体特性依赖训练和任务。
  • 这种灵活性使 ViT 在大数据集上表现优异,弥补了缺乏 CNN 归纳偏置的不足。
表 1:ViT 与 CNN 层级行为的对比
特性 CNN ViT
早期层关注 低级特征(如边缘、纹理) 局部和全局信息,部分头全局关注
深层关注 高级语义(如对象部件、整体) 更全局,聚焦任务相关语义区域
层次结构 明确,低级到高级逐步构建 数据驱动,无严格层次,灵活性高
归纳偏置 强(如局部性、平移不变性) 弱,依赖大数据训练
深层(如第 31、32 层)特性 捕捉全局对象,明确语义 假设深,聚焦语义,任务依赖
讨论与未来方向

对于第 31、32 层,当前研究多集中于 12-24 层的标准 ViT 模型,深层(如 32 层)行为需更多实验验证。未来可探索自监督预训练(如掩码 Patch 预测)如何影响深层特征,及如何设计更高效的层级结构,结合 CNN 和 ViT 的优势。


后记

2025年3月12日19点34分于上海,在Grok 3大模型辅助下完成。


网站公告

今日签到

点亮在社区的每一天
去签到