[ET-VK] Add cooperative matrix dispatch for quantized linear#19892
[ET-VK] Add cooperative matrix dispatch for quantized linear#19892xuyanwen2012 wants to merge 10 commits into
Conversation
Adds coopmat shaders and dispatch for 4-bit (q4gsw, dq8ca_q4gsw) and 8-bit (dq8ca_q8csw) quantized linear, gated on Adapter::supports_cooperative_matrix(), wave64 subgroup size, buffer output storage, and coopmat tile alignment — mirroring the fp16 coopmat path from pytorch#19009. Ineligible shapes fall back to the existing tiled shaders. Review order: QuantizedLinear.cpp for the dispatch gate (can_use_q4gsw_coopmat), then the linear_*_coopmat.glsl shaders, op_registry.py / custom_ops_lib.py / patterns for registration, then the tests.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19892
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below:
|
One or more co-authors of this pull request were not found. You must specify co-authors in commit message trailer via: Supported
Alternatively, if the co-author should not be included, remove the Please update your commit message(s) by doing |
This PR needs a
|
… crash (+ WIP) The Samsung Xclipse / AMD-PAL Vulkan shader compiler crashes (null deref in vkCreateComputePipelines) when a loop containing coopMatMulAdd has a trip count derived from the runtime sizes UBO. Pass the loop bound (num_groups for q4gsw / dq8ca_q4gsw, K-chunk count for dq8ca_q8csw) as a specialization constant instead, so the bound is a compile-time constant. Verified on ERD9975 (Exynos S5E9975): int4 coopmat ~1.19x and int8 coopmat ~1.67x faster prefill than tiled, no crash. Also saves in-progress quantized-linear coopmat aggregation work and adds the test_coopmat_probe capability probe.
…coopmat
The four coopmat linear shaders produced wrong results on Samsung Xclipse
970, exposed by the first strict CPU-reference validation. Three
independent bugs: (1) the int4 coopmat shaders read packed weight blocks
transposed relative to the pack_q4_linear_weight layout (ivec2(n8,k4)
instead of ivec2(k4,n8)); (2) the int8 coopmat shaders passed coopMatLoad
offsets/strides in int8-element units where the KHR spec defines them in
units of the backing array's element type (uint = 4 int8), so all LDS
matrix loads were 4x off; (3) the Xclipse/AMD-PAL compiler miscompiles
coopMatStore whose offset/stride derive from a UBO value (only the first
store per subgroup lands correctly; proven with a standalone raw-Vulkan
repro), worked around by passing the output width N as a 4th
specialization constant (out_N_arg) - the same pattern as the earlier
UBO-derived-loop-bound pipeline-crash fix.
With the shaders now numerically validated, linear_q8csw_coopmat (8w) is
wired into pick_linear_qw_shader. The consolidated bench gains an exact
fp32 reference for all four ops (the dq8ca correctness data uses
scale=1/16, zp=0, activations that are multiples of 1/16 so the fp16
dynamic-quant round-trip is exact), a {64,128,64} plus multi-workgroup
{128,256,128} correctness matrix, and a COOPMAT_BENCH_CORRECTNESS_ONLY
toggle; the test harness now reports a per-16x16-tile mismatch map and
continues past failing cases instead of aborting.
On ERD9975 all 16 correctness cases pass; prefill speedups vs tiled at
M=1024 Llama-8B shapes: 4w 1.76x, 8da4w 2.5x, 8w 2.9x, 8da8w 4.0-4.5x.
Authored with Claude Code.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…ilog hoist Port the LDS layout from linear_dq8ca_q8csw_coopmat: split A/B into MMA_K slabs and stage B ColumnMajor with a +1-uint skew per column, so each matB lane reads its 4 K-contiguous bytes with one ds_load_b32 instead of the per-byte load + repack chains a RowMajor int8 B forces. The int4 block layout cooperates: one (component, parity) pair of a packed block is exactly one N-column's 4 K-bytes, so the nibble-unpack instruction count is unchanged and the gain is purely from the layout. Also hoist the izp/ifs row broadcasts out of the per-group epilog (they are group invariant) and reorder the epilog loops so wsum/wsc are loaded once per N-subtile, cutting broadcast coopMatLoads per group from 16 to 4. On ERD9975 at M=1024 Llama-8B shapes the 8da4w coopmat path goes from 2658-2730 to 3261-3370 GFLOP/s (2.4-2.6x -> 3.0-3.2x vs tiled); the other three coopmat ops are unchanged and the 16-case correctness matrix passes. Authored with Claude Code. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Same two driver-bug workarounds already applied to the quantized coopmat shaders: the Xclipse/AMD-PAL compiler crashes in vkCreateComputePipelines when a loop containing coopMatMulAdd has a UBO-derived trip count, and miscompiles coopMatStore whose offset/stride derive from a UBO value. With these, the fp16 coopmat shader runs correctly on Xclipse 970 for the first time (validated via test_fp16_gemm_bench on ERD9975). Authored with Claude. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…eference Ports NVIDIA's shmem_double_buf4.comp (double-buffered shared memory, store-first single barrier per iteration, 128x128 tile, K-step 16, 8 subgroups x 32 threads forced via REQUIRED_SUBGROUP_SIZE) into the custom_ops prototyping framework as etvk.gemm_double_buf, and adds test_fp16_gemm_bench comparing it against matmul_coopmat and the tiled texture baseline at Llama 3.1 8B prefill shapes. On ERD9975 (Xclipse 970) the double-buffered reference reaches 5.6-6.2 TFLOP/s vs ~3.95 for matmul_coopmat (1.44-1.56x) and ~0.9 for tiled; all small-shape correctness cases pass against a CPU fp32 reference. Authored with Claude. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…rence All four quantized-linear coopmat shaders adopt the loop structure of the NVIDIA shmem_double_buf4 reference (see gemm_double_buf.glsl): prologue register prefetch, one barrier per K-chunk, ping-pong shared-memory slices, quant unpack at the store stage so the prefetch is pure loads in flight during the MMA. The fp16-MMA shaders (q4gsw, q8csw) move to a 128x128 tile with K-step 16 at 8 subgroups x 32 threads (subgroup size 32 forced via REQUIRED_SUBGROUP_SIZE); the int8-MMA shaders (dq8ca_q4gsw, dq8ca_q8csw) move to 128x64 with K-step 32 on wave64, since the Xclipse PAL compiler crashes compiling int8 WMMA at forced subgroup size 32. dq8ca_q4gsw keeps a nested groups-by-chunks loop with the group epilog at the group tail: the flattened form with a conditional coopmat epilog also crashes that compiler at large spec-resolved trip counts. Per-shader tile geometry moves into a coopmat_tile_dims() table consulted by the shared gate and dispatch sizing. Measured on ERD9975 (Xclipse 970) at M=1024 Llama 3.1 8B shapes, vs the tiled texture baseline: q4gsw 1.76x -> 2.2x, dq8ca_q4gsw 3.1x -> 4.0-4.4x, q8csw 2.9x -> 3.3x, dq8ca_q8csw 4.1-4.9x -> 6.9-7.9x. All correctness cases pass; the bench gains shape coverage that also exposes a pre-existing upstream int_input_sums allocation bug (fixed separately on yanwen/fix-dq8ca-input-sums-alloc), plus a fix for the test harness printing GPU data instead of the reference when dumping fp16 failures. Authored with Claude. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
The four quantized-linear coopmat shaders share one double-buffered
skeleton since the previous change, so fold them into one template per MMA
family, mirroring how coopmat_mm.glsl hosts its three variants:
linear_qw_coopmat.glsl -> linear_q4gsw_coopmat_* (WEIGHT_NBITS=4)
linear_q8csw_coopmat_* (WEIGHT_NBITS=8)
linear_dq8ca_qw_coopmat.glsl -> linear_dq8ca_q4gsw_coopmat_* (4)
linear_dq8ca_q8csw_coopmat_* (8)
Generated variant names are unchanged, so dispatch code needs no changes.
The WEIGHT_NBITS conditionals cover only the weight unpack and the quant
bookkeeping: per-channel INT8 is expressed as the single-group special case
of the grouped INT4 path (num_groups = 1), which collapses the nested
groups-by-chunks loop, the wsum/wsc group-parity ping-pong, and the group
epilog to exactly the previous per-channel control flow.
Regression on ERD9975: all coopmat correctness cases pass and the M=1024
Llama-shape numbers match the pre-refactor run (4w 2.2x, 8da4w 4.0-4.4x,
8w 3.3x, 8da8w 7.0-8.0x vs tiled).
Authored with Claude.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
The test/custom_ops reference port of NVIDIA's shmem_double_buf4.comp is the comparison baseline for the coopmat shaders, so name it after what it is (the coopmat_mm reference) rather than how it buffers. Shader, yaml, op (etvk.coopmat_mm_ref), impl file, bench labels, and comment references updated; validated on device (fp16 GEMM bench: 6/6 correctness PASS, numbers unchanged). Authored with Claude. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds cooperative-matrix (coopmat) shader variants and dispatch gating for quantized linear ops on Vulkan, plus introduces a new dynamic-activation INT8 × per-channel INT8 weight path (linear_dq8ca_q8csw) and expands test coverage to fp16 activation variants.
Changes:
- Add coopmat shaders (and dispatch gating / WG sizing) for
linear_q4gswandlinear_dq8ca_*variants. - Introduce
linear_dq8ca_q8cswend-to-end: pattern rewrite, op registration, runtime implementation, and custom op schema. - Update test utilities and benchmarks to treat fp16 data as IEEE-754 half and to exercise
_halfshader variants.
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/vulkan/test/custom_ops/utils.cpp | Fix fp16 data generation/printing to use real half bit patterns. |
| backends/vulkan/test/custom_ops/test_q4gsw_linear.cpp | Expand tests to fp16 activations and adjust tolerances / shapes. |
| backends/vulkan/test/custom_ops/test_dq8ca_q8csw_linear.cpp | New benchmark + reference for dq8ca×q8csw path. |
| backends/vulkan/test/custom_ops/CMakeLists.txt | Build the new dq8ca×q8csw test target. |
| backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp | Add coopmat gating, WG sizing, and register linear_dq8ca_q8csw. |
| backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coopmat.{yaml,glsl} | New fp16×int4 coopmat shader variants. |
| backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q8csw_tiled.{yaml,glsl} | New tiled dq8ca×q8csw shader variants. |
| backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q8csw_coopmat.{yaml,glsl} | New int8 coopmat dq8ca×q8csw shader variants. |
| backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_coopmat.{yaml,glsl} | New int8 coopmat dq8ca×q4gsw shader variants. |
| backends/vulkan/patterns/quantized_linear.py | Add rewrite to linear_dq8ca_q8csw + broaden per-channel scales detection. |
| backends/vulkan/op_registry.py | Register feature constraints for linear_dq8ca_q8csw. |
| backends/vulkan/custom_ops_lib.py | Define/implement Python custom op for linear_dq8ca_q8csw. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @update_features(exir_ops.edge.et_vk.linear_dq8ca_q8csw.default) | ||
| def register_linear_dq8ca_q8csw(): | ||
| return OpFeatures( | ||
| inputs_storage=[ | ||
| utils.CONTIGUOUS_ANY, # input | ||
| utils.WIDTH_PACKED_TEXTURE, # input_scale | ||
| utils.WIDTH_PACKED_TEXTURE, # input_zero_point | ||
| utils.NO_STORAGE, # weight (prepacked) | ||
| utils.NO_STORAGE, # weight_sums (prepacked) | ||
| utils.NO_STORAGE, # weight_scales (prepacked) | ||
| utils.NO_STORAGE, # bias (prepacked) | ||
| ], | ||
| inputs_dtypes=utils.FP_T, | ||
| supports_prepacking=True, | ||
| ) |
| def linear_dq8ca_q8csw( | ||
| x: torch.Tensor, | ||
| input_scale: torch.Tensor, | ||
| input_zero_point: torch.Tensor, | ||
| weights: torch.Tensor, | ||
| weight_sums: torch.Tensor, | ||
| weight_scales: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None, | ||
| ): | ||
| # Per-channel symmetric INT8 weight: dequant = weight.to(fp) * scales (per output channel) | ||
| weights_dq = weights.to(x.dtype) * weight_scales.unsqueeze(-1) | ||
| return torch.nn.functional.linear(x, weights_dq, bias) |
| # Per-channel symmetric INT8 weight: no group_size, no nibble packing. | ||
| # Align width to 4 so GPU shader reads don't go OOB. | ||
| utils.align_width_and_update_state_dict( | ||
| ep, | ||
| match.weight_node, | ||
| weight_tensor, | ||
| align_to=1, | ||
| force_update=True, | ||
| ) |
| if (input_dtype == vkapi::kHalf) { | ||
| // INT8 dot4 → INT32 accum → fp32 dequant → fp16 store; the only fp16 | ||
| // rounding is at the final store. Per-row dynamic act scale gives | ||
| // O(1) magnitudes pre-store, so a few ULPs of fp16 jitter is normal. | ||
| test_case.set_abs_tolerance(5.0f); | ||
| test_case.set_rel_tolerance(2e-1f); | ||
| } else { | ||
| test_case.set_abs_tolerance(1e-2f); | ||
| test_case.set_rel_tolerance(1e-2f); | ||
| } |
| // CPU reference uses fp32 throughout; comparing against an fp16 GPU output | ||
| // hits inherent rounding mismatches on edge-case (near-zero) elements that | ||
| // exceed any practical tolerance. Match q4gsw_linear.cpp's convention and | ||
| // skip correctness for kHalf — performance timings still run. | ||
| if (input_spec.dtype == vkapi::kHalf) { | ||
| throw std::invalid_argument( | ||
| "Reference impl skipped for kHalf — fp16 round-trip diverges from " | ||
| "the fp32 CPU reference at near-zero elements."); | ||
| } |
| const auto* adapter = graph->context()->adapter_ptr(); | ||
| if (!adapter->supports_cooperative_matrix()) { | ||
| return false; | ||
| } | ||
| if (adapter->subgroup_size() != 64) { | ||
| return false; | ||
| } |
| const uint A_STRIDE_VEC4 = (WG_TILE_K + FP16_PER_VEC4) / FP16_PER_VEC4; | ||
| const uint B_STRIDE_VEC4 = (WG_TILE_N + FP16_PER_VEC4) / FP16_PER_VEC4; |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 05a4f505ef
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| const std::vector<int64_t> out_sizes = graph->sizes_of(output); | ||
| const int64_t N = utils::val_at(-1, out_sizes); | ||
| const int64_t M = utils::val_at(-2, out_sizes); | ||
| const std::vector<int64_t> in_sizes = graph->sizes_of(fp_input); | ||
| const int64_t K = utils::val_at(-1, in_sizes); |
There was a problem hiding this comment.
Reject batched outputs before selecting coopmat
When out has rank > 2, this gate can still select the new coopmat shaders as long as the last two dimensions are 64-aligned. Those shaders use only gl_WorkGroupID.xy and store gi * N + gj (no batch/z offset), and the existing is_coopmat_eligible() helper in GemmCoopmat.h explicitly rejects graph.dim_of(out) > 2 for that reason. On coopmat-capable devices, a batched quantized linear with no bias will therefore compute only the first 2-D tile region / overwrite the wrong buffer region instead of falling back to the tiled path.
Useful? React with 👍 / 👎.
|
ack -- looking now |
f828f68 to
1b8c843
Compare
1b8c843 to
9b498d9
Compare
Adds cooperative-matrix (WMMA) coopmat shaders and dispatch for the two int4-weight quantized linear ops on the Vulkan backend: - linear_q4gsw_coopmat (int4 group-symmetric weight-only) - linear_dq8ca_q4gsw_coopmat (int8 dynamic act x int4 group-sym weight) Dispatch in QuantizedLinear.cpp selects the coopmat shader when the device supports cooperative matrix, subgroup size is 64, the output is buffer-stored half, and M/N/K are tile-aligned; falls back to the tiled path otherwise. The int8-input variant additionally gates on a new Adapter::supports_int8_cooperative_matrix() (backed by a vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR query in Device.cpp that checks for VK_COMPONENT_TYPE_SINT8_KHR), so it is never selected on devices that expose fp16 coopmat but not int8. The gate also rejects batched (rank > 2) outputs, since the shaders dispatch over gl_WorkGroupID.xy only. Tests: test_q4gsw_linear and test_coopmat_linear_bench cover both ops against a CPU reference at coopmat-eligible shapes; test_coopmat_probe reports the device's enumerated coopmat properties.
9b498d9 to
33cbbea
Compare
Summary
Adds KHR cooperative-matrix dispatch for quantized linear on the Vulkan backend, extending the fp16 coopmat path from #19009 to quantized weights:
linear_q4gsw_coopmat(fp16 act × INT4 weight) andlinear_dq8ca_q4gsw_coopmat(8-bit dynamic act × INT4 weight)linear_dq8ca_q8csw_coopmat(8-bit dynamic act × INT8 weight), plus its tiled V_DOT4 fallback and op registrationCoopmat is gated on
Adapter::supports_cooperative_matrix(), a wave64 subgroup, buffer output storage, half dtype, and M/N/K tile alignment. Ineligible shapes — including any with a bias — fall back to the existing tiled shaders.Review order
QuantizedLinear.cpp(dispatch gatecan_use_q4gsw_coopmat) → thelinear_*_coopmat.glslshaders →op_registry.py/custom_ops_lib.py/patterns/quantized_linear.py(registration) → the custom-op tests.Test plan
Built against
mainwithEXECUTORCH_BUILD_VULKAN=ON; ran the custom_ops prototyping tests on an AMD Radeon 780M (RDNA3, wave64):test_q4gsw_linear: 72/72 correctness passtest_dq8ca_q8csw_linear: 22/22 correctness passPer the existing convention, fp16 (coopmat-only) correctness is not asserted against the fp32 CPU reference (the fp16 round-trip diverges at near-zero / overflowing elements); the coopmat path is exercised via build + dispatch + perf.
Open questions (draft)