DeepSpeed ZeRO 主流程源码和流程解析

ZeRO:Zero Redundancy Optimizer(无冗余优化器),主要目标是为了降低 GPU 显存的冗余占用,通过分片的模式来提高训练任务并行效率或者提高资源利用率。

ZeRO 和 DeepSpeed

背景:

ZeRO:Zero Redundancy Optimizer(无冗余优化器),主要目标是为了降低 GPU 显存的冗余占用,通过分片的模式来提高训练任务并行效率或者提高资源利用率。

显存内容:

  • Gradients:梯度;
  • Activations:激活函数信息,比如 ReLU 里 0 值的位置等;
  • Optimizer 状态,比如 Adam 优化器里的一些加权平均值等参数;
    [图片]

ZeRO 的不同阶段,对应了不同状态的冗余(Optimizer、Gradients、Parameters)处理:

Stage 分片参数 过程
ZeRO-1 Optimizer 状态分片: - 如 Adam Optimizer,动量和均方信息; - Backward 过程中,每个 GPU 只负责自己的 optimizer 参数,比如:m[p] = beta * m'[p] + (1-beta)*grad[p]
ZeRO-2 Optimizer 的状态和 Gradients 分片 - Backward 过程中,gradients 被 reduce 到各自 rank 上,取代了 all reduce,更新完成后进行 all gather 操作;
- 由于梯度被切分,更新参数前需要 All-Reduce → Reduce-Scatter → All-Gather 这样的通信操作:
  - 反向传播时:
    - 每张 GPU 算出本地梯度。
    - 通过 Reduce-Scatter 把梯度分摊到不同 GPU 上。
    - 每张 GPU 只保存自己那部分梯度。
  - 优化器更新时:
    - 优化器只在本地存储的梯度和状态上更新对应的参数分片。
  - 前向传播时:
    - 参数依旧全量存在,所以不用 All-Gather(这是 ZeRO-3 才需要的)。
ZeRO-3 Optimizer 的状态、Gradients、Parameters 分片 - 相当于数据并行 + 模型并行(因为 parameters 做了分片) - Backward
  - 按层执行,gather-reduce-scatter,gather 是把所有梯度放到当前 GPU,reduce 是汇总并聚合所有梯度,scatter 是释放不属于当前 gpu 的梯度;
  - 参数更新时因为 parameters/gradients/optimizer 都是一一对应的,所以不用进行通信,本地完成更新;
- Forward
  - 按层(Layer)动态进行 all-gather,gather-compute-free,算完之后就进行释放(这里如何保证效率?)

初始化

代码主要在 DeepSpeedEngine.py,初始化过程中主要干了以下几个事情:

  1. dataloader 初始化
self.training_dataloader = self.deepspeed_io(training_data)

这一部分主要使用了 DeepSpeed 的 DataLoader,相比于 PyTorch DataLoader 区别是

对比项 PyTorch DataLoader DeepSpeedDataLoader
核心定位 通用数据加载器(单机/多进程) 专为 分布式训练 优化的 DataLoader 封装
分布式支持 需手动设置 DistributedSampler,每进程各自初始化 自动集成分布式采样与 sharding,适配 DeepSpeed engine 的 rank/world_size
数据划分 手动控制每 GPU 的样本切分 内部自动根据 rank 划分 dataset,确保样本不重复不遗漏
动态 batch 管理 固定 batch_size,需自行调整 支持 动态批大小(dynamic batch size) 与 gradient accumulation 配合
异步与性能优化 基本依赖 Python 多进程/线程加载 支持 异步预取(async prefetch)、pipeline 式数据加载,减少 I/O 等待
容错/重启 不支持状态恢复 可与 DeepSpeed Checkpoint 联动,恢复 DataLoader 状态(epoch、iteration)
典型使用场景 单机或基本 DDP 大规模分布式训练(ZeRO、MoE、pipeline 等)
  1. Optimizer 初始化
if has_optimizer:
    self._configure_optimizer(optimizer, model_parameters)

这里面根据业务的各种配置,以及 optimizer 的参数进行初始化,本文要讲的 ZeRO Optimizer 就是在这部分进行初始化完成,其中 ZeRO 有 4 个 Enum,对应四种模式:

class ZeroStageEnum(int, Enum):
    """ Enum class for possible zero stages """
    disabled = 0    
    optimizer_states = 1    # 只缓存 optimizer states,对应 ZeRO-1
    gradients = 2           # 缓存 optimizer states + gradients,对应 ZeRO-2
    weights = 3             # 缓存 optimizer states + gradients + weights,对应 ZeRO-3
    max_stage = 3

初始化部分主要是上面两段内容,另外会针对 deepspeed 内部的一些优化特性进行初始化,比如常见的:

  • Torch AMP:
    • 它解决的是在不显著损失数值精度的前提下,显著提升训练/推理速度 & 降低显存占用的问题。核心做法是:对不同算子用不同精度(如 matmul/conv 用 FP16/BF16,加法/归一化/损失计算用 FP32 累加),并在训练时用动态 loss scaling防止 FP16 下溢。
    1. 对不同算子使用不同精度
    • Matmul / Conv(矩阵乘法,卷积):计算量大,乘法相对误差小;
    • 加法 / 归一化 / loss 计算:涉及的计算精度要求高,使用 FP32;
    1. 训练时使用 loss scaling
  • ZenFlow:ZenFlow 是 DeepSpeed 在 2025 年推出的一个 扩展模块 / 引擎,定位为 “stall-free offloading engine”(无停顿的 offloading 引擎)。 arxiv.org+3PyTorch+3DeepSpeed+3,核心问题是解决 参数 offload 到 CPU 过程中 GPU 的阻塞问题。
    它基于 ZeRO-Offload,对其进行了 解耦(decoupling)、异步更新(asynchronous updates) 的增强,用来解决 ZeRO-Offload 在 CPU–GPU 同步更新时造成的 GPU 空闲/等待问题(stall)DeepSpeed+2PyTorch+2
    简单说:它让 重要梯度 在 GPU 上立刻更新,而把不那么关键的梯度推到 CPU,做异步累计 & 更新,从而让 CPU 和 PCIe 传输的工作可以在 GPU 计算期间重叠,减少 GPU 空闲时间。

ZeRO-1

初始化

代码在 DeepSpeedZeroOptimizer 类中,这里要注意,ZeRO-1 和 ZeRO-2 使用的是同一个 Optimizer,但是通过 partition_gradients 来区分(即是否要对 gradients 梯度进行分片)

# ZeRO stage 1 (False) or 2 (True)
self.partition_gradients = partition_grads
self.zero_stage_string = "ZeRO-2" if partition_grads else "ZeRO-1"

核心过程:

  1. 初始化参数组,
self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
self.partition_count = [dp_size for i in range(len(self.optimizer.param_groups))]

一个 optimizer 内部的参数可以被分为多个参数组,共享不同的超参;

  1. 遍历所有的参数组信息,一系列预处理:

打散

# 遍历所有的参数组
for i, param_group in enumerate(self.optimizer.param_groups):
    # 当前进程在特定数据并行组中的 rank,用于确定当前进程负责哪一部分参数
    partition_id = dist.get_rank(group=self.real_dp_process_group[i])
    ...
    # 这里把需要 train 的参数都放进了 bit16_groups 里
    if self.round_robin_gradients:
        # 把第 i 个参数组的参数按照轮询的方式,拆到同一个进程组的不同进程里
        round_robin_tensors, round_robin_indices = self._round_robin_reorder(self.bit16_groups[i],dist.get_world_size(group=self.real_dp_process_group[i]))
    else:
        round_robin_tensors = self.bit16_groups[i]
        round_robin_indices = list(range(len(self.bit16_groups[i])))

这里的 round_robin_gradients 主要是为了梯度打散,比如 transformer 里多个 layer,不同 layer 的参数更新负载可能是不一样的,所以这里要根据 partition 做一个打散;

Padding:

# 把 tensor 拍平,然后根据 NCCL 的对齐要求,对齐到边界,多余的补上 torch.zeros(..)
flattened_buffer = self.flatten_dense_tensors_aligned(
    self.round_robin_bit16_groups[i],
    self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i]),
    use_cpu_data=True)
...
# 尽可能等分作切割,每个 partition 里存的是一部分 flat tensor
data_parallel_partitions = self.get_data_parallel_partitions(self.bit16_groups_flat[i], i)
self.parallel_partitioned_bit16_groups.append(data_parallel_partitions)
...
# partition_id之前的所有 tensor 的元素数量,求一个 left bound,算一下当前的长度
left_boundary = sum([t.numel() for t in data_parallel_partitions[:partition_id]])
curr_partition_size = data_parallel_partitions[partition_id].numel()

# 计算 padding 的大小; 这里是为了确保每个 partition 的大小是 nccl_start_alignment_factor 的倍数
if orig_group_numel <= left_boundary:
    padding = curr_partition_size
elif orig_group_numel < left_boundary + curr_partition_size:
    padding = left_boundary + curr_partition_size - orig_group_numel
else:
    padding = 0
self.groups_padding.append(padding)

这里用 round robin 后的参数,进行拍平,flatten 操作就是把参数 tensor 转换成一维,然后根据 NVLink 的通信要求对齐到边界,padding 的补上 torch.zeros(…);

明确关键变量

partition_size = len(self.bit16_groups_flat[i]) / dist.get_world_size(group=self.real_dp_process_group[i])
# 取到当前的 partition id 分片的 params
params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(
    self.round_robin_bit16_groups[i], partition_size, partition_id)

self.partition_size.append(partition_size)
self.params_in_partition.append(params_in_partition)
self.params_not_in_partition.append(params_not_in_partition)
self.first_offset.append(first_offset)

这里主要是明确 params_in_partition 和 params_not_in_partition,在后续 backward 和 step 过程里会使用到。

Step()

因为 ZeRO-1 主要是在 optimizer stats 上有变化,所以主要改变在 step() 的步骤(具体代码参考 DeepSpeedZeroOptimizer 中的 step() 方法):

  1. 判断是否 overflow,对于 fp16 指数位比较少的情况会出现,如果 overflow 了,通过 all-reduce 广播给所有节点,降级或者放弃这次 params 更新
# 这里是如何判断overflow的?
# 1. 先判断当前的 partition 是否 overflow
# 2. 如果当前的 partition 没有 overflow,再判断是否有其他的 partition 溢出
# 4. 做一个all reduce, 把所有的 partition 的 overflow 标志位都同步起来
def has_overflow(self, partition_gradients=True):
    overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial()
    overflow_gpu = get_accelerator().ByteTensor([overflow]) if self.cpu_offload else overflow.byte().to(
        get_accelerator().current_device_name())

    if partition_gradients:
        '''This will capture overflow across all data parallel and expert parallel process
        Since expert parallel process are a subset of data parallel process'''
        dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group)

    # Since each model parallel GPU carries only part of the model,
    # make sure overflow flag is synced across all the model parallel GPUs
    self._model_parallel_all_reduce(tensor=overflow_gpu, op=dist.ReduceOp.MAX)

    overflow = overflow_gpu[0].item()
    return bool(overflow)
  1. 计算梯度
# 释放掉所有不在当前 partition 中的 params 的梯度
self.free_grad_in_param_list(self.params_not_in_partition[i])

# 这里是取到当前 partition 中的 params 的梯度
single_grad_partition = self.flatten(self.averaged_gradients[i]).to(self.single_partition_of_fp32_groups[i].dtype)

关于 averaged_gradients 是如何计算出来的,可以追踪源码看一下,有很多种配置和情况,具体代码在(简单来说都是通过 all-reduce 来完成 gradients 同步的)

# ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients():
    self.optimizer.overlapping_partition_gradients_reduce_epilogue()

# Communicate only at gradient accumulation boundaries
elif self.is_gradient_accumulation_boundary():
    if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states and hasattr(
            self.optimizer, 'reduce_gradients'):
        self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism)
    else:
        grads = None
        self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size)
elif self.zenflow:
    self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism)
  1. 获取梯度后进行 optimize()
# Step 3:- run the optimizer if no offloading
self.timers(OPTIMIZER_STEP_TIMER).start()
self._optimizer_step(i)

def _optimizer_step(self, group_no):
    original_param_groups = self.optimizer.param_groups
    self.optimizer.param_groups = [original_param_groups[group_no]]
    ...
    self.optimizer.step()
    self.optimizer.param_groups = original_param_groups

注意这里,先把 optimizer 内部的 param_groups 进行暂存,然后重置为当前分片的 optimizer params(optimizer states 分片,即优化器只更新属于自己分片的参数,通过 all-gather 获取其他节点更新的参数),进行参数调整后,再重置为原始的所有 param_groups

  1. All-Gather 完成参数收集
# 这里是 all gather 所有的 params, 然后更新到当前的 partition 中
all_gather_dp_groups(groups_flat=self.bit16_groups_flat,
                     partitioned_param_groups=self.parallel_partitioned_bit16_groups,
                     dp_process_group=self.real_dp_process_group,
                     start_alignment_factor=self.nccl_start_alignment_factor,
                     allgather_bucket_size=self.allgather_bucket_size)

ZeRO-2

复用 ZeRO-1 的 optimizer,但是通过 partition_gradients 参数来区分两者,整体的代码非常相似,这里只列举一下两边不同的部分,主要是在 backward() 里。(因为 gradient 的分片主要是在反向过程中计算梯度的时候用到)

backward()

调用链:

def backward(self, loss, retain_graph=False, scale_wrt_gas=True):
    ...
    # 这里是直接diaoyong调用 optimizer 的 backward 方法
    self._do_optimizer_backward(loss, retain_graph)
    # 这里是对 backward() 后的数据做处理
    self._backward_epilogue()

def _backward_epilogue(self):
    self.allreduce_gradients()

def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
    ...
    # ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
    if self.zero_optimization_partition_gradients():
        self.optimizer.overlapping_partition_gradients_reduce_epilogue()

def independent_gradient_partition_epilogue(self):
    ...
    self.reduce_ipg_grads()
    ...

这里有一个核心方法 reduce_ipg_grads(),ipg 的含义是 independent partitioned gradients,这个方法里通过对 grads 进行分桶,以及 reduce 等操作,间接实现了分片/scatter 的功能,核心方法是 average_tensor。
详细分桶规则参考下面:

  • rank_and_offsets 是将 grads 打平成一维后,不同 rank 和每个 rank grads 的 offset;
    • dst 是目标 rank,bucket_offset 是 grads 起始 offset,numel 是 grads 数量;
  • 产出的 buckets,bucket_key 是 dst + 进程组,value 是 grad tensor;
buckets = {}
for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets):
    grad_slice = tensor.narrow(0, int(bucket_offset), int(numel))
    bucket_key = real_dp_process_group[i] if self.use_multi_rank_bucket_allreduce else (
        dst, real_dp_process_group[i])
    if bucket_key not in buckets:
        buckets[bucket_key] = []
    if self.use_multi_rank_bucket_allreduce:
        buckets[bucket_key].append((dst, grad_slice))
    else:
        buckets[bucket_key].append(grad_slice)

遍历所有 buckets(注意这里的 buckets 里包含要发送给所有 rank 节点的所有 grads 梯度信息),代码如下。在 ipg 相关代码里,主要做两个事情:

  1. 针对 dst rank,把所有的 grads 发送给 dst(通过 dist.reduce)
  2. 把属于当前的 partition 的 grads 储存下来,其他的 clear() ;(通过 bucket.clear())
for bucket_key in buckets:
    self.allreduce_no_retain(buckets[bucket_key],
                             communication_data_type,
                             numel_per_bucket=self.reduce_bucket_size,
                             rank=dst,
                             divide=False,
                             process_group=process_group)

ZeRO-3

训练过程中 ZeRO-3 的参数加载、释放、all-gather 生命周期都是由 deepspeed 内部的各种 hook 实现自动管理的;其中几个典型的 hook:

# Pre forward hook
self.forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook))
# Post forward hook
self.forward_hooks.append(module.register_forward_hook(_post_forward_module_hook))

PartitionedParameterCoordinator

对于 ZeRO-3,最核心的是 PartitionedParameterCoordinator,和 ZeRO-1,2 不同,这里对于 params 做了分片来最大程度节约显存空间,那么 forward 和 backward 过程中,对于 params 的处理和协调,就是在这个 Coordinator 里实现的。

这里描述了 Coordinator 的所有功能

class PartitionedParameterCoordinator:
    FORWARD_FETCH_SUBMIT = 'forward_fetch_submit'
    FORWARD_FETCH_WAIT = 'forward_fetch_wait'
    FORWARD_PREFETCH_SUBMIT = 'forward_prefetch_submit'
    BACKWARD_FETCH_SUBMIT = 'backward_fetch_submit'
    BACKWARD_FETCH_WAIT = 'backward_fetch_wait'
    BACKWARD_PREFETCH_SUBMIT = 'backward_prefetch_submit'
    FORWARD_ALL_GATHER = 'forward_all_gather'
    BACKWARD_ALL_GATHER = 'backward_all_gather'
    """Handles partitioning and gathering of parameters."""

为了加速 ZeRO-3 参数分片带来的性能损失(作为显存下降的 trade-off),DeepSpeed 也提供了一些特性:

  • pre-fetch:预拉取 next layer 的参数
  • Fast fetch:跳过一些依赖校验,直接拿参数;(这个原理后面再看..)
  • Trace 能力:通过 trace 来做一些加速;(这个原理后面再看..)

对于 params 的参数的状态追踪有三种(主要还是为了保持一致性):

  • AVAILABLE:当前分片可用
  • NOT_AVAILABLE:当前分片不可用
  • INFLIGHT:传输中
class ZeroParamStatus(Enum):
    # parameters are fully present and ready for use on all processes
    AVAILABLE = 1
    # parameters are either partitioned or remote in some or all process
    NOT_AVAILABLE = 2
    # parameters are being gathered.
    INFLIGHT = 3

Forward

  1. Pre Forward:

Forward 对于每个 layer / nn module 的调用,会通过 fetch 来取到当前 module 需要的参数,同时 pre-fetch 加速后续 module 的参数获取,代码如下:

def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
    ...
    params_to_fetch = set(iter_params(current_submodule, recurse=z3_leaf_module(current_submodule)))
    fetch_numel = sum([p.partition_numel() for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])

    if fetch_numel > 0:
        ...
        self.__all_gather_params(params_to_fetch, forward)
  1. Post Forward:

PreForward 用来获取参数,PostForward 就是用来裁剪参数。

def release_sub_module(self, submodule: Module, forward=False) -> None:
    """release the parameters of a sub module, assuming they meet conditions to
    be released."""
    #print_rank_0(f"release_sub_module {'fwd' if forward else 'bwd'}: {debug_module2name_id(submodule)}", force=False)
    params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
        p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule))))

    free_data = not z3_leaf_module(submodule) or not self.fast_sharding_for_leaf_module
    if not free_data:
        # wait for the computation to finish and launch as early as possible.
        empty_buffer = torch.empty(1, device=get_accelerator().current_device())

    for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
        param.ds_active_sub_modules.discard(submodule.ds_id)
        if param.ds_id in params_to_release and not param.is_external_param:
            self.__release_param(param, free_data)
        if not free_data:
            if param.ds_id in params_to_release and not param.is_external_param:
                # empty buffer ensures that all computations are complete
                param.data = empty_buffer

Backward

累了…