Phantom 根据图片和文字描述,自动生成一段视频,并且动作、场景等内容会按照文字描述来呈现

发布于:2025-05-23 ⋅ 阅读:(23) ⋅ 点赞:(0)

Phantom 根据图片和文字描述,自动生成一段视频,并且动作、场景等内容会按照文字描述来呈现

flyfish

视频生成的实践效果展示

Phantom 视频生成的实践
Phantom 视频生成的流程
Phantom 视频生成的命令

Wan2.1 图生视频 支持批量生成
Wan2.1 文生视频 支持批量生成、参数化配置和多语言提示词管理
Wan2.1 加速推理方法
Wan2.1 通过首尾帧生成视频

AnyText2 在图片里玩文字而且还是所想即所得
Python 实现从 MP4 视频文件中平均提取指定数量的帧

配置

{
    "task": "s2v-1.3B",
    "size": "832*480",
    "frame_num": 81,
    "ckpt_dir": "./Wan2.1-T2V-1.3B",
    "phantom_ckpt": "./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth",
    "offload_model": false,
    "ulysses_size": 1,
    "ring_size": 1,
    "t5_fsdp": false,
    "t5_cpu": false,
    "dit_fsdp": false,
    "use_prompt_extend": false,
    "prompt_extend_method": "local_qwen",
    "prompt_extend_model": null,
    "prompt_extend_target_lang": "ch",
    "base_seed": 40,
    "sample_solver": "unipc",
    "sample_steps": null,
    "sample_shift": null,
    "sample_guide_scale": 5.0,
    "sample_guide_scale_img": 5.0,
    "sample_guide_scale_text": 7.5
}

参数

一、参数作用解析

1. 任务与模型路径
  • task: "s2v-1.3B"

    • 作用:指定任务类型为 文本到视频生成(Text-to-Video,T2V)1.3B 表示使用的基础模型(如 Wan2.1-T2V-1.3B)参数规模为 130亿
  • ckpt_dir: "./Wan2.1-T2V-1.3B"

    • 作用:指定基础模型的权重文件路径。根据你之前提供的文件夹内容,该路径下包含:
      • models_t5_umt5-xxl-enc-bf16.pthT5文本编码器的权重文件(用于处理文本提示)。
      • Wan2.1_VAE.pthVAE模型的权重(用于视频的时空压缩和重建)。
      • google/umt5-xxl文件夹:可能包含T5模型的结构定义或配置文件。
  • phantom_ckpt: "./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth"

    • 作用:指定 Phantom跨模态对齐模型 的权重路径,用于锁定参考图像的主体特征(如颜色、轮廓),确保生成视频中主体与参考图像一致。
2. 视频生成配置
  • size: "832*480"

    • 作用:生成视频的分辨率,格式为 宽度×高度,因此 832是宽度,480是高度。例如,常见的16:9分辨率中,宽度大于高度。
  • frame_num: 81

    • 作用:生成视频的总帧数。假设帧率为24fps,81帧约为3.375秒的视频(实际时长取决于帧率设置)。
3. 模型性能与资源配置
  • offload_model: false

    • 作用:是否将模型参数卸载到CPU或磁盘以节省GPU内存。设为false时,模型全程运行在GPU内存中,速度更快但需更大显存。
  • ulysses_size: 1ring_size: 1

    • 作用:与分布式训练(如FSDP)相关的参数,用于多卡并行计算。设为1时,表示 单卡运行,不启用分布式分片。
  • t5_fsdp: falset5_cpu: false

    • t5_fsdp:是否对T5文本编码器使用全分片数据并行(FSDP),false表示单卡加载T5模型。
    • t5_cpu:是否将T5模型放在CPU上运行,false表示运行在GPU上(推荐,速度更快)。
  • dit_fsdp: false

    • 作用:是否对扩散Transformer(DIT,Diffusion Transformer)使用FSDP,false表示单卡运行。
4. 提示与生成控制
  • use_prompt_extend: false

    • 作用:是否启用提示扩展功能(增强文本提示的语义丰富度)。设为false时,直接使用输入的文本提示,不进行扩展。
  • prompt_extend_method: "local_qwen"prompt_extend_model: null

    • 作用:提示扩展方法指定为本地Qwen模型,但prompt_extend_model设为null表示未加载该模型,因此扩展功能实际未启用。
  • prompt_extend_target_lang: "ch"

    • 作用:提示扩展的目标语言为中文,若启用扩展功能,会将中文提示转换为更复杂的语义表示。
5. 随机种子与生成算法
  • base_seed: 40
    • 作用:随机种子,用于复现相同的生成结果。固定种子后,相同提示和参数下生成的视频内容一致。

二、分辨率宽度确认

  • 832*480 中,832是宽度,480是高度
    • 分辨率的表示规则为 宽度×高度(Width×Height),例如:
      • 1080p是1920×1080(宽1920,高1080),
      • 这里的832×480接近16:9的比例(832÷480≈1.733,接近16:9的1.777)。

采样参数设置

一、核心参数解析

1. sample_solver: "unipc"
  • 作用:指定采样算法(扩散模型生成视频的核心求解器)。
    • UniPC(Unified Predictor-Corrector):一种高效的数值积分方法,适用于扩散模型采样,兼顾速度与生成质量,支持动态调整步长,在较少步数下可实现较好效果。
    • 对比其他 solver:相比传统的 DDIM/PLMS 等算法,UniPC 在相同步数下生成细节更丰富,尤其适合视频生成的时空连贯性优化。
2. sample_steps: 50
  • 作用:采样过程中执行的扩散步数(从噪声反向生成清晰样本的迭代次数)。
    • 数值影响
      • 50步:中等计算量,适合平衡速度与质量。步数不足可能导致细节模糊、动态不连贯;步数过高(如100+)会增加耗时,但收益可能边际递减。
      • 建议场景:若追求快速生成,可设为30-50;若需高保真细节(如复杂光影、精细纹理),可尝试60-80步。
3. sample_shift: 5.0
  • 作用控制跨帧生成时的时间步长偏移或运动连贯性约束。
    • 在视频生成中,相邻帧的生成需考虑时间序列的连续性,sample_shift 可能用于调整帧间采样的时间相关性(如抑制突然运动或增强动态平滑度)。
    • 数值较高(如5.0)可能增强帧间约束,减少闪烁或跳跃,但可能限制剧烈动作的表现力;数值较低(如1.0-2.0)允许更自由的动态变化。
4. 引导尺度参数(guide_scale 系列)

引导尺度控制文本提示和参考图像对生成过程的约束强度,数值越高,生成结果越贴近输入条件,但可能导致多样性下降或过拟合。

  • sample_guide_scale: 5.0(通用引导尺度):

    • 全局控制文本+图像引导的综合强度,若未单独设置 img/text 参数,默认使用此值。
  • sample_guide_scale_img: 5.0(图像引导尺度):

    • 参考图像对生成的约束强度(适用于图生视频 s2v 任务)。
    • 5.0 含义:中等强度,生成内容会保留参考图像的视觉特征(如颜色、构图、主体形态),但允许一定程度的变化(如视角调整、动态延伸)。
  • sample_guide_scale_text: 7.5(文本引导尺度):

    • 文本提示对生成的约束强度,数值显著高于图像引导(7.5 > 5.0),表明:
      • 优先遵循文本描述:生成内容会严格匹配文本语义(如“夕阳下的海滩”“机械恐龙奔跑”),可能牺牲部分图像参考的细节。
      • 风险与收益:高文本引导可能导致图像参考的视觉特征(如主体颜色、背景元素)被覆盖,需确保文本与图像语义一致(如文本描述需包含图像中的关键视觉元素)。

多组提示词+多参考图像输入

举例说明

[
    {
        "prompt": "内容",
        "image_paths": ["examples/1.jpg","examples/3.jpg"]
    },
    {
        "prompt": "内容",
        "image_paths": ["examples/2.jpg","examples/3.jpg"]
    }
    ,
    {
        "prompt": "内容",
        "image_paths": ["examples/3.png","examples/8.jpg"]
    }
]

一、JSON输入结构解析

prompt.json包含3组生成任务,每组结构为:

{
  "prompt": "文本提示词",       // 描述生成内容的语义(如“猫在草地跳跃”)
  "image_paths": ["图1路径", "图2路径"]  // 用于主体对齐的参考图像列表(支持多图)
}

关键特点

  • 每组任务可包含 1~N张参考图像(如examples/1.jpgexamples/3.jpg共同定义主体)。
  • 多图输入时,模型会自动融合多张图像的特征,适用于需要捕捉主体多角度、多姿态的场景(如生成人物行走视频时,用正面+侧面照片定义体型)。

二、处理多参考图像

1. 加载与预处理阶段(load_ref_images函数)
  • 输入image_paths列表(如["examples/1.jpg", "examples/3.jpg"])。
  • 处理逻辑
    1. 逐张加载图像,转换为RGB格式。
    2. 对每张图像进行保持比例缩放+中心填充,统一为目标尺寸(如832×480):
      • 若图像宽高比与目标尺寸不一致,先按比例缩放至最长边等于目标边,再用白色填充短边
    3. 输出ref_images列表,每张图像为PIL.Image对象,尺寸均为832×480
2. 模型生成阶段(Phantom_Wan_S2V.generate
  • 输入ref_images列表(多图) + prompt文本。
  • 核心逻辑
    1. 多图特征融合
      跨模态模型(Phantom-Wan)会提取每张参考图像的主体特征(如颜色、轮廓),并计算平均特征向量动态特征融合(根据图像顺序加权),形成对主体的综合描述。
    2. 动态对齐
      在生成视频的每一帧时,模型会同时参考所有输入图像的特征,确保主体在不同视角下的一致性(如正面图像约束面部特征,侧面图像约束身体比例)。

三、例子

场景1:复杂主体多角度定义
  • 需求:生成一个“机器人从左侧走向右侧”的视频,需要机器人正面和侧面外观一致。
  • 输入
    {
      "prompt": "银色机器人在灰色地面行走,头部有蓝色灯光",
      "image_paths": ["robot_front.jpg", "robot_side.jpg"]  // 正面+侧面图
    }
    
  • 效果
    • 视频中机器人正面视角时匹配robot_front.jpg的面部细节。
    • 转向侧面时匹配robot_side.jpg的身体轮廓和机械结构。
场景2:主体特征互补
  • 需求:修复单张图像缺失的细节(如证件照生成生活视频)。
  • 输入
    {
      "prompt": "穿蓝色衬衫的人在公园跑步,风吹动头发",
      "image_paths": ["id_photo.jpg", "hair_reference.jpg"]  // 证件照+发型参考图
    }
    
  • 效果
    • 主体面部和服装来自证件照,头发动态和颜色来自hair_reference.jpg,解决证件照中头发静止的问题。
场景3:多主体生成
  • 需求:生成“两个人握手”的视频,两人外观分别来自不同图像。
  • 输入
    {
      "prompt": "穿西装的男人和穿裙子的女人在会议室握手",
      "image_paths": ["man.jpg", "woman.jpg"]  // 两人的参考图像
    }
    
  • 效果
    • 模型自动识别图像中的两个主体,分别对齐到视频中的对应人物,确保两人外观与参考图像一致。

四、调用示例

1. 终端命令
python main.py --config_file config.json --prompt_file prompt.json
2. 生成结果示例

假设输入为:

[
  {
    "prompt": "戴帽子的狗在雪地里打滚",
    "image_paths": ["dog_front.jpg", "dog_side.jpg"]
  }
]

生成的视频中:

  • 狗的头部特征(如眼睛、鼻子)来自dog_front.jpg
  • 身体姿态和帽子形状来自dog_side.jpg
  • 雪地、打滚动作由文本提示驱动生成。

多参考图像输入是Phantom-Wan实现复杂主体动态生成的核心能力之一,通过融合多张图像的特征。

完整代码

import argparse
from datetime import datetime
import logging
import os
import sys
import warnings
import json
import time
from uuid import uuid4  # 新增:用于生成唯一标识符

warnings.filterwarnings('ignore')

import torch, random
import torch.distributed as dist
from PIL import Image, ImageOps

import phantom_wan
from phantom_wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from phantom_wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from phantom_wan.utils.utils import cache_video, cache_image, str2bool

def _validate_args(args):
    """参数验证函数"""
    # 基础检查
    assert args.ckpt_dir is not None, "请指定检查点目录"
    assert args.phantom_ckpt is not None, "请指定Phantom-Wan检查点"
    assert args.task in WAN_CONFIGS, f"不支持的任务: {args.task}"

    args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)
    
    # 尺寸检查["832*480", "480*832"]
    assert args.size in SUPPORTED_SIZES[args.task], \
        f"任务{args.task}不支持尺寸{args.size},支持尺寸:{', '.join(SUPPORTED_SIZES[args.task])}"

def _parse_args():
    """参数解析函数"""
    parser = argparse.ArgumentParser(description="使用Phantom生成视频")
    parser.add_argument("--config_file", type=str, default="config.json", help="配置JSON文件路径")
    parser.add_argument("--prompt_file", type=str, default="prompt.json", help="提示词JSON文件路径")
    
    args = parser.parse_args()
    
    # 从配置文件加载参数
    with open(args.config_file, 'r') as f:
        config = json.load(f)
    for key, value in config.items():
        setattr(args, key, value)
    
    _validate_args(args)
    return args

def _init_logging(rank):
    """日志初始化函数"""
    if rank == 0:
        logging.basicConfig(
            level=logging.INFO,
            format="[%(asctime)s] %(levelname)s: %(message)s",
            handlers=[logging.StreamHandler(stream=sys.stdout)]
        )
    else:
        logging.basicConfig(level=logging.ERROR)

def load_ref_images(path, size):
    """加载参考图像并预处理"""
    h, w = size[1], size[0]  # 尺寸格式转换
    ref_images = []
    
    for image_path in path:
        with Image.open(image_path) as img:
            img = img.convert("RGB")
            img_ratio = img.width / img.height
            target_ratio = w / h

            # 保持比例缩放
            if img_ratio > target_ratio:
                new_width = w
                new_height = int(new_width / img_ratio)
            else:
                new_height = h
                new_width = int(new_height * img_ratio)
            
            img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
            
            # 中心填充至目标尺寸
            delta_w = w - img.size[0]
            delta_h = h - img.size[1]
            padding = (delta_w//2, delta_h//2, delta_w-delta_w//2, delta_h-delta_h//2)
            new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
            ref_images.append(new_img)
    
    return ref_images

def generate(args):
    """主生成函数"""
    rank = int(os.getenv("RANK", 0))
    world_size = int(os.getenv("WORLD_SIZE", 1))
    local_rank = int(os.getenv("LOCAL_RANK", 0))
    device = local_rank
    _init_logging(rank)

    # 分布式环境配置
    if world_size > 1:
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size)
    
    # 模型并行配置
    if args.ulysses_size > 1 or args.ring_size > 1:
        assert args.ulysses_size * args.ring_size == world_size, "ulysses_size与ring_size乘积需等于总进程数"
        from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
        init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
        initialize_model_parallel(
            sequence_parallel_degree=dist.get_world_size(),
            ring_degree=args.ring_size,
            ulysses_degree=args.ulysses_size,
        )

    # 提示词扩展初始化
    prompt_expander = None
    if args.use_prompt_extend:
        if args.prompt_extend_method == "dashscope":
            prompt_expander = DashScopePromptExpander(
                model_name=args.prompt_extend_model, 
                is_vl="i2v" in args.task
            )
        elif args.prompt_extend_method == "local_qwen":
            prompt_expander = QwenPromptExpander(
                model_name=args.prompt_extend_model,
                is_vl="i2v" in args.task,
                device=rank
            )
        else:
            raise NotImplementedError(f"不支持的提示词扩展方法: {args.prompt_extend_method}")

    # 模型初始化(仅加载一次)
    cfg = WAN_CONFIGS[args.task]
    logging.info(f"初始化模型,任务类型: {args.task}")
    
    if "s2v" in args.task:
        # 视频生成(参考图像输入)
        wan = phantom_wan.Phantom_Wan_S2V(
            config=cfg,
            checkpoint_dir=args.ckpt_dir,
            phantom_ckpt=args.phantom_ckpt,
            device_id=device,
            rank=rank,
            t5_fsdp=args.t5_fsdp,
            dit_fsdp=args.dit_fsdp,
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
            t5_cpu=args.t5_cpu,
        )
    elif "t2v" in args.task or "t2i" in args.task:
        # 文本生成(图像/视频)
        wan = phantom_wan.WanT2V(
            config=cfg,
            checkpoint_dir=args.ckpt_dir,
            device_id=device,
            rank=rank,
            t5_fsdp=args.t5_fsdp,
            dit_fsdp=args.dit_fsdp,
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
            t5_cpu=args.t5_cpu,
        )
    else:
        # 图像生成视频(i2v)
        wan = phantom_wan.WanI2V(
            config=cfg,
            checkpoint_dir=args.ckpt_dir,
            device_id=device,
            rank=rank,
            t5_fsdp=args.t5_fsdp,
            dit_fsdp=args.dit_fsdp,
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
            t5_cpu=args.t5_cpu,
        )

    # 加载提示词列表
    with open(args.prompt_file, 'r') as f:
        prompts = json.load(f)
    
    total_generation_time = 0
    generation_counter = 0  # 新增:生成计数器防止文件名重复

    for prompt_info in prompts:
        prompt = prompt_info["prompt"]
        image_paths = prompt_info.get("image_paths", [])  # 处理可能不存在的键
        start_time = time.time()

        # 分布式环境同步种子
        if dist.is_initialized():
            base_seed = [args.base_seed] if rank == 0 else [None]
            dist.broadcast_object_list(base_seed, src=0)
            args.base_seed = base_seed[0]

        # 提示词扩展处理
        if args.use_prompt_extend and rank == 0:
            logging.info("正在扩展提示词...")
            if "s2v" in args.task or "i2v" in args.task and image_paths:
                img = Image.open(image_paths[0]).convert("RGB")
                prompt_output = prompt_expander(prompt, image=img, seed=args.base_seed)
            else:
                prompt_output = prompt_expander(prompt, seed=args.base_seed)
            
            if not prompt_output.status:
                logging.warning(f"提示词扩展失败: {prompt_output.message}, 使用原始提示词")
                input_prompt = prompt
            else:
                input_prompt = prompt_output.prompt
            
            # 分布式广播扩展后的提示词
            input_prompt = [input_prompt] if rank == 0 else [None]
            if dist.is_initialized():
                dist.broadcast_object_list(input_prompt, src=0)
            prompt = input_prompt[0]
            logging.info(f"扩展后提示词: {prompt}")

        # 执行生成
        logging.info(f"开始生成,提示词: {prompt}")
        if "s2v" in args.task:
            ref_images = load_ref_images(image_paths, SIZE_CONFIGS[args.size])
            video = wan.generate(
                prompt,
                ref_images,
                size=SIZE_CONFIGS[args.size],
                frame_num=args.frame_num,
                shift=args.sample_shift,
                sample_solver=args.sample_solver,
                sampling_steps=args.sample_steps,
                guide_scale_img=args.sample_guide_scale_img,
                guide_scale_text=args.sample_guide_scale_text,
                seed=args.base_seed,
                offload_model=args.offload_model
            )
        elif "t2v" in args.task or "t2i" in args.task:
            video = wan.generate(
                prompt,
                size=SIZE_CONFIGS[args.size],
                frame_num=args.frame_num,
                shift=args.sample_shift,
                sample_solver=args.sample_solver,
                sampling_steps=args.sample_steps,
                guide_scale=args.sample_guide_scale,
                seed=args.base_seed,
                offload_model=args.offload_model
            )
        else:  # i2v任务
            img = Image.open(image_paths[0]).convert("RGB")
            video = wan.generate(
                prompt,
                img,
                max_area=MAX_AREA_CONFIGS[args.size],
                frame_num=args.frame_num,
                shift=args.sample_shift,
                sample_solver=args.sample_solver,
                sampling_steps=args.sample_steps,
                guide_scale=args.sample_guide_scale,
                seed=args.base_seed,
                offload_model=args.offload_model
            )

        # 计算生成时间
        generation_time = time.time() - start_time
        total_generation_time += generation_time
        logging.info(f"生成耗时: {generation_time:.2f}秒")

        # 主进程保存结果
        if rank == 0:
            generation_counter += 1  # 计数器递增
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            safe_prompt = prompt.replace(" ", "_").replace("/", "_")[:50]  # 安全文件名处理
            file_uuid = str(uuid4())[:8]  # 新增:添加UUID短标识
            suffix = '.png' if "t2i" in args.task else '.mp4'
            
            # 生成唯一文件名
            save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_" \
                        f"{safe_prompt}_{timestamp}_{generation_counter}_{file_uuid}{suffix}"
            
            logging.info(f"保存结果到: {save_file}")
            if "t2i" in args.task:
                cache_image(
                    tensor=video.squeeze(1)[None],
                    save_file=save_file,
                    nrow=1,
                    normalize=True,
                    value_range=(-1, 1)
                )
            else:
                cache_video(
                    tensor=video[None],
                    save_file=save_file,
                    fps=cfg.sample_fps,
                    nrow=1,
                    normalize=True,
                    value_range=(-1, 1)
                )

    logging.info(f"总生成耗时: {total_generation_time:.2f}秒")
    logging.info("生成完成")

if __name__ == "__main__":
    args = _parse_args()
    generate(args)

网站公告

今日签到

点亮在社区的每一天
去签到