Distributed Top-K
稀疏注意力与 MoE 场景下分布式 Top-K 的原语设计与通信代价
核心要点:
- 非经典原语 (NCCL 标准库未提供),稀疏 attention / MoE 普及催生需求
- 朴素 AllGather + 本地排序在 $n$ 大时不可行 ($O(n)$ 通信量)
- Tree-based 把通信量压到 $O(k \log N)$,利用 top-k 集合的可结合性
- 放大因子 $c = k' / k$ 权衡通信量与召回率,$c = 4$ 通常 ~98% 召回
- 索引编码必须全局,FP16 / BF16 tie-breaking 需附加字典序
语义
核心问题:Distributed Top-K 的语义是什么?与 AllGather + 本地排序的朴素方案有什么区别?
$N$ rank,每 rank $i$ 持长度 $n_i$ 的 score 向量 $\mathbf{s}_i$ (KV entry / token / candidate 的相关度),全局总长度 $n = \sum_i n_i$。
输出全局 top-k 索引集合:
$$\begin{equation} \mathcal{T} = \mathrm{Top}_k \left( \bigcup_{i=0}^{N-1} \mathbf{s}_i \right), \quad |\mathcal{T}| = k \label{eq:cc-topk-semantic} \end{equation}$$典型场景 $k = 1024$, $n_i = 250 \text{K}$, $k \ll n$。
与经典原语的区别
| 原语 | 输入 | 输出 | 与 top-k 关系 |
|---|---|---|---|
| AllGather | 每 rank 一段 | 所有 rank 持全量 | 朴素方案 = AllGather + 本地排 |
| AllReduce | 每 rank 一向量 | 逐元素归约 | 不适用 (top-k 不是逐元素) |
| ArgMax 归约 | 每 rank 一向量 | 全局 argmax | 是 top-1 特例,扩到 top-k 需新原语 |
@tbl-cc-topk-vs-others Top-k 与经典原语的关系
Top-k 是集合选择而非逐元素归约,不能套 AllReduce 模板。
朴素方案:AllGather + 本地排序
核心问题:最直接做法的代价?
1. AllGather: 每 rank 把 s_i 发给所有, 总通信 O(n)
2. 本地: 每 rank 对长度 n 向量做 top-k, O(n log k)
$n = 250 \text{K}$ FP32 score 每 rank 1 MB;每层每 query 一次,1M token query 每层 1 TB 级,不可行。
朴素方案仅当 $n < 1000$ 或 top-k 偶发时可用。
Tree-Based Top-K 归约
核心问题:能否利用 top-k 的代数性质把通信压到 $O(k)$?
代数基础
两个 top-$k'$ 集合的并集再取 top-$k$ 等于全局 top-$k$ (条件 $k' \ge k$),故 top-k 可按 reduction tree 结构归约。
基本算法
Stage 1 (本地): 每 rank 对 s_i 做本地 top-k', 得 (index, value) 列表
Stage 2 (tree reduce): 二叉树 (或环) 归约
- 相邻 rank 交换各自 top-k' 列表
- 合并后再取 top-k', 向上层传
Stage 3: 根持有全局 top-k, 视需要 broadcast
代价
每轮通信仅 $k'$ 个 (index, value) 对 (~12 字节 = FP32 index + FP32 value),$\log_2 N$ 轮:
$$\begin{equation} T_{\text{tree}} = \log_2 N \cdot \left( \alpha + \frac{12 k'}{\beta} \right) \label{eq:cc-topk-tree-time} \end{equation}$$比朴素 $O(n)$ 压到 $O(k')$, $k' \ll n$ 时收益巨大。
精确 vs 近似:放大因子 $c$
核心问题:$k' = k$ 通信最省但召回低,放大多少够用?
召回率公式 (均匀切分假设)
$$\begin{equation} \mathbb{E}[\text{Recall}] \approx 1 - \exp(-c), \quad c = k' / k \label{eq:cc-topk-recall} \end{equation}$$| $c$ | 期望召回率 |
|---|---|
| 1 | 63.2% |
| 2 | 86.5% |
| 4 | 98.2% |
| 8 | 99.97% |
@tbl-cc-topk-recall 放大因子与召回率 (均匀切分假设)
实际 score 分布常呈幂律 (少数 token 显著高分),真实召回率通常高于上表理论下界。
风险与权衡
- $k' = k$:通信最小,可能丢解 (本地排序中位于 $k+1$ 位之后的真全局 top-k 会被剪)
- $k' = c k$ ($c > 1$):保留更多候选,召回率上升,通信量正比上升
- 稀疏 attention 本身就是近似的,少 2% 召回换 20× 通信收益划算
单 rank 写回 vs 全 rank 复制
核心问题:Tree reduce 后是否要 broadcast 给所有 rank?
| 下游 | 是否 broadcast |
|---|---|
| KV cache 集中存储,root 负责 gather KV | 不需 |
| KV cache 分片存储,每 rank 自 gather 本地段 | 需 (广播 top-k 索引) |
| Top-k 索引用作稀疏 AllToAll 路由表 | 需 |
@tbl-cc-topk-broadcast 是否广播 top-k 索引的判断
Broadcast 通信量小 ($k$ 个 index = $4 k$ 字节), 但增加一个同步点。
拓扑相关实现
Ring 拓扑:Recursive Halving
第 $t$ 轮每 rank 与距离 $2^t$ 的 rank 交换 top-$k'$, $\log_2 N$ 轮后所有 rank 持全局。通信复杂度同 tree,但带宽利用率受 Ring 钳制。
Fat-Tree / NVSwitch 全连接:All-to-One + Broadcast
$N$ 较小 (单 pod 8–72 GPU) 时直接:
- 每 rank 一次性发本地 top-$k'$ 到 root
- Root 合并出全局 top-k
- Broadcast 给所有
Dragonfly / 多跳:两级 reduce
跨 group 时延迟代价高,两级:
- Group 内 reduce 到 group leader
- Leader 间再 reduce
与 hierarchical AllReduce 思路一致。
复杂度对比
核心问题:各 Top-K 算法变体在通信复杂度和计算复杂度上如何对比?
$N$ rank,每 rank $n / N$ score,目标 top-k, $k' = c k$:
| 方案 | 通信量 | 通信轮数 | 召回率 |
|---|---|---|---|
| AllGather + 本地排 | $O(n)$ | $\log_2 N$ | 100% |
| Tree top-k ($c = 1$) | $O(k \log N)$ | $\log_2 N$ | ~63% |
| Tree top-k ($c = 4$) | $O(4 k \log N)$ | $\log_2 N$ | ~98% |
| Tree top-k ($c = 8$) | $O(8 k \log N)$ | $\log_2 N$ | ~99.97% |
| 两级 (group + global) | $O(k \sqrt{N})$ 量级 | 2 阶段 | 取决于 $c$ |
@tbl-cc-topk-complexity Distributed top-k 方案复杂度对比
$k' = 4 k$ 在 $k = 1024$ 时单次通信 ~48 KB,比 AllGather 的 1 MB 降 20×, 98% 召回在稀疏 attention 场景足够。
应用场景
核心问题:Distributed Top-K 在稀疏注意力、MoE routing 等 LLM 场景中有哪些典型应用?
| 场景 | 典型参数 | 备注 |
|---|---|---|
| DeepSeek V3.2 DSA | $n = 160 \text{K}$, $k = 2048$ | KV 未压缩,每层 top-k |
| DeepSeek V4 CSA | $n \approx n_{\text{ctx}} / m$ (1M ctx 时 $n = 250 \text{K}$), $k = 512/1024$ | KV 压缩 4×, indexer 选择 |
| MoE gate (top-2 / top-6) | $n = E$ (expert 数),$k = 2 \sim 6$ | $n$ 很小,AllGather + 本地排 |
| Speculative decoding | $n = $ candidate 树宽,$k = $ beam | 通常单卡 |
| Vector DB 跨分片检索 | $n = $ 亿级 embedding, $k = $ 几百 | Tree-based 标配 |
@tbl-cc-topk-use-cases Top-k 典型应用场景
实现细节
核心问题:Distributed Top-K 在实际实现中有哪些关键工程决策和陷阱?
核心问题:实现时哪些坑会让 top-k 结果不对?
- 索引必须全局:tree-reduce 中 score 携带的 index 必须是 rank_id 拼 local offset,否则合并后索引冲突
- 数值稳定性:FP16 / BF16 同值 tie-breaking 可能不一致,需附
(value, global_index)字典序保证确定性 - GPU 实现:CUTLASS / FlashInfer 提供 device-side top-k kernel,通信层暴露
ncclSend / ncclRecv让上层组装 - 稀疏 send list 衔接:top-k 结果常作 sparse AllToAll 路由表,见 4.14 不规则与流水化 AllToAll
Takeaway
| 知识点 | 核心结论 |
|---|---|
| 性质 | 集合选择,非逐元素,不能复用 AllReduce |
| 朴素 AllGather | $O(n)$, $n$ 大时不可行 |
| Tree-based | $O(k \log N)$, top-k 可结合性是基础 |
| 放大因子 $c$ | 召回率 $\approx 1 - e^{-c}$, $c = 4$ 取 98% |
| 是否 broadcast | KV 分片场景需要,集中场景不需 |
| 拓扑实现 | Ring → recursive halving, FC → All-to-One + Broadcast |
| 实现关键 | 全局索引 + tie-breaking 字典序 |
| 工程权衡 | 稀疏 attention 本身近似,牺牲 2% 召回换 20× 通信省合理 |
@tbl-cc-topk-takeaway Distributed Top-K 要点
参考资料
- DeepSeek-V3.2 / V4 技术报告 — DSA / CSA 稀疏注意力机制
- FlashInfer — device-side top-k kernel
- 下一节:4.14 不规则与流水化 AllToAll — sparse / wave-scheduled AllToAll