[ExecuTorch][WebGPU] Register-tile the SDPA QK/AV kernels#20507
Conversation
Pull Request resolved: #20405 **+32% SDPA attention-compute (AV +40%)** — register-tile the QK and AV kernels (isolated GPU-timestamp A/B, decode S=1, Chrome Canary / M4 Pro). A kernel-time win, not a wall-clock `forward()` win — `forward()` stays bound by the submit/sync/readback floor (the separate fusion axis). **Problem**: The naive QK/AV kernels compute one output element per thread, so each thread re-loads Q/K/V and the dot products are scalar — poor register reuse, ALU/latency-bound. **Solution**: Each thread computes a 4×4 output tile with the dot products vec4-packed in registers: - **Before**: one thread per output element; scalar accumulate over D (QK) / context (AV). - **After**: one thread per `(head, S-tile, {ctx,D}-tile)`; 4×4 register tile, vec4 dot products. A floating-point accumulation reorder of the same products — no algorithm change. **Implementation**: - `sdpa_compute_attn_weights.wgsl` (QK): one thread per `(head, S-tile, ctx-tile)`, grid `Hq · ceil(S/4) · ceil(ctx/4)`; tile registers are `array<vec4<f32>, TM/TN>` loaded via `for` loops. - `sdpa_compute_out.wgsl` (AV): one thread per `(head, S-tile, D-tile)`, grid `Hq · ceil(S/4) · ceil(D/4)`. - `Sdpa.cpp`: dispatch math moves from an element count to a tile count (`kSdpaTileM/N=4`, shared `utils::div_up`), keeping the uint32 scratch-overflow guard. - Mirrors the Vulkan register-tiled SDPA kernels; the shared `utils::div_up` mirrors Vulkan's `utils::div_up`. **Constraints**: - softmax, `update_cache`, the bind-group layouts, and the scratch-buffer sizes (`Hq*S*ctx`) are unchanged. - Scope is tiling only — causal tile-skip, V-cache coalescing, and branchless aligned/tail loads are separate follow-ups; this diff intentionally omits the Vulkan causal tile-skip since it is correctness-neutral (the per-element mask in `store_qk` is identical). See DESIGN_DECISIONS.md. - Output matches the naive kernels within fp tolerance (accumulation reorder only). ghstack-source-id: 396792505 @exported-using-ghexport Differential Revision: [D109081409](https://our.internmc.facebook.com/intern/diff/D109081409/)
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20507
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ No Failures, 142 PendingAs of commit 431acb0 with merge base e03f777 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
…head-dim Pull Request resolved: #20459 **~19% faster SDPA attention-output (AV) stage** — 393→317 µs on llama3 prefill (Chrome Canary / M4 Pro). **Problem**: V-cache reads load 4 strided context rows × 1 head-dim lane, missing coalescing. **Solution**: Flip access pattern to read 4 contiguous head-dim lanes per context row: - **Before**: `load_v_vec4(d, kvh, c4)` → 4 strided rows, `dot()` along D - **After**: `load_v_d4(c, kvh, d0)` → 4 contiguous D-lanes (16-byte texel), scalar broadcast **Implementation**: - Reindex `load_v` helper to read contiguous head-dim - Replace `dot(A, V)` with `acc += A[c] * V_vec4(d0:d0+3)` - Mirrors Vulkan `load_v_cache_d4` coalescing pattern **Constraints**: - No KV-cache layout change (still `[C, Hkv, D]`) - Output numerically identical (FP-reassociated, max abs diff 1.43e-6 vs torch) ghstack-source-id: 396792504 @exported-using-ghexport Differential Revision: [D109339276](https://our.internmc.facebook.com/intern/diff/D109339276/)
…l tiles Pull Request resolved: #20492 **Skip the QK contraction for fully-masked causal tiles** — at S=128 prefill ~48% of the (query, key) tiles are entirely above the diagonal and contribute nothing; this elides their dot products (prefill-only; bit-identical output). **Problem**: For causal prefill, ~half the (query S-tile, key context-tile) pairs are entirely above the diagonal, yet the kernel still computes their full `d4` dot product before masking the result to `NEG_INF`. **Solution**: Skip the contraction for fully-masked tiles; the existing per-element mask still writes the sentinel: - **Before**: every `(s0, c0)` tile runs the full `d4` dot-product loop, then `store_qk` masks above-diagonal elements to `NEG_INF`. - **After**: a fully-masked tile (`c0 > s0 + TM-1 + input_pos`) breaks the `d4` loop immediately (`acc` stays 0); `store_qk` masks every element to `NEG_INF` exactly as before. **Implementation**: - Add `skip_tile = c0 > s0 + (TM - 1) + params.input_pos`, folded into the `d4` loop break condition. - Store loop unchanged — runs unconditionally, so no scratch entry is left stale. - Mirrors Vulkan `sdpa_compute_attn_weights_tiled.glsl` (`tile_in_mask_region`). **Constraints**: - No KV-cache, host, dispatch, or uniform change (all tiles still launch; the skip is in-shader). - Prefill-only: decode `S=1` never triggers it (`c0 <= input_pos < input_pos + TM - 1`). - `NEG_INF` stays the WGSL-safe `-1.0e30` (WGSL forbids a literal `-inf`); does not copy Vulkan's `-1.0/0.0`. Co-authored with Claude Code. ghstack-source-id: 396792509 @exported-using-ghexport Differential Revision: [D109517773](https://our.internmc.facebook.com/intern/diff/D109517773/)
… kernels Pull Request resolved: #20493 **Branchless aligned/tail loads + vec4 storage bindings** — drop the always-true per-lane bounds checks in the tiled QK/AV hot loops, split the AV context contraction into a branch-free aligned body plus a checked tail, and declare the head-dim-indexed SDPA storage buffers as `array<vec4<f32>>` so the loads/stores are forced-vectorized (addresses review feedback to mirror Vulkan's vec4 bindings). **Problem**: The tiled QK/AV vec4 loaders run 4 per-lane `if` bounds checks on every load, every contraction iteration (8 loads/iter). But `head_dim` is always a multiple of 4, so the D-axis checks never fire, and the AV context axis only needs a bounds check on the last ragged chunk. Separately the storage buffers were declared `array<f32>`, so the 4-lane loads/stores were not guaranteed to compile to aligned 128-bit vector accesses. **Solution**: Remove the dead checks, split the ragged axis, and vectorize the bindings: - **Before**: `load_q_vec4`/`load_k_vec4` (and AV `load_a_vec4`/`load_v_d4`) do 4 per-lane bounds `if`s per call; the AV `c4` loop runs checked loads for every chunk; `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array<f32>` accessed element-by-element. - **After**: QK loads are a plain unchecked `vec4` (D%4==0, host-guarded); AV runs a branch-free aligned body over `c4 in [0, context_len - context_len%4)` then a 0-or-1 checked tail; the head-dim-indexed buffers `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array<vec4<f32>>` indexed `[base/4u]`, and AV writes a single aligned `store_out_vec4`. **Implementation**: - QK: `load_q_vec4`/`load_k_vec4` drop the per-lane D checks and return `t_q[base/4u]` / `t_k_cache[base/4u]`. - AV: branch-free `load_a_vec4_nc`/`load_v_d4_nc` for the aligned body; checked `load_a_vec4`/`load_v_d4` for the tail; V reads `t_v_cache[base/4u]`; output is one aligned `store_out_vec4`. - Bindings: `t_q`, `t_k_cache` (QK) and `t_v_cache`, `t_out` (AV) are `array<vec4<f32>>`. `t_attn_weights` and the softmax buffer stay `array<f32>` — they are `context_len`-indexed (row stride not 4-aligned) and written per-element under the causal mask, so a `vec4` binding there would need a padded scratch row. - Host: add a `D % 4 == 0` guard in `Sdpa.cpp` — WGSL has no `SDPA_PAD_D` pad-load, so fail loud rather than read past the row; this guard also makes every `[base/4u]` index 4-aligned and every buffer a 16-byte multiple. - Test: add a `reject_d6` (head_dim=6) config + an `expect_reject` harness branch asserting the guard rejects a non-aligned head_dim at load. - Mirrors Vulkan `sdpa_compute_out_tiled.glsl` (aligned/tail split) and Vulkan's `array<vec4>` SDPA bindings. **Constraints**: - Requires `head_dim % 4 == 0` (true for every Llama config, D=64); enforced by a loud host throw, not a silent narrowing. - Bit-identical output: the aligned body processes the same chunks in the same accumulation order as the scalar loop, the tail's out-of-range lanes contribute 0, and the `vec4` bindings read/write the same bytes as the scalar version. - No KV-cache layout, dispatch, or uniform change. Co-authored with Claude Code. ghstack-source-id: 396792517 @exported-using-ghexport Differential Revision: [D109521069](https://our.internmc.facebook.com/intern/diff/D109521069/)
The register-tile change (pytorch#20507) rewrote the `update_cache`, QK (`sdpa_compute_attn_weights`), `sdpa_softmax`, and AV (`sdpa_compute_out`) `build_dispatch` call sites and dropped the per-dispatch `kernel_name` labels originally added in pytorch#20167. With the labels gone, `WEBGPU_TIMESTAMP_QUERY` profiling can no longer attribute on-GPU time to the attention stage that produced it (every dispatch reports as the default "dispatch"). This re-threads `kernel_name` through `build_dispatch` (defaulted to `""`, so all other callers are unaffected) into the existing `WebGPUDispatch::kernel_name` field that `WebGPUQueryPool` already reads, and re-applies the four SDPA stage labels. No behavior change when profiling is off; the production `execute()` path is byte-identical.
This PR was created by the merge bot to help merge the original PR into the main branch.
ghstack PR number: #20405 by @JulianCloudNTH
^ Please use this as the source of truth for the PR details, comments, and reviews
ghstack PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/49/base
ghstack PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/49/head
Merge bot PR base: https://github.com/pytorch/executorch/tree/main
Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/49/orig
@diff-train-skip-merge