Diffusion Policy 代码详解

发布于:2025-07-02 ⋅ 阅读:(21) ⋅ 点赞:(0)

Diffusion Policy 代码详解

写在前面

Diffusion Policy这个开源项目从24年的十月份断断续续的关注了有大半年,中间跑过源代码也看了几遍代码,但是对代码和代码结构目前还是一知半解。这个暑假发誓要做出来点东西,因此,重拾DP这个项目,来认真仔细的学习下整个Diffusion Policy为之后的模仿学习打下基础。

预备知识

基础的机器学习知识

【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现

Diffusion Model的相关知识

之前写的VAE相关知识:VAE(Variational auto encoder)原理推导及代码部分实现

DDPM 李宏毅老师的课程:【李宏毅】2024年公认最好的扩散模型【Diffusion Model】教程!

模仿学习的相关知识

模仿学习简洁教程

RSS 2024 Workshop 教程: 真实机器人的监督策略学习

Diffusion Policy

原作者talk:【LeRobot】中文字幕|Diffusion Policy: LeRobot Research Presentation 2 by Cheng Chi

知乎文章:Diffusion Policy: 将扩散模型加噪-去噪的看家本领用于生成机器人动作啦!

代码结构

从train.py入手:

(robodiff)[diffusion_policy]$ python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_cnn.yaml training.seed=42 training.device=cuda:0 hydra.run.dir='data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}'

关于训练的配置都放在了image_pusht_diffusion_policy_cnn.yaml中。

Config 文件

image_pusht_diffusion_policy_cnn.yaml config来初窥diffusion policy的训练参数配置

在这个文件中,总共配置了几类参数:

_target_:指定了train的WorkSpace参考diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py,在WorkSpace中定义了train的具体细节,包括加载数据集,设置模型参数,训练细节,评估过程,保存checkppoint

checkpoint:检查点设置

dataloader:数据加载设置

ema:EMA模型设置, 伏笔1:为什么要用EMA?

logging:log 日志设置

multi_run:多运行设置

optimizer:优化器设置,伏笔2:优化器在训练中的作用是什么?

policy

shape_meta

  • 注意到shape_meta在config中出现了三次,分别在全局,policy,task中出现
  • 它们并非巧合一致,而是为了在不同的组件和模块中提供统一的动作和观测空间信息,达到在整个训练流程中动作和观测空间的定义统一。
  • 在策略类的 __init__ 方法中,会解析这个 shape_meta 来获取动作和观测的维度信息,进而配置观测编码器、扩散模型等组件。
  • 全局的 shape_meta 是为整个训练工作流提供统一的动作和观测空间信息,方便其他组件在需要时引用。例如,某些组件可能需要根据这个信息来初始化数据加载器、环境等。
  • task中shape_meta 主要用于配置任务相关的组件,如数据集和环境运行器。在数据集类的初始化过程中,会根据这个 shape_meta 来解析观测和动作的信息,从而正确加载和处理数据。

task

  • 对使用的dataset及参数进行配置
  • 选择相应的env_runner类并对其进行配置
  • 伏笔4:后续阅读了解env_runner代码

training:训练参数的一些配置,包括checkpoint, learning_rate, epoch 等设置相关

val_dataloader:验证集设置

总结

整个的训练全靠这个config的配置文件来指定相关类和参数,达到了是作者所说更灵活的编程思想,作者原话

implementing N tasks and M methods will only require O(N+M) amount of code instead of O(N*M)

These design decisions come at the cost of code repetition between the tasks and the methods. However, we believe that the benefit of being able to add/modify task/methods without affecting the remainder and being able understand a task/method by reading the code linearly outweighs the cost of copying and pasting 😊.

这个 config 的配置是了解整个训练框架的入口,指定了很重要的workspace, policy, env_runner, dataset并且对关键的参数进行配置,从这里入手可以一点点的拨开diffusion policy的神秘面纱。

Workspace 代码粗读(6.30)

代码参考:diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py

Workspace是主要的训练程序,拥有一个最重要的 run 函数,承担了主要的训练功能,掌管了整个训练过程。

__init__ 部分:

  • 继承了一部分 BaseWorkspce 类的 __init__ method
  • self.model 是从Policy类中进行加载的
  • EMA模型的加载
  • 优化器的设置

run 函数包括了:

  • resume training
  • configure dataset (test and validation), lr scheduler, ema, env, logging, checkpoint
  • 训练的主循环:
    • 主循环中遍历每个epoch,使用tqdm对训练数据train_dataloader进度进行监视
    • 训练过程就是经典的深度学习的过程:计算损失->反向传播->梯度更新
  • 此外,训练还包括了ema的更新和log相关
  • 在每个epoch训练就进行eval和rollout更新log,这里的模型是从policy中加载出来的
  • 在每个cfg.training.val_every就跑一次验证集对模型进行评估
  • checkpoint保存和log日志相关

Policy 代码粗读 (6.30)

diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py为例,对Policy的代码进行粗读。

DiffusionUnetHybridImagePolicy 是一个基于扩散模型的混合图像策略,它结合了图像编码器和条件 U-Net 1D 模型,用于处理机器人控制中的动作预测,反向扩散过程计算预测损失的,这里面包含了他们的具体实现。

上面Workspace中的self.model就是实例化这个policy得到的,因此,阅读这个代码是很有必要的。

__init__ 部分:

  • DiffusionUnetHybridImagePolicy 类继承自 BaseImagePolicy
  • __init__ 方法初始化了策略类的各种参数,包括观察和动作的形状信息、Unet相关设置,噪声调度器、预测时间步长、是否将观察作为全局条件等。
  • 创建**diffusion model **(Base Unet1D)

conditional_sample 函数:

conditional_sample 的主要作用是基于给定的条件数据(condition_data)和条件掩码(condition_mask),通过扩散模型的解码器生成上一时刻的trajectory

  • 加载扩散模型,生成与 trajectory 相同shape的噪声
  • 遍历 timesteps 反向扩散,使用model来获得模型输出,利用噪声调度器 self.noise_scheduler 来获得获得上一时刻的 trajectory

predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 函数:

基于obs来预测action。

  • 对action维度进行设置,调整为所需要的shape
  • 使用 conditional_sample 来获取结果中的动作预测部分,并对action进行 unnormalize 反归一化
  • 再对action的shape进行限制

set_normalizer(self, normalizer: LinearNormalizer): 函数:

将传入的 normalizer 对象的状态(即其内部的参数)复制到当前策略类实例的 self.normalizer 中

compute_loss(self, batch): 函数

对数据进行归一化后,对action添加噪声,并使用模型预测噪声残差,计算损失。这个函数包含了噪声添加,预测,计算损失的整个过程。

  • 对输入的数据和action进行归一化处理,这里的action也用trajectory来表示
  • 生成noise和timestep
  • 添加到trajecory中
  • 模型预测,得到预测的噪声残差
  • 计算预测值和真实值之间的loss

EMA Model(7.1)

EMA在训练中可以替代原始参数获得更平滑、更准确的预测结果,在Workspace中的代码体现:

# update ema
if cfg.training.use_ema:
	ema.step(self.model) # 更新EMA

env_runner (7.1)

仍然以pushT的实验为例解析 diffusion_policy/env_runner/pusht_image_runner.py 代码。

env_runner 主要是对policy进行训练可视化,包括创建悬链和测试的视频记录环境,运行rollout并记录视频。

__init__ 函数:

  • 创建记录pushT的训练和测试的运行环境
  • 各种参数(环境数量,前缀)的初始化

run 函数:

  • 计算需要运行的轮数
  • 运行rollout
  • 记录视频和log

Callback: 在Workspace中,env_runner在训练完每个epoch后进行调用:

# run rollout
if (self.epoch % cfg.training.rollout_every) == 0:
	runner_log = env_runner.run(policy)
	# log all
	step_log.update(runner_log)

dataset (7.1)

diffusion_policy/dataset/pusht_image_dataset.py 为例,对dataset类进行阅读。

这个dataset模块主要是对数据进行处理,包括:

  • 加载数据集
  • 划分数据集(训练集和验证集)
  • 对数据集进行归一化
  • 将数据集转换为标准的数据格式

更加详细的注释请阅读上面给出的代码链接

__init__ 函数:

从基类 BaseImageDataset 继承,加载指定zarr路径下的数据集,将数据集划分为训练集和验证集,并对数据集建立序列采样器,用于从 replay_buffer 回放缓冲区中采样序列数据。

get_validation_dataset(self) 函数

创建验证集数据集,使用与训练集相同的采样器,但使用验证集掩码,返回一个新的 PushTImageDataset 实例,包含验证集数据。

get_normalizer(self, mode='limits', **kwargs) 函数

为state-action pair数据和图像数据创建一个归一化器。

_sample_to_data(self, sample) 函数:

将采样得到的数据转换为标准的数据格式。

结构再梳理 (7.1)

现在把主要的代码都粗读了一遍,脑海里有了大概的印象了。现在再梳理一遍diffusion policy的代码逻辑,把分隔的这些代码都融会贯通起来,梳理出一个具体的逻辑链,每个部分都各司其什么职。

回到Workspace对代码进行梳理:

  • 从BaseWorkspace继承基类
  • Config文件参数设置:
    • 随机种子
    • [policy类](#Policy 代码粗读 (6.30))的配置,后续要用到,作为model参与训练(计算损失,扩散与反扩散过程)
    • ema模型设置,平滑上面的model参数
  • run 函数:这里不详细说一些config设置的事了,在[上面](#Workspace 代码粗读(6.30))已经分析过了,这里详细说一下训练的逻辑,直接进入 training loop
    • 加载policy实例化的model计算loss
    • 反向传播更新梯度和学习率
    • 使用ema对model进行平滑化处理
    • eval:
      • 运行rollout评估model
      • 在验证集上验证

总的来说,Workspace是整个diffusion policy的主程序,确实是diffusion policy work的space。这里面包含了policy(计算损失与扩散过程),dataset(提供数据集),env_runner(记录rollout), EMAModel(平滑model参数)