[ExecuTorch][WebGPU] Add et_vk.embedding_q4gsw (4-bit groupwise-symmetric quantized embedding)#20263
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20263
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 3 New Failures, 8 Pending, 1 Unrelated FailureAs of commit 7dea4c4 with merge base 0e65ba6 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review |
|
Claude finished @JulianCloudNTH's task in 2m 1s —— View job Code Review:
|
|
@claude review |
|
Claude finished @JulianCloudNTH's task in 3m 45s —— View job Code Review:
|
f3d16c3
into
gh/JulianCloudNTH/25/base
#20292) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.15.0) (oldest at bottom): * __->__ #20292 * #20265 * #20290 * #20264 * #20289 * #20263 Test suite for the `et_vk.prepack` constant-materialization op, split into its own diff (op below, tests above) per the per-op test-split convention. The prepack op is how a serialized constant becomes a GPU tensor: the constant arrives as a CPU-side reference (sizes + a pointer into the .pte bytes), and the prepack node is the sole materialization — one CPU->GPU transfer straight into the consumer's buffer. The model `M(x) = x + w` (w a constant) routes `w` through a prepack node, so the delegate must run the materialization for the output to equal `x + w` rather than `x + 0`. @exported-using-ghexport Differential Revision: [D108678631](https://our.internmc.facebook.com/intern/diff/D108678631/) Differential Revision: [D108678631](https://our.internmc.facebook.com/intern/diff/D108678631)
…tric quantized embedding) (#20414) This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #20263 by @JulianCloudNTH ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/25/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/25/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/25/orig @diff-train-skip-merge --------- Co-authored-by: Julian Ng-Thow-Hing <juliannth@meta.com>
Stack from ghstack (oldest at bottom):
Adds the WebGPU backend handler for
et_vk.embedding_q4gsw.default(a 4-bit groupwise-symmetric quantized embedding gather) plus the host-side integer-input infra it requires.The op is a single compute dispatch composed of one stage: one thread per 32-element block of each gathered row dequantizes the packed 4-bit table (
q = (nibble - 8) * scale; even dim = high nibble, odd dim = low) into the fp32 output, mirroring the Vulkanembedding_q4gswreference (flat buffer-backed weight;is_linear_weight=trueis unsupported and throws). The workgroup size is awg_sizepipeline-override constant clamped to the device limit viaWebGPUUtils::clamp_workgroup_size, the 1D dispatch count goes throughWebGPUUtils::compute_1d_workgroup_count(validated before any GPU-object allocation), and the embedded WGSL string header is generated bygen_wgsl_headers.py.Embedding indices arrive as int64 at the program boundary but the serialized graph stores them as int32, so the shared input path is extended with a host-side
InputDataview ({data, nbytes, host_is_int64}) andcopy_inputsgains three branches: a byte-for-byte fast path when host and GPU sizes match, an int64->int32 narrowing copy when the buffer is int32 and the host input is twice as wide (mirrors the VulkankLong->kIntstaging cast), and a fail-loud throw otherwise.WebGPUTensorgainselem_size/is_intto drive the narrowing decision, andupdate_symints_from_inputstakes the sameInputDatavector soexecute()builds a single input list consumed by both.@exported-using-ghexport
Differential Revision: D108428753
Differential Revision: D108428753