Building an LLM Inference Engine from ScratchPart IV / Module 11
Part IV · Building xinfer

Module 11:Phase 2 — 把 Transformer 算子写成 Kernel

CPU reference 先行:RMSNorm、RoPE、SwiGLU、Residual Add、GQA Attention、Matmul 的 GPU kernel 与数值校验。

学习目标

  • 理解为什么每个 GPU kernel 都要先有 CPU reference 作为 correctness oracle。
  • 能把 RMSNorm、RoPE、SwiGLU、residual add、GQA attention 拆成可测试的输入/输出 shape。
  • 理解 matmul / transposed-weight linear 的 shape 约定,以及为什么后来改成 custom HLSL matmul。
  • 理解数值误差来源:浮点舍入、归约顺序、f16 权重、softmax 稳定性。
  • 完成 Lab 11:实现 kernel 并通过 GPU-vs-CPU parity tests。

11.1CPU reference first:正确性 oracle

GPU kernel 的主要风险不在于 dispatch 失败,而在于结果以很小或很隐蔽的方式偏离数学定义。GPU 调试成本高,因此每个 kernel 都应先有一个简单、串行、可审查的 CPU reference。

xinfer 的 CPU reference 位于 crates\xinfer-core\src\cpu.rs,作为 xinfer-dml parity tests 的 correctness oracle。该文件采用统一的 row-major contiguous layout:

  • hidden states:[seq, hidden]
  • Q:[seq_q, n_heads, head_dim]
  • K/V:[seq_k, n_kv_heads, head_dim]

crates\xinfer-dml\tests\ops_parity.rs 对 RMSNorm、RoPE、SwiGLU、add 和 attention 都采用同一结构:

inputCPU referenceycpu,inputGPU kernelygpu,ygpuycpu<ε\text{input} \xrightarrow{\text{CPU reference}} y_{\text{cpu}}, \qquad \text{input} \xrightarrow{\text{GPU kernel}} y_{\text{gpu}}, \qquad \lVert y_{\text{gpu}}-y_{\text{cpu}}\rVert_\infty < \varepsilon
Test input deterministic tensors CPU reference simple loops GPU kernel HLSL dispatch Compare max abs diff Pass? ε tolerance CPU reference 不追求快;它追求简单、可信、容易审查。
图 11-1:每个 GPU kernel 都要通过 CPU reference 的数值校验。

11.2HLSL kernels:逐个算子落地

Phase 2 将 Module 2 的数学算子逐个落到 HLSL。设计 kernel 时需要先固定三个条件:输入/输出 shape、线程或 threadgroup 到数据的映射、是否需要组内同步。

算子数学定义GPU 切分方式文件
RMSNormx/mean(x2)+εgx/\sqrt{\operatorname{mean}(x^2)+\varepsilon}\cdot g每行一个 threadgroup,组内归约rms_norm.hlsl
RoPE二维旋转每个 (token, head) 一个 threadrope.hlsl
SwiGLUSiLU(g)u\operatorname{SiLU}(g)\odot u每个元素一个 threadswiglu.hlsl
Residual Adda+ba+b每个元素一个 threadadd.hlsl
GQA Attentionmasked softmax + weighted V每个 (query, head) 一个 threadattention.hlsl
LinearXWXW^\topPhase 2 先走 DML GEMM;后续改 HLSLgemm.rs / linear*.hlsl

RMSNorm kernel

CPU 公式是:

yi=xi1Hjxj2+εgiy_i = \frac{x_i}{\sqrt{\frac{1}{H}\sum_j x_j^2+\varepsilon}}g_i

实际 rms_norm.hlsl 采用一个 64-thread threadgroup 处理一个 token row。线程按 stride 遍历 hidden 维,先把平方和写入 groupshared 数组并做树形归约,再用同一个 threadgroup 写回归一化后的整行。测试输入和输出 shape 均为 [seq, hidden]

RoPE kernel

RoPE 对 Q/K 的每个 head 做旋转。GPT-NeoX / Qwen2 约定把 head_dim 的前半和后半配对:

a=acosϕbsinϕ,b=bcosϕ+asinϕa' = a\cos\phi - b\sin\phi,\qquad b' = b\cos\phi + a\sin\phi

rope.hlsl 是 in-place kernel:RWStructuredBuffer<float> X 被直接改写。一个线程负责一个 (token, head),在该 head 内遍历 head_dim/2 对元素;pos_base 使同一 kernel 可用于 prefill 与 KV-cache decode。

GQA attention kernel

注意力 kernel 要处理 seq_qseq_k 可能不同的情况,这为 KV cache 做准备。 对 query row tqt_q,可见 key 数是:

key_count=min(q_pos_base+tq+1,  seq_k)\text{key\_count} = \min(q\_\text{pos\_base}+t_q+1,\; \text{seq\_k})

attention.hlsl 使用一个线程处理一个 (query token, query head),并以 online softmax 方式单次流式读取 K/V:维护 running max、denominator 和加权 V 累加器。Phase 2 的重点是 GQA head mapping、causal limit 与 softmax 稳定性正确。

Phase 2:把每个数学算子变成可测试的 kernel X [S,H] RMSNorm q_proj k_proj v_proj RoPE RoPE GQA Attnsoftmax + V o_proj RMSNorm gate_proj up_proj SwiGLU down_proj Add 绿色:自写 HLSL kernel;紫色:线性/matmul 路径;橙色:residual add。
图 11-2:Phase 2 把 decoder layer 中的每个数学算子变成可单独测试的 GPU kernel。

11.3Matmul 与 transposed-weight linear

HuggingFace 权重通常按 [out_features, in_features] 存储,因此线性层是:

Y=XWY = XW^\top

如果 XRM×KX\in\mathbb{R}^{M\times K}WRN×KW\in\mathbb{R}^{N\times K},则:

Yi,j=p=0K1Xi,pWj,p,YRM×NY_{i,j}=\sum_{p=0}^{K-1}X_{i,p}W_{j,p}, \qquad Y\in\mathbb{R}^{M\times N}

Phase 2 先用 DirectML GEMM / transposed-B 路径验证 linear 层。项目后续在部分 RDNA4 shape 上遇到 DirectML GEMM device removed,因此 Phase 3 之后 matmul 迁移到 custom HLSL。linear_f16.hlsl 采用 64-thread threadgroup 计算一个输出元素,沿 K 维做 coalesced strided reduction,并用 2D grid(idx = gy*grid_x + gx)绕过单轴 65535 threadgroup 上限。

11.4数值 parity:为什么很少 bit-exact?

CPU 与 GPU 执行同一数学计算时,通常不要求逐 bit 相同。差异主要来自:

  • 浮点加法不满足结合律:(a+b)+c(a+b)+ca+(b+c)a+(b+c) 可能不同;
  • GPU 归约顺序可能与 CPU loop 不同;
  • f16 / bf16 权重会引入舍入;
  • softmax / exp 对微小误差敏感;
  • 不同硬件可能使用不同 fused multiply-add 行为。

因此测试通常设定容差:

maxiyigpuyicpu<ε\max_i |y^{gpu}_i-y^{cpu}_i| < \varepsilon

选择 ε\varepsilon 时要结合 dtype、算子和归约方式。纯 f32 的 elementwise add 可以使用较小容差;attention、matmul 和 f16 权重通常需要更宽容差。项目级验证还应覆盖整模型:xinfer 的 README 记录 full Qwen2 forward GPU vs CPU 的最大 logit diff 约为 2×1072\times10^{-7},incremental KV-cache decode 与 full forward 差异为 0,streamed layers 与 resident layers 差异为 0。

图 11-3:不同算子适合不同容差;图中数值是教学示意,不是硬规则。

Lab 11实现每个 kernel,通过 GPU-vs-CPU parity

本实验要求为每个 CPU reference 提供 GPU 版本,并通过 ops_parity。建议按依赖和调试难度递增的顺序实现:

  1. add.hlsl:逐元素 add,最简单。
  2. swiglu.hlsl:逐元素 SiLU 与乘法。
  3. rms_norm.hlsl:先写单线程版本,再改成 threadgroup 归约。
  4. rope.hlsl:in-place 旋转,测试 L2 norm 保持。
  5. attention.hlsl:GQA + causal mask + softmax 稳定性。
  6. linear:先验证 GEMM shape,再实现/替换 custom kernel。
# 关键验收命令
cargo test -p xinfer-dml --test ops_parity -- --nocapture

# 你应该看到:
rms_norm_matches_cpu ... ok
rope_matches_cpu ... ok
swiglu_matches_cpu ... ok
add_matches_cpu ... ok
attention_matches_cpu ... ok
调试建议

parity 失败时,优先检查 shape、row-major offset、head mapping、causal limit 和 softmax 是否减 max。多数 kernel 错误来自 indexing、边界或 resource state,而不是 GPU 驱动本身。

小结

Module 11 将 transformer 的核心算子落实为可单独验证的 GPU kernel。xinfer_core::cpu 提供 RMSNorm、RoPE、SwiGLU、add、matmul 和 GQA attention 的参考实现;xinfer-dml 中的 HLSL kernel 逐项通过 parity tests。数值验证不追求 bit-exact,而是以 shape、layout、head mapping、causal mask 和容差为中心建立证据链;整模型级别还需要比较 full forward、KV-cache decode 与 streamed/resident layers 的一致性。

思考与练习

基础为什么 CPU reference 必须比 GPU kernel 更简单?

因为 CPU reference 是正确性的基准。它应使用直白循环、串行控制流和 f32 计算,便于审查和手算对照。GPU kernel 为性能引入并行、分块、组内归约、f16 权重或 online softmax,复杂度更高;只有 CPU 侧足够简单,parity 失败时才能把问题主要定位到 GPU 实现。

基础写出 transposed-weight linear 的 shape 关系。

权重以 WRN×KW\in\mathbb{R}^{N\times K}(out×in)存储,计算 Y=XWY=XW^\topXRM×KX\in\mathbb{R}^{M\times K}WRK×NW^\top\in\mathbb{R}^{K\times N}YRM×NY\in\mathbb{R}^{M\times N}。即对每个输出元素 Ym,n=kXm,kWn,kY_{m,n}=\sum_k X_{m,k}W_{n,k}——注意 W 的两个下标都按行存,正好让同一输出列读取 W 的连续一行,利于 coalescing。

进阶解释 GQA 中 query head 到 KV head 的映射公式。

设 group size g=nheads/nkvg=n_{\text{heads}}/n_{\text{kv}}。query head hh 使用的 KV head 为 kv(h)=h/g\text{kv}(h)=\lfloor h/g\rfloor。Qwen2.5-0.5B 中 g=14/2=7g=14/2=7:head 0–6 → KV 0,head 7–13 → KV 1。attention 计算时 Q 用 hh 索引、K/V 用 h/g\lfloor h/g\rfloor 索引。

进阶为什么 attention softmax 必须减去最大 score?

数值稳定性。softmax 要算 esie^{s_i},若 sis_i 较大,esie^{s_i} 会溢出成 inf。减去最大值 m=maxisim=\max_i s_i 后算 esime^{s_i-m},最大项变成 e0=1e^0=1,其余 1\le1,不会上溢;由于分子分母同乘 eme^{-m},结果数学上不变。这就是“safe softmax”,flash attention 的 online softmax 也据此维护 running max 并 rescale。

挑战构造一个小输入,手算 causal attention 输出,并用它作为单元测试。

取 1 head,d=1d=1,两个 token。令 Q=[q0,q1]Q=[q_0,q_1]K=[k0,k1]K=[k_0,k_1]V=[v0,v1]V=[v_0,v_1],缩放 1/1=11/\sqrt{1}=1

token 0(只能看自己):score=q0k0q_0k_0,softmax 单元素=1,输出 =v0=v_0

token 1(看 0,1):scores s0=q1k0, s1=q1k1s_0=q_1k_0,\ s_1=q_1k_1w0=es0mes0m+es1mw_0=\frac{e^{s_0-m}}{e^{s_0-m}+e^{s_1-m}}w1=1w0w_1=1-w_0;输出 =w0v0+w1v1=w_0v_0+w_1v_1

代入具体数(如 q=[1,1],k=[1,2],v=[3,5]q=[1,1],k=[1,2],v=[3,5]):token1 scores=[1,2],w=softmax([1,2])[0.269,0.731]w=\text{softmax}([1,2])\approx[0.269,0.731],输出 0.2693+0.73154.46\approx0.269\cdot3+0.731\cdot5\approx4.46。把这些常数硬编码进测试,断言 kernel 输出在容差内匹配 [3.0, 4.46],并验证 token0 不会看到 token1(causal)。