摘要
本周阅读了一篇25年二月份发表于CVPR 的论文《Attention Distillation: A Unified Approach to Visual Characteristics Transfer》,论文开发了Attention Distillation引导采样,这是一种改进的分类器引导方法,将注意力蒸馏损失整合到去噪过程中,大大加快了合成速度,并支持广泛的视觉特征迁移和合成应用。
abstract
This week I read a paper published in CVPR in February, "Attention Distillation: A Unified Approach to Visual Characteristics Transfer, this paper develops the Attention Distillation guided sampling, which is an improved classifier guided method to integrate the attention distillation loss into the denoising process. It greatly speeds up synthesis and supports a wide range of visual feature migration and synthesis applications.
下图中是给定参考图,文生图的示例:
论文摘要
最近扩散模型方面的进展显示了对图像风格和语义的内在理解。论文提出了一种新颖的注意力蒸馏损失,通过在潜在空间中反向传播来优化合成图像,同时改进了一个分类器引导,它将注意力蒸馏损失集成到去噪采样过程中,进一步加速合成过程。
简介
论文解决问题: 现有生成扩散模型在图像风格和语义理解方面虽然有进展,但在将参考图像的视觉特征转移到生成图像中时,使用即插即用注意力特征的方法存在局限性。
传统的方法通常将纹理定义为重复的局部模式,并通过从源图像中复制局部补丁来合成新的纹理。通常归结为以下三个原因导致的局限性:
- 域差距:当两幅图像存在显著差异时,目标Q(合成图像的查询)与参考图像的K,V之间的相似性较低且不可靠,导致错误的聚合结果(AdaIN和注意力能缓解这个问题)
- 误差积累:虽然扩散模型中的迭代采样过程可以改善目标Q和参考图中的K,V之间的巨大差异,但误差也可能积累。来自不同扩散模型层的特征集中于不同的信息,如语义和几何。不正确的匹配将会错误传播到马尔科夫链的后续层,并降低最终图像质量。
- 框架限制:在去噪网络的剩余分支内实现自注意力机制,参考图像中的自注意力特征可能对目标图像有潜在的影响,降低了合成的效力。
为了解决上述局限性,本篇论文中引入一种新的注意力蒸馏损失AD loss,在此基础上,通过反向传播直接更新合成的图像。
提出方案: 首先,提出了一种新颖的注意力蒸馏损失,用于在理想和当前风格化结果之间计算损失,并在隐空间中通过反向传播优化合成图像。其次,开发了一种改进的分类器引导方法,即注意力蒸馏引导采样,将注意力蒸馏损失整合到去噪采样过程中。
方法
预备知识
隐空间扩散模型(LDM),如Stable Diffusion,由于其对复杂数据分布的强大建模能力,在图像生成方面达到了最先进的性能。在LDM中,首先使用预训练的VAE 将图像x压缩到一个学习到的隐空间中。随后,基于UNet的去噪网络被训练用于在扩散过程中预测噪声,通过最小化预测噪声与实际添加噪声之间的均方误差来实现。
L L D M = E z ∼ E ( x ) , y , ϵ ∼ N ( 0 , 1 ) , t [ ∥ ϵ θ ( z t , t , y ) − ϵ ∥ 2 2 ] \mathcal{L}_{\mathrm{LDM}}=\mathbb{E}_{z\sim\mathcal{E}(x),y,\epsilon\sim\mathcal{N}(0,1),t}\left[\|\epsilon_\theta(z_t,t,y)-\epsilon\|_2^2\right] LLDM=Ez∼E(x),y,ϵ∼N(0,1),t[∥ϵθ(zt,t,y)−ϵ∥22]
其中 y 表示条件, 表示时间步长。去噪 UNet 通常由一系列卷积块和自注意力/交叉注力模块组成,所有这些都集成在残差架构的预测分支中。
KV注入在图像编辑、风格迁移和纹理合成中被广泛使用。它建立在自注意力机制之上,并将扩散模型中的自注意力特征用作即插即用的属性。自注意力机制的公式为:
S e l f − A t t n ( Q , K , V ) = s o f t m a x ( Q K T d ) V \mathrm{Self-Attn}(Q,K,V)=\mathrm{softmax}(\frac{QK^{T}}{\sqrt{d}})V Self−Attn(Q,K,V)=softmax(dQKT)V
在注意力机制的核心,是基于查询Q和键K之间的相似性计算权重矩阵,该矩阵用于对值V进行加权聚合。KV注入通过在不同的合成分支之间复制或共享KV特征来扩展这一机制。其关键假设是KV特征代表图像的视觉外观。在采样过程中,将合成分支中的KV特征替换为示例的相应时间步长的KV特征,可以实现从源图像到合成目标的外观转移。
注意力蒸馏损失
尽管KV注入取得了显著的效果,但由于残差机制的影响,它在保留参考的风格或纹理细节方面表现不足;例如,下图(a)中。KV注入仅作用于残差,这意味着信息流(红色箭头)随后受到恒等连接的影响,导致信息传递不完整。因此,采样输出无法完全再现所需的视觉细节。
本论文通过在自注意力机制中重新聚合特征来提取视觉元素。利用预训练的T2I扩散模型SD的UNet,从自注意力模块中提取图像特征。
上图中,首先根据目标分支的Q,从参考分支重新聚合KV特征(Ks和Vs)的视觉信息,这与KV注入相同。
将此注意力输出视为理想的风格化。然后,我们计算目标分支的注意力输出,并计算相对于理想注意力输出的L1损失,这定义了AD损失:
L A D = ∥ S e l f − A t t n ( Q , K , V ) − S e l f − A t t n ( Q , K s , V s ) ∥ 1 \mathcal{L}_{\mathrm{AD}}=\|\mathrm{Self-Attn}(Q,K,V)-\mathrm{Self-Attn}(Q,K_{s},V_{s})\|_{1} LAD=∥Self−Attn(Q,K,V)−Self−Attn(Q,Ks,Vs)∥1
可以使用提出的AD损失通过梯度下降来优化随机隐空间噪声,从而在输出中实现生动的纹理或风格再现;例如,参见上图(b)。这归因于优化中的反向传播,它不仅允许信息在(残差)自注意力模块中流动,还通过恒等连接流动。通过持续优化,Q和Ks之间的差距逐渐缩小,使得注意力越来越准确,最终特征被正确聚合以产生所需的视觉细节。
注意力引导采样
将注意力蒸馏损失以改进的分类器引导方式纳入扩散模型的采样过程中。
分类器引导在去噪过程中改变去噪方向,从而生成来自p(zt|c)的样本,其公式可以表示为:
ϵ ^ θ = ϵ θ ( z t , t , y ) − α σ t ∇ z t log p ( c ∣ z t ) \hat{\epsilon}_\theta=\epsilon_\theta(z_t,t,y)-\alpha\sigma_t\nabla_{z_t}\log p(c|z_t) ϵ^θ=ϵθ(zt,t,y)−ασt∇ztlogp(c∣zt)
其中,t是时间步长,y表示提示, ϵ θ \epsilon_\theta ϵθ和 z t \ {z_t} zt分别指去噪网络和LDM中的隐空间变量。 α \alpha α控制引导强度。使用基于注意力蒸馏损失的能量函数来引导扩散采样过程。
实验
由于补丁来源有限,使用传统方法合成超高分辨率纹理非常困难。在此,将注意力蒸馏引导的采样应用于MultiDiffusion模型,使纹理扩展到任意分辨率。尽管SD-1.5是在尺寸为512×512的图像上训练的,但令人惊讶的是,当结合注意力蒸馏时,它在大尺寸纹理合成中表现出了强大的能力。下图展示了将纹理扩展到512×1536的尺寸与GCD和GPDM的比较。
损失函数代码
def ad_loss(
q_list, ks_list, vs_list, self_out_list, scale=1, source_mask=None, target_mask=None
):
loss = 0
attn_mask = None
for q, ks, vs, self_out in zip(q_list, ks_list, vs_list, self_out_list):
if source_mask is not None and target_mask is not None:
w = h = int(np.sqrt(q.shape[2]))
mask_1 = torch.flatten(F.interpolate(source_mask, size=(h, w)))
mask_2 = torch.flatten(F.interpolate(target_mask, size=(h, w)))
attn_mask = mask_1.unsqueeze(0) == mask_2.unsqueeze(1)
attn_mask=attn_mask.to(q.device)
target_out = F.scaled_dot_product_attention(
q * scale,
torch.cat(torch.chunk(ks, ks.shape[0]), 2).repeat(q.shape[0], 1, 1, 1),
torch.cat(torch.chunk(vs, vs.shape[0]), 2).repeat(q.shape[0], 1, 1, 1),
attn_mask=attn_mask
)
loss += loss_fn(self_out, target_out.detach())
return loss
def q_loss(q_list, qc_list):
loss = 0
for q, qc in zip(q_list, qc_list):
loss += loss_fn(q, qc.detach())
return loss
# weight = 200
def qk_loss(q_list, k_list, qc_list, kc_list):
loss = 0
for q, k, qc, kc in zip(q_list, k_list, qc_list, kc_list):
scale_factor = 1 / math.sqrt(q.size(-1))
self_map = torch.softmax(q @ k.transpose(-2, -1) * scale_factor, dim=-1)
target_map = torch.softmax(qc @ kc.transpose(-2, -1) * scale_factor, dim=-1)
loss += loss_fn(self_map, target_map.detach())
return loss
# weight = 1
def qkv_loss(q_list, k_list, vc_list, c_out_list):
loss = 0
for q, k, vc, target_out in zip(q_list, k_list, vc_list, c_out_list):
self_out = F.scaled_dot_product_attention(q, k, vc)
loss += loss_fn(self_out, target_out.detach())
return loss
下面这段代码主要通过自适应特征提取和优化,将内容图像的潜变量 (latents) 调整为具有风格图像特征的潜变量,实现风格迁移(Style Transfer)或风格控制 (Style-Adaptive Denoising, AD)。
1.使用了一种基于 AdaIN (Adaptive Instance Normalization) 的方法对 latents 进行风格调整:
if self.adain:
noise = torch.randn_like(self.style_latent)
style_latent = self.scheduler.add_noise(self.style_latent, noise, t)
latents = utils.adain(latents, style_latent)
2.提取风格和内容特征:
qs_list, ks_list, vs_list, s_out_list = self.extract_feature(
self.style_latent,
t,
self.null_embeds_for_style,
add_noise=True,
)
if self.content_latent is not None:
qc_list, kc_list, vc_list, c_out_list = self.extract_feature(
self.content_latent,
t,
self.null_embeds,
add_noise=True,
)
3.优化 latents 使其匹配风格和内容特征:
optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
optimizer = self.accelerator.prepare(optimizer)
在 iters 轮优化中,计算损失 (style_loss 和 content_loss),并进行反向传播:
for j in range(iters):
style_loss = ad_loss(q_list, ks_list, vs_list, self_out_list, scale=self.attn_scale)
if self.content_latent is not None:
content_loss = q_loss(q_list, qc_list)
loss = style_loss + content_loss * weight
self.accelerator.backward(loss)
optimizer.step()
结论
这篇论文提出了一种统一的方法来处理各种视觉特征转移任务,包括风格/外观转移、特定风格的图像生成和纹理合成。该方法的关键是一种新颖的注意力蒸馏损失,它计算理想风格化与当前风格化之间的差异,并逐步修改合成。
总结
这篇论文提出了一种基于注意力蒸馏(Attention Distillation, AD)的新方法,用于改进扩散模型在视觉特征迁移任务中的表现。作者引入注意力蒸馏损失(AD Loss),通过反向传播优化合成图像,使其更好地匹配目标风格。此外,论文提出注意力蒸馏引导采样,将AD Loss整合到去噪过程中,加快图像合成速度,并提升细节保真度。实验表明,该方法在风格迁移、特定风格图像生成和纹理合成等任务中均优于现有技术,特别是在高分辨率纹理生成方面表现突出。该方法通过改进查询-键-值(Q-K-V)特征聚合,有效缓解域差距、误差积累和框架限制问题。
参考文献
[1] Attention Distillation: A Unified Approach to Visual Characteristics Transfer