通信 overlap 实现
AG/RS 与 GEMM 有数据依赖,三条路线怎么分块流水实现内核级融合
核心要点:
- AG/RS 与 GEMM 有数据依赖,不能简单异步发起
- 三条路线按分块粒度从粗到细排列
- Userbuffers 用 P2P ring-exchange 分块流水
- Flux 用 tile 级 kernel fusion,async-TP 用 SymmetricMemory 双流
- 小 GEMM 下 overlap 收益消失甚至变负
前置阅读:
- SP 为什么产生 AG/RS、它们在层内的位置 → 4.2 Megatron-SP 机制
- overlap 可行性判据(AllReduce 时间 / MatMul 时间)与跨策略通用原理 → 9.2 计算通信 Overlap
名词定义
| 名词 | 定义 |
|---|---|
| Userbuffers (UB) | TransformerEngine 的通信 backend:预分配跨 rank 共享缓冲区,用 P2P ring-exchange 分步搬运,支持与 GEMM 流水 |
| bulk overlap | 把与当前计算无数据依赖的通信,整块与另一段计算并发执行 |
| pipelined overlap | 把与当前计算有数据依赖的 AG/RS 拆成多步 P2P,每步数据就绪即喂给 GEMM 的对应分块 |
| prologue / epilogue fusion | 在 GEMM kernel 入口(prologue)等待输入 tile 通信就绪,或在出口(epilogue)把输出 tile 直接 P2P 写到远端的融合方式 |
| SymmetricMemory | PyTorch 的对称共享内存抽象:各 GPU 预分配可被对端直接读写的缓冲区,绕过 NCCL 走 CUDA P2P |
| overlap efficiency | 实际被隐藏的通信时间占理论可隐藏量的比例,受 tail latency 限制 |
@tbl-par-spovl-glossary 本文专属名词定义
为什么 AG/RS 不能像 DP 梯度那样简单异步发起?
AG 的输出是 GEMM 的输入、GEMM 的输出是 RS 的输入,存在直接数据依赖,必须在 kernel 级拆分依赖链才能 overlap[1]。
- DP 梯度的情形:梯度 AllReduce 与反向计算无依赖,发起后异步等待即可,是最容易藏的通信。
- SP 的 AG/RS 的情形:AG 必须先于 GEMM 完成(提供完整输入),RS 必须在 GEMM 之后(消费完整输出),整块串行发起就退化成串行执行,没有 overlap。
- 解法的共同思路:把"全量 AG → 整块 GEMM → 全量 RS"拆成分块流水——AG 搬完一块就让 GEMM 算这一块,GEMM 算完一块就 RS 这一块。三条主流路线的区别在分块粒度与实现层次。
| 路线 | 分块粒度 | 实现层次 | 主要载体 |
|---|---|---|---|
| TE Userbuffers | 设备数量级(ring step) | 库 + 预分配缓冲 | TransformerEngine / Megatron-LM |
| ByteDance Flux | GEMM tile(thread block) | CUTLASS kernel 内融合 | byte-flux |
| PyTorch async-TP | matmul chunk | torch.compile 自动变换 | PyTorch + SymmetricMemory |
@tbl-par-spovl-routes 三条 overlap 路线的粒度与层次
TransformerEngine Userbuffers 怎么分块流水?
UB 区分 bulk 与 pipelined 两种模式:无依赖通信用 bulk 整块并发,有依赖通信用 pipelined 把 AG/RS 拆成多步 P2P ring-exchange 与 GEMM 交织[2]。
- bulk overlap(默认):通信与一段无关计算并发。例如做某个 projection GEMM 的同时发起另一处的 AG。
- pipelined overlap:把 AG 替换为多步 P2P ring-exchange——每步收一块分片,收到即对该分片做 GEMM,不等全量 AG;RS 对称地替换为多步输出 P2P + 逐块 reduction。Megatron-LM / NeMo 的开关为
tp_comm_overlap,底层直接调用 TE 的 UB backend(Megatron-LM 不独立实现 overlap kernel)。
硬件与调优约束:
- UB 依赖 NVLink P2P 内存访问;TransformerEngine v2.2 增加跨节点 NVLink 支持。
- 更快的 CUDA Multicast 路径(
multimem做 in-switch reduction)需要 NVSwitch V3+ 与匹配 driver;不满足时设UB_SKIPMC=1退回 CUDA IPC(Issue #1923 记录了 A100 + CUDA 12.2 不满足的案例)[3]。 - 分配给通信的 SM 数量可调:过多挤占 GEMM、过少暴露通信延迟,无通用公式,需在目标配置上调优[2]。
Flux 的 tile 级融合解决了什么?
Flux 把通信分块对齐到 GEMM 的 thread block tile,融进同一个 kernel,避免「拆成多个小 GEMM kernel」导致的 SM 利用率下降[1]。
- AG+GEMM(prologue fusion):host 侧异步发起 AG,每完成一个 tile 置位 signal;GEMM thread block 在 prologue 等待对应 signal 后即算该 tile。支持 pull / push 两种变体。
- GEMM+RS(epilogue fusion):thread block 在 epilogue 用 P2P 指令把结果直接写到目标 GPU 远端内存(跨节点用 NVSHMEM put)。tile coordinate swizzling 把 thread block 索引按 rank 偏移,避免多 GPU 同写一个内存控制器。
- 为什么比串小 kernel 好:UB 类方案把一个大 GEMM 拆成多个小 GEMM kernel 串联,GEMM 形状变小后 SM 利用率下降;Flux 在单 kernel 内用 warp 级并发处理通信等待 / 写入,GEMM 效率接近原始单 kernel[1]。
实测加速(论文摘要 headline 数字)[1]:
| 场景 | 加速 | 基准 | 配置 |
|---|---|---|---|
| 训练(model 级) | 1.24× | vs Megatron-LM | 128 GPU(多代 GPU 与互联) |
| prefill 推理 | 1.66× | vs vLLM | 8 GPU |
| decode 推理 | 1.30× | vs vLLM | 8 GPU |
@tbl-par-spovl-flux Flux 实测加速(论文另报最高可隐藏 96% 通信)
论文正文 §5 另有按 operation 粒度、按硬件(A100 PCIe / H800 NVLink)细分的 vs TransformerEngine 对照,倍数随 GEMM 形状与互联差异较大,需查原表对应行,本表只取摘要给出的端到端 headline。
开源于 byte-flux(支持 sm80/sm89/sm90),需 CUDA 12.4 + CUTLASS + NVSHMEM;2025-03 扩展出 MoE 版本 COMET[4]。
PyTorch async-TP 怎么做到改模型代码无感?
async-TP 用 SymmetricMemory 做 P2P 直写、双流交替把 matmul 拆 chunk 流水,并由 torch.compile 自动识别 TP 模式重写,无需手改模型[5]。
- SymmetricMemory:各 GPU 预分配对称共享缓冲区,对端可直接读写,绕过 NCCL 的 SM 内核开销,走 CUDA P2P + copy engine。
- micro-pipelining:matmul 拆成 N 个 chunk,两条流交替——流 A 算第 $i$ 块 GEMM,流 B 同时 RS 第 $i$ 块结果,流 A 接着算第 $i+1$ 块。
- 自动变换:torch.compile 的 inductor pass(
_micro_pipeline_tp=True)识别 AG+matmul / matmul+RS 模式并重写为 chunk 交织,用户只需对 TP process group 调enable_symm_mem_for_group()。
实测(TorchTitan,64×H100 NVSwitch,bf16)[5]:Llama3 7B forward +29% / E2E +8%;Llama3 70B forward +20% / E2E +8%。SymmetricMemory 在 PyTorch 2.9 正式文档化,主要面向单节点,跨节点开发中。
overlap 的代价和适用边界在哪?
overlap 不是免费的:通信占 SM 拖慢 GEMM、小 GEMM 下分块效率崩塌、尾部通信藏不住,三者共同决定收益边界。
- SM 竞争:通信 kernel 占用的 SM 会拖慢同时运行的 GEMM。T3 论文(AMD GPU + Accel-Sim 仿真)的微基准显示:GEMM 主导时 overlap 完成时间略慢于纯 GEMM、但显著快于串行执行;通信主导时 overlap 引入的额外开销很小[6]。即 overlap 在通信占比高时几乎稳赚,GEMM 占比高时收益有限。
- 小 GEMM 失效:GEMM 越小(decode 小 batch、或高 TP 度导致分片小),分块 overlap 的 SM 利用率越低。Flux 在 $m \le 64$ 时 TMA 效率下降、overlap 收益消失甚至变负[1]——这与 4.2 Megatron-SP 机制 提到的 decode 退化、4.1 总览 推理场景分析一致。
- tail latency:最后一波 chunk 的通信无法与计算重叠,理想延迟为 $\max(T_{\text{GEMM}}, T_{\text{comm}})$,实际为 $T_{\text{GEMM}} + T_{\text{comm-tail}}$。Flux 96% overlap efficiency 即 4% 通信藏不住[1]。
- 建模含义:通信能否被藏住的条件是 $T_{\text{comm}} \le T_{\text{GEMM}}$(判据详见 9.2 计算通信 Overlap);高 TP 度下 $T_{\text{GEMM}}$ 随分片变小而下降、$T_{\text{comm}}$ 不等比缩小,条件更难满足,overlap 收益随 TP 度增大而递减。
Takeaway
| 知识点 | 核心结论 |
|---|---|
| 依赖约束 | AG→GEMM→RS 有数据依赖,必须分块流水,不能整块异步 |
| 三路线粒度 | UB(ring step)< Flux(GEMM tile)< async-TP(matmul chunk),粒度越细 SM 利用率越高 |
| Userbuffers | bulk(无依赖整块并发)/ pipelined(有依赖 P2P 分步),Megatron tp_comm_overlap 调用它 |
| Flux | tile 级 prologue/epilogue 融进单 kernel,避免小 GEMM kernel 串联;实测训练 1.24× vs Megatron / prefill 1.66× vs vLLM |
| async-TP | SymmetricMemory P2P + 双流,torch.compile 自动重写;Llama3 forward +20~29% |
| 代价边界 | SM 竞争(GEMM 主导时 +7.5%)、小 GEMM 失效($m\le64$ 负收益)、tail latency(4% 藏不住) |
@tbl-par-spovl-takeaway 通信 overlap 实现核心知识点
参考资料
- Chang et al., FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion, arXiv:2406.06858, 2024(tile 级融合机制、§5.2 实测、overlap efficiency、小 GEMM 限制). https://arxiv.org/abs/2406.06858
- NVIDIA NeMo Framework, Communication Overlap 文档(bulk / pipelined overlap、SM 调优). https://docs.nvidia.com/nemo-framework/user-guide/24.09/nemotoolkit/features/optimizations/communication_overlap.html
- TransformerEngine Issue #1923, CUDA Multicast 要求与 UB_SKIPMC 退回路径. https://github.com/NVIDIA/TransformerEngine/issues/1923
- ByteDance, flux GitHub 仓库(byte-flux,COMET MoE 扩展). https://github.com/bytedance/flux
- PyTorch, Distributed w/ TorchTitan: Introducing Async Tensor Parallelism in PyTorch, 2024(SymmetricMemory、micro-pipelining、torch.compile pass、TorchTitan 实测). https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487
- T3: Transparent Tracking & Triggering for Fine-grained Overlap, arXiv:2401.16677, 2024(SM 竞争微基准). https://arxiv.org/abs/2401.16677