【代码解读】阿里最新开源视频生成模型 Wan 2.1 实现解析

发布于:2025-02-28 ⋅ 阅读:(16) ⋅ 点赞:(0)

昨晚阿里巴巴开源了最新视频生成模型的代码和权重,从给出的 demo 效果来看还是非常惊艳的。这个模型目前也是在 VBench 榜单上排到了第一名,超越了 Sora 以及 HunyuanVideo 等一系列知名方法。
截至写文章时的 VBench 榜单
从官方给出的方法架构图来说,Wan 2.1 并没有使用 MMDiT 的架构,而是基于普通的 DiT 架构,而文本条件则是通过 Cross Attention 实现注入。在文本编码器方面,Wan 2.1 采用了支持多语言的 UMT5 作为编码器,因此 prompt 部分或许能够原生支持中文输入。图中的 Wan-Encoder 和 Wan-Decoder 实际上就是视频生成模型常用的 3D Causal VAE,根据官方的说法,其支持无损时序信息编解码任意时长1080P视频。在时间编码方面,模型的所有 block 采用了统一的时间步编码器,并采取了类似 AdaLN 的方式将时间步编码进行注入。
Wan 2.1 模型架构
Wan 2.1 公布了不同尺寸的多个变体,小型的为 1.3B,想必是为了支持消费级显卡推理推出的一款模型;大型的为 14B,超过了 HunyuanVideo 的尺度,并且支持 720P 分辨率视频的生成。从表中的信息来看,不仅支持文生视频,同时也能够支持图生视频。

模型 支持 480P 分辨率 支持 720P 分辨率
T2V-14B 支持 支持
I2V-14B-720P 不支持 支持
I2V-14B-480P 支持 不支持
T2V-1.3B 支持 不支持

同时官方也已经给出了一些定量指标,目前看到的生成质量指标是由人工评测得到的,所以暂时先不分析。个人感觉比较重要的信息是这张图里的推理成本。从表中可以看出,1.3B 模型的峰值显存占用仅为 8 GB,且在单张消费级显卡上推理约 4 分钟即可的得到一段视频(而且这个结果是将 T5 模型卸载到 CPU 上得到的,所以如果把文本提前做离线 embedding,这个性能应当还有进一步提升),还是很可观的。不过 14B 模型的推理成本就比较高了,在单卡上的显存占用已经接近 80 GB,推理时间也来到了几千秒的数量级。
Wan 2.1 的推理成本

代码实现分析

首先可以看到的是,和其他的方法一样,Wan 2.1 也使用了 Classifier-Free Guidance(代码链接):

noise_pred_cond = self.model(
    latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model(
    latent_model_input, t=timestep, **arg_null)[0]

noise_pred = noise_pred_uncond + guide_scale * (
    noise_pred_cond - noise_pred_uncond)

对于图生视频任务,模型会使用 CLIP Vision Encoder 将图像进行编码作为 latents 中的第一帧,其余部分填充零,且加入一个 mask 通道(类似 inpainting 的做法,代码链接):

self.clip.model.to(self.device)
clip_context = self.clip.visual([img[:, None, :, :]])
if offload_model:
    self.clip.model.cpu()

y = self.vae.encode([
    torch.concat([
        torch.nn.functional.interpolate(
            img[None].cpu(), size=(h, w), mode='bicubic').transpose(
                0, 1),
        torch.zeros(3, 80, h, w)
    ],
                    dim=1).to(self.device)
])[0]
y = torch.concat([msk, y])

除了在 latents 上的调整,图生视频还会将图像的 CLIP 特征再次进行 embedding,并在 Cross Attention 时与文本图像拼接后共同作为条件进行生成(代码链接):

if clip_fea is not None:
    context_clip = self.img_emb(clip_fea)  # bs x 257 x dim
    context = torch.concat([context_clip, context], dim=1)

进入模型内部,可以看到一个现象是模型的输入并不是 batched tensor,而是一个 tensor 的列表,相当于把同一个批次拆分成了多个单个视频。在推理时也是遍历整个列表,可能因为模型的推理显存比较高,通过把批次拆开来节省显存。以 Patch Embedding 为例(代码链接):

x = [self.patch_embedding(u.unsqueeze(0)) for u in x]

对于模型的每个 block,其内部由一组 self attention 与一组 cross attention 组成,并且都按照 DiT 的方式进行了 modulation 操作(代码链接):

# self-attention
y = self.self_attn(
    self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
    freqs)
with amp.autocast(dtype=torch.float32):
    x = x + y * e[2]

# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
    x = x + self.cross_attn(self.norm3(x), context, context_lens)
    y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
    with amp.autocast(dtype=torch.float32):
        x = x + y * e[5]
    return x

x = cross_attn_ffn(x, context, context_lens, e)

在图生视频的 attention 中,将文本 token 与图像 token 进行了拆分,分别与 latent 计算 cross attention,然后再将两组结果相加得到最后的交叉注意力结果(代码链接):

def forward(self, x, context, context_lens):
    r"""
    Args:
        x(Tensor): Shape [B, L1, C]
        context(Tensor): Shape [B, L2, C]
        context_lens(Tensor): Shape [B]
    """
    context_img = context[:, :257]
    context = context[:, 257:]
    b, n, d = x.size(0), self.num_heads, self.head_dim

    # compute query, key, value
    q = self.norm_q(self.q(x)).view(b, -1, n, d)
    k = self.norm_k(self.k(context)).view(b, -1, n, d)
    v = self.v(context).view(b, -1, n, d)
    k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
    v_img = self.v_img(context_img).view(b, -1, n, d)
    img_x = flash_attention(q, k_img, v_img, k_lens=None)
    # compute attention
    x = flash_attention(q, k, v, k_lens=context_lens)

    # output
    x = x.flatten(2)
    img_x = img_x.flatten(2)
    x = x + img_x
    x = self.o(x)
    return x

在计算 self-attention 时,也使用了 RoPE(代码链接):

x = flash_attention(
    q=rope_apply(q, grid_sizes, freqs),
    k=rope_apply(k, grid_sizes, freqs),
    v=v,
    k_lens=seq_lens,
    window_size=self.window_size)

目前来说从代码里能够看到的比较有用的信息就是这些,由于具体的 report 还没有放出来,所以关于数据的细节目前不太清楚(据说用了 1.5B 视频数据和 10B 图像数据),也期待一下技术报告早日公布。