Skip to content

[ET-VK] Add cooperative matrix dispatch for quantized linear#9

Draft
xuyanwen2012 wants to merge 10 commits into
mainfrom
yanwen/quant-dev
Draft

[ET-VK] Add cooperative matrix dispatch for quantized linear#9
xuyanwen2012 wants to merge 10 commits into
mainfrom
yanwen/quant-dev

Conversation

@xuyanwen2012

@xuyanwen2012 xuyanwen2012 commented May 29, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds KHR cooperative-matrix dispatch for quantized linear on the Vulkan backend, mirroring the fp16 coopmat path from pytorch#19009:

  • 4-bit weightlinear_q4gsw_coopmat (W4A16) and linear_dq8ca_q4gsw_coopmat (8da4w)
  • 8-bit weightlinear_dq8ca_q8csw_coopmat (8da8w)

Coopmat is gated on Adapter::supports_cooperative_matrix(), a wave64 subgroup, buffer output storage, half dtype, and M/N/K tile alignment; ineligible shapes (including any with a bias) fall back to the existing tiled shaders. Also fixes an fp16 overflow in the W8A16 qcs8w tiled/coop shaders by accumulating the K-loop in fp32.

Review order

QuantizedLinear.cpp (dispatch gate can_use_q4gsw_coopmat) → linear_*_coopmat.glsl shaders → op_registry.py / custom_ops_lib.py / patterns/ (registration) → the custom-op tests.

Testing

Built against current main with EXECUTORCH_BUILD_VULKAN=ON; custom_ops prototyping tests on an AMD Radeon 780M (RDNA3, wave64):

  • test_q4gsw_linear: 72/72 correctness pass
  • test_dq8ca_q8csw_linear: 22/22 correctness pass

Per the existing convention, fp16 (coopmat-only) correctness is not asserted against the fp32 CPU reference (the fp16 round-trip diverges at near-zero / overflowing elements); the coopmat path is exercised via build + dispatch + perf.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 32f742f3fb

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +134 to +137
const int row_m4 = div_4(m + m_row);
const int row_m4i = mod_4(m + m_row);
float row_scale = float(input_scales.data[row_m4][row_m4i]);
int row_zp = int(input_zps.data[row_m4][row_m4i]);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Index dynamic qparams relative to the tile

For any tiled q8csw dispatch with M > 4, the second output tile has m = 4, but input_scales/input_zps only contain the TILE_M4 entries loaded for the current tile by load_int8_input_scales_and_zps(..., m4). Using the absolute div_4(m + m_row) therefore indexes past the local arrays (and reads the wrong per-row scale/zp); this should use tile-local indices like the q4 path (div_4(m_row) / mod_4(m_row)).

Useful? React with 👍 / 👎.

@xuyanwen2012 xuyanwen2012 force-pushed the yanwen/quant-dev branch 2 times, most recently from 074d8ec to d17ee15 Compare May 30, 2026 00:10
Adds coopmat shaders and dispatch for 4-bit (q4gsw, dq8ca_q4gsw) and 8-bit (dq8ca_q8csw) quantized linear, gated on Adapter::supports_cooperative_matrix(), wave64 subgroup size, buffer output storage, and coopmat tile alignment — mirroring the fp16 coopmat path from pytorch#19009. Ineligible shapes fall back to the existing tiled shaders.

Review order: QuantizedLinear.cpp for the dispatch gate (can_use_q4gsw_coopmat), then the linear_*_coopmat.glsl shaders, op_registry.py / custom_ops_lib.py / patterns for registration, then the tests.
xuyanwen2012 and others added 8 commits June 9, 2026 09:59
… crash (+ WIP)

The Samsung Xclipse / AMD-PAL Vulkan shader compiler crashes (null deref in
vkCreateComputePipelines) when a loop containing coopMatMulAdd has a trip count
derived from the runtime sizes UBO. Pass the loop bound (num_groups for q4gsw /
dq8ca_q4gsw, K-chunk count for dq8ca_q8csw) as a specialization constant instead,
so the bound is a compile-time constant. Verified on ERD9975 (Exynos S5E9975):
int4 coopmat ~1.19x and int8 coopmat ~1.67x faster prefill than tiled, no crash.

Also saves in-progress quantized-linear coopmat aggregation work and adds the
test_coopmat_probe capability probe.
…coopmat

The four coopmat linear shaders produced wrong results on Samsung Xclipse
970, exposed by the first strict CPU-reference validation. Three
independent bugs: (1) the int4 coopmat shaders read packed weight blocks
transposed relative to the pack_q4_linear_weight layout (ivec2(n8,k4)
instead of ivec2(k4,n8)); (2) the int8 coopmat shaders passed coopMatLoad
offsets/strides in int8-element units where the KHR spec defines them in
units of the backing array's element type (uint = 4 int8), so all LDS
matrix loads were 4x off; (3) the Xclipse/AMD-PAL compiler miscompiles
coopMatStore whose offset/stride derive from a UBO value (only the first
store per subgroup lands correctly; proven with a standalone raw-Vulkan
repro), worked around by passing the output width N as a 4th
specialization constant (out_N_arg) - the same pattern as the earlier
UBO-derived-loop-bound pipeline-crash fix.

With the shaders now numerically validated, linear_q8csw_coopmat (8w) is
wired into pick_linear_qw_shader. The consolidated bench gains an exact
fp32 reference for all four ops (the dq8ca correctness data uses
scale=1/16, zp=0, activations that are multiples of 1/16 so the fp16
dynamic-quant round-trip is exact), a {64,128,64} plus multi-workgroup
{128,256,128} correctness matrix, and a COOPMAT_BENCH_CORRECTNESS_ONLY
toggle; the test harness now reports a per-16x16-tile mismatch map and
continues past failing cases instead of aborting.

On ERD9975 all 16 correctness cases pass; prefill speedups vs tiled at
M=1024 Llama-8B shapes: 4w 1.76x, 8da4w 2.5x, 8w 2.9x, 8da8w 4.0-4.5x.

Authored with Claude Code.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…ilog hoist

Port the LDS layout from linear_dq8ca_q8csw_coopmat: split A/B into MMA_K
slabs and stage B ColumnMajor with a +1-uint skew per column, so each matB
lane reads its 4 K-contiguous bytes with one ds_load_b32 instead of the
per-byte load + repack chains a RowMajor int8 B forces. The int4 block
layout cooperates: one (component, parity) pair of a packed block is
exactly one N-column's 4 K-bytes, so the nibble-unpack instruction count
is unchanged and the gain is purely from the layout. Also hoist the
izp/ifs row broadcasts out of the per-group epilog (they are group
invariant) and reorder the epilog loops so wsum/wsc are loaded once per
N-subtile, cutting broadcast coopMatLoads per group from 16 to 4.

On ERD9975 at M=1024 Llama-8B shapes the 8da4w coopmat path goes from
2658-2730 to 3261-3370 GFLOP/s (2.4-2.6x -> 3.0-3.2x vs tiled); the other
three coopmat ops are unchanged and the 16-case correctness matrix passes.

Authored with Claude Code.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Same two driver-bug workarounds already applied to the quantized coopmat
shaders: the Xclipse/AMD-PAL compiler crashes in vkCreateComputePipelines
when a loop containing coopMatMulAdd has a UBO-derived trip count, and
miscompiles coopMatStore whose offset/stride derive from a UBO value. With
these, the fp16 coopmat shader runs correctly on Xclipse 970 for the first
time (validated via test_fp16_gemm_bench on ERD9975).

Authored with Claude.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…eference

Ports NVIDIA's shmem_double_buf4.comp (double-buffered shared memory,
store-first single barrier per iteration, 128x128 tile, K-step 16,
8 subgroups x 32 threads forced via REQUIRED_SUBGROUP_SIZE) into the
custom_ops prototyping framework as etvk.gemm_double_buf, and adds
test_fp16_gemm_bench comparing it against matmul_coopmat and the tiled
texture baseline at Llama 3.1 8B prefill shapes.

On ERD9975 (Xclipse 970) the double-buffered reference reaches 5.6-6.2
TFLOP/s vs ~3.95 for matmul_coopmat (1.44-1.56x) and ~0.9 for tiled;
all small-shape correctness cases pass against a CPU fp32 reference.

Authored with Claude.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…rence

All four quantized-linear coopmat shaders adopt the loop structure of the
NVIDIA shmem_double_buf4 reference (see gemm_double_buf.glsl): prologue
register prefetch, one barrier per K-chunk, ping-pong shared-memory slices,
quant unpack at the store stage so the prefetch is pure loads in flight
during the MMA. The fp16-MMA shaders (q4gsw, q8csw) move to a 128x128 tile
with K-step 16 at 8 subgroups x 32 threads (subgroup size 32 forced via
REQUIRED_SUBGROUP_SIZE); the int8-MMA shaders (dq8ca_q4gsw, dq8ca_q8csw)
move to 128x64 with K-step 32 on wave64, since the Xclipse PAL compiler
crashes compiling int8 WMMA at forced subgroup size 32. dq8ca_q4gsw keeps a
nested groups-by-chunks loop with the group epilog at the group tail: the
flattened form with a conditional coopmat epilog also crashes that compiler
at large spec-resolved trip counts. Per-shader tile geometry moves into a
coopmat_tile_dims() table consulted by the shared gate and dispatch sizing.

Measured on ERD9975 (Xclipse 970) at M=1024 Llama 3.1 8B shapes, vs the
tiled texture baseline: q4gsw 1.76x -> 2.2x, dq8ca_q4gsw 3.1x -> 4.0-4.4x,
q8csw 2.9x -> 3.3x, dq8ca_q8csw 4.1-4.9x -> 6.9-7.9x. All correctness
cases pass; the bench gains shape coverage that also exposes a pre-existing
upstream int_input_sums allocation bug (fixed separately on
yanwen/fix-dq8ca-input-sums-alloc), plus a fix for the test harness
printing GPU data instead of the reference when dumping fp16 failures.

Authored with Claude.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
The four quantized-linear coopmat shaders share one double-buffered
skeleton since the previous change, so fold them into one template per MMA
family, mirroring how coopmat_mm.glsl hosts its three variants:

  linear_qw_coopmat.glsl        -> linear_q4gsw_coopmat_*  (WEIGHT_NBITS=4)
                                   linear_q8csw_coopmat_*  (WEIGHT_NBITS=8)
  linear_dq8ca_qw_coopmat.glsl  -> linear_dq8ca_q4gsw_coopmat_* (4)
                                   linear_dq8ca_q8csw_coopmat_* (8)

Generated variant names are unchanged, so dispatch code needs no changes.
The WEIGHT_NBITS conditionals cover only the weight unpack and the quant
bookkeeping: per-channel INT8 is expressed as the single-group special case
of the grouped INT4 path (num_groups = 1), which collapses the nested
groups-by-chunks loop, the wsum/wsc group-parity ping-pong, and the group
epilog to exactly the previous per-channel control flow.

Regression on ERD9975: all coopmat correctness cases pass and the M=1024
Llama-shape numbers match the pre-refactor run (4w 2.2x, 8da4w 4.0-4.4x,
8w 3.3x, 8da8w 7.0-8.0x vs tiled).

Authored with Claude.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
The test/custom_ops reference port of NVIDIA's shmem_double_buf4.comp is
the comparison baseline for the coopmat shaders, so name it after what it
is (the coopmat_mm reference) rather than how it buffers. Shader, yaml,
op (etvk.coopmat_mm_ref), impl file, bench labels, and comment references
updated; validated on device (fp16 GEMM bench: 6/6 correctness PASS,
numbers unchanged).

Authored with Claude.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@xuyanwen2012 xuyanwen2012 marked this pull request as draft June 24, 2026 22:35
@xuyanwen2012 xuyanwen2012 force-pushed the yanwen/quant-dev branch 2 times, most recently from 1b8c843 to 9b498d9 Compare June 24, 2026 23:03
Adds cooperative-matrix (WMMA) coopmat shaders and dispatch for the two
int4-weight quantized linear ops on the Vulkan backend:
  - linear_q4gsw_coopmat        (int4 group-symmetric weight-only)
  - linear_dq8ca_q4gsw_coopmat  (int8 dynamic act x int4 group-sym weight)

Dispatch in QuantizedLinear.cpp selects the coopmat shader when the device
supports cooperative matrix, subgroup size is 64, the output is buffer-stored
half, and M/N/K are tile-aligned; falls back to the tiled path otherwise.

The int8-input variant additionally gates on a new
Adapter::supports_int8_cooperative_matrix() (backed by a
vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR query in Device.cpp that
checks for VK_COMPONENT_TYPE_SINT8_KHR), so it is never selected on devices
that expose fp16 coopmat but not int8. The gate also rejects batched
(rank > 2) outputs, since the shaders dispatch over gl_WorkGroupID.xy only.

Tests: test_q4gsw_linear and test_coopmat_linear_bench cover both ops against
a CPU reference at coopmat-eligible shapes; test_coopmat_probe reports the
device's enumerated coopmat properties.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant