Diffusion Policy 代码详解
文章目录
写在前面
Diffusion Policy这个开源项目从24年的十月份断断续续的关注了有大半年,中间跑过源代码也看了几遍代码,但是对代码和代码结构目前还是一知半解。这个暑假发誓要做出来点东西,因此,重拾DP这个项目,来认真仔细的学习下整个Diffusion Policy为之后的模仿学习打下基础。
预备知识
基础的机器学习知识
【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现
Diffusion Model的相关知识
之前写的VAE相关知识:VAE(Variational auto encoder)原理推导及代码部分实现
DDPM 李宏毅老师的课程:【李宏毅】2024年公认最好的扩散模型【Diffusion Model】教程!
模仿学习的相关知识
RSS 2024 Workshop 教程: 真实机器人的监督策略学习
Diffusion Policy
原作者talk:【LeRobot】中文字幕|Diffusion Policy: LeRobot Research Presentation 2 by Cheng Chi
知乎文章:Diffusion Policy: 将扩散模型加噪-去噪的看家本领用于生成机器人动作啦!
代码结构
从train.py入手:
(robodiff)[diffusion_policy]$ python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_cnn.yaml training.seed=42 training.device=cuda:0 hydra.run.dir='data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}'
关于训练的配置都放在了image_pusht_diffusion_policy_cnn.yaml
中。
Config 文件
以 image_pusht_diffusion_policy_cnn.yaml config来初窥diffusion policy的训练参数配置
在这个文件中,总共配置了几类参数:
_target_
:指定了train的WorkSpace
参考diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py,在WorkSpace
中定义了train的具体细节,包括加载数据集,设置模型参数,训练细节,评估过程,保存checkppoint
。
checkpoint
:检查点设置
dataloader
:数据加载设置
ema
:EMA模型设置, 伏笔1:为什么要用EMA?
logging
:log 日志设置
multi_run
:多运行设置
optimizer
:优化器设置,伏笔2:优化器在训练中的作用是什么?
policy
:
- 指定使用的策略类
- 对策略类的参数进行设置,请对比 image_pusht_diffusion_policy_cnn.yaml 中Policy部分和diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py You will see.
- 伏笔3:策略类起到什么作用?
shape_meta
:
- 注意到shape_meta在config中出现了三次,分别在全局,policy,task中出现
- 它们并非巧合一致,而是为了在不同的组件和模块中提供统一的动作和观测空间信息,达到在整个训练流程中动作和观测空间的定义统一。
- 在策略类的
__init__
方法中,会解析这个shape_meta
来获取动作和观测的维度信息,进而配置观测编码器、扩散模型等组件。 - 全局的
shape_meta
是为整个训练工作流提供统一的动作和观测空间信息,方便其他组件在需要时引用。例如,某些组件可能需要根据这个信息来初始化数据加载器、环境等。 - task中
shape_meta
主要用于配置任务相关的组件,如数据集和环境运行器。在数据集类的初始化过程中,会根据这个shape_meta
来解析观测和动作的信息,从而正确加载和处理数据。
task
:
- 对使用的
dataset
及参数进行配置 - 选择相应的
env_runner
类并对其进行配置 - 伏笔4:后续阅读了解
env_runner
代码
training
:训练参数的一些配置,包括checkpoint
, learning_rate
, epoch
等设置相关
val_dataloader
:验证集设置
总结
整个的训练全靠这个config的配置文件来指定相关类和参数,达到了是作者所说更灵活的编程思想,作者原话:
implementing
N
tasks andM
methods will only requireO(N+M)
amount of code instead ofO(N*M)
These design decisions come at the cost of code repetition between the tasks and the methods. However, we believe that the benefit of being able to add/modify task/methods without affecting the remainder and being able understand a task/method by reading the code linearly outweighs the cost of copying and pasting 😊.
这个 config
的配置是了解整个训练框架的入口,指定了很重要的workspace
, policy
, env_runner
, dataset
并且对关键的参数进行配置,从这里入手可以一点点的拨开diffusion policy的神秘面纱。
Workspace 代码粗读(6.30)
代码参考:diffusion_policy/workspace/train_diffusion_unet_hybrid_workspace.py
Workspace是主要的训练程序,拥有一个最重要的 run
函数,承担了主要的训练功能,掌管了整个训练过程。
__init__
部分:
- 继承了一部分
BaseWorkspce
类的__init__
method self.model
是从Policy
类中进行加载的- EMA模型的加载
- 优化器的设置
run
函数包括了:
- resume training
- configure
dataset (test and validation)
,lr scheduler
,ema
,env
,logging
,checkpoint
- 训练的主循环:
- 主循环中遍历每个epoch,使用tqdm对训练数据
train_dataloader
进度进行监视 - 训练过程就是经典的深度学习的过程:计算损失->反向传播->梯度更新
- 主循环中遍历每个epoch,使用tqdm对训练数据
- 此外,训练还包括了ema的更新和log相关
- 在每个epoch训练就进行eval和rollout更新log,这里的模型是从policy中加载出来的
- 在每个
cfg.training.val_every
就跑一次验证集对模型进行评估 checkpoint
保存和log
日志相关
Policy 代码粗读 (6.30)
以diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py为例,对Policy的代码进行粗读。
DiffusionUnetHybridImagePolicy
是一个基于扩散模型的混合图像策略类,它结合了图像编码器和条件 U-Net 1D 模型,用于处理机器人控制中的动作预测,反向扩散过程,计算预测损失的,这里面包含了他们的具体实现。
上面Workspace中的self.model
就是实例化这个policy得到的,因此,阅读这个代码是很有必要的。
__init__
部分:
DiffusionUnetHybridImagePolicy
类继承自BaseImagePolicy
。__init__
方法初始化了策略类的各种参数,包括观察和动作的形状信息、Unet相关设置,噪声调度器、预测时间步长、是否将观察作为全局条件等。- 创建**diffusion model **(Base Unet1D)
conditional_sample
函数:
conditional_sample
的主要作用是基于给定的条件数据(condition_data
)和条件掩码(condition_mask
),通过扩散模型的解码器生成上一时刻的trajectory
。
- 加载扩散模型,生成与
trajectory
相同shape的噪声 - 遍历
timesteps
反向扩散,使用model来获得模型输出,利用噪声调度器self.noise_scheduler
来获得获得上一时刻的trajectory
predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
函数:
基于obs来预测action。
- 对action维度进行设置,调整为所需要的shape
- 使用
conditional_sample
来获取结果中的动作预测部分,并对action进行unnormalize
反归一化 - 再对action的shape进行限制
set_normalizer(self, normalizer: LinearNormalizer):
函数:
将传入的 normalizer 对象的状态(即其内部的参数)复制到当前策略类实例的 self.normalizer 中
compute_loss(self, batch):
函数
对数据进行归一化后,对action添加噪声,并使用模型预测噪声残差,计算损失。这个函数包含了噪声添加,预测,计算损失的整个过程。
- 对输入的数据和action进行归一化处理,这里的action也用trajectory来表示
- 生成noise和timestep
- 添加到trajecory中
- 模型预测,得到预测的噪声残差
- 计算预测值和真实值之间的loss
EMA Model(7.1)
EMA在训练中可以替代原始参数获得更平滑、更准确的预测结果,在Workspace中的代码体现:
# update ema
if cfg.training.use_ema:
ema.step(self.model) # 更新EMA
env_runner (7.1)
仍然以pushT的实验为例解析 diffusion_policy/env_runner/pusht_image_runner.py 代码。
env_runner
主要是对policy进行训练可视化,包括创建悬链和测试的视频记录环境,运行rollout并记录视频。
__init__
函数:
- 创建记录pushT的训练和测试的运行环境
- 各种参数(环境数量,前缀)的初始化
run
函数:
- 计算需要运行的轮数
- 运行rollout
- 记录视频和log
Callback: 在Workspace中,env_runner在训练完每个epoch后进行调用:
# run rollout
if (self.epoch % cfg.training.rollout_every) == 0:
runner_log = env_runner.run(policy)
# log all
step_log.update(runner_log)
dataset (7.1)
以diffusion_policy/dataset/pusht_image_dataset.py 为例,对dataset类进行阅读。
这个dataset模块主要是对数据进行处理,包括:
- 加载数据集
- 划分数据集(训练集和验证集)
- 对数据集进行归一化
- 将数据集转换为标准的数据格式
更加详细的注释请阅读上面给出的代码链接
__init__
函数:
从基类 BaseImageDataset
继承,加载指定zarr路径下的数据集,将数据集划分为训练集和验证集,并对数据集建立序列采样器,用于从 replay_buffer
回放缓冲区中采样序列数据。
get_validation_dataset(self)
函数:
创建验证集数据集,使用与训练集相同的采样器,但使用验证集掩码,返回一个新的 PushTImageDataset 实例,包含验证集数据。
get_normalizer(self, mode='limits', **kwargs)
函数:
为state-action pair数据和图像数据创建一个归一化器。
_sample_to_data(self, sample)
函数:
将采样得到的数据转换为标准的数据格式。
结构再梳理 (7.1)
现在把主要的代码都粗读了一遍,脑海里有了大概的印象了。现在再梳理一遍diffusion policy的代码逻辑,把分隔的这些代码都融会贯通起来,梳理出一个具体的逻辑链,每个部分都各司其什么职。
回到Workspace对代码进行梳理:
- 从BaseWorkspace继承基类
- Config文件参数设置:
- 随机种子
- [policy类](#Policy 代码粗读 (6.30))的配置,后续要用到,作为model参与训练(计算损失,扩散与反扩散过程)
- ema模型设置,平滑上面的model参数
run
函数:这里不详细说一些config设置的事了,在[上面](#Workspace 代码粗读(6.30))已经分析过了,这里详细说一下训练的逻辑,直接进入training loop
- 加载policy实例化的model计算loss
- 反向传播更新梯度和学习率
- 使用ema对model进行平滑化处理
- eval:
- 运行rollout评估model
- 在验证集上验证
总的来说,Workspace是整个diffusion policy的主程序,确实是diffusion policy work的space。这里面包含了policy(计算损失与扩散过程),dataset(提供数据集),env_runner(记录rollout), EMAModel(平滑model参数)