Skip to content

[ET-VK] Add cooperative matrix dispatch for quantized linear#19892

Open
xuyanwen2012 wants to merge 10 commits into
pytorch:mainfrom
sarc-acl:yanwen/quant-dev
Open

[ET-VK] Add cooperative matrix dispatch for quantized linear#19892
xuyanwen2012 wants to merge 10 commits into
pytorch:mainfrom
sarc-acl:yanwen/quant-dev

Conversation

@xuyanwen2012

@xuyanwen2012 xuyanwen2012 commented May 30, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds KHR cooperative-matrix dispatch for quantized linear on the Vulkan backend, extending the fp16 coopmat path from #19009 to quantized weights:

  • 4-bit weightlinear_q4gsw_coopmat (fp16 act × INT4 weight) and linear_dq8ca_q4gsw_coopmat (8-bit dynamic act × INT4 weight)
  • 8-bit weightlinear_dq8ca_q8csw_coopmat (8-bit dynamic act × INT8 weight), plus its tiled V_DOT4 fallback and op registration

Coopmat is gated on Adapter::supports_cooperative_matrix(), a wave64 subgroup, buffer output storage, half dtype, and M/N/K tile alignment. Ineligible shapes — including any with a bias — fall back to the existing tiled shaders.

Review order

QuantizedLinear.cpp (dispatch gate can_use_q4gsw_coopmat) → the linear_*_coopmat.glsl shaders → op_registry.py / custom_ops_lib.py / patterns/quantized_linear.py (registration) → the custom-op tests.

Test plan

Built against main with EXECUTORCH_BUILD_VULKAN=ON; ran the custom_ops prototyping tests on an AMD Radeon 780M (RDNA3, wave64):

  • test_q4gsw_linear: 72/72 correctness pass
  • test_dq8ca_q8csw_linear: 22/22 correctness pass

Per the existing convention, fp16 (coopmat-only) correctness is not asserted against the fp32 CPU reference (the fp16 round-trip diverges at near-zero / overflowing elements); the coopmat path is exercised via build + dispatch + perf.

Open questions (draft)

  • Reachability: the coopmat path requires buffer output storage, so it only fires when the partitioner selects buffer storage for the linear's output. Feedback welcome on the preferred way to make it reachable in a typical export

Adds coopmat shaders and dispatch for 4-bit (q4gsw, dq8ca_q4gsw) and 8-bit (dq8ca_q8csw) quantized linear, gated on Adapter::supports_cooperative_matrix(), wave64 subgroup size, buffer output storage, and coopmat tile alignment — mirroring the fp16 coopmat path from pytorch#19009. Ineligible shapes fall back to the existing tiled shaders.

Review order: QuantizedLinear.cpp for the dispatch gate (can_use_q4gsw_coopmat), then the linear_*_coopmat.glsl shaders, op_registry.py / custom_ops_lib.py / patterns for registration, then the tests.
@pytorch-bot

pytorch-bot Bot commented May 30, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19892

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⚠️ 12 Awaiting Approval

As of commit 33cbbea with merge base 5395f20 (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 30, 2026
@linux-foundation-easycla

linux-foundation-easycla Bot commented May 30, 2026

Copy link
Copy Markdown

CLA Not Signed

One or more co-authors of this pull request were not found. You must specify co-authors in commit message trailer via:

Co-authored-by: name <email>

Supported Co-authored-by: formats include:

  1. Anything <id+login@users.noreply.github.com> - it will locate your GitHub user by id part.
  2. Anything <login@users.noreply.github.com> - it will locate your GitHub user by login part.
  3. Anything <public-email> - it will locate your GitHub user by public-email part. Note that this email must be made public on Github.
  4. Anything <other-email> - it will locate your GitHub user by other-email part but only if that email was used before for any other CLA as a main commit author.
  5. login <any-valid-email> - it will locate your GitHub user by login part, note that login part must be at least 3 characters long.

Alternatively, if the co-author should not be included, remove the Co-authored-by: line from the commit message.

Please update your commit message(s) by doing git commit --amend and then git push [--force] and then request re-running CLA check via commenting on this pull request:

/easycla

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

xuyanwen2012 and others added 8 commits June 9, 2026 09:59
… crash (+ WIP)

The Samsung Xclipse / AMD-PAL Vulkan shader compiler crashes (null deref in
vkCreateComputePipelines) when a loop containing coopMatMulAdd has a trip count
derived from the runtime sizes UBO. Pass the loop bound (num_groups for q4gsw /
dq8ca_q4gsw, K-chunk count for dq8ca_q8csw) as a specialization constant instead,
so the bound is a compile-time constant. Verified on ERD9975 (Exynos S5E9975):
int4 coopmat ~1.19x and int8 coopmat ~1.67x faster prefill than tiled, no crash.

Also saves in-progress quantized-linear coopmat aggregation work and adds the
test_coopmat_probe capability probe.
…coopmat

The four coopmat linear shaders produced wrong results on Samsung Xclipse
970, exposed by the first strict CPU-reference validation. Three
independent bugs: (1) the int4 coopmat shaders read packed weight blocks
transposed relative to the pack_q4_linear_weight layout (ivec2(n8,k4)
instead of ivec2(k4,n8)); (2) the int8 coopmat shaders passed coopMatLoad
offsets/strides in int8-element units where the KHR spec defines them in
units of the backing array's element type (uint = 4 int8), so all LDS
matrix loads were 4x off; (3) the Xclipse/AMD-PAL compiler miscompiles
coopMatStore whose offset/stride derive from a UBO value (only the first
store per subgroup lands correctly; proven with a standalone raw-Vulkan
repro), worked around by passing the output width N as a 4th
specialization constant (out_N_arg) - the same pattern as the earlier
UBO-derived-loop-bound pipeline-crash fix.

With the shaders now numerically validated, linear_q8csw_coopmat (8w) is
wired into pick_linear_qw_shader. The consolidated bench gains an exact
fp32 reference for all four ops (the dq8ca correctness data uses
scale=1/16, zp=0, activations that are multiples of 1/16 so the fp16
dynamic-quant round-trip is exact), a {64,128,64} plus multi-workgroup
{128,256,128} correctness matrix, and a COOPMAT_BENCH_CORRECTNESS_ONLY
toggle; the test harness now reports a per-16x16-tile mismatch map and
continues past failing cases instead of aborting.

On ERD9975 all 16 correctness cases pass; prefill speedups vs tiled at
M=1024 Llama-8B shapes: 4w 1.76x, 8da4w 2.5x, 8w 2.9x, 8da8w 4.0-4.5x.

Authored with Claude Code.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…ilog hoist

Port the LDS layout from linear_dq8ca_q8csw_coopmat: split A/B into MMA_K
slabs and stage B ColumnMajor with a +1-uint skew per column, so each matB
lane reads its 4 K-contiguous bytes with one ds_load_b32 instead of the
per-byte load + repack chains a RowMajor int8 B forces. The int4 block
layout cooperates: one (component, parity) pair of a packed block is
exactly one N-column's 4 K-bytes, so the nibble-unpack instruction count
is unchanged and the gain is purely from the layout. Also hoist the
izp/ifs row broadcasts out of the per-group epilog (they are group
invariant) and reorder the epilog loops so wsum/wsc are loaded once per
N-subtile, cutting broadcast coopMatLoads per group from 16 to 4.

On ERD9975 at M=1024 Llama-8B shapes the 8da4w coopmat path goes from
2658-2730 to 3261-3370 GFLOP/s (2.4-2.6x -> 3.0-3.2x vs tiled); the other
three coopmat ops are unchanged and the 16-case correctness matrix passes.

Authored with Claude Code.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Same two driver-bug workarounds already applied to the quantized coopmat
shaders: the Xclipse/AMD-PAL compiler crashes in vkCreateComputePipelines
when a loop containing coopMatMulAdd has a UBO-derived trip count, and
miscompiles coopMatStore whose offset/stride derive from a UBO value. With
these, the fp16 coopmat shader runs correctly on Xclipse 970 for the first
time (validated via test_fp16_gemm_bench on ERD9975).

Authored with Claude.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…eference

Ports NVIDIA's shmem_double_buf4.comp (double-buffered shared memory,
store-first single barrier per iteration, 128x128 tile, K-step 16,
8 subgroups x 32 threads forced via REQUIRED_SUBGROUP_SIZE) into the
custom_ops prototyping framework as etvk.gemm_double_buf, and adds
test_fp16_gemm_bench comparing it against matmul_coopmat and the tiled
texture baseline at Llama 3.1 8B prefill shapes.

On ERD9975 (Xclipse 970) the double-buffered reference reaches 5.6-6.2
TFLOP/s vs ~3.95 for matmul_coopmat (1.44-1.56x) and ~0.9 for tiled;
all small-shape correctness cases pass against a CPU fp32 reference.

Authored with Claude.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…rence

All four quantized-linear coopmat shaders adopt the loop structure of the
NVIDIA shmem_double_buf4 reference (see gemm_double_buf.glsl): prologue
register prefetch, one barrier per K-chunk, ping-pong shared-memory slices,
quant unpack at the store stage so the prefetch is pure loads in flight
during the MMA. The fp16-MMA shaders (q4gsw, q8csw) move to a 128x128 tile
with K-step 16 at 8 subgroups x 32 threads (subgroup size 32 forced via
REQUIRED_SUBGROUP_SIZE); the int8-MMA shaders (dq8ca_q4gsw, dq8ca_q8csw)
move to 128x64 with K-step 32 on wave64, since the Xclipse PAL compiler
crashes compiling int8 WMMA at forced subgroup size 32. dq8ca_q4gsw keeps a
nested groups-by-chunks loop with the group epilog at the group tail: the
flattened form with a conditional coopmat epilog also crashes that compiler
at large spec-resolved trip counts. Per-shader tile geometry moves into a
coopmat_tile_dims() table consulted by the shared gate and dispatch sizing.

Measured on ERD9975 (Xclipse 970) at M=1024 Llama 3.1 8B shapes, vs the
tiled texture baseline: q4gsw 1.76x -> 2.2x, dq8ca_q4gsw 3.1x -> 4.0-4.4x,
q8csw 2.9x -> 3.3x, dq8ca_q8csw 4.1-4.9x -> 6.9-7.9x. All correctness
cases pass; the bench gains shape coverage that also exposes a pre-existing
upstream int_input_sums allocation bug (fixed separately on
yanwen/fix-dq8ca-input-sums-alloc), plus a fix for the test harness
printing GPU data instead of the reference when dumping fp16 failures.

Authored with Claude.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
The four quantized-linear coopmat shaders share one double-buffered
skeleton since the previous change, so fold them into one template per MMA
family, mirroring how coopmat_mm.glsl hosts its three variants:

  linear_qw_coopmat.glsl        -> linear_q4gsw_coopmat_*  (WEIGHT_NBITS=4)
                                   linear_q8csw_coopmat_*  (WEIGHT_NBITS=8)
  linear_dq8ca_qw_coopmat.glsl  -> linear_dq8ca_q4gsw_coopmat_* (4)
                                   linear_dq8ca_q8csw_coopmat_* (8)

Generated variant names are unchanged, so dispatch code needs no changes.
The WEIGHT_NBITS conditionals cover only the weight unpack and the quant
bookkeeping: per-channel INT8 is expressed as the single-group special case
of the grouped INT4 path (num_groups = 1), which collapses the nested
groups-by-chunks loop, the wsum/wsc group-parity ping-pong, and the group
epilog to exactly the previous per-channel control flow.

Regression on ERD9975: all coopmat correctness cases pass and the M=1024
Llama-shape numbers match the pre-refactor run (4w 2.2x, 8da4w 4.0-4.4x,
8w 3.3x, 8da8w 7.0-8.0x vs tiled).

Authored with Claude.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
The test/custom_ops reference port of NVIDIA's shmem_double_buf4.comp is
the comparison baseline for the coopmat shaders, so name it after what it
is (the coopmat_mm reference) rather than how it buffers. Shader, yaml,
op (etvk.coopmat_mm_ref), impl file, bench labels, and comment references
updated; validated on device (fp16 GEMM bench: 6/6 correctness PASS,
numbers unchanged).

Authored with Claude.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@xuyanwen2012 xuyanwen2012 marked this pull request as ready for review June 10, 2026 19:18
Copilot AI review requested due to automatic review settings June 10, 2026 19:18

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds cooperative-matrix (coopmat) shader variants and dispatch gating for quantized linear ops on Vulkan, plus introduces a new dynamic-activation INT8 × per-channel INT8 weight path (linear_dq8ca_q8csw) and expands test coverage to fp16 activation variants.

Changes:

  • Add coopmat shaders (and dispatch gating / WG sizing) for linear_q4gsw and linear_dq8ca_* variants.
  • Introduce linear_dq8ca_q8csw end-to-end: pattern rewrite, op registration, runtime implementation, and custom op schema.
  • Update test utilities and benchmarks to treat fp16 data as IEEE-754 half and to exercise _half shader variants.

Reviewed changes

Copilot reviewed 16 out of 16 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
backends/vulkan/test/custom_ops/utils.cpp Fix fp16 data generation/printing to use real half bit patterns.
backends/vulkan/test/custom_ops/test_q4gsw_linear.cpp Expand tests to fp16 activations and adjust tolerances / shapes.
backends/vulkan/test/custom_ops/test_dq8ca_q8csw_linear.cpp New benchmark + reference for dq8ca×q8csw path.
backends/vulkan/test/custom_ops/CMakeLists.txt Build the new dq8ca×q8csw test target.
backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp Add coopmat gating, WG sizing, and register linear_dq8ca_q8csw.
backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coopmat.{yaml,glsl} New fp16×int4 coopmat shader variants.
backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q8csw_tiled.{yaml,glsl} New tiled dq8ca×q8csw shader variants.
backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q8csw_coopmat.{yaml,glsl} New int8 coopmat dq8ca×q8csw shader variants.
backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_coopmat.{yaml,glsl} New int8 coopmat dq8ca×q4gsw shader variants.
backends/vulkan/patterns/quantized_linear.py Add rewrite to linear_dq8ca_q8csw + broaden per-channel scales detection.
backends/vulkan/op_registry.py Register feature constraints for linear_dq8ca_q8csw.
backends/vulkan/custom_ops_lib.py Define/implement Python custom op for linear_dq8ca_q8csw.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread backends/vulkan/op_registry.py Outdated
Comment on lines +484 to +498
@update_features(exir_ops.edge.et_vk.linear_dq8ca_q8csw.default)
def register_linear_dq8ca_q8csw():
return OpFeatures(
inputs_storage=[
utils.CONTIGUOUS_ANY, # input
utils.WIDTH_PACKED_TEXTURE, # input_scale
utils.WIDTH_PACKED_TEXTURE, # input_zero_point
utils.NO_STORAGE, # weight (prepacked)
utils.NO_STORAGE, # weight_sums (prepacked)
utils.NO_STORAGE, # weight_scales (prepacked)
utils.NO_STORAGE, # bias (prepacked)
],
inputs_dtypes=utils.FP_T,
supports_prepacking=True,
)
Comment thread backends/vulkan/custom_ops_lib.py Outdated
Comment on lines +315 to +326
def linear_dq8ca_q8csw(
x: torch.Tensor,
input_scale: torch.Tensor,
input_zero_point: torch.Tensor,
weights: torch.Tensor,
weight_sums: torch.Tensor,
weight_scales: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
# Per-channel symmetric INT8 weight: dequant = weight.to(fp) * scales (per output channel)
weights_dq = weights.to(x.dtype) * weight_scales.unsqueeze(-1)
return torch.nn.functional.linear(x, weights_dq, bias)
Comment on lines +502 to +510
# Per-channel symmetric INT8 weight: no group_size, no nibble packing.
# Align width to 4 so GPU shader reads don't go OOB.
utils.align_width_and_update_state_dict(
ep,
match.weight_node,
weight_tensor,
align_to=1,
force_update=True,
)
Comment on lines +189 to +198
if (input_dtype == vkapi::kHalf) {
// INT8 dot4 → INT32 accum → fp32 dequant → fp16 store; the only fp16
// rounding is at the final store. Per-row dynamic act scale gives
// O(1) magnitudes pre-store, so a few ULPs of fp16 jitter is normal.
test_case.set_abs_tolerance(5.0f);
test_case.set_rel_tolerance(2e-1f);
} else {
test_case.set_abs_tolerance(1e-2f);
test_case.set_rel_tolerance(1e-2f);
}
Comment on lines +291 to +299
// CPU reference uses fp32 throughout; comparing against an fp16 GPU output
// hits inherent rounding mismatches on edge-case (near-zero) elements that
// exceed any practical tolerance. Match q4gsw_linear.cpp's convention and
// skip correctness for kHalf — performance timings still run.
if (input_spec.dtype == vkapi::kHalf) {
throw std::invalid_argument(
"Reference impl skipped for kHalf — fp16 round-trip diverges from "
"the fp32 CPU reference at near-zero elements.");
}
Comment on lines +137 to +143
const auto* adapter = graph->context()->adapter_ptr();
if (!adapter->supports_cooperative_matrix()) {
return false;
}
if (adapter->subgroup_size() != 64) {
return false;
}
Comment on lines +97 to +98
const uint A_STRIDE_VEC4 = (WG_TILE_K + FP16_PER_VEC4) / FP16_PER_VEC4;
const uint B_STRIDE_VEC4 = (WG_TILE_N + FP16_PER_VEC4) / FP16_PER_VEC4;

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 05a4f505ef

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +151 to +155
const std::vector<int64_t> out_sizes = graph->sizes_of(output);
const int64_t N = utils::val_at(-1, out_sizes);
const int64_t M = utils::val_at(-2, out_sizes);
const std::vector<int64_t> in_sizes = graph->sizes_of(fp_input);
const int64_t K = utils::val_at(-1, in_sizes);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Reject batched outputs before selecting coopmat

When out has rank > 2, this gate can still select the new coopmat shaders as long as the last two dimensions are 64-aligned. Those shaders use only gl_WorkGroupID.xy and store gi * N + gj (no batch/z offset), and the existing is_coopmat_eligible() helper in GemmCoopmat.h explicitly rejects graph.dim_of(out) > 2 for that reason. On coopmat-capable devices, a batched quantized linear with no bias will therefore compute only the first 2-D tile region / overwrite the wrong buffer region instead of falling back to the tiled path.

Useful? React with 👍 / 👎.

@SS-JIA

SS-JIA commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

ack -- looking now

Copilot AI review requested due to automatic review settings June 24, 2026 23:03

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot was unable to review this pull request because the user who requested the review has reached their quota limit.

Adds cooperative-matrix (WMMA) coopmat shaders and dispatch for the two
int4-weight quantized linear ops on the Vulkan backend:
  - linear_q4gsw_coopmat        (int4 group-symmetric weight-only)
  - linear_dq8ca_q4gsw_coopmat  (int8 dynamic act x int4 group-sym weight)

Dispatch in QuantizedLinear.cpp selects the coopmat shader when the device
supports cooperative matrix, subgroup size is 64, the output is buffer-stored
half, and M/N/K are tile-aligned; falls back to the tiled path otherwise.

The int8-input variant additionally gates on a new
Adapter::supports_int8_cooperative_matrix() (backed by a
vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR query in Device.cpp that
checks for VK_COMPONENT_TYPE_SINT8_KHR), so it is never selected on devices
that expose fp16 coopmat but not int8. The gate also rejects batched
(rank > 2) outputs, since the shaders dispatch over gl_WorkGroupID.xy only.

Tests: test_q4gsw_linear and test_coopmat_linear_bench cover both ops against
a CPU reference at coopmat-eligible shapes; test_coopmat_probe reports the
device's enumerated coopmat properties.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot was unable to review this pull request because the user who requested the review has reached their quota limit.

@xuyanwen2012 xuyanwen2012 requested a review from Copilot June 26, 2026 19:40

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot was unable to review this pull request because the user who requested the review has reached their quota limit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants