Module 13:Phase 4 — Runtime:KV Cache、Decode 与 Sampling
把完整 forward 接入推理 runtime:per-layer KV cache、prefill/decode 循环、采样器、tokenizer 与真实 Qwen2.5 端到端生成。
学习目标
- 实现 per-layer KV cache,并理解每个 layer 都有自己的 K/V 缓存。
- 理解
CopyBufferRegion如何把新 token 的 K/V 追加到 cache。 - 把 attention kernel 泛化到
seq_q != seq_k,支持 prefill 与单步 decode。 - 实现 prefill → autoregressive decode loop。
- 实现 greedy、temperature、top-k、top-p 采样和可复现 RNG。
- 集成 HuggingFace tokenizer、Qwen chat template 与 EOS token。
13.1Per-layer KV cache:每层一份 K/V
Phase 3 的 forward 已经能对一段 token 序列输出 logits。若每生成一个 token 都重新计算完整历史,前面 token 的 key/value 会被反复生成。Phase 4 引入 KV cache:每个 decoder layer 各自保存历史 token 的 key 和 value,decode 时只计算新 token 的 Q/K/V,并让它读取已有缓存。
对第 层,cache 结构可以写成:
其中 是 cache capacity。QwenModel::make_cache 为每一层各分配一个 K buffer 与一个 V buffer,长度为 capacity * num_key_value_heads * head_dim。当新 token 进入第 层时,模型计算 ,再写入从 pos_base 开始的槽位:
在 run_forward 中,这个追加由 exec.copy_region 记录为 D3D12 CopyBufferRegion:目标偏移为 pos_base * row_bytes,拷贝长度为 seq * row_bytes。新 K/V 与 cache 中的目标范围都是连续 GPU buffer 区间,因此无需把 K/V 下载到 CPU。
13.2泛化 attention:query 长度与 key 长度可以不同
无 KV cache 的 full forward 常可假设 。加入缓存后,query 来自当前输入 token,而 key/value 来自“历史缓存 + 当前输入”,两者长度不再相同。单步 decode 的典型形态为:
因此 attention kernel 必须接受两个长度:seq_q 与 seq_k。对 query row ,它的绝对位置是 ,可见 key 数为:
这个公式同时支持 prefill 和 decode:prefill 时 ,decode 时它等于 cache 中已有 token 数。
13.3Prefill → decode loop
xinfer-runtime::generate 将模型、tokenizer、KV cache 与 sampler 连接为一个生成循环。实际实现用 GenerateOptions { max_new_tokens, sampling, seed, eos_token_id } 控制生成长度、采样策略、随机种子与停止条件。流程可写成以下伪代码:
fn generate(prompt_tokens, max_new_tokens):
cache = model.make_cache(prompt_len + max_new_tokens)
// 1. prefill:一次处理整个 prompt
logits_or_argmax = model.decode(cache, prompt_tokens, pos_base=0)
next = sample(last_logits)
// 2. decode:一次追加一个 token
output = []
for step in 0..max_new_tokens:
if next == eos: break
output.push(next)
logits_or_argmax = model.decode(cache, [next], pos_base=cache.len)
next = sample(last_logits)
return tokenizer.decode(output)
工程中的 model.decode 表示“将 tokens 追加到 KV cache 后执行一次 decoder forward”。prefill 调用输入多个 prompt token;decode loop 每次输入一个新 token。pos_base 必须等于当前 cache.len(),否则位置编码和 cache 写入都会失去一致性。
13.4Sampling:Rust 中的可复现采样器
采样器把最后一行 logits 转换为下一个 token id。SamplingConfig 只有三个字段:temperature、top_k、top_p;temperature <= 0.0 选择 greedy。Phase 4 实现的策略包括:
- **Greedy:**取最大 logit;
- **Temperature:**缩放 logits;
- **Top-k:**只保留最高的 个候选;
- **Top-p:**只保留累积概率达到 的最小候选集合;
- **可复现 RNG:**使用简单 xorshift,使实验结果可重复。
temperature 公式:
top-p 集合:
Greedy 路径使用 QwenModel::decode_argmax:GPU 对最后一行 logits 做 argmax,CPU 只读回一个 u32 token id。Qwen2.5 的词表为 151936,完整 f32 logits 约 608 KB/token;该 fast path 避免了这次 readback。temperature、top-k、top-p 仍调用 decode 读回完整 logits,并在 CPU 上采样。
13.5Tokenizer 与 Qwen chat template
Runtime 的输入输出边界是文本,模型内部处理的是 token id。xinfer-tokenizer 封装 HuggingFace tokenizers,从模型目录读取 tokenizer.json,并提供 encode/decode:
let tokenizer = Tokenizer::from_dir(model_dir)?;
let ids = tokenizer.encode(prompt_text, false)?;
let text = tokenizer.decode(&generated_ids, true)?;
对 instruct 模型,CLI 默认使用 qwen_chat_prompt(system, user) 包装 system、user 与 assistant 起始标记:
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
... user prompt ...<|im_end|>
<|im_start|>assistant
read_eos_token 从 generation_config.json 或 config.json 读取 eos_token_id。生成循环在采样到 EOS 时停止,并在统计中标记 hit_eos。
13.6第一次真实端到端 demo
Phase 4 的验收对象是真实 Qwen2.5 checkpoint 的端到端路径:加载配置与 safetensors,构建 tokenizer 和 GPU model,分配 KV cache,执行 prefill/decode,并在 CLI 中流式输出生成文本。
xinfer generate --model models\Qwen2.5-0.5B-Instruct \
--prompt "What is the capital of France?" --max-tokens 16
The capital of France is Paris.
一次实际运行还会打印加载时间、GPU resident weights、prompt token 数、KV cache 大小以及 prefill/decode 吞吐。例如 Qwen2.5-0.5B 的一次记录为:Loaded in 4.0s (942 MiB GPU weights)、Prompt tokens: 26 (KV cache 2 MiB)、prefill: 26 tok in 69 ms (378 tok/s) | decode: 7 tok in 0.23 s (30.5 tok/s) [eos]。从系统角度看,命令经过如下路径:
Lab 13生成文本,并验证 incremental decode == full forward
本实验包含两个验收目标。第一,真实模型能完成端到端生成。第二,KV cache incremental decode 必须与 full forward 的最后一行 logits 匹配;该测试直接覆盖 cache 追加、RoPE 位置偏移和 causal mask。
# 真实模型生成
cargo run --release -p xinfer-cli -- generate \
--model models\Qwen2.5-0.5B-Instruct \
--prompt "What is the capital of France?" \
--max-tokens 16
# KV cache correctness
cargo test -p xinfer-model --test forward_parity -- --nocapture
正确性测试的数学意义是:
也就是说,先一次性 forward 全部 tokens,和先 prefill 前 个 token 再 decode 最后一个 token,最后位置的 logits 应该一致。
小结
本章把 Phase 3 的完整 forward 扩展为可生成文本的 runtime。KvCache 为每层保存 K/V;attention kernel 接受不同的 seq_q 与 seq_k;generate 使用 GenerateOptions 和 SamplingConfig 管理 prefill、decode、采样与 EOS;greedy 路径通过 GPU argmax 只读回一个 token id;tokenizer 与 Qwen chat template 构成文本输入输出边界。至此,引擎已经能够对真实 Qwen2.5 checkpoint 执行端到端推理。
思考与练习
基础为什么每一层都需要自己的 KV cache?
因为每层有各自独立的 K/V 投影权重,产生不同的 key/value 表示。第 层的 attention 只能用第 层历史 token 的 K/V,不能跨层共用。所以要为 24 层各保存一份 KV cache,decode 时各层把新 token 的 K/V 追加到对应层的缓存里。
基础prefill 与 decode 调用的是同一个 model.decode,为什么输入 shape 不同?
prefill 一次喂入整个 prompt,seq_q = prompt 长度(如 128),同时把这 128 个 K/V 写入空缓存(seq_k 也是 128)。decode 每步只喂 1 个新 token,seq_q=1,而 seq_k = 已有历史长度+1。同一函数通过 seq_q≠seq_k 与 q_pos_base 两个参数统一处理两种情形,attention kernel 据此设置 causal 上限。
进阶推导 key_count = min(q_pos_base + tq + 1, seq_k) 的含义。
query 在序列中的绝对位置是 (tq 是本批内的相对索引)。causal 规则下它能看见位置 ,即 个 key,故可见数为 。再对 seq_k(缓存中实际存在的 key 数)取 min,防止越界。decode 时 , = 历史长度,于是 key_count = 历史长度+1 = 全部缓存,符合“新 token 能看到所有过去 + 自己”。
进阶解释 greedy fast path 为什么可以只读回一个 u32。
greedy 解码只需要 argmax(logits) 这一个 token id,不需要整条概率分布。于是在 GPU 上做归约求出最大 logit 的下标,只把这个 4 字节的 u32 拷回 CPU,而不是把 151936 个 f32(约 608 KB)全读回再在 CPU 上找 max。这把每 token 的 readback 从约 608 KB 降到 4 字节,是 greedy decode 达到约 30.5 tok/s 的关键优化之一。temperature、top-k、top-p 采样仍需概率分布信息,因此保留完整 logits readback 与 CPU 采样路径。
挑战实现一个测试:随机 prompt 下 full forward 最后一行 logits 与 incremental decode logits 的 max abs diff 小于阈值。
步骤:① 随机生成一段 token 序列作为 prompt;② 路径 A:对完整序列做一次 full forward,取最后一个位置的 logits;③ 路径 B:先 prefill 前 个 token 建立 KV cache,再对第 个 token 调 decode 拿 logits;④ 断言 max(abs(A - B)) < eps。
理论上 incremental decode 与 full forward 在数学上等价,f32 路径 diff 应为 0(xinfer 的 kv_cache_decode_matches_full 实测 diff=0);若用 f16 kernel 则放宽到 ~1e-2。这个测试守护 KV cache 追加、位置编码偏移与 causal mask 的正确性——它们是 decode 最易出错的环节。