跳到主要内容

长上下文的第一性挑战

核心要点

  • Attention 计算 $O(n^2)$:prefill 算力随序列长度二次增长,是 TTFT 恶化的主因
  • KVCache 显存 $O(n)$:随序列长度线性增长,长上下文下首先撞单卡 HBM 上限
  • 位置编码外推困难:超出训练长度的位置模型未见过,效果退化
  • 中段遗忘 (Lost in the Middle):模型对上下文首尾敏感、中段不敏感,长上下文中段信息利用率低
  • 四个挑战分别从算力 / 显存 / 表征 / 利用率四个维度限制可用上下文长度

在讨论"怎么实现长上下文"之前,先把"为什么长上下文难"讲清楚。本文不涉及具体优化方案(那些方案分散在第 03-08 章),只立 4 个第一性问题,后续每章都对应解一类。

挑战 1:Attention 计算量随序列长度 $O(n^2)$ 增长

来源:QK 矩阵的全对比

原版 Transformer[1] 的 attention 计算(单头、batch 1)为:

$$\begin{equation} \text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V \label{eq:lc-ch-attention} \end{equation}$$

其中 $Q, K, V \in \mathbb{R}^{n \times d}$$n$ 是序列长度,$d$ 是头维度。三步关键操作的复杂度见

操作形状计算量显存
$S = Q K^\top$$n \times n$$O(n^2 d)$$O(n^2)$
$P = \text{softmax}(S / \sqrt{d_k})$$n \times n$$O(n^2)$$O(n^2)$
$O = P V$$n \times d$$O(n^2 d)$$O(n^2)$

@tbl-longctx-challenge-attn-complexity 原版 attention 三步操作的算力与显存复杂度

主导项 $O(n^2 d)$ 来自 $QK^\top$$PV$序列翻倍 → 算力翻 4 倍。这是长上下文 prefill 阶段算力压力的根本来源。

Prefill 与 Decode 的非对称性

阶段一次性处理单步计算量累计计算量算力 / 访存特征
Prefill整段 prompt 共 $n$ 个 token$O(n^2 d)$$O(n^2 d)$计算密集,FLOP/Byte 高
Decode单 token$O(n d)$输出 $m$ token 后 $O(n m d)$访存密集,FLOP/Byte 低

@tbl-longctx-challenge-prefill-decode prefill 与 decode 的算力特征

关键观察

  • Prefill 算力随 $n^2$:TTFT 随上下文长度恶化。1M token prompt 的 prefill 在 H100 上要数十秒级。
  • Decode 每步 $O(nd)$:每生成一个 token 都要扫一遍 KVCache,TPOT 随 $n$ 线性涨,但绝对延迟由显存带宽决定。
  • 两阶段算术强度差异大:roofline 上 prefill 在算力 roof 附近、decode 在带宽 roof 附近。优化手段不能"一招通吃"。

FlashAttention 的边界

FlashAttention[2] 等 IO-aware 实现消除了显存上$O(n^2)$ —— 不再物化 $n \times n$ 的 attention 矩阵,而是按块 (tile) 流式计算并融合 softmax。但计算量仍是 $O(n^2 d)$:FLOPs 没变,只是从 HBM 搬到 SRAM 算,访存压力降到 $O(n)$

所以 FlashAttention 让长上下文"放得下"(显存)但没让它"算得快"(算力)。算力 $O(n^2)$ 的本质削减需要注意力机制变体(→ 04-注意力机制变体)。

挑战 2:KVCache 显存随序列长度 $O(n)$ 增长

KVCache 占用公式

每生成一个 token 都要存它的 K 和 V,供后续 decode 重用。单 batch 单序列的 KVCache 总占用:

$$\begin{equation} M_{\text{KV}} = 2 \cdot L \cdot H_{\text{kv}} \cdot d_h \cdot n \cdot b \label{eq:lc-ch-kv-memory} \end{equation}$$
符号含义
$L$层数
$H_{\text{kv}}$KV head 数(GQA 下小于 query head 数)
$d_h$头维度
$n$序列长度
$b$bytes per value(FP16/BF16 = 2,FP8 = 1)
系数 2K 和 V 两份

@tbl-longctx-challenge-kv-formula KVCache 占用计算

数量级直观

以 LLaMA 3 70B($L=80, H_{\text{kv}}=8, d_h=128$,BF16 $b=2$)为例。序列长度按 HPC 习惯按 2 的幂展开($1\text{K} = 2^{10}$$1\text{M} = 2^{20}$):

序列长度单 batch KV (GiB)单卡 H100 80GiB 装得下吗
8K2.5轻松
32K10.0可以
128K40.0紧张(模型权重已占 ~35GiB / 卡,TP=4 时)
1M320单卡放不下,必须切分

@tbl-longctx-challenge-kv-size LLaMA 3 70B 在不同序列长度下的 KVCache 占用

表中"装得下吗"基于 TP=4 部署:模型权重约 35 GiB / 卡,激活与 framework 开销另需 ~10 GiB,留给 KV 的空间约 35 GiB。

关键观察

  • KV 占用是模型参数量级以外的独立显存压力,且 batch 大时还要乘 batch。
  • 1M token 单 batch 已 300 GiB+,远超任何单 HBM。这是架构层 KV 压缩05-kv-cache架构压缩)和推理层 KV 管理07-推理-kv管理)必要性的根本来源。
  • 互联视角:单卡放不下 → 上下文并行(CP)把序列切到多卡,KV 跨卡分布 → attention 全局性带来跨卡通信。系统层面的详细分析见 上下文并行 (CP)KV 跨节点传输瓶颈

KV 是 Decode 阶段的隐性瓶颈

Decode 每生成一个 token,需要把全量 KV 从 HBM 读到 SRAM 算 attention。单 token TPOT ≈ KV 总量 / HBM 带宽

序列长度KV (GiB, TP=4 单卡)H100 带宽 3.35 TB/s 单 token 延迟
32K2.5~0.7 ms
128K10.0~3.0 ms
1M78~23 ms

@tbl-longctx-challenge-decode-bw KV 大小对 decode TPOT 的影响(理论下界)

数值仅按"扫一遍 KV"估算 attention 部分,不含 FFN 等其他算子,实际延迟更高。

关键观察:长上下文下 decode 的"每 token 延迟"由 KV 大小和 HBM 带宽决定,与算力无关。算力堆得再高也救不了 decode——这是长上下文部署偏好"大显存 + 高带宽"的根本原因,也是超节点架构的核心价值所在。

挑战 3:位置编码外推困难

训练长度限制

Transformer 的注意力本身是位置无关的(softmax 是集合运算),位置信息靠位置编码注入。常见做法(RoPE、ALiBi 等)在训练时给每个位置一个表示。

问题:训练时只在 $[0, n_{\text{train}})$ 区间见过这些位置编码。推理时若位置超出 $n_{\text{train}}$,模型从未学过这部分编码的语义,直接外推效果退化——困惑度(perplexity)爆炸、生成质量崩溃。

举例:LLaMA 2 训练长度 4K,直接用于 8K 上下文时 perplexity 上升数倍;用于 32K 时几乎不可用。

三类应对思路

思路代表方法训练成本
设计外推友好的编码ALiBi、xPos训练时已具备一定外推性
推理时调整编码(无训练)NTK-aware RoPE、Position Interpolation零成本,效果有限
短训长推(少量长数据微调)YaRN、LongRoPE千-万级 token 微调即可达 256K+

@tbl-longctx-challenge-extrapolation 位置编码外推的三类应对

详细对比在 03-位置编码与外推

关键观察:位置外推是模型能力问题,不是系统问题——单纯堆显存、堆算力解决不了。这是为什么"模型上下文长度"和"硬件能跑多长"是两个独立指标,前者由训练 / 外推决定,后者由显存 / 互联决定。

挑战 4:中段遗忘 (Lost in the Middle)

现象

Liu et al., 2023[3] 系统观测到:即使模型宣称支持很长上下文,对上下文中段的信息利用率低于首尾。一个典型实验:把答案藏在 N 个文档中的不同位置,正确率与位置呈 U 形——首尾高、中段低。

这意味着:

  • "宣称上下文长度"$\neq$"有效上下文长度"
  • 长上下文下若关键信息恰好落在中段,模型可能视而不见
  • 简单堆长度不能保证长上下文质量

与外推问题的区别

维度位置外推问题中段遗忘
触发条件推理位置超过训练长度推理位置在训练范围内但靠中段
失败模式perplexity 爆炸、乱码输出流畅但忽略关键信息
缓解手段YaRN / NTK / 长数据微调长数据训练 + 中段重点训练 + 评测覆盖中段
评测方式困惑度、长输出连贯性needle-in-a-haystack 多位置变体、RULER

@tbl-longctx-challenge-extrapolation-vs-lost 位置外推与中段遗忘的区别

详见 09-评测与现状 中评测体系对中段质量的覆盖。

四个挑战的归因总结

挑战受限维度主要缓解章节互联视角链接
Attention $O(n^2)$算力 (Prefill TTFT)04-注意力机制变体08-推理-调度 (chunked prefill)无(纯模型 / 计算层)
KVCache $O(n)$显存 / 带宽 (Decode TPOT)05-KV 架构压缩07-推理-KV 管理上下文并行KV 传输瓶颈
位置外推表征能力03-位置编码与外推06-训练侧无(纯模型层)
中段遗忘上下文利用率06-训练侧09-评测无(纯模型层)

@tbl-longctx-challenge-summary 四个第一性挑战的归因与对应章节

全局观察

  • 4 个挑战分布在 4 个不同维度(算力、显存、表征、利用率),任何单一手段都解决不了所有问题。
  • 算力与显存两个维度的解药同时涉及"模型层"(变体 / 压缩)和"系统层"(CP / KV 传输 / 超节点)。本章节其余文档讲模型层,互联视角链接到 docs/interconnect/
  • 表征与利用率两个维度纯模型层,与系统无关——但是长上下文最终效果的天花板,再快的系统也救不了"看到了但不会用"。

Takeaway

知识点核心结论
Attention $O(n^2)$prefill 算力随序列长度二次增长,序列翻倍算力翻 4 倍,是 TTFT 恶化主因,解药在 04-注意力机制变体
KVCache $O(n)$显存随序列长度线性增长,1M token 单 batch 已 300 GiB+,并制约 decode TPOT,解药在 05 / 07
位置外推超训练长度的位置模型未学过,外推困惑度爆炸,属模型能力问题,解药在 03-位置编码与外推
中段遗忘中段信息利用率低于首尾,宣称长度不等于有效长度,解药在 06-训练侧 / 09-评测
四维度归因算力 / 显存 / 表征 / 利用率四个独立维度,无单一手段通解

@tbl-longctx-challenge-takeaway 全文要点

参考资料

  1. Vaswani et al., Attention Is All You Need, 2017. https://arxiv.org/abs/1706.03762
  2. Dao et al., FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, 2022. https://arxiv.org/abs/2205.14135
  3. Liu et al., Lost in the Middle: How Language Models Use Long Contexts, 2023. https://arxiv.org/abs/2307.03172

延伸阅读