Building an LLM Inference Engine from ScratchPart IV / Module 13
Part IV · Building xinfer

Module 13:Phase 4 — Runtime:KV Cache、Decode 与 Sampling

把完整 forward 接入推理 runtime:per-layer KV cache、prefill/decode 循环、采样器、tokenizer 与真实 Qwen2.5 端到端生成。

学习目标

  • 实现 per-layer KV cache,并理解每个 layer 都有自己的 K/V 缓存。
  • 理解 CopyBufferRegion 如何把新 token 的 K/V 追加到 cache。
  • 把 attention kernel 泛化到 seq_q != seq_k,支持 prefill 与单步 decode。
  • 实现 prefill → autoregressive decode loop。
  • 实现 greedy、temperature、top-k、top-p 采样和可复现 RNG。
  • 集成 HuggingFace tokenizer、Qwen chat template 与 EOS token。

13.1Per-layer KV cache:每层一份 K/V

Phase 3 的 forward 已经能对一段 token 序列输出 logits。若每生成一个 token 都重新计算完整历史,前面 token 的 key/value 会被反复生成。Phase 4 引入 KV cache:每个 decoder layer 各自保存历史 token 的 key 和 value,decode 时只计算新 token 的 Q/K/V,并让它读取已有缓存。

对第 \ell 层,cache 结构可以写成:

Kcache()RT×nkv×dhead,Vcache()RT×nkv×dheadK^{(\ell)}_{\text{cache}}\in\mathbb{R}^{T\times n_{\text{kv}}\times d_{\text{head}}}, \qquad V^{(\ell)}_{\text{cache}}\in\mathbb{R}^{T\times n_{\text{kv}}\times d_{\text{head}}}

其中 TT 是 cache capacity。QwenModel::make_cache 为每一层各分配一个 K buffer 与一个 V buffer,长度为 capacity * num_key_value_heads * head_dim。当新 token 进入第 \ell 层时,模型计算 Knew,VnewK_{\text{new}},V_{\text{new}},再写入从 pos_base 开始的槽位:

Kcache()[pos]Knew(),Vcache()[pos]Vnew()K^{(\ell)}_{\text{cache}}[\text{pos}] \leftarrow K^{(\ell)}_{\text{new}}, \qquad V^{(\ell)}_{\text{cache}}[\text{pos}] \leftarrow V^{(\ell)}_{\text{new}}

run_forward 中,这个追加由 exec.copy_region 记录为 D3D12 CopyBufferRegion:目标偏移为 pos_base * row_bytes,拷贝长度为 seq * row_bytes。新 K/V 与 cache 中的目标范围都是连续 GPU buffer 区间,因此无需把 K/V 下载到 CPU。

每个 decoder layer 都有自己的 K cache 与 V cache Layer ℓ K cache k0 k1 k2 new free Layer ℓ V cache v0 v1 v2 new free CopyBufferRegion dst offset = pos * row_bytes bytes = seq * row_bytes append 是纯 GPU buffer-to-buffer copy;不需要把 K/V 读回 CPU。
图 13-1:KV cache append:新 token 的 K/V 被拷贝到当前 cache 位置。

13.2泛化 attention:query 长度与 key 长度可以不同

无 KV cache 的 full forward 常可假设 seqq=seqk=Sseq_q=seq_k=S。加入缓存后,query 来自当前输入 token,而 key/value 来自“历史缓存 + 当前输入”,两者长度不再相同。单步 decode 的典型形态为:

seqq=1,seqk=past length+1seq_q = 1,\qquad seq_k = \text{past length}+1

因此 attention kernel 必须接受两个长度:seq_qseq_k。对 query row tqt_q,它的绝对位置是 q_pos_base+tqq\_\text{pos\_base}+t_q,可见 key 数为:

key_count=min(q_pos_base+tq+1,  seqk)\text{key\_count}=\min(q\_\text{pos\_base}+t_q+1,\;seq_k)

这个公式同时支持 prefill 和 decode:prefill 时 q_pos_base=0q\_\text{pos\_base}=0,decode 时它等于 cache 中已有 token 数。

图 13-2:prefill 与 decode 的 query/key 长度不同。

13.3Prefill → decode loop

xinfer-runtime::generate 将模型、tokenizer、KV cache 与 sampler 连接为一个生成循环。实际实现用 GenerateOptions { max_new_tokens, sampling, seed, eos_token_id } 控制生成长度、采样策略、随机种子与停止条件。流程可写成以下伪代码:

fn generate(prompt_tokens, max_new_tokens):
    cache = model.make_cache(prompt_len + max_new_tokens)

    // 1. prefill:一次处理整个 prompt
    logits_or_argmax = model.decode(cache, prompt_tokens, pos_base=0)
    next = sample(last_logits)

    // 2. decode:一次追加一个 token
    output = []
    for step in 0..max_new_tokens:
        if next == eos: break
        output.push(next)
        logits_or_argmax = model.decode(cache, [next], pos_base=cache.len)
        next = sample(last_logits)

    return tokenizer.decode(output)

工程中的 model.decode 表示“将 tokens 追加到 KV cache 后执行一次 decoder forward”。prefill 调用输入多个 prompt token;decode loop 每次输入一个新 token。pos_base 必须等于当前 cache.len(),否则位置编码和 cache 写入都会失去一致性。

Prompt tokenslength = S Prefillfill KV cache Sample nexttoken y0 Decode oneappend K/V Sample nexttoken yi Prefill 建立上下文;decode 每轮只处理新 token,但会读取越来越长的 KV cache。
图 13-3:runtime 先 prefill,再进入单 token decode 循环。

13.4Sampling:Rust 中的可复现采样器

采样器把最后一行 logits 转换为下一个 token id。SamplingConfig 只有三个字段:temperaturetop_ktop_ptemperature <= 0.0 选择 greedy。Phase 4 实现的策略包括:

  • **Greedy:**取最大 logit;
  • **Temperature:**缩放 logits;
  • **Top-k:**只保留最高的 kk 个候选;
  • **Top-p:**只保留累积概率达到 pp 的最小候选集合;
  • **可复现 RNG:**使用简单 xorshift,使实验结果可重复。

temperature 公式:

pi=softmax(ziτ)p_i = \operatorname{softmax}\left(\frac{z_i}{\tau}\right)

top-p 集合:

Sp=min{S:iSpip}\mathcal{S}_p = \min\left\{S:\sum_{i\in S}p_i\ge p\right\}
工程优化

Greedy 路径使用 QwenModel::decode_argmax:GPU 对最后一行 logits 做 argmax,CPU 只读回一个 u32 token id。Qwen2.5 的词表为 151936,完整 f32 logits 约 608 KB/token;该 fast path 避免了这次 readback。temperature、top-k、top-p 仍调用 decode 读回完整 logits,并在 CPU 上采样。

13.5Tokenizer 与 Qwen chat template

Runtime 的输入输出边界是文本,模型内部处理的是 token id。xinfer-tokenizer 封装 HuggingFace tokenizers,从模型目录读取 tokenizer.json,并提供 encode/decode:

let tokenizer = Tokenizer::from_dir(model_dir)?;
let ids = tokenizer.encode(prompt_text, false)?;
let text = tokenizer.decode(&generated_ids, true)?;

对 instruct 模型,CLI 默认使用 qwen_chat_prompt(system, user) 包装 system、user 与 assistant 起始标记:

<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
... user prompt ...<|im_end|>
<|im_start|>assistant

read_eos_tokengeneration_config.jsonconfig.json 读取 eos_token_id。生成循环在采样到 EOS 时停止,并在统计中标记 hit_eos

13.6第一次真实端到端 demo

Phase 4 的验收对象是真实 Qwen2.5 checkpoint 的端到端路径:加载配置与 safetensors,构建 tokenizer 和 GPU model,分配 KV cache,执行 prefill/decode,并在 CLI 中流式输出生成文本。

xinfer generate --model models\Qwen2.5-0.5B-Instruct \
  --prompt "What is the capital of France?" --max-tokens 16

The capital of France is Paris.

一次实际运行还会打印加载时间、GPU resident weights、prompt token 数、KV cache 大小以及 prefill/decode 吞吐。例如 Qwen2.5-0.5B 的一次记录为:Loaded in 4.0s (942 MiB GPU weights)Prompt tokens: 26 (KV cache 2 MiB)prefill: 26 tok in 69 ms (378 tok/s) | decode: 7 tok in 0.23 s (30.5 tok/s) [eos]。从系统角度看,命令经过如下路径:

CLIprompt Tokenizerids Runtimeprefill/decode QwenModelKV cache DML/HLSLGPU kernels decode textstream output
图 13-4:第一次端到端 demo 证明所有模块已经连成完整系统。

Lab 13生成文本,并验证 incremental decode == full forward

本实验包含两个验收目标。第一,真实模型能完成端到端生成。第二,KV cache incremental decode 必须与 full forward 的最后一行 logits 匹配;该测试直接覆盖 cache 追加、RoPE 位置偏移和 causal mask。

# 真实模型生成
cargo run --release -p xinfer-cli -- generate \
  --model models\Qwen2.5-0.5B-Instruct \
  --prompt "What is the capital of France?" \
  --max-tokens 16

# KV cache correctness
cargo test -p xinfer-model --test forward_parity -- --nocapture

正确性测试的数学意义是:

logitsfull[S1]logitsprefill+decode[0]\operatorname{logits}_{\text{full}}[S-1] \approx \operatorname{logits}_{\text{prefill+decode}}[0]

也就是说,先一次性 forward 全部 tokens,和先 prefill 前 S1S-1 个 token 再 decode 最后一个 token,最后位置的 logits 应该一致。

小结

本章把 Phase 3 的完整 forward 扩展为可生成文本的 runtime。KvCache 为每层保存 K/V;attention kernel 接受不同的 seq_qseq_kgenerate 使用 GenerateOptionsSamplingConfig 管理 prefill、decode、采样与 EOS;greedy 路径通过 GPU argmax 只读回一个 token id;tokenizer 与 Qwen chat template 构成文本输入输出边界。至此,引擎已经能够对真实 Qwen2.5 checkpoint 执行端到端推理。

思考与练习

基础为什么每一层都需要自己的 KV cache?

因为每层有各自独立的 K/V 投影权重,产生不同的 key/value 表示。第 \ell 层的 attention 只能用第 \ell 层历史 token 的 K/V,不能跨层共用。所以要为 24 层各保存一份 KV cache,decode 时各层把新 token 的 K/V 追加到对应层的缓存里。

基础prefill 与 decode 调用的是同一个 model.decode,为什么输入 shape 不同?

prefill 一次喂入整个 prompt,seq_q = prompt 长度(如 128),同时把这 128 个 K/V 写入空缓存(seq_k 也是 128)。decode 每步只喂 1 个新 token,seq_q=1,而 seq_k = 已有历史长度+1。同一函数通过 seq_q≠seq_kq_pos_base 两个参数统一处理两种情形,attention kernel 据此设置 causal 上限。

进阶推导 key_count = min(q_pos_base + tq + 1, seq_k) 的含义。

query 在序列中的绝对位置是 p=q_pos_base+tqp=q\_pos\_base+t_qtq 是本批内的相对索引)。causal 规则下它能看见位置 0..p0..p,即 p+1p+1 个 key,故可见数为 q_pos_base+tq+1q\_pos\_base+t_q+1。再对 seq_k(缓存中实际存在的 key 数)取 min,防止越界。decode 时 tq=0t_q=0q_pos_baseq\_pos\_base = 历史长度,于是 key_count = 历史长度+1 = 全部缓存,符合“新 token 能看到所有过去 + 自己”。

进阶解释 greedy fast path 为什么可以只读回一个 u32

greedy 解码只需要 argmax(logits) 这一个 token id,不需要整条概率分布。于是在 GPU 上做归约求出最大 logit 的下标,只把这个 4 字节的 u32 拷回 CPU,而不是把 151936 个 f32(约 608 KB)全读回再在 CPU 上找 max。这把每 token 的 readback 从约 608 KB 降到 4 字节,是 greedy decode 达到约 30.5 tok/s 的关键优化之一。temperature、top-k、top-p 采样仍需概率分布信息,因此保留完整 logits readback 与 CPU 采样路径。

挑战实现一个测试:随机 prompt 下 full forward 最后一行 logits 与 incremental decode logits 的 max abs diff 小于阈值。

步骤:① 随机生成一段 token 序列作为 prompt;② 路径 A:对完整序列做一次 full forward,取最后一个位置的 logits;③ 路径 B:先 prefill 前 n1n-1 个 token 建立 KV cache,再对第 nn 个 token 调 decode 拿 logits;④ 断言 max(abs(A - B)) < eps

理论上 incremental decode 与 full forward 在数学上等价,f32 路径 diff 应为 0(xinfer 的 kv_cache_decode_matches_full 实测 diff=0);若用 f16 kernel 则放宽到 ~1e-2。这个测试守护 KV cache 追加、位置编码偏移与 causal mask 的正确性——它们是 decode 最易出错的环节。