Building an LLM Inference Engine from ScratchPart I / Module 2
Part I · Foundations

Module 2:Transformer Block,算子逐个看

从 embedding 和 residual stream 出发,准确推导 RMSNorm、linear、attention、GQA、RoPE、SwiGLU 与 LM head 的数学形状。

学习目标

  • 能从张量 shape 角度解释 decoder layer 中每个算子的输入和输出。
  • 理解 RMSNorm、RoPE、GQA、SwiGLU 的数学定义和工程直觉。
  • 能把公式对应到 xinfer_core::cpu 的参考实现。
  • 能解释为什么 matmul / GEMV 是推理性能的核心工作负载。
  • 完成 Lab 2:实现每个 op 的 CPU 参考函数并写单元测试。

2.1Embedding 与 residual stream

tokenizer 输出的是整数 token id。进入 transformer 前,模型用 embedding table 做一次查表,将每个 id 映射为 hidden vector:

X0[t]=E[xt],ERV×dmodelX_0[t] = E[x_t], \qquad E \in \mathbb{R}^{|V|\times d_{\text{model}}}

这里 X0RS×HX_0\in\mathbb{R}^{S\times H}SS 是序列长度,H=dmodelH=d_{\text{model}} 是 hidden size。Qwen2.5-0.5B-Instruct 的 config.json 给出 H=896H=896V=151936|V|=151936、24 个 decoder layer、14 个 query head、2 个 KV head、intermediate size 4864、rope_theta=1000000.0rms_norm_eps=1e-6tie_word_embeddings=true。这条 S×HS\times H 张量在各层之间通过 residual add 传递,通常称为 residual stream

Xℓ [S,H] RMSNorm attention norm Attention QKV + RoPE + GQA + Uℓ [S,H] RMSNorm mlp norm MLP gate/up + SwiGLU + Xℓ+1 [S,H] 橙色虚线是 residual skip:子层输出只负责提供“要加上的变化量”
图 2-1:一个 decoder layer 由 attention 子层和 MLP 子层组成,每个子层都带 residual add。

Qwen2 的 decoder layer 采用 pre-norm 结构。归一化先作用于子层输入,子层输出再加回 residual stream:

U=X+AttentionBlock(Norm(X))X+1=U+MLP(Norm(U))\begin{aligned} U_\ell &= X_\ell + \operatorname{AttentionBlock}_\ell(\operatorname{Norm}(X_\ell)) \\ X_{\ell+1} &= U_\ell + \operatorname{MLP}_\ell(\operatorname{Norm}(U_\ell)) \end{aligned}

一层 Qwen2 decoder 的完整展开

图 2-1 省略了具体投影矩阵。按 xinfer-modelrun_forward 顺序,一层 Qwen2 decoder 先对 XX_\ell 做 attention RMSNorm,再计算 Q/K/V;Q 和 K 应用 RoPE;GQA attention 的输出经 o_proj 回到 hidden size,并与原始 XX_\ell 相加得到 UU_\ell

A=RMSNorm(X;gattn)Q=AWQ+bQ,K=AWK+bK,V=AWV+bVQ~=RoPE(Q),K~=RoPE(K)C=GQA(Q~,K~,V)O=CWOU=X+O\begin{aligned} A &= \operatorname{RMSNorm}(X_\ell; g_{\text{attn}}) \\ Q &= A W_Q^\top + b_Q,\quad K = A W_K^\top + b_K,\quad V = A W_V^\top + b_V \\ \tilde Q &= \operatorname{RoPE}(Q),\quad \tilde K = \operatorname{RoPE}(K) \\ C &= \operatorname{GQA}(\tilde Q,\tilde K,V) \\ O &= C W_O^\top \\ U_\ell &= X_\ell + O \end{aligned}

随后对 UU_\ell 做 MLP RMSNorm。Qwen2 的 MLP 使用 gated 结构:gate_projup_proj 产生两个 [S,4864][S,4864] 张量,SwiGLU 逐元素相乘后,再由 down_proj 投回 [S,896][S,896]

B=RMSNorm(U;gmlp)G=BWgate,P=BWupS=SiLU(G)PD=SWdownX+1=U+D\begin{aligned} B &= \operatorname{RMSNorm}(U_\ell; g_{\text{mlp}}) \\ G &= B W_{\text{gate}}^\top,\quad P = B W_{\text{up}}^\top \\ S &= \operatorname{SiLU}(G)\odot P \\ D &= S W_{\text{down}}^\top \\ X_{\ell+1} &= U_\ell + D \end{aligned}
Qwen2 decoder layer:Attention 子层 + MLP 子层的算子位置 Xℓ RMSNorm q_proj k_proj v_proj RoPE RoPE GQAsoftmax(V) o_proj + Uℓ RMSNorm gate_proj up_proj SwiGLUSiLU(gate) ⊙ up down_proj + Xℓ+1 紫色是线性层(matmul/GEMV);白色是逐元素或归一化算子;橙色圆圈是 residual add。
图 2-2:每个算子在真实 Qwen2 decoder layer 中的位置。Attention 子层先产生 $U_\ell$,MLP 子层再产生 $X_{\ell+1}$。
步骤符号典型 shape(Qwen2.5-0.5B)对应算子
输入 residualX_\ell[S,896][S,896]上一层输出
Attention normAA[S,896][S,896]RMSNorm
Q 投影QQ[S,14,64][S,14,64]linear + bias
K/V 投影K,VK,V[S,2,64][S,2,64]linear + bias(GQA)
位置编码\tilde Q,\tilde K同 Q/KRoPE
注意力输出CC[S,14,64]\Rightarrow[S,896]GQA attention
输出投影 + 残差U_\ell[S,896][S,896]o_proj + add
MLP normBB[S,896][S,896]RMSNorm
门控 MLPSS[S,4864][S,4864]gate/up + SwiGLU
down 投影 + 残差X_{\ell+1}[S,896][S,896]down_proj + add

2.2RMSNorm:只归一化均方根

Qwen2 使用 RMSNorm。对一行 hidden vector xRHx\in\mathbb{R}^H,它按均方根缩放每个分量:

RMSNorm(x)i=xi1Hj=1Hxj2+ε  gi\operatorname{RMSNorm}(x)_i = \frac{x_i}{\sqrt{\frac{1}{H}\sum_{j=1}^{H}x_j^2+\varepsilon}}\; g_i

其中 gRHg\in\mathbb{R}^H 是可学习的 scale 参数。RMSNorm 不减均值,也不使用 bias;它只用平方均值控制向量尺度。与 LayerNorm 相比,这少了一次 mean reduction。

实现对应

xinfer_core::cpu::rms_norm 的外层循环遍历 token row,内层先计算 mean_sq,再写出 row[i] * inv * weight[i]shaders/rms_norm.hlsl 用 64 个线程组成一个 threadgroup 处理一行,在 groupshared memory 中归约平方和。

2.3Linear layer / matmul:主要计算负载

Transformer block 中的大部分浮点计算来自线性层。给定输入 XRS×KX\in\mathbb{R}^{S\times K},HuggingFace 权重通常存为 WRN×KW\in\mathbb{R}^{N\times K}(out_features 在前),因此实现中计算的是:

Y=XW+b,YRS×NY = XW^\top + b,\qquad Y\in\mathbb{R}^{S\times N}

prefill 阶段 SS 等于 prompt 长度,线性层表现为 GEMM;decode 阶段通常 S=1S=1,线性层退化为 GEMV,即一个 activation vector 乘大矩阵。Qwen2.5-0.5B 的 LM head 为 896151936896\rightarrow151936,单个 token 需要约 896×1519361.36×108896\times151936\approx1.36\times10^8 次乘加。

图 2-3:decode 阶段许多线性层是 GEMV;LM head 因词表巨大而特别重。

2.4Self-attention:Q / K / V、缩放点积与因果 mask

Self-attention 先从同一组 hidden states 生成 query、key、value:

Q=XWQ,K=XWK,V=XWVQ=XW_Q^\top,\qquad K=XW_K^\top,\qquad V=XW_V^\top

对单个 head,query qtq_t 与每个 key kjk_j 做点积,得到相似度分数:

st,j=qtkjdheads_{t,j} = \frac{q_t k_j^\top}{\sqrt{d_{\text{head}}}}

除以 dhead\sqrt{d_{\text{head}}} 可以控制点积方差,避免 head_dim 增大后 softmax 过早饱和。加上 causal mask 后,位置 tt 只能读取 0..t0..t 的 key:

st,j={qtkj/dhead,jt,j>ts_{t,j} = \begin{cases} q_tk_j^\top/\sqrt{d_{\text{head}}}, & j\le t \\ -\infty, & j>t \end{cases}

然后对 jj 做 softmax,得到权重 αt,j\alpha_{t,j},再加权求和 value:

Attn(qt,K,V)=jtαt,jvj\operatorname{Attn}(q_t,K,V)=\sum_{j\le t}\alpha_{t,j}v_j
Causal Attention Mask(深色 = 可见,浅色 = 被 mask) query 0 query 1 query 2 query 3 query 4 key 0 key 1 key 2 key 3 key 4 第 t 行只能看 key 0..t:这就是 decoder-only 模型的自回归约束。
图 2-4:因果 mask 保证模型不能偷看未来 token。

2.5Multi-head 与 Grouped-Query Attention(GQA)

Multi-head attention 将 hidden 维度拆成若干 head。每个 query head 在独立子空间中计算 attention:

H=nheadsdheadH = n_{\text{heads}}\cdot d_{\text{head}}

Qwen2.5-0.5B 中 nheads=14n_{\text{heads}}=14dhead=64d_{\text{head}}=64,所以 H=896H=896。它只使用 nkv=2n_{\text{kv}}=2 个 key/value head;每 7 个 query head 共享 1 个 KV head。这种结构称为 Grouped-Query Attention(GQA)。

kv_head(h)=hnheads/nkv\text{kv\_head}(h)=\left\lfloor \frac{h}{n_{\text{heads}}/n_{\text{kv}}} \right\rfloor

GQA 保留较多 query head,同时减少需要存入 KV cache 的 K/V head 数。对该模型,KV cache 的 head 数从 14 降到 2,容量与读取带宽按比例降低为 MHA 的 2/142/14

图 2-5:GQA 减少 KV head 数,直接降低 KV cache 显存。

2.6RoPE:把位置编码成旋转

纯 attention 对序列顺序没有内建感知。RoPE(Rotary Position Embedding)不把位置向量加到 hidden 上,而是在 Q/K 的 head 维度中按二维平面旋转。对一对维度 (a,b)(a,b),位置 pp 对应的旋转为:

[ab]=[cosθpsinθpsinθpcosθp][ab]\begin{bmatrix} a'\\ b' \end{bmatrix} = \begin{bmatrix} \cos\theta_p & -\sin\theta_p\\ \sin\theta_p & \cos\theta_p \end{bmatrix} \begin{bmatrix} a\\ b \end{bmatrix}

xinfer_core::cpu::ropeshaders/rope.hlsl 采用 Qwen2/HuggingFace 的 GPT-NeoX rotate_half 约定:前半维与后半维配对。频率由 rope_theta 控制,Qwen2.5-0.5B 中该值为 1000000.0:

θp,i=pθbase2i/dhead\theta_{p,i}=p\cdot \theta_{\text{base}}^{-2i/d_{\text{head}}}
RoPE 的作用

RoPE 将绝对位置写入 Q/K 的相位。两个位置的相对距离会反映到 qkq^\top k 的相位差中,使 attention score 能包含相对位置信息。

2.7MLP / FFN 与 SwiGLU

Attention 在 token 之间聚合信息;MLP 对每个 token 的 hidden vector 独立做非线性变换。Qwen2 使用 gated MLP,其逐元素非线性为 SwiGLU:

SiLU(x)=xσ(x),SwiGLU(x)=SiLU(xWg)(xWu)\operatorname{SiLU}(x)=x\cdot\sigma(x),\qquad \operatorname{SwiGLU}(x)=\operatorname{SiLU}(xW_g^\top)\odot(xW_u^\top)

然后再投影回 hidden size:

MLP(x)=SwiGLU(x)Wd\operatorname{MLP}(x)=\operatorname{SwiGLU}(x)W_d^\top

其中 \odot 是逐元素乘法。up_proj 产生内容分支,gate_proj 经 SiLU 后产生门控分支;两者相乘后再由 down_proj 回到 hidden size。shaders/swiglu.hlsl 对每个元素执行 silu(gate) * up

2.8LM head 与 tied embeddings

24 个 decoder layer 结束后,模型得到 XLX_L。final RMSNorm 之后,LM head 将每个 hidden vector 投影到词表维度:

Z=Norm(XL)Wlm,ZRS×VZ = \operatorname{Norm}(X_L)W_{\text{lm}}^\top,\qquad Z\in\mathbb{R}^{S\times |V|}

自回归生成只使用最后一行 Z[S1]Z[S-1],因为该行给出下一个 token 的 logits。Qwen2.5-0.5B 的 tie_word_embeddings=truexinfer-model 在这种情况下直接使用 embedding 权重作为 LM head,即 Wlm=EW_{\text{lm}}=E

性能提醒

LM head 的输出维度等于词表大小。Qwen2.5-0.5B 的 V=151936|V|=151936;若完整 logits 以 f32 读回 CPU,每个 token 约为 151936×4608151936\times4\approx608 KB。greedy decoding 使用 GPU argmax 时,只需读回一个 u32 token id。

小结

本模块把 Qwen2.5-0.5B-Instruct 的 decoder layer 拆成可验证的算子链:embedding 查表得到 [S,896][S,896] residual stream;每层先执行 RMSNorm、Q/K/V projection、RoPE、GQA attention、o_proj 与 residual add,再执行 RMSNorm、gate/up projection、SwiGLU、down_proj 与 residual add。GQA 使 14 个 query head 共享 2 个 KV head,降低 KV cache 容量与读取带宽。最终的 LM head 使用 tied embeddings,将 hidden vector 投影到 151936 维 logits。

Lab 2实现每个 op 的 CPU 参考函数

本实验对应 xinfer-core::cpu。目标是实现可读、确定的 FP32 reference op,作为 DirectML/HLSL kernel 的 correctness oracle。

函数输入 shape输出 shape测试建议
rms_norm[seq, hidden][seq, hidden]单位权重时 RMS 接近 1
matmul[m,k]×[k,n][m,n]乘单位矩阵不变
rope[seq, heads, head_dim]in-place旋转保持 L2 norm
attentionQ,K,V[seq_q, heads, head_dim]mask 后不能看未来
swiglugate, upsame与手算 SiLU 对比
pub fn rms_norm(x: &[f32], weight: &[f32], seq: usize, hidden: usize, eps: f32) -> Vec<f32> {
    let mut out = vec![0.0; x.len()];
    for t in 0..seq {
        let row = &x[t * hidden..(t + 1) * hidden];
        let mean_sq = row.iter().map(|v| v * v).sum::<f32>() / hidden as f32;
        let inv = 1.0 / (mean_sq + eps).sqrt();
        for i in 0..hidden {
            out[t * hidden + i] = row[i] * inv * weight[i];
        }
    }
    out
}

思考与练习

基础写出 RMSNorm 与 LayerNorm 的主要区别。

LayerNorm 同时减去均值并除以标准差:xμσ2+εγ+β\frac{x-\mu}{\sqrt{\sigma^2+\varepsilon}}\gamma+\beta,需要计算均值和方差,并有 scale 和 bias 两个参数。RMSNorm 不减均值,只用均方根缩放:x1Hxj2+εg\frac{x}{\sqrt{\frac{1}{H}\sum x_j^2+\varepsilon}}g,只有一个 scale 参数 gg,没有 bias。RMSNorm 少一次 mean reduction,计算更简单、更省,Qwen2 与多数现代 LLM 都用它。

基础XR4×8X\in\mathbb{R}^{4\times 8}WR16×8W\in\mathbb{R}^{16\times 8},则 XWXW^\top 的 shape 是什么?

WR8×16W^\top\in\mathbb{R}^{8\times16},所以 XWR4×16XW^\top\in\mathbb{R}^{4\times16}。这里 XX 的列数 8 是 in_features KKWW 的行数 16 是 out_features NN,结果是 [seq=4, out=16]

进阶证明二维旋转矩阵保持向量长度:a2+b2=a2+b2a^2+b^2=a'^2+b'^2

a=acosθbsinθa'=a\cos\theta-b\sin\thetab=bcosθ+asinθb'=b\cos\theta+a\sin\theta。则

a2+b2=(acosθbsinθ)2+(bcosθ+asinθ)2a'^2+b'^2=(a\cos\theta-b\sin\theta)^2+(b\cos\theta+a\sin\theta)^2

=a2cos2θ2abcosθsinθ+b2sin2θ+b2cos2θ+2abcosθsinθ+a2sin2θ=a^2\cos^2\theta-2ab\cos\theta\sin\theta+b^2\sin^2\theta + b^2\cos^2\theta+2ab\cos\theta\sin\theta+a^2\sin^2\theta

交叉项相消,余下 a2(cos2θ+sin2θ)+b2(cos2θ+sin2θ)=a2+b2a^2(\cos^2\theta+\sin^2\theta)+b^2(\cos^2\theta+\sin^2\theta)=a^2+b^2。因此 RoPE 旋转保持向量长度,这也是它不会改变 token 表示尺度的原因。

进阶Qwen2.5-0.5B 的 nheads=14n_{\text{heads}}=14nkv=2n_{\text{kv}}=2。每个 KV head 被多少个 query head 共享?

group size =nheads/nkv=14/2=7=n_{\text{heads}}/n_{\text{kv}}=14/2=7。即每个 KV head 被 7 个 query head 共享。映射为 kv_head(h)=h/7\text{kv\_head}(h)=\lfloor h/7\rfloor:query head 0–6 用 KV head 0,query head 7–13 用 KV head 1。这样 KV cache 只需存 2 个 head,而不是 14 个,显著减少缓存与带宽。

挑战实现 GQA attention 的 CPU 版本,支持 seq_q != seq_kq_pos_base

核心循环对每个 (head hh, query tqt_q):先求 kv head kv=h/groupkv=h/\text{group};可见 key 数 key_count = min(q_pos_base + tq + 1, seq_k);对 0..key_count0..\text{key\_count} 计算 qk/dq\cdot k/\sqrt{d},取 max 后做稳定 softmax,再加权求和 VV。要点:

seq_qseq_k 分开传入,使其同时支持 prefill(seqq=seqkseq_q=seq_k)和单步 decode(seqq=1seq_q=1seqk=seq_k= 历史长度+1);② q_pos_base 表示 query 的绝对起始位置,用于 causal 上限;③ K/V 用 n_kv_heads 索引而 Q 用 n_heads 索引。可对照 xinfer_core::cpu::attention 的参考实现。