摘要
本文提出了一种无需依赖手工设计的数据增强方法即可学习高度语义图像表示的技术。我们引入了一种基于图像的联合嵌入预测架构(Image-based Joint-Embedding Predictive Architecture,简称 I-JEPA),这是一种非生成式的图像自监督学习方法。I-JEPA 的核心思想很简单:从单个上下文块出发,预测同一图像中多个目标块的表示。引导 I-JEPA 生成语义表示的关键设计选择是其掩码策略;具体而言,关键在于 (a) 采样具有足够大尺度(语义性)的目标块,以及 (b) 使用具有足够信息量(空间分布广泛)的上下文块。实证表明,当与 Vision Transformers 结合使用时,I-JEPA 具有极强的可扩展性。例如,我们使用 16 个 A100 GPU 在不到 72 小时内在 ImageNet 上训练了一个 ViT-Huge/14 模型,并在从线性分类到目标计数和深度预测等多种下游任务中表现出色。
1. 引言
在计算机视觉中,从图像进行自监督学习的常见方法大致可以分为两类:基于不变性的方式 [1,4,10,17,18,24,35,37,74] 和生成式方法 [8, 28, 36, 57]。
基于不变性的预训练方法通过优化编码器,使其对同一图像的两个或多个视图生成相似的嵌入表示 [15, 20],而这些图像视图通常是通过一组手工设计的数据增强方式(如随机缩放、裁剪和颜色扰动 [20],以及其他方式 [35])构造的。这些预训练方法可以生成语义层次较高的表示 [4, 18],但它们也引入了强烈的偏差,这些偏差可能对某些下游任务甚至对具有不同数据分布的预训练任务造成不利影响 [2]。通常情况下,这些偏差很难泛化到需要不同抽象层次的任务中。例如,图像分类和实例分割就不需要相同的不变性 [11]。此外,将这些特定于图像的数据增强方式泛化到其他模态(如音频)也并不容易。
认知学习理论提出,生物系统中表征学习的驱动机制之一是内部模型的适应,用于预测感官输入的反应 [31, 59]。这一思想是自监督生成方法的核心,这些方法通过移除或破坏输入的部分内容,并学习预测被破坏的内容 [9, 36, 57, 67, 68, 71]。特别是,掩码去噪方法通过重建来自输入的随机掩码块来学习表征,这些块可以是像素级别或token级别的。与视图不变性方法相比,掩码预训练任务需要的先验知识较少,并且可以轻松泛化到图像模态之外 [8]。然而,生成的表示通常处于较低的语义层次,并且在现成评估(例如线性探测)和有限监督的语义分类任务的迁移设置中,通常表现不如基于不变性的预训练方法 [4]。因此,为了充分利用这些方法,通常需要更复杂的适应机制(例如端到端微调)。
在本研究中,我们探索如何在不使用通过图像变换编码的额外先验知识的情况下,提升自监督表示的语义水平。为此,我们提出了一种用于图像的联合嵌入预测架构(Image-based Joint-Embedding Predictive Architecture,I-JEPA)[48]。图3展示了该方法的示意图。I-JEPA 的核心思想是在抽象的表示空间中预测缺失的信息;例如,给定一个context block,预测同一图像中多个target block的表示,其中target表示由一个训练得到的target-encoder网络计算。
与在像素/token空间中进行预测的生成方法相比,I-JEPA 使用抽象的预测目标,可能会去除不必要的像素级细节,从而引导模型学习更具语义性的特征。另一个指导 I-JEPA 学习语义表示的核心设计选择是我们提出的multi-block掩码策略。具体来说,我们展示了在图像中预测足够大的target block的重要性,同时使用一个信息丰富(空间上分布广泛)的context block。
通过大量实证评估,我们证明了:
- I-JEPA 可以在不使用手工设计的视图增强的情况下学习出色的现成表示(参见图1)。在ImageNet-1K的线性探测、1%样本的半监督学习以及语义迁移任务中,I-JEPA 的表现优于像素重建方法如 MAE [36]。
- I-JEPA 在语义任务中具有与视图不变预训练方法相当的竞争力,并在诸如目标计数和深度预测等低层视觉任务中表现更佳(第5节和第6节)。通过使用一个结构更简单、归纳偏差更少的模型,I-JEPA 可应用于更广泛的任务。
- I-JEPA 同时具有良好的可扩展性和高效性(第7节)。在 ImageNet 上预训练一个 ViT-H/14 仅需不到 1200 GPU 小时,这比使用 iBOT [79] 预训练的 ViT-S/16 快超过 2.5 倍,比使用 MAE 预训练的 ViT-H/14 高效超过 10 倍。在表示空间中进行预测显著减少了自监督预训练所需的总计算量。
2. 背景
自监督学习是一种表示学习方法,系统通过学习输入之间的关系来进行学习。该目标可以通过能量基模型(Energy-Based Models, EBMs)[49] 的框架来描述,其中自监督目标是为不兼容的输入分配高能量,为兼容的输入分配低能量。许多现有的生成和非生成方法的自监督学习实际上都可以在该框架下进行建模;参见图2。
联合嵌入架构(Joint-Embedding Architectures)
基于不变性的预训练方法可以在能量基模型(EBMs)的框架下表述为联合嵌入架构(Joint-Embedding Architecture,JEA),该架构学习为兼容的输入对 x, y 生成相似的嵌入表示,为不兼容的输入对生成不同的嵌入表示;参见图2a。在图像预训练的背景下,兼容的x, y 对通常通过对同一张图像随机应用手工设计的数据增强操作来构造 [20]。
JEA 面临的主要挑战是表示坍塌(representation collapse),即能量空间过于平坦(即编码器对任意输入均输出恒定结果)。近年来,已经提出了多种策略来防止表示坍塌,例如:
- 对比损失(contrastive losses),通过显式拉开负样本的嵌入距离 [15, 24, 37];
- 非对比损失(non-contrastive losses),通过最小化嵌入间的信息冗余 [10, 74];
- 聚类方法,通过最大化平均嵌入的熵来提升多样性 [4, 5, 18];
- 启发式设计,例如采用不对称的架构设计(x-encoder 和 y-encoder 不同)以防止坍塌 [8, 24, 35]。
生成架构(Generative Architectures)
基于重建的自监督学习方法也可以在能量基模型的框架中建模为生成架构(参见图2b)。生成架构直接学习从兼容信号 xx 重建信号 yy,通过一个条件解码器网络完成,该网络依赖一个额外的(可能是潜在的)变量 zz 来协助重建。在图像预训练中,计算机视觉领域常见的方法是使用掩码生成兼容的 x,yx, y 对 [9, 38],其中 xx 是图像 yy 的一个副本,但部分patch被遮蔽。条件变量z 对应于一组(可能是可学习的)mask与位置信息,用于告知解码器应重建哪些patch。
在生成架构中,只要 zz 的信息容量远低于信号y,则通常不会面临表示坍塌的问题。
联合嵌入预测架构(Joint-Embedding Predictive Architectures)
如图2c所示,联合嵌入预测架构(Joint-Embedding Predictive Architectures, JEPA)在概念上与生成架构类似,但关键区别在于其损失函数是施加在嵌入空间中,而非输入空间中。JEPA 学习从兼容信号 x 的嵌入中预测信号 y 的嵌入,预测过程依赖一个额外的(可能是潜在的)变量z 来提供辅助信息。
我们提出的 I-JEPA 是在图像任务中采用掩码机制的 JEPA 实例,详见图3。
与 JEA 不同,JEPA 不追求对一组手工设计的数据增强操作不变的表示,而是学习在条件信息z 给定的情况下彼此具有预测性的表示。然而,与 JEA 一样,JEPA 也可能面临表示坍塌问题;因此,我们同样采用了 不对称的 x-encoder 与 y-encoder 架构设计以避免坍塌。
3. 方法
我们在本节中介绍所提出的图像级联合嵌入预测架构(Image-based Joint-Embedding Predictive Architecture,I-JEPA),如图3所示。总体目标如下:给定一个上下文块(context block),预测同一图像中多个目标块(target blocks)的表示。我们为 context-encoder、target-encoder 和 predictor 均采用 Vision Transformer(ViT)架构 [29, 63]。ViT 由若干堆叠的 transformer 层组成,每层包括一个自注意力机制(self-attention) [66] 和一个全连接 MLP。我们的编码器/预测器架构与生成式的掩码自编码器(masked autoencoders,MAE)方法 [36] 在结构上有些相似,但一个关键区别在:I-JEPA 是非生成式的(non-generative),并且其预测发生在表示空间(representation space)中,而非像 MAE 那样在像素空间中重建原始输入。
目标
我们首先描述如何在 I-JEPA 框架中生成目标:在 I-JEPA 中,目标对应于图像块的表示。给定输入图像 y,我们将其转换为N 个不重叠的图像块(patch),然后通过目标编码器(target-encoder) f θ ˉ f _ { \bar { \theta } } fθˉ将其送入,以获得相应的块级表示 s y = { s y 1 , … , s y N } s _ { y } \ =\{ s _ { y _ { 1 } } , \dots , s _ { y _ { N } } \} sy ={sy1,…,syN},其中 s y k s _ { y _ { k } } syk 是与第 k 个图像块相关联的表示。
为了获得损失函数的目标,我们从目标表示 s y s_y sy 中随机采样 MM 个(可能重叠的)块。我们用 B i B_i Bi 来表示第 i 个块的掩码,并用 s y ( i ) = { s y j } j ∈ B i s _ { y } ( i ) \, = \, \{ s _ { y _ { j } } \} _ { j \in B _ { i } } sy(i)={syj}j∈Bi表示其块级表示。通常,我们将 M 设置为 4,并且在(0.75, 1.5) 范围内随机选择长宽比,在 (0.15, 0.2) 范围内随机选择尺度。需要注意的是,目标块是通过掩码操作应用于目标编码器的输出,而非输入。这一区别至关重要,以确保目标表示具有较高的语义水平;例如,参见 [8]。
上下文
回顾一下,I-JEPA 的目标是从单个上下文块预测目标块的表示。为了获得 I-JEPA 中的上下文,我们首先从图像中随机采样一个上下文块 c c c,其尺度范围在 (0.85, 1.0) 之间,长宽比为 1。我们用 B x B_x Bx 表示与上下文块 c c c 相关联的掩码。由于目标块是从上下文块独立采样的,因此它们可能会有显著的重叠。为了确保任务具有非平凡的预测难度,我们从上下文块中移除任何重叠区域。图 4 展示了实际应用中各种上下文块和目标块的示例。接下来,带掩码的上下文块 c c c 被送入上下文编码器 f θ f_\theta fθ,以获得相应的块级表示 s x = { s x j } j ∈ B x s_x = \{ s_{x_j} \}_{j \in B_x} sx={sxj}j∈Bx。
预测
给定上下文编码器的输出 s x s_x sx,我们希望预测 M M M 个目标块的表示 s y ( 1 ) , … , s y ( M ) s_y(1), \ldots, s_y(M) sy(1),…,sy(M)。为此,对于给定的目标块 s y ( i ) s_y(i) sy(i),对应于目标掩码 B i B_i Bi,预测器 g ϕ ( ⋅ , ⋅ ) g_\phi(\cdot, \cdot) gϕ(⋅,⋅) 接受上下文编码器的输出 s x s_x sx 和我们希望预测的每个块的掩码标记 { m j } j ∈ B i \{ m_j \}_{j \in B_i} {mj}j∈Bi 作为输入,并输出一个块级的预测值 s ^ y ( i ) = { s ^ y j } j ∈ B i = g ϕ ( s x , { m j } j ∈ B i ) \hat{s}_y(i) = \{ \hat{s}_{y_j} \}_{j \in B_i} = g_\phi(s_x, \{ m_j \}_{j \in B_i}) s^y(i)={s^yj}j∈Bi=gϕ(sx,{mj}j∈Bi)。这些掩码标记由一个共享的可学习向量进行参数化,并添加位置嵌入。由于我们希望对 M M M 个目标块进行预测,因此我们将预测器应用 M 次,每次都条件化在我们希望预测的目标块位置的掩码标记上,并得到预测值 s ^ y ( 1 ) , … , s ^ y ( M ) \hat{s}_y(1), \dotsc, \hat{s}_y(M) s^y(1),…,s^y(M)。
损失
损失函数是预测的块级表示 s ^ y ( i ) \hat{s}_y(i) s^y(i) 与目标块级表示 s y ( i ) s_y(i) sy(i) 之间的平均 L 2 L_2 L2 距离;即:
1 M ∑ i = 1 M D ( s ^ y ( i ) , s y ( i ) ) = 1 M ∑ i = 1 M ∑ j ∈ B i ∥ s ^ y j − s y j ∥ 2 2 \frac { 1 } { M } \sum _ { i = 1 } ^ { M } D \left( \hat { s } _ { y } ( i ) , s _ { y } ( i ) \right) = \frac { 1 } { M } \sum _ { i = 1 } ^ { M } \sum _ { j \in { \cal B } _ { i } } \| \hat { s } _ { y _ { j } } - s _ { y _ { j } } \| _ { 2 } ^ { 2 } M1i=1∑MD(s^y(i),sy(i))=M1i=1∑Mj∈Bi∑∥s^yj−syj∥22
参数 Φ \Phi Φ 和上下文编码器 θ \theta θ 的参数通过基于梯度的优化进行学习,而目标编码器的参数通过上下文编码器参数的指数滑动平均进行更新。使用指数滑动平均的目标编码器对训练 JEAs 和 Vision Transformers(如 [18, 25, 79] 所示)至关重要,我们发现对于 I-JEPA 同样适用。
4. 相关工作
在视觉表示学习领域,许多研究探索了通过预测缺失或损坏的感官输入的值来学习视觉表示。去噪自编码器使用随机噪声作为输入损坏的手段[67]。上下文编码器基于周围的上下文回归整个图像区域[57]。其他工作将图像上色任务视为一种去噪任务[46, 47, 77]。
最近,图像去噪的思想在掩蔽图像建模的背景下得到了重新审视[9, 36, 71],其中使用 Vision Transformer [29] 来重建缺失的输入图像块。Masked Autoencoders (MAE) [36] 提出了一个高效的架构,仅需要编码器处理可见的图像块。通过在像素空间中重建缺失的图像块,MAE 在大规模标注数据集上进行端到端微调时取得了强大的性能,并展示了良好的扩展性。BEiT [9] 在经过离散VAE编码的标记空间中预测缺失图像块的值,具体地,通过在包含2.5亿张图像的数据集上训练的离散VAE对图像块进行标记。然而,像素级预训练已被证明在微调时优于 BEiT [36]。另一个工作,SimMIM [71],探索基于经典梯度直方图(Histogram of Gradients,HOG)特征空间的重建目标,并展示了在像素空间重建上的一些优势。与这些工作不同,我们的方法通过联合嵌入预测架构(Joint-Embedding Predictive Architecture,J-EPA)在训练过程中学习表示空间。我们的目标是学习不需要大量微调的语义表示,这些表示可以直接应用于下游任务。
与我们的工作最为接近的是 data2vec [8] 和 Context Autoencoders [25]。data2vec 方法学习通过在线目标编码器预测缺失图像块的表示;通过避免使用手工设计的增强方法,这一方法可以应用于各种模态,并在视觉、文本和语音任务中取得了良好的结果。Context Autoencoders 使用编码器/解码器架构,通过重建损失和对齐约束进行优化,从而强制要求在表示空间中预测缺失的图像块。与这些方法相比,I-JEPA 在计算效率上表现出显著的改进,并且能够学习更多语义层面的表示。与我们工作同时进行的 data2vec-v2 [7] 探讨了高效的架构,以便在不同模态上进行学习。
我们还将 I-JEPA 与基于联合嵌入架构的各种方法进行比较,例如 DINO [18]、MSN [4] 和 iBOT [79]。这些方法依赖于在预训练过程中使用手工设计的数据增强来学习语义图像表示。MSN [4] 使用掩蔽作为预训练中的额外数据增强,而 iBOT 将 data2vec 风格的图像块级重建损失与 DINO 的视图不变性损失结合起来。这些方法的共同点是需要处理每个输入图像的多个用户生成的视图,从而限制了其可扩展性。相比之下,I-JEPA 只需要处理每张图像的单个视图。我们发现,使用 I-JEPA 训练的 ViT-Huge/14 比使用 iBOT 训练的 ViT-Small/16 所需的计算资源要少。
温馨提示:
阅读全文请访问"AI深语解构" I-JEPA:基于联合嵌入预测架构的图像自监督学习