Vision Transformer(ViT)模型实例化PyTorch逐行实现

发布于:2025-08-01 ⋅ 阅读:(12) ⋅ 点赞:(0)

为了让大家更好地理解,我们将从零开始,逐步构建 ViT 的各个核心组件,并最终将它们组合成一个完整的模型。我们会以一个在 CIFAR-10 数据集上应用的实例来贯穿整个讲解过程。

ViT 核心思想

在讲解代码之前,我们先快速回顾一下 ViT 的核心思想,这有助于理解代码每一部分的目的。

图片切块 (Image to Patches): 传统 CNN 逐像素处理图像,而 ViT 模仿 NLP 中处理单词 (Token) 的方式。它将一幅图像 (H*W*C) 切割成一个个小块 (Patch),每个小块大小为 P*P*C。

展平与线性投射 (Patch Flattening & Linear Projection): 将每个小块展平成一个一维向量,然后通过一个全连接层(线性投射)将其映射到一个固定的维度 D,这个向量就成为了 Transformer 的 "Token"。

类别令牌 (Class Token): 模仿 BERT 的 [CLS] 令牌,在所有 Patch Token 的最前面加入一个可学习的 [CLS] Token。这个 Token 最终将用于图像分类。

位置编码 (Positional Embedding): Transformer 本身不包含位置信息。为了让模型知道每个 Patch 的原始位置,我们需要为每个 Token(包括 [CLS] Token)添加一个可学习的位置编码。

Transformer 编码器 (Transformer Encoder): 将带有位置编码的 Token 序列输入到标准的 Transformer Encoder 中。Encoder 由多层堆叠而成,每一层都包含一个多头自注意力模块 (Multi-Head Self-Attention) 和一个前馈网络 (Feed-Forward Network)

分类头 (MLP Head): 将 Transformer Encoder 输出的 [CLS] Token 对应的向量,送入一个简单的多层感知机(MLP),最终输出分类结果。

实例设定

我们将以 CIFAR-10 数据集为例。

图片尺寸 (image_size): 32*32*3

Patch 尺寸 (patch_size): 4*4 (我们可以选择 8x8 或 16x16,这里用 4x4 举例)

类别数 (num_classes): 10

嵌入维度 (dim): 512 (每个 Patch 展平后映射到的维度)

Transformer Encoder 层数 (depth): 6

多头注意力头数 (heads): 8

MLP 内部维度 (mlp_dim): 2048

根据这些设定,我们可以计算出:

每张图片的 Patch 数量 (num_patches): (32/4)x(32/4)=8x8=64

PyTorch 代码逐行实现

我们将按照 ViT 的思想,一步步构建代码。

1. Patch Embedding (图像切块与线性投射)

这是 ViT 的第一步,我们的目标是将一个 (B, C, H, W) 的图像张量,转换成一个 (B, N, D) 的 Token 序列张量,其中 B 是批量大小,N 是 Patch 数量,D 是嵌入维度。

一个巧妙高效的实现方法是使用二维卷积

思想: 我们可以设置一个卷积层,其卷积核大小 (kernel_size)步长 (stride) 都等于 patch_size。这样,卷积核每次滑动的区域恰好就是一个不重叠的 Patch。卷积的输出通道数设为我们想要的嵌入维度 dim

import torch
from torch import nn

class PatchEmbedding(nn.Module):
    """
    将图像分割成块并进行线性嵌入。
    
    参数:
        image_size (int): 输入图像的尺寸 (假设为正方形)。
        patch_size (int): 每个图像块的尺寸 (假设为正方形)。
        in_channels (int): 输入图像的通道数。
        dim (int): 线性投射后的嵌入维度。
    """
    def __init__(self, image_size, patch_size, in_channels, dim):
        super().__init__()
        self.patch_size = patch_size
        
        # 检查图像尺寸是否能被 patch 尺寸整除
        if not (image_size % patch_size == 0):
            raise ValueError("error")
            
        # 计算 patch 的数量
        self.num_patches = (image_size // patch_size) ** 2
        
        # 核心:使用 Conv2d 实现 patch 化和线性投射
        # kernel_size 和 stride 都设为 patch_size,实现不重叠的块分割
        # out_channels 设为嵌入维度 dim
        self.projection = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # 输入 x 的形状: (B, C, H, W)
        # 例如: (B, 3, 32, 32)
        
        # 经过卷积层,将图像转换为 patch 的特征图
        # 输出形状: (B, dim, H/P, W/P)
        # 例如: (B, 512, 8, 8)
        x = self.projection(x)
        
        # 将特征图展平
        # .flatten(2) 将从第2个维度开始展平 (H/P 和 W/P 维度)
        # 输出形状: (B, dim, N) 其中 N = (H/P) * (W/P)
        # 例如: (B, 512, 64)
        x = x.flatten(2)
        
        # 交换维度,以匹配 Transformer 输入格式 (B, N, D)
        # 输出形状: (B, N, dim)
        # 例如: (B, 64, 512)
        x = x.transpose(1, 2)
        
        return x
2. Transformer Encoder Block

Transformer Encoder 由多个相同的块 (Block) 堆叠而成。每个块包含两个主要部分:

多头自注意力 (Multi-Head Self-Attention)

前馈网络 (Feed-Forward Network / MLP)

每个部分都伴随着残差连接 (Residual Connection) 和层归一化 (Layer Normalization)。

class TransformerEncoderBlock(nn.Module):
    """
    标准的 Transformer Encoder 块。
    
    参数:
        dim (int): 输入的 token 维度。
        heads (int): 多头注意力的头数。
        mlp_dim (int): MLP 层的隐藏维度。
        dropout (float): Dropout 的概率。
    """
    def __init__(self, dim, heads, mlp_dim, dropout=0.1):
        super().__init__()
        # 第一个 LayerNorm
        self.norm1 = nn.LayerNorm(dim)
        
        # 多头自注意力模块
        # PyTorch 内置的 MultiheadAttention 期望输入形状为 (N, B, D),
        # 但我们通常使用 (B, N, D)。设置 batch_first=True 可以解决这个问题。
        self.attention = nn.MultiheadAttention(
            embed_dim=dim, 
            num_heads=heads, 
            dropout=dropout, 
            batch_first=True
        )
        
        # 第二个 LayerNorm
        self.norm2 = nn.LayerNorm(dim)
        
        # MLP / 前馈网络
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),  # ViT 论文中使用的激活函数
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # x 的形状: (B, N, D)
        
        # 1. 多头自注意力部分
        # 残差连接: x + Attention(LayerNorm(x))
        x_norm = self.norm1(x)
        # 注意力模块返回 attn_output 和 attn_weights,我们只需要前者
        attn_output, _ = self.attention(x_norm, x_norm, x_norm)
        x = x + attn_output
        
        # 2. 前馈网络部分
        # 残差连接: x + MLP(LayerNorm(x))
        x_norm = self.norm2(x)
        mlp_output = self.mlp(x_norm)
        x = x + mlp_output
        
        return x
3. 完整的 Vision Transformer 模型

现在,我们将所有组件整合在一起。

class VisionTransformer(nn.Module):
    """
    Vision Transformer 模型。
    
    参数:
        image_size (int): 输入图像尺寸。
        patch_size (int): Patch 尺寸。
        in_channels (int): 输入通道数。
        num_classes (int): 分类类别数。
        dim (int): 嵌入维度。
        depth (int): Transformer Encoder 层数。
        heads (int): 多头注意力头数。
        mlp_dim (int): MLP 隐藏维度。
        dropout (float): Dropout 概率。
    """
    def __init__(self, image_size, patch_size, in_channels, num_classes,
                 dim, depth, heads, mlp_dim, dropout=0.1):
        super().__init__()
        
        # 1. Patch Embedding
        self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, dim)
        
        # 计算 patch 数量
        num_patches = self.patch_embedding.num_patches
        
        # 2. Class Token
        # 这是一个可学习的参数,维度为 (1, 1, D)
        # '1' 个 batch,'1' 个 token,'D' 维
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        
        # 3. Positional Embedding
        # 这也是一个可学习的参数
        # 长度为 num_patches + 1 (为了包含 cls_token)
        # 维度为 (1, N+1, D)
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        
        self.dropout = nn.Dropout(dropout)
        
        # 4. Transformer Encoder
        # 使用 nn.Sequential 将多个 Encoder Block 堆叠起来
        self.transformer_encoder = nn.Sequential(
            *[TransformerEncoderBlock(dim, heads, mlp_dim, dropout) for _ in range(depth)]
        )
        
        # 5. MLP Head (分类头)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim), # 在送入分类头前先进行一次 LayerNorm
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        # img 形状: (B, C, H, W)
        
        # 1. 获取 Patch Embedding
        # x 形状: (B, N, D)
        x = self.patch_embedding(img)
        b, n, d = x.shape  # b: batch_size, n: num_patches, d: dim
        
        # 2. 添加 Class Token
        # 将 cls_token 复制 b 份,拼接到 x 的最前面
        # cls_tokens 形状: (B, 1, D)
        cls_tokens = self.cls_token.expand(b, -1, -1) 
        # x 形状变为: (B, N+1, D)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # 3. 添加 Positional Embedding
        # pos_embedding 形状是 (1, N+1, D),利用广播机制直接相加
        x += self.pos_embedding
        x = self.dropout(x)
        
        # 4. 通过 Transformer Encoder
        # x 形状不变: (B, N+1, D)
        x = self.transformer_encoder(x)
        
        # 5. 提取 Class Token 的输出用于分类
        # 只取序列的第一个 token (cls_token) 的输出
        # x 形状: (B, D)
        cls_token_output = x[:, 0]
        
        # 6. 通过 MLP Head 得到最终的分类 logits
        # output 形状: (B, num_classes)
        output = self.mlp_head(cls_token_output)
        
        return output

完整模型与实例

现在我们把所有代码放在一起,并用我们之前设定的 CIFAR-10 参数来实例化模型,看看它的输入和输出。

import torch
from torch import nn

# --- 组件 1: PatchEmbedding ---
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, dim):
        super().__init__()
        if not (image_size % patch_size == 0):
            raise ValueError("Image dimensions must be divisible by the patch size.")
        self.num_patches = (image_size // patch_size) ** 2
        self.projection = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

# --- 组件 2: TransformerEncoderBlock ---
class TransformerEncoderBlock(nn.Module):
    def __init__(self, dim, heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attention = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        attn_output, _ = self.attention(self.norm1(x), self.norm1(x), self.norm1(x))
        x = x + attn_output
        mlp_output = self.mlp(self.norm2(x))
        x = x + mlp_output
        return x

# --- 主模型: VisionTransformer ---
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, num_classes,
                 dim, depth, heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, dim)
        num_patches = self.patch_embedding.num_patches
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.dropout = nn.Dropout(dropout)
        self.transformer_encoder = nn.Sequential(
            *[TransformerEncoderBlock(dim, heads, mlp_dim, dropout) for _ in range(depth)]
        )
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.patch_embedding(img)
        b, n, d = x.shape
        cls_tokens = self.cls_token.expand(b, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding
        x = self.dropout(x)
        x = self.transformer_encoder(x)
        cls_token_output = x[:, 0]
        output = self.mlp_head(cls_token_output)
        return output

# --- 实例化并测试 ---

# CIFAR-10 实例参数
BATCH_SIZE = 4
IMAGE_SIZE = 32
IN_CHANNELS = 3
PATCH_SIZE = 4
NUM_CLASSES = 10
DIM = 512
DEPTH = 6
HEADS = 8
MLP_DIM = 2048

# 创建模型实例
vit_model = VisionTransformer(
    image_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    in_channels=IN_CHANNELS,
    num_classes=NUM_CLASSES,
    dim=DIM,
    depth=DEPTH,
    heads=HEADS,
    mlp_dim=MLP_DIM
)

# 创建一个假的输入图像张量 (Batch, Channels, Height, Width)
dummy_img = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)

# 将图像输入模型
logits = vit_model(dummy_img)

# 打印输出的形状
print(f"输入图像形状: {dummy_img.shape}")
print(f"模型输出 (Logits) 形状: {logits.shape}")

# 检查输出形状是否正确
assert logits.shape == (BATCH_SIZE, NUM_CLASSES)
print("\n模型构建成功,输入输出形状正确!")


网站公告

今日签到

点亮在社区的每一天
去签到