Building an LLM Inference Engine from ScratchPart I / Module 3
Part I · Foundations

Module 3:自回归生成、KV Cache 与采样

理解一个 LLM 如何从 prompt 开始逐 token 生成文本,以及为什么 KV cache、采样策略和吞吐指标决定推理系统的体验。

学习目标

  • 区分 prefill 与 decode,并解释它们的计算形态为什么不同。
  • 推导没有 KV cache 时的重复计算成本,以及 KV cache 如何用显存换计算。
  • 理解 greedy、temperature、top-k、top-p 采样的数学含义与行为差异。
  • 理解 tokenizer、chat template、special token / EOS token 在生成循环中的位置。
  • 能定义并测量 prefill tok/s、decode tok/s、latency 与 throughput。

3.1自回归生成:一次只预测一个 token

decoder-only LLM 采用 autoregressive generation(自回归生成):给定 prompt tokens,模型输出最后一个位置的 logits;采样器选出下一个 token;该 token 追加到上下文末尾后,模型继续预测下一步。循环在生成 EOS 或达到 max_new_tokens 时停止。xinfer-runtime::generate 按这一控制流实现 prefill 与逐 token decode。

xt+1Pθ(x1,,xt)x_{t+1} \sim P_\theta(\cdot \mid x_1,\ldots,x_t)

如果使用 greedy decoding,符号 \sim 可以改写成取最大值:

xt+1=argmaxizix_{t+1} = \arg\max_i z_i

其中 ziz_i 是最后一行 logits 中第 ii 个 token 的分数。Qwen2.5-0.5B 的词表大小为 151936,因此这一行 logits 含 151936 个分量。

Prompt tokens [x1, x2, ..., xt] Model forward last logits z Sampling choose xt+1 Append [x1..xt, xt+1] 循环直到 EOS 或 max_new_tokens
图 3-1:自回归生成循环:forward → sampling → append → 再 forward。

3.2Prefill 与 Decode:同一个模型,两种工作负载

推理过程通常分为 prefill 与 decode 两个阶段。二者调用同一组 decoder layer,但输入长度和系统瓶颈不同:

阶段输入目标主要工作负载
Prefill整个 prompt,长度 SS建立上下文表示与 KV cache较大的 GEMM,attention 处理 S\times S
Decode每次 1 个新 token生成下一个 tokenGEMV + 读 KV cache + LM head

prefill 一次处理整个 prompt,并把每层新产生的 K/V 写入 cache。decode 每步只输入刚生成的 1 个 token;第 t+2t+2 个 token 的输入依赖第 t+1t+1 个 token 的采样结果,因此不能直接并行生成多个未来 token。

性能报告需要分开给出两个吞吐量:

prefill tok/s=prompt tokensprefill seconds,decode tok/s=generated tokensdecode seconds\text{prefill tok/s} = \frac{\text{prompt tokens}}{\text{prefill seconds}}, \qquad \text{decode tok/s} = \frac{\text{generated tokens}}{\text{decode seconds}}
图 3-2:Prefill 通常吞吐更高;decode 受自回归依赖限制,延迟更关键。

3.3为什么需要 KV cache?

每个 decoder layer 都会由 hidden states 计算 key 和 value。设已有 tt 个历史 token,下一步只新增 1 个 token。若没有 KV cache,每一步都必须把完整上下文重新送入所有层,历史 token 的 Q/K/V projection 和 attention 输入会被反复计算。

使用 KV cache 后,模型只为新 token 计算一次 K,VK,V,并把结果追加到每层缓存的下一个槽位。xinfer-model::run_forward 在 RoPE 后通过 copy_region 将新的 K/V 写入 cache.k[li]cache.v[li]

Kcache()[t]Knew(),Vcache()[t]Vnew()K_{\text{cache}}^{(\ell)}[t] \leftarrow K_{\text{new}}^{(\ell)},\qquad V_{\text{cache}}^{(\ell)}[t] \leftarrow V_{\text{new}}^{(\ell)}

计算 attention 时,新 query 读取缓存中已经写入的 keys/values。decode 单步的典型形状为 Q:[1,14,64]Q:[1,14,64]K,V:[T,2,64]K,V:[T,2,64],输出再展开为 [1,896][1,896]

Attn(qt,Kt,Vt)=softmax(qtKtd)Vt\operatorname{Attn}(q_t,K_{\le t},V_{\le t}) = \operatorname{softmax}\left(\frac{q_tK_{\le t}^{\top}}{\sqrt{d}}\right)V_{\le t}
KV cache 的本质

KV cache 用显存保存过去 token 的 K/V,避免在每个 decode step 重新计算历史 token 的 K/V projection。只要位置、causal mask 与数值精度一致,它不改变模型的数学输出,只改变计算路径与内存占用。

每一层都有自己的 K cache 与 V cache K cache k0 k1 k2 k_new free V cache v0 v1 v2 v_new free new query q attend to slots 0..t append 只写一个新槽位;attention 读取所有已填槽位。
图 3-3:KV cache 的追加与读取:每层各有一份 K/V 缓存。

KV cache 的显存公式

对本项目的 f32 KV cache,容量按下式计算:

bytesKV=2LTnkvdheadbytes(dtype)\operatorname{bytes}_{KV} = 2 \cdot L \cdot T \cdot n_{\text{kv}} \cdot d_{\text{head}} \cdot \operatorname{bytes}(\text{dtype})

其中 2 表示 K 和 V,LL 是层数,TT 是 cache capacity,nkvn_{\text{kv}} 是 KV head 数。Qwen2.5-0.5B 取 L=24L=24nkv=2n_{\text{kv}}=2dhead=64d_{\text{head}}=64、dtype=f32。若 prompt 为 26 tokens、max_new_tokens=64,CLI 分配的容量为 T=90T=90,KV cache 约 224902644=2,211,8402\cdot24\cdot90\cdot2\cdot64\cdot4=2{,}211{,}840 bytes,即约 2.1 MiB。若没有 GQA 而使用 14 个 KV head,该项会扩大为 7 倍。

3.4采样:从 logits 到下一个 token

模型输出 logits zz 后,采样器把它转换为下一个 token id。xinfer-runtime::sampling 支持 greedy、temperature、top-k 与 top-p。默认配置为 temperature=0.0top_k=0top_p=1.0,即 greedy decoding。

Greedy / Argmax

xt+1=argmaxizix_{t+1} = \arg\max_i z_i

greedy 每步选择最大 logit 对应的 token,输出确定且可复现。本项目对 greedy 走 decode_argmax 快路径:shaders/argmax.hlsl 在 GPU 上归约最后一行 logits,只读回一个 u32 token id。若读回完整 f32 logits,Qwen2.5-0.5B 每步需要传输约 151936×4608151936\times4\approx608 KB。

Temperature

temperature > 0 时,采样器先按温度缩放 logits:

pi=softmax(ziτ)p_i = \operatorname{softmax}\left(\frac{z_i}{\tau}\right)

τ\tau 越小,概率分布越集中;τ\tau 越大,低 logit token 获得更高相对概率。代码中 temperature <= 0.0 被视为 greedy,不执行随机采样。

Top-k 与 Top-p

Top-k 在排序后只保留 logit 最高的 kk 个候选;top_k=0 表示不启用该截断。Top-p(nucleus sampling)在 softmax 后按概率从高到低累加,只保留累积概率首次达到 pp 的前缀;top_p>=1.0 表示不启用该截断:

Sp=min{S:iSpip}\mathcal{S}_p=\min\left\{S:\sum_{i\in S}p_i\ge p\right\}
图 3-4:temperature 改变 softmax 分布的尖锐程度;top-k/top-p 再截断候选集合。

3.5Tokenizer、chat template 与 special tokens

模型输入输出均为 token id,而不是 Unicode 字符串。xinfer-tokenizer 封装 HuggingFace tokenizers,从模型目录加载 tokenizer.json,提供 encodedecodetoken_to_id

text  tokenizer  [x1,,xS][y1,,yT]  decode  text\text{text}\;\xrightarrow{\text{tokenizer}}\;[x_1,\ldots,x_S] \qquad [y_1,\ldots,y_T]\;\xrightarrow{\text{decode}}\;\text{text}

对于 instruct 模型,裸 user prompt 通常需要先套用 chat template。qwen_chat_prompt(system, user) 在 CLI 默认启用,生成的格式为:

<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What is the capital of France?<|im_end|>
<|im_start|>assistant

EOS token 表示模型认为当前回答已经结束。read_eos_token 优先读取 generation_config.json;Qwen2.5-0.5B-Instruct 的 generation config 中 EOS 列表为 151645 与 151643,当前实现取列表中的第一个 id。生成循环在以下条件之一成立时停止:

  • 生成了 EOS token;
  • 达到 max_new_tokens
  • 外部系统要求中断(例如 UI 停止按钮)。

3.6吞吐、延迟与用户体验

推理性能不能只用单个平均值描述。至少需要区分首 token 前的等待、持续生成速度、显存占用与服务端吞吐:

指标含义用户感知
Prefill tok/s处理 prompt 的速度首 token 等待时间
Decode tok/s持续生成 token 的速度回答流畅度
Latency单次请求耗时是否“卡顿”
Throughput单位时间处理多少请求或 token服务端成本
Memory权重 + KV cache + 中间 buffer能否运行更大模型/更长上下文

本项目的示例运行报告为 prefill: 26 tok in 69 ms (378 tok/s) | decode: 7 tok in 0.23 s (30.5 tok/s)。decode 的关键优化是 greedy 模式下避免每 token 将完整 vocab logits 读回 CPU;temperature、top-k、top-p 路径仍调用 model.decode 读回 logits,并在 CPU 上执行采样。

图 3-5:同一模型、同一任务,系统优化会显著改变 decode 体验。

小结

自回归生成把一次回答拆成 prefill 和 decode。prefill 处理完整 prompt 并建立 KV cache;decode 每步处理 1 个新 token,从 cache 中读取历史 K/V,并将新 K/V 追加到每层缓存。Qwen2.5-0.5B 的 logits 长度为 151936;greedy 模式可用 GPU argmax 只读回 token id,而 temperature、top-k、top-p 仍需把 logits 读回 CPU 后采样。tokenizer 与 chat template 决定模型实际看到的 token 序列,EOS 与 max_new_tokens 决定生成何时结束。

Lab 3纯 CPU 生成循环:KV cache + samplers

本实验不需要 GPU。用 tiny 模型或假 logits 实现完整生成循环,重点检查 prefill、cache 追加、EOS 判断与采样器调用顺序。

def generate(prompt_tokens, max_new_tokens):
    cache = empty_kv_cache()
    logits = model_prefill(prompt_tokens, cache)
    next_token = sample(logits[-1])

    output = []
    for _ in range(max_new_tokens):
        if next_token == EOS:
            break
        output.append(next_token)
        logits = model_decode_one(next_token, cache)
        next_token = sample(logits[-1])
    return output

思考与练习

基础解释 prefill 和 decode 的区别。为什么 decode 不能一次并行生成 10 个未来 token?

prefill 一次性处理整个 prompt(多个 token 并行),建立上下文表示和 KV cache;decode 每次只处理一个新 token。decode 不能一次并行生成 10 个未来 token,是因为自回归依赖:第 t+2t+2 个 token 的输入依赖第 t+1t+1 个 token 的采样结果,而后者必须先生成。未来 token 之间存在串行因果链;在第一个 token 尚未采样时,后续 token 的输入并不存在。(speculative decoding 是用小模型“猜”多个 token 再验证,属于另一种思路。)

基础如果 prompt 有 128 个 token,生成 32 个 token,KV cache 最终 len 是多少?

128+32=160128+32=160。prefill 写入 128 个槽位,随后每 decode 一个 token 追加一个槽位,共 32 个。注意为避免溢出,cache capacity 至少要预留 160。

进阶推导 KV cache 显存公式,并代入 L=24,T=1024,nkv=2,d=64,dtype=f32L=24,T=1024,n_{\text{kv}}=2,d=64,\text{dtype}=f32

每层每个 token 要存 K 和 V,各 nkv×dn_{\text{kv}}\times d 个元素,所以:bytes=2LTnkvdbytes(dtype)\text{bytes}=2\cdot L\cdot T\cdot n_{\text{kv}}\cdot d\cdot \text{bytes(dtype)}

代入:2×24×1024×2×64×4=2×24×1024×512=25,165,8242\times24\times1024\times2\times64\times4=2\times24\times1024\times512=25{,}165{,}824 字节 24\approx 24 MiB。可见上下文 TT 越长,KV cache 线性增长;GQA 让 nkv=2n_{\text{kv}}=2 而非 14,已经把它压到 1/7。

进阶给定 logits [1,2,3][1,2,3],分别计算 temperature 1.01.00.50.5 下 softmax 分布的变化趋势。

τ=1.0\tau=1.0:softmax([1,2,3]) ≈ [0.090, 0.245, 0.665]。

τ=0.5\tau=0.5:先把 logits 除以 0.5 得到 [2,4,6],softmax([2,4,6]) ≈ [0.016, 0.117, 0.867]。

趋势:温度变小,分布变得更尖锐,最大 logit 对应 token 的概率更突出(0.665 → 0.867),低概率 token 被进一步压制。温度越大则相反,分布更平、更随机。

挑战实现 top-p sampling,并写一个测试证明保留集合的累积概率至少为 pp

实现步骤:① 对概率从高到低排序;② 累加,直到累积和首次 p\ge p,取到该位置为止的最小集合 Sp\mathcal{S}_p;③ 在 Sp\mathcal{S}_p 内重新归一化后按概率采样。

测试思路:构造若干随机分布,对每个分布运行 top-p,断言 sum(probs in S_p) >= p - 1e-6(留浮点容差)。由于算法是“累加到首次达到 pp 才停止”,集合累积概率必然 p\ge p,同时它是满足该条件的最小集合(去掉最后一个元素就 <p<p)。