PyTorch API 5 - 全分片数据并行、流水线并行、概率分布

发布于:2025-05-10 ⋅ 阅读:(20) ⋅ 点赞:(0)


全分片数据并行 (FullyShardedDataParallel)

class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)

一个用于在数据并行工作节点间分片模块参数的包装器。

该设计灵感来源于Xu等人的论文以及DeepSpeed的ZeRO第三阶段技术。

FullyShardedDataParallel通常简称为FSDP。

要了解FSDP内部实现原理,请参阅FSDP技术说明


示例:

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> torch.cuda.set_device(device_id)
>>> sharded_module = FSDP(my_module)
>>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
>>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
>>> loss = x.sum()
>>> loss.backward()
>>> optim.step()

使用FSDP需要先包装你的模块,然后在之后初始化优化器。这是必要的,因为FSDP会改变参数变量。

在设置FSDP时,你需要考虑目标CUDA设备。如果设备有ID(dev_id),你有三个选项:

  • 将模块放在该设备上
  • 使用torch.cuda.set_device(dev_id)设置设备
  • dev_id传入device_id构造函数参数

这确保了FSDP实例的计算设备是目标设备。对于选项1和3,FSDP初始化始终在GPU上进行。对于选项2,FSDP初始化发生在模块的当前设备上,可能是CPU。

如果你使用sync_module_states=True标志,需要确保模块在GPU上,或者使用device_id参数指定FSDP在构造函数中将模块移动到的CUDA设备。这是必要的,因为sync_module_states=True需要GPU通信。

FSDP还会负责将输入张量移动到前向方法的GPU计算设备上,因此你不需要手动将它们从CPU移动。

对于use_orig_params=TrueShardingStrategy.SHARD_GRAD_OP会暴露未分片的参数,而不是像ShardingStrategy.FULL_SHARD那样在前向之后的分片参数。如果你想检查梯度,可以使用summon_full_params方法并设置with_grads=True

使用limit_all_gathers=True时,你可能会在FSDP前向之前看到一个CPU线程没有发出任何内核的间隙。这是有意为之,显示了速率限制器在起作用。以这种方式同步CPU线程可以防止为后续的all-gather操作过度分配内存,实际上不会延迟GPU内核的执行。

出于与自动梯度相关的原因,FSDP在前向和后向计算期间会用torch.Tensor视图替换托管模块的参数。如果你的模块的前向依赖于保存的参数引用而不是每次迭代重新获取引用,那么它将看不到FSDP新创建的视图,自动梯度将无法正常工作。

最后,当使用sharding_strategy=ShardingStrategy.HYBRID_SHARD且分片进程组为节点内、复制进程组为节点间时,设置NCCL_CROSS_NIC=1可以帮助在某些集群设置中提高复制进程组的all-reduce时间。


限制

使用FSDP时有几个需要注意的限制:

  • 在使用CPU卸载时,FSDP目前不支持在no_sync()之外进行梯度累积。这是因为FSDP使用新减少的梯度而不是与任何现有梯度累积,这可能导致不正确的结果。
  • FSDP不支持运行包含在FSDP实例中的子模块的前向传递。这是因为子模块的参数会被分片,但子模块本身不是FSDP实例,因此其前向传递不会适当地all-gather完整参数。
  • 由于FSDP注册后向钩子的方式,它不支持双重后向。
  • FSDP在冻结参数时有一些限制。对于use_orig_params=False,每个FSDP实例必须管理全部冻结或全部未冻结的参数。对于use_orig_params=True,FSDP支持混合冻结和未冻结参数,但建议避免这样做以防止高于预期的梯度内存使用。
  • 截至PyTorch 1.12,FSDP对共享参数的支持有限。如果你的用例需要增强的共享参数支持,请在此问题中发帖。
  • 你应该避免在不使用summon_full_params上下文的情况下在前向和后向之间修改参数,因为这些修改可能不会持久化。

参数

  • module (nn.Module) – 这是要用FSDP包装的模块。
  • process_group (Optional[Union[ProcessGroup*, Tuple[ProcessGroup*, ProcessGroup]]]) – 这是模型分片的进程组,因此也是用于FSDP的all-gather和reduce-scatter集体通信的进程组。如果为None,则FSDP使用默认进程组。对于混合分片策略如ShardingStrategy.HYBRID_SHARD,用户可以传入一个进程组元组,分别表示分片和复制的组。如果为None,则FSDP为用户构建进程组以在节点内分片和在节点间复制。(默认:None
  • sharding_strategy (Optional[ShardingStrategy]) – 这配置分片策略,可能会在内存节省和通信开销之间进行权衡。详情参见ShardingStrategy。(默认:FULL_SHARD
  • cpu_offload (Optional[CPUOffload]) – 这配置CPU卸载。如果设置为None,则不进行CPU卸载。详情参见CPUOffload。(默认:None
  • auto_wrap_policy (Optional[Union[Callable[[nn.Module,* [bool],* int ],* [bool]], ModuleWrapPolicy*, CustomPolicy]]) – 这指定一个策略将FSDP应用于module的子模块,这对于通信和计算重叠是必要的,从而影响性能。如果为None,则FSDP仅应用于module,用户应手动将FSDP应用于父模块(自底向上)。为方便起见,这直接接受ModuleWrapPolicy,允许用户指定要包装的模块类(例如transformer块)。否则,这应该是一个可调用对象,接受三个参数module: nn.Modulerecurse: boolnonwrapped_numel: int,并返回一个bool,指定如果recurse=False是否应对传入的module应用FSDP,或者如果recurse=True是否应继续遍历模块的子树。用户可以添加额外的参数到可调用对象。torch.distributed.fsdp.wrap.py中的size_based_auto_wrap_policy提供了一个示例可调用对象,如果模块子树中的参数超过100M numel,则应用FSDP。我们建议在应用FSDP后打印模型并根据需要进行调整。

示例:


>>> def custom_auto_wrap_policy(
>>>     module: nn.Module, >>    recurse: bool, >>    nonwrapped_numel: int, >>    # Additional custom arguments
>>>     min_num_params: int = int(1e8), >>) -bool:
>>>     return nonwrapped_numel >= min_num_params
>>> # Configure a custom `min_num_params`
>>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))

  • backward_prefetch (Optional[BackwardPrefetch]) – 该参数用于配置所有-gather操作的显式反向预取。如果设为None,FSDP将不执行反向预取,导致反向传播过程中没有通信与计算重叠。详情参见BackwardPrefetch。(默认值:BACKWARD_PRE
  • mixed_precision (Optional[MixedPrecision]) – 该参数用于配置FSDP的原生混合精度。如果设为None,则不使用混合精度。否则可以设置参数、缓冲区和梯度缩减的数据类型。详情参见MixedPrecision。(默认值:None
  • ignored_modules (Optional[Iterable[torch.nn.Module ]]) – 该参数指定的模块及其子模块的参数和缓冲区将被当前FSDP实例忽略。直接列在ignored_modules中的模块不应是FullyShardedDataParallel实例,且任何已构建的FullyShardedDataParallel子模块即使嵌套在当前实例下也不会被忽略。该参数可用于:1) 使用auto_wrap_policy时避免在模块粒度上分片特定参数;2) 当参数分片不由FSDP管理时。(默认值:None
  • param_init_fn (Optional[Callable[[nn.Module], None]]) – 一个可调用对象Callable[torch.nn.Module] -None,用于指定如何将当前位于meta设备上的模块初始化到实际设备。从v1.12开始,FSDP通过is_meta检测带有参数或缓冲区的meta设备模块,并执行以下操作:如果指定了param_init_fn则应用该函数,否则调用nn.Module.reset_parameters()。两种情况下,实现应初始化该模块的参数/缓冲区,而非其子模块的,以避免重复初始化。此外,FSDP还支持通过torchdistX的(https://github.com/pytorch/torchdistX) deferred_init() API进行延迟初始化——延迟模块会通过调用指定的param_init_fn或torchdistX默认的materialize_module()来初始化。如果指定了param_init_fn,它将应用于所有meta设备模块,因此可能需要根据模块类型进行条件判断。FSDP在参数扁平化和分片之前调用初始化函数。

示例:


>>> module = MyModule(device="meta")
>>> def my_init_fn(module: nn.Module):
>>>     # E.g. initialize depending on the module type
>>>     ...
>>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
>>> print(next(fsdp_model.parameters()).device) # current CUDA device
>>> # With torchdistX
>>> module = deferred_init.deferred_init(MyModule, device="cuda")
>>> # Will initialize via deferred_init.materialize_module().
>>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)

  • device_id (Optional[Union[int, torch.device]]) – 指定FSDP初始化所在的CUDA设备,可以是inttorch.device类型,包括模块初始化(如需要)和参数分片过程。当module位于CPU时指定该参数可提升初始化速度。若已设置默认CUDA设备(例如通过torch.cuda.set_device),可传入torch.cuda.current_device。(默认值:None
  • sync_module_states ([bool]) – 若为True,每个FSDP模块会从rank 0广播模块参数和缓冲区,确保跨rank数据一致(会增加构造函数的通信开销)。这有助于通过load_state_dict以内存高效方式加载state_dict检查点。示例用法参见FullStateDictConfig。(默认值:False
  • forward_prefetch ([bool]) – 若为True,FSDP会显式在当前前向计算完成前预取下一轮前向传播的all-gather操作。仅适用于CPU密集型工作负载,提前发起all-gather可能提升计算重叠度。由于预取遵循首轮迭代执行顺序,该参数仅适用于静态图模型。(默认值:False
  • limit_all_gathers ([bool]) – 若为True,FSDP会显式同步CPU线程,确保GPU内存仅被两个连续FSDP实例占用(当前执行计算的实例和预取了all-gather的下一个实例)。若为False,则允许CPU线程无额外同步地发起all-gather。(默认值:True)该特性常被称为"速率限制器",仅在内存压力低的CPU密集型场景下可设为False,此时CPU线程可激进提交所有内核而无需考虑GPU内存占用。
  • use_orig_params ([bool]) – 设为True时,FSDP将使用模块的原始参数。通过nn.Module.named_parameters()暴露原始参数而非内部FlatParameter,使得优化器基于原始参数运行(支持每个原始参数的独立超参)。FSDP会保留原始参数变量,并在未分片/分片状态间转换其数据(始终分别作为底层未分片/分片FlatParameter的视图)。当前算法中分片形式始终为1D,会丢失原始张量结构。原始参数的数据可能全部/部分/不存在于当前rank,不存在时其数据表现为空张量。用户不应编写依赖分片形式数据的程序。使用torch.compile()必须设为True。设为False会通过nn.Module.named_parameters()暴露内部FlatParameter。(默认值:False
  • ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]) – 指定不由该FSDP实例管理的参数或模块,意味着这些参数不会被分片且梯度不会跨rank规约。该参数与现有ignored_modules参数功能统一,未来可能弃用ignored_modules。为保持向后兼容,同时保留两个参数,但FSDP要求二者只能有一个为非None
  • device_mesh (Optional[DeviceMesh]) – 可作为process_group的替代方案。传入时FSDP会使用底层process_group执行all-gather和reduce-scatter集合通信,因此这两个参数需互斥。对于ShardingStrategy.HYBRID_SHARD等混合分片策略,可传入2D DeviceMesh替代process_group元组。2D FSDP+TP场景下必须使用device_mesh而非process_group。更多DeviceMesh信息请参阅:

DeviceMesh教程

apply(fn)

对自身以及每个子模块(通过.children()返回)递归应用fn函数。

典型用途包括初始化模型参数(参见torch.nn.init文档)。

torch.nn.Module.apply相比,此版本在应用fn前会先收集完整参数。注意不应在另一个summon_full_params上下文中调用该方法。

参数

  • fn (Module -None) – 要应用于每个子模块的函数

返回
自身

返回类型
Module


check_is_root()

检查此实例是否为根 FSDP 模块。

返回类型:bool


clip_grad_norm_(max_norm, norm_type=2.0)

对所有参数的梯度范数进行裁剪。

该范数计算时将所有权重参数的梯度视为单个向量,并原地修改这些梯度值。

参数

  • max_norm (float 或 int) – 梯度的最大范数值
  • norm_type (float 或 int) – 所用p-范数类型。可设为'inf'表示无穷范数

返回

参数的总范数值(视为单个向量)。

返回类型:Tensor

若所有FSDP实例都使用NO_SHARD策略(即梯度未跨rank分片),可直接使用torch.nn.utils.clip_grad_norm_()

若存在FSDP实例使用分片策略(即非NO_SHARD策略),则应改用本方法而非torch.nn.utils.clip_grad_norm_(),因为本方法能正确处理跨rank分片的梯度。

返回的总范数值将根据PyTorch的类型提升规则,采用所有参数/梯度中"最大"的数据类型。例如:若所有参数/梯度使用低精度类型,则返回范数保持该低精度类型;但只要存在至少一个FP32精度的参数/梯度,返回范数将采用FP32类型。

警告:由于涉及集合通信操作,必须在所有rank上调用本方法。


static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)

展平分片的优化器状态字典。

该API与shard_full_optim_state_dict()类似,唯一区别在于输入的sharded_optim_state_dict应来自sharded_optim_state_dict()的返回结果。因此,每个rank上都会执行all-gather调用来收集ShardedTensor

参数

  • sharded_optim_state_dict (Dict[str, Any]) - 与未展平参数对应的优化器状态字典,包含分片的优化器状态。
  • model ( torch.nn.Module ) - 参考shard_full_optim_state_dict()
  • optim ( torch.optim.Optimizer ) - 用于model参数的优化器。

返回值:参考shard_full_optim_state_dict()

返回类型:dict[str, Any]


forward(*args, **kwargs)

对封装模块执行前向传播,同时插入 FSDP 特有的前向分片与后向分片逻辑。

返回类型:Any


*static* fsdp_modules(module, root_only=False)

返回所有嵌套的FSDP实例。

这可能包含module本身,且当root_only=True时仅包含FSDP根模块。

参数

  • module ( torch.nn.Module ) – 根模块,可能是也可能不是一个FSDP模块。
  • root_only ([bool]) – 是否仅返回FSDP根模块。(默认值:False

返回

嵌套在输入module中的FSDP模块。

返回类型:List [FullyShardedDataParallel]


static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)

返回完整的优化器状态字典。

该方法会在 rank 0 上整合完整的优化器状态,并以 dict 形式返回,遵循 torch.optim.Optimizer.state_dict() 的规范,即包含 "state""param_groups" 键。model 中包含的 FSDP 模块中的扁平化参数会被映射回其原始的非扁平化参数。

由于使用了集体通信操作,此方法需要在所有 rank 上调用。但如果 rank0_only=True,则仅在 rank 0 上填充状态字典,其他所有 rank 返回空字典。

torch.optim.Optimizer.state_dict() 不同,本方法使用完整参数名作为键(而非参数 ID)。

torch.optim.Optimizer.state_dict() 类似,优化器状态字典中包含的张量不会被克隆,因此可能存在别名意外。建议立即保存返回的优化器状态字典(例如使用 torch.save())以获得最佳实践。

参数

  • model (torch.nn.Module) – 根模块(可能是也可能不是 FullyShardedDataParallel 实例),其参数已传入优化器 optim
  • optim (torch.optim.Optimizer) – 用于 model 参数的优化器。
  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传入优化器 optim 的输入,表示参数组的 list 或可迭代参数;如果为 None,则该方法假定输入为 model.parameters()。此参数已弃用,无需再传递。(默认值:None
  • rank0_only ([bool]) – 如果为 True,仅在 rank 0 上保存填充的字典;如果为 False,在所有 rank 上保存。(默认值:True
  • group (dist.ProcessGroup) – 模型的进程组,如果使用默认进程组则为 None。(默认值:None

返回值:一个 dict,包含 model 原始非扁平化参数的优化器状态,并遵循 torch.optim.Optimizer.state_dict() 规范包含 “state” 和 “param_groups” 键。如果 rank0_only=True,则非零 rank 返回空字典。

返回类型:Dict[str, Any]


static get_state_dict_type(module)

获取以 module 为根的 FSDP 模块的 state_dict_type 及其对应配置。

目标模块不必是 FSDP 模块。

返回值:返回一个 StateDictSettings 对象,包含当前设置的 state_dict_type 以及 state_dict / optim_state_dict 配置。

异常

  • 如果不同 FSDP 子模块的 StateDictSettings 不一致,抛出 AssertionError

返回类型:StateDictSettings


property module:  Module 

返回被包装的模块。


named_buffers(*args, **kwargs)

返回一个遍历模块缓冲区的迭代器,同时生成缓冲区的名称和缓冲区本身。

summon_full_params() 上下文管理器内部时,会拦截缓冲区名称并移除所有特定于FSDP的扁平化缓冲区前缀。

返回类型为 Iterator [tuple [str, torch.Tensor]]


named_parameters(*args, **kwargs)

返回一个遍历模块参数的迭代器,同时生成参数名称和参数本身。

summon_full_params() 上下文管理器内部时,会拦截参数名称并移除所有特定于FSDP的扁平化参数前缀。

返回类型为 Iterator [tuple [str, [torch.nn.parameter.Parameter]]


no_sync()

禁用跨FSDP实例的梯度同步。

在此上下文中,梯度将累积在模块变量中,这些梯度会在退出上下文后的首次前向-反向传播过程中同步。此功能应仅用于根FSDP实例,并将递归应用于所有子FSDP实例。

注意:这可能导致更高的内存使用量,因为FSDP会累积完整的模型梯度(而非梯度分片),直到最终同步完成。

注意:与CPU卸载功能同时使用时,在上下文管理器内部梯度不会被卸载到CPU。相反,它们只会在最终同步后立即被卸载。

返回类型:生成器


static optim_state_dict(model, optim, optim_state_dict=None, group=None)

转换分片模型对应的优化器状态字典。

给定的状态字典可转换为以下三种类型之一:

  1. 完整优化器状态字典 2) 分片优化器状态字典 3) 本地优化器状态字典

对于完整优化器状态字典,所有状态均未展平且未分片。可通过 state_dict_type() 指定仅限 Rank0 和仅限 CPU 以避免内存溢出。

对于分片优化器状态字典,所有状态均未展平但已分片。可通过 state_dict_type() 指定仅限 CPU 以进一步节省内存。

对于本地状态字典,不会执行任何转换。但状态会从 nn.Tensor 转换为 ShardedTensor 以表示其分片特性(当前尚未支持此功能)。


示例:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> from torch.distributed.fsdp import FullOptimStateDictConfig
>>> # Save a checkpoint
>>> model, optim = ...
>>> FSDP.set_state_dict_type(
>>>     model, >>    StateDictType.FULL_STATE_DICT, >>    FullStateDictConfig(rank0_only=False), >>    FullOptimStateDictConfig(rank0_only=False), >>)
>>> state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
>>> state_dict, optim_state_dict = load_a_checkpoint()
>>> FSDP.set_state_dict_type(
>>>     model, >>    StateDictType.FULL_STATE_DICT, >>    FullStateDictConfig(rank0_only=False), >>    FullOptimStateDictConfig(rank0_only=False), >>)
>>> model.load_state_dict(state_dict)
>>> optim_state_dict = FSDP.optim_state_dict_to_load(
>>>     model, optim, optim_state_dict
>>> )
>>> optim.load_state_dict(optim_state_dict)

参数

  • model ( torch.nn.Module ) – 根模块(可能是也可能不是FullyShardedDataParallel实例),其参数已传入优化器optim
  • optim ( torch.optim.Optimizer ) – 用于model参数的优化器。
  • optim_state_dict (Dict[str, Any]) – 需要转换的目标优化器状态字典。若值为None,将使用optim.state_dict()。(默认值:None
  • group (dist.ProcessGroup) – 模型参数分片所在的进程组,若使用默认进程组则为None。(默认值:None

返回值:一个包含model优化器状态的dict。优化器状态的分片基于state_dict_type

返回类型:Dict[str, Any]


static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)

将优化器状态字典转换为可加载到与FSDP模型关联的优化器中的格式。

给定一个通过 optim_state_dict() 转换得到的 optim_state_dict,该方法会将其转换为扁平化的优化器状态字典,该字典可加载到 model 的优化器 optim 中。注意:model 必须是通过 FullyShardedDataParallel 进行分片的模型。


>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.distributed.fsdp import StateDictType
>>> from torch.distributed.fsdp import FullStateDictConfig
>>> from torch.distributed.fsdp import FullOptimStateDictConfig
>>> # Save a checkpoint
>>> model, optim = ...
>>> FSDP.set_state_dict_type(
>>>     model, >>    StateDictType.FULL_STATE_DICT, >>    FullStateDictConfig(rank0_only=False), >>    FullOptimStateDictConfig(rank0_only=False), >>)
>>> state_dict = model.state_dict()
>>> original_osd = optim.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(
>>>     model, >>    optim, >>    optim_state_dict=original_osd
>>> )
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
>>> state_dict, optim_state_dict = load_a_checkpoint()
>>> FSDP.set_state_dict_type(
>>>     model, >>    StateDictType.FULL_STATE_DICT, >>    FullStateDictConfig(rank0_only=False), >>    FullOptimStateDictConfig(rank0_only=False), >>)
>>> model.load_state_dict(state_dict)
>>> optim_state_dict = FSDP.optim_state_dict_to_load(
>>>     model, optim, optim_state_dict
>>> )
>>> optim.load_state_dict(optim_state_dict)

参数

  • model (torch.nn.Module) – 根模块(可能是也可能不是FullyShardedDataParallel实例),其参数已传入优化器optim
  • optim (torch.optim.Optimizer) – 用于model参数的优化器。
  • optim_state_dict (Dict[str, Any]) – 待加载的优化器状态字典。
  • is_named_optimizer ([bool]) – 该优化器是否为NamedOptimizer或KeyedOptimizer。仅当optim是TorchRec的KeyedOptimizer或torch.distributed的NamedOptimizer时设为True。
  • load_directly ([bool]) – 若设为True,本API将在返回结果前自动调用optim.load_state_dict(result);否则用户需自行调用optim.load_state_dict()(默认值:False)。
  • group (dist.ProcessGroup) – 模型参数分片所在的进程组,若使用默认进程组则为None(默认值:None)。

返回类型:dict[str, Any]


register_comm_hook(state, hook)

注册一个通信钩子。

该功能是一项增强,为用户提供了一个灵活的钩子,可以指定FSDP如何在多个工作节点间聚合梯度。

这个钩子可用于实现多种算法,例如GossipGrad和梯度压缩,这些算法在使用FullyShardedDataParallel训练时涉及不同的参数同步通信策略。

警告:FSDP通信钩子必须在初始前向传播运行前注册,且只能注册一次。

参数

  • state ( object ) - 传递给钩子以在训练过程中维护任何状态信息。

示例包括梯度压缩中的误差反馈、GossipGrad中下一次通信的对等节点等。

该状态由每个工作节点本地存储,并由该工作节点上的所有梯度张量共享。

  • hook (Callable) - 可调用对象,具有以下签名之一:
  1. hook: Callable[torch.Tensor] -None:

该函数接收一个Python张量,表示与该FSDP单元包装的模型(未被其他FSDP子单元包装的部分)对应的所有变量的完整、展平、未分片的梯度。

然后执行所有必要的处理并返回None

  1. hook: Callable[torch.Tensor, torch.Tensor] -None:

该函数接收两个Python张量,第一个表示与该FSDP单元包装的模型(未被其他FSDP子单元包装的部分)对应的所有变量的完整、展平、未分片的梯度。第二个表示预分配大小的张量,用于存储归约后的分片梯度块。

在这两种情况下,可调用对象都会执行所有必要的处理并返回None

签名1的可调用对象预期处理NO_SHARD情况下的梯度通信。

签名2的可调用对象预期处理分片情况下的梯度通信。


static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)

将优化器状态字典 optim_state_dict 的键类型重新映射为 optim_state_key_type

该功能可用于实现带有 FSDP 实例的模型与普通模型之间优化器状态字典的兼容性。

若要将 FSDP 完整优化器状态字典(即来自 full_optim_state_dict() 的字典)重新映射为使用参数 ID 键,并使其可加载至未封装模型:

>>> wrapped_model, wrapped_optim = ...
>>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
>>> nonwrapped_model, nonwrapped_optim = ...
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
>>> nonwrapped_optim.load_state_dict(rekeyed_osd)

要将普通优化器状态字典(来自未封装模型)重新映射为可加载到封装模型中的格式:

>>> nonwrapped_model, nonwrapped_optim = ...
>>> osd = nonwrapped_optim.state_dict()
>>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
>>> wrapped_model, wrapped_optim = ...
>>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
>>> wrapped_optim.load_state_dict(sharded_osd)

返回

优化器状态字典,使用optim_state_key_type指定的参数键重新映射键名。

返回类型:Dict[str, Any]


static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)

将完整的优化器状态字典从 rank 0 分发到所有其他 ranks。

返回每个 rank 上的分片优化器状态字典。

返回值与 shard_full_optim_state_dict() 相同,且在 rank 0 上,
第一个参数应为 full_optim_state_dict() 的返回值。


示例:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)  # only non-empty on rank 0
>>> # Define new model with possibly different world size
>>> new_model, new_optim, new_group = ...
>>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group)
>>> new_optim.load_state_dict(sharded_osd)

注意:shard_full_optim_state_dict()scatter_full_optim_state_dict() 均可用于获取待加载的分片优化器状态字典。假设完整优化器状态字典存储在CPU内存中:

  • 前者要求每个进程在CPU内存中保存完整字典,各进程独立进行分片且无需通信
  • 后者仅要求rank 0在CPU内存中保存完整字典,由rank 0将每个分片移至GPU内存(用于NCCL)并通过通信分发到对应进程

因此,前者总CPU内存开销更高,而后者通信开销更大。

参数说明

  • full_optim_state_dict (Optional[Dict[str, Any]]) - 对应未展平参数的优化器状态字典,在rank 0上保存完整非分片状态;非0 rank会忽略该参数
  • model (torch.nn.Module) - 根模块(可能是也可能不是FullyShardedDataParallel实例),其参数与full_optim_state_dict中的优化器状态对应
  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) - 传入优化器的输入,可以是参数组的list或参数的可迭代对象;若为None则默认使用model.parameters()。此参数已弃用,无需再传递(默认值:None
  • optim (Optional[torch.optim.Optimizer]) - 将加载本方法返回状态字典的优化器。推荐使用此参数替代optim_input(默认值:None
  • group (dist.ProcessGroup) - 模型使用的进程组,若为None则使用默认进程组(默认值:None

返回值:返回重构后的完整优化器状态字典,其中:

  • 参数映射为展平后的形式(而非原始未展平参数)
  • 仅包含当前rank对应的优化器状态部分

返回类型:Dict[str, Any]


static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)

设置目标模块所有子级FSDP模块的state_dict_type

同时支持(可选)配置模型和优化器的状态字典。

目标模块不必是FSDP模块。如果目标模块本身是FSDP模块,其state_dict_type也会被修改。

注意:此API应仅对顶层(根)模块调用。

注意:当根FSDP模块被其他nn.Module包裹时,此API允许用户透明地使用常规state_dict接口来保存模型检查点。例如以下场景:确保对非FSDP实例调用state_dict方法,同时对FSDP实例则转为分片状态字典实现:


示例:

>>> model = DDP(FSDP(...))
>>> FSDP.set_state_dict_type(
>>>     model, >>    StateDictType.SHARDED_STATE_DICT, >>    state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), >>    optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), >>)
>>> param_state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)

参数

  • module ( torch.nn.Module ) – 根模块。
  • state_dict_type (StateDictType) – 要设置的期望 state_dict_type
  • state_dict_config (Optional[StateDictConfig]) – 目标 state_dict_type 的配置。
  • optim_state_dict_config (Optional[OptimStateDictConfig]) – 优化器状态字典的配置。

返回值:返回一个包含模块先前状态字典类型和配置的 StateDictSettings 对象。

返回类型:StateDictSettings


static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)

切分完整优化器状态字典。

full_optim_state_dict 中的状态从非展平参数重新映射为展平参数,并限制为仅包含当前秩(rank)对应的优化器状态部分。

第一个参数应为 full_optim_state_dict() 的返回值。

示例:


>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim)
>>> torch.save(full_osd, PATH)
>>> # Define new model with possibly different world size
>>> new_model, new_optim = ...
>>> full_osd = torch.load(PATH)
>>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
>>> new_optim.load_state_dict(sharded_osd)

注意:shard_full_optim_state_dict()scatter_full_optim_state_dict() 均可用于获取待加载的分片优化器状态字典。假设完整优化器状态字典驻留在CPU内存中:

  • 前者要求每个rank在CPU内存中保存完整字典,各rank独立进行分片且无需通信
  • 后者仅要求rank 0在CPU内存中保存完整字典,由rank 0将各分片移至GPU内存(用于NCCL)并通过通信分发到对应rank

因此,前者总CPU内存开销更高,而后者的通信开销更大。

参数说明

  • full_optim_state_dict (Dict[str, Any]) - 对应未展平参数的优化器状态字典,包含完整未分片的优化器状态
  • model ( torch.nn.Module ) - 根模块(可能是也可能不是FullyShardedDataParallel实例),其参数与full_optim_state_dict中的优化器状态相对应
  • optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) - 传递给优化器的输入,可以是参数组的list或可迭代参数集合;若为None则默认使用model.parameters()。此参数已弃用,无需再传递(默认值:None
  • optim (Optional[torch.optim.Optimizer ]) - 将加载本方法返回状态字典的优化器。推荐优先使用此参数而非optim_input(默认值:None

返回值:返回重构后的完整优化器状态字典,该字典:
1、已从未展平参数映射为展平参数
2、仅包含当前rank对应的优化器状态部分

返回类型:Dict[str, Any]


static sharded_optim_state_dict(model, optim, group=None)

返回优化器状态字典的分片形式。

该API与full_optim_state_dict()类似,但会将所有非零维状态分块为ShardedTensor以节省内存。

注意:只有当模型state_dict是通过上下文管理器with state_dict_type(SHARDED_STATE_DICT):导出时,才应使用此API。

具体用法请参考full_optim_state_dict()

警告:返回的状态字典包含ShardedTensor,不能直接被常规的optim.load_state_dict使用。

返回类型:dict[str, Any]


static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)

设置目标模块下所有FSDP子模块的state_dict_type

此上下文管理器功能与set_state_dict_type()相同。详情请参阅set_state_dict_type()的文档说明。


示例:

>>> model = DDP(FSDP(...))
>>> with FSDP.state_dict_type(
>>>     model, >>    StateDictType.SHARDED_STATE_DICT, >>):
>>>     checkpoint = model.state_dict()

参数

  • module ( torch.nn.Module ) – 根模块。
  • state_dict_type (StateDictType) – 要设置的期望 state_dict_type
  • state_dict_config (Optional[StateDictConfig ]) – 目标 state_dict_type 对应的模型 state_dict 配置。
  • optim_state_dict_config (Optional[OptimStateDictConfig ]) – 目标 state_dict_type 对应的优化器 state_dict 配置。

返回类型:生成器


static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)

通过此上下文管理器暴露FSDP实例的完整参数。

模型完成前向/反向传播后,可用于获取参数进行额外处理或检查。它可以接受非FSDP模块,并根据recurse参数递归地为所有包含的FSDP模块及其子模块召唤完整参数。

注意:可在内部FSDP上使用。

注意不可在前向或反向传播过程中使用,也不能在该上下文中启动前向/反向传播。

注意:上下文管理器退出后,参数将恢复为本地分片状态,存储行为与前向传播相同。

注意:完整参数可被修改,但只有对应本地参数分片的部分会在上下文退出后保留(除非设置writeback=False,此时修改会被丢弃)。当FSDP不对参数分片时(当前仅world_size == 1NO_SHARD配置),无论writeback如何设置,修改都会被保留。

注意:此方法适用于非FSDP模块(可能包含多个独立FSDP单元)。此时给定参数将应用于所有包含的FSDP单元。

警告:当前不支持rank0_only=Truewriteback=True同时使用,会触发错误。因为上下文中各rank的模型参数形状不同,退出时写入会导致跨rank不一致。

警告offload_to_cpurank0_only=False组合会导致完整参数被冗余复制到同一机器的CPU内存中,可能引发CPU OOM风险。建议配合rank0_only=True使用offload_to_cpu

参数说明

  • recurse ([bool], 可选) – 是否递归召唤嵌套FSDP实例的所有参数(默认:True)
  • writeback ([bool], 可选) – 若为False,上下文退出时丢弃参数修改;禁用此选项可略微提升效率(默认:True)
  • rank0_only ([bool], 可选) – 若为True,仅全局rank 0会物化完整参数,其他rank保持分片参数。注意rank0_only=Truewriteback=True的组合不被支持,因上下文中各rank参数形状不同,退出时写入会导致不一致
  • offload_to_cpu ([bool], 可选) – 若为True,完整参数将卸载到CPU。当前仅分片参数会触发卸载(world_size=1NO_SHARD配置除外)。建议配合rank0_only=True使用以避免重复卸载到同一CPU内存
  • with_grads ([bool], 可选) – 若为True,梯度会随参数一起解除分片。当前仅当FSDP构造器传入use_orig_params=True且本方法设置offload_to_cpu=False时支持(默认:False

返回类型
生成器


class torch.distributed.fsdp.BackwardPrefetch(value)

该配置启用了显式的反向预取功能,通过在后向传递中实现通信与计算的重叠来提升吞吐量,但会略微增加内存使用量。

  • BACKWARD_PRE:实现最大程度的重叠,但内存使用量也最高。该模式会在当前参数组的梯度计算之前预取下一组参数。这使得下一次全收集操作当前梯度计算能够重叠执行,内存峰值时会同时保留当前参数组、下一组参数以及当前梯度数据。
  • BACKWARD_POST:实现较少重叠,但内存需求更低。该模式在当前参数组的梯度计算完成后才预取下一组参数。这使得当前规约散射操作下一组梯度计算能够重叠执行,并在为下一组参数分配内存前释放当前参数组,内存峰值时仅保留下一组参数和当前梯度数据。
  • FSDP的backward_prefetch参数支持设为None以完全禁用反向预取。该模式无任何重叠效果且不会增加内存开销。通常不建议采用此设置,因为它可能显著降低吞吐性能。

技术背景说明:对于使用NCCL后端的单个进程组,所有集合操作(即使来自不同流)都会争用同一设备上的NCCL流,这意味着集合操作的发起顺序将直接影响重叠效果。上述两种反向预取值对应不同的操作发起顺序。


class torch.distributed.fsdp.ShardingStrategy(value)

这里指定了FullyShardedDataParallel用于分布式训练的分片策略。

  • FULL_SHARD:参数、梯度和优化器状态均进行分片。

对于参数,该策略在前向计算前执行解分片(通过all-gather操作),前向计算后重新分片,反向计算前再次解分片,反向计算后重新分片。对于梯度,在反向计算后通过reduce-scatter操作进行同步和分片。分片的优化器状态由每个rank本地更新。

  • SHARD_GRAD_OP:计算过程中梯度和优化器状态保持分片,此外参数在计算外也保持分片。

对于参数,该策略在前向计算前执行解分片,前向计算后不重新分片,仅在反向计算后重新分片。分片的优化器状态由每个rank本地更新。在no_sync()上下文中,反向计算后参数不会重新分片。

  • NO_SHARD:参数、梯度和优化器状态不进行分片,而是像PyTorch的DistributedDataParallelAPI那样跨rank复制。对于梯度,该策略在反向计算后通过all-reduce操作进行同步。未分片的优化器状态由每个rank本地更新。
  • HYBRID_SHARD:在节点内应用FULL_SHARD策略,同时在节点间复制参数。由于昂贵的all-gather和reduce-scatter操作仅在节点内执行,这可以减少通信量,对中等规模模型可能更具性能优势。
  • _HYBRID_SHARD_ZERO2:在节点内应用SHARD_GRAD_OP策略,同时在节点间复制参数。与HYBRID_SHARD类似,但由于前向计算后未释放解分片的参数,节省了反向计算前的all-gather操作,可能提供更高的吞吐量。

class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>, ))

此配置用于FSDP原生的混合精度训练。

变量说明

  • param_dtype (Optional[torch.dtype]) - 指定前向传播和反向传播期间模型参数的数据类型,从而决定前向和反向计算的数据类型。在前向和反向之外,分片参数保持全精度(例如用于优化器步骤),而模型检查点时参数始终以全精度保存。(默认:None
  • reduce_dtype (Optional[torch.dtype]) - 指定梯度归约(即reduce-scatter或all-reduce)的数据类型。若为Noneparam_dtype不为None,则采用param_dtype值,仍以低精度运行梯度归约。允许与param_dtype不同,例如强制梯度归约以全精度运行。(默认:None
  • buffer_dtype (Optional[torch.dtype]) - 指定缓冲区的数据类型。FSDP不对缓冲区进行分片,而是在首次前向传播时将其转换为buffer_dtype并保持该类型。模型检查点时,除LOCAL_STATE_DICT外缓冲区均以全精度保存。(默认:None
  • keep_low_precision_grads ([bool]) - 若为False,FSDP在反向传播后将梯度提升至全精度以备优化器步骤使用;若为True,则保持梯度为归约时的数据类型,可节省支持低精度运行的自定义优化器的内存。(默认:False
  • cast_forward_inputs ([bool]) - 若为True,FSDP模块将其前向传播的args和kwargs转换为param_dtype,确保参数与输入数据类型匹配以满足多数运算要求。当仅对部分FSDP模块应用混合精度时可能需要设为True,此时混合精度子模块需重新转换输入。(默认:False
  • cast_root_forward_inputs ([bool]) - 若为True,根FSDP模块会覆盖cast_forward_inputs值,将其前向传播的args和kwargs转换为param_dtype。非根FSDP模块不受影响。(默认:True
  • _module_classes_to_ignore (collections.abc.Sequence[type[torch.nn.modules.module.Module]]) - 指定使用auto_wrap_policy时忽略混合精度的模块类:这些类的模块将单独应用FSDP且禁用混合精度(导致最终FSDP构造偏离指定策略)。未指定auto_wrap_policy时此参数无效。该API为实验性质可能变更。(默认:(_BatchNorm,)

注意事项

此API为实验性质,可能发生变化。

仅浮点张量会被转换为指定数据类型。

summon_full_params中,参数强制转为全精度,但缓冲区不受影响。

即使输入为float16bfloat16等低精度,层归一化和批归一化仍以float32累加。仅为这些归一化模块禁用混合精度意味着仿射参数保持float32,但会导致额外的all-gather和reduce-scatter操作,可能降低效率。若任务允许,建议仍对这些模块应用混合精度。

默认情况下,若用户传入包含_BatchNorm模块的模型并指定auto_wrap_policy,批归一化模块将单独应用FSDP且禁用混合精度。详见_module_classes_to_ignore参数。

MixedPrecision默认设置cast_root_forward_inputs=Truecast_forward_inputs=False。根FSDP实例的cast_root_forward_inputs优先于cast_forward_inputs,非根实例的cast_root_forward_inputs值被忽略。典型场景下(所有FSDP实例具有相同MixedPrecision配置且仅需在模型前向开始时转换输入至param_dtype),默认设置已足够。

对于具有不同MixedPrecision配置的嵌套FSDP实例,建议通过单独设置cast_forward_inputs来配置各实例前向传播前的输入转换。此时由于转换发生在各FSDP实例前向之前,父FSDP实例应使其非FSDP子模块在FSDP子模块之前运行,避免因不同MixedPrecision配置导致激活数据类型变化。


示例:

>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
>>> model[1] = FSDP(
>>>     model[1], >>    mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), >>)
>>> model = FSDP(
>>>     model, >>    mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), >>)

上面的示例展示了正常运行的情况。反之,如果将 model[1] 替换为 model[0],即让使用不同 MixedPrecision 的子模块先执行前向计算,那么 model[1] 就会错误地接收到 float16 类型的激活值,而非预期的 bfloat16 类型。


class torch.distributed.fsdp.CPUOffload(offload_params=False)

此配置用于启用 CPU 卸载功能。

变量说明

  • offload_params ([bool]) – 指定是否在参数不参与计算时将其卸载到 CPU。若设为 True,则梯度也会被卸载至 CPU,这意味着优化器步骤将在 CPU 上执行。

class torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)

StateDictConfig 是所有 state_dict 配置类的基类。用户需要实例化其子类(例如 FullStateDictConfig)来配置 FSDP 所支持的对应 state_dict 类型的相关设置。

变量说明

  • offload_to_cpu ([bool]) – 若设为 True,FSDP 会将状态字典的值卸载到 CPU;若设为 False,则保留在 GPU 上。(默认值:False

class torch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False)

FullStateDictConfig 是一个配置类,专为配合 StateDictType.FULL_STATE_DICT 使用而设计。我们建议在保存完整状态字典时,同时启用 offload_to_cpu=Truerank0_only=True 参数,以分别节省 GPU 内存和 CPU 内存。该配置类需通过 state_dict_type() 上下文管理器使用,示例如下:

>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> fsdp = FSDP(model, auto_wrap_policy=...)
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
>>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
>>>     state = fsdp.state_dict()
>>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0、>># To reload checkpoint for inference, finetuning, transfer learning, etc:
>>> model = model_fn()  # Initialize model in preparation for wrapping with FSDP
>>> if dist.get_rank() == 0:
>>> # Load checkpoint only on rank 0 to avoid memory redundancy
>>>     state_dict = torch.load("my_checkpoint.pt")
>>>     model.load_state_dict(state_dict)
>>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
>>> # communicates loaded checkpoint states from rank 0 to rest of the world.
>>> fsdp = FSDP(
...     model, 
...     device_id=torch.cuda.current_device(), 
...     auto_wrap_policy=..., 
...     sync_module_states=True, 
... )
>>> # After this point, all ranks have FSDP model with loaded checkpoint.

变量

  • rank0_only ([bool]) – 如果设为True,则仅 rank 0 进程保存完整的状态字典,非零 rank 进程保存空字典。如果设为False,则所有 rank 进程都会保存完整的状态字典。(默认值:False

class torch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False)

ShardedStateDictConfig 是一个配置类,专为与 StateDictType.SHARDED_STATE_DICT 配合使用而设计。

变量说明

  • _use_dtensor ([bool]) – 若设为 True,FSDP 会将状态字典值保存为 DTensor;若设为 False,则保存为 ShardedTensor。(默认值:False

警告:_use_dtensorShardedStateDictConfig 的私有字段,FSDP 通过该字段决定状态字典值的类型。用户不应手动修改 _use_dtensor


class torch.distributed.fsdp.LocalStateDictConfig(offload_to_cpu:  bool  = False)

class torch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)

OptimStateDictConfig 是所有 optim_state_dict 配置类的基类。用户应实例化子类(例如 FullOptimStateDictConfig)来配置 FSDP 支持的对应 optim_state_dict 类型的设置。

变量说明

  • offload_to_cpu ([bool]) – 若设为 True,FSDP 会将状态字典的张量值卸载到 CPU;若设为 False,则保留在原始设备上(除非启用了参数 CPU 卸载功能,否则原始设备为 GPU)。(默认值:True

class torch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)

变量

  • rank0_only ([bool]) – 如果设为True,则仅rank 0会保存完整的状态字典,非零rank保存空字典。如果设为False,则所有rank都会保存完整的状态字典。(默认值:False

class torch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False)

ShardedOptimStateDictConfig 是一个配置类,专为与 StateDictType.SHARDED_STATE_DICT 配合使用而设计。

变量说明

  • _use_dtensor ([bool]) – 若设为 True,FSDP 会将状态字典的值保存为 DTensor;若设为 False,则保存为 ShardedTensor。(默认值:False

警告
_use_dtensorShardedOptimStateDictConfig 的私有字段,FSDP 通过它来决定状态字典值的类型。用户不应手动修改此字段。


class torch.distributed.fsdp.LocalOptimStateDictConfig(offload_to_cpu:  bool  = False)

class torch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config:torch.distributed.fsdp.api.StateDictConfig )

torch.distributed.fsdp.fully_shard


PyTorch FSDP2 (fully_shard)

PyTorch FSDP2 提供了一种完全分片数据并行(FSDP)实现,旨在实现高性能的即时执行模式,同时采用逐参数分片以提升易用性。

  • 如果您是 FSDP 的新用户,我们建议从 FSDP2 开始使用,因其具有更好的易用性。
  • 如果您当前正在使用 FSDP1,请评估以下差异以决定是否应切换到 FSDP2:

与 PyTorch FSDP1 (FullyShardedDataParallel) 相比:

  • FSDP2 使用基于 DTensor 的维度 0 逐参数分片,相比 FSDP1 的扁平参数分片提供了更简单的分片表示,同时保持了相似的吞吐性能。具体来说,FSDP2 在数据并行工作节点间对每个参数沿维度 0 进行分块(使用 torch.chunk(dim=0)),而 FSDP1 会将一组张量展平、拼接并一起分块,这使得理解每个工作节点上的数据以及重新分片到不同并行模式变得复杂。逐参数分片提供了更直观的用户体验,放宽了对冻结参数的限制,并允许无通信(分片)的状态字典,而在 FSDP1 中则需要全收集操作。
  • FSDP2 采用不同的内存管理方法来处理多流使用场景,避免了 torch.Tensor.record_stream。这确保了确定性和预期的内存使用,且不需要像 FSDP1 的 limit_all_gathers=True 那样阻塞 CPU。
  • FSDP2 提供了手动控制预取和集体调度的 API,允许高级用户进行更多自定义。详情请参阅下文 FSDPModule 的方法。
  • FSDP2 简化了部分 API 接口:例如,FSDP2 不直接支持完整状态字典。用户可以使用 DTensor API(如 DTensor.full_tensor())或更高级的 API(如 PyTorch 分布式检查点 的分布式状态字典 API)自行将包含 DTensor 的分片状态字典重新分片为完整状态字典。此外,一些其他参数已被移除;详情请参阅此处

如果您是首次使用 FSDP,或上述任何一点符合您的使用场景,我们建议您考虑使用 FSDP2。

有关系统设计和实现的详细信息,请参阅此 RFC


注意:torch.distributed.fsdp.fully_shard 目前处于原型阶段,正在开发中。核心 API 可能不会更改,但我们可能会根据需要进行一些 API 调整。

前端 API 是 fully_shard,可以在 module 上调用:

torch.distributed.fsdp.fully_shard(module, *, mesh=None, reshard_after_forward=True, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), offload_policy=OffloadPolicy(), ignored_params=None)

module应用全分片数据并行(FSDP),其中FSDP将模块参数、梯度和优化器状态分片到数据并行工作节点上,以通信开销为代价节省内存。

初始化时,FSDP根据mesh指定的数据并行工作节点对模块参数进行分片。在前向计算前,FSDP通过全收集操作跨数据并行工作节点获取完整参数用于计算。若reshard_after_forwardTrue,则FSDP在前向计算后释放完整参数,并在反向计算前重新全收集。梯度计算完成后,FSDP释放完整参数并通过规约分散操作分发未分片梯度。

本实现使用DTensor表示分片参数(沿0维分片),而完整参数保持与原始模块参数相同类型(如原为torch.Tensor则仍为torch.Tensor)。模块的前向预钩子负责参数全收集,前向钩子负责释放参数(如需要)。类似的反向钩子处理参数收集与梯度分发。

为提高通信效率,本实现将多个张量分组进行集合操作。对module调用fully_shard()会构建包含module.parameters()中参数的分组(子模块已分组参数除外),因此应在模型上自底向上调用fully_shard()。每组参数通过单次集合操作完成全收集和梯度规约分散。分层分组(“逐层”)可实现内存峰值优化和通信/计算重叠。通常不应仅在顶层模块调用fully_shard()

参数说明

  • module (Union[nn.Module, List[nn.Module]) – 待分片的模块或模块列表,这些模块将被分组进行通信
  • mesh (Optional[[DeviceMesh](distributed.html#torch.distributed.device_mesh.DeviceMesh "torch.distributed.device_mesh.DeviceMesh")]) – 数据并行网格定义分片方式和设备:
    • 一维网格:参数完全分片(FSDP),采用(Shard(0),)布局
    • 二维网格:参数沿第1维分片且沿第0维复制(HSDP),采用(Replicate(), Shard(0))布局
    • 网格设备类型决定通信设备类型(如CUDA类设备使用当前设备)
  • reshard_after_forward (Union[[bool],* int ]) – 控制前向计算后的参数行为,平衡内存与通信:
    • True:前向后重新分片参数,反向时重新全收集
    • False:前向后保留完整参数,避免反向全收集
    • 整数值:指定前向后的分片规模(应为网格分片维度的非平凡除数,如节点内设备数torch.cuda.device_count()),以较小通信规模换取较高内存使用
    • 根FSDP状态默认设为False(因其参数通常需立即用于反向计算)
    • 前向后模块注册参数取决于该设置:分片参数(True时)、完整参数(False时)或缩网格分片参数(整数值时)。若需在前反向间修改参数,必须注册分片参数(对False或整数值可通过reshard()手动分片)
  • shard_placement_fn (Optional[Callable[[nn.Parameter],* Optional[Shard ]]]) – 自定义参数分片布局(如返回Shard(1)则沿1维分片)。当前非零维分片要求张量维度大小可被分片网格大小整除
  • mp_policy ( MixedPrecisionPolicy ) – 混合精度策略,控制该模块的参数/规约精度。详见MixedPrecisionPolicy
  • offload_policy (OffloadPolicy) – 卸载策略,控制参数/梯度/优化器状态卸载。详见OffloadPolicy及其子类
  • ignored_params (Optional[set[nn.Parameter]]) – 不需要FSDP分片的参数集合

返回值:返回应用FSDP后的模块(原地修改),类型为FSDPModule

调用fully_shard(module)会动态创建继承原模块类型和FSDPModule的新类。例如对linear: nn.Linear调用fully_shard(linear)会生成FSDPLinear类并转换模块类型。该方法不改变模块结构和参数全限定名,FSDPModule类提供特定于FSDP的方法支持。


class torch.distributed.fsdp.FSDPModule(*args, **kwargs) 

reshard()

对模块参数进行重新分片,如果未分片参数已分配则释放它们,并将分片后的参数注册到模块中。该方法不会递归执行。


set_all_reduce_hook(hook, *, stream=None)

参数

  • hook (Callable[[torch.Tensor], None]) – 用户自定义的all-reduce钩子函数,预期签名为hook(reduce_output: torch.Tensor) -> None

其中reduce_output表示:

  • 若仅使用FSDP则为reduce-scatter操作的输出
  • 若使用原生HSDP则为all-reduce操作的输出
  • stream (Optional[torch.cuda.Stream]) – 运行all-reduce钩子的CUDA流。注意:
    • 仅在不使用原生HSDP时需要设置此参数
    • 若使用原生HSDP,钩子将在HSDP内部定义的all-reduce流中自动执行

set_is_last_backward(is_last_backward)

设置下一次反向传播是否为最后一次。在最后一次反向传播时,FSDP会等待待处理的梯度归约操作完成,并清除用于反向预取的内部数据结构。这一特性对于微批次训练特别有用。


set_modules_to_backward_prefetch(modules)

设置当前FSDP模块在反向传播时应显式预取全聚集操作的FSDP模块。这会覆盖默认的反向预取实现(默认实现基于反向后序顺序预取下一个FSDP模块)。

传入包含前一个FSDP模块的单例列表,可获得与默认重叠行为相同的全聚集操作重叠效果。

若需更激进的重叠效果(会占用更多预留内存),则必须传入至少包含两个模块的列表。

参数:
modules (List[FSDPModule]) – 需要预取的FSDP模块列表。


set_modules_to_forward_prefetch(modules)

设置此FSDP模块在正向传播中应显式预取全收集操作的FSDP模块。预取操作会在本模块的全收集复制输出之后执行。

如果传入仅包含下一个FSDP模块的单例列表,将获得与默认重叠行为相同的全收集重叠效果,区别在于预取的全收集操作会从CPU端更早发起。要实现更激进的重叠效果(将占用更多预留内存),必须传入至少包含两个模块的列表。

参数

modules (List[FSDPModule]) – 需要预取的FSDP模块列表。


set_post_optim_event(event)

为根FSDP模块设置一个优化器步骤后事件,用于等待所有聚集流就绪。

默认情况下,根FSDP模块会在当前流上等待所有聚集流,以确保优化器步骤在开始全聚集前已完成。但如果优化器步骤后存在无关计算,这种方式可能会引入虚假依赖。该API允许用户提供自定义事件进行等待。当根模块完成事件等待后,该事件会被丢弃,因此每次迭代都应调用本API传入新事件。

参数

event (torch.Event) - 记录在优化器步骤后、用于等待所有聚集流的事件对象。


set_reduce_scatter_divide_factor(factor)

为reduce-scatter操作设置自定义的除法因子。这将通过NCCL的PreMulSum功能实现一个自定义的reduce操作,允许在归约前先乘以该因子。

参数

factor (float) – 自定义除法因子。


set_requires_all_reduce(requires_all_reduce, *, recurse=True)

设置该模块是否应执行梯度全归约操作。这可用于实现仅使用reduce-scatter而不进行全归约的梯度累积方案,适用于HSDP场景。


set_requires_gradient_sync(requires_gradient_sync, *, recurse=True)

设置模块是否应同步梯度。该功能可用于实现无需通信的梯度累积。对于HSDP,这将同时控制reduce-scatter和all-reduce操作。其功能等同于FSDP1中的no_sync。

参数说明

  • requires_gradient_sync ([bool]) – 控制是否对模块参数执行梯度归约操作。
  • recurse ([bool]) – 控制设置范围:仅作用于当前模块,还是递归作用于所有FSDP子模块。

set_reshard_after_backward(reshard_after_backward, *, recurse=True)

设置模块是否应在反向传播后重新分片参数。这在梯度累积期间可用于以更高内存为代价换取减少通信,因为在下一次前向传播前无需重新全收集未分片的参数。

参数

  • reshard_after_backward ([bool]) – 是否在反向传播后重新分片参数。
  • recurse ([bool]) – 是为所有FSDP子模块设置还是仅针对传入的模块设置。

set_unshard_in_backward(unshard_in_backward)

设置是否需要在反向传播时解除FSDP模块参数的共享状态。这一功能适用于专家级场景,当用户确定该FSDP模块参数组中的所有参数都不参与反向计算时(例如嵌入层),便可使用此设置。


unshard(async_op=False)

通过分配内存并全收集(all-gather)参数来解除模块参数的分片状态。此方法不会递归执行。解除分片操作遵循MixedPrecisionPolicy,因此如果设置了param_dtype,将按照该类型进行全收集。

参数

  • async_op ([bool]) - 若为True,则返回一个包含wait()方法的UnshardHandle对象用于等待解除分片操作;若为False,则返回None并在函数内部等待操作完成。

返回类型:Optional[UnshardHandle]

注意:当async_op=True时,FSDP会在模块的前向传播前自动等待待处理的解除分片操作。用户只需在需要前向传播前显式调用wait()方法即可。


class torch.distributed.fsdp.UnshardHandle 

一个用于等待 FSDPModule.unshard() 操作完成的句柄。


wait()

等待取消分片操作完成。这确保了当前流可以使用已取消分片的参数,这些参数现在已注册到模块中。


torch.distributed.fsdp.register_fsdp_forward_method(module, method_name)

module 上注册一个方法,使其被视为 FSDP 的前向传播方法。

FSDP 会在前向传播前执行参数的全收集(all-gather),并可选地在前向传播后释放参数(取决于 reshard_after_forward 的设置)。默认情况下,FSDP 仅对 nn.Module.forward() 执行此操作。此函数通过钩子机制,使用户指定的方法分别在执行前后运行前向/后向传播的预处理/后处理逻辑。如果 module 不是 FSDPModule 实例,则该操作无效。

参数

  • module (nn.Module) – 需要注册前向传播方法的模块。
  • method_name (str) – 前向传播方法的名称。

class torch.distributed.fsdp.MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True) 

该配置用于设置FSDP的混合精度。与autocast不同,这是在模块级别而非操作级别应用混合精度,意味着会保存低精度激活值用于反向传播,而高精度到低精度的转换仅发生在模块边界处。

FSDP与模块级混合精度配合良好,因为它始终在内存中保存高精度分片参数。换句话说,FSDP不需要额外内存来保存高精度参数副本用于优化器步骤。

变量说明:

  • param_dtype (Optional[torch.dtype]) - 指定未分片参数的数据类型,即前向/反向计算和参数全收集时使用的数据类型。若为None,则未分片参数使用原始数据类型。优化器步骤使用原始数据类型的已分片参数。(默认值:None
  • reduce_dtype (Optional[torch.dtype]) - 指定梯度规约(即reduce-scatter或all-reduce)时使用的数据类型。若为Noneparam_dtype不为None,则规约使用计算数据类型。该参数可用于在计算时使用低精度,同时保持梯度规约为全精度。若通过set_requires_gradient_sync()禁用梯度规约,FSDP将使用reduce_dtype累积梯度。(默认值:None
  • output_dtype (Optional[torch.dtype]) - 指定浮点前向输出结果的转换数据类型。可用于实现不同模块采用不同混合精度策略的场景。(默认值:None
  • cast_forward_inputs ([bool]) - 指定FSDP是否应将前向传播的浮点输入张量转换为param_dtype类型。

class torch.distributed.fsdp.OffloadPolicy 

这个基类表示不进行卸载的策略,仅用作 offload_policy 参数的默认值。


class torch.distributed.fsdp.CPUOffloadPolicy(pin_memory=True) 

该卸载策略将参数、梯度和优化器状态卸载到CPU。分片参数在all-gather操作前会从主机内存复制到设备内存。根据reshard_after_forward的设置,all-gather后的参数会被释放。

在反向传播过程中,分片梯度会从设备内存复制到主机内存,优化器步骤则在CPU上使用CPU优化器状态运行。

变量说明:

  • pin_memory ([bool]) – 是否固定分片参数和梯度的内存。固定内存可以实现更高效率的主机到设备/设备到主机的内存拷贝,并使拷贝操作与计算重叠。但固定内存无法被其他进程使用。若CPU内存不足,请将此参数设为False。(默认值:True

Tensor Parallelism - torch.distributed.tensor.parallel

Tensor Parallelism(张量并行,简称TP)构建在PyTorch DistributedTensor(DTensor)之上,提供多种并行风格:列并行(Colwise)、行并行(Rowwise)以及序列并行(Sequence Parallelism)。


警告:Tensor Parallelism API目前处于实验阶段,后续可能发生变更。

使用Tensor Parallelism并行化nn.Module的入口点是:

torch.distributed.tensor.parallel.parallelize_module(module, device_mesh=None, parallelize_plan=None, *, src_data_rank=0)

在PyTorch中通过基于用户指定方案并行化模块或子模块来应用张量并行。

我们根据parallelize_plan对模块或子模块进行并行化。该计划包含:

ParallelStyle,用于指示用户希望如何并行化模块或子模块。

用户还可以为每个模块的完全限定名称(FQN)指定不同的并行风格。

注意:parallelize_module仅接受一维DeviceMesh。如果使用二维或N维DeviceMesh,需先将DeviceMesh切片为一维子DeviceMesh再传入此API(例如device_mesh["tp"])。

参数

  • module (nn.Module) – 待并行化的模块。
  • device_mesh (DeviceMesh, 可选) – 描述DTensor设备网格拓扑的对象。若未指定,调用必须在DeviceMesh上下文中进行。
  • parallelize_plan (Union [ParallelStyle, Dict[str, ParallelStyle]], 可选) – 模块并行化方案。可以是包含张量并行输入/输出准备的ParallelStyle对象,或是模块FQN与其对应ParallelStyle对象的字典。若未指定,当前调用不会执行任何操作。

关键字参数

  • src_data_rank ( int , 可选) – 逻辑/全局张量的源数据秩,由distribute_tensor()用于将分片/副本分发到其他秩。默认使用每个DeviceMesh维度上的group_rank=0作为源数据以保持单设备语义。若显式传入Noneparallelize_module()将直接使用本地数据,而非通过分发/广播保持单设备语义。默认值:0

返回值:并行化后的nn.Module对象。

返回类型:Module


示例:

>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> >
>>> # Define the module.
>>> m = Model(...)
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
>>> >

注意:对于像Attention、MLP层这样的复杂模块架构,我们建议将不同的ParallelStyle组合使用(例如ColwiseParallelRowwiseParallel),并通过parallelize_plan传递,以实现所需的分片计算。

Tensor Parallelism支持以下并行风格:

class torch.distributed.tensor.parallel.ColwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)

以列式方式对兼容的 nn.Module 进行分区。当前支持 nn.Linearnn.Embedding

用户可将其与 RowwiseParallel 组合使用,以实现更复杂模块的分片(例如 MLP、Attention)。

关键字参数

  • input_layouts (Placement, 可选) – 输入张量在 nn.Module 中的 DTensor 布局,用于将输入张量标注为 DTensor。若未指定,则默认输入张量为副本形式。
  • output_layouts (Placement, 可选) – 输出张量在 nn.Module 中的 DTensor 布局,用于确保模块输出符合用户预期的布局。若未指定,输出张量将在最后一维分片。
  • use_local_output (bool, 可选) – 是否使用本地 torch.Tensor 而非 DTensor 作为模块输出,默认值:True。

返回

一个表示 nn.Module 列式分片的 ParallelStyle 对象。


示例

>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "w1" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> >
>>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor
>>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.
>>> >
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()})
>>> ...

注意:默认情况下,如果未指定 output_layoutsColwiseParallel 的输出会在最后一个维度上进行分片。如果存在需要特定张量形状的运算符(例如在配对的 RowwiseParallel 之前),请记住,若输出被分片,则可能需要根据分片后的尺寸调整该运算符。


class torch.distributed.tensor.parallel.RowwiseParallel(*, input_layouts=None, output_layouts=None, use_local_output=True)

以行方式对兼容的 nn.Module 进行分区。当前支持 nn.Linearnn.Embedding

用户可结合 ColwiseParallel 来实现更复杂模块的分片(例如 MLP、Attention)。

关键字参数

  • input_layouts (Placement, 可选) – 用于标注输入张量成为 DTensor 的布局参数。若未指定,则默认输入张量在最后一个维度分片。
  • output_layouts (Placement, 可选) – 确保模块输出符合用户预期布局的参数。若未指定,输出张量将被复制为全副本。
  • use_local_output (bool, 可选) – 是否使用本地 torch.Tensor 而非 DTensor 作为模块输出,默认值:True。

返回值
返回代表 nn.Module 行分片的 ParallelStyle 对象。


示例

>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "w2" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> >
>>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim
>>> # and the output of "w2" will return a replicated :class:`torch.Tensor`.
>>> >
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}), >>...

class torch.distributed.tensor.parallel.SequenceParallel(*, sequence_dim=1, use_local_output=False)

SequenceParallel(序列并行)会复制兼容的nn.Module参数,并在序列维度上对分片输入执行分片计算。当前支持nn.LayerNormnn.Dropout以及RMSNorm的Python实现

该模式实现了论文《减少大型Transformer模型中的激活重计算》中描述的操作。

若传入该nn.Module的输入是torch.Tensor,则假定输入已在序列维度分片,并将其转换为序列维度分片的DTensor。若传入的输入已是DTensor但未在序列维度分片,则会重新分配输入使其在序列维度分片。

nn.Module的输出将在序列维度分片。

关键字参数

  • sequence_dim (int, 可选) – 用于指定输入张量的序列维度,该参数会将输入张量标注为序列维度分片的DTensor,默认值:1
  • use_local_output (bool, 可选) – 是否对模块输出使用本地torch.Tensor而非DTensor,默认值:False

返回
一个代表nn.Module序列并行化的ParallelStyle对象。


示例

>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...)  # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> >
>>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
>>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
>>> >
>>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}), >>...

注意:SequenceParallel 风格假设 nn.Module 中的权重采用全1初始化(例如 nn.LayerNormRMSNorm,这些模块默认采用全1初始化)。如果这些模块的权重采用自定义初始化方式,则需要在并行化前后广播权重以确保权重被正确复制。

若只需为 nn.Module 的输入输出配置 DTensor 布局并执行必要的布局重分布,而无需将模块参数分发为 DTensor,在调用 parallelize_module 时可在 parallelize_plan 中使用以下 ParallelStyle

class torch.distributed.tensor.parallel.PrepareModuleInput(*, input_layouts=None, desired_input_layouts=None, input_kwarg_layouts=None, desired_input_kwarg_layouts=None, use_local_output=False)

配置 nn.Module 的输入参数,根据 input_layouts 在运行时将 nn.Module 的输入张量转换为 DTensor,并按照 desired_input_layouts 执行布局重分布。

关键字参数

  • input_layouts (Union[Placement, Tuple[Optional[Placement]]]) - 用于指定 nn.Module 输入张量的 DTensor 布局,该参数用于将输入张量转换为 DTensor。如果某些输入不是 torch.Tensor 或无需转换为 DTensor,需要用 None 作为占位符。默认值:None。
  • desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]) - 用于指定 nn.Module 输入张量的期望 DTensor 布局,该参数确保 nn.Module 的输入具有期望的 DTensor 布局。此参数需要与 input_layouts 长度相同。默认值:None。
  • input_kwarg_layouts (Dict[str, Placement]) - 用于指定 nn.Module 输入关键字参数的 DTensor 布局,该参数用于将输入关键字参数张量转换为 DTensor。默认值:None。
  • desired_input_kwarg_layouts – (Dict[str, Placement]) - 用于指定 nn.Module 输入关键字参数的期望 DTensor 布局,该参数确保 nn.Module 的输入具有期望的 DTensor 布局。默认值:None。
  • use_local_output ([bool], 可选) - 是否对模块输入使用本地 torch.Tensor 而非 DTensor。默认值:False。

返回值:返回一个 ParallelStyle 对象,用于准备 nn.Module 输入的分片布局。


示例:

>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> >
>>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
>>> # and then redistributed to Replicated DTensor.
>>> parallelize_module(
>>>     block, # this can be a submodule or module
>>>     tp_mesh, >>    parallelize_plan={
>>>         "attn": PrepareModuleInput(
>>>             input_layouts=(Shard(0), None, None, 
...), >>            desired_input_layouts=(Replicate(), None, None, 
...)
>>>         ), >>    }
>>> )

class torch.distributed.tensor.parallel.PrepareModuleOutput(*, output_layouts, desired_output_layouts, use_local_output=True)

配置nn.Module的输出,在运行时根据output_layoutsnn.Module的输出张量转换为DTensor,并根据desired_output_layouts执行布局重分布。

关键字参数

  • output_layouts (Union[Placement , Tuple[Placement ]]) - 用于指定nn.Module输出张量的DTensor布局,当输出为torch.Tensor时将其转换为DTensor。如果某些输出不是torch.Tensor或无需转换,需要用None作为占位符。
  • desired_output_layouts (Union[Placement , Tuple[Placement ]]) - 指定nn.Module输出张量的目标DTensor布局,用于确保模块输出具有预期的DTensor布局。
  • use_local_output ([bool], 可选) - 是否对模块输出使用本地torch.Tensor而非DTensor,默认值为True。

返回值:返回一个ParallelStyle对象,用于设置nn.Module输出张量的分片布局。


示例:

>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> >
>>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor
>>> # and then redistributed to Sharded DTensor.
>>> parallelize_module(
>>>     block, # this can be a submodule or module
>>>     tp_mesh, >>    parallelize_plan = PrepareModuleOutput(
>>>         output_layouts=Replicate(), >>        desired_output_layouts=Shard(0)
>>>     )
>>> )

注意:当使用 Shard(dim) 作为上述 ParallelStyle 的输入/输出布局时,我们假设输入/输出激活张量在 TP 操作的 DeviceMesh 上沿张量维度 dim 均匀分片。例如,由于 RowwiseParallel 接受在最后一个维度分片的输入,它假设输入张量已在最后一个维度上均匀分片。对于非均匀分片的激活张量,用户可以直接将 DTensor 传入分区模块,并通过设置 use_local_output=False 使每个 ParallelStyle 处理后返回 DTensor,此时 DTensor 会记录非均匀分片信息。

对于 Transformer 这类模型,我们建议用户在 parallelize_plan 中同时使用 ColwiseParallelRowwiseParallel,以实现整个模型(包括注意力层和 MLP)的预期分片效果。

并行化的交叉熵损失计算(损失并行)可通过以下上下文管理器支持:

torch.distributed.tensor.parallel.loss_parallel()

一个支持损失并行计算的上下文管理器,当输入在类别维度上分片时,可以执行高效的并行化损失计算。目前仅支持交叉熵损失。

在该上下文管理器内,可以像往常一样使用 cross_entropy()CrossEntropyLoss,但需满足以下输入参数假设。

对应的 backward() 调用(如有)也需要在该上下文管理器下进行。

参数

  • input (DTensor) – 输入logits。假设在类别维度上分片。
  • target (Union [torch.Tensor, DTensor]) – 必须是真实类别索引(当前不支持类别概率)。假设在 DeviceMesh 上复制。
  • weight (Union [torch.Tensor, DTensor], 可选) – 如果提供,假设在 DeviceMesh 上复制。
  • label_smoothing – 当前不支持。

返回

一个复制的 DTensor


示例

这里手动创建了一个分片的DTensor来展示用法。实际应用中,它通常是TP模块的输出。


>>> from torch.distributed.tensor.parallel import loss_parallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> device_mesh = init_device_mesh("cuda", (8,))
>>> input = torch.randn(4, 16, device="cuda", requires_grad=True)
>>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)])
>>> target = torch.randint(16, (4,), device="cuda")
>>> with loss_parallel():
>>>     loss = F.cross_entropy(dist_input, target, reduction="mean")
>>>     loss.backward()
>>> ...

警告:loss_parallel API 目前处于实验阶段,后续可能会发生变化。


分布式优化器


警告:当前不支持在使用CUDA张量时使用分布式优化器

torch.distributed.optim提供了DistributedOptimizer,它接收一个远程参数列表(RRef)并在参数所在的worker节点上本地运行优化器。该分布式优化器可以使用任何本地优化器基类来在每个worker上应用梯度。


class torch.distributed.optim.DistributedOptimizer(optimizer_class, params_rref, *args, **kwargs)

分布式优化器(DistributedOptimizer)接收分布在各个工作节点上的参数的远程引用,并为每个参数在本地应用指定的优化器。

该类通过 get_gradients() 方法来获取特定参数的梯度。

step() 的并发调用(无论来自相同或不同客户端)将在每个工作节点上串行执行——因为每个工作节点的优化器一次只能处理一组梯度。但无法保证完整的正向-反向-优化器序列会为一个客户端连续执行。这意味着应用的梯度可能不对应于给定工作节点上执行的最新正向传递。此外,不同工作节点之间也没有保证的执行顺序。

分布式优化器默认启用 TorchScript 来创建本地优化器,这样在多线程训练(如分布式模型并行)时,优化器更新不会被 Python 全局解释器锁(GIL)阻塞。目前大多数优化器都支持此功能。您也可以参考 PyTorch 教程中的这个示例来为自己的自定义优化器启用 TorchScript 支持。

参数

  • optimizer_class ([optim.Optimizer](optim.html#torch.optim.Optimizer "torch.optim.Optimizer")) – 要在每个工作节点上实例化的优化器类。
  • params_rref (list[RRef]) – 要优化的本地或远程参数的 RRef 列表。
  • args – 传递给每个工作节点上优化器构造函数的参数。
  • kwargs – 传递给每个工作节点上优化器构造函数的参数。

示例:

>>> import torch.distributed.autograd as dist_autograd
>>> import torch.distributed.rpc as rpc
>>> from torch import optim
>>> from torch.distributed.optim import DistributedOptimizer
>>> >
>>> with dist_autograd.context() as context_id:
>>>   # Forward pass.
>>>   rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>>   rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>>   loss = rref1.to_here() + rref2.to_here()
>>> >
>>>   # Backward pass.
>>>   dist_autograd.backward(context_id, [loss.sum()])
>>> >
>>>   # Optimizer.
>>>   dist_optim = DistributedOptimizer(
>>>      optim.SGD, >>     [rref1, rref2], >>     lr=0.05, >>  )
>>>   dist_optim.step(context_id)

step(context_id)

执行单次优化步骤。

该方法会在每个包含待优化参数的 worker 上调用 torch.optim.Optimizer.step(),并阻塞直至所有 worker 返回。提供的 context_id 将用于检索对应的 context,该上下文包含应应用于参数的梯度。

参数

  • context_id - 用于运行优化器步骤的自动求导上下文 ID。

class torch.distributed.optim.PostLocalSGDOptimizer(optim, averager)

封装一个任意的 torch.optim.Optimizer 并运行 post-local SGD。该优化器在每一步都运行本地优化器。

在预热阶段结束后,它会在应用本地优化器后定期对参数进行平均。

参数

  • optim ([Optimizer](optim.html#torch.optim.Optimizer "torch.optim.optimizer.Optimizer")) – 本地优化器。
  • averager (ModelAverager) – 用于运行 post-localSGD 算法的模型平均器实例。

示例

>>> import torch
>>> import torch.distributed as dist
>>> import torch.distributed.algorithms.model_averaging.averagers as averagers
>>> import torch.nn as nn
>>> from torch.distributed.optim import PostLocalSGDOptimizer
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
>>>   PostLocalSGDState, >>  post_localSGD_hook, >>)
>>> >
>>> model = nn.parallel.DistributedDataParallel(
>>>    module, device_ids=[rank], output_device=rank
>>> )
>>> >
>>> # Register a post-localSGD communication hook.
>>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>> >
>>> # Create a post-localSGD optimizer that wraps a local optimizer.
>>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as >># ``start_localSGD_iter`` used in ``PostLocalSGDState``.
>>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
>>> opt = PostLocalSGDOptimizer(
>>>     optim=local_optim, >>    averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
>>> )
>>> >
>>> # In the first 100 steps, DDP runs global gradient averaging at every step.
>>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default), >># and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
>>> for step in range(0, 200):
>>>    opt.zero_grad()
>>>    loss = loss_fn(output, labels)
>>>    loss.backward()
>>>    opt.step()

load_state_dict(state_dict)

这与 torch.optim.Optimizerload_state_dict() 方法功能相同,但还会将模型平均器的步长值恢复为提供的 state_dict 中保存的值。

如果 state_dict 中没有 "step" 条目,系统会发出警告并将模型平均器的步长初始化为 0。


state_dict()

这与 torch.optim.Optimizerstate_dict() 功能相同,但额外增加了一个条目用于记录模型平均器的步骤到检查点,以确保重新加载时不会再次引发不必要的预热过程。


step()

执行单次优化步骤(参数更新)。


class torch.distributed.optim.ZeroRedundancyOptimizer(params, optimizer_class, process_group=None, parameters_as_bucket_view=False, overlap_with_ddp=False, **defaults)

包装一个任意的 optim.Optimizer 并在组内各 rank 之间分片其状态。

分片方式遵循 ZeRO 论文描述。

每个 rank 的本地优化器实例仅负责更新约 1 / world_size 的参数,因此只需维护 1 / world_size 的优化器状态。本地参数更新完成后,每个 rank 会将其参数广播给所有其他节点,以保持所有模型副本的状态一致。

ZeroRedundancyOptimizer 可与 torch.nn.parallel.DistributedDataParallel 结合使用,以降低单 rank 的峰值内存消耗。

ZeroRedundancyOptimizer 使用排序贪心算法在每个 rank 上打包若干参数。每个参数仅属于单一 rank,不会被分割到多个 rank。这种划分是任意的,可能与参数注册顺序或使用顺序不一致。

参数

  • params (Iterable) - 包含所有待分片参数的 torch.Tensordict 的可迭代对象

关键字参数

  • optimizer_class (torch.nn.Optimizer) - 本地优化器的类

  • process_group (ProcessGroup, 可选) - torch.distributedProcessGroup(默认使用 torch.distributed.init_process_group() 初始化的 dist.group.WORLD

  • parameters_as_bucket_view ([bool], 可选) - 若为 True,参数会被打包到桶中以加速通信,且 param.data 字段指向桶视图的不同偏移量;若为 False,则单独通信每个参数,且每个 params.data 保持不变(默认:False

  • overlap_with_ddp ([bool], 可选) - 若为 Truestep() 将与 DistributedDataParallel 的梯度同步过程重叠执行,这要求:
    1、optimizer_class 必须是函数式优化器或具有等效功能
    2、需注册来自 ddp_zero_hook.py 的 DDP 通信钩子
    参数会打包为与 DistributedDataParallel 匹配的桶,此时 parameters_as_bucket_view 参数将被忽略。

    若为 Falsestep() 将在反向传播后独立执行(默认行为)(默认:False

  • **defaults - 其他尾部参数,将传递给本地优化器


示例

>>> import torch.nn as nn
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> model = nn.Sequential([nn.Linear(2000, 2000).to(rank) for _ in range(20)])
>>> ddp = DDP(model, device_ids=[rank])
>>> opt = ZeroRedundancyOptimizer(
>>>     ddp.parameters(), >>    optimizer_class=torch.optim.Adam, >>    lr=0.01
>>> )
>>> ddp(inputs).sum().backward()
>>> opt.step()

警告:目前 ZeroRedundancyOptimizer 要求所有传入参数必须是相同密集类型。

警告:如果设置 overlap_with_ddp=True,请注意以下情况:根据当前 DistributedDataParallelZeroRedundancyOptimizer 重叠的实现方式,前两到三次训练迭代不会执行优化器参数更新(具体次数取决于 static_graph=Falsestatic_graph=True)。这是因为需要获取 DistributedDataParallel 使用的梯度分桶策略信息——当 static_graph=False 时该信息在第二次前向传播后才会确定,而 static_graph=True 时则需等到第三次前向传播。解决方法之一是在训练数据前添加虚拟输入。

警告:ZeroRedundancyOptimizer 仍处于实验阶段,功能可能发生变化。


add_param_group(param_group)

Optimizerparam_groups 添加一个参数组。

在微调预训练网络时,这个方法非常有用——随着训练进行,原本冻结的层可以变为可训练状态,并添加到 Optimizer 中。

参数说明

  • param_group ( dict ) - 指定待优化的参数及该组特有的优化选项。

警告说明
此方法会更新所有分区的参数分片,但必须在所有计算节点上调用。若仅部分节点调用该方法,会导致训练挂起,因为通信原语的调用依赖于托管参数,且要求所有节点必须基于同一组参数参与计算。


consolidate_state_dict(to=0)

将各 rank 的 state_dict 列表(每个 rank 一个)合并到目标 rank 上。

参数

  • to (int) – 接收优化器状态的 rank(默认值:0)。

抛出异常
RuntimeError – 若 overlap_with_ddp=True 且此方法在 ZeroRedundancyOptimizer 实例完全初始化前被调用(完全初始化需等待 DistributedDataParallel 梯度桶重建完成)。

警告:必须在所有 rank 上调用此方法。


property join_device:  device 

返回默认设备。


join_hook(**kwargs)

返回 ZeRO 连接钩子。

该钩子通过遮蔽优化器步骤中的集体通信操作,实现在非均匀输入数据上的训练。

调用此钩子前必须正确设置梯度。

参数

  • kwargs ( dict ) – 一个包含运行时修改连接钩子行为的关键字参数的字典;所有共享同一连接上下文管理器的 Joinable 实例都会收到相同的 kwargs 值。

此钩子不支持任何关键字参数,即 kwargs 未被使用。


property join_process_group:  Any

返回进程组。


load_state_dict(state_dict)

从输入的 state_dict 中加载与指定 rank 相关的状态,并根据需要更新本地优化器。

参数

  • state_dict ( dict ) – 优化器状态;应为调用 state_dict() 返回的对象。

抛出异常

RuntimeError – 如果 overlap_with_ddp=True 且此方法在 ZeroRedundancyOptimizer 实例完全初始化之前被调用(完全初始化发生在 DistributedDataParallel 梯度桶重建完成之后)。


state_dict()

返回当前节点已知的最后一个全局优化器状态。

可能引发的异常

RuntimeError —— 当满足以下任一条件时抛出:
1、设置overlap_with_ddp=True时,在ZeroRedundancyOptimizer实例完全初始化前调用本方法(初始化完成标志是DistributedDataParallel梯度桶重建完成);
2、调用本方法前未先调用consolidate_state_dict()方法。

返回类型:dict[str , Any ]


step(closure=None, **kwargs)

执行单次优化器步骤并同步所有进程间的参数。

参数

  • closure (Callable) – 用于重新评估模型并返回损失值的闭包函数;大多数优化器可省略此参数。

返回值:取决于底层本地优化器的可选损失值。

返回类型:Optional[float]

注意:所有额外参数都将原样传递给基础优化器。


流水线并行


注意:torch.distributed.pipelining 目前处于 alpha 阶段且正在开发中。API 可能会发生变化。该功能是从 PiPPy 项目迁移而来。


为什么需要流水线并行?

流水线并行是深度学习中基础的并行方式之一。它允许将模型执行过程进行划分,使得多个微批次能够同时执行模型代码的不同部分。流水线并行在以下场景中尤为有效:

  • 大规模训练
  • 带宽受限的集群
  • 大模型推理

这些场景的共同特点是:单个设备的计算量无法掩盖传统并行方式(如FSDP的权重全收集操作)带来的通信开销。


什么是 torch.distributed.pipelining

虽然流水线并行在扩展性方面前景广阔,但其实现往往颇具挑战性,因为它不仅需要对模型权重进行划分,还需要拆分模型执行过程。执行过程的划分通常需要对模型代码进行侵入式修改。另一重复杂性来源于分布式环境中的微批次调度,同时还需考虑数据流依赖关系

pipelining 包提供了一套自动化工具链,能够自动完成上述操作,从而在通用模型上轻松实现流水线并行。

该工具包包含两个核心组件:拆分前端分布式运行时。拆分前端直接接收原始模型代码,将其分割为多个"模型分区",并捕获数据流关系。分布式运行时则在不同设备上并行执行流水线阶段,处理微批次划分、调度、通信和梯度传播等任务。

总体而言,pipelining 包提供以下核心功能:

  • 基于简单配置的模型代码自动拆分
  • 全面支持多种流水线调度策略(包括GPipe、1F1B、交错式1F1B和循环BFS),并提供自定义调度器开发基础设施
  • 原生支持跨主机流水线并行(这是PP技术最典型的应用场景,适用于低速网络互联环境)
  • 可与PyTorch其他并行技术(如数据并行DDP/FSDP或张量并行)组合使用。TorchTitan项目展示了在Llama模型上实现"3D并行"的应用案例。

第一步:构建 PipelineStage

在使用 PipelineSchedule 之前,我们需要先创建 PipelineStage 对象,这些对象封装了在该阶段运行的模型部分。PipelineStage 负责分配通信缓冲区,并创建发送/接收操作以与对等节点通信。它管理中间缓冲区,例如尚未被消费的前向输出,并提供运行阶段模型反向传播的实用工具。

PipelineStage 需要知道阶段模型的输入和输出形状,以便正确分配通信缓冲区。这些形状必须是静态的,即在运行时,形状不能每一步都变化。如果运行时形状与预期形状不匹配,将抛出 PipeliningShapeError 异常。在与其他并行技术组合或应用混合精度时,必须考虑这些技术,以便 PipelineStage 在运行时知道阶段模块输出的正确形状(和数据类型)。

用户可以直接构造 PipelineStage 实例,方法是传入一个 nn.Module,表示应在该阶段运行的模型部分。这可能需要对原始模型代码进行修改。具体示例请参见选项1:手动拆分模型

或者,拆分前端可以使用图分区技术自动将模型拆分为一系列 nn.Module。此技术要求模型可以通过 torch.Export 进行追踪。生成的 nn.Module 与其他并行技术的组合仍处于实验阶段,可能需要一些变通方法。如果用户难以修改模型代码,使用此前端可能更具吸引力。更多信息请参见选项2:自动拆分模型


步骤2:使用PipelineSchedule执行

现在我们可以将PipelineStage附加到流水线调度器上,并通过输入数据运行该调度器。以下是一个GPipe示例:

from torch.distributed.pipelining import ScheduleGPipe

# Create a schedule
schedule = ScheduleGPipe(stage, n_microbatches)

# Input data (whole batch)
x = torch.randn(batch_size, in_dim, device=device)

# Run the pipeline with input `x`
# `x` will be divided into microbatches automatically if rank == 0:
    schedule.step(x)
else:
    output = schedule.step()

请注意,上述代码需要在每个工作节点上启动,因此我们使用一个启动器服务来启动多个进程:

torchrun --nproc_per_node=2 example.py

模型分割方案


方案一:手动拆分模型

要直接构建一个PipelineStage,用户需要负责提供一个单独的nn.Module实例,该实例需包含相关的nn.Parametersnn.Buffers,并定义一个forward()方法来执行该阶段相关的操作。例如,Torchtitan中定义的Transformer类精简版展示了一种构建易于分区模型的模式。


class Transformer(nn.Module):
    def __init__(self, model_args: ModelArgs):
        super().__init__()

        self.tok_embeddings = nn.Embedding(...)

        # Using a ModuleDict lets us delete layers without affecting names,  # ensuring checkpoints will correctly save and load.
        self.layers = torch.nn.ModuleDict()
        for layer_id in range(model_args.n_layers):
            self.layers[str(layer_id)] = TransformerBlock(...)

        self.output = nn.Linear(...)

    def forward(self, tokens: torch.Tensor):
        # Handling layers being 'None' at runtime enables easy pipeline splitting
        h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

        for layer in self.layers.values():
            h = layer(h, self.freqs_cis)

        h = self.norm(h) if self.norm else h
        output = self.output(h).float() if self.output else h
        return output

以这种方式定义的模型可以轻松按阶段进行配置,具体步骤如下:

首先初始化整个模型(使用 meta-device 避免内存不足错误),然后删除该阶段不需要的层,最后创建一个封装模型的 PipelineStage。例如:

with torch.device("meta"):
    assert num_stages == 2, "This is a simple 2-stage example"

    # we construct the entire model, then delete the parts we do not need for this stage
    # in practice, this can be done using a helper function that automatically divides up layers across stages.
    model = Transformer()

    if stage_index == 0:
        # prepare the first stage model
        del model.layers["1"]
        model.norm = None
        model.output = None

    elif stage_index == 1:
        # prepare the second stage model
        model.tok_embeddings = None
        del model.layers["0"]

    from torch.distributed.pipelining import PipelineStage
    stage = PipelineStage(
        model,  stage_index,  num_stages,  device, )

当与其他数据或模型并行技术结合使用时,如果模型分块的输出形状/数据类型会受到影响,可能还需要指定 output_args


选项2:自动拆分模型

如果您拥有完整模型,且不想花费时间将其修改为一系列"模型分区",那么pipeline API可以为您提供帮助。以下是一个简单示例:

class Model(torch.nn.Module):
    def __init__(self) -None:
        super().__init__()
        self.emb = torch.nn.Embedding(10, 3)
        self.layers = torch.nn.ModuleList(
            Layer() for _ in range(2)
        )
        self.lm = LMHead()

    def forward(self, x: torch.Tensor) -torch.Tensor:
        x = self.emb(x)
        for layer in self.layers:
            x = layer(x)
        x = self.lm(x)
        return x

如果打印模型,我们会看到多个层级结构,这使得手动拆分变得困难:

Model(
  (emb): Embedding(10, 3)
  (layers): ModuleList(
    (0-1): 2 x Layer(
      (lin): Linear(in_features=3, out_features=3, bias=True)
    )
  )
  (lm): LMHead(
    (proj): Linear(in_features=3, out_features=3, bias=True)
  )
)

让我们看看 pipeline API 的工作原理:

from torch.distributed.pipelining import pipeline, SplitPoint

# An example micro-batch input
x = torch.LongTensor([1, 2, 4, 5])

pipe = pipeline(
    module=mod, mb_args=(x,), split_spec={
        "layers.1": SplitPoint.BEGINNING, }
)


pipeline API 根据给定的 split_spec 对模型进行分割,其中:

SplitPoint.BEGINNING 表示在 forward 函数中特定子模块执行之前添加分割点,类似地,SplitPoint.END 表示在此类子模块执行之后添加分割点。

如果我们执行 print(pipe),可以看到:

GraphModule(
  (submod_0): GraphModule(
    (emb): InterpreterModule()
    (layers): Module(
      (0): InterpreterModule(
        (lin): InterpreterModule()
      )
    )
  )
  (submod_1): GraphModule(
    (layers): Module(
      (1): InterpreterModule(
        (lin): InterpreterModule()
      )
    )
    (lm): InterpreterModule(
      (proj): InterpreterModule()
    )
  )
)

def forward(self, x):
    submod_0 = self.submod_0(x);  x = None
    submod_1 = self.submod_1(submod_0);  submod_0 = None
    return (submod_1,)


“模型分区”由子模块(submod_0submod_1)表示,每个子模块都使用原始模型的操作、权重和层次结构进行重构。此外,还重构了一个“根级别”的forward函数,用于捕获这些分区之间的数据流。后续将由流水线运行时以分布式方式重放这些数据流。

Pipe对象提供了一个用于检索“模型分区”的方法:

stage_mod : nn.Module = pipe.get_stage_module(stage_idx)

返回的 stage_mod 是一个 nn.Module,你可以用它来创建优化器、保存或加载检查点,或者应用其他并行策略。

Pipe 还允许你基于给定的 ProcessGroup 在设备上创建分布式阶段运行时环境。


stage = pipe.build_stage(stage_idx, device, group)

或者,如果您希望在修改 stage_mod 后稍后再构建 stage 运行时,可以使用 build_stage API 的函数式版本。例如:

from torch.distributed.pipelining import build_stage
from torch.nn.parallel import DistributedDataParallel

dp_mod = DistributedDataParallel(stage_mod)
info = pipe.info()
stage = build_stage(dp_mod, stage_idx, info, device, group)

注意:pipeline 前端使用追踪器 (torch.export) 将你的模型捕获为单一计算图。如果你的模型无法实现全图捕获,可以使用下方提供的手动前端。


Hugging Face 示例

在最初创建此包的 PiPPy 代码库中,我们保留了基于未修改的 Hugging Face 模型的示例。请参阅 examples/huggingface 目录。


示例包括:


技术深度解析


pipeline API 如何分割模型?

首先,pipeline API 通过追踪模型将其转换为有向无环图(DAG)。它使用 PyTorch 2 的全图捕获工具 torch.export 来追踪模型。

然后,它将一个阶段所需的操作和参数分组到重建的子模块中:submod_0submod_1 等。

与传统的子模块访问方法(如 Module.children())不同,pipeline API 不仅切割模型的模块结构,还会切割模型的 forward 函数。

这是必要的,因为像 Module.children() 这样的模型结构仅捕获 Module.__init__() 期间的信息,而不会捕获任何关于 Module.forward() 的信息。换句话说,Module.children() 缺少以下对流水线至关重要的信息:

  • forward 中子模块的执行顺序
  • 子模块之间的激活流
  • 子模块之间是否存在任何函数式操作(例如,reluadd 操作不会被 Module.children() 捕获)。

相反,pipeline API 确保 forward 行为被完整保留。它还捕获分区之间的激活流,帮助分布式运行时无需人工干预即可正确执行发送/接收调用。

pipeline API 的另一个灵活性是,分割点可以位于模型层次结构的任意级别。在分割后的分区中,与该分区相关的原始模型层次结构会被重建,且不会带来额外开销。因此,指向子模块或参数的完全限定名称(FQN)仍然有效,依赖 FQN 的服务(如 FSDP、TP 或检查点)几乎无需代码更改即可继续运行。


实现自定义调度策略

您可以通过扩展以下两个基类之一来实现自己的流水线调度策略:

  • PipelineScheduleSingle
  • PipelineScheduleMulti

PipelineScheduleSingle 适用于每个计算节点仅分配单个阶段的调度策略。
PipelineScheduleMulti 则适用于每个计算节点分配多个阶段的调度策略。

例如:

  • ScheduleGPipeSchedule1F1BPipelineScheduleSingle 的子类
  • ScheduleInterleaved1F1BScheduleLoopedBFSScheduleInterleavedZeroBubble 以及 ScheduleZBVZeroBubble 则是 PipelineScheduleMulti 的子类

日志记录

您可以通过设置torch._logging中的TORCH_LOGS环境变量来启用额外的日志记录功能:

  • TORCH_LOGS=+pp 会显示logging.DEBUG及以上级别的日志信息
  • TORCH_LOGS=pp 会显示logging.INFO及以上级别的日志信息
  • TORCH_LOGS=-pp 会显示logging.WARNING及以上级别的日志信息

API 参考


模型拆分 API

以下一组 API 可将您的模型转换为流水线表示形式。


class torch.distributed.pipelining.SplitPoint(value)

表示在子模块执行过程中可插入切分点的枚举类型。

:ivar BEGINNING: 表示在前向函数中某个子模块执行之前添加切分点。

:ivar END: 表示在前向函数中某个子模块执行之后添加切分点。


torch.distributed.pipelining.pipeline(module, mb_args, mb_kwargs=None, split_spec=None, split_policy=None)

根据规范拆分模块。

更多详情请参阅 Pipe。

参数

  • module ( Module ) – 待拆分的模块。
  • mb_args ( tuple [Any , ...]) – 示例位置输入,以微批次形式提供。
  • mb_kwargs (Optional[dict[str, Any ]]) – 示例关键字输入,以微批次形式提供。(默认值:None)
  • split_spec (Optional[dict[str, torch.distributed.pipelining._IR.SplitPoint]]) – 使用子模块名称作为拆分标记的字典。(默认值:None)
  • split_policy (Optional[Callable [[GraphModule)],GraphModule]]) – 用于拆分模块的策略。(默认值:None)

返回类型:返回 Pipe 类的流水线表示形式。


class torch.distributed.pipelining.Pipe(split_gm, num_stages, has_loss_and_backward, loss_spec)

torch.distributed.pipelining.pipe_split()

pipe_split 是一个特殊运算符,用于标记模块中各阶段之间的边界。它的作用是将模块拆分为多个阶段。如果你以即时执行模式运行带注解的模块,该运算符不会产生任何效果。


示例:

>>> def forward(self, x):
>>>     x = torch.mm(x, self.mm_param)
>>>     x = torch.relu(x)
>>>     pipe_split()
>>>     x = self.lin(x)
>>>     return x


上述示例将被拆分为两个阶段。


微批次工具集


class torch.distributed.pipelining.microbatch.TensorChunkSpec(split_dim)

用于指定输入分块的类


torch.distributed.pipelining.microbatch.split_args_kwargs_into_chunks(args, kwargs, chunks, args_chunk_spec=None, kwargs_chunk_spec=None)

根据给定的参数序列(args和kwargs),按照各自的分块规格将它们分割成多个块。

参数说明:

  • args (tuple[Any, ...]) - 参数元组
  • kwargs (Optional[dict[str, Any]]) - 关键字参数字典
  • chunks (int) - 要将args和kwargs分割成的块数
  • args_chunk_spec (Optional[tuple[torch.distributed.pipelining.microbatch.TensorChunkSpec, ...]]) - args的分块规格,形状与args相同
  • kwargs_chunk_spec (Optional[dict[str, torch.distributed.pipelining.microbatch.TensorChunkSpec]]) - kwargs的分块规格,形状与kwargs相同

返回值说明:

  • args_split: 分割后的参数列表
  • kwargs_split: 分割后的关键字参数字典列表
  • 返回类型: args_split

torch.distributed.pipelining.microbatch.merge_chunks(chunks, chunk_spec)

根据分块规范将给定的分块列表合并为单个值。

参数

  • chunks (list[Any ]) - 分块列表
  • chunk_spec - 分块的分块规范

返回值:合并后的值

返回类型:值


流水线阶段


class torch.distributed.pipelining.stage.PipelineStage(submodule, stage_index, num_stages, device, input_args=None, output_args=None, group=None, dw_builder=None)

一个表示流水线并行设置中流水线阶段的类。

PipelineStage 假设模型采用顺序分区方式,即模型被分割成多个块,其中一个块的输出作为下一个块的输入,不存在跳跃连接。

PipelineStage 通过按线性顺序将 stage0 的输出传播到 stage1 等方式,自动执行运行时形状/数据类型推断。若要绕过形状推断,需将 input_args 和 output_args 传递给每个 PipelineStage 实例。

参数

  • submodule (nn.Module) – 该阶段封装的 PyTorch 模块。
  • stage_index ( int ) – 本阶段的 ID。
  • num_stages ( int ) – 阶段总数。
  • device ( torch.device ) – 本阶段所在的设备。
  • input_args (Union[torch.Tensor, Tuple[torch.tensor]], 可选) – 子模块的输入参数。
  • output_args (Union[torch.Tensor, Tuple[torch.tensor]], 可选) – 子模块的输出参数。
  • group (dist.ProcessGroup, 可选) – 分布式训练的进程组。若为 None,则使用默认组。
  • dw_builder (Optional[Callable[[], Callable[...*, None]]) – 若提供,dw_builder 将构建一个新的 dw_runner 函数,该函数会为 F、I、W(前向、输入、权重)零气泡调度执行 W 动作(输入权重)。

torch.distributed.pipelining.stage.build_stage(stage_module, stage_index, pipe_info, device, group=None)

创建一个流水线阶段,给定需要被该阶段包装的stage_module以及流水线信息。


参数

  • stage_module ( torch.nn.Module ) – 需要被该阶段包装的模块
  • stage_index ( int ) – 该阶段在流水线中的索引
  • pipe_info (PipeInfo) – 关于流水线的信息,可通过pipe.info()获取
  • device ( torch.device ) – 该阶段使用的设备
  • group (Optional[dist.ProcessGroup]) – 该阶段使用的进程组

返回一个可与PipelineSchedules一起运行的流水线阶段。

返回类型:_PipelineStage


流水线调度


class torch.distributed.pipelining.schedules.ScheduleGPipe(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)

GPipe调度方案。

采用填充-排空的方式处理所有微批次数据。


class torch.distributed.pipelining.schedules.Schedule1F1B(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)

1F1B调度方案。

在稳定状态下,将对微批次执行一次前向和一次后向操作。


class torch.distributed.pipelining.schedules.ScheduleInterleaved1F1B(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)

交错式1F1B调度方案。

详情请参阅https://arxiv.org/pdf/2104.04473

在稳定状态下,该方案会对微批次执行一次前向和一次后向计算,并支持每个rank处理多个阶段。当微批次准备好进行多个本地阶段计算时,交错式1F1B会优先处理较早的微批次(也称为"深度优先")。

该调度方案与原始论文基本相似,主要区别在于放宽了num_microbatch % pp_size == 0的要求。使用flex_pp调度时,我们会得到num_rounds = max(1, n_microbatches // pp_group_size),只要满足n_microbatches % num_rounds == 0即可正常工作。例如:

1、pp_group_size = 4,n_microbatches = 10时,num_rounds = 2且n_microbatches % 2 == 0
2、pp_group_size = 4,n_microbatches = 3时,num_rounds = 1且n_microbatches % 1 == 0


class torch.distributed.pipelining.schedules.ScheduleLoopedBFS(stages, n_microbatches, loss_fn=None, output_merge_spec=None, scale_grads=True)

广度优先流水线并行。

详情请参阅https://arxiv.org/abs/2211.05953

与交错式1F1B类似,循环BFS支持每个rank运行多个阶段。

不同之处在于,当多个本地阶段的微批次准备就绪时,循环BFS会优先处理较早的阶段,一次性运行所有可用的微批次。


class torch.distributed.pipelining.schedules.ScheduleInterleavedZeroBubble(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)

零气泡计划(ZBV变体)。

详情请参见 https://arxiv.org/pdf/2401.10241 第6节。

此计划要求每个等级恰好有两个阶段。

该计划将在稳定状态下对微批次的输入执行一次前向传播和一次后向传播,并支持每个等级有多个阶段。使用相对于权重的后向传播来填补管道气泡。

只有当时间前向传播 == 时间后向传播输入 == 时间后向传播权重时,这个ZB-V计划才具有“零气泡”属性。

实际上,对于真实模型来说,这不太可能是真的,所以可以选择实现一个贪婪调度器来处理不平等或不平衡的时间。


class torch.distributed.pipelining.schedules.ScheduleZBVZeroBubble(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)

零气泡调度方案(ZBV变体)。

详情请参阅https://arxiv.org/pdf/2401.10241第6节。

该调度方案要求每个rank(计算节点)严格使用两个阶段。

在稳定状态下,该方案会对微批次的输入执行一次前向传播和一次反向传播,并支持每个rank多阶段处理。通过权重反向传播来填补流水线气泡。

只有当满足以下条件时,该ZB-V调度方案才具备"零气泡"特性:前向传播时间 == 输入反向传播时间 == 权重反向传播时间。

实际应用中,真实模型很难满足这一条件。因此,针对时间不均衡的情况,可以改用贪心调度器实现。


class torch.distributed.pipelining.schedules.PipelineScheduleSingle(stage, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, scale_grads=True)

单阶段计划的基础类。

实现了步骤方法。

派生类应该实现 _step_microbatches 方法。

根据 scale_grads 参数,梯度会根据 num_microbatches 进行缩放,默认为 True。这个设置应该与您的 loss_fn 的配置相匹配,loss_fn 可能是平均损失(scale_grads=True)或总和损失(scale_grads=False)。


step(*args, target=None, losses=None, **kwargs)

运行一次管道计划的迭代,使用 whole-batch 输入。

将自动将输入分块为微批次,并根据计划实现依次处理微批次。

args: 模型的位置参数(与非管道情况相同)。

kwargs: 模型的关键字参数(与非管道情况相同)。

target: 损失函数的目标。

losses: 用于存储每个微批次的损失的列表。


class torch.distributed.pipelining.schedules.PipelineScheduleMulti(stages, n_microbatches, loss_fn=None, args_chunk_spec=None, kwargs_chunk_spec=None, output_merge_spec=None, use_full_backward=None, scale_grads=True)

多阶段计划的基础类。

实现了步骤方法。

根据 scale_grads 参数,梯度会根据 num_microbatches 进行缩放,默认为 True。这个设置应该与您的 loss_fn 的配置相匹配,loss_fn 可能是平均损失(scale_grads=True)或总和损失(scale_grads=False)。


step(*args, target=None, losses=None, **kwargs)

运行管道调度的一次迭代,使用全批次输入。

该方法会自动将输入切分为微批次,并根据调度实现依次处理这些微批次。

参数说明:

  • args: 传递给模型的位置参数(与非管道式情况相同)
  • kwargs: 传递给模型的关键字参数(与非管道式情况相同)
  • target: 损失函数的目标值
  • losses: 用于存储每个微批次损失值的列表

分布式检查点 - torch.distributed.checkpoint

分布式检查点(DCP)支持并行地从多个计算节点加载和保存模型。它能够处理加载时的重分片操作,这使得在一个集群拓扑中保存的模型可以加载到另一个不同拓扑的集群中。

DCP与torch.save和torch.load在几个重要方面存在差异:

  • 每个检查点会生成多个文件,每个计算节点至少对应一个文件
  • 它以原地方式操作,这意味着模型需要先分配其数据存储空间,DCP会直接使用这些预分配的存储空间

以下是加载和保存检查点的主要入口方法:

附加资源:


class torch.distributed.checkpoint.state_dict_saver.AsyncCheckpointerType(value)

异步检查点类型的枚举。


torch.distributed.checkpoint.state_dict_saver.save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None, no_dist=False)

以SPMD风格保存分布式模型。

此函数与torch.save()不同,它通过让每个rank仅保存本地分片来处理ShardedTensorDTensor

对于每个Stateful对象(同时具有state_dictload_state_dict方法),保存操作会在序列化前调用state_dict


警告:不同PyTorch版本间保存的state_dict不保证具有向后兼容性。


警告:如果使用process_group参数,请确保只有该组的rank调用save_state_dict,且state_dict中的所有数据都属于该组。


注意:当为FSDP的ShardingStrategy.HYBRID_SHARD保存检查点时,只有一个shard_group应调用save_state_dict,且需要传入对应的process_group。


注意:

如果没有可用的进程组,此函数会假定意图是在本地进程中保存state_dict。


参数

  • state_dict (Dict[str, Any]) – 要保存的state_dict。
  • checkpoint_id (Union[str, os.PathLike, None]) – 检查点实例的ID。checkpoint_id的具体含义取决于存储类型,可以是文件夹路径、文件路径,或者键值存储中的键名(默认:None)。
  • storage_writer (Optional[StorageWriter]) – 用于执行写入操作的StorageWriter实例。如果未指定,DCP会根据checkpoint_id自动推断写入器。如果checkpoint_id也为None,将抛出异常(默认:None)。
  • planner (Optional[SavePlanner]) – SavePlanner实例。如果未指定,将使用默认planner(默认:None)。
  • process_group (Optional[ProcessGroup]) – 用于跨rank同步的进程组(默认:None)。
  • no_dist ([bool]) – 如果为True,此函数将假定意图是在不使用跨rank同步的情况下加载检查点(默认:False)。

返回

已保存检查点的元数据对象。

返回类型:Metadata


示例


>>> my_model = MyModule()


>>> state_dict = {"model": my_model}


>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
...     "/checkpoint/1"
... )
>>> torch.distributed.checkpoint.save(
>>>     state_dict=state_dict, >>    storage_writer=fs_storage_writer, >>)

注意save_state_dict 使用集合通信(collectives)来协调不同进程间的写入操作。

对于基于 NCCL 的进程组,对象的内部张量表示必须在通信发生前移至 GPU 设备。

此时,所使用的设备由 torch.cuda.current_device() 指定,用户需自行确保通过 torch.cuda.set_device() 正确设置,使每个进程独占一个 GPU。


torch.distributed.checkpoint.state_dict_saver.async_save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None, async_checkpointer_type=AsyncCheckpointerType.THREAD)

save 的异步版本。该代码首先将 state_dict 从暂存区移出到暂存存储(默认为 CPU 内存),然后在单独的线程中调用保存操作。

警告:此功能为实验性质,可能会发生变化。

参数

  • state_dict (Dict[str, Any]) – 要保存的 state_dict。
  • checkpoint_id (Union[str,* os.PathLike, None]) – 此检查点实例的 ID。checkpoint_id 的具体含义取决于存储类型。它可以是文件夹路径或文件路径。如果存储是键值存储,它也可以是键。(默认值:None
  • storage_writer (Optional[StorageWriter)]) – 用于执行 ‘stage’ 和 ‘save’ 的 StorageWriter 实例。如果未指定,DCP 将根据 checkpoint_id 自动推断写入器。如果 checkpoint_id 也为 None,则会抛出异常。(默认值:None
  • planner (Optional[SavePlanner]) – SavePlanner 实例。如果未指定,将使用默认的 planner。(默认值:None
  • process_group (Optional[ProcessGroup]) – 用于跨 rank 同步的 ProcessGroup。(默认值:None

返回值:一个持有保存操作返回的 Metadata 对象的 future。

返回类型:Future

示例


>>> my_model = MyModule()


>>> state_dict = {"model": my_model}


>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
...     "/checkpoint/1"
... )
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
>>>     state_dict=state_dict, >>    storage_writer=fs_storage_writer, >>)
>>> >
>>> # ... do some work ...
>>> >
>>> checkpoint_future.result()


torch.distributed.checkpoint.state_dict_saver.save_state_dict(state_dict, storage_writer, process_group=None, coordinator_rank=0, no_dist=False, planner=None)

此方法已弃用。请改用 save

返回类型:Metadata


torch.distributed.checkpoint.state_dict_loader.load(state_dict, *, checkpoint_id=None, storage_reader=None, planner=None, process_group=None, no_dist=False)

以SPMD风格将检查点加载到分布式状态字典中。

每个进程提供的state_dict必须包含相同的键。键不匹配可能导致挂起或错误。如果不确定,可以使用utils._assert_same_keys API进行检查(但可能会产生通信开销)。

每个进程会尝试读取最少量的数据来填充请求的state_dict。当加载ShardedTensorDTensor实例时,每个进程仅读取其本地分片的数据。

对于每个Stateful对象(同时具有state_dictload_state_dict方法),加载操作会先调用state_dict,然后尝试反序列化,反序列化完成后调用load_state_dict

对于非Stateful对象,加载操作会反序列化对象,然后在state_dict中用反序列化后的对象替换原对象。


警告:state_dict中的所有张量必须在调用此函数之前分配到目标设备上。

所有非张量数据使用torch.load()加载,并在state_dict中就地修改。


警告:用户必须在根模块上调用load_state_dict,以确保加载后处理和非张量数据正确传播。


参数

  • state_dict (Dict[str, Any]) – 要加载检查点的状态字典。
  • checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的ID。checkpoint_id的含义取决于存储类型。可以是文件夹路径、文件路径,如果存储是键值存储也可以是键名。(默认: None)
  • storage_reader (Optional[[StorageReader](https://pytorch.org/docs/stable/data.html#torch.distributed.checkpoint.StorageReader "torch.distributed.checkpoint.StorageReader")]) – 用于执行读取操作的StorageWriter实例。如果未指定,DCP会根据checkpoint_id自动推断读取器。如果checkpoint_id也为None,则会抛出异常。(默认: None)
  • planner (Optional[LoadPlanner]) – LoadPlanner实例。如果未指定,将使用默认规划器。(默认: None)
  • process_group (Optional[ProcessGroup]) – 用于跨进程同步的ProcessGroup。(默认: None)
  • no_dist ([bool]) – 如果为True,此函数将假定目的是在不使用跨进程同步的情况下加载检查点。(默认: False)

返回

无。

返回类型:无


示例


>>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(
...     "/checkpoint/1"
... )


>>> torch.distributed.checkpoint.load_state_dict(
>>>     state_dict=model_state_dict, >>    storage_reader=fs_storage_reader, >>)


>>> # module.load_state_dict() function might have customized steps
>>> # to flush the state_dict, must call it to >># ensure correct behavior.
>>> my_model.load_state_dict(model_state_dict)

注意load_state_dict 使用集合通信来协调跨进程的读取操作。

对于基于 NCCL 的进程组,对象的内部张量表示必须在通信发生前移至 GPU 设备。

此时使用的设备由 torch.cuda.current_device() 指定,用户需自行确保通过 torch.cuda.set_device() 正确设置,使每个进程独占一个 GPU。


torch.distributed.checkpoint.state_dict_loader.load_state_dict(state_dict, storage_reader, process_group=None, coordinator_rank=0, no_dist=False, planner=None)

该方法已弃用,请改用 load

以下模块还可用于对异步检查点(torch.distributed.checkpoint.async_save)使用的暂存机制进行额外定制:

class torch.distributed.checkpoint.staging.AsyncStager(*args, **kwargs)

该协议旨在为dcp.async_save提供定制化和扩展能力,允许用户在并行执行常规dcp.save流程前自定义数据暂存方式。

预期操作顺序(具体定义于torch.distributed.state_dict_saver.async_save)如下:

1、AsyncStager.stage_data(state_dict):此调用为AsyncStager提供"暂存"state_dict的机会。此处的暂存预期目的是创建state_dict的"训练安全"表示形式,这意味着暂存完成后对模块数据的任何更新都不应反映在该方法返回的state_dict中。例如默认情况下,会在CPU内存中创建整个state_dict的副本并返回,从而允许用户继续训练而不影响正在被序列化的数据。

2、对暂存返回的state_dict并行调用dcp.save。该调用负责序列化state_dict并将其写入存储。

3、若AsyncStager.should_synchronize_after_execute为True,该方法将在序列化线程启动后、从dcp.async_save返回前立即调用。若设为False,则假定用户已定义自定义同步点以进一步优化训练循环中的保存延迟(例如通过将暂存与前向/反向传播重叠),此时用户需在适当时机调用AsyncStager.synchronize_staging


property should_synchronize_after_execute:  bool 

是否在执行阶段后进行同步。


stage(state_dict)

返回一个"暂存"状态的 state_dict 副本。该暂存副本的特性是:在 stage 调用完成后,不会受到任何后续更新的影响。

返回类型:dict[str , Union [~StatefulT, Any ]


synchronize_staging()

如果阶段以某种方式异步进行,应调用此方法以确保暂存完成,此时可以安全地开始修改原始 state_dict。


class torch.distributed.checkpoint.staging.BlockingAsyncStager(cache_staged_state_dict=False, type_check=False)

一个实现了 AsyncStager 的类,将 state_dict 暂存到 CPU 内存中,并阻塞直到复制完成。

该实现还提供了使用固定内存来优化暂存延迟的选项。

注意:在这种情况下,synchronize_staging 是一个空操作。


stage(state_dict)

返回一个位于CPU上的state_dict副本。

返回类型:dict[str, Union[~StatefulT, Any]]


synchronize_staging()

无操作函数,因为暂存是阻塞式的。

除了上述入口点外,如下所述的有状态对象在保存/加载过程中提供了额外的自定义功能

… automodule:: torch.distributed.checkpoint.stateful


class torch.distributed.checkpoint.stateful.Stateful(*args, **kwargs)

支持检查点(checkpoint)与恢复功能的对象状态协议。


load_state_dict(state_dict)

从提供的 state_dict 恢复对象的状态。

参数

  • state_dict ( dict[str, Any ]) – 用于恢复的状态字典

state_dict()


Objects should return their state_dict representation as a dictionary.
The output of this function will be checkpointed, and later restored in load_state_dict().


Warning: Because of the inplace nature of restoring a checkpoint, this function is also called during torch.distributed.checkpoint.load.

Returns
The objects state dict

Return type
Dict

This example shows how to use Pytorch Distributed Checkpoint to save a FSDP model.

The following types define the IO interface used during checkpoint:


class torch.distributed.checkpoint.StorageReader

Interface used by load_state_dict to read from storage.

One StorageReader instance acts as both the coordinator and the follower in a distributed checkpoint. As part of initialization, each instance is told its role.

A subclass should expected the following sequence of calls by load_state_dict:

0、(all ranks) set checkpoint_id if users pass a valid checkpoint_id.
1、(all ranks) read_metadata()
2、(all ranks) set_up_storage_reader()
3、(all ranks) prepare_local_plan()
4、(coordinator) prepare_global_plan()
5、(all ranks) read_data()


ABSTRACT  prepare_global_plan(plans)

Perform centralized planning of storage loading.

This method is only called on the coordinator instance.

While this method can produce a completely different plan, the preferred
way is to store storage specific data in LoadPlan::storage_data.


Parameters

  • plans (list[torch.distributed.checkpoint.planner.LoadPlan]) – A list of LoadPlan instances, one for each rank.

Returns
A list of transformed LoadPlan after storage global planning

Return type
list [torch.distributed.checkpoint.planner.LoadPlan]


ABSTRACT prepare_local_plan(plan)

Perform storage-specific local planning.

While this method can produce a completely different plan, the recommended
way is to store storage specific data in LoadPlan::storage_data.


Parameters

  • plan (LoadPlan) – The local plan from the LoadPlan in use.

Returns
A transformed LoadPlan after storage local planning

Return type
LoadPlan


ABSTRACT read_data(plan, planner)

Read all items from plan using planner to resolve the data.

A subclass should call LoadPlanner::load_bytes to deserialize a BytesIO
object into the right place.

A subclass should call LoadPlanner::resolve_tensor to get access to the tensors that in should load data into.

It’s the StorageLayer responsibility to properly schedule any cross device copies
required.


Parameters

  • plan (LoadPlan) – The local plan to execute on * planner (LoadPlanner) – The planner object to use to resolve items.

Returns
A future that completes once all reads are finished.

Return type
Future [None]


read_metadata()

摘要


Read the checkpoint metadata.

Returns
The metadata object associated with the checkpoint being loaded.

Return type
Metadata


ABSTRACT  reset(checkpoint_id=None)

Calls to indicates a brand new checkpoint read is going to happen.
A checkpoint_id may be present if users set the checkpoint_id for this checkpoint read. The meaning of the checkpiont_id is storage-dependent. It can be a path to a folder/file or a key for a key-value storage.


Parameters

  • checkpoint_id (Union[str,* os.PathLike, None]) – The ID of this checkpoint instance. The meaning of the checkpoint_id
    depends on the storage. It can be a path to a folder or to a file.
    It can also be a key if the storage is more like a key-value store.
    (Default: None)


ABSTRACT set_up_storage_reader(metadata, is_coordinator) 

Initialize this instance.


Parameters

  • metadata (Metadata) – The metadata schema to use.
  • is_coordinator ([bool]) – Whether this instance is responsible for coordinating the checkpoint.

Abstract Classmethod validate_checkpoint_id(checkpoint_id)

检查给定的 checkpoint_id 是否被存储支持。这允许我们启用自动存储选择。

返回类型:bool


class torch.distributed.checkpoint.StorageWriter

save_state_dict 用于写入存储的接口。

在分布式检查点中,一个 StorageWriter 实例同时充当协调者和跟随者角色。初始化时,每个实例都会被告知其角色。

子类应遵循以下调用顺序:

0、(所有进程)如果用户提供了有效的 checkpoint_id,则设置 checkpoint_id

1、(所有进程)调用 set_up_storage_writer()

2、(所有进程)调用 prepare_local_plan()

3、(协调者)调用 prepare_global_plan()

4、(所有进程)调用 write_data()

5、(协调者)调用 finish()


ABSTRACT  finish(metadata, results)

写入元数据并将当前检查点标记为成功。

用于序列化元数据的实际格式/模式是实现细节,唯一要求是能够还原为相同的对象图。

参数

  • metadata (Metadata) – 新检查点的元数据
  • results (list[list[torch.distributed.checkpoint.storage.WriteResult]]) – 来自所有进程的WriteResults列表

返回值:无

返回类型:无


ABSTRACT  prepare_global_plan(plans)

执行存储的集中规划。

此方法仅在协调器实例上调用。

虽然该方法可以生成完全不同的规划方案,但推荐的方式是将存储特定数据保存在 SavePlan::storage_data 中。

参数

  • plans (list[[torch.distributed.checkpoint.planner.SavePlan](https://pytorch.org/docs/stable/data.html#torch.distributed.checkpoint.SavePlan "torch.distributed.checkpoint.planner.SavePlan")]) – 一个包含各rank对应SavePlan实例的列表。

返回
经过存储全局规划处理后的SavePlan列表

返回类型
list[torch.distributed.checkpoint.planner.SavePlan]


ABSTRACT  prepare_local_plan(plan)

执行存储特定的本地规划。

虽然此方法可以生成完全不同的计划,但推荐的方式是将存储特定数据保存在 SavePlan::storage_data 中。

参数

  • plan ([SavePlan](https://pytorch.org/docs/stable/data.html#torch.distributed.checkpoint.SavePlan "torch.distributed.checkpoint.SavePlan")) – 当前使用的 SavePlanner 生成的本地计划。

返回

经过存储本地规划转换后的 SavePlan

返回类型

SavePlan


ABSTRACT  reset(checkpoint_id=None)

调用表示即将开始一次全新的检查点写入。

如果用户为本次检查点写入设置了checkpoint_id,则该参数可能存在。checkpoint_id的具体含义取决于存储实现,可能是指向文件夹/文件的路径,也可能是键值存储中的键名。

参数说明

  • checkpoint_id (Union[str, os.PathLike, None]) - 本次检查点实例的ID。checkpoint_id的具体含义取决于存储类型:
    • 对于文件系统存储,可以是文件夹路径或文件路径
    • 对于键值存储,可以是键名
      (默认值:None

ABSTRACT  set_up_storage_writer(is_coordinator)

初始化该实例。

参数

  • is_coordinator ([bool]) – 该实例是否负责协调检查点。

storage_meta()

返回存储特定的元数据。该方法用于在检查点中存储额外信息,这些信息有助于提供请求级别的可观测性。在保存调用期间,StorageMeta会被传递给SavePlanner。默认返回None。

TODO: 提供一个示例

返回类型:Optional[StorageMeta]


ABSTRACT classmethod* validate_checkpoint_id(checkpoint_id)

检查给定的 checkpoint_id 是否被存储系统支持。这让我们能够启用自动存储选择功能。

返回类型:bool


ABSTRACT  write_data(plan, planner)

使用 planner 解析数据,将 plan 中的所有条目写入。

子类应对计划中的每个条目调用 SavePlanner::resolve_data 方法,以获取待写入的底层对象访问权限。子类应惰性调用 resolve_data,因为该方法可能涉及内存分配。

对于张量,需遵循以下假设:

  • 张量可能位于任意设备上(包括与 WriteItem::tensor_data 设备不匹配的情况)
  • 张量可能是视图或非连续的,仅需保存其投影部分

参数

  • plan ([SavePlan](https://pytorch.org/docs/stable/data.html#torch.distributed.checkpoint.SavePlan "torch.distributed.checkpoint.SavePlan")) – 要执行的保存计划
  • planner ([SavePlanner](https://pytorch.org/docs/stable/data.html#torch.distributed.checkpoint.SavePlanner "torch.distributed.checkpoint.SavePlanner")) – 用于将条目解析为数据的规划器对象

返回值
一个最终返回 WriteResult 列表的 Future 对象

返回类型
Future [list [torch.distributed.checkpoint.storage.WriteResult]]

以下类型定义了检查点期间使用的规划器接口:

class torch.distributed.checkpoint.LoadPlanner

定义加载状态字典(load_state_dict)所用协议的抽象基类。

LoadPlanner是有状态对象,可用于自定义整个加载流程。它作为访问state_dict的代理,任何对字典的修改都会在整个流程中可见。

load_state_dict执行期间,规划器子类会按以下顺序接收调用:

1、set_up_planner - 所有rank节点都会调用。标志检查点加载开始
2、create_local_plan - 所有rank节点调用。处理state_dict并生成将用于全局规划的LoadPlan
3、create_global_plan - 仅协调者rank节点调用。汇总各rank的LoadPlan并做出全局决策
4、load_bytes - 每个rank节点会多次调用。对应state_dict中每个非张量值调用一次
5、resolve_tensorcommit_tensor - 每个rank节点成对调用。对应state_dict中每个张量值调用

建议用户继承DefaultLoadPlanner而非直接实现此接口,因为多数修改只需重写单个方法即可实现。

扩展规划器通常有两种模式:

重写state_dict。这是扩展加载流程最简单的方式,因为不需要理解LoadPlan的内部机制。由于加载是原地(in-place)操作,我们需要保留原始state_dict的引用,以便执行原地修改。


>>> class RenamePlanner(DefaultLoadPlanner):
>>>     def set_up_planner(
>>>         self, >>        state_dict: STATE_DICT_TYPE, >>        metadata: Metadata, >>        is_coordinator: bool, >>    ) -None:
>>>         self.original_state_dict = state_dict
>>>         state_dict = {"foo_" + k: v for k, v in state_dict.items()}
>>> >
>>>         if self.flatten_sharded_tensors:
>>>             state_dict = _flatten_sharded_tensors(state_dict)
>>> >
>>>         if self.flatten_state_dict:
>>>             state_dict, self.mappings = flatten_state_dict(state_dict)
>>> >
>>>         self.state_dict = state_dict
>>>         self.metadata = metadata
>>>         self.is_coordinator = is_coordinator
>>> >
>>>     def load_bytes(self, read_item, value):
>>> # Remove the "foo_" prefix
>>>         self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value, weights_only=False)


修改 resolve_tensorcommit_tensor 方法以支持加载时转换。


>>> class MetaModelMaterialize(DefaultSavePlanner):
>>>     def resolve_tensor(self, read_item):
>>>         tensor = super().resolve_tensor(read_item)
>>>         return torch.empty_like(tensor, device="cpu")
>>> >
>>>     def commit_tensor(self, read_item, tensor):
>>>         self.state_dict[read_item.dest_index.fqn] = tensor


ABSTRACT  commit_tensor(read_item, tensor)

StorageReader完成将数据加载到tensor后调用一次。

提供的tensor与调用resolve_tensor返回的是同一个。

仅当该LoadPlanner需要在将tensor复制回state_dict之前进行后处理时,才需要此方法。

tensor的内容将遵循其设备同步模型。


ABSTRACT  create_global_plan(global_plan)

计算全局加载计划并返回每个rank的加载计划。

注意:此方法仅在协调器rank上调用。

返回类型:list [torch.distributed.checkpoint.planner.LoadPlan]


ABSTRACT  create_local_plan()

基于set_up_planner提供的state_dict和元数据创建加载计划。

注意:此方法会在每个rank上调用。

返回类型:LoadPlan


ABSTRACT  finish_plan(central_plan)

接受协调器的计划并返回最终的加载方案。

返回类型:LoadPlan


ABSTRACT  load_bytes(read_item, value)

加载由 read_itemvalue 描述的项。

该方法预期会就地修改底层的 state_dict。

value 的内容由用于生成待加载检查点的 SavePlanner 定义。


resolve_bytes(read_item)

返回供 StorageReader 用于加载 read_item 的 BytesIO 对象。

该 BytesIO 应与底层 state_dict 中的对象建立别名关系,因为 StorageReader 会替换其内容。

返回类型:BytesIO


ABSTRACT  resolve_tensor(read_item)

返回由 read_item 描述的张量,供 StorageReader 用于加载 read_item。

该张量应与底层 state_dict 中的某个张量建立别名关系,因为 StorageReader 会替换其内容。

如果因任何原因无法实现这一点,规划器可以使用 commit_tensor 方法将数据复制回 state_dict 中的对应张量。

返回类型:Tensor


ABSTRACT  set_up_planner(state_dict, metadata=None, is_coordinator=False)

初始化该实例以将数据加载到 state_dict 中。

注意:此操作会在每个 rank 上调用。


class torch.distributed.checkpoint.LoadPlan(items:  list [[torch.distributed.checkpoint.planner.ReadItem](https://pytorch.org/docs/stable/data.html#torch.distributed.checkpoint.ReadItem "torch.distributed.checkpoint.planner.ReadItem")], storage_data: Any = None, planner_data: Any = None)

class torch.distributed.checkpoint.ReadItem(type: torch.distributed.checkpoint.planner.LoadItemType, dest_index: torch.distributed.checkpoint.metadata.MetadataIndex, dest_offsets:  torch.Size , storage_index: torch.distributed.checkpoint.metadata.MetadataIndex, storage_offsets:  torch.Size , lengths:  torch.Size )

class torch.distributed.checkpoint.SavePlanner

定义保存状态字典(save_state_dict)所用协议的抽象类。

SavePlanner 是有状态对象,可用于自定义整个保存过程。它作为访问 state_dict 的代理,因此对其进行的任何转换都会对整个过程可见。

在 save_state_dict 过程中,规划器子类会按以下顺序调用方法:

1、set_up_planner - 在所有 rank 上调用。标志检查点保存开始
2、create_local_plan - 在所有 rank 上调用。处理 state_dict 并生成将用于全局规划的 SavePlan
3、create_global_plan - 仅在协调器 rank 上调用。汇总各 rank 的 SavePlan 并做出全局决策
4、finish_plan - 在所有 rank 上调用。使各 rank 能根据全局规划决策进行调整
5、resolve_data - 在每个 rank 上多次调用。为存储层查找 state_dict 中的值以供写入

建议用户直接继承 DefaultSavePlanner 而非本接口,因为大多数修改只需更改单个方法即可实现。

扩展通常有三种模式:

重写 state_dict。这是扩展保存过程最简单的方式,因为它不需要理解 SavePlan 的内部工作机制。


>>> class RenamePlanner(DefaultSavePlanner):
>>>     def set_up_planner(
>>>         self, >>        state_dict: STATE_DICT_TYPE, >>        storage_meta: Optional[StorageMeta], >>        is_coordinator: bool, >>    ) -None:
>>> # prefix all keys with `foo_``
>>>         super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, storage_meta, is_coordinator)


同步修改本地计划和查询。这在需要精细控制数据持久化方式时非常有用。


>>> class FP16Planner(DefaultSavePlanner):
>>>     def create_local_plan(self):
>>>         plan = super().create_local_plan()
>>>         for p in plan:
>>>             if p.tensor_data is not None:
>>>                 p.tensor_data.properties.dtype = torch.float16
>>>         return plan
>>> >
>>>     def resolve_data(self, write_item):
>>>         item = super().resolve_data(write_item)
>>>         return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)


使用全局规划步骤来制定无法由每个节点单独做出的中心化决策


>>> from itertools import zip_longest
>>> from dataclasses import replace
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
>>> # This uses the default local plan behavior of having all non-sharded writes in rank 0
>>> # This sample doesn't handle ShardedTensors
>>>     def create_global_plan(self, all_plans):
>>>         iters = [iter(all_plans[0].items)] * len(all_plans)
>>>         items_per_rank = [
>>>             [item for item in items if item is not None]
>>>             for items in zip(zip_longest(iters), strict=True)
>>>         ]
>>>         all_plans = [
>>>             replace(plan, items=items)
>>>             for plan, items in zip(all_plans, items_per_rank, strict=True)
>>>         ]
>>>         return super().create_global_plan(all_plans)


最后,某些规划器需要在检查点中保存额外的元数据。实现方式是让每个节点在本地计划中贡献其数据项,然后由全局规划器进行聚合:

>>> class SaveExtraDataPlanner(DefaultSavePlanner):
>>>     def create_local_plan(self) -SavePlan:
>>>         plan = super().create_local_plan()
>>>         return replace(plan, planner_data="per-rank-data")
>>> >
>>>     def create_global_plan(self, all_plans: List[SavePlan]) -Tuple[List[SavePlan], Metadata]:
>>>         global_plan, metadata = super().create_global_plan(all_plans)
>>>         merged_data = [p.planner_data for p in global_plan]
>>>         metadata = replace(metadata, planner_data=merged_data)
>>>         return global_plan, metadata


ABSTRACT  create_global_plan(all_plans)

计算全局检查点计划并返回每个rank的本地计划。

此方法仅在协调器rank上调用。

返回类型:tuple [list [torch.distributed.checkpoint.planner.SavePlan], torch.distributed.checkpoint.metadata.Metadata]


ABSTRACT  create_local_plan()

计算当前秩的保存计划。

该计划将被聚合并传递给create_global_plan

可以通过SavePlan::planner_data传递规划器特定数据。

此操作在所有秩上调用。

返回类型:SavePlan


ABSTRACT  finish_plan(new_plan)

create_local_plan 创建的规划与 create_global_plan 的结果进行合并。

此方法在所有进程上调用。

返回类型:SavePlan


ABSTRACT  resolve_data(write_item)

转换并准备来自 state_dictwrite_item 以进行存储,确保操作的幂等性和线程安全性。

在存储层处理之前,从 state_dict 中查找与 write_item 关联的对象,并应用任何转换(例如序列化)。

该方法会在每个 rank 上被多次调用,最终 SavePlan 中的每个 WriteItem 至少调用一次。

此方法应具备幂等性和线程安全性。StorageWriter 实现可以按需自由调用它。

为了减少检查点操作所需的内存峰值,任何涉及内存分配的转换都应延迟到调用该方法时执行。

返回张量时,它们可以位于任何设备或格式上,也可以是视图。存储层需自行确定如何保存它们。

返回类型:
Union [Tensor, BytesIO]


ABSTRACT  set_up_planner(state_dict, storage_meta=None, is_coordinator=False)

初始化此规划器以保存 state_dict

实现时应保存这些值,因为在后续保存过程中不会再次提供这些数据。

该操作会在所有节点上调用。


class torch.distributed.checkpoint.SavePlan(items:  list [[torch.distributed.checkpoint.planner.WriteItem](https://pytorch.org/docs/stable/data.html#torch.distributed.checkpoint.planner.WriteItem "torch.distributed.checkpoint.planner.WriteItem")], storage_data: Any = None, planner_data: Any = None, usable:  bool  = True)

class torch.distributed.checkpoint.planner.WriteItem(index, type, tensor_data=None)

这是一个数据类,用于保存需要写入存储的信息。


tensor_storage_size()

计算底层张量的存储大小,如果不是张量写入则返回 None。

返回值:Optional[int] 底层张量的存储大小(以字节为单位),如果存在的话。

返回类型:Optional[int]

我们提供了一个基于文件系统的存储层:

class torch.distributed.checkpoint.FileSystemReader(path, _extension_registry=None)

property checkpoint_id:  Union [str , PathLike] 

返回将用于加载检查点的 checkpoint_id。


class torch.distributed.checkpoint.FileSystemWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000, cache_staged_state_dict=False, overwrite=True, _extensions=None)

使用文件IO实现StorageWriter的基础版本。

该实现基于以下假设和简化条件:

  • 检查点路径是一个空目录或不存在的目录
  • 文件创建操作是原子性的

每个检查点包含:每个写入请求对应一个文件,外加一个存储序列化元数据的.metadata文件。


stage(state_dict)

重写 AsyncStager.stage 方法

返回值类型:dict[str, Union[~StatefulT, Any]]

我们提供了 LoadPlanner 和 SavePlanner 的默认实现,能够处理所有 torch.distributed 结构,包括 FSDP、DDP、ShardedTensor 和 DistributedTensor。


class torch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=None, dedup_save_to_lowest_rank=False, enable_plan_caching=False)

lookup_object(index)

从规划器接口扩展,便于扩展默认规划器。

返回类型:任意


transform_object(write_item, object)

从规划器接口扩展而来,便于扩展默认规划器。


class torch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)

LoadPlanner基础上添加多项功能的DefaultLoadPlanner

具体新增以下特性:

  • flatten_state_dict:支持处理包含嵌套字典的state_dict
  • flatten_sharded_tensors:针对2D并行模式下的FSDP优化
  • allow_partial_load:若设为False,当state_dict中的键存在于检查点时会抛出运行时错误

lookup_tensor(index)

从规划器接口扩展而来,便于扩展默认规划器。

返回类型:Tensor


transform_tensor(read_item, tensor)

从规划器接口扩展而来,便于扩展默认规划器。

由于历史设计决策,FSDP和DDP的状态字典可能具有不同的键或完全限定名称(例如layer1.weight),即使原始未并行化的模型完全相同。此外,FSDP提供多种类型的模型状态字典,例如完整和分片状态字典。另外,优化器状态字典使用参数ID而非完全限定名称来标识参数,这在使用并行技术(如流水线并行)时可能导致问题。

为解决这些挑战,我们提供了一组API,方便用户管理状态字典。get_model_state_dict()返回的模型状态字典,其键与未并行化模型状态字典返回的键保持一致。类似地,get_optimizer_state_dict()提供的优化器状态字典,其键在所有应用的并行技术中保持统一。为实现这种一致性,get_optimizer_state_dict()将参数ID转换为与未并行化模型状态字典中完全相同的完全限定名称。

请注意,这些API返回的结果可直接与torch.distributed.checkpoint.save()torch.distributed.checkpoint.load()方法配合使用,无需任何额外转换。

set_model_state_dict()set_optimizer_state_dict()用于加载由各自getter API生成的模型和优化器状态字典。

请注意,set_optimizer_state_dict()只能在优化器调用backward()之前或step()之后调用。

请注意,此功能为实验性质,未来API签名可能会发生变化。


torch.distributed.checkpoint.state_dict.get_state_dict(model, optimizers, *, submodules=None, options=None)

返回模型的状态字典(state_dict)和优化器的状态字典。

get_state_dict 能够处理任何通过 PyTorch 并行化的模块,包括 FSDP/fully_shard、DDP/replicate、tensor_parallel/parallelize_module 以及这些并行方式的任意组合。get_state_dict 的主要功能包括:

1、返回一个模型和优化器的状态字典,该字典可以在不同数量的训练器和/或不同并行方式下重新分片。
2、隐藏并行化特定的状态字典 API。用户无需调用这些 API。
3、对结果状态字典进行完整性检查。

结果状态字典的键是规范的完全限定名称(FQN)。规范的 FQN 指的是基于参数在 nn.Module 层次结构中的位置生成的 FQN。更具体地说,参数的规范 FQN 是当模块未被任何并行化方式分发时,通过 module.named_parameters()module.named_buffers() 返回的 FQN。由于优化器内部使用参数 ID 来表示参数,调用此 API 时会将参数 ID 转换为规范的 FQN。

get_state_dict 也可以处理未并行化的模块。在这种情况下,get_state_dict 仅执行一项功能——将优化器的参数 ID 转换为规范的 FQN。

示例


>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> from torch.distributed.checkpoint.state_dict import get_state_dict


>>> fsdp_model = FSDP(copy.deepcopy(model))
>>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_model = DDP(copy.deepcopy(model))
>>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)


>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(
...     fsdp_model, fsdp_optim
... )


>>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), >># the asserts will fail.
>>> assert ddp_state_dict == fsdp_state_dict
>>> assert ddp_optim_state == fsdp_optim_state_dict


参数

  • model (nn.Module) - 需要获取状态字典的神经网络模型。
  • optimizers (Union[None, Optimizer, Iterable[Optimizer]]) - 用于优化model的优化器集合。
  • submodules (已弃用) - Optional[set[nn.Module]]: 仅返回属于指定子模块的模型参数。
  • options (StateDictOptions) - 控制如何返回模型状态字典和优化器状态字典的配置选项。详情参见StateDictOptions。

返回值:包含模型状态字典和优化器状态字典的Tuple元组。

返回类型:Tuple[Dict[str, ValueType], OptimizerStateType]


torch.distributed.checkpoint.state_dict.get_model_state_dict(model, *, submodules=None, options=None)

返回模型的model状态字典。

详细用法请参阅get_state_dict

参数

  • model (nn.Module) – 需要获取状态字典的nn.Module模型。
  • submodules (已弃用) – Optional[set[nn.Module]]: 仅返回属于指定子模块的模型参数。
  • options (StateDictOptions) – 控制如何返回模型状态字典和优化器状态字典的选项。详情参见StateDictOptions。

返回值:model的状态字典。

返回类型:Dict[str, ValueType]


torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)

返回优化器的组合状态字典。

有关详细用法,请参阅 get_state_dict

参数

  • model (nn.Module) – 用于模型的 nn.Module。
  • optimizers (Union[None*,* Optimizer, Iterable[Optimizer]]) – 用于优化 model 的优化器。
  • submodules (已弃用) – Optional[set[nn.Module]]: 仅返回属于子模块的模型参数。
  • options (StateDictOptions) – 控制如何返回模型状态字典和优化器状态字典的选项。详情请参阅 StateDictOptions。

返回值:optimizers 的状态字典。

返回类型:OptimizerStateType


torch.distributed.checkpoint.state_dict.set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)

加载模型状态字典(state_dict)和优化器状态字典。

这是与 get_state_dict 相对应的操作,用于将状态字典设置到模型和优化器中。给定的 model_state_dictoptim_state_dict 不必由 get_state_dict 返回,但必须满足以下要求:

  1. 所有 FQN(完全限定名)必须符合 get_state_dict 中定义的规范格式;
  2. 如果张量是分片的,则必须是 ShardedTensor 或 DTensor 类型;
  3. 优化器状态字典不能包含参数 ID,其键应为规范化的 FQN。

警告:set_state_dict 只能在调用 backward() 之前或优化器执行 step() 之后调用,否则优化器状态将无法正确初始化。

参数

  • model (nn.Module) – 目标模型(nn.Module 实例)。
  • optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化 model 的优化器(单个或可迭代集合)。
  • model_state_dict (Dict[str, ValueType]) – (联合类型 [Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]):要加载的模型状态字典。若 model_state_dict 的键为 nn.Module,则该键是 model 的子模块,其值应为该子模块的状态字典。加载时会将子模块前缀自动附加到状态字典键名。
  • optim_state_dict (OptimizerStateType) – 要加载的优化器状态字典(OptimizerStateType 类型)。
  • options (StateDictOptions) – 控制模型和优化器状态字典加载方式的选项,详见 StateDictOptions 说明。

返回值

  • missing_keys:字符串列表,包含模型状态字典中缺失的键。
  • unexpected_keys:字符串列表,包含模型状态字典中意外的键。

返回类型:包含 missing_keysunexpected_keys 字段的命名元组(NamedTuple)


torch.distributed.checkpoint.state_dict.set_model_state_dict(model, model_state_dict, *, options=None)

加载模型的状态字典(state_dict)。

这是get_model_state_dict的对应方法,用于将状态字典设置到模型上。详细用法请参考set_state_dict

参数

  • model (nn.Module) - 需要加载状态字典的nn.Module模型
  • model_state_dict Dict[str, ValueType]) - (Dict[str, ValueType]): 要加载的模型状态字典。如果model_state_dict的键是nn.Module类型,则该键是model的子模块,对应的值应该是该子模块的状态字典。加载时会将子模块的前缀附加到状态字典上。
  • options (StateDictOptions) - 控制如何加载模型状态字典和优化器状态字典的选项。详情请参阅StateDictOptions。

返回值

  • missing_keys 包含缺失键的字符串列表
  • unexpected_keys 包含意外键的字符串列表

返回类型:带有missing_keysunexpected_keys字段的NamedTuple


torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model, optimizers, optim_state_dict, *, options=None)

加载优化器的状态字典。

这是get_optimizer_state_dict的对应方法,用于将状态字典设置到优化器中。具体用法请参考set_state_dict

警告:set_optimizer_state_dict只能在优化器调用backward()之前或调用step()之后执行。否则,优化器状态将无法正确初始化。

参数

  • model (nn.Module) – 要操作的nn.Module模型。
  • optimizers (Union[Optimizer, Iterable[Optimizer]]) – 用于优化model的优化器或优化器集合。
  • optim_state_dict (OptimizerStateType) – OptimizerStateType类型:要加载的优化器状态字典。
  • options (StateDictOptions) – 控制如何加载模型状态字典和优化器状态字典的选项。详情请参阅StateDictOptions。

返回值:无

返回类型:无


class torch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True, broadcast_from_rank0=False, flatten_optimizer_state_dict=False, dsd_fqn_modifiers='_fqn_modifiers')

该数据类规定了 get_state_dict/set_state_dict 的工作机制:

  • full_state_dict:若设为 True,返回的 state_dict 中将收集所有张量,不会包含任何分片张量(ShardedTensor)或分布式张量(DTensor)。
  • cpu_offload:将所有张量卸载到 CPU。为防止 CPU 内存溢出(OOM),若同时启用 full_state_dict,则仅 rank0 会获取完整 state_dict,其他 rank 将获得空字典。
  • ignore_frozen_params:若为 True,返回的 state_dict 将排除所有冻结参数(即 requires_grad 为 False 的参数),默认值为 False。
  • keep_submodule_prefixes(已弃用):当指定 submodules 时,此选项决定是否保留 state_dict 键名中的子模块前缀。例如:若子模块为 module.pretrain 且参数完整限定名(FQN)为 pretrain.layer1.weight,启用该选项时返回的 state_dict 键名将保持为 pretrain.layer1.weight,禁用时则简化为 layer1.weight
    ⚠️ 注意:若禁用 keep_submodule_prefixes 可能导致 FQN 冲突,因此 submodules 应仅包含单个子模块。
  • strict:控制 set_state_dict 调用 model.load_state_dict() 时的严格模式。
  • broadcast_from_rank0:启用时,rank0 将接收完整 state_dict 并逐个广播其中的张量至其他 rank。其他 rank 会根据模型和优化器的本地分片情况接收并分片张量。使用此选项时必须启用 full_state_dict
    ⚠️ 当前仅支持 DTensor,不支持旧版 ShardedTensor。

针对习惯使用 torch.save 格式共享模型的用户,我们提供了以下离线工具方法用于格式转换:

torch.distributed.checkpoint.format_utils.dcp_to_torch_save(dcp_checkpoint_dir, torch_save_path)

给定一个包含DCP检查点的目录,此函数会将其转换为Torch保存文件。

参数

  • dcp_checkpoint_dir ( Union [str,* PathLike]) - 包含DCP检查点的目录。
  • torch_save_path ( Union [str,* PathLike]) - 用于存储转换后的Torch保存文件的文件名。

警告:为避免内存不足(OOM),建议仅在单个rank上运行此函数。


torch.distributed.checkpoint.format_utils.torch_save_to_dcp(torch_save_path, dcp_checkpoint_dir)

给定 Torch 保存文件的位置,将其转换为 DCP 检查点。


参数

  • torch_save_path ( Union [str,* PathLike]) – Torch 保存文件的文件名。
  • dcp_checkpoint_dir ( Union [str,* PathLike]) – 存储 DCP 检查点的目录。

警告:为避免内存不足(OOM),建议仅在单个 rank 上运行此函数。

以下类也可用于从 torch.save 格式在线加载和重新分片模型。


class torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader(checkpoint_id=None, coordinator_rank=0)

StorageReader 用于读取 Torch 保存文件。该读取器会在协调器节点上读取整个检查点,然后将每个张量广播并分片到所有节点。

注意:需与 DynamicMetaLoadPlanner 配合使用。


警告:当前实现仅支持加载张量。


>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd, >>   storage_reader=BroadcastingTorchSaveReader(), >>   planner=DynamicMetaLoadPlanner(), >>   checkpoint_id="path_to_model.pt"
>>> )


prepare_global_plan(global_plan)

StorageReader 方法的实现

返回值类型:list [torch.distributed.checkpoint.planner.LoadPlan]


prepare_local_plan(plan)

StorageReader 方法的实现

返回类型:LoadPlan


read_data(plan, planner)

在协调器(coordinator)节点上读取 torch 保存的数据,随后进行广播

这会带来通信开销,但避免了在每个节点上加载完整检查点的需求,有望防止内存溢出(OOM)问题

返回类型:Future [None]


read_metadata()

扩展默认的 StorageReader 以支持构建元数据文件

返回类型:Metadata


reset(checkpoint_id=None)

StorageReader 方法的实现


set_up_storage_reader(metadata, is_coordinator)

StorageReader 方法的实现


CLASSMETHOD validate_checkpoint_id(checkpoint_id)

StorageReader 方法的实现

返回类型:bool


class torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True, allow_partial_load=False)

DefaultLoadPlanner的扩展实现,它会根据传入的状态字典创建新的元数据对象,从而避免从磁盘读取元数据的开销。这在读取没有独立元数据文件的格式(如Torch保存文件)时非常有用。

注意:该实现需与BroadcastingTorchSaveReader配合使用。

警告:当前实现仅支持加载张量(Tensors)。


>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd, >>   storage_reader=BroadcastingTorchSaveReader(), >>   planner=DynamicMetaLoadPlanner(), >>   checkpoint_id="path_to_model.pt"
>>> )


set_up_planner(state_dict, metadata=None, is_coordinator=False)

以下是翻译结果:

规划器的设置,通过从状态字典创建元数据对象来扩展默认行为

以下实验性接口可用于提升生产环境中的可观测性:


概率分布 - torch.distributions

distributions 包包含可参数化的概率分布和采样函数。这使得构建随机计算图和用于优化的随机梯度估计器成为可能。该包总体上遵循 TensorFlow Distributions 包的设计理念。

无法直接通过随机样本进行反向传播。然而,有两种主要方法可以创建可反向传播的替代函数:评分函数估计器/似然比估计器/REINFORCE 和路径导数估计器。REINFORCE 通常被视为强化学习中策略梯度方法的基础,而路径导数估计器常见于变分自编码器的重参数化技巧中。评分函数仅需要样本值 f(x)f(x)f(x),而路径导数则需要导数 f′(x)f’(x)f′(x)。接下来的章节将通过强化学习示例讨论这两种方法。更多细节请参阅 使用随机计算图的梯度估计


评分函数

当概率密度函数对其参数可微时,我们只需要使用 sample()log_prob() 即可实现 REINFORCE 算法:

Δ θ = α r ∂ log ⁡ p ( a ∣ π θ ( s ) ) ∂ θ \Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta} Δθ=αrθlogp(aπθ(s))

其中 θ \theta θ 表示参数, α \alpha α 是学习率, r r r 代表奖励值, p ( a ∣ π θ ( s ) ) p(a|\pi^\theta(s)) p(aπθ(s)) 表示在状态 s s s 下根据策略 π θ \pi^\theta πθ 采取行动 a a a 的概率。

实际应用中,我们会从网络输出中采样一个动作,在环境中执行该动作,然后使用 log_prob 构建等效的损失函数。注意这里使用负号是因为优化器采用梯度下降法,而上述规则假设的是梯度上升。对于分类策略,实现 REINFORCE 的代码如下:

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

路径导数

另一种实现这些随机/策略梯度的方法是使用rsample()方法中的重参数化技巧。通过这种方式,参数化的随机变量可以转化为一个无参数随机变量的确定性函数。因此,重参数化后的样本变得可微分。以下是实现路径导数的代码示例:

params = policy_network(state)
m = Normal(params)
# Any distribution with .has_rsample == True could work based on the application
action = m.rsample()
next_state, reward = env.step(action)  # Assuming that reward is differentiable
loss = -reward
loss.backward()

分发


class torch.distributions.distribution.Distribution(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)

基类:object

Distribution 是概率分布的抽象基类。


property arg_constraints:  dict[str , torch.distributions.constraints.Constraint] 

返回一个从参数名到Constraint对象的字典,该字典应满足此分布每个参数的要求。非张量类型的参数无需出现在此字典中。


property batch_shape: Size 

返回参数批处理所应用的形状。


cdf(value)

返回在给定值处评估的累积密度/质量函数。

参数

  • value ( Tensor )

返回类型 : Tensor


entropy()

返回在 batch_shape 上批处理的分布熵。

返回值:形状为 batch_shape 的张量。

返回类型:Tensor


enumerate_support(expand=True)

返回包含离散分布所有可能取值的张量。结果将沿着第0维度进行枚举,因此输出形状为:(基数,) + 批次形状 + 事件形状(对于单变量分布,事件形状=())。

需注意:该方法会以同步锁步方式枚举所有批处理张量,例如[[0,0], [1,1], …]。当expand=False时,枚举仅沿第0维度进行,其余批次维度保持单一维度,形如[[0], [1], …]。

若要遍历完整的笛卡尔积,请使用itertools.product(m.enumerate_support())。

参数说明:

  • expand ([bool]) - 控制是否沿批次维度扩展支持集以匹配分布的batch_shape

返回值:
沿第0维度迭代的张量

返回类型:Tensor


property event_shape:  Size 

返回单个样本的形状(不包含批处理)。


expand(batch_shape, _instance=None)

返回一个新的分布实例(或填充由派生类提供的现有实例),并将批次维度扩展为batch_shape。该方法会在分布的参数上调用expand。因此,扩展后的分布实例不会分配新的内存。此外,当首次创建实例时,不会重复执行__init__.py中的任何参数检查或参数广播操作。

参数

  • batch_shape ( torch.Size ) – 期望扩展的尺寸。
  • _instance – 需要覆盖.expand方法的子类提供的新实例。

返回值:批次维度扩展至batch_size的新分布实例。


icdf(value)

返回在给定值处评估的逆累积密度/质量函数。

参数

  • value ( Tensor )

返回类型 : Tensor


log_prob(value)

返回在给定值处评估的概率密度/质量函数的对数。

参数

  • value ( Tensor )

返回类型 : Tensor


property mean:  Tensor 

返回该分布的均值。


property mode:  Tensor 

返回该分布的众数。


perplexity()

返回在 batch_shape 上批处理的分布困惑度。

返回值:形状为 batch_shape 的张量。

返回类型:Tensor


rsample(sample_shape=torch.Size([]))

生成一个形状为 sample_shape 的重参数化样本,或者当分布参数为批处理时,生成形状为 sample_shape 的批量重参数化样本。

返回类型:Tensor


sample(sample_shape=torch.Size([]))

生成一个形状为 sample_shape 的样本,如果分布参数是批处理的,则生成形状为 sample_shape 的批量样本。

返回类型:Tensor


sample_n(n)

生成 n 个样本,如果分布参数是批处理的,则生成 n 批样本。

返回类型:Tensor


static set_default_validate_args(value)

设置是否启用验证功能。

默认行为模仿 Python 的 assert 语句:验证功能默认开启,但如果 Python 以优化模式运行(通过 python -O 命令)则会自动关闭。由于验证过程可能消耗较多资源,当模型运行稳定后可以考虑禁用此功能。

参数说明

  • value ([bool]) – 控制是否启用验证的布尔值。

property stddev:  Tensor 

返回该分布的标准差。


property support: Optional[Constraint] 

返回一个表示该分布支撑集的 Constraint 对象。


property variance:  Tensor 

返回该分布的方差。


指数族分布


class torch.distributions.exp_family.ExponentialFamily(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)

基类:Distribution

ExponentialFamily 是指数族概率分布的抽象基类,其概率质量/密度函数定义如下:

p F ( x ; θ ) = exp ⁡ ( ⟨ t ( x ) , θ ⟩ − F ( θ ) + k ( x ) ) p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x)) pF(x;θ)=exp(⟨t(x),θF(θ)+k(x))

其中 θ \theta θ 表示自然参数, t ( x ) t(x) t(x) 表示充分统计量, F ( θ ) F(\theta) F(θ) 是该族的对数归一化函数, k ( x ) k(x) k(x) 为载体测度。

说明:该类是 Distribution 类与属于指数族的分布之间的中间层,主要用于验证 .entropy() 和解析 KL 散度方法的正确性。我们利用该类通过自动微分框架和 Bregman 散度来计算熵与 KL 散度(基于 Frank Nielsen 和 Richard Nock 的研究成果《指数族的熵与交叉熵》)。


entropy()

通过计算对数归一化器的Bregman散度来计算熵的方法。


伯努利


class torch.distributions.bernoulli.Bernoulli(probs=None, logits=None, validate_args=None)

基础分布:ExponentialFamily

创建一个由 probslogits 参数化的伯努利分布(但不可同时使用两者)。

样本为二元值(0 或 1)。以概率 p 取值为 1,以概率 1 - p 取值为 0。


示例:

>>> m = Bernoulli(torch.tensor([0.3]))
>>> m.sample()  # 30% chance 1; 70% chance 0
tensor([0.])

参数

  • probs (Number*,* Tensor ) – 采样结果为1的概率
  • logits (Number*,* Tensor ) – 采样结果为1的对数几率

arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}

entropy()

enumerate_support(expand=True)

expand(batch_shape, _instance=None)

has_enumerate_support = True

log_prob(value)

property logits:  Tensor 

property mean:  Tensor 

property mode:  Tensor 

property param_shape:  Size 

property probs:  Tensor 

sample(sample_shape=torch.Size([]))

support = Boolean()

property variance:  Tensor 

Beta


class torch.distributions.beta.Beta(concentration1, concentration0, validate_args=None)

基类:ExponentialFamily

concentration1concentration0 参数化的 Beta 分布。


示例:

>>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
>>> m.sample()  # Beta distributed with concentration concentration1 and concentration0
tensor([0.1046])

参数

  • concentration1 (float 或 Tensor) - 分布的第一个浓度参数(通常称为 alpha)
  • concentration0 (float 或 Tensor) - 分布的第二个浓度参数(通常称为 beta)

arg_constraints = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0)}

property concentration0:  Tensor 

property concentration1:  Tensor 

entropy()

expand(batch_shape, _instance=None)

has_rsample = True

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 
rsample(sample_shape=())

Return type : Tensor


support = Interval(lower_bound=0.0, upper_bound=1.0)

property variance: Tensor

Binomial


class torch.distributions.binomial.Binomial(total_count=1, probs=None, logits=None, validate_args=None)

Bases: Distribution

Creates a Binomial distribution parameterized by total_count and either probs or logits (but not both). total_count must be broadcastable with probs/logits.


Example:


>>> m = Binomial(100, torch.tensor([0 , .2, .8, 1]))

>>> x = m.sample()

tensor([ 0., 22., 71., 100.])

>>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8]))

>>> x = m.sample()

tensor([[4., 5.], [7., 6.]])

Parameters

  • total_count ( int or Tensor ) – number of Bernoulli trials
  • probs ( Tensor ) – Event probabilities
  • logits ( Tensor ) – Event log-odds

arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0), 'total_count': IntegerGreaterThan(lower_bound=0)}


entropy()

enumerate_support(expand=True)

expand(batch_shape, _instance=None)

has_enumerate_support = True

log_prob(value)

property logits: Tensor

property mean: Tensor

property mode: Tensor

property param_shape: Size


property probs:  Tensor
***

sample(sample_shape=torch.Size([]))

property support

Return type : _DependentProperty


property variance: Tensor

Categorical


class torch.distributions.categorical.Categorical(probs=None, logits=None, validate_args=None)

Bases: Distribution

Creates a categorical distribution parameterized by either probs or logits (but not both).


Note: It is equivalent to the distribution that torch.multinomial()
samples from.

Samples are integers from {0,…,K−1}\{0, \ldots, K-1\}{0,…,K−1} where K is probs.size(-1).

If probs is 1-dimensional with length-K, each element is the relative probability of sampling the class at that index.

If probs is N-dimensional, the first N-1 dimensions are treated as a batch of relative probability vectors.


Note: The probs argument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension. probs
will return this normalized value.
The logits argument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension. logits
will return this normalized value.

See also: torch.multinomial()


Example:

>>> m = Categorical(torch.tensor([0.25, 0.25, 0.25, 0.25 ]))

>>> m.sample()  # 0, 1, 2, 3 的采样概率均等

tensor(3)

Parameters

  • probs ( Tensor ) – event probabilities
  • logits ( Tensor ) – event log probabilities (unnormalized)

arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}

entropy()

enumerate_support(expand=True)

expand(batch_shape, _instance=None)

has_enumerate_support = True

log_prob(value)

property logits: Tensor

property mean: Tensor

property mode:  Tensor

property param_shape: Size

property probs: Tensor

sample(sample_shape=torch.Size([]))

property support

Return type : _DependentProperty


property variance: Tensor

Cauchy


class torch.distributions.cauchy.Cauchy(loc, scale, validate_args=None)

Bases: Distribution

Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of independent normally distributed random variables with means 0 follows a Cauchy distribution.


Example:

>>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0]))

>>> m.sample()  # 从位置参数为0、尺度参数为1的柯西分布中采样

tensor([2.3214])

Parameters

  • loc (float or Tensor ) – mode or median of the distribution.
  • scale (float or Tensor ) – half width at half maximum.

arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}

cdf(value)

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


icdf(value)

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

rsample(sample_shape=torch.Size([]))

Return type : Tensor


support = Real()

property variance: Tensor
***

Chi2


class torch.distributions.chi2.Chi2(df, validate_args=None)

Bases: Gamma

Creates a Chi-squared distribution parameterized by shape parameter df.
This is exactly equivalent to Gamma(alpha=0.5*df, beta=0.5)


Example:

>>> m = Chi2(torch.tensor([1.0]))

>>> m.sample()  # 自由度为1的卡方分布抽样

tensor([0.1046])

Parameters

  • df (float or Tensor ) – shape parameter of the distribution

arg_constraints = {'df': GreaterThan(lower_bound=0.0)}

property df: Tensor

expand(batch_shape, _instance=None)

连续伯努利分布


class torch.distributions.continuous_bernoulli.ContinuousBernoulli(probs=None, logits=None, lims=(0.499, 0.501), validate_args=None)

基类:ExponentialFamily

创建一个由 probslogits 参数化的连续伯努利分布(两者不可同时使用)。

该分布的支持区间为 [0, 1],可通过 ‘probs’(取值在 (0,1) 区间)或 ‘logits’(实数)进行参数化。需要注意的是,与伯努利分布不同,这里的 ‘probs’ 并不对应概率,‘logits’ 也不对应对数几率,但由于与伯努利分布的相似性而沿用了相同名称。更多细节请参阅文献 [1]。


示例:

>>> m = ContinuousBernoulli(torch.tensor([0.3]))
>>> m.sample()
tensor([0.2538])

参数

  • probs (Number*,* Tensor ) – 取值范围在(0,1)之间的参数
  • logits (Number*,* Tensor ) – 实数参数,其sigmoid值匹配’probs’

[1] 连续伯努利分布:修正变分自编码器中的一个普遍错误,Loaiza-Ganem G 和 Cunningham JP,NeurIPS 2019。https://arxiv.org/abs/1907.06845

arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}

cdf(value)
entropy()

expand(batch_shape, _instance=None)

has_rsample = True


icdf(value)
log_prob(value)

property logits: Tensor

(注:根据核心翻译原则第1条,代码块内容保持原样不翻译)


property mean: Tensor

property param_shape:  Size

property probs:  Tensor 

rsample(sample_shape=torch.Size([]))

Return type : Tensor


sample(sample_shape=torch.Size([]))

property stddev: Tensor


support = Interval(lower_bound=0.0, upper_bound=1.0)

property variance:  Tensor 

狄利克雷


class torch.distributions.dirichlet.Dirichlet(concentration, validate_args=None)

基类:ExponentialFamily

创建一个由浓度参数 concentration 参数化的狄利克雷分布。


示例:

>>> m = Dirichlet(torch.tensor([0.5, 0.5]))
>>> m.sample()  # Dirichlet distributed with concentration [0.5, 0.5]
tensor([0.1046, 0.8954])

参数

  • concentration ( Tensor ) - 分布的浓度参数(通常称为 alpha)

arg_constraints = {'concentration': IndependentConstraint(GreaterThan(lower_bound=0.0), 1)}

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

rsample(sample_shape=())

返回类型:Tensor


support = Simplex()

property variance:  Tensor 

指数函数


class torch.distributions.exponential.Exponential(rate, validate_args=None)

基类:ExponentialFamily

创建一个由 rate 参数化的指数分布。


示例:

>>> m = Exponential(torch.tensor([1.0]))
>>> m.sample()  # Exponential distributed with rate=1
tensor([0.1046])

参数

  • rate (float 或 Tensor) – 该分布的 rate = 1 / scale

arg_constraints = {'rate': GreaterThan(lower_bound=0.0)}

cdf(value)

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


icdf(value)

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


property stddev:  Tensor 

support = GreaterThanEq(lower_bound=0.0)

property variance:  Tensor 

费希尔-斯涅克分布


class torch.distributions.fishersnedecor.FisherSnedecor(df1, df2, validate_args=None)

基类:Distribution

创建一个由 df1df2 参数化的 Fisher-Snedecor 分布。


示例:

>>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0]))
>>> m.sample()  # Fisher-Snedecor-distributed with df1=1 and df2=2
tensor([0.2453])

参数

  • df1 (float 或 Tensor) – 自由度参数1
  • df2 (float 或 Tensor) – 自由度参数2

arg_constraints = {'df1': GreaterThan(lower_bound=0.0), 'df2': GreaterThan(lower_bound=0.0)}

expand(batch_shape, _instance=None)

has_rsample = True

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


support = GreaterThan(lower_bound=0.0)

property variance:  Tensor 

Gamma


class torch.distributions.gamma.Gamma(concentration, rate, validate_args=None)

基类:ExponentialFamily

创建一个由形状参数 concentration 和比率参数 rate 参数化的 Gamma 分布。


示例:

>>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample()  # Gamma distributed with concentration=1 and rate=1
tensor([0.1046])

参数

  • concentration (float 或 Tensor) - 分布的形状参数(通常称为 alpha)
  • rate (float 或 Tensor) - 分布的速率参数(通常称为 beta),rate = 1 / scale

arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'rate': GreaterThan(lower_bound=0.0)}

cdf(value)

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


support = GreaterThanEq(lower_bound=0.0)

property variance:  Tensor 

几何


class torch.distributions.geometric.Geometric(probs=None, logits=None, validate_args=None)

基类:Distribution

创建一个由 probs 参数化的几何分布,其中 probs 表示伯努利试验的成功概率。

概率质量函数为:
P(X=k)=(1−p)kp,k=0,1,…P(X=k) = (1-p)^{k} p, k = 0, 1,
…P(X=k)=(1−p)kp,k=0,1,…

注意:
torch.distributions.geometric.Geometric() 将第 (k+1)(k+1)(k+1) 次试验视为首次成功,因此采样范围为 {0,1,…}\{0, 1, \ldots\}{0,1,…};
torch.Tensor.geometric_() 将第 k 次试验视为首次成功,因此采样范围为 {1,2,…}\{1, 2, \ldots\}{1,2,…}。


示例:

>>> m = Geometric(torch.tensor([0.3]))
>>> m.sample()  # underlying Bernoulli has 30% chance 1; 70% chance 0
tensor([2.])

参数

  • probs (Number*,* Tensor ) – 采样结果为1的概率值,必须在(0, 1]范围内
  • logits (Number*,* Tensor ) – 采样结果为1的对数几率值

arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)} 

entropy()

expand(batch_shape, _instance=None)

log_prob(value)

property logits:  Tensor 

property mean:  Tensor 

property mode:  Tensor 

property probs:  Tensor 

sample(sample_shape=torch.Size([]))

support = IntegerGreaterThan(lower_bound=0)

property variance:  Tensor 

gumbel


class torch.distributions.gumbel.Gumbel(loc, scale, validate_args=None)

基类:TransformedDistribution

从Gumbel分布中采样。


示例:

>>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0]))
>>> m.sample()  # sample from Gumbel distribution with loc=1, scale=2
tensor([1.0124])

参数

  • loc (float 或 Tensor) - 分布的位置参数
  • scale (float 或 Tensor) - 分布的尺度参数

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)} 

entropy()

expand(batch_shape, _instance=None)

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

property stddev:  Tensor 

support = Real()

property variance: Tensor

HalfCauchy


class torch.distributions.half_cauchy.HalfCauchy(scale, validate_args=None)

Bases: TransformedDistribution

Creates a half-Cauchy distribution parameterized by scale where:


X ~ Cauchy(0, scale)
Y = |X| ~ HalfCauchy(scale)

Example:

>>> m = HalfCauchy(torch.tensor([1.0]))

>>> m.sample()  # 从scale=1的半柯西分布中采样

tensor([2.3214])

Parameters

  • scale (float or Tensor ) – scale of the full Cauchy distribution

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'scale': GreaterThan(lower_bound=0.0)}

cdf(value)

entropy()


expand(batch_shape, _instance=None)

has_rsample = True


icdf(prob)

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

property scale:  Tensor 

support = GreaterThanEq(lower_bound=0.0)

property variance: Tensor

HalfNormal

class torch.distributions.half_normal.HalfNormal(scale, validate_args=None)

Bases: TransformedDistribution

Creates a half-normal distribution parameterized by scale where:

X ~ Normal(0, scale)
Y = |X| ~ HalfNormal(scale)

Example:

>>> m = HalfNormal(torch.tensor([1.0]))

>>> m.sample()  # 从scale=1的半正态分布中采样

tensor([0.1046])

Parameters

  • scale (float or Tensor ) – scale of the full Normal distribution

arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {'scale': GreaterThan(lower_bound=0.0)}

cdf(value)

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


icdf(prob)

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

property scale:  Tensor 

support = GreaterThanEq(lower_bound=0.0)

property variance:  Tensor 

独立


class torch.distributions.independent.Independent(base_distribution, reinterpreted_batch_ndims, validate_args=None)

基类:Distribution

将分布的部分批次维度重新解释为事件维度。

这一功能主要用于改变 log_prob() 返回结果的形状。例如,若想创建一个与多元正态分布形状相同的对角正态分布(使二者可互换),您可以:

>>> from torch.distributions.multivariate_normal import MultivariateNormal
>>> from torch.distributions.normal import Normal
>>> loc = torch.zeros(3)
>>> scale = torch.ones(3)
>>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
>>> [mvn.batch_shape, mvn.event_shape]
[torch.Size([]), torch.Size([3])]
>>> normal = Normal(loc, scale)
>>> [normal.batch_shape, normal.event_shape]
[torch.Size([3]), torch.Size([])]
>>> diagn = Independent(normal, 1)
>>> [diagn.batch_shape, diagn.event_shape]
[torch.Size([]), torch.Size([3])]

参数

  • base_distribution (torch.distributions.distribution.Distribution) – 基础分布
  • reinterpreted_batch_ndims ( int ) – 将被重新解释为事件维度的批次维度数量

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {} 

entropy()

enumerate_support(expand=True)


expand(batch_shape, _instance=None)

property has_enumerate_support:  bool 

property has_rsample:  bool 

log_prob(value)

property mean: Tensor

property mode: Tensor

rsample(sample_shape=torch.Size([]))

Return type : Tensor


sample(sample_shape=torch.Size([]))

Return type : Tensor


property support

Return type : _DependentProperty


property variance: Tensor

InverseGamma



class torch.distributions.inverse_gamma.InverseGamma(concentration, rate, validate_args=None)

Bases: TransformedDistribution

Creates an inverse gamma distribution parameterized by concentration and rate
where:

X ~ Gamma(concentration, rate)
Y = 1 / X ~ InverseGamma(concentration, rate)

Example:

>>> m = InverseGamma(torch.tensor([2.0]), torch.tensor([3.0]))

>>> m.sample()

tensor([1.2953])

Parameters

  • concentration (float or Tensor ) – shape parameter of the distribution
    (often referred to as alpha)
  • rate (float or Tensor ) – rate = 1 / scale of the distribution
    (often referred to as beta)

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'concentration': GreaterThan(lower_bound=0.0), 'rate': GreaterThan(lower_bound=0.0)}

property concentration:  Tensor

entropy()

expand(batch_shape, _instance=None)

has_rsample = True

property mean:  Tensor 

property mode:  Tensor 

property rate:  Tensor 

支持范围 = GreaterThan(下限=0.0)


property variance: Tensor

Kumaraswamy


class torch.distributions.kumaraswamy.Kumaraswamy(concentration1, concentration0, validate_args=None)

Bases: TransformedDistribution

Samples from a Kumaraswamy distribution.


Example:

>>> m = Kumaraswamy(torch.tensor([1.0]), torch.tensor([1.0]))

>>> m.sample()  # 从 alpha=1 和 beta=1 的 Kumaraswamy 分布中采样

tensor([0.1729])

Parameters

  • concentration1 (float or Tensor ) – 1st concentration parameter of the distribution
    (often referred to as alpha)
  • concentration0 (float or Tensor ) – 2nd concentration parameter of the distribution
    (often referred to as beta)

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0)}

entropy()

expand(batch_shape, _instance=None)

has_rsample = True

property mean: Tensor

property mode: Tensor


support = Interval(lower_bound=0.0, upper_bound=1.0)

property variance:  Tensor 

LKJCholesky


class torch.distributions.lkj_cholesky.LKJCholesky(dim, concentration=1.0, validate_args=None)

基类:Distribution

LKJ分布用于描述相关矩阵的下三角Cholesky因子。

该分布由浓度参数η(concentration)控制,使得从Cholesky因子生成的相关矩阵M的概率与det(M)^{η-1}成正比。因此,当concentration == 1时,我们得到相关矩阵Cholesky因子的均匀分布。


L ~ LKJCholesky(dim, concentration)
X = L @ L' ~ LKJCorr(dim, concentration)

请注意,该分布是对相关矩阵的Cholesky因子进行采样,而非直接对相关矩阵本身采样,因此与文献[1]中关于LKJCorr分布的推导略有不同。在采样过程中,这里采用了文献[1]第3节所述的Onion方法。


示例:

>>> l = LKJCholesky(3, 0.5)
>>> l.sample()  # l @ l.T is a sample of a correlation 3x3 matrix
tensor([[1.0000, 0.0000, 0.0000], [0.3516, 0.9361, 0.0000], [-0.1899, 0.4748, 0.8593]])

参数

  • dimension (dim) – 矩阵的维度
  • concentration (float 或 Tensor) – 分布的形状参数/浓度参数(通常称为 eta)

参考文献

[1] Generating random correlation matrices based on vines and extended onion method (2009), Daniel Lewandowski, Dorota Kurowicka, Harry Joe.

Journal of Multivariate Analysis. 100、10.1016/j.jmva.2009.04.008


arg_constraints = {'concentration': GreaterThan(lower_bound=0.0)}

expand(batch_shape, _instance=None)

log_prob(value)

sample(sample_shape=torch.Size([]))

support = CorrCholesky()

Laplace


class torch.distributions.laplace.Laplace(loc, scale, validate_args=None)

Bases: Distribution

Creates a Laplace distribution parameterized by loc and scale.


Example:

>>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0]))

>>> m.sample()  # 服从拉普拉斯分布,位置参数=0,尺度参数=1

tensor([0.1046])

Parameters

  • loc (float or Tensor ) – mean of the distribution
  • scale (float or Tensor ) – scale of the distribution

arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}

cdf(value)

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


icdf(value)

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 


rsample(sample_shape=torch.Size([]))

Return type : Tensor


property stddev:  Tensor

support = Real()

property variance:  Tensor 

对数正态分布


class torch.distributions.log_normal.LogNormal(loc, scale, validate_args=None)

基类:TransformedDistribution

创建一个由locscale参数化的对数正态分布,其中:

***
X ~ Normal(loc, scale)
Y = exp(X) ~ LogNormal(loc, scale)

Example:


>>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample()  # log-normal distributed with mean=0 and stddev=1
tensor([0.1046])

参数

  • loc (float 或 Tensor) - 分布对数的均值
  • scale (float 或 Tensor) - 分布对数的标准差

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}

entropy()

expand(batch_shape, _instance=None)

has_rsample = True

property loc:  Tensor 

property mean:  Tensor 

property mode:  Tensor 

property scale:  Tensor 

support = GreaterThanEq(lower_bound=0.0)

property variance: Tensor

LowRankMultivariateNormal


class torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(loc, cov_factor, cov_diag, validate_args=None)

Bases: Distribution

Creates a multivariate normal distribution with covariance matrix having a low-rank form
parameterized by cov_factor and cov_diag:

covariance_matrix = cov_factor @ cov_factor.T + cov_diag

Example :

>>> m = LowRankMultivariateNormal(
...     torch.zeros(2), torch.tensor([[1.0], [0.0]]), torch.ones(2)
... )
>>> m.sample()  # 服从均值=`[0,0]`、协方差因子=`[[1],[0]]`、对角协方差=`[1,1]`的正态分布
tensor([-0.2102, -0.5429])

Parameters

  • loc ( Tensor ) – mean of the distribution with shape batch_shape + event_shape
  • cov_factor ( Tensor ) – factor part of low-rank form of covariance matrix with shape
    batch_shape + event_shape + (rank,)
  • cov_diag ( Tensor ) – diagonal part of low-rank form of covariance matrix with shape
    batch_shape + event_shape

Note: The computation for determinant and inverse of covariance matrix is avoided when
cov_factor.shape[1] << cov_factor.shape[0] thanks to Woodbury matrix identity and matrix determinant lemma.
Thanks to these formulas, we just need to compute the determinant and inverse of the small size “capacitance” matrix:

capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor

arg_constraints = {'cov_diag': IndependentConstraint(GreaterThan(lower_bound=0.0), 1), 'cov_factor': IndependentConstraint(Real(), 2), 'loc': IndependentConstraint(Real(), 1)}

property covariance_matrix:  Tensor 

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

property precision_matrix:  Tensor 

rsample(sample_shape=torch.Size([]))

Return type : Tensor


property scale_tril:  Tensor


support = IndependentConstraint(Real(), 1)

property variance:  Tensor 

混合相同族分布


class torch.distributions.mixture_same_family.MixtureSameFamily(mixture_distribution, component_distribution, validate_args=None)

基类:Distribution

MixtureSameFamily 分布实现了(批量)混合分布,其中所有组件都来自同一分布类型的不同参数化形式。它通过一个分类"选择分布"(覆盖k个组件)和一个组件分布进行参数化,其中组件分布是一个具有最右侧批量形状(等于[k])的Distribution,用于索引每个(批量的)组件。


示例:

>>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally
>>> # weighted normal distributions
>>> mix = D.Categorical(torch.ones(5,))
>>> comp = D.Normal(torch.randn(5,), torch.rand(5,))
>>> gmm = MixtureSameFamily(mix, comp)

>>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally
>>> # weighted bivariate normal distributions
>>> mix = D.Categorical(torch.ones(5,))
>>> comp = D.Independent(D.Normal(
...          torch.randn(5,2), torch.rand(5,2)), 1)
>>> gmm = MixtureSameFamily(mix, comp)

>>> # Construct a batch of 3 Gaussian Mixture Models in 2D each
>>> # consisting of 5 random weighted bivariate normal distributions
>>> mix = D.Categorical(torch.rand(3,5))
>>> comp = D.Independent(D.Normal(
...         torch.randn(3,5,2), torch.rand(3,5,2)), 1)
>>> gmm = MixtureSameFamily(mix, comp)

参数

  • mixture_distribution (Categorical) – 类似 torch.distributions.Categorical 的实例,用于管理选择组件的概率。类别数量必须与 component_distribution 最右侧的批次维度匹配。必须具有标量 batch_shape 或与 component_distribution.batch_shape[:-1] 匹配的 batch_shape。
  • component_distribution (Distribution) – 类似 torch.distributions.Distribution 的实例。最右侧的批次维度用于索引组件。

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {}

cdf(x)

property component_distribution: Distribution

expand(batch_shape, _instance=None)

has_rsample = False

log_prob(x)

property mean: Tensor

property mixture_distribution: Categorical


sample(sample_shape=torch.Size([]))

property support

Return type : _DependentProperty


property variance: Tensor

Multinomial


class torch.distributions.multinomial.Multinomial(total_count=1, probs=None, logits=None, validate_args=None)

Bases: Distribution

Creates a Multinomial distribution parameterized by total_count and either probs or logits (but not both). The innermost dimension of probs indexes over categories. All other dimensions index over batches.

Note that total_count need not be specified if only log_prob() is called (see example below)


Note: The probs argument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension. probs
will return this normalized value.
The logits argument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension. logits
will return this normalized value.

  • sample() requires a single shared total_count for all
    parameters and samples.
  • log_prob() allows different total_count for each parameter and sample.

Example:

>>> m = Multinomial(100, torch.tensor([1., 1., 1., 1.]))

>>> x = m.sample()  # 0, 1, 2, 3 的采样概率均等

tensor([21., 24., 30., 25.])

>>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)

tensor([-4.1338])

Parameters

  • total_count ( int ) – number of trials
  • probs ( Tensor ) – event probabilities
  • logits ( Tensor ) – event log probabilities (unnormalized)

arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}

entropy()

expand(batch_shape, _instance=None)

log_prob(value)

property logits: Tensor

property mean: Tensor

property param_shape:  Size

property probs: Tensor

sample(sample_shape=torch.Size([]))

property support 

返回类型:_DependentProperty

total_count:int


property variance:  Tensor 

多元正态分布


class torch.distributions.multivariate_normal.MultivariateNormal(loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)

基类:Distribution

创建一个由均值向量和协方差矩阵参数化的多元正态(也称为高斯)分布。

多元正态分布可以通过以下三种方式参数化:
1、正定协方差矩阵 Σ\mathbf{\Sigma}Σ
2、正定精度矩阵 Σ−1\mathbf{\Sigma}^{-1}Σ−1
3、具有正对角元素的下三角矩阵 L\mathbf{L}L(满足 Σ=LL⊤\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\topΣ=LL⊤)

该三角矩阵可以通过协方差矩阵的Cholesky分解等方法获得。


示例

>>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
>>> m.sample()  # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
tensor([-0.2102, -0.5429])

参数

  • loc ( Tensor ) – 分布的均值
  • covariance_matrix ( Tensor ) – 正定协方差矩阵
  • precision_matrix ( Tensor ) – 正定精度矩阵
  • scale_tril ( Tensor ) – 协方差的下三角因子,对角线元素为正

注意:只能指定 covariance_matrixprecision_matrixscale_tril 中的一个参数。

使用 scale_tril 会更高效:所有内部计算都基于 scale_tril。如果传入的是 covariance_matrixprecision_matrix,则仅用于通过 Cholesky 分解计算对应的下三角矩阵。


arg_constraints = {'covariance_matrix': PositiveDefinite(), 'loc': IndependentConstraint(Real(), 1), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}

property covariance_matrix:  Tensor 

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

property precision_matrix:  Tensor 

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


property scale_tril:  Tensor 

support = IndependentConstraint(Real(), 1)

property variance: Tensor

NegativeBinomial


class torch.distributions.negative_binomial.NegativeBinomial(total_count, probs=None, logits=None, validate_args=None)

Bases: Distribution

Creates a Negative Binomial distribution, i.e. distribution of the number of successful independent and identical Bernoulli trials
before total_count failures are achieved. The probability of success of each Bernoulli trial is probs.


Parameters

  • total_count (float or Tensor ) – non-negative number of negative Bernoulli
    trials to stop, although the distribution is still valid for real
    valued count
  • probs ( Tensor ) – Event probabilities of success in the half open interval [0, 1)
  • logits ( Tensor ) – Event log-odds for probabilities of success

arg_constraints = {'logits': Real(), 'probs': HalfOpenInterval(lower_bound=0.0, upper_bound=1.0), 'total_count': GreaterThanEq(lower_bound=0)}

expand(batch_shape, _instance=None)


log_prob(value)

property logits: Tensor

property mean: Tensor

property mode: Tensor

property param_shape:  Size

property probs: Tensor

sample(sample_shape=torch.Size([]))

support = IntegerGreaterThan(lower_bound=0)

property variance:  Tensor 

常规


class torch.distributions.normal.Normal(loc, scale, validate_args=None)

基类:ExponentialFamily

创建一个由locscale参数化的正态(也称为高斯)分布。


示例:

>>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample()  # normally distributed with loc=0 and scale=1
tensor([0.1046])

参数

  • loc (float 或 Tensor) - 分布的均值(通常称为 mu)
  • scale (float 或 Tensor) - 分布的标准差(通常称为 sigma)

arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}

cdf(value)

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


icdf(value)

log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


sample(sample_shape=torch.Size([]))

property stddev:  Tensor 

support = Real()

property variance:  Tensor 

OneHotCategorical


class torch.distributions.one_hot_categorical.OneHotCategorical(probs=None, logits=None, validate_args=None)

基类:Distribution

创建一个由 probslogits 参数化的 one-hot 分类分布。

样本是大小为 probs.size(-1) 的 one-hot 编码向量。

注意:probs 参数必须是非负、有限且具有非零和,它将在最后一个维度上被归一化为总和为 1。probs 将返回这个归一化后的值。

logits 参数将被解释为未归一化的对数概率,因此可以是任何实数。它同样会被归一化,使得最终概率在最后一个维度上总和为 1。logits 将返回这个归一化后的值。

另请参阅:torch.distributions.Categorical() 以了解 probslogits 的详细说明。


示例:

>>> m = OneHotCategorical(torch.tensor([0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample()  # equal probability of 0, 1, 2, 3
tensor([0., 0., 0., 1.])

参数

  • probs ( Tensor ) – 事件概率
  • logits ( Tensor ) – 事件对数概率(未归一化)

arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}

entropy()

enumerate_support(expand=True)

expand(batch_shape, _instance=None)```

has_enumerate_support = True


log_prob(value)

property logits: Tensor


property mean:  Tensor

property mode: Tensor

property param_shape:  Size

property probs: Tensor

sample(sample_shape=torch.Size([]))

support = OneHot()

property variance: Tensor

Pareto


class torch.distributions.pareto.Pareto(scale, alpha, validate_args=None)

Bases: TransformedDistribution

Samples from a Pareto Type 1 distribution.


Example:

>>> m = Pareto(torch.tensor([1.0]), torch.tensor([1.0]))

>>> m.sample()  # 从scale=1且alpha=1的帕累托分布中采样

tensor([1.5623])

Parameters

  • scale (float or Tensor ) – Scale parameter of the distribution
  • alpha (float or Tensor ) – Shape parameter of the distribution

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'alpha': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}

entropy()

Return type : Tensor

expand(batch_shape, _instance=None)

Return type
Pareto


property mean: Tensor

property mode:  Tensor

property support: Constraint

Return type : _DependentProperty


property variance: Tensor

Poisson


class torch.distributions.poisson.Poisson(rate, validate_args=None)

Bases: ExponentialFamily

Creates a Poisson distribution parameterized by rate, the rate parameter.

Samples are nonnegative integers, with a pmf given by
rateke−ratek!\mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}

ratekk!e−rate​Example:

>>> m = Poisson(torch.tensor([4]))
>>> m.sample()
tensor([3.])

Parameters

  • rate (Number*,* Tensor ) – the rate parameter

arg_constraints = {'rate': GreaterThanEq(lower_bound=0.0)}

expand(batch_shape, _instance=None)

log_prob(value)

property mean: Tensor

property mode: Tensor

sample(sample_shape=torch.Size([]))

support = IntegerGreaterThan(lower_bound=0) 

property variance:  Tensor 

松弛伯努利分布

(注:根据技术文档翻译原则,此处保留原英文术语"RelaxedBernoulli"作为专有名词不翻译,仅对标题层级和格式符号进行本地化处理。技术文档中常见的分布名称通常保留原文以确保准确性。)


class torch.distributions.relaxed_bernoulli.RelaxedBernoulli(temperature, probs=None, logits=None, validate_args=None)

基类:TransformedDistribution

创建一个RelaxedBernoulli分布,参数化方式为temperature,以及probslogits(但不可同时使用)。这是伯努利分布的松弛版本,因此取值范围在(0, 1)之间,并且具有可重参数化的样本。


示例:

>>> m = RelaxedBernoulli(torch.tensor([2.2]), 
...                      torch.tensor([0.1, 0.2, 0.3, 0.99]))
>>> m.sample()
tensor([0.2951, 0.3442, 0.8918, 0.9021])

参数

  • temperature ( Tensor ) – 松弛温度
  • probs (Number*,* Tensor ) – 采样结果为1的概率
  • logits (Number*,* Tensor ) – 采样结果为1的对数几率

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}

expand(batch_shape, _instance=None)

has_rsample = True

property logits:  Tensor 

property probs:  Tensor 

support = Interval(lower_bound=0.0, upper_bound=1.0)

property temperature:  Tensor 

LogitRelaxedBernoulli


class torch.distributions.relaxed_bernoulli.LogitRelaxedBernoulli(temperature, probs=None, logits=None, validate_args=None)

基类:Distribution

创建一个由 probslogits(但不同时使用)参数化的 LogitRelaxedBernoulli 分布,这是 RelaxedBernoulli 分布的对数几率。

采样结果是 (0, 1) 区间值的对数几率。更多细节参见[1]。


参数

  • temperature ( Tensor ) – 松弛温度参数
  • probs (Number*,* Tensor ) – 采样结果为 1 的概率
  • logits (Number*,* Tensor ) – 采样结果为 1 的对数优势比

参考文献
[1] 《具体分布:离散随机变量的连续松弛方法》(Maddison 等人,2017)
[2] 《基于 Gumbel-Softmax 的类别重参数化方法》(Jang 等人,2017)

注:
1、保留所有代码块和链接原格式
2、技术术语如"logits"、“tensor"等保持英文
3、被动语态转为主动语态(如"parameterized by"译为"由…参数化的”)
4、数学符号区间(0,1)保留原格式
5、文献标题采用中文书名号并补充说明性文字"方法"


arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}

expand(batch_shape, _instance=None)

log_prob(value)

property logits:  Tensor 

property param_shape:  Size 

property probs:  Tensor 

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


support = Real()

RelaxedOneHotCategorical


class torch.distributions.relaxed_categorical.RelaxedOneHotCategorical(temperature, probs=None, logits=None, validate_args=None)

基类:TransformedDistribution

创建一个由 temperature 以及 probslogits 参数化的 RelaxedOneHotCategorical 分布。

这是 OneHotCategorical 分布的松弛版本,因此其样本位于单纯形上,并且可重新参数化。


示例:

>>> m = RelaxedOneHotCategorical(torch.tensor([2.2]), 
...                              torch.tensor([0.1, 0.2, 0.3, 0.4]))
>>> m.sample()
tensor([0.1294, 0.2324, 0.3859, 0.2523])

参数

  • temperature ( Tensor ) – 松弛温度
  • probs ( Tensor ) – 事件概率
  • logits ( Tensor ) – 每个事件的未归一化对数概率

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}

expand(batch_shape, _instance=None)

has_rsample = True

property logits:  Tensor 

property probs:  Tensor 

support = Simplex()

property temperature:  Tensor 

StudentT 分布


class torch.distributions.studentT.StudentT(df, loc=0.0, scale=1.0, validate_args=None)

基类:Distribution

创建一个由自由度 df、均值 loc 和尺度参数 scale 参数化的学生t分布。


示例:

>>> m = StudentT(torch.tensor([2.0]))
>>> m.sample()  # Student's t-distributed with degrees of freedom=2
tensor([0.1046])

参数

  • df (float 或 Tensor) – 自由度
  • loc (float 或 Tensor) – 分布的平均值
  • scale (float 或 Tensor) – 分布的尺度参数

arg_constraints = {'df': GreaterThan(lower_bound=0.0), 'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}

entropy()

expand(batch_shape, _instance=None)


 

has_rsample = True



log_prob(value)

property mean: Tensor

property mode: Tensor

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


support = Real()

property variance: Tensor

TransformedDistribution


class torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms, validate_args=None)

(说明:根据核心翻译原则第1条"代码保护",所有代码块保持原内容不处理,因此上述Python类定义未作翻译,完整保留原始格式和内容)


Bases: Distribution

Extension of the Distribution class, which applies a sequence of Transforms to a base distribution. Let f be the composition of transforms applied:

X ~ BaseDistribution

Y = f(X) ~ TransformedDistribution(BaseDistribution, f)

log p(Y) = log p(X) + log |det (dX/dY)|

Note that the .event_shape of a TransformedDistribution is the maximum shape of its base distribution and its transforms, since transforms can introduce correlations among events.

An example for the usage of TransformedDistribution would be:

# Building a Logistic Distribution
# X ~ Uniform(0, 1)
# f = a + b * logit(X)
# Y ~ f(X) ~ Logistic(a, b)
base_distribution = Uniform(0, 1)
transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
logistic = TransformedDistribution(base_distribution, transforms)

For more examples, please look at the implementations of Gumbel, HalfCauchy, HalfNormal, LogNormal, Pareto, Weibull, RelaxedBernoulli and RelaxedOneHotCategorical


arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {}

cdf(value)

通过反转变换并计算基础分布的得分来计算累积分布函数。


expand(batch_shape, _instance=None)

property has_rsample:  bool 

icdf(value)

通过变换计算逆累积分布函数,并得出基础分布的评分值。


log_prob(value)

通过逆变换计算样本得分,利用基础分布的得分和对数绝对雅可比行列式进行评分。


rsample(sample_shape=torch.Size([]))

生成一个形状为 sample_shape 的重参数化样本,或者当分布参数为批处理时,生成形状为 sample_shape 的批量重参数化样本。首先从基础分布中采样,然后对列表中的每个变换应用 transform() 方法。

返回类型:Tensor


sample(sample_shape=torch.Size([]))

生成一个形状为 sample_shape 的样本,如果分布参数是批处理的,则生成形状为 sample_shape 的样本批次。首先从基础分布中采样,然后对列表中的每个变换应用 transform() 方法。


property support 

返回类型:_DependentProperty


统一性


class torch.distributions.uniform.Uniform(low, high, validate_args=None)

基础分布:Distribution

生成在半开区间 [low, high) 内均匀分布的随机样本。


示例:

>>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0]))
>>> m.sample()  # uniformly distributed in the range [0.0, 5.0)
tensor([2.3418])

参数

  • low (float 或 Tensor) - 下限值(包含)
  • high (float 或 Tensor) - 上限值(不包含)

arg_constraints = {'high': Dependent(), 'low': Dependent()}

cdf(value)

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


icdf(value)

log_prob(value)

property mean: Tensor

property mode: Tensor

rsample(sample_shape=torch.Size([]))

返回类型:Tensor


property stddev:  Tensor 

property support 

返回类型:_DependentProperty


property variance:  Tensor 

冯·米塞斯


class torch.distributions.von_mises.VonMises(loc, concentration, validate_args=None)

基类:Distribution

圆形冯·米塞斯分布。

该实现采用极坐标系。locvalue参数可以是任意实数(以便进行无约束优化),但会被解释为对2π取模的角度值。


示例:

>>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample()  # von Mises distributed with loc=1 and concentration=1
tensor([1.9777])

参数

  • loc (torch.Tensor) - 以弧度表示的角度值
  • concentration (torch.Tensor) - 集中度参数

arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'loc': Real()}

expand(batch_shape, _instance=None)


has_rsample = False 

log_prob(value)

property mean:  Tensor 

提供的平均值为循环平均值。


property mode:  Tensor 

sample(sample_shape=torch.Size([]))

The sampling algorithm for the von Mises distribution is based on the following paper: D.J. Best and N.I. Fisher, “Efficient simulation of the von Mises distribution.” Applied Statistics (1979): 152-157.

Sampling is always done in double precision internally to avoid a hang in _rejection_sample() for small values of the concentration, which starts to happen for single precision around 1e-4 (see issue #88443).


support = Real()

property variance:  Tensor 

提供的方差为圆形方差。


weibull


class torch.distributions.weibull.Weibull(scale, concentration, validate_args=None)

基类:TransformedDistribution

从双参数威布尔分布中采样的实现。


示例

>>> m = Weibull(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample()  # sample from a Weibull distribution with scale=1, concentration=1
tensor([0.4784])

参数

  • scale (float 或 Tensor) - 分布的尺度参数(lambda)。
  • concentration (float 或 Tensor) - 分布的集中度参数(k/shape)。

arg_constraints:  dict[str , torch.distributions.constraints.Constraint] = {'concentration': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}

entropy()

expand(batch_shape, _instance=None)

property mean:  Tensor 

property mode:  Tensor 


support = GreaterThan(lower_bound=0.0)

property variance:  Tensor 

wishart


class torch.distributions.wishart.Wishart(df, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)

基类:ExponentialFamily

创建一个由对称正定矩阵Σ\SigmaΣ或其Cholesky分解Σ=LL⊤\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\topΣ=LL⊤参数化的Wishart分布。


示例

>>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2))
>>> m.sample()  # Wishart distributed with mean=`df * I` and >># variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j

参数

  • df (float 或 Tensor) – 实值参数,需大于(方阵的维度)- 1
  • covariance_matrix (Tensor) – 正定协方差矩阵
  • precision_matrix (Tensor) – 正定精度矩阵
  • scale_tril (Tensor) – 协方差矩阵的下三角因子,其对角线元素为正

注意:只能指定 covariance_matrixprecision_matrixscale_tril 中的一个。

使用 scale_tril 会更高效:所有内部计算都基于 scale_tril。如果传入的是 covariance_matrixprecision_matrix,则仅用于通过 Cholesky 分解计算对应的下三角矩阵。

torch.distributions.LKJCholesky 是一种受限的 Wishart 分布。[1]

参考文献

[1] Wang, Z., Wu, Y. 和 Chu, H., 2018、关于 LKJ 分布与受限 Wishart 分布的等价性。

[2] Sawyer, S., 2007、Wishart 分布与逆 Wishart 采样。

[3] Anderson, T. W., 2003、多元统计分析导论(第 3 版)。

[4] Odell, P. L. 和 Feiveson, A. H., 1966、生成样本协方差矩阵的数值方法。JASA, 61(313):199-203。

[5] Ku, Y.-C. 和 Bloomfield, P., 2010、在 OX 中生成具有分数自由度的随机 Wishart 矩阵。


arg_constraints = {'covariance_matrix': PositiveDefinite(), 'df': GreaterThan(lower_bound=0), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}

property covariance_matrix:  Tensor 

entropy()

expand(batch_shape, _instance=None)

has_rsample = True


log_prob(value)

property mean:  Tensor 

property mode:  Tensor 

property precision_matrix:  Tensor 

rsample(sample_shape=torch.Size([]), max_try_correction=None)

Warning: In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples.
Several tries to correct singular samples are performed by default, but it may end up returning
singular matrix samples. Singular samples may return -inf values in .log_prob().
In those cases, the user should validate the samples and either fix the value of df or adjust max_try_correction value for argument in .rsample accordingly.

Return type : Tensor


property scale_tril:  Tensor

support = PositiveDefinite()

property variance: Tensor

KL Divergence

`torch.distributions.kl.kl_divergence(p, q)` 

Compute Kullback-Leibler divergence KL(p∥q)KL(p | q)KL(p∥q) between two distributions.

KL(p∥q)=∫p(x)log⁡p(x)q(x) dxKL(p | q) = \int p(x) \log\frac {p(x)} {q(x)} \,dxKL(p∥q)=∫p(x)logq(x)p(x)​dx


Parameters

  • p (Distribution) – A Distribution object.
  • q (Distribution) – A Distribution object.

Returns
A batch of KL divergences of shape batch_shape.

Return type : Tensor

Raises
NotImplementedError – If the distribution types have not been registered via [register_kl()`](https://pytorch.org/docs/stable/data.html#torch.distributions.kl.register_kl “torch.distributions.kl.register_kl”).

KL divergence is currently implemented for the following distribution pairs:* Bernoulli and Bernoulli

  • Bernoulli and Poisson
  • Beta and Beta
  • Beta and ContinuousBernoulli
  • Beta and Exponential
  • Beta and Gamma
  • Beta and Normal
  • Beta and Pareto
  • Beta and Uniform
  • Binomial and Binomial
  • Categorical and Categorical
  • Cauchy and Cauchy
  • ContinuousBernoulli and ContinuousBernoulli
  • ContinuousBernoulli and Exponential
  • ContinuousBernoulli and Normal
  • ContinuousBernoulli and Pareto
  • ContinuousBernoulli and Uniform
  • Dirichlet and Dirichlet
  • Exponential and Beta
  • Exponential and ContinuousBernoulli
  • Exponential and Exponential
  • Exponential and Gamma
  • Exponential and Gumbel
  • Exponential and Normal
  • Exponential and Pareto
  • Exponential and Uniform
  • ExponentialFamily and ExponentialFamily
  • Gamma and Beta
  • Gamma and ContinuousBernoulli
  • Gamma and Exponential
  • Gamma and Gamma
  • Gamma and Gumbel
  • Gamma and Normal
  • Gamma and Pareto
  • Gamma and Uniform
  • Geometric and Geometric
  • Gumbel and Beta
  • Gumbel and ContinuousBernoulli
  • Gumbel and Exponential
  • Gumbel and Gamma
  • Gumbel and Gumbel
  • Gumbel and Normal
  • Gumbel and Pareto
  • Gumbel and Uniform
  • HalfNormal and HalfNormal
  • Independent and Independent
  • Laplace and Beta
  • Laplace and ContinuousBernoulli
  • Laplace and Exponential
  • Laplace and Gamma
  • Laplace and Laplace
  • Laplace and Normal
  • Laplace and Pareto
  • Laplace and Uniform
  • LowRankMultivariateNormal and LowRankMultivariateNormal
  • LowRankMultivariateNormal and MultivariateNormal
  • MultivariateNormal and LowRankMultivariateNormal
  • MultivariateNormal and MultivariateNormal
  • Normal and Beta
  • Normal and ContinuousBernoulli
  • Normal and Exponential
  • Normal and Gamma
  • Normal and Gumbel
  • Normal and Laplace
  • Normal and Normal
  • Normal and Pareto
  • Normal and Uniform
  • OneHotCategorical and OneHotCategorical
  • Pareto and Beta
  • Pareto and ContinuousBernoulli
  • Pareto and Exponential
  • Pareto and Gamma
  • Pareto and Normal
  • Pareto and Pareto
  • Pareto and Uniform
  • Poisson and Bernoulli
  • Poisson and Binomial
  • Poisson and Poisson
  • TransformedDistribution and TransformedDistribution
  • Uniform and Beta
  • Uniform and ContinuousBernoulli
  • Uniform and Exponential
  • Uniform and Gamma
  • Uniform and Gumbel
  • Uniform and Normal
  • Uniform and Pareto
  • Uniform and Uniform

torch.distributions.kl.register_kl(type_p, type_q)

Decorator to register a pairwise function with kl_divergence().
Usage:

@register_kl(Normal, Normal)
def kl_normal_normal(p, q):
    # insert implementation here

Lookup returns the most specific (type,type) match ordered by subclass. If the match is ambiguous, a RuntimeWarning is raised. For example to resolve the ambiguous situation:

@register_kl(BaseP, DerivedQ)
def kl_version1(p, q): ...

@register_kl(DerivedP, BaseQ) 
def kl_version2(p, q): ...

you should register a third most-specific implementation, e.g.:

register_kl(DerivedP, DerivedQ)(kl_version1)  # 打破平局


Parameters

  • type_p (type) – A subclass of Distribution.
  • type_q (type) – A subclass of Distribution.

Transforms


class torch.distributions.transforms.AbsTransform(cache_size=0)

Transform via the mapping y=∣x∣y = |x|y=∣x∣.


class torch.distributions.transforms.AffineTransform(loc, scale, event_dim=0, cache_size=0)

Transform via the pointwise affine mapping y=loc+scale×xy = \text{loc} + \text{scale} \times xy=loc+scale×x.


Parameters

  • loc ( Tensor or float) – Location parameter.
  • scale ( Tensor or float) – Scale parameter.
  • event_dim ( int ) – Optional size of event_shape. This should be zerofor univariate random variables, 1 for distributions over vectors, 2 for distributions over matrices, etc.

class torch.distributions.transforms.CatTransform(tseq, dim=0, lengths=None, cache_size=0)

(注:根据核心翻译原则第1条"代码保护"规则,代码块内容保持原样不翻译)


Transform functor that applies a sequence of transforms tseq
component-wise to each submatrix at dim, of length lengths[dim], in a way compatible with torch.cat().


Example:

x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0)

x = torch.cat([x0, x0], dim=0)

t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10])

t = CatTransform([t0, t0], dim=0, lengths=[20, 20])

y = t(x)

class torch.distributions.transforms.ComposeTransform(parts, cache_size=0)

Composes multiple transforms in a chain.
The transforms being composed are responsible for caching.


Parameters

  • parts (list of Transform ) – A list of transforms to compose.
  • cache_size ( int ) – Size of cache. If zero, no caching is done. If one, the latest single value is cached. Only 0 and 1 are supported.

class torch.distributions.transforms.CorrCholeskyTransform(cache_size=0)

Transforms an uncontrained real vector xxx with length D∗(D−1)/2D*(D-1)/2D∗(D−1)/2 into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower
triangular matrix with positive diagonals and unit Euclidean norm for each row.
The transform is processed as follows:

1、First we convert x into a lower triangular matrix in row order.
2、For each row XiX_iXi​ of the lower triangular part, we apply a signed version of class StickBreakingTransform to transform XiX_iXi​ into a unit Euclidean length vector using the following steps:

  • Scales into the interval (−1,1)(-1, 1)(−1,1) domain: ri=tanh⁡(Xi)r_i = \tanh(X_i)ri​=tanh(Xi​).
  • Transforms into an unsigned domain: zi=ri2z_i = r_i^2zi​=ri2​.
  • Applies si=StickBreakingTransform(zi)s_i = StickBreakingTransform(z_i)si​=StickBreakingTransform(zi​).
  • Transforms back into signed domain: yi=sign(ri)∗siy_i = sign(r_i) * \sqrt{s_i}yi​=sign(ri​)∗si​​.

class torch.distributions.transforms.CumulativeDistributionTransform(distribution, cache_size=0)

Transform via the cumulative distribution function of a probability distribution.


Parameters

  • distribution (Distribution) – Distribution whose cumulative distribution function to use for the transformation.

Example:

# 从多元正态分布构建高斯Copula
base_dist = MultivariateNormal(
    loc=torch.zeros(2), scale_tril=LKJCholesky(2).sample(), )
transform = CumulativeDistributionTransform(Normal(0, 1))
copula = TransformedDistribution(base_dist, [transform])

class torch.distributions.transforms.ExpTransform(cache_size=0)

Transform via the mapping y=exp⁡(x)y = \exp(x)y=exp(x).


class torch.distributions.transforms.IndependentTransform(base_transform, reinterpreted_batch_ndims, cache_size=0)

Wrapper around another transform to treat
reinterpreted_batch_ndims-many extra of the right most dimensions as dependent. This has no effect on the forward or backward transforms, but
does sum out reinterpreted_batch_ndims-many of the rightmost dimensions in log_abs_det_jacobian().


Parameters

  • base_transform ( Transform ) – A base transform.
  • reinterpreted_batch_ndims ( int ) – The number of extra rightmost
    dimensions to treat as dependent.

class torch.distributions.transforms.LowerCholeskyTransform(cache_size=0)

Transform from unconstrained matrices to lower-triangular matrices with nonnegative diagonal entries.

This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization.


class torch.distributions.transforms.PositiveDefiniteTransform(cache_size=0)

(说明:根据核心翻译原则第1条"代码保护"规则,所有代码块保持原内容不处理)


Transform from unconstrained matrices to positive-definite matrices.


class torch.distributions.transforms.PowerTransform(exponent, cache_size=0)

Transform via the mapping y=xexponenty = x^{\text{exponent}}y=xexponent.


class torch.distributions.transforms.ReshapeTransform(in_shape, out_shape, cache_size=0)

Unit Jacobian transform to reshape the rightmost part of a tensor.

Note that in_shape and out_shape must have the same number of elements, just as for torch.Tensor.reshape().


Parameters

  • in_shape ( torch.Size ) – The input event shape.
  • out_shape ( torch.Size ) – The output event shape.
  • cache_size ( int ) – Size of cache. If zero, no caching is done. If one, the latest single value is cached. Only 0 and 1 are supported. (Default 0.)

class torch.distributions.transforms.SigmoidTransform(cache_size=0)

Transform via the mapping y=11+exp⁡(−x)y = \frac{1}{1 + \exp(-x)}y=1+exp(−x)1​ and x=logit(y)x = \text{logit}(y)x=logit(y).


class torch.distributions.transforms.SoftplusTransform(cache_size=0)

Transform via the mapping Softplus(x)=log⁡(1+exp⁡(x))\text{Softplus}(x) = \log(1 + \exp(x))Softplus(x)=log(1+exp(x)).
The implementation reverts to the linear function when x>20x 20x>20、


class torch.distributions.transforms.TanhTransform(cache_size=0)

Transform via the mapping y = t a n h ⁡ ( x ) y=tanh⁡(x) y=tanh(x).

It is equivalent to

ComposeTransform(
    [
        AffineTransform(0.0, 2.0),
        SigmoidTransform(),
        AffineTransform(-1.0, 2.0),
    ]
)

However this might not be numerically stable, thus it is recommended to use TanhTransform
instead.

Note that one should use cache_size=1 when it comes to NaN/Inf values.


class torch.distributions.transforms.SoftmaxTransform(cache_size=0)

SoftmaxTransform 是 PyTorch 分布变换类,用于实现 softmax 变换。该变换通常用于将未归一化的 logits 转换为概率分布。cache_size 参数控制变换结果的缓存大小,设置为 0 表示不缓存。


Transform from unconstrained space to the simplex via y=exp⁡(x)y = \exp(x)y=exp(x) then
normalizing.

This is not bijective and cannot be used for HMC. However this acts mostly
coordinate-wise (except for the final normalization), and thus is appropriate for coordinate-wise optimization algorithms.


class torch.distributions.transforms.StackTransform(tseq, dim=0, cache_size=0)

Transform functor that applies a sequence of transforms tseq
component-wise to each submatrix at dim in a way compatible with torch.stack().


Example:

x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1)

t = StackTransform([ExpTransform(), identity_transform], dim=1)

y = t(x)

class torch.distributions.transforms.StickBreakingTransform(cache_size=0)

Transform from unconstrained space to the simplex of one additional
dimension via a stick-breaking process.

This transform arises as an iterated sigmoid transform in a stick-breaking
construction of the Dirichlet distribution: the first logit is transformed via sigmoid to the first probability and the probability of everything else, and then the process recurses.

This is bijective and appropriate for use in HMC; however it mixes
coordinates together and is less appropriate for optimization.


class torch.distributions.transforms.Transform(cache_size=0)

Abstract class for invertable transformations with computable log
det jacobians. They are primarily used in torch.distributions.TransformedDistribution.

Caching is useful for transforms whose inverses are either expensive or numerically unstable. Note that care must be taken with memoized values
since the autograd graph may be reversed. For example while the following
works with or without caching:

y = t(x)

t.log_abs_det_jacobian(x, y).backward()  # x将接收梯度。


However the following will error when caching due to dependency reversal:

y = t(x)

z = t.inv(y)

grad(z.sum(), [y])  # 报错,因为 z 就是 x


Derived classes should implement one or both of _call() or _inverse(). Derived classes that set bijective=True should also
implement log_abs_det_jacobian().


Parameters

  • cache_size ( int ) – Size of cache. If zero, no caching is done. If one, the latest single value is cached. Only 0 and 1 are supported.

Variables

  • domain (Constraint) – The constraint representing valid inputs to this transform.
  • codomain (Constraint) – The constraint representing valid outputs to this transform
    which are inputs to the inverse transform.
  • bijective ([bool]) – Whether this transform is bijective. A transform
    t is bijective iff t.inv(t(x)) == x and t(t.inv(y)) == y for every x in the domain and y in the codomain. Transforms that are not bijective should at least
    maintain the weaker pseudoinverse properties
    t(t.inv(t(x)) == t(x) and t.inv(t(t.inv(y))) == t.inv(y).
  • sign ( int or Tensor ) – For bijective univariate transforms, this should be +1 or -1 depending on whether transform is monotone
    increasing or decreasing.

property inv: Transform

Returns the inverse Transform of this transform.
This should satisfy t.inv.inv is t.


property sign: int

Returns the sign of the determinant of the Jacobian, if applicable.
In general this only makes sense for bijective transforms.


log_abs_det_jacobian(x, y)

Computes the log det jacobian log |dy/dx| given input and output.



forward_shape(shape)

Infers the shape of the forward computation, given the input shape.
Defaults to preserving shape.



inverse_shape(shape)

Infers the shapes of the inverse computation, given the output shape.
Defaults to preserving shape.


Constraints


class torch.distributions.constraints.Constraint

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

Variables

  • is_discrete ([bool]) – Whether constrained space is discrete.
    Defaults to False.
  • event_dim ( int ) – Number of rightmost dimensions that together define an event. The check() method will remove this many dimensions
    when computing validity.


check(value)

Returns a byte tensor of sample_shape + batch_shape indicating
whether each event in value satisfies this constraint.


torch.distributions.constraints.cat

alias of _Cat


torch.distributions.constraints.dependent_property

alias of _DependentProperty


torch.distributions..constraints.greater_than

alias of _GreaterThan


torch.distributions..constraints.greater_than_eq

alias of _GreaterThanEq


torch.distributions..constraints.independent

alias of _IndependentConstraint


torch.distributions..constraints.integer_interval

alias of _IntegerInterval


torch.distributions..constraints.interval

alias of _Interval


torch.distributions..constraints.half_open_interval

alias of _HalfOpenInterval


torch.distributions..constraints.is_dependent(constraint)

Checks if constraint is a _Dependent object.


Parameters

  • constraint – A Constraint object.

Returns
True if constraint can be refined to the type _Dependent, False otherwise.

Return type
bool


Examples

>>> import torch

>>> from torch.distributions import Bernoulli

>>> from torch.distributions.constraints import is_dependent


>>> dist = Bernoulli(probs=torch.tensor([0.6], requires_grad=True))

>>> constraint1 = dist.arg_constraints["probs"]

>>> constraint2 = dist.arg_constraints["logits"]

>>> for constraint in [constraint1, constraint2]:
        if is_dependent(constraint):
            continue

torch.distributions.constraints.less_than 

alias of _LessThan


torch.distributions..constraints.multinomial 

alias of _Multinomial


torch.distributions..constraints.stack

alias of _Stack


Constraint Registry

PyTorch provides two global ConstraintRegistry objects that link Constraint objects to Transform objects. These objects both input constraints and return transforms, but they have different guarantees on bijectivity.

1、biject_to(constraint) looks up a bijective Transform from constraints.real to the given constraint. The returned transform is guaranteed to have .bijective = True and should implement .log_abs_det_jacobian().
2、transform_to(constraint) looks up a not-necessarily bijective Transform from constraints.real to the given constraint. The returned transform is not guaranteed to implement .log_abs_det_jacobian().

The transform_to() registry is useful for performing unconstrained optimization on constrained parameters of probability distributions, which are indicated by each distribution’s .arg_constraints dict. These transforms often overparameterize a space in order to avoid rotation; they are thus more suitable for coordinate-wise optimization algorithms like Adam:


loc = torch.zeros(100, requires_grad=True)

unconstrained = torch.zeros(100, requires_grad=True)

scale = transform_to(Normal.arg_constraints["scale"])(unconstrained)

loss = -Normal(loc, scale).log_prob(data).sum()

The biject_to() registry is useful for Hamiltonian Monte Carlo, where samples from a probability distribution with constrained .support are propagated in an unconstrained space, and algorithms are typically rotation invariant.:

dist = Exponential(rate)

unconstrained = torch.zeros(100, requires_grad=True)

sample = biject_to(dist.support)(unconstrained)

potential_energy = -dist.log_prob(sample).sum()

Note: An example where transform_to and biject_to differ is constraints.simplex: transform_to(constraints.simplex) returns a SoftmaxTransform that simply exponentiates and normalizes its inputs; this is a cheap and mostly coordinate-wise operation appropriate for algorithms like SVI. In contrast, biject_to(constraints.simplex) returns a StickBreakingTransform that bijects its input down to a one-fewer-dimensional space; this a more expensive less numerically stable transform but is needed for algorithms like HMC.

The biject_to and transform_to objects can be extended by user-defined constraints and transforms using their .register() method either as a function on singleton constraints:

transform_to.register(my_constraint, my_transform)

or as a decorator on parameterized constraints:


@transform_to.register(MyConstraintClass)
def my_factory(constraint):
    assert isinstance(constraint, MyConstraintClass)
    return MyTransform(constraint.param1, constraint.param2)

You can create your own registry by creating a new ConstraintRegistry
object.


class torch.distributions.constraint_registry.ConstraintRegistry

Registry to link constraints to transforms.


register(constraint, factory=None)

Registers a Constraint
subclass in this registry. Usage:


@my_registry.register(MyConstraintClass)
def construct_transform(constraint):
    assert isinstance(constraint, MyConstraint)
    return MyTransform(constraint.arg_constraints)

参数说明

  • constraint (Constraint的子类) - 可以是Constraint的子类,或是目标类的单例对象。
  • factory (可调用对象) - 一个可调用对象,接收约束对象作为输入并返回一个Transform对象。

2025-05-10(六)


网站公告

今日签到

点亮在社区的每一天
去签到