Skip to content

[ExecuTorch][WebGPU] Register-tile the SDPA QK/AV kernels#20507

Merged
JulianCloudNTH merged 4 commits into
mainfrom
gh/JulianCloudNTH/49/orig
Jun 25, 2026
Merged

[ExecuTorch][WebGPU] Register-tile the SDPA QK/AV kernels#20507
JulianCloudNTH merged 4 commits into
mainfrom
gh/JulianCloudNTH/49/orig

Conversation

@pytorchbot

Copy link
Copy Markdown
Collaborator

This PR was created by the merge bot to help merge the original PR into the main branch.
ghstack PR number: #20405 by @JulianCloudNTH
^ Please use this as the source of truth for the PR details, comments, and reviews
ghstack PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/49/base
ghstack PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/49/head
Merge bot PR base: https://github.com/pytorch/executorch/tree/main
Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/49/orig

@diff-train-skip-merge

Pull Request resolved: #20405

**+32% SDPA attention-compute (AV +40%)** — register-tile the QK and AV kernels (isolated GPU-timestamp A/B, decode S=1, Chrome Canary / M4 Pro). A kernel-time win, not a wall-clock `forward()` win — `forward()` stays bound by the submit/sync/readback floor (the separate fusion axis).

**Problem**: The naive QK/AV kernels compute one output element per thread, so each thread re-loads Q/K/V and the dot products are scalar — poor register reuse, ALU/latency-bound.

**Solution**: Each thread computes a 4×4 output tile with the dot products vec4-packed in registers:
- **Before**: one thread per output element; scalar accumulate over D (QK) / context (AV).
- **After**: one thread per `(head, S-tile, {ctx,D}-tile)`; 4×4 register tile, vec4 dot products. A floating-point accumulation reorder of the same products — no algorithm change.

**Implementation**:
- `sdpa_compute_attn_weights.wgsl` (QK): one thread per `(head, S-tile, ctx-tile)`, grid `Hq · ceil(S/4) · ceil(ctx/4)`; tile registers are `array<vec4<f32>, TM/TN>` loaded via `for` loops.
- `sdpa_compute_out.wgsl` (AV): one thread per `(head, S-tile, D-tile)`, grid `Hq · ceil(S/4) · ceil(D/4)`.
- `Sdpa.cpp`: dispatch math moves from an element count to a tile count (`kSdpaTileM/N=4`, shared `utils::div_up`), keeping the uint32 scratch-overflow guard.
- Mirrors the Vulkan register-tiled SDPA kernels; the shared `utils::div_up` mirrors Vulkan's `utils::div_up`.

**Constraints**:
- softmax, `update_cache`, the bind-group layouts, and the scratch-buffer sizes (`Hq*S*ctx`) are unchanged.
- Scope is tiling only — causal tile-skip, V-cache coalescing, and branchless aligned/tail loads are separate follow-ups; this diff intentionally omits the Vulkan causal tile-skip since it is correctness-neutral (the per-element mask in `store_qk` is identical). See DESIGN_DECISIONS.md.
- Output matches the naive kernels within fp tolerance (accumulation reorder only).
ghstack-source-id: 396792505
@exported-using-ghexport

Differential Revision: [D109081409](https://our.internmc.facebook.com/intern/diff/D109081409/)
@pytorch-bot

pytorch-bot Bot commented Jun 25, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20507

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⏳ No Failures, 142 Pending

As of commit 431acb0 with merge base e03f777 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 25, 2026
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

…head-dim

Pull Request resolved: #20459

**~19% faster SDPA attention-output (AV) stage** — 393→317 µs on llama3 prefill (Chrome Canary / M4 Pro).

**Problem**: V-cache reads load 4 strided context rows × 1 head-dim lane, missing coalescing.

**Solution**: Flip access pattern to read 4 contiguous head-dim lanes per context row:
- **Before**: `load_v_vec4(d, kvh, c4)` → 4 strided rows, `dot()` along D
- **After**: `load_v_d4(c, kvh, d0)` → 4 contiguous D-lanes (16-byte texel), scalar broadcast

**Implementation**:
- Reindex `load_v` helper to read contiguous head-dim
- Replace `dot(A, V)` with `acc += A[c] * V_vec4(d0:d0+3)`
- Mirrors Vulkan `load_v_cache_d4` coalescing pattern

**Constraints**:
- No KV-cache layout change (still `[C, Hkv, D]`)
- Output numerically identical (FP-reassociated, max abs diff 1.43e-6 vs torch)
ghstack-source-id: 396792504
@exported-using-ghexport

Differential Revision: [D109339276](https://our.internmc.facebook.com/intern/diff/D109339276/)
…l tiles

Pull Request resolved: #20492

**Skip the QK contraction for fully-masked causal tiles** — at S=128 prefill ~48% of the (query, key) tiles are entirely above the diagonal and contribute nothing; this elides their dot products (prefill-only; bit-identical output).

**Problem**: For causal prefill, ~half the (query S-tile, key context-tile) pairs are entirely above the diagonal, yet the kernel still computes their full `d4` dot product before masking the result to `NEG_INF`.

**Solution**: Skip the contraction for fully-masked tiles; the existing per-element mask still writes the sentinel:
- **Before**: every `(s0, c0)` tile runs the full `d4` dot-product loop, then `store_qk` masks above-diagonal elements to `NEG_INF`.
- **After**: a fully-masked tile (`c0 > s0 + TM-1 + input_pos`) breaks the `d4` loop immediately (`acc` stays 0); `store_qk` masks every element to `NEG_INF` exactly as before.

**Implementation**:
- Add `skip_tile = c0 > s0 + (TM - 1) + params.input_pos`, folded into the `d4` loop break condition.
- Store loop unchanged — runs unconditionally, so no scratch entry is left stale.
- Mirrors Vulkan `sdpa_compute_attn_weights_tiled.glsl` (`tile_in_mask_region`).

**Constraints**:
- No KV-cache, host, dispatch, or uniform change (all tiles still launch; the skip is in-shader).
- Prefill-only: decode `S=1` never triggers it (`c0 <= input_pos < input_pos + TM - 1`).
- `NEG_INF` stays the WGSL-safe `-1.0e30` (WGSL forbids a literal `-inf`); does not copy Vulkan's `-1.0/0.0`.

Co-authored with Claude Code.
ghstack-source-id: 396792509
@exported-using-ghexport

Differential Revision: [D109517773](https://our.internmc.facebook.com/intern/diff/D109517773/)
… kernels

Pull Request resolved: #20493

**Branchless aligned/tail loads + vec4 storage bindings** — drop the always-true per-lane bounds checks in the tiled QK/AV hot loops, split the AV context contraction into a branch-free aligned body plus a checked tail, and declare the head-dim-indexed SDPA storage buffers as `array<vec4<f32>>` so the loads/stores are forced-vectorized (addresses review feedback to mirror Vulkan's vec4 bindings).

**Problem**: The tiled QK/AV vec4 loaders run 4 per-lane `if` bounds checks on every load, every contraction iteration (8 loads/iter). But `head_dim` is always a multiple of 4, so the D-axis checks never fire, and the AV context axis only needs a bounds check on the last ragged chunk. Separately the storage buffers were declared `array<f32>`, so the 4-lane loads/stores were not guaranteed to compile to aligned 128-bit vector accesses.

**Solution**: Remove the dead checks, split the ragged axis, and vectorize the bindings:
- **Before**: `load_q_vec4`/`load_k_vec4` (and AV `load_a_vec4`/`load_v_d4`) do 4 per-lane bounds `if`s per call; the AV `c4` loop runs checked loads for every chunk; `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array<f32>` accessed element-by-element.
- **After**: QK loads are a plain unchecked `vec4` (D%4==0, host-guarded); AV runs a branch-free aligned body over `c4 in [0, context_len - context_len%4)` then a 0-or-1 checked tail; the head-dim-indexed buffers `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array<vec4<f32>>` indexed `[base/4u]`, and AV writes a single aligned `store_out_vec4`.

**Implementation**:
- QK: `load_q_vec4`/`load_k_vec4` drop the per-lane D checks and return `t_q[base/4u]` / `t_k_cache[base/4u]`.
- AV: branch-free `load_a_vec4_nc`/`load_v_d4_nc` for the aligned body; checked `load_a_vec4`/`load_v_d4` for the tail; V reads `t_v_cache[base/4u]`; output is one aligned `store_out_vec4`.
- Bindings: `t_q`, `t_k_cache` (QK) and `t_v_cache`, `t_out` (AV) are `array<vec4<f32>>`. `t_attn_weights` and the softmax buffer stay `array<f32>` — they are `context_len`-indexed (row stride not 4-aligned) and written per-element under the causal mask, so a `vec4` binding there would need a padded scratch row.
- Host: add a `D % 4 == 0` guard in `Sdpa.cpp` — WGSL has no `SDPA_PAD_D` pad-load, so fail loud rather than read past the row; this guard also makes every `[base/4u]` index 4-aligned and every buffer a 16-byte multiple.
- Test: add a `reject_d6` (head_dim=6) config + an `expect_reject` harness branch asserting the guard rejects a non-aligned head_dim at load.
- Mirrors Vulkan `sdpa_compute_out_tiled.glsl` (aligned/tail split) and Vulkan's `array<vec4>` SDPA bindings.

**Constraints**:
- Requires `head_dim % 4 == 0` (true for every Llama config, D=64); enforced by a loud host throw, not a silent narrowing.
- Bit-identical output: the aligned body processes the same chunks in the same accumulation order as the scalar loop, the tail's out-of-range lanes contribute 0, and the `vec4` bindings read/write the same bytes as the scalar version.
- No KV-cache layout, dispatch, or uniform change.

Co-authored with Claude Code.
ghstack-source-id: 396792517
@exported-using-ghexport

Differential Revision: [D109521069](https://our.internmc.facebook.com/intern/diff/D109521069/)
@JulianCloudNTH JulianCloudNTH merged commit a0a730a into main Jun 25, 2026
182 of 184 checks passed
@JulianCloudNTH JulianCloudNTH deleted the gh/JulianCloudNTH/49/orig branch June 25, 2026 16:45
JulianCloudNTH added a commit to JulianCloudNTH/executorch that referenced this pull request Jun 26, 2026
The register-tile change (pytorch#20507) rewrote the `update_cache`, QK (`sdpa_compute_attn_weights`), `sdpa_softmax`, and AV (`sdpa_compute_out`) `build_dispatch` call sites and dropped the per-dispatch `kernel_name` labels originally added in pytorch#20167. With the labels gone, `WEBGPU_TIMESTAMP_QUERY` profiling can no longer attribute on-GPU time to the attention stage that produced it (every dispatch reports as the default "dispatch").

This re-threads `kernel_name` through `build_dispatch` (defaulted to `""`, so all other callers are unaffected) into the existing `WebGPUDispatch::kernel_name` field that `WebGPUQueryPool` already reads, and re-applies the four SDPA stage labels. No behavior change when profiling is off; the production `execute()` path is byte-identical.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants