分布式训练入门:数据并行、张量并行与流水线并行

从单卡瓶颈出发,系统理解大模型训练中的三种核心并行策略,以及它们在 FSDP、DeepSpeed 等工程框架中的落地方式。

难度

前沿

阅读时长

约 125 分钟

更新日期

2026/03/24

主题

训练工程 / 分布式训练 / 数据并行 / 张量并行 / 流水线并行

先修知识

GPU 训练常识矩阵乘法与张量维度概念大模型训练流水线总览

学习目标

读完这篇文章后,你应该能:

  1. 说清大模型训练为什么通常无法停留在单卡或单机。
  2. 区分数据并行、张量并行、流水线并行分别切分了什么。
  3. 理解 FSDP、ZeRO、Megatron-LM、DeepSpeed 这些词大致落在哪一层。
  4. 在面对一个新训练任务时,判断应该优先补显存、补带宽,还是补调度策略。

这篇教程的重点不是某一个框架的命令行参数,而是建立一张“并行策略地图”。

为什么单卡训练很快会碰到天花板

大模型训练的资源消耗,不只是模型权重本身。真正占内存和算力的通常至少有四块:

  • 模型参数
  • 梯度
  • 优化器状态
  • 激活值

如果你用 Adam,一份参数往往对应多份状态;如果序列变长,激活和注意力开销也会迅速膨胀。于是很快会出现两类问题:

  1. 模型根本放不下。
  2. 即使勉强放下,训练速度也太慢。

分布式训练的目标,就是把“放不下”和“跑不动”同时拆解掉。

先建立一个判断框架:你到底在切什么

理解并行策略最简单的方法,是先问一句:

我是把数据切开了,还是把模型切开了,还是把执行时间切开了?

对应起来就是:

  • 数据并行:每张卡持有同一份模型,不同卡处理不同数据。
  • 张量并行:同一层里的权重矩阵被拆到多张卡上。
  • 流水线并行:不同层或不同层组被放到不同设备上,像装配线一样串起来。

三者并不冲突,真实的大模型训练往往是组合使用。

数据并行:最容易理解,也是最先接触的方案

数据并行的直觉非常简单:

  • 每个 GPU 上放同一份模型副本
  • 把一个大 batch 切成多个小 batch
  • 每张卡各自前向和反向
  • 最后把梯度同步,再一起更新参数

它的优点是实现直观,对模型结构侵入小,因此常作为分布式训练的起点。

DDP 在解决什么

DistributedDataParallel(DDP)最经典的职责,就是在反向传播阶段做梯度同步。你可以把它理解成:

  • 各卡先独立算出梯度
  • 再通过 all-reduce 把梯度聚合
  • 保证所有副本最终保持一致

DDP 的问题也很明确:虽然数据切开了,但模型副本仍然是完整复制的,所以显存压力没有从根本上消失。

ZeRO 和 FSDP 为什么重要

当模型越来越大时,只复制完整权重会非常浪费。ZeRO 和 FSDP 的核心思想,都是进一步把“参数、梯度、优化器状态”也分散到多卡上,而不是每张卡都保留完整副本。

可以粗略理解为:

  • ZeRO-1:切优化器状态
  • ZeRO-2:再切梯度
  • ZeRO-3:连参数也切
  • FSDP:把参数按 shard 管理,按需聚合和释放

这些方法本质上仍属于“数据并行家族”,只是它们把内存管理做得更激进、更精细。

张量并行:一层都太大时,就把层内部拆开

当单层矩阵已经大到一张卡放不下,或者单卡算力无法满足吞吐要求时,仅靠数据并行就不够了。这时要考虑张量并行。

张量并行做的事情是:把同一层的线性计算切到多张卡上一起算

最常见的方式有两类:

  • 按列切分权重矩阵
  • 按行切分权重矩阵

例如一个非常大的线性层 Y = XW,如果 W 太大,就可以让不同 GPU 各自持有 W 的一部分,然后把结果再拼接或聚合。

张量并行的好处

  • 单层模型容量可以突破单卡限制。
  • 大矩阵乘法能利用更多设备并行计算。
  • 对超大模型尤其关键,很多百亿、千亿模型都离不开它。

张量并行的代价

  • 层内通信频繁,对高速互联依赖强。
  • 实现复杂度高,通常要和模型结构强耦合。
  • 如果卡间带宽不够,通信可能吃掉理论加速收益。

所以张量并行不是“越早越好”,而是在模型尺寸真的逼近硬件边界时才更有价值。

流水线并行:把层堆切成多段执行

流水线并行的出发点是:既然模型有很多层,那不如把前几层放在设备 A,中间层放在设备 B,后几层放在设备 C。

这样一来,不同 micro-batch 就可以像工厂流水线一样穿过各个阶段。

例如:

  • GPU0 负责第 1 到 12 层
  • GPU1 负责第 13 到 24 层
  • GPU2 负责第 25 到 36 层

当第一个 micro-batch 进入 GPU1 时,GPU0 已经可以开始处理第二个 micro-batch 了。

流水线并行解决什么问题

  • 单个设备不必存完整网络
  • 更容易按层切分显存负担
  • 对非常深的网络很友好

它最大的痛点是什么

最典型的问题叫 pipeline bubble,也就是某些阶段暂时空转,没有被充分喂满。如果 micro-batch 太少、阶段划分不均匀,设备利用率就会下降。

所以流水线并行不仅是“切层”,更是“切得是否均衡”的问题。

三种并行方式怎么组合

真实工业训练通常不是三选一,而是三维组合。你可以把它想成:

  • 数据并行负责横向扩展样本吞吐
  • 张量并行负责拆大层
  • 流水线并行负责拆层堆

很多框架里的“3D parallelism”说的就是这个思路。

一个粗略示例:

  • 8 卡做 2 路张量并行
  • 每 2 卡组成一个流水线 stage
  • stage 之间再做 2 路数据并行

这听起来复杂,但本质上就是在不同维度同时切分。

什么时候优先选哪种并行

下面这个判断框架很实用:

场景一:模型能放下,但训练太慢

优先考虑数据并行,因为它最直接提升吞吐,工程改造成本也相对低。

场景二:模型放不下,但层还没大到离谱

先看 FSDP 或 ZeRO-3。这类方案往往能在不深改模型结构的情况下,把显存利用率抠出来。

场景三:单层矩阵本身已经过大

需要张量并行,因为问题不再只是“副本浪费”,而是“层本体过大”。

场景四:模型很深,层间切分更自然

可以引入流水线并行,但要仔细处理 stage 划分和 micro-batch 调度。

一个常见工程组合:FSDP 或 DeepSpeed 作为第一步

如果你是第一次真正落地分布式训练,最常见的路线通常不是上来就做完整 3D 并行,而是先在以下两条里选一条:

  1. PyTorch FSDP + torchrun
  2. DeepSpeed ZeRO + accelerate / deepspeed launcher

前者更贴近 PyTorch 原生生态,后者在大模型训练里有很成熟的工程经验。一个最小 torchrun 启动方式大致如下:

torchrun --nproc_per_node=4 train.py \
  --model_name Qwen/Qwen2.5-7B \
  --fsdp "full_shard auto_wrap" \
  --bf16 true \
  --gradient_checkpointing true

这条命令背后的几个关键信号是:

  • nproc_per_node=4 表示本机 4 个进程对齐 4 张卡
  • full_shard 表示参数和梯度会被分片管理
  • gradient_checkpointing 用时间换显存

它不是最终最优配置,但足够让你从单卡走到多卡。

通信为什么会决定分布式训练成败

初学分布式训练时,很多人只关注 GPU 数量,却忽视了卡间通信。实际上:

  • 数据并行依赖梯度同步
  • 张量并行依赖层内中间结果交换
  • 流水线并行依赖阶段间激活传递

如果设备间互联很弱,例如只靠普通 PCIe,而不是 NVLink、InfiniBand 之类高速连接,那么“理论上可以并行”不等于“实践里真的更快”。

一个经验判断是:

  • 数据并行对通信较敏感,但通常最容易接受
  • 张量并行最依赖高带宽、低延迟互联
  • 流水线并行则更依赖 stage 均衡和调度设计

激活检查点为什么经常和分布式训练一起出现

你会发现很多多卡训练命令都会同时开启 gradient checkpointing 或 activation recomputation。原因很简单:

  • 分布式并没有消除激活开销
  • 长序列训练时,激活常常仍是主要显存压力

激活检查点的思路是:前向时不保存所有中间结果,反向时再局部重算,以此节省显存。它和并行策略并不冲突,往往是配套出现的。

训练时应该重点盯哪些指标

当你开始做多卡训练后,判断系统是否健康,不能只看 loss。建议同步观测:

  • 每卡显存占用是否均匀
  • GPU 利用率是否长期偏低
  • 通信等待时间是否过长
  • step time 是否抖动明显
  • 梯度同步是否成为瓶颈
  • 是否频繁出现 OOM 或 NCCL 超时

很多训练项目卡住,并不是模型本身有问题,而是某个并行维度配置失衡。

常见误区

1. 以为多加几张卡就一定线性加速

一旦通信、I/O 或调度成为瓶颈,卡数增加可能只会带来更复杂的问题,而不会带来等比例收益。

2. 模型放不下就立刻上张量并行

很多时候 FSDP 或 ZeRO 已经足够把显存问题解决掉,没必要过早引入更重的模型结构改造。

3. 忽视 batch 设计

分布式训练里常见的 global batch size = per_device_batch * world_size * grad_accumulation。如果这些量之间没有想清楚,训练稳定性和吞吐都可能出问题。

4. 只看平均吞吐,不看长尾

某些配置平均速度不错,但偶发抖动严重、recover 成本高。上线前一定要看稳定性,而不是只看最高速度。

练习与思考题

  1. 为什么说 FSDP/ZeRO 仍然属于数据并行思路的延伸?
  2. 在什么情况下,张量并行比流水线并行更自然?
  3. 如果你的 8 卡训练利用率很低,你会优先排查数据加载、通信还是 stage 划分?为什么?
  4. 当模型放不下、互联又一般时,应该怎样决定并行策略的优先顺序?

延伸阅读

相关阅读

从相近主题继续深入,建立连续学习链路。