跳到主要内容

Megatron-SP 机制

f/g 算子如何插入通信、AllReduce 怎么拆成 AG+RS、激活显存怎么算

核心要点

  • 纯 TP 的激活冗余来自逐 token 算子被整序列复制
  • g/ḡ 共轭算子替换 f/f̄,AllReduce 拆为 AG+RS
  • 激活显存三档公式(TP / TP+SP / +selective recompute)
  • backward 多一次 AllGather,可被 wgrad 计算隐藏
  • SP 必须依附 TP,EP+TP 场景强制启用

前置阅读

名词定义

名词定义
$f$ / $\bar{f}$纯 TP 的共轭算子对:$f$ forward 为恒等、backward 为 AllReduce;$\bar{f}$ 相反。$f$ 包裹 column-parallel linear 输入,$\bar{f}$ 包裹 row-parallel linear 输出
$g$ / $\bar{g}$TP+SP 的共轭算子对:$g$ forward 为 AllGather、backward 为 ReduceScatter;$\bar{g}$ 相反。两者是 SP region 与 TP region 之间的转换算子
wgrad AGbackward 计算 weight gradient 前,把按序列分片存储的激活重新 AllGather 成完整序列的那次额外通信
RNG trackerMegatron-LM 的 CUDA 随机数状态管理器(get_cuda_rng_tracker),按并行域分 seed,保证 Dropout mask 在该复制的 rank 上一致、该不同的 rank 上不同

@tbl-par-spmech-glossary 本文专属名词定义

纯 TP 的激活冗余从哪来?

LayerNorm / Dropout / residual add 是逐 token 算子,本可以在任意序列分片上独立计算,但纯 TP 把它们的激活在每个 rank 上按完整序列复制了 $t$[1]

  • 逐 token 算子:LayerNorm 沿 hidden 维归一化、Dropout 逐元素置零、residual add 逐元素相加,三者在 token 之间没有任何依赖——这是 SP 能按序列维分片它们的数学前提。
  • 纯 TP 的处理:TP 只切分 attention / MLP 的权重,这些逐 token 算子落在 TP 切分范围之外,每个 rank 各算一遍完整序列,激活存储完全冗余。
  • 冗余的量:以 $s$ = 序列长度、$b$ = micro-batch、$h$ = hidden 维、$a$ = head 数、$t$ = TP 度,单层激活显存(BF16 训练,单位 byte)为:
$$\begin{equation} A_{\text{TP}} = sbh\left(10 + \frac{24}{t} + \frac{5as}{ht}\right) \label{eq:par-spmech-act-tp} \end{equation}$$

括号里 $24/t$(GEMM 类激活)和 $5as/(ht)$(attention score)随 $t$ 缩小,$10$ 这一项不随 $t$ 缩小——它正是 LayerNorm 输入、Dropout mask 与输出等逐 token 算子的激活[1]$t=8$ 时这一项占单层激活的 60% 以上,TP 越大占比越高,成为继续扩 TP 也消不掉的地板。

f/g 算子怎么定义?

SP 把纯 TP 的 $f$/$\bar{f}$ 共轭对替换为 $g$/$\bar{g}$ 共轭对,利用 AllReduce ≡ ReduceScatter + AllGather 的恒等关系,把通信拆到 SP region 的边界上[1]

算子ForwardBackward位置
$f$(纯 TP)恒等AllReducecolumn-parallel linear 输入
$\bar{f}$(纯 TP)AllReduce恒等row-parallel linear 输出
$g$(TP+SP)AllGatherReduceScatterSP region → TP region 边界
$\bar{g}$(TP+SP)ReduceScatterAllGatherTP region → SP region 边界

@tbl-par-spmech-fg f/f̄ 与 g/ḡ 共轭算子对

替换的代数基础是 ring 实现下的原语恒等:

$$\begin{equation} \text{AllReduce}(x) \equiv \text{AllGather}(\text{ReduceScatter}(x)) \label{eq:par-spmech-ar-identity} \end{equation}$$

纯 TP 每层 4 次 AllReduce(前向 2 + 反向 2),TP+SP 每层 4 次 AG + 4 次 RS,按 两者总通信量逐字节相同——SP 不省通信量,省的是激活显存,赚的是拆开后的 overlap 机会(见 4.3 通信 overlap 实现)。

通信插在层内什么位置?

每个 transformer 层有 2 个 $g$ + 2 个 $\bar{g}$:attention 子层和 MLP 子层各一对,AllGather 在 LayerNorm 之后、QKV/FC1 之前,ReduceScatter 在 attention-out/FC2 之后、Dropout 之前[1]

输入 [b, s/t, h](序列分片)
└─ LayerNorm ← SP region
└─ g: AllGather → [b, s, h]
└─ QKV 投影 (column-parallel) ← TP region
└─ Attention + 输出投影 (row-parallel)
└─ ḡ: ReduceScatter → [b, s/t, h]
└─ Dropout + residual add ← SP region
└─ LayerNorm
└─ g: AllGather → [b, s, h]
└─ FC1 (column-parallel)
└─ GeLU + FC2 (row-parallel)
└─ ḡ: ReduceScatter → [b, s/t, h]
└─ Dropout + residual add

单次 AG / RS 的有效数据量与纯 TP 的 AllReduce 相同,均针对 $[b, s, h]$ 张量:

$$\begin{equation} M_{\text{SP}} = b \cdot s \cdot h \cdot \text{dtype\_size} \label{eq:par-spmech-msg-size} \end{equation}$$

ring 实现下 AG 与 RS 各搬运 $\frac{t-1}{t} M_{\text{SP}}$,一对 $g$/$\bar{g}$ 合计 $\frac{2(t-1)}{t} M_{\text{SP}}$,与一次 ring AllReduce 相同。典型训练 shape 下单条消息 10 ~ 100 MB 量级,走 TP 所在的高带宽域(NVLink / C2C)。

激活显存能省多少?

TP+SP 把整层激活均匀除以 $t$;再叠加 selective recomputation 把 attention score 项消掉,只剩 $34sbh/t$[1]

无并行时的单层激活(公式中 $34 = 11$ (attention) $+ 19$ (MLP) $+ 4$ (两个 LayerNorm),$5as/h$ 为 attention score 类激活):

$$\begin{equation} A_{\text{base}} = sbh\left(34 + \frac{5as}{h}\right) \label{eq:par-spmech-act-base} \end{equation}$$

TP+SP(对比 ,不可切分的 $10$ 一项消失,整层均匀切分):

$$\begin{equation} A_{\text{TP+SP}} = \frac{sbh}{t}\left(34 + \frac{5as}{h}\right) \label{eq:par-spmech-act-sp} \end{equation}$$

TP+SP+selective recomputation(重算 $QK^{\top}$ / softmax / softmax dropout / attention-over-$V$ 四步,消去 $5as/h$ 项;$L$ 层合计):

$$\begin{equation} A_{\text{selective}} = \frac{34\,sbhL}{t} \label{eq:par-spmech-act-selective} \end{equation}$$
配置单层激活相对无并行
无并行$sbh(34 + 5as/h)$1
纯 TP$sbh(10 + 24/t + 5as/(ht))$地板 $10sbh$ 不随 $t$
TP + SP$\frac{sbh}{t}(34 + 5as/h)$整层 $1/t$
TP + SP + selective$34sbh/t$$1/t$ 且消去 $s^2$

@tbl-par-spmech-act 激活显存四档对比(per layer,byte)

消去了随 $s^2$ 增长的项,这是长序列训练中 SP + selective recompute 组合的关键:激活显存恢复为与 $s$ 线性。

backward 比纯 TP 多了什么通信?

多一次 wgrad AllGather:激活按 $s/t$ 分片存储后,反向算 weight gradient 需要完整序列的激活,必须先 AG 恢复;Megatron 把这次 AG 与梯度计算重叠,不进关键路径[1]

  • 为什么需要:row/column-parallel linear 的 weight gradient 是 $\nabla W = X^{\top} \nabla Y$,需要完整的输入激活 $X$;SP 为省显存只存了 $X$$s/t$ 分片。
  • 代价怎么藏:这次 AG 与"对激活的梯度"计算(dgrad GEMM)重叠执行,论文明确以此消除额外延迟;Megatron-Core 对应开关 tp_comm_bulk_wgrad
  • 建模要点:SP 的 forward 通信量与纯 TP 逐字节相同;backward 多一次 $\frac{t-1}{t}M_{\text{SP}}$ 的 AG,理想 overlap 下不增加迭代时间,但它占用 TP 域带宽——在带宽竞争模型里要把这部分流量计入。

Dropout 的 RNG 为什么要分域管理?

Dropout mask 的正确性要求「该相同的 rank 相同、该不同的 rank 不同」,Megatron-LM 用三套 RNG tracker 实现,SP region 的 Dropout 需要各 rank 不同的 seed[2]

RNG 域用途跨 TP rank
model-parallel-rngTP region 内 Dropout(attention 内部)各 rank 不同(按 rank 偏移 seed)
默认 RNGTP 切分外、各 rank 复制计算的 Dropout同 TP 组内相同
expert-parallel-rngMoE expert 区域各 EP rank 不同

@tbl-par-spmech-rng Megatron-LM 的三套 RNG 域

SP 引入的变化:原本 TP 组内各 rank 对完整序列复制计算 Dropout(mask 必须一致 → 默认 RNG),SP 后各 rank 只算自己的 $s/t$ 段(mask 必须不同 → 应切到 model-parallel-rng)。Megatron-Core 0.9.0 存在已记录的 bug:SP region 的 Dropout 未切换到正确的 RNG 域(Issue #1256)[3]——做数值对齐验证时要注意这一项。

工程约束与实测收益是什么?

SP 不能脱离 TP 使用;EP+TP 组合时强制启用;530B 实测迭代提速 29.7%、激活显存降 5×

工程约束(Megatron-LM 源码与官方文档)[4][5]

  • sequence_parallel=True 要求 tensor_model_parallel_size > 1,否则直接 raise ValueError——SP 并行度被绑死等于 TP 度的代码级证据
  • EP 与 TP 同时启用时,SP 是强制要求(官方文档:"you must enable Sequence Parallelism")
  • 官方推荐用 TP 就开 SP("Always enable when using TP");与 PP 无互斥约束

530B GPT 模型的论文实测(Table 5,无数据并行配置)[1]

指标full recomputationSP + selective recompute
迭代时间49.05 s37.83 s(提速 29.7%)
MFU56.0%56.0%
激活显存降低 5×
重算计算开销全量降低 90% 以上

@tbl-par-spmech-530b 530B 模型上 SP + selective recompute 的实测收益

MFU 两行相同是因为收益体现在迭代时间而非 FLOPs 利用率:full recomputation 多做了一遍前向冗余 FLOPs,selective recompute 省掉这些冗余计算,于是同样的 MFU 下迭代更快。论文摘要另给出 42.1% → 54.2% 的 MFU 对比,那是 2240 张 A100、8-way 数据并行配置的数字,与本表(无 DP)不是同一实验,不可混用。

Takeaway

知识点核心结论
冗余来源逐 token 算子(LayerNorm/Dropout)本可分片,纯 TP 却整序列复制,地板 $10sbh$ 不随 $t$
算子替换$f$/$\bar{f}$ (恒等/AllReduce) → $g$/$\bar{g}$ (AG/RS),基于 AR ≡ RS+AG 恒等式
通信位置每层 2 AG + 2 RS,在 SP↔TP region 边界(LayerNorm 后 / Dropout 前)
通信量forward 与纯 TP 逐字节相同;backward 多一次可被 overlap 的 wgrad AG
显存收益整层激活 $1/t$;叠加 selective recompute 后 $34sbh/t$,消去 $s^2$
RNGSP region 的 Dropout 需各 rank 不同 seed,mcore 0.9.0 有未切换 RNG 域的已知 bug
工程约束必须 TP>1;EP+TP 强制开 SP;530B 实测提速 29.7%、激活显存降 5×

@tbl-par-spmech-takeaway Megatron-SP 机制核心知识点

参考资料

  1. Korthikanti et al., Reducing Activation Recomputation in Large Transformer Models, arXiv:2205.05198, 2022(§4.1 激活显存推导、§4.2 f/g 算子与通信分析、Table 5 实测). https://arxiv.org/abs/2205.05198
  2. Megatron-LM, megatron/core/tensor_parallel/random.py(RNG tracker 实现). https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py
  3. Megatron-LM Issue #1256, Effect of sequence parallel with dropout rng context. https://github.com/NVIDIA/Megatron-LM/issues/1256
  4. Megatron-LM, megatron/core/model_parallel_config.pysequence_parallel 与 TP 度的强制校验). https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/model_parallel_config.py
  5. NVIDIA, Megatron-Core Parallelism Strategies Guide. https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/parallelism-guide.html