ZeRO:Zero Redundancy Optimizer(无冗余优化器),主要目标是为了降低 GPU 显存的冗余占用,通过分片的模式来提高训练任务并行效率或者提高资源利用率。
ZeRO 和 DeepSpeed
背景:
ZeRO:Zero Redundancy Optimizer(无冗余优化器),主要目标是为了降低 GPU 显存的冗余占用,通过分片的模式来提高训练任务并行效率或者提高资源利用率。
显存内容:
- Gradients:梯度;
- Activations:激活函数信息,比如 ReLU 里 0 值的位置等;
- Optimizer 状态,比如 Adam 优化器里的一些加权平均值等参数;
[图片]
ZeRO 的不同阶段,对应了不同状态的冗余(Optimizer、Gradients、Parameters)处理:
Stage | 分片参数 | 过程 |
---|---|---|
ZeRO-1 | Optimizer 状态分片: - 如 Adam Optimizer,动量和均方信息; | - Backward 过程中,每个 GPU 只负责自己的 optimizer 参数,比如:m[p] = beta * m'[p] + (1-beta)*grad[p] |
ZeRO-2 | Optimizer 的状态和 Gradients 分片 | - Backward 过程中,gradients 被 reduce 到各自 rank 上,取代了 all reduce,更新完成后进行 all gather 操作; - 由于梯度被切分,更新参数前需要 All-Reduce → Reduce-Scatter → All-Gather 这样的通信操作: - 反向传播时: - 每张 GPU 算出本地梯度。 - 通过 Reduce-Scatter 把梯度分摊到不同 GPU 上。 - 每张 GPU 只保存自己那部分梯度。 - 优化器更新时: - 优化器只在本地存储的梯度和状态上更新对应的参数分片。 - 前向传播时: - 参数依旧全量存在,所以不用 All-Gather(这是 ZeRO-3 才需要的)。 |
ZeRO-3 | Optimizer 的状态、Gradients、Parameters 分片 - 相当于数据并行 + 模型并行(因为 parameters 做了分片) | - Backward - 按层执行,gather-reduce-scatter,gather 是把所有梯度放到当前 GPU,reduce 是汇总并聚合所有梯度,scatter 是释放不属于当前 gpu 的梯度; - 参数更新时因为 parameters/gradients/optimizer 都是一一对应的,所以不用进行通信,本地完成更新; - Forward - 按层(Layer)动态进行 all-gather,gather-compute-free,算完之后就进行释放(这里如何保证效率?) |
初始化
代码主要在 DeepSpeedEngine.py,初始化过程中主要干了以下几个事情:
- dataloader 初始化
self.training_dataloader = self.deepspeed_io(training_data)
这一部分主要使用了 DeepSpeed 的 DataLoader,相比于 PyTorch DataLoader 区别是
对比项 | PyTorch DataLoader | DeepSpeedDataLoader |
---|---|---|
核心定位 | 通用数据加载器(单机/多进程) | 专为 分布式训练 优化的 DataLoader 封装 |
分布式支持 | 需手动设置 DistributedSampler,每进程各自初始化 | 自动集成分布式采样与 sharding,适配 DeepSpeed engine 的 rank/world_size |
数据划分 | 手动控制每 GPU 的样本切分 | 内部自动根据 rank 划分 dataset,确保样本不重复不遗漏 |
动态 batch 管理 | 固定 batch_size,需自行调整 | 支持 动态批大小(dynamic batch size) 与 gradient accumulation 配合 |
异步与性能优化 | 基本依赖 Python 多进程/线程加载 | 支持 异步预取(async prefetch)、pipeline 式数据加载,减少 I/O 等待 |
容错/重启 | 不支持状态恢复 | 可与 DeepSpeed Checkpoint 联动,恢复 DataLoader 状态(epoch、iteration) |
典型使用场景 | 单机或基本 DDP | 大规模分布式训练(ZeRO、MoE、pipeline 等) |
- Optimizer 初始化
if has_optimizer:
self._configure_optimizer(optimizer, model_parameters)
这里面根据业务的各种配置,以及 optimizer 的参数进行初始化,本文要讲的 ZeRO Optimizer 就是在这部分进行初始化完成,其中 ZeRO 有 4 个 Enum,对应四种模式:
class ZeroStageEnum(int, Enum):
""" Enum class for possible zero stages """
disabled = 0
optimizer_states = 1 # 只缓存 optimizer states,对应 ZeRO-1
gradients = 2 # 缓存 optimizer states + gradients,对应 ZeRO-2
weights = 3 # 缓存 optimizer states + gradients + weights,对应 ZeRO-3
max_stage = 3
初始化部分主要是上面两段内容,另外会针对 deepspeed 内部的一些优化特性进行初始化,比如常见的:
- Torch AMP:
- 它解决的是在不显著损失数值精度的前提下,显著提升训练/推理速度 & 降低显存占用的问题。核心做法是:对不同算子用不同精度(如 matmul/conv 用 FP16/BF16,加法/归一化/损失计算用 FP32 累加),并在训练时用动态 loss scaling防止 FP16 下溢。
- 对不同算子使用不同精度
- Matmul / Conv(矩阵乘法,卷积):计算量大,乘法相对误差小;
- 加法 / 归一化 / loss 计算:涉及的计算精度要求高,使用 FP32;
- 训练时使用 loss scaling
- ZenFlow:ZenFlow 是 DeepSpeed 在 2025 年推出的一个 扩展模块 / 引擎,定位为 “stall-free offloading engine”(无停顿的 offloading 引擎)。 arxiv.org+3PyTorch+3DeepSpeed+3,核心问题是解决 参数 offload 到 CPU 过程中 GPU 的阻塞问题。
它基于 ZeRO-Offload,对其进行了 解耦(decoupling)、异步更新(asynchronous updates) 的增强,用来解决 ZeRO-Offload 在 CPU–GPU 同步更新时造成的 GPU 空闲/等待问题(stall)DeepSpeed+2PyTorch+2
简单说:它让 重要梯度 在 GPU 上立刻更新,而把不那么关键的梯度推到 CPU,做异步累计 & 更新,从而让 CPU 和 PCIe 传输的工作可以在 GPU 计算期间重叠,减少 GPU 空闲时间。
ZeRO-1
初始化
代码在 DeepSpeedZeroOptimizer 类中,这里要注意,ZeRO-1 和 ZeRO-2 使用的是同一个 Optimizer,但是通过 partition_gradients 来区分(即是否要对 gradients 梯度进行分片)
# ZeRO stage 1 (False) or 2 (True)
self.partition_gradients = partition_grads
self.zero_stage_string = "ZeRO-2" if partition_grads else "ZeRO-1"
核心过程:
- 初始化参数组,
self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
self.partition_count = [dp_size for i in range(len(self.optimizer.param_groups))]
一个 optimizer 内部的参数可以被分为多个参数组,共享不同的超参;
- 遍历所有的参数组信息,一系列预处理:
打散
# 遍历所有的参数组
for i, param_group in enumerate(self.optimizer.param_groups):
# 当前进程在特定数据并行组中的 rank,用于确定当前进程负责哪一部分参数
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
...
# 这里把需要 train 的参数都放进了 bit16_groups 里
if self.round_robin_gradients:
# 把第 i 个参数组的参数按照轮询的方式,拆到同一个进程组的不同进程里
round_robin_tensors, round_robin_indices = self._round_robin_reorder(self.bit16_groups[i],dist.get_world_size(group=self.real_dp_process_group[i]))
else:
round_robin_tensors = self.bit16_groups[i]
round_robin_indices = list(range(len(self.bit16_groups[i])))
这里的 round_robin_gradients 主要是为了梯度打散,比如 transformer 里多个 layer,不同 layer 的参数更新负载可能是不一样的,所以这里要根据 partition 做一个打散;
Padding:
# 把 tensor 拍平,然后根据 NCCL 的对齐要求,对齐到边界,多余的补上 torch.zeros(..)
flattened_buffer = self.flatten_dense_tensors_aligned(
self.round_robin_bit16_groups[i],
self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i]),
use_cpu_data=True)
...
# 尽可能等分作切割,每个 partition 里存的是一部分 flat tensor
data_parallel_partitions = self.get_data_parallel_partitions(self.bit16_groups_flat[i], i)
self.parallel_partitioned_bit16_groups.append(data_parallel_partitions)
...
# partition_id之前的所有 tensor 的元素数量,求一个 left bound,算一下当前的长度
left_boundary = sum([t.numel() for t in data_parallel_partitions[:partition_id]])
curr_partition_size = data_parallel_partitions[partition_id].numel()
# 计算 padding 的大小; 这里是为了确保每个 partition 的大小是 nccl_start_alignment_factor 的倍数
if orig_group_numel <= left_boundary:
padding = curr_partition_size
elif orig_group_numel < left_boundary + curr_partition_size:
padding = left_boundary + curr_partition_size - orig_group_numel
else:
padding = 0
self.groups_padding.append(padding)
这里用 round robin 后的参数,进行拍平,flatten 操作就是把参数 tensor 转换成一维,然后根据 NVLink 的通信要求对齐到边界,padding 的补上 torch.zeros(…);
明确关键变量
partition_size = len(self.bit16_groups_flat[i]) / dist.get_world_size(group=self.real_dp_process_group[i])
# 取到当前的 partition id 分片的 params
params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(
self.round_robin_bit16_groups[i], partition_size, partition_id)
self.partition_size.append(partition_size)
self.params_in_partition.append(params_in_partition)
self.params_not_in_partition.append(params_not_in_partition)
self.first_offset.append(first_offset)
这里主要是明确 params_in_partition 和 params_not_in_partition,在后续 backward 和 step 过程里会使用到。
Step()
因为 ZeRO-1 主要是在 optimizer stats 上有变化,所以主要改变在 step() 的步骤(具体代码参考 DeepSpeedZeroOptimizer 中的 step() 方法):
- 判断是否 overflow,对于 fp16 指数位比较少的情况会出现,如果 overflow 了,通过 all-reduce 广播给所有节点,降级或者放弃这次 params 更新
# 这里是如何判断overflow的?
# 1. 先判断当前的 partition 是否 overflow
# 2. 如果当前的 partition 没有 overflow,再判断是否有其他的 partition 溢出
# 4. 做一个all reduce, 把所有的 partition 的 overflow 标志位都同步起来
def has_overflow(self, partition_gradients=True):
overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial()
overflow_gpu = get_accelerator().ByteTensor([overflow]) if self.cpu_offload else overflow.byte().to(
get_accelerator().current_device_name())
if partition_gradients:
'''This will capture overflow across all data parallel and expert parallel process
Since expert parallel process are a subset of data parallel process'''
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group)
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
self._model_parallel_all_reduce(tensor=overflow_gpu, op=dist.ReduceOp.MAX)
overflow = overflow_gpu[0].item()
return bool(overflow)
- 计算梯度
# 释放掉所有不在当前 partition 中的 params 的梯度
self.free_grad_in_param_list(self.params_not_in_partition[i])
# 这里是取到当前 partition 中的 params 的梯度
single_grad_partition = self.flatten(self.averaged_gradients[i]).to(self.single_partition_of_fp32_groups[i].dtype)
关于 averaged_gradients 是如何计算出来的,可以追踪源码看一下,有很多种配置和情况,具体代码在(简单来说都是通过 all-reduce 来完成 gradients 同步的)
# ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
# Communicate only at gradient accumulation boundaries
elif self.is_gradient_accumulation_boundary():
if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states and hasattr(
self.optimizer, 'reduce_gradients'):
self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism)
else:
grads = None
self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size)
elif self.zenflow:
self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism)
- 获取梯度后进行 optimize()
# Step 3:- run the optimizer if no offloading
self.timers(OPTIMIZER_STEP_TIMER).start()
self._optimizer_step(i)
def _optimizer_step(self, group_no):
original_param_groups = self.optimizer.param_groups
self.optimizer.param_groups = [original_param_groups[group_no]]
...
self.optimizer.step()
self.optimizer.param_groups = original_param_groups
注意这里,先把 optimizer 内部的 param_groups 进行暂存,然后重置为当前分片的 optimizer params(optimizer states 分片,即优化器只更新属于自己分片的参数,通过 all-gather 获取其他节点更新的参数),进行参数调整后,再重置为原始的所有 param_groups
- All-Gather 完成参数收集
# 这里是 all gather 所有的 params, 然后更新到当前的 partition 中
all_gather_dp_groups(groups_flat=self.bit16_groups_flat,
partitioned_param_groups=self.parallel_partitioned_bit16_groups,
dp_process_group=self.real_dp_process_group,
start_alignment_factor=self.nccl_start_alignment_factor,
allgather_bucket_size=self.allgather_bucket_size)
ZeRO-2
复用 ZeRO-1 的 optimizer,但是通过 partition_gradients 参数来区分两者,整体的代码非常相似,这里只列举一下两边不同的部分,主要是在 backward() 里。(因为 gradient 的分片主要是在反向过程中计算梯度的时候用到)
backward()
调用链:
def backward(self, loss, retain_graph=False, scale_wrt_gas=True):
...
# 这里是直接diaoyong调用 optimizer 的 backward 方法
self._do_optimizer_backward(loss, retain_graph)
# 这里是对 backward() 后的数据做处理
self._backward_epilogue()
def _backward_epilogue(self):
self.allreduce_gradients()
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
...
# ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
def independent_gradient_partition_epilogue(self):
...
self.reduce_ipg_grads()
...
这里有一个核心方法 reduce_ipg_grads(),ipg 的含义是 independent partitioned gradients,这个方法里通过对 grads 进行分桶,以及 reduce 等操作,间接实现了分片/scatter 的功能,核心方法是 average_tensor。
详细分桶规则参考下面:
- rank_and_offsets 是将 grads 打平成一维后,不同 rank 和每个 rank grads 的 offset;
- dst 是目标 rank,bucket_offset 是 grads 起始 offset,numel 是 grads 数量;
- 产出的 buckets,bucket_key 是 dst + 进程组,value 是 grad tensor;
buckets = {}
for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets):
grad_slice = tensor.narrow(0, int(bucket_offset), int(numel))
bucket_key = real_dp_process_group[i] if self.use_multi_rank_bucket_allreduce else (
dst, real_dp_process_group[i])
if bucket_key not in buckets:
buckets[bucket_key] = []
if self.use_multi_rank_bucket_allreduce:
buckets[bucket_key].append((dst, grad_slice))
else:
buckets[bucket_key].append(grad_slice)
遍历所有 buckets(注意这里的 buckets 里包含要发送给所有 rank 节点的所有 grads 梯度信息),代码如下。在 ipg 相关代码里,主要做两个事情:
- 针对 dst rank,把所有的 grads 发送给 dst(通过 dist.reduce)
- 把属于当前的 partition 的 grads 储存下来,其他的 clear() ;(通过 bucket.clear())
for bucket_key in buckets:
self.allreduce_no_retain(buckets[bucket_key],
communication_data_type,
numel_per_bucket=self.reduce_bucket_size,
rank=dst,
divide=False,
process_group=process_group)
ZeRO-3
训练过程中 ZeRO-3 的参数加载、释放、all-gather 生命周期都是由 deepspeed 内部的各种 hook 实现自动管理的;其中几个典型的 hook:
# Pre forward hook
self.forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook))
# Post forward hook
self.forward_hooks.append(module.register_forward_hook(_post_forward_module_hook))
PartitionedParameterCoordinator
对于 ZeRO-3,最核心的是 PartitionedParameterCoordinator,和 ZeRO-1,2 不同,这里对于 params 做了分片来最大程度节约显存空间,那么 forward 和 backward 过程中,对于 params 的处理和协调,就是在这个 Coordinator 里实现的。
这里描述了 Coordinator 的所有功能
class PartitionedParameterCoordinator:
FORWARD_FETCH_SUBMIT = 'forward_fetch_submit'
FORWARD_FETCH_WAIT = 'forward_fetch_wait'
FORWARD_PREFETCH_SUBMIT = 'forward_prefetch_submit'
BACKWARD_FETCH_SUBMIT = 'backward_fetch_submit'
BACKWARD_FETCH_WAIT = 'backward_fetch_wait'
BACKWARD_PREFETCH_SUBMIT = 'backward_prefetch_submit'
FORWARD_ALL_GATHER = 'forward_all_gather'
BACKWARD_ALL_GATHER = 'backward_all_gather'
"""Handles partitioning and gathering of parameters."""
为了加速 ZeRO-3 参数分片带来的性能损失(作为显存下降的 trade-off),DeepSpeed 也提供了一些特性:
- pre-fetch:预拉取 next layer 的参数
- Fast fetch:跳过一些依赖校验,直接拿参数;(这个原理后面再看..)
- Trace 能力:通过 trace 来做一些加速;(这个原理后面再看..)
对于 params 的参数的状态追踪有三种(主要还是为了保持一致性):
- AVAILABLE:当前分片可用
- NOT_AVAILABLE:当前分片不可用
- INFLIGHT:传输中
class ZeroParamStatus(Enum):
# parameters are fully present and ready for use on all processes
AVAILABLE = 1
# parameters are either partitioned or remote in some or all process
NOT_AVAILABLE = 2
# parameters are being gathered.
INFLIGHT = 3
Forward
- Pre Forward:
Forward 对于每个 layer / nn module 的调用,会通过 fetch 来取到当前 module 需要的参数,同时 pre-fetch 加速后续 module 的参数获取,代码如下:
def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
...
params_to_fetch = set(iter_params(current_submodule, recurse=z3_leaf_module(current_submodule)))
fetch_numel = sum([p.partition_numel() for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
if fetch_numel > 0:
...
self.__all_gather_params(params_to_fetch, forward)
- Post Forward:
PreForward 用来获取参数,PostForward 就是用来裁剪参数。
def release_sub_module(self, submodule: Module, forward=False) -> None:
"""release the parameters of a sub module, assuming they meet conditions to
be released."""
#print_rank_0(f"release_sub_module {'fwd' if forward else 'bwd'}: {debug_module2name_id(submodule)}", force=False)
params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule))))
free_data = not z3_leaf_module(submodule) or not self.fast_sharding_for_leaf_module
if not free_data:
# wait for the computation to finish and launch as early as possible.
empty_buffer = torch.empty(1, device=get_accelerator().current_device())
for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
param.ds_active_sub_modules.discard(submodule.ds_id)
if param.ds_id in params_to_release and not param.is_external_param:
self.__release_param(param, free_data)
if not free_data:
if param.ds_id in params_to_release and not param.is_external_param:
# empty buffer ensures that all computations are complete
param.data = empty_buffer
Backward
累了…