Module 11:Phase 2 — 把 Transformer 算子写成 Kernel
CPU reference 先行:RMSNorm、RoPE、SwiGLU、Residual Add、GQA Attention、Matmul 的 GPU kernel 与数值校验。
学习目标
- 理解为什么每个 GPU kernel 都要先有 CPU reference 作为 correctness oracle。
- 能把 RMSNorm、RoPE、SwiGLU、residual add、GQA attention 拆成可测试的输入/输出 shape。
- 理解 matmul / transposed-weight linear 的 shape 约定,以及为什么后来改成 custom HLSL matmul。
- 理解数值误差来源:浮点舍入、归约顺序、f16 权重、softmax 稳定性。
- 完成 Lab 11:实现 kernel 并通过 GPU-vs-CPU parity tests。
11.1CPU reference first:正确性 oracle
GPU kernel 的主要风险不在于 dispatch 失败,而在于结果以很小或很隐蔽的方式偏离数学定义。GPU 调试成本高,因此每个 kernel 都应先有一个简单、串行、可审查的 CPU reference。
xinfer 的 CPU reference 位于 crates\xinfer-core\src\cpu.rs,作为 xinfer-dml parity tests 的 correctness oracle。该文件采用统一的 row-major contiguous layout:
- hidden states:
[seq, hidden] - Q:
[seq_q, n_heads, head_dim] - K/V:
[seq_k, n_kv_heads, head_dim]
crates\xinfer-dml\tests\ops_parity.rs 对 RMSNorm、RoPE、SwiGLU、add 和 attention 都采用同一结构:
11.2HLSL kernels:逐个算子落地
Phase 2 将 Module 2 的数学算子逐个落到 HLSL。设计 kernel 时需要先固定三个条件:输入/输出 shape、线程或 threadgroup 到数据的映射、是否需要组内同步。
| 算子 | 数学定义 | GPU 切分方式 | 文件 |
|---|---|---|---|
| RMSNorm | 每行一个 threadgroup,组内归约 | rms_norm.hlsl | |
| RoPE | 二维旋转 | 每个 (token, head) 一个 thread | rope.hlsl |
| SwiGLU | 每个元素一个 thread | swiglu.hlsl | |
| Residual Add | 每个元素一个 thread | add.hlsl | |
| GQA Attention | masked softmax + weighted V | 每个 (query, head) 一个 thread | attention.hlsl |
| Linear | Phase 2 先走 DML GEMM;后续改 HLSL | gemm.rs / linear*.hlsl |
RMSNorm kernel
CPU 公式是:
实际 rms_norm.hlsl 采用一个 64-thread threadgroup 处理一个 token row。线程按 stride 遍历 hidden 维,先把平方和写入 groupshared 数组并做树形归约,再用同一个 threadgroup 写回归一化后的整行。测试输入和输出 shape 均为 [seq, hidden]。
RoPE kernel
RoPE 对 Q/K 的每个 head 做旋转。GPT-NeoX / Qwen2 约定把 head_dim 的前半和后半配对:
rope.hlsl 是 in-place kernel:RWStructuredBuffer<float> X 被直接改写。一个线程负责一个 (token, head),在该 head 内遍历 head_dim/2 对元素;pos_base 使同一 kernel 可用于 prefill 与 KV-cache decode。
GQA attention kernel
注意力 kernel 要处理 seq_q 与 seq_k 可能不同的情况,这为 KV cache 做准备。 对 query row ,可见 key 数是:
attention.hlsl 使用一个线程处理一个 (query token, query head),并以 online softmax 方式单次流式读取 K/V:维护 running max、denominator 和加权 V 累加器。Phase 2 的重点是 GQA head mapping、causal limit 与 softmax 稳定性正确。
11.3Matmul 与 transposed-weight linear
HuggingFace 权重通常按 [out_features, in_features] 存储,因此线性层是:
如果 ,,则:
Phase 2 先用 DirectML GEMM / transposed-B 路径验证 linear 层。项目后续在部分 RDNA4 shape 上遇到 DirectML GEMM device removed,因此 Phase 3 之后 matmul 迁 移到 custom HLSL。linear_f16.hlsl 采用 64-thread threadgroup 计算一个输出元素,沿 K 维做 coalesced strided reduction,并用 2D grid(idx = gy*grid_x + gx)绕过单轴 65535 threadgroup 上限。
11.4数值 parity:为什么很少 bit-exact?
CPU 与 GPU 执行同一数学计算时,通常不要求逐 bit 相同。差异主要来自:
- 浮点加法不满足结合律: 与 可能不同;
- GPU 归约顺序可能与 CPU loop 不同;
- f16 / bf16 权重会引入舍入;
- softmax / exp 对微小误差敏感;
- 不同硬件可能使用不同 fused multiply-add 行为。
因此测试通常设定容差:
选择 时要结合 dtype、算子和归约方式。纯 f32 的 elementwise add 可以使用较小容差;attention、matmul 和 f16 权重通常需要更宽容差。项目级验证还应覆盖整模型:xinfer 的 README 记录 full Qwen2 forward GPU vs CPU 的最大 logit diff 约为 ,incremental KV-cache decode 与 full forward 差异为 0,streamed layers 与 resident layers 差异为 0。
Lab 11实现每个 kernel,通过 GPU-vs-CPU parity
本实验要求为每个 CPU reference 提供 GPU 版本,并通过 ops_parity。建议按依赖和调试难度递增的顺序实现:
add.hlsl:逐元素 add,最简单。swiglu.hlsl:逐元素 SiLU 与乘法。rms_norm.hlsl:先写单线程版本,再改成 threadgroup 归约。rope.hlsl:in-place 旋转,测试 L2 norm 保持。attention.hlsl:GQA + causal mask + softmax 稳定性。linear:先验证 GEMM shape,再实现/替换 custom kernel。
# 关键验收命令
cargo test -p xinfer-dml --test ops_parity -- --nocapture
# 你应该看到:
rms_norm_matches_cpu ... ok
rope_matches_cpu ... ok
swiglu_matches_cpu ... ok
add_matches_cpu ... ok
attention_matches_cpu ... ok
parity 失败时,优先检查 shape、row-major offset、head mapping、causal limit 和 softmax 是否减 max。多数 kernel 错误来自 indexing、边界或 resource state,而不是 GPU 驱动本身。
小结
Module 11 将 transformer 的核心算子落实为可单独验证的 GPU kernel。xinfer_core::cpu 提供 RMSNorm、RoPE、SwiGLU、add、matmul 和 GQA attention 的参考实现;xinfer-dml 中的 HLSL kernel 逐项通过 parity tests。数值验证不追求 bit-exact,而是以 shape、layout、head mapping、causal mask 和容差为中心建立证据链;整模型级别还需要比较 full forward、KV-cache decode 与 streamed/resident layers 的一致性。
思考与练习
基础为什么 CPU reference 必须比 GPU kernel 更简单?
因为 CPU reference 是正确性的基准。它应使用直白循环、串行控制流和 f32 计算,便于审查和手算对照。GPU kernel 为性能引入并行、分块、组内归约、f16 权重或 online softmax,复杂度更高;只有 CPU 侧足够简单,parity 失败时才能把问题主要定位到 GPU 实现。
基础写出 transposed-weight linear 的 shape 关系。
权重以 (out×in)存储,计算 :,,。即对每个输出元素 ——注意 W 的两个下标都按行存,正好让同一输出列读取 W 的连续一行,利于 coalescing。
进阶解释 GQA 中 query head 到 KV head 的映射公式。
设 group size 。query head 使用的 KV head 为 。Qwen2.5-0.5B 中 :head 0–6 → KV 0,head 7–13 → KV 1。attention 计算时 Q 用 索引、K/V 用 索引。
进阶为什么 attention softmax 必须减去最大 score?
数值稳定性。softmax 要算 ,若 较大, 会溢出成 inf。减去最大值 后算 ,最大项变成 ,其余 ,不会上溢;由于分子分母同乘 ,结果数学上不变。这就是“safe softmax”,flash attention 的 online softmax 也据此维护 running max 并 rescale。
挑战构造一个小输入,手算 causal attention 输出,并用它作为单元测试。
取 1 head,,两个 token。令 ,,,缩放 。
token 0(只能看自己):score=,softmax 单元素=1,输出 。
token 1(看 0,1):scores ;,;输出 。
代入具体数(如 ):token1 scores=[1,2],,输出 。把这些常数硬 编码进测试,断言 kernel 输出在容差内匹配 [3.0, 4.46],并验证 token0 不会看到 token1(causal)。