跳到主要内容

Distributed Top-K

稀疏注意力与 MoE 场景下分布式 Top-K 的原语设计与通信代价

核心要点

  • 非经典原语 (NCCL 标准库未提供),稀疏 attention / MoE 普及催生需求
  • 朴素 AllGather + 本地排序在 $n$ 大时不可行 ($O(n)$ 通信量)
  • Tree-based 把通信量压到 $O(k \log N)$,利用 top-k 集合的可结合性
  • 放大因子 $c = k' / k$ 权衡通信量与召回率,$c = 4$ 通常 ~98% 召回
  • 索引编码必须全局,FP16 / BF16 tie-breaking 需附加字典序

语义

核心问题:Distributed Top-K 的语义是什么?与 AllGather + 本地排序的朴素方案有什么区别?

$N$ rank,每 rank $i$ 持长度 $n_i$ 的 score 向量 $\mathbf{s}_i$ (KV entry / token / candidate 的相关度),全局总长度 $n = \sum_i n_i$

输出全局 top-k 索引集合:

$$\begin{equation} \mathcal{T} = \mathrm{Top}_k \left( \bigcup_{i=0}^{N-1} \mathbf{s}_i \right), \quad |\mathcal{T}| = k \label{eq:cc-topk-semantic} \end{equation}$$

典型场景 $k = 1024$, $n_i = 250 \text{K}$, $k \ll n$

与经典原语的区别

原语输入输出与 top-k 关系
AllGather每 rank 一段所有 rank 持全量朴素方案 = AllGather + 本地排
AllReduce每 rank 一向量逐元素归约不适用 (top-k 不是逐元素)
ArgMax 归约每 rank 一向量全局 argmax是 top-1 特例,扩到 top-k 需新原语

@tbl-cc-topk-vs-others Top-k 与经典原语的关系

Top-k 是集合选择而非逐元素归约,不能套 AllReduce 模板

朴素方案:AllGather + 本地排序

核心问题:最直接做法的代价?

1. AllGather: 每 rank 把 s_i 发给所有, 总通信 O(n)
2. 本地: 每 rank 对长度 n 向量做 top-k, O(n log k)

$n = 250 \text{K}$ FP32 score 每 rank 1 MB;每层每 query 一次,1M token query 每层 1 TB 级,不可行

朴素方案仅当 $n < 1000$ 或 top-k 偶发时可用。

Tree-Based Top-K 归约

核心问题:能否利用 top-k 的代数性质把通信压到 $O(k)$?

代数基础

两个 top-$k'$ 集合的并集再取 top-$k$ 等于全局 top-$k$ (条件 $k' \ge k$),故 top-k 可按 reduction tree 结构归约。

基本算法

Stage 1 (本地): 每 rank 对 s_i 做本地 top-k', 得 (index, value) 列表
Stage 2 (tree reduce): 二叉树 (或环) 归约
- 相邻 rank 交换各自 top-k' 列表
- 合并后再取 top-k', 向上层传
Stage 3: 根持有全局 top-k, 视需要 broadcast

代价

每轮通信仅 $k'$ 个 (index, value) 对 (~12 字节 = FP32 index + FP32 value),$\log_2 N$ 轮:

$$\begin{equation} T_{\text{tree}} = \log_2 N \cdot \left( \alpha + \frac{12 k'}{\beta} \right) \label{eq:cc-topk-tree-time} \end{equation}$$

比朴素 $O(n)$ 压到 $O(k')$, $k' \ll n$ 时收益巨大

精确 vs 近似:放大因子 $c$

核心问题$k' = k$ 通信最省但召回低,放大多少够用?

召回率公式 (均匀切分假设)

$$\begin{equation} \mathbb{E}[\text{Recall}] \approx 1 - \exp(-c), \quad c = k' / k \label{eq:cc-topk-recall} \end{equation}$$
$c$期望召回率
163.2%
286.5%
498.2%
899.97%

@tbl-cc-topk-recall 放大因子与召回率 (均匀切分假设)

实际 score 分布常呈幂律 (少数 token 显著高分),真实召回率通常高于上表理论下界

风险与权衡

  • $k' = k$:通信最小,可能丢解 (本地排序中位于 $k+1$ 位之后的真全局 top-k 会被剪)
  • $k' = c k$ ($c > 1$):保留更多候选,召回率上升,通信量正比上升
  • 稀疏 attention 本身就是近似的,少 2% 召回换 20× 通信收益划算

单 rank 写回 vs 全 rank 复制

核心问题:Tree reduce 后是否要 broadcast 给所有 rank?

下游是否 broadcast
KV cache 集中存储,root 负责 gather KV不需
KV cache 分片存储,每 rank 自 gather 本地段需 (广播 top-k 索引)
Top-k 索引用作稀疏 AllToAll 路由表

@tbl-cc-topk-broadcast 是否广播 top-k 索引的判断

Broadcast 通信量小 ($k$ 个 index = $4 k$ 字节), 但增加一个同步点

拓扑相关实现

Ring 拓扑:Recursive Halving

$t$ 轮每 rank 与距离 $2^t$ 的 rank 交换 top-$k'$, $\log_2 N$ 轮后所有 rank 持全局。通信复杂度同 tree,但带宽利用率受 Ring 钳制。

Fat-Tree / NVSwitch 全连接:All-to-One + Broadcast

$N$ 较小 (单 pod 8–72 GPU) 时直接:

  1. 每 rank 一次性发本地 top-$k'$ 到 root
  2. Root 合并出全局 top-k
  3. Broadcast 给所有

Dragonfly / 多跳:两级 reduce

跨 group 时延迟代价高,两级:

  1. Group 内 reduce 到 group leader
  2. Leader 间再 reduce

与 hierarchical AllReduce 思路一致。

复杂度对比

核心问题:各 Top-K 算法变体在通信复杂度和计算复杂度上如何对比?

$N$ rank,每 rank $n / N$ score,目标 top-k, $k' = c k$:

方案通信量通信轮数召回率
AllGather + 本地排$O(n)$$\log_2 N$100%
Tree top-k ($c = 1$)$O(k \log N)$$\log_2 N$~63%
Tree top-k ($c = 4$)$O(4 k \log N)$$\log_2 N$~98%
Tree top-k ($c = 8$)$O(8 k \log N)$$\log_2 N$~99.97%
两级 (group + global)$O(k \sqrt{N})$ 量级2 阶段取决于 $c$

@tbl-cc-topk-complexity Distributed top-k 方案复杂度对比

$k' = 4 k$$k = 1024$ 时单次通信 ~48 KB,比 AllGather 的 1 MB 降 20×, 98% 召回在稀疏 attention 场景足够

应用场景

核心问题:Distributed Top-K 在稀疏注意力、MoE routing 等 LLM 场景中有哪些典型应用?

场景典型参数备注
DeepSeek V3.2 DSA$n = 160 \text{K}$, $k = 2048$KV 未压缩,每层 top-k
DeepSeek V4 CSA$n \approx n_{\text{ctx}} / m$ (1M ctx 时 $n = 250 \text{K}$), $k = 512/1024$KV 压缩 4×, indexer 选择
MoE gate (top-2 / top-6)$n = E$ (expert 数),$k = 2 \sim 6$$n$ 很小,AllGather + 本地排
Speculative decoding$n = $ candidate 树宽,$k = $ beam通常单卡
Vector DB 跨分片检索$n = $ 亿级 embedding, $k = $ 几百Tree-based 标配

@tbl-cc-topk-use-cases Top-k 典型应用场景

实现细节

核心问题:Distributed Top-K 在实际实现中有哪些关键工程决策和陷阱?

核心问题:实现时哪些坑会让 top-k 结果不对?

  • 索引必须全局:tree-reduce 中 score 携带的 index 必须是 rank_id 拼 local offset,否则合并后索引冲突
  • 数值稳定性:FP16 / BF16 同值 tie-breaking 可能不一致,需附 (value, global_index) 字典序保证确定性
  • GPU 实现:CUTLASS / FlashInfer 提供 device-side top-k kernel,通信层暴露 ncclSend / ncclRecv 让上层组装
  • 稀疏 send list 衔接:top-k 结果常作 sparse AllToAll 路由表,见 4.14 不规则与流水化 AllToAll

Takeaway

知识点核心结论
性质集合选择,非逐元素,不能复用 AllReduce
朴素 AllGather$O(n)$, $n$ 大时不可行
Tree-based$O(k \log N)$, top-k 可结合性是基础
放大因子 $c$召回率 $\approx 1 - e^{-c}$, $c = 4$ 取 98%
是否 broadcastKV 分片场景需要,集中场景不需
拓扑实现Ring → recursive halving, FC → All-to-One + Broadcast
实现关键全局索引 + tie-breaking 字典序
工程权衡稀疏 attention 本身近似,牺牲 2% 召回换 20× 通信省合理

@tbl-cc-topk-takeaway Distributed Top-K 要点

参考资料