系列文章目录
前言
特征图 (Latent) 尺寸和注意力图(attention map)尺寸在扩散模型中有差异,是由于模型架构和注意力机制的特性决定的。
特征图和注意力图的尺寸差异原因
不同的功能目的
- 特征图(Feature Maps):承载图像的语义和视觉特征,维持空间结构
- 注意力图(Attention Maps):表示不同位置之间的关联强度,是一种关系矩阵
UNet架构中的特征图尺寸
在U-Net中,特征图的尺寸在不同层级有变化:- 输入图像通常是 512×512 或 256×256
- 下采样路径(Encoder):尺寸逐渐缩小 (512→256→128→64→32→16…)
- 上采样路径(Decoder):尺寸逐渐增大 (16→32→64→128→256→512…)
在Break-a-Scene代码中,我们看到特征图尺寸被下采样到64×64:
downsampled_mask = F.interpolate(input=max_masks, size=(64, 64))
注意力机制中的尺寸计算
注意力机制处理的是"token"之间的关系,其中:- 自注意力(Self-Attention):特征图中的每个位置视为一个token
- 交叉注意力(Cross-Attention):文本序列中的token与特征图中的位置建立关联
如果特征图尺寸是h×w,则自注意力矩阵的尺寸是(hw)×(hw),这是一个平方关系
在代码中,注意力图通常被下采样到16×16:
GT_masks = F.interpolate(input=batch["instance_masks"][batch_idx], size=(16, 16))
计算效率考虑
- 注意力计算的复杂度是O(n²),其中n是token数量
- 对于64×64的特征图,如果直接计算自注意力,需要处理4096×4096的矩阵
- 为了降低计算量,通常在较低分辨率(如16×16)的特征图上计算注意力,这样只需处理256×256的矩阵
在Break-a-Scene中的具体实现
在Break-a-Scene中,这些尺寸差异体现在:
两种不同的损失计算:
a. 掩码损失(Masked Loss):应用在64×64的 Latent 上
max_masks = torch.max(batch["instance_masks"], axis=1).values downsampled_mask = F.interpolate(input=max_masks, size=(64, 64)) model_pred = model_pred * downsampled_mask target = target * downsampled_mask
b. 注意力损失(Attention Loss):应用在16×16的注意力图上
GT_masks = F.interpolate(input=batch["instance_masks"][batch_idx], size=(16, 16)) agg_attn = self.aggregate_attention(res=16, from_where=("up", "down"), is_cross=True, select=batch_idx)
注意力存储的筛选:
在存储注意力图时,只保留小尺寸的注意力图:
def forward(self, attn, is_cross: bool, place_in_unet: str): key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" if attn.shape[1] <= 32**2: # 只保存小于或等于32×32的注意力图 self.step_store[key].append(attn) return attn
注意力聚合:
在聚合不同层的注意力时,确保只使用匹配目标分辨率的注意力图:
def aggregate_attention(self, res: int, from_where: List[str], is_cross: bool, select: int): # ... num_pixels = res**2 for location in from_where: for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: if item.shape[1] == num_pixels: # 只选择匹配分辨率的注意力图 cross_maps = item.reshape(self.args.train_batch_size, -1, res, res, item.shape[-1])[select] out.append(cross_maps) # ...
总结
特征图和注意力图尺寸的差异主要是因为:
- 它们在模型中的功能不同
- 注意力计算的计算复杂度要求在较低分辨率上进行
- UNet架构中的不同层级有不同的特征图尺寸
- 为了平衡精度和计算效率,Break-a-Scene使用不同分辨率的特征图和注意力图来计算不同类型的损失
这种设计使得Break-a-Scene能够有效地学习token与图像区域之间的对应关系,同时保持计算效率。