难度
前沿
从单卡瓶颈出发,系统理解大模型训练中的三种核心并行策略,以及它们在 FSDP、DeepSpeed 等工程框架中的落地方式。
难度
前沿
阅读时长
约 125 分钟
更新日期
2026/03/24
主题
训练工程 / 分布式训练 / 数据并行 / 张量并行 / 流水线并行
读完这篇文章后,你应该能:
这篇教程的重点不是某一个框架的命令行参数,而是建立一张“并行策略地图”。
大模型训练的资源消耗,不只是模型权重本身。真正占内存和算力的通常至少有四块:
如果你用 Adam,一份参数往往对应多份状态;如果序列变长,激活和注意力开销也会迅速膨胀。于是很快会出现两类问题:
分布式训练的目标,就是把“放不下”和“跑不动”同时拆解掉。
理解并行策略最简单的方法,是先问一句:
我是把数据切开了,还是把模型切开了,还是把执行时间切开了?
对应起来就是:
三者并不冲突,真实的大模型训练往往是组合使用。
数据并行的直觉非常简单:
它的优点是实现直观,对模型结构侵入小,因此常作为分布式训练的起点。
DistributedDataParallel(DDP)最经典的职责,就是在反向传播阶段做梯度同步。你可以把它理解成:
all-reduce 把梯度聚合DDP 的问题也很明确:虽然数据切开了,但模型副本仍然是完整复制的,所以显存压力没有从根本上消失。
当模型越来越大时,只复制完整权重会非常浪费。ZeRO 和 FSDP 的核心思想,都是进一步把“参数、梯度、优化器状态”也分散到多卡上,而不是每张卡都保留完整副本。
可以粗略理解为:
这些方法本质上仍属于“数据并行家族”,只是它们把内存管理做得更激进、更精细。
当单层矩阵已经大到一张卡放不下,或者单卡算力无法满足吞吐要求时,仅靠数据并行就不够了。这时要考虑张量并行。
张量并行做的事情是:把同一层的线性计算切到多张卡上一起算。
最常见的方式有两类:
例如一个非常大的线性层 Y = XW,如果 W 太大,就可以让不同 GPU 各自持有 W 的一部分,然后把结果再拼接或聚合。
所以张量并行不是“越早越好”,而是在模型尺寸真的逼近硬件边界时才更有价值。
流水线并行的出发点是:既然模型有很多层,那不如把前几层放在设备 A,中间层放在设备 B,后几层放在设备 C。
这样一来,不同 micro-batch 就可以像工厂流水线一样穿过各个阶段。
例如:
当第一个 micro-batch 进入 GPU1 时,GPU0 已经可以开始处理第二个 micro-batch 了。
最典型的问题叫 pipeline bubble,也就是某些阶段暂时空转,没有被充分喂满。如果 micro-batch 太少、阶段划分不均匀,设备利用率就会下降。
所以流水线并行不仅是“切层”,更是“切得是否均衡”的问题。
真实工业训练通常不是三选一,而是三维组合。你可以把它想成:
很多框架里的“3D parallelism”说的就是这个思路。
一个粗略示例:
这听起来复杂,但本质上就是在不同维度同时切分。
下面这个判断框架很实用:
优先考虑数据并行,因为它最直接提升吞吐,工程改造成本也相对低。
先看 FSDP 或 ZeRO-3。这类方案往往能在不深改模型结构的情况下,把显存利用率抠出来。
需要张量并行,因为问题不再只是“副本浪费”,而是“层本体过大”。
可以引入流水线并行,但要仔细处理 stage 划分和 micro-batch 调度。
如果你是第一次真正落地分布式训练,最常见的路线通常不是上来就做完整 3D 并行,而是先在以下两条里选一条:
PyTorch FSDP + torchrunDeepSpeed ZeRO + accelerate / deepspeed launcher前者更贴近 PyTorch 原生生态,后者在大模型训练里有很成熟的工程经验。一个最小 torchrun 启动方式大致如下:
torchrun --nproc_per_node=4 train.py \
--model_name Qwen/Qwen2.5-7B \
--fsdp "full_shard auto_wrap" \
--bf16 true \
--gradient_checkpointing true
这条命令背后的几个关键信号是:
nproc_per_node=4 表示本机 4 个进程对齐 4 张卡full_shard 表示参数和梯度会被分片管理gradient_checkpointing 用时间换显存它不是最终最优配置,但足够让你从单卡走到多卡。
初学分布式训练时,很多人只关注 GPU 数量,却忽视了卡间通信。实际上:
如果设备间互联很弱,例如只靠普通 PCIe,而不是 NVLink、InfiniBand 之类高速连接,那么“理论上可以并行”不等于“实践里真的更快”。
一个经验判断是:
你会发现很多多卡训练命令都会同时开启 gradient checkpointing 或 activation recomputation。原因很简单:
激活检查点的思路是:前向时不保存所有中间结果,反向时再局部重算,以此节省显存。它和并行策略并不冲突,往往是配套出现的。
当你开始做多卡训练后,判断系统是否健康,不能只看 loss。建议同步观测:
很多训练项目卡住,并不是模型本身有问题,而是某个并行维度配置失衡。
一旦通信、I/O 或调度成为瓶颈,卡数增加可能只会带来更复杂的问题,而不会带来等比例收益。
很多时候 FSDP 或 ZeRO 已经足够把显存问题解决掉,没必要过早引入更重的模型结构改造。
分布式训练里常见的 global batch size = per_device_batch * world_size * grad_accumulation。如果这些量之间没有想清楚,训练稳定性和吞吐都可能出问题。
某些配置平均速度不错,但偶发抖动严重、recover 成本高。上线前一定要看稳定性,而不是只看最高速度。
从相近主题继续深入,建立连续学习链路。