PyTorch 的 unfold 函数:深入解析与在 Vision Transformer (ViT) 中的应用

发布于:2025-03-13 ⋅ 阅读:(22) ⋅ 点赞:(0)

PyTorch 的 unfold 函数:深入解析与在 Vision Transformer (ViT) 中的应用

作为一名深度学习研究者,你可能对 PyTorch 的灵活性和强大功能并不陌生。在计算机视觉领域,尤其是 Vision Transformer (ViT) 的实现中,torch.nn.functional.unfold(或 Tensor 的 unfold 方法)是一个非常实用的工具,用于将图像分割成固定大小的 Patch。本篇博客将深入介绍 unfold 函数的原理、用法,并结合 ViT 中的 Patch Embedding 模块,展示它在现代视觉模型中的具体应用。目标读者是对 PyTorch 和深度学习有一定基础的研究者,希望通过这篇文章加深你对 unfold 的理解及其在 ViT 中的作用。


什么是 unfold 函数?

unfold 是 PyTorch 提供的一个操作,用于从张量中提取滑动窗口(Sliding Window)的子区域。它在图像处理中特别有用,可以高效地将二维图像分割为多个局部块(Patch),而无需显式循环。这种操作类似于卷积神经网络(CNN)中的卷积操作,但它直接返回原始数据的子块,而不是计算加权和。

基本定义

对于一个张量,unfold 方法的签名如下:

tensor.unfold(dimension, size, step)
  • dimension:沿哪个维度展开(对于图像,通常是高度或宽度维度)。
  • size:窗口大小(即提取的子块大小)。
  • step:滑动步幅(Stride),控制窗口移动的距离。

返回的张量包含所有滑动窗口的子块,形状根据输入和参数调整。

工作原理

假设有一个二维张量 ( X ∈ R H × W X \in \mathbb{R}^{H \times W} XRH×W )(高 ( H H H ),宽 ( W W W )),沿高度维度应用 unfold

X.unfold(dimension=0, size=kh, step=stride)
  • ( k h kh kh ) 是窗口高度。
  • 输出形状为 ( ( H − k h + 1 ) / s t r i d e × W × k h (H - kh + 1) / stride \times W \times kh (Hkh+1)/stride×W×kh )(假设步幅正好整除)。
  • 每个窗口是一个长度为 ( k h kh kh ) 的向量,包含原始张量在该维度上的连续子区域。

如果再沿宽度维度应用 unfold,可以得到所有二维 Patch。

示例:二维图像的 unfold

假设有一个 ( 4 × 4 4 \times 4 4×4 ) 的张量:

import torch

x = torch.arange(16).reshape(4, 4).float()
print("原始张量:")
print(x)

# 沿高度维度展开,窗口大小为 2,步幅为 2
x_unfolded = x.unfold(0, 2, 2)
print("\n沿高度展开:")
print(x_unfolded)

# 再沿宽度维度展开,窗口大小为 2,步幅为 2
x_unfolded = x_unfolded.unfold(1, 2, 2)
print("\n沿宽度展开:")
print(x_unfolded)

输出:

原始张量:
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])

沿高度展开:
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.]],
        [[ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]]])

沿宽度展开:
tensor([[[[ 0.,  1.],
          [ 4.,  5.]],
         [[ 2.,  3.],
          [ 6.,  7.]]],
        [[[ 8.,  9.],
          [12., 13.]],
         [[10., 11.],
          [14., 15.]]]])
  • 第一次 unfold 将 ( 4 × 4 4 \times 4 4×4 ) 张量分成 ( 2 × 4 × 2 2 \times 4 \times 2 2×4×2 ) 的张量(2 个高度块,每个块高 2,宽 4)。
  • 第二次 unfold 将每个高度块再分成 ( 2 × 2 × 2 2 \times 2 \times 2 2×2×2 ) 的块,最终得到 ( 2 × 2 × 2 × 2 2 \times 2 \times 2 \times 2 2×2×2×2 ) 的张量,表示 4 个 ( 2 × 2 2 \times 2 2×2 ) 的 Patch。

在 ViT 中的应用:Patch Embedding

Vision Transformer (ViT) 的核心创新是将图像视为 Patch 序列,而不是传统 CNN 的像素网格。Patch Embedding 模块负责将输入图像分割为固定大小的 Patch,并将其嵌入到 Transformer 可处理的向量空间中。unfold 在这里发挥了关键作用。具体可以参考笔者的另一篇博客:Vision Transformer (ViT):将Transformer带入计算机视觉的革命性尝试(代码实现)

ViT Patch Embedding 代码解析

以下是 ViT 中 Patch Embedding 的 PyTorch 实现(以 MNIST 为例,图像大小 ( 28 × 28 28 \times 28 28×28 )):

class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, patch_dim, dim, dropout):
        super().__init__()
        self.num_patches = (image_size // patch_size) ** 2
        # 线性投影:将 Patch 展平并映射到 dim 维度
        self.proj = nn.Linear(patch_dim, dim)
        # 位置编码
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, dim))
        # CLS Token
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B = x.shape[0]  # Batch Size
        # 将图像分割为 Patch 并展平
        x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)  # (B, C, H/p, W/p, p, p)
        x = x.permute(0, 2, 3, 1, 4, 5).contiguous()  # (B, H/p, W/p, C, p, p)
        x = x.view(B, self.num_patches, -1)  # (B, num_patches, patch_dim)
        # 线性投影
        x = self.proj(x)  # (B, num_patches, dim)
        # 添加 CLS Token
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, num_patches + 1, dim)
        # 添加位置编码
        x = x + self.pos_embedding
        x = self.dropout(x)
        return x
代码逐步解析
  1. 输入张量

    • 输入 ( x x x ) 的形状为 ( ( B , C , H , W ) (B, C, H, W) (B,C,H,W) ),例如 MNIST 的 ( ( B , 1 , 28 , 28 ) (B, 1, 28, 28) (B,1,28,28) )(灰度图,通道数 ( C = 1 C=1 C=1 ))。
    • ( B B B ) 是批次大小,( H H H) 和 ( W W W ) 是图像高度和宽度。
  2. 使用 unfold 分割 Patch

    • 高度维度x.unfold(2, patch_size, patch_size) 将高度 ( H = 28 H=28 H=28 ) 分成 ( H / p = 28 / 7 = 4 H/p=28/7=4 H/p=28/7=4 ) 个块,每个块大小为 ( p = 7 p=7 p=7 ),步幅为 7(无重叠)。形状变为 ( ( B , C , 4 , W , 7 ) (B, C, 4, W, 7) (B,C,4,W,7) )。
    • 宽度维度:再次 unfold(3, patch_size, patch_size) 将宽度 ( W = 28 W=28 W=28 ) 分成 ( W / p = 4 W/p=4 W/p=4 ) 个块,形状变为 ( ( B , C , 4 , 4 , 7 , 7 ) (B, C, 4, 4, 7, 7) (B,C,4,4,7,7) )。
    • 结果是 ( 4 × 4 = 16 4 \times 4 = 16 4×4=16 ) 个 ( 7 × 7 7 \times 7 7×7 ) 的 Patch。
  3. 调整维度

    • permute(0, 2, 3, 1, 4, 5) 调整顺序为 ( ( B , 4 , 4 , C , 7 , 7 ) (B, 4, 4, C, 7, 7) (B,4,4,C,7,7) ),将空间维度 ( ( H / p , W / p ) (H/p, W/p) (H/p,W/p) ) 放在前面。
    • view(B, num_patches, -1) 将每个 Patch 展平为向量,形状变为 ( ( B , 16 , 49 ) (B, 16, 49) (B,16,49) )(对于灰度图,( p a t c h _ d i m = 7 × 7 × 1 = 49 patch\_dim = 7 \times 7 \times 1 = 49 patch_dim=7×7×1=49 ))。
  4. 线性投影

    • self.proj 将 ( 49 49 49 ) 维 Patch 映射到目标维度 ( d i m dim dim )(例如 64),形状变为 ( ( B , 16 , 64 ) (B, 16, 64) (B,16,64) )。
  5. 添加 CLS Token 和位置编码

    • CLS Token 扩展为 ( ( B , 1 , 64 ) (B, 1, 64) (B,1,64) ),与 Patch 拼接后为 ( ( B , 17 , 64 ) (B, 17, 64) (B,17,64) )。
    • 位置编码 ( pos_embedding )(形状 ( ( 1 , 17 , 64 ) (1, 17, 64) (1,17,64) ))加到序列上,保留空间信息。
unfold 的作用
  • 高效性unfold 避免了手动循环或切片,直接提取所有 Patch,计算效率高。
  • 灵活性:通过调整 patch_sizestep,可以轻松改变 Patch 大小和重叠程度。
  • 与 ViT 的契合:ViT 需要将图像转为序列,unfold 提供了自然的分块方式,符合 Transformer 的输入需求。

unfold 的优势与局限

优势

  1. 计算效率unfold 是 PyTorch 的原生操作,底层优化良好,适合大批量图像处理。
  2. 内存友好:返回的张量是输入的视图(View),不复制数据,节省内存。
  3. 直观性:与卷积类似,易于理解和调试。

局限

  1. 边界处理:如果图像大小不能被 patch_sizestep 整除,部分区域可能被截断(需填充或调整输入)。
  2. 步幅限制:当前实现中步幅等于 Patch 大小(无重叠),若需重叠 Patch,需调整逻辑。
  3. 维度管理:多次 unfold 后维度复杂,需配合 permuteview 调整。

在 ViT 中的意义

ViT 的创新在于将图像视为 Patch 序列,unfold 是实现这一思想的关键步骤:

  • 序列化:通过 unfold 分割,图像从二维网格变为 Transformer 可处理的一维序列。
  • 全局建模:与 CNN 的局部卷积不同,ViT 依赖自注意力全局整合信息,unfold 提供了基础输入。
  • 可扩展性:对于更高分辨率图像(如 224x224 的 ImageNet),只需调整 patch_size(如 16),即可生成更多 Patch。

示例:MNIST 上的可视化

假设输入一张 ( 28 × 28 28 \times 28 28×28 ) 的 MNIST 图像,unfold 生成了 16 个 ( 7 × 7 7 \times 7 7×7 ) 的 Patch。每个 Patch 展平为 49 维向量,投影后作为 Transformer 的输入 token。这种分块方式保留了局部信息,同时通过位置编码和自注意力捕捉全局关系。


扩展与优化建议

  1. 支持重叠 Patch
    • step 设为小于 patch_size(如 step=5),生成重叠 Patch,增强局部细节捕捉。
  2. 动态输入大小
    • 对非整除的图像大小,添加填充(Padding)确保完整分割。
  3. 替代实现
    • 使用 torch.nn.functional.unfold(而不是 Tensor 的 unfold 方法),一次性提取所有 Patch,可能更高效:
      x = F.unfold(x, kernel_size=patch_size, stride=patch_size)  # (B, C*p*p, num_patches)
      x = x.transpose(1, 2)  # (B, num_patches, C*p*p)
      

结语

unfold 函数是 PyTorch 中一个强大而低调的工具,在 ViT 的 Patch Embedding 中扮演了至关重要的角色。它高效地将图像分割为 Patch,为 Transformer 的序列化输入奠定了基础。通过理解 unfold 的原理和用法,你不仅能更好地实现 ViT,还能将其应用于其他需要滑动窗口的任务(如时间序列分析或特征提取)。希望这篇博客为你提供了清晰的思路,欢迎留言讨论你的应用场景或优化想法!


参考

后记

2025年3月12日19点45分于上海,在Grok 3大模型辅助下完成。


网站公告

今日签到

点亮在社区的每一天
去签到