Module 8:Backend 选择 — DirectML 与 Direct3D 12
理解 DirectML 能给我们什么、D3D12 要我们自己管理什么,以及为什么 xinfer 最终把核心 matmul 改成自写 HLSL kernel。
学习目标
- 区分 DirectML operator 与自写 HLSL compute kernel 的职责。
- 理解 DirectML operator 的生命周期:create → compile → initialize → bind → dispatch。
- 知道什么时候应该使用库算子,什么时候应写 custom kernel。
- 理解 RDNA4 / DirectML GEMM device-removal 案例对工程决策的启发。
- 理解 PC 上 DXGI device creation 与 Xbox 上
D3D12XboxCreateDevice的区别。
8.1DirectML 提供什么,我们自己写什么?
DirectML 是 Windows / Xbox 生态中的机器学习 operator API,运行在 D3D12 资源与命令队列之上。它提供 GEMM、卷积、softmax、elementwise 等通用 operator,并负责把这些 operator 编译成底层驱动可以执行的 dispatchable 对象。
LLM runtime 还包含 DirectML 不覆盖的部分:tokenizer、Qwen2 tensor naming、KV cache、采样、layer streaming,以及针对模型 shape 的 fused kernel。xinfer 因此把 backend 分成两层:D3D12 负责 device、buffer、queue、fence、barrier 等资源管理;DirectML 和自写 HLSL kernel 都是 D3D12 之上的执行路径。
| 能力 | DirectML operator | 自写 HLSL kernel |
|---|---|---|
| 通用矩阵乘 / softmax | 有现成 operator | 可针对 shape 定制 |
| RMSNorm / RoPE / SwiGLU | 通常需要组合多个 operator | 一个 kernel 直接实现 |
| KV cache append | 不是典型 ML operator | CopyBufferRegion 或专用 kernel |
| GPU argmax | 可用 reduce 类 operator 或自写 | 简单、可控、少 readback |
| 异常 shape / driver quirk | 可能遇到黑盒问题 | 可绕开并调试 |
当前代码保留了通用 DmlOperator 封装和 DirectML GEMM smoke test,用于理解 operator 生命周期与验证基础路径。核心 transformer 路径(linear、RMSNorm、RoPE、attention、SwiGLU、argmax)主要使用 shaders\*.hlsl 中的 custom kernel,以便控制 layout、f16 解包、barrier 与 dispatch 形状。
8.2DirectML operator 生命周期
DirectML operator 不是普通 HLSL kernel 的轻量替代品。创建一个可复用 operator 通常需要完成一组固定步骤,xinfer-dml 的 DmlOperator 将这些步骤封装在 new 与 execute 中:
- Create:用 operator descriptor 描述输入/输出 tensor shape、dtype、参数。
- Compile:把 operator 编译成可执行的 compiled operator。
- Initialize:为需要预处理的 operator 运行 initializer,并处理 persistent / temporary resource。
- Bind:创建或重置 binding table,把 input/output/temp/persistent buffer 绑定进去。
- Dispatch:通过 DML command recorder 把 operator 记录到 D3D12 command list。
这套流程适合语义稳定、shape 常规、库实现成熟的 operator。若一个操作需要多个小算子拼接、频繁产生中间 buffer,或需要与项目的内存布局紧密配合,自写 HLSL 往往更容易形成可测试的单一 kernel。
8.3什么时候用 DML operator,什么时候写 custom kernel?
判断 backend 路径时,应同时考虑语义匹配、shape、dtype、内存布局、融合收益和调试风险。标准 operator 且库实现稳定时,DirectML 可以减少手写代码量;当操作需要特殊 layout、f16 packed weights、KV cache 操作或算子融合时,custom HLSL 更适合控制性能细节。
| 情况 | 倾向 | 原因 |
|---|---|---|
| 大 GEMM,库实现稳定 | DML / vendor lib | 库可能有高度优化的 tiling |
| RMSNorm / RoPE | HLSL | 组合 DML operator 会产生多次 dispatch / 中间 buffer |
| KV cache append | D3D12 copy 或 HLSL | 更像内存操作,不是典型 ML op |
| f16 packed weights / custom layout | HLSL | 需要手动解包和 coalesced 访问 |
| driver / device removed | HLSL | 黑盒 operator 不可控,自写路径可调试 |
xinfer 的 nn.rs 将 RMSNorm、RoPE、SwiGLU、attention、linear_f16 和 argmax 都编译为 HLSL compute kernel。linear_f16.hlsl 负责从 f16 权重读取并转换到 f32 累加;argmax.hlsl 将贪心解码的最后一步留在 GPU 上,只回读 token id。
8.4案例:DirectML GEMM 与 RDNA4 device-removal
本项目的 backend 选择受到一个具体故障影响:在 AMD RX 9070 XT(RDNA4)上,DirectML 的 GEMM operator 对部分 shape 会触发 device removed,HRESULT 为 0x887A0005。例如 的 GEMM 可复现该问题。
D3D12 debug layer 输出过 “CreateMetaCommand parameters are not supported” 警告,说明 DirectML 的 metacommand 路径不可用, 回退路径在该驱动/shape 上有问题 。自写 HLSL matmul 则正常工作。
该故障使 transformer 主路径不能依赖 DML GEMM。xinfer 保留 DirectML GEMM 作为 smoke test 和 operator lifecycle 示例,但实际 matmul 使用 custom HLSL kernel(shaders\linear.hlsl、shaders\linear_f16.hlsl)。后续优化围绕该路径展开:f16 weights、coalesced GEMV、groupshared reduction、2D dispatch grid,以及 GPU argmax。
这个案例说明两点:第一,ML operator 的可靠性取决于 driver、硬件、shape、dtype 与库内部选择的执行路径;第二,项目必须保留可替换的 backend 结构,才能在库路径失效时把风险隔离到具体 kernel。
8.5Device creation:PC 上 DXGI,Xbox 上 GDK
Windows PC 与 Xbox 的 D3D12 device creation 入口不同。PC 版本可以通过 DXGI 枚举 hardware adapter,随后调用 D3D12CreateDevice 并创建 command queue;DmlDevice::new(false) 走的就是这一路径。
Xbox console 上没有 DXGI。GDK host 负责进程入口、平台初始化与设备创建,通过 D3D12XboxCreateDevice 得到 ID3D12Device,再创建 queue。Rust core 不能在 console 上自行枚举 DXGI adapter,因此 xinfer-ffi 暴露 xinfer_create_with_device,由 C++ host 把 ID3D12Device* 与 ID3D12CommandQueue* 传入 Rust,Rust 侧用 DmlDevice::from_raw_pointers 在已有 device 上创建 DirectML device。
| 平台 | 设备来源 | Rust core 如何获得 device |
|---|---|---|
| Windows PC | DXGI adapter → D3D12CreateDevice | DmlDevice::new(false) |
| Xbox GDK | C++ host → D3D12XboxCreateDevice | xinfer_create_with_device 注入 raw COM pointers |
Lab 8DirectML GEMM 与自写 kernel 的对比实验
本实验分两步。第一步运行 DirectML GEMM,并与 CPU reference 对比,确认 operator descriptor、binding table、temporary / persistent resource 和 dispatch 过程。第二步选择已知可能触发 device-removal 的 shape,记录 D3D12 debug layer 与 GetDeviceRemovedReason 的输出,并改用自写 HLSL matmul 路径完成同一计算。
# 运行 DML / HLSL 相关测试
cargo test -p xinfer-dml --test gemm_smoke
cargo test -p xinfer-dml --test linear
# 观察 D3D12 debug layer / device removed reason 的输出(如有)
实验报告应回答:
- DML GEMM operator 需要哪些 descriptor 和 binding?
- CPU reference 如何计算同一个结果?误差阈值应该怎么设?
- 如果 device removed,如何用 D3D12 debug layer 和
GetDeviceRemovedReason缩小问题范围? - 为什么自写 HLSL kernel 可以绕开该问题?它牺牲了什么,又获得了什么?
小结
本章讨论 xinfer 的 backend 边界。DirectML 提供标准 operator 与明确的 create → compile → initialize → bind → dispatch 生命周期;D3D12 提供资源、队列、同步与跨平台 device 基础;custom HLSL kernel 则承担 Qwen2 主路径中需要 layout 控制、f16 权重、融合或规避 driver 问题的部分。RDNA4 上 DML GEMM device-removal 的案例说明,推理引擎不能只依赖单一黑盒路径。PC 与 Xbox 的差异主要发生在 device creation:PC 通过 DXGI 枚举 adapter,Xbox 由 GDK host 创建 device 并注入 Rust core。
思考与练习
基础列出 DirectML operator 的五个生命周期阶段。
五个阶段是 create、compile、initialize、bind、dispatch。create 由 descriptor 定义 operator;compile 生成 compiled operator;initialize 处理 persistent / temporary resource;bind 将输入、输出和 scratch buffer 填入 binding table;dispatch 通过 DirectML command recorder 写入 D3D12 command list。
基础解释 SRV/UAV/temp/persistent buffer 在 DML binding 中的角色。
SRV 表示只读输入,例如权重或 activation;UAV 表示可写输出;temporary buffer 是一次 execute 期间的工作区,执行后可复用;persistent buffer 由 initializer 填充,在后续多次 execute 中继续使用。绑定时必须按 compiled operator 报告的大小分配这些资源。
进阶为什么 RMSNorm 更适合写成 custom kernel,而不是组合多个 DML operator?
用 DML 组合 RMSNorm 通常需要 square、reduce mean、rsqrt、broadcast multiply、scale 等多个 operator。每一步都会引入额外 dispatch 和中间 buffer。custom HLSL kernel 可以在一个 threadgroup 内完成归约和缩放,减少显存往返,并为后续融合保留空间;这对 decode 阶段的单 token 延迟尤其重要。
进阶解释 PC 和 Xbox device creation 的关键差异。
PC 路径通过 DXGI 枚举 hardware adapter,然后调用 D3D12CreateDevice;Xbox console 没有 DXGI,C++ GDK host 通过 D3D12XboxCreateDevice 创建设备。xinfer 的 FFI 提供 xinfer_create_with_device,允许 host 把 ID3D12Device* 和 ID3D12CommandQueue* 注入 Rust core,Rust 再在该 device 上创建 DirectML device。
挑战为一个新算子制定决策:用 DML 还是 HLSL?写出你考虑的 shape、dtype、融合、debug 风险。
若算子语义标准、shape 常规、dtype 受支持,并且库实现已经验证稳定,可优先评估 DML。若算子需要特殊 layout、f16 packed weights、自定义 KV cache 操作,或能通过融合显著减少 dispatch 与中间 buffer,应评估 HLSL。还要考虑硬件和 driver 风险:本项目中 DML GEMM 在 RDNA4 的特定 shape 下触发 0x887A0005 device removed,因此主路径改为 custom matmul,并用 CPU reference 与 parity test 守护正确性。