[ExecuTorch][WebGPU] SDPA: skip QK contraction for fully-masked causal tiles#20492
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20492
Note: Links to docs will display an error until the docs builds have been completed. ❗ 2 Active SEVsThere are 2 currently active SEVs. If your PR is affected, please view them below:
❌ 2 New Failures, 1 Unrelated FailureAs of commit 173487b with merge base e03f777 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
SS-JIA
left a comment
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
1a9fe0a
into
gh/JulianCloudNTH/62/base
…l tiles Pull Request resolved: #20492 **Skip the QK contraction for fully-masked causal tiles** — at S=128 prefill ~48% of the (query, key) tiles are entirely above the diagonal and contribute nothing; this elides their dot products (prefill-only; bit-identical output). **Problem**: For causal prefill, ~half the (query S-tile, key context-tile) pairs are entirely above the diagonal, yet the kernel still computes their full `d4` dot product before masking the result to `NEG_INF`. **Solution**: Skip the contraction for fully-masked tiles; the existing per-element mask still writes the sentinel: - **Before**: every `(s0, c0)` tile runs the full `d4` dot-product loop, then `store_qk` masks above-diagonal elements to `NEG_INF`. - **After**: a fully-masked tile (`c0 > s0 + TM-1 + input_pos`) breaks the `d4` loop immediately (`acc` stays 0); `store_qk` masks every element to `NEG_INF` exactly as before. **Implementation**: - Add `skip_tile = c0 > s0 + (TM - 1) + params.input_pos`, folded into the `d4` loop break condition. - Store loop unchanged — runs unconditionally, so no scratch entry is left stale. - Mirrors Vulkan `sdpa_compute_attn_weights_tiled.glsl` (`tile_in_mask_region`). **Constraints**: - No KV-cache, host, dispatch, or uniform change (all tiles still launch; the skip is in-shader). - Prefill-only: decode `S=1` never triggers it (`c0 <= input_pos < input_pos + TM - 1`). - `NEG_INF` stays the WGSL-safe `-1.0e30` (WGSL forbids a literal `-inf`); does not copy Vulkan's `-1.0/0.0`. Co-authored with Claude Code. ghstack-source-id: 396792509 @exported-using-ghexport Differential Revision: [D109517773](https://our.internmc.facebook.com/intern/diff/D109517773/)
…l tiles Pull Request resolved: #20492 **Skip the QK contraction for fully-masked causal tiles** — at S=128 prefill ~48% of the (query, key) tiles are entirely above the diagonal and contribute nothing; this elides their dot products (prefill-only; bit-identical output). **Problem**: For causal prefill, ~half the (query S-tile, key context-tile) pairs are entirely above the diagonal, yet the kernel still computes their full `d4` dot product before masking the result to `NEG_INF`. **Solution**: Skip the contraction for fully-masked tiles; the existing per-element mask still writes the sentinel: - **Before**: every `(s0, c0)` tile runs the full `d4` dot-product loop, then `store_qk` masks above-diagonal elements to `NEG_INF`. - **After**: a fully-masked tile (`c0 > s0 + TM-1 + input_pos`) breaks the `d4` loop immediately (`acc` stays 0); `store_qk` masks every element to `NEG_INF` exactly as before. **Implementation**: - Add `skip_tile = c0 > s0 + (TM - 1) + params.input_pos`, folded into the `d4` loop break condition. - Store loop unchanged — runs unconditionally, so no scratch entry is left stale. - Mirrors Vulkan `sdpa_compute_attn_weights_tiled.glsl` (`tile_in_mask_region`). **Constraints**: - No KV-cache, host, dispatch, or uniform change (all tiles still launch; the skip is in-shader). - Prefill-only: decode `S=1` never triggers it (`c0 <= input_pos < input_pos + TM - 1`). - `NEG_INF` stays the WGSL-safe `-1.0e30` (WGSL forbids a literal `-inf`); does not copy Vulkan's `-1.0/0.0`. Co-authored with Claude Code. ghstack-source-id: 396792509 @exported-using-ghexport Differential Revision: [D109517773](https://our.internmc.facebook.com/intern/diff/D109517773/)
Stack from ghstack (oldest at bottom):
Skip the QK contraction for fully-masked causal tiles — at S=128 prefill ~48% of the (query, key) tiles are entirely above the diagonal and contribute nothing; this elides their dot products (prefill-only; bit-identical output).
Problem: For causal prefill, ~half the (query S-tile, key context-tile) pairs are entirely above the diagonal, yet the kernel still computes their full
d4dot product before masking the result toNEG_INF.Solution: Skip the contraction for fully-masked tiles; the existing per-element mask still writes the sentinel:
(s0, c0)tile runs the fulld4dot-product loop, thenstore_qkmasks above-diagonal elements toNEG_INF.c0 > s0 + TM-1 + input_pos) breaks thed4loop immediately (accstays 0);store_qkmasks every element toNEG_INFexactly as before.Implementation:
skip_tile = c0 > s0 + (TM - 1) + params.input_pos, folded into thed4loop break condition.sdpa_compute_attn_weights_tiled.glsl(tile_in_mask_region).Constraints:
S=1never triggers it (c0 <= input_pos < input_pos + TM - 1).NEG_INFstays the WGSL-safe-1.0e30(WGSL forbids a literal-inf); does not copy Vulkan's-1.0/0.0.Co-authored with Claude Code.
@exported-using-ghexport
Differential Revision: D109517773
Differential Revision: D109517773