前言
ChatGPT出来后的两年多,也是我疯狂写博的两年多(年初deepseek更引爆了下),比如从创业起步时的15年到后来22年之间 每年2-6篇的,干到了23年30篇、24年65篇、25年前两月18篇,成了我在大模型和具身的原始技术积累
如今一转眼已到25年3月初,时光走得太快,近期和团队接了好几个大客户订单,使得3月起 不得不全力加速落地,自己也得每天抠paper、搞代码
虽然今年可能没法像去年24年那样干65篇,不过,我还是争取保持月月更新
- 一方面,有些文章是之前既定计划中的,比如如此文《π0开源了且推出自回归版π0-FAST——打造机器人动作专用的高效Tokenizer:比扩散π0的训练速度快5倍但效果相当》最后所说的,对π0源码的解读
「至于什么是π0,详见此文《π0——用于通用机器人控制的VLA模型:一套框架控制7种机械臂(基于PaliGemma和流匹配的3B模型)》」 - 二方面,我司「七月在线」在做一系列工厂落地场景的过程中,我们也希望团结到可以和我们一块做的朋友,而若想团结,便需要对外分享我们每个季度在重点做的业务场景
比如过去一周,我把lerobot、reflect vlm、π0的仿真环境都在我自己本地电脑上跑了下(过程中,GitHub copilot这种AI编程工具在环境的安装上帮了我很大的忙——各种环境 只要几句命令,直接帮我装好,真心不错)
如此硬着头皮冥思苦想、摸索了好几天,随后使得我自己知道怎么带队完成『太多工厂希望实现的一个生产线任务』了,3月初先仿真训练,2-3个月内部署到真机
当然了,也不单纯只是「这几天的想」就能想出来的,这几天之前
- 有把过去一年当三年用的具身技术积累
- 有一年多来,和同事们 如姚博士,以及朋友们许多的讨论
- 有去年十几个工厂对我们的支持与信任
我们正在不断壮大队伍
- 有我司内部同事,亦有我带的北理、中南等985的具身研究生,及一块合作开发的朋友,很快会把多个生产线任务并行开发起来
- 且无论哪个项目,都是不断长期迭代的,故过程中少不了科研层面的突破,欢迎更多伙伴加入我们(全、兼、实习皆可,有意者,敬请私我),和我们一块开发
话休絮烦,本文便按照如下图所示的源码结构,重点解读一下π的整个源码「我身边的很多朋友目前都在做π0的微调及二次开发,相信本文无论对我身边的朋友,还是对更多人的学习与工作,都会起到比较大的提升」
第一部分 examples、packages、scripts等结构的分析
1.1 examples :各种机器人平台的示例实现
根据π0对应examples模块的结构
其涉及以下模块
- aloha_real/:真实机器人ALOHA的示例
- aloha_sim/:ALOHA模拟器的示例
- droid/:DROID机器人的示例
- libero/:LIBERO基准测试的示例
- simple_client/:简单客户端的示例
- ur5/:UR5机器人的示例
- inference.ipynb:推理示例的Jupyter Notebook
- policy_records.ipynb:策略记录示例的Jupyter Notebook
1.2 packages
该模块的目录结构如下
1.3 scripts:包含数据处理、模型训练/推理的多个脚本
根据下图
可知,scripts 目录包含多个 Python 脚本,这些脚本用于数据处理、模型训练和服务部署等任务,每个脚本通常对应一个特定的功能或任务
- __init__.py
- compute_norm_stats.py: 计算数据的归一化统计信息
- serve_policy.py: 启动策略服务,提供模型推理接口
- train_test.py: 训练和测试模型
- train.py: 训练模型
1.3.1 __init__.py
1.3.2 compute_norm_stats.py:计算数据的归一化统计信息
1.3.3 serve_policy.py:启动策略服务,用于模型推理
- 在这个代码片段中,首先导入了一些必要的模块和库,包括 `policy`、`policy_config`、`websocket_policy_server` 和 `config`,这些模块来自 `openpi` 项目
接下来定义了一个枚举类 `EnvMode`,它表示支持的环境类型,包括 `ALOHA`、`ALOHA_SIM`、`DROID` 和 `LIBERO`from openpi.policies import policy as _policy # 导入 openpi.policies.policy 模块并重命名为 _policy from openpi.policies import policy_config as _policy_config # 导入 openpi.policies.policy_config 模块并重命名为 _policy_config from openpi.serving import websocket_policy_server # 导入 openpi.serving.websocket_policy_server 模块 from openpi.training import config as _config # 导入 openpi.training.config 模块并重命名为 _config
class EnvMode(enum.Enum): """支持的环境。""" ALOHA = "aloha" # ALOHA 环境 ALOHA_SIM = "aloha_sim" # ALOHA 模拟环境 DROID = "droid" # DROID 环境 LIBERO = "libero" # LIBERO 环境
- 然后定义了几个数据类
`Checkpoint` 类用于从训练好的检查点加载策略,包含两个字段:`config`(训练配置名称)和 `dir`(检查点目录)
`Default` 类表示使用默认策略
`Args` 类定义了脚本的参数,包括环境类型、默认提示、端口、是否记录策略行为以及如何加载策略 - 接下来定义了一个字典 `DEFAULT_CHECKPOINT`,它为每个环境类型指定了默认的检查点配置
`create_default_policy` 函数根据环境类型创建默认策略,如果环境类型不支持,则抛出异常# 每个环境应使用的默认检查点 DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = { EnvMode.ALOHA: Checkpoint( config="pi0_aloha", dir="s3://openpi-assets/checkpoints/pi0_base", ), EnvMode.ALOHA_SIM: Checkpoint( config="pi0_aloha_sim", dir="s3://openpi-assets/checkpoints/pi0_aloha_sim", ), EnvMode.DROID: Checkpoint( config="pi0_fast_droid", dir="s3://openpi-assets/checkpoints/pi0_fast_droid", ), EnvMode.LIBERO: Checkpoint( config="pi0_fast_libero", dir="s3://openpi-assets/checkpoints/pi0_fast_libero", ), }
`create_policy` 函数根据传入的参数创建策略,如果参数中指定了检查点,则从检查点加载策略,否则使用默认策略def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy: """为给定环境创建默认策略 """ if checkpoint := DEFAULT_CHECKPOINT.get(env): # 获取环境对应的默认检查点 return _policy_config.create_trained_policy( _config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt ) # 创建训练好的策略 raise ValueError(f"Unsupported environment mode: {env}") # 如果环境不支持,抛出异常
def create_policy(args: Args) -> _policy.Policy: """根据给定的参数创建策略 """ match args.policy: # 匹配策略类型 case Checkpoint(): # 如果是 Checkpoint 类型 return _policy_config.create_trained_policy( _config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt ) # 创建训练好的策略 case Default(): # 如果是 Default 类型 return create_default_policy(args.env, default_prompt=args.default_prompt) # 创建默认策略
- `main` 函数是脚本的入口点,它首先调用 `create_policy` 函数创建策略,然后记录策略的元数据
如果参数中指定了记录策略行为,则使用 `PolicyRecorder` 包装策略def main(args: Args) -> None: policy = create_policy(args) # 创建策略 policy_metadata = policy.metadata # 获取策略的元数据
接着获取主机名和本地 IP 地址# 记录策略的行为 if args.record: # 使用 PolicyRecorder 记录策略行为 policy = _policy.PolicyRecorder(policy, "policy_records")
并创建一个 WebSocket 服务器来提供策略服务,最后调用 `serve_forever` 方法启动服务器hostname = socket.gethostname() # 获取主机名 local_ip = socket.gethostbyname(hostname) # 获取本地 IP 地址 logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip) # 记录服务器创建信息
server = websocket_policy_server.WebsocketPolicyServer( policy=policy, host="0.0.0.0", port=args.port, metadata=policy_metadata, ) # 创建 WebSocket 策略服务器 server.serve_forever() # 启动服务器,永远运行
- 在脚本的最后,使用 `logging` 模块配置日志记录,并调用 `main` 函数启动脚本,参数通过 `tyro.cli` 解析
1.3.4 train_test.py:训练和测试模型
1.3.5 train.py:训练模型
1.3.6 scripts/docker
好的,下面是对 `openpi-main/scripts/docker` 目录的详细分析。这个目录通包含与 Docker 相关的脚本和配置文件,用于构建和管理 Docker 容器,具体而言,包含以下文件和子目录:
主要文件和功能如下所示
- docker/compose.yml
- docker/install_docker_ubuntu22.sh
- docker/install_nvidia_container_toolkit.sh
- docker/serve_policy.Dockerfile
// 待更
第二部分 核心模块src下models的全面分析与解读
接下来,我们来看核心src下的各个模块
首先是其中的src/openpi/models
2.1 models/pi0.py的实现
它结合了多模态输入(图像和文本)来生成机器人动作序列。下面是对代码的详细解析:
2.1.1 注意力掩码生成函数
这个函数生成transformer中使用的注意力掩码,控制 token 之间的注意力流动方式
def make_attn_mask(input_mask, mask_ar):
"""
从big_vision项目改编的注意力掩码生成函数
Token可以关注那些累积mask_ar小于等于自己的有效输入token。
这样`mask_ar` bool[?B, N]可用于设置几种类型的注意力,例如:
[[1 1 1 1 1 1]]: 纯因果注意力。
[[0 0 0 1 1 1]]: 前缀语言模型注意力。前3个token之间可以互相关注,
后3个token有因果注意力。第一个条目也可以是1,不改变行为。
[[1 0 1 0 1 0 0 1 0 0]]: 4个块之间的因果注意力。一个块的token可以
关注所有之前的块和同一块内的所有token。
参数:
input_mask: bool[B, N] 如果是输入的一部分则为true,如果是填充则为false
mask_ar: bool[?B, N] 如果前面的token不能依赖于它则为true,
如果它共享与前一个token相同的注意力掩码则为false
"""
# 将mask_ar广播到与input_mask相同的形状
mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
# 计算mask_ar在序列维度上的累积和
cumsum = jnp.cumsum(mask_ar, axis=1)
# 创建注意力掩码:当目标位置的累积值<=查询位置的累积值时,允许注意力流动
attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
# 创建有效掩码:只有有效的输入位置之间才能有注意力
valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
# 结合注意力掩码和有效掩码
return jnp.logical_and(attn_mask, valid_mask)
它支持多种注意力模式:
- 纯因果注意力(每个 token 只能关注自己和之前的 token)
- 前缀语言模型注意力(允许前缀内部自由注意,后缀部分使用因果注意力)
- 块状因果注意力(在块内自由注意,块之间是因果的)
2.1.2 位置编码函数
使用正弦余弦函数实现位置编码
def posemb_sincos(
pos: at.Real[at.Array, Any], embedding_dim: int, min_period: float, max_period: float
) -> at.Float[at.Array, f"b {embedding_dim}"]:
"""计算标量位置的正弦余弦位置嵌入向量"""
if embedding_dim % 2 != 0: # 检查嵌入维度是否为偶数
raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")
fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2) # 创建均匀分布的分数值
period = min_period * (max_period / min_period) ** fraction # 计算周期值,对数空间中均匀分布
sinusoid_input = jnp.einsum(
"i,j->ij",
pos,
1.0 / period * 2 * jnp.pi, # 计算角频率
precision=jax.lax.Precision.HIGHEST, # 使用最高精度进行计算
)
# 连接sin和cos值,形成完整的位置编码
return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
2.1.3 Pi0Config 配置类:含inputs_spec、get_freeze_filter
2.1.3.1 模型配置参数的定义
首先,这个类定义了模型的配置参数,比如PaLI-Gemma 变体:`gemma_2b
class Pi0Config(_model.BaseModelConfig):
dtype: str = "bfloat16" # 设置数据类型为bfloat16
paligemma_variant: _gemma.Variant = "gemma_2b" # 设置PaLI-Gemma变体为2B参数版本
action_expert_variant: _gemma.Variant = "gemma_300m" # 设置动作专家变体为300M参数版本
# 设置模型特定的默认值
action_dim: int = 32 # 设置动作维度为32
action_horizon: int = 50 # 设置动作序列长度为50步
max_token_len: int = 48 # 设置最大token长度为48
2.1.3.2 inputs_spec:定义了π0模型本身接收的输入数据格式
其次,通过inputs_spec函数定义了π0模型本身接收的输入数据格式,函数采用关键字参数 `batch_size`(默认为1),返回一个包含观察规格和动作规格的元组
def inputs_spec(self, *, batch_size: int = 1) -> Tuple[Type[_model.Observation], Type[_model.Actions]]
- 其支持多种输入,比如
视觉输入(三个不同视角的RGB图像)、语言输入(分词后的文本prompt)、状态输入(当前机器人状态) - 输出上
则是一个时序动作序列(包含50个连续的动作向量,每个动作向量有32个维度,可能对应关节角度或其他控制信号)
具体而言该函数先
创建图像规格
image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
其中的
- `[batch_size, *_model.IMAGE_RESOLUTION, 3]` 定义了图像张量的形状:比如
批次大小
图像分辨率(从 `_model.IMAGE_RESOLUTION` 获取,可能是如 [224, 224] 这样的值)
3 个颜色通道 (RGB)
- `jnp.float32` 指定了数据类型为 32 位浮点数
创建图像掩码规格
image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
其定义了图像掩码规格,每个批次中的每个图像都有一个布尔值,这个掩码用于指示哪些图像是有效的(`True`)或无效的(`False`)
创建观察规格:包含视觉输入、机器人状态、指令输入
`at.disable_typechecking()` 临时禁用类型检查,可能是因为这里创建的是类型规格而不是实际的数据,且观察规格包含多个组件:
- 多视角图像
base_0_rgb: 机器人底座/身体视角的RGB图像
left_wrist_0_rgb: 左手腕视角的RGB图像
right_wrist_0_rgb: 右手腕视角的RGB图像with at.disable_typechecking(): observation_spec = _model.Observation( images={ "base_0_rgb": image_spec, "left_wrist_0_rgb": image_spec, "right_wrist_0_rgb": image_spec, },
- 图像掩码
对应每个视角图像的有效性掩码 - 机器人状态:
形状为 `[batch_size, self.action_dim]` 的浮点数张量
`self.action_dim` 默认为32,表示状态向量的维度state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
- 分词后的文本prompt
形状为 `[batch_size, self.max_token_len]` 的整数张量
`self.max_token_len` 默认为48,表示最大token数量
数据类型为 `jnp.int32`,表示token ID - 提示掩码
与分词提示相同形状的布尔张量,用于指示哪些位置有有效的tokenstate=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32), tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32), tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool), )
创建动作规格
action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
其定义了动作数据的形状和类型:
- `batch_size`: 批次大小
- `self.action_horizon`: 动作序列长度,默认为50
- `self.action_dim`: 每个动作的维度,默认为32
- `jnp.float32` 指定了数据类型为32位浮点数
然后返回
return observation_spec, action_spec
2.1.3.3 get_freeze_filter:针对是否LoRA的处理
此外,该配置类还实现了get_freeze_filter这个函数,作用是如果选择LoRA微调(冻结原始预训练模型的参数,只更新新添加的低秩适应层参数),则需要对模型中的某些参数做冻结
三种可能的情况:
- 只对 PaLI-Gemma 使用 LoRA:冻结 Gemma 参数(但排除动作专家参数)
- 只对动作专家使用 LoRA:冻结动作专家参数
- 对两者都使用 LoRA:冻结两者的基础参数
如此,可以选择性地微调模型的特定部分(语言部分或动作预测部分)
具体而言
- 首先,定义函数
def get_freeze_filter(self) -> nnx.filterlib.Filter: """返回基于模型配置的冻结过滤器"""
- 其次,初始化变量
filters = [] # 初始化过滤器列表 has_lora = False # 初始化LoRA标志
- 接着,创建参数过滤器
# 匹配所有LLM参数的正则表达式,用于选择 Gemma 语言模型的参数 gemma_params_filter = nnx_utils.PathRegex(".*llm.*") # 匹配动作专家参数的正则表达式 action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*")
- 接下来是对PaLI-Gemma变体的处理
# 如果PaLI-Gemma使用LoRA if "lora" in self.paligemma_variant: filters.append( gemma_params_filter, # 添加Gemma参数过滤器 ) if "lora" not in self.action_expert_variant: # 如果只冻结Gemma参数,排除动作专家参数 filters.append( nnx.Not(action_expert_params_filter), ) has_lora = True
- 再下来是对动作专家变体的处理
elif "lora" in self.action_expert_variant: # 如果动作专家使用LoRA filters.append( action_expert_params_filter, ) has_lora = True
2.1.4 Pi0 模型类
核心模型类,继承自 `_model.BaseModel`,实现了:
- 多模态输入处理
处理多视角图像(基础视角、左手腕视角、右手腕视角)
处理文本提示(如指令)
处理机器人当前状态 - 扩散过程
训练时:将干净动作添加噪声,让模型学习去噪
推理时:从纯噪声开始,逐步降噪生成动作序列 - 注意力机制
使用精心设计的注意力掩码控制信息流动
前缀(图像和文本)内部使用全注意力
后缀(状态和动作)使用特殊的注意力模式
2.1.4.1 初始化方法 `__init__`
class Pi0(_model.BaseModel):
def __init__(self, config: Pi0Config, rngs: nnx.Rngs):
# 初始化基类
super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
# 获取PaLI-Gemma和动作专家配置
paligemma_config = _gemma.get_config(config.paligemma_variant)
action_expert_config = _gemma.get_config(config.action_expert_variant)
其组合了多个核心组件:
一个是PaLI-Gemma 模型:结合了 Gemma 语言模型和 SigLIP 视觉模型
- 先是对语言模型的初始化
# 创建并初始化语言模型 # TODO: 用NNX重写Gemma,目前使用桥接 llm = nnx_bridge.ToNNX( _gemma.Module( configs=[paligemma_config, action_expert_config], # 配置两个Gemma模型 embed_dtype=config.dtype, # 设置嵌入数据类型 ) ) llm.lazy_init(rngs=rngs, method="init") # 延迟初始化LLM
- 然后是对视觉模型的初始化
# 创建并初始化图像模型 img = nnx_bridge.ToNNX( _siglip.Module( num_classes=paligemma_config.width, # 设置图像特征维度与语言模型宽度相匹配 variant="So400m/14", # 使用400M参数SigLIP模型 pool_type="none", # 不使用池化,保留所有图像标记 scan=True, # 启用扫描优化 dtype_mm=config.dtype, # 设置矩阵乘法数据类型 ) ) # 使用假观察中的图像初始化图像模型 img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
- 最后,把语言模型和视觉模型组合成PaLI-Gemma多模态模型
# 组合LLM和图像模型为PaLI-Gemma多模态模型 self.PaliGemma = nnx.Dict(llm=llm, img=img)
另一个是线性投影层:用于
- 状态投影
# 状态投影层:将机器人状态投影到模型维度 self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
- 动作投影
# 动作输入投影层:将动作投影到模型维度 self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
- 时间-动作混合等
# 动作-时间MLP输入层:将连接的动作和时间特征投影到模型维度 self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs) # 动作-时间MLP输出层 self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs) # 动作输出投影层:将模型输出投影回动作维度 self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)
2.1.4.2 特征嵌入方法:embed_prefix、embed_suffix
- `embed_prefix`:处理图像和文本输入(图像通过SigLip模型编码,文本通过Gemma LLM编码),创建前缀 token,皆为双向注意力,用ar_mask = false表示
- `embed_suffix`:处理机器人状态信息
、噪声化的动作信息
(状态和噪声动作经过线性投影和MLP处理),创建后缀 token
其中
状态为单个token,和第一个动作token均设置为单向注意力,用ar_mask = true表示
其余动作tokens之间设置为双向注意力,用ar_mask = false表示
对于前者embed_prefix
def embed_prefix(
self, obs: _model.Observation
) -> Tuple[at.Float[at.Array, Any], at.Bool[at.Array, Any], at.Bool[at.Array, Any]]:
"""嵌入前缀部分(图像和文本)"""
input_mask = [] # 初始化输入掩码列表
ar_mask = [] # 初始化自回归掩码列表
tokens = [] # 初始化token列表
其工作流程为
- 图像处理:说白了,就是把图像token化
使用SigLip视觉模型处理每个图像,生成图像tokens序列# 嵌入图像 for name in obs.images: # 通过图像模型获取图像token image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False) tokens.append(image_tokens) # 添加图像token
- 图像掩码扩展
将图像掩码扩展到与图像tokens相同的序列长度,使用einops.repeat进行形状变换,这些掩码会指示哪些图像是有效的,而哪些是填充的# 重复图像掩码以匹配token维度 input_mask.append( einops.repeat( obs.image_masks[name], "b -> b s", # 调整形状:批次维度保持不变,添加序列维度 s=image_tokens.shape[1], # 序列长度等于图像token数 ) )
- 自回归掩码设置
设置图像tokens之间的注意力为双向(False表示双向注意力),原因在于图像内容通常是非时序性的数据# 图像token之间互相关注(非自回归) ar_mask += [False] * image_tokens.shape[1]
- 文本处理
使用LLM模型对文本输入tokenized_inputs进行嵌入
且同样设置为双向注意力,相当于语言token可以关注图像token,图像token反过来亦可关注语言token,最终实现多模态融合# 添加语言(即分词后的输入) if obs.tokenized_prompt is not None: # 通过语言模型嵌入分词后的提示 tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed") tokens.append(tokenized_inputs) # 添加文本token input_mask.append(obs.tokenized_prompt_mask) # 添加提示掩码
# 图像和语言输入之间完全关注(非自回归) ar_mask += [False] * tokenized_inputs.shape[1]
- 最后,连接所有token和掩码,其中包含了
多模态信息的融合表示tokens——图像token和语言token
以及指示哪些token是有效信息的input_mask
和如何在这些token之间进行注意力计算规则的ar_mask
# 连接所有token和掩码 tokens = jnp.concatenate(tokens, axis=1) # 在序列维度上连接token input_mask = jnp.concatenate(input_mask, axis=1) # 在序列维度上连接输入掩码 ar_mask = jnp.array(ar_mask) # 转换自回归掩码为数组 return tokens, input_mask, ar_mask # 返回token、输入掩码和自回归掩码
顺便,再回顾下此图
对于后者embed_suffix
def embed_suffix(
self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, Any]
) -> Tuple[at.Float[at.Array, Any], at.Bool[at.Array, Any], at.Bool[at.Array, Any]]:
"""嵌入后缀部分(状态和动作)"""
input_mask = [] # 初始化输入掩码列表
ar_mask = [] # 初始化自回归掩码列表
tokens = [] # 初始化token列表
其工作流程为
- 状态处理
将状态信息投影到embedding空间
并设置为单向注意力(True),表明图像和语言输入不能关注状态信息,因为image/language do not attend to state or actions# 添加单个状态token state_token = self.state_proj(obs.state)[:, None, :] # 投影状态并添加序列维度 tokens.append(state_token) # 添加状态token # 添加状态掩码(全为1),表示这个状态token是有效的 input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
# 图像/语言输入不关注状态或动作(自回归) ar_mask += [True]
- 时间步嵌入,使用正弦-余弦位置编码生成时间步嵌入
# 使用正弦余弦位置编码嵌入时间步,敏感度范围为[0, 1] time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
- 动作和时间信息融合
# 混合时间步+动作信息,使用MLP action_tokens = self.action_in_proj(noisy_actions) # 投影带噪声的动作 # 重复时间嵌入以匹配动作序列长度 time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon) # 连接动作和时间token action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
- MLP处理
使用两层MLP和swish激活函数对「动作和时间的组合表示」进行非线性变换,以进一步融合:动作和时间信息# 通过MLP处理 action_time_tokens = self.action_time_mlp_in(action_time_tokens) # 输入层 action_time_tokens = nnx.swish(action_time_tokens) # Swish激活函数 action_time_tokens = self.action_time_mlp_out(action_time_tokens) # 输出层
- 注意力掩码设置
第一个动作token设置为单向注意力「上面说过了的,单向注意力,用ar_mask = true表示」,其余动作tokens之间设置为双向注意力# 添加动作时间token tokens.append(action_time_tokens) # 添加掩码(全为1),表示所有动作token都是有效的 input_mask.append(jnp.ones(action_time_tokens.shape[:2], dtype=jnp.bool_)) # 图像/语言/状态输入不关注动作token(动作第一个是自回归的——单向,其余不是——双向) ar_mask += [True] + ([False] * (self.action_horizon - 1))
- 最后连接所有token和掩码
# 连接所有token和掩码 tokens = jnp.concatenate(tokens, axis=1) # 在序列维度上连接token input_mask = jnp.concatenate(input_mask, axis=1) # 在序列维度上连接输入掩码 ar_mask = jnp.array(ar_mask) # 转换自回归掩码为数组 return tokens, input_mask, ar_mask # 返回token、输入掩码和自回归掩码
2.1.4.3 损失函数 `compute_loss`
实现了扩散模型的训练损失计算
- 对输入观察进行预处理,其中
preprocess_rng用于观察预处理(比如图像增强等)
noise_rng用于生成噪声
time_rng用于从beta分布采样时间步def compute_loss( self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False ) -> at.Float[at.Array, Any]: """计算扩散模型的损失函数""" # 分割随机数生成器为三部分,用于不同的随机操作 preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
- 生成随机噪声并采样时间点 t
# 获取动作的批次形状 batch_shape = actions.shape[:-2] # 生成与动作相同形状的高斯噪声 noise = jax.random.normal(noise_rng, actions.shape) # 从Beta分布采样时间点,范围为[0.001, 1],Beta(1.5, 1)偏向较低的值 time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001 # 扩展时间维度以匹配动作形状 time_expanded = time[..., None, None]
- 创建带噪动作序列 x_t,相当于x_t是噪声化的动作,随着时间从0到1,原始动作逐渐加噪,变为纯噪声
而u_t代表所加的真实噪声,而咱们就是要预测所添加的噪声(而所添加的噪声即等于加满噪声的动作 - 原始动作)
扩散策略diffusion policy的灵感来源于图像生成中的扩散模型DDPM,通过逐步去除噪声来生成目标数据(比如机器人的动作序列),如果对DDPM原理不太明白的,详见此文《图像生成发展起源:从VAE、扩散模型DDPM、DDIM到DETR、ViT、Swin transformer》# 创建带噪声的动作:t*noise + (1-t)*actions x_t = time_expanded * noise + (1 - time_expanded) * actions # 计算真实噪声减去动作的差异,这是模型需要预测的目标 u_t = noise - actions
- 嵌入前缀和后缀
# 一次性前向传递前缀+后缀 # 嵌入前缀(图像和文本) prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation) # 嵌入后缀(状态和带噪声的动作) suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(observation, x_t, time)
- 构建注意力掩码和位置编码
根据下图
可得# 连接掩码:通过链接前缀和后缀的掩码,从而创建完整的输入掩码 input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1) ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0) # 创建注意力掩码make_attn_mask,从而控制不同token之间的可见性 attn_mask = make_attn_mask(input_mask, ar_mask) # 计算位置编码 positions = jnp.cumsum(input_mask, axis=1) - 1
- 模型前向传播,即使用PaliGemma进行推理,处理前缀和后缀token
当然了,输出中我们只关注与后缀相关的部分,因为其中包含了我们想要的动作预测的部分# 通过PaLI-Gemma模型处理token _, suffix_out = self.PaliGemma.llm( [prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions )
- 预测噪声v_t
# 将模型输出投影回动作空间 v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
- 计算预测噪声与实际噪声间的均方误差
# 返回预测噪声和真实噪声之间的均方误差 return jnp.mean(jnp.square(v_t - u_t), axis=-1)
2.1.4.4 推理函数 `sample_actions`
使用扩散模型采样机器人动作序列:
- 首先从纯噪声开始 (t=1)
- 通过重复迭代降噪步骤,逐步将噪声转化为有意义的动作序列
- 使用KV缓存优化推理速度
- 实现了一个迭代降噪过程:
- 最终返回完全降噪后的动作序列 x_0
// 待更