From 993cff58ae5232e5be0ec360b0cde49b746b2019 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 24 Jun 2026 01:03:55 -0700 Subject: [PATCH 1/3] [gemma4_31b][cuda] Export Gemma4-31B @128k under 32 GB Three CUDA-export memory optimizations: - tq4_sdpa: add BLOCK_N=16 (and a BLOCK_M=32) autotune config. The superset is kept for big-shared-memory GPUs (A100/H100); the Triton autotuner auto-prunes configs that exceed a GPU's shared memory (OutOfResources -> inf), so the same config list also works on the 5090 (Blackwell, ~101 KB SMEM) where the previous smallest config did not fit. - int4_dispatch: chunk the inline _dequant_matmul along N for vocab-sized weights (N>65536, i.e. only the lm_head). Avoids transiently materializing the full ~10 GiB bf16 lm_head when AOTI executes the int4_plain_mm custom op during autotune / cpp_wrapper. The runtime decode path uses the C++ dp4a shim and the M>4 prefill inline path is below the threshold, so this never enters the runtime graph -> zero runtime / accuracy impact. Applied unconditionally (no flag). - cuda_backend / aoti_backend: skip occupying the GPU with the KV-cache buffers during AOTI compile (gated behind low_memory_mode). A new move_program_to_device hook places KV constants on the target device but immediately frees their storage (resize_(0)), so the fake-tensor device check passes while no real KV bytes sit on the GPU during autotune. The emptied buffers are re-synthesized as zeros at the _unlift_graph clone and at serialization, and excluded from constant dedup (resize_(0) gives every KV data_ptr 0, which would otherwise collapse same-shape caches across layers). Result on 2xA100: Gemma4-31B @128k no-TQ export peak 36.3 -> 27.0 GiB; the exported model runs correctly (output "...Paris."). --- backends/aoti/aoti_backend.py | 25 ++- backends/cuda/cuda_backend.py | 172 +++++++++++++++++- .../quantize_op_dispatch/int4_dispatch.py | 46 ++++- backends/cuda/triton/kernels/tq4_sdpa.py | 5 + 4 files changed, 229 insertions(+), 19 deletions(-) diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py index 91a8a60078e..22f6feeab6c 100644 --- a/backends/aoti/aoti_backend.py +++ b/backends/aoti/aoti_backend.py @@ -112,6 +112,21 @@ def codesign_so(cls, so_path: str, compile_specs: List[CompileSpec]) -> None: """ return + @classmethod + def move_program_to_device( + cls, + edge_program: ExportedProgram, + device: str, + compile_specs: List[CompileSpec], + ) -> ExportedProgram: + """Move the exported program to the target device for compilation. + + Default implementation moves everything (params, buffers, constants) via + ``move_to_device_pass``. Concrete backends may override to keep large + non-parameter tensors off the device during a low-memory export. + """ + return move_to_device_pass(edge_program, device) + @classmethod def release_moved_tensors( cls, @@ -196,9 +211,13 @@ def preprocess( decomposition_table = cls.get_decomposition_table() options = cls.get_aoti_compile_options(compile_specs) - # Move the edge_program to the target device - device_edge_program = move_to_device_pass( - edge_program, device_name if device_name != "metal" else "mps" + # Move the edge_program to the target device. Routed through a hook so + # backends can keep large non-parameter tensors (e.g. KV-cache buffers) + # off the device during a low-memory export. + device_edge_program = cls.move_program_to_device( + edge_program, + device_name if device_name != "metal" else "mps", + compile_specs, ) # Replace view_copy with view diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index f9f23a842f9..1781c5bfd39 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -61,29 +61,81 @@ def _is_cpu_clone_active() -> bool: return getattr(_CPU_CLONE_GUARD, "active", False) +def _full_zeros_preserving_strides(x: torch.Tensor, device) -> torch.Tensor: + """Allocate a zero-filled tensor matching ``x``'s size/stride/dtype on ``device``. + + Used to re-synthesize KV-cache buffers whose storage was freed (``resize_(0)``) + during the low-memory device move. KV content is all zeros, so this exactly + reproduces the buffer for both the lifted graph value and serialization. + """ + needed = 1 + for size, stride in zip(x.size(), x.stride()): + needed += (size - 1) * stride + buf = torch.zeros(int(needed), dtype=x.dtype, device=device) + return torch.as_strided(buf, x.size(), x.stride()) + + +def _is_emptied(x) -> bool: + return ( + isinstance(x, torch.Tensor) + and x.numel() > 0 + and x.untyped_storage().nbytes() == 0 + ) + + @contextlib.contextmanager def _compile_time_cpu_clones(target_device: torch.device): """Force AOTI's mutated-buffer clones onto CPU while preserving the serialized constants' target device.""" - from torch._inductor import compile_fx as _cfx + from torch._inductor import compile_fx as _cfx, graph as _graph from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu as _Cpp + from torch._inductor.graph import GraphLowering as _GL orig_clone = _cfx.clone_preserve_strides orig_codegen_device = _Cpp.codegen_device + orig_get_const = _GL.get_original_value_of_constant + orig_is_same = _graph.is_same_tensor + + def _is_same_skip_emptied(data, value): + # KV buffers freed via resize_(0) all have data_ptr 0, so the stock + # is_same_tensor would treat every same-shape KV constant as a duplicate + # and collapse the 60 layers' caches into one — the runtime needs each + # FQN's own buffer, so the collapsed ones load uninitialized garbage. + # Never dedup an emptied tensor. + if _is_emptied(data) or _is_emptied(value): + return False + return orig_is_same(data, value) def _cpu_clone_preserve_strides(x: torch.Tensor) -> torch.Tensor: - # `clone_preserve_strides` is shared by `_unlift_graph` (clones - # lifted buffers — can be safely kept on CPU) and by autotuning code - # in `triton_heuristics.py` (clones for benchmark — must stay on - # GPU for Triton). Discriminate by caller frame so we only force - # CPU clones for the buffer-lifting path. + # `clone_preserve_strides` is shared by `_unlift_graph` (clones lifted + # buffers — can be safely kept on CPU) and by autotuning code in + # `triton_heuristics.py` (clones for benchmark — must stay on GPU for + # Triton). Discriminate by caller frame so we only force CPU clones for + # the buffer-lifting path. import sys caller = sys._getframe(1).f_code.co_name if caller == "_unlift_graph": + # KV-cache buffers are emptied (storage resize_(0)) by the low-memory + # device move so they never occupy GPU memory during compile. Their + # content is all zeros, so re-synthesize zeros (on CPU, strides + # preserved) instead of cloning the now-empty storage. + if _is_emptied(x): + return _full_zeros_preserving_strides(x, "cpu") return orig_clone(x).cpu() return orig_clone(x) + def _get_const_synthesize_zeros(self, name): + # AOTI serializes each constant via get_original_value_of_constant -> + # _to_bytes. For KV buffers we freed with resize_(0) this would otherwise + # fall back to the empty-storage constant and write 0 bytes, producing a + # .ptd with an uninitialized cache. Re-synthesize the zeros so the blob + # holds a correctly-zeroed KV cache. + value = orig_get_const(self, name) + if _is_emptied(value): + return _full_zeros_preserving_strides(value, "cpu") + return value + def _codegen_device_target_aware(self, device): # Translate accidental CPU device strings back to the model target # device only when a constant we forced to CPU is being serialized. @@ -99,6 +151,8 @@ def _codegen_device_target_aware(self, device): _cfx.clone_preserve_strides = _cpu_clone_preserve_strides _Cpp.codegen_device = _codegen_device_target_aware + _GL.get_original_value_of_constant = _get_const_synthesize_zeros + _graph.is_same_tensor = _is_same_skip_emptied prev_active = getattr(_CPU_CLONE_GUARD, "active", False) _CPU_CLONE_GUARD.active = True try: @@ -107,6 +161,89 @@ def _codegen_device_target_aware(self, device): _CPU_CLONE_GUARD.active = prev_active _cfx.clone_preserve_strides = orig_clone _Cpp.codegen_device = orig_codegen_device + _GL.get_original_value_of_constant = orig_get_const + _graph.is_same_tensor = orig_is_same + + +def _is_kv_buffer(name, v) -> bool: + return ( + isinstance(v, torch.Tensor) + and not isinstance(v, torch.nn.Parameter) + and "kv_cache" in name + ) + + +def _empty_strided_on_device(v, location): + """A device tensor with v's shape/stride/dtype but zero (freed) storage.""" + t = torch.empty_strided(v.shape, v.stride(), dtype=v.dtype, device=location) + t.untyped_storage().resize_(0) # free bytes, keep device + shape/stride + return t + + +def _move_graph_nodes_to_device(graph_module, location): + """Point node device kwargs / aten.to.device targets / meta vals at location.""" + import torch.utils._pytree as pytree + + def _to_loc(v): + return v.to(location) if isinstance(v, torch.Tensor) else v + + for m in graph_module.modules(): + if not isinstance(m, torch.fx.GraphModule): + continue + for node in m.graph.nodes: + if "device" in node.kwargs: + node.kwargs = {**node.kwargs, "device": location} + if node.op == "call_function" and node.target is torch.ops.aten.to.device: + args = list(node.args) + args[1] = location + node.args = tuple(args) + node.meta["val"] = pytree.tree_map(_to_loc, node.meta.get("val")) + + +def _move_to_device_resize_kv(ep, location): + """``move_to_device_pass`` variant that frees KV-cache storage on-device. + + Mirrors ``torch.export.passes.move_to_device_pass`` exactly, except KV-cache + buffers (FQN contains ``kv_cache``) are placed on ``location`` but with their + storage immediately freed via ``resize_(0)``. This keeps ``device == + location`` — so the fake-tensor device check on the ``index_copy`` cache + update passes (``self`` and ``values`` both on cuda) — while no real KV bytes + occupy the device during the AOTI compile. KV content is all zeros, so the + emptied tensors are re-synthesized as zeros at the ``_unlift_graph`` clone + (see ``_compile_time_cpu_clones``), which is reused as both the lifted initial + value and the serialized ``.ptd`` constant. The empty/free is interleaved per + tensor so the transient device peak is a single KV buffer, not the whole cache. + Only ``kv_cache`` tensors are emptied (they are the lone large zero-buffers); + every other tensor is moved normally so non-zero content is never lost. + """ + import torch.utils._pytree as pytree + + for k, v in ep.state_dict.items(): + if isinstance(v, torch.nn.Parameter): + ep._state_dict[k] = torch.nn.Parameter(v.to(location), v.requires_grad) + elif _is_kv_buffer(k, v): + ep._state_dict[k] = _empty_strided_on_device(v, location) + else: + ep._state_dict[k] = v.to(location) + + for k, v in ep.constants.items(): + if isinstance(v, torch.Tensor): + ep._constants[k] = ( + _empty_strided_on_device(v, location) + if _is_kv_buffer(k, v) + else v.to(location) + ) + + if ep.example_inputs is not None: + args, kwargs = ep.example_inputs + ep._example_inputs = ( + pytree.tree_map_only(torch.Tensor, lambda t: t.to(location), args), + pytree.tree_map_only(torch.Tensor, lambda t: t.to(location), kwargs), + ) + + _move_graph_nodes_to_device(ep.graph_module, location) + ep.validate() + return ep @final @@ -424,6 +561,29 @@ def _is_low_memory_mode(compile_specs: List[CompileSpec]) -> bool: return spec.value.decode("utf-8").upper() == "ON" return False + @classmethod + def move_program_to_device( + cls, + edge_program, + device: str, + compile_specs: List[CompileSpec], + ): + """Move the program to ``device`` for AOTI compile. + + On a low-memory export (``low_memory_mode="ON"``) the KV-cache buffers — + which can be 10+ GiB at long context — are placed on-device but with their + storage freed (``resize_(0)``), so they never occupy device memory during + the autotune / cpp_wrapper compile while still satisfying the device-match + check on the cache update. They are re-synthesized as zeros for the lifted + graph and the serialized blob. This activates automatically with low-memory + mode. Other (non-low-memory) exports use the stock pass. + """ + from torch.export.passes import move_to_device_pass + + if not cls._is_low_memory_mode(compile_specs): + return move_to_device_pass(edge_program, device) + return _move_to_device_resize_kv(edge_program, device) + @classmethod def release_moved_tensors( cls, diff --git a/backends/cuda/quantize_op_dispatch/int4_dispatch.py b/backends/cuda/quantize_op_dispatch/int4_dispatch.py index c3b8921e2fe..1b8c370eecf 100644 --- a/backends/cuda/quantize_op_dispatch/int4_dispatch.py +++ b/backends/cuda/quantize_op_dispatch/int4_dispatch.py @@ -60,11 +60,29 @@ def _cuda(self, qdata, scale, zero, group_size): return _dequant_matmul(self, qdata, scale, zero, group_size) +# Chunked dequant for the export GPU budget. The lm_head dequant (N = vocab_size, +# e.g. 262144) runs through the int4_plain_mm custom op (M=1); AOTI executes that +# op's CUDA impl during autotune / cpp_wrapper codegen, where it transiently holds +# ~5 full-size bf16 temporaries (low/high/data/data-z/w_deq) — ~10 GiB for a +# 262144-row weight even though the final w_deq is only ~2.6 GiB. Chunking along N +# caps that at ~chunk rows. It is numerically identical (F.linear output rows are +# independent), and because only the lm_head (custom-op) path crosses the N +# threshold — never the M>4 prefill inline path — it never enters the runtime +# graph: ZERO runtime / accuracy impact. Applied unconditionally to any weight +# whose row count exceeds the threshold. +_DEQUANT_N_THRESHOLD = 65536 +_DEQUANT_N_CHUNK = 32768 + + def _dequant_matmul(x, qdata, scale, zero, group_size): """Dequant INT4 weights to input dtype and call F.linear. scale/zero are in the coalesced [N, n_groups] layout (baked into the weight constant at pack time), aligned row-for-row with qdata's [N, *]. + + Large weights (N > threshold, i.e. the lm_head) are chunked along N to bound + the dequant intermediate (see note above); smaller weights take the original + single-shot dequant. """ N, K_half = qdata.shape K = K_half * 2 @@ -72,16 +90,24 @@ def _dequant_matmul(x, qdata, scale, zero, group_size): gs_half = group_size // 2 dtype = x.dtype - p = qdata.to(torch.uint8).reshape(N, n_groups, gs_half) - low = (p & 0x0F).to(dtype) - high = ((p >> 4) & 0x0F).to(dtype) - data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size) - - s = scale.to(dtype).unsqueeze(-1) - z = zero.to(dtype).unsqueeze(-1) - w_deq = ((data - z) * s).reshape(N, K) - - return F.linear(x, w_deq) + def _dq(qd, sc, ze, rows): + p = qd.to(torch.uint8).reshape(rows, n_groups, gs_half) + low = (p & 0x0F).to(dtype) + high = ((p >> 4) & 0x0F).to(dtype) + data = torch.stack([low, high], dim=-1).reshape(rows, n_groups, group_size) + s = sc.to(dtype).unsqueeze(-1) + z = ze.to(dtype).unsqueeze(-1) + w_deq = ((data - z) * s).reshape(rows, K) + return F.linear(x, w_deq) + + if N <= _DEQUANT_N_THRESHOLD: + return _dq(qdata, scale, zero, N) + + outs = [] + for i in range(0, N, _DEQUANT_N_CHUNK): + j = min(i + _DEQUANT_N_CHUNK, N) + outs.append(_dq(qdata[i:j], scale[i:j], zero[i:j], j - i)) + return torch.cat(outs, dim=-1) # --------------------------------------------------------------------------- diff --git a/backends/cuda/triton/kernels/tq4_sdpa.py b/backends/cuda/triton/kernels/tq4_sdpa.py index 10f02c7fa3c..7a41eaf92c1 100644 --- a/backends/cuda/triton/kernels/tq4_sdpa.py +++ b/backends/cuda/triton/kernels/tq4_sdpa.py @@ -294,6 +294,10 @@ def _tq4_sdpa_fwd_kernel_body( triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2), triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3), triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=3), ], key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], ) @@ -410,6 +414,7 @@ def _tq4_sdpa_fwd_kernel_m64( triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2), triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 16}, num_warps=4, num_stages=3), ], key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], ) From 92d62c974345d8fd387d6698f5e161b542ec9939 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 25 Jun 2026 08:55:35 -0700 Subject: [PATCH 2/3] Fix TurboQuant KV zeroed by low-mem export (993cff58ae): _is_kv_buffer only frees genuinely all-zero kv_cache.* buffers (count_nonzero==0); preserves TQ4 centroids/boundaries/rotation/rotation_T --- backends/cuda/cuda_backend.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 1781c5bfd39..b328a05df54 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -166,11 +166,29 @@ def _codegen_device_target_aware(self, device): def _is_kv_buffer(name, v) -> bool: - return ( - isinstance(v, torch.Tensor) - and not isinstance(v, torch.nn.Parameter) - and "kv_cache" in name - ) + """True only for an actual KV-cache *content* buffer that is safe to free. + + The low-memory path (``_move_to_device_resize_kv``) frees every buffer this + matches and re-synthesizes it as ZEROS in both the lifted graph and the + serialized ``.ptd`` (see ``_full_zeros_preserving_strides`` / + ``_get_const_synthesize_zeros``). That is only valid for genuine KV *content*, + which is all-zeros at export time (caches start empty). + + It must NOT match the non-zero constants that some KV-cache modules register + alongside the cache — e.g. TurboQuant registers its codebook/rotation + (``centroids``/``boundaries``/``rotation``/``rotation_T``) as buffers on the + ``kv_cache`` module, so their FQNs also contain ``kv_cache``. Freeing+zeroing + those silently corrupts the serialized model (TQ4 dequant -> 0 -> garbage). + Gate on the buffer actually being all-zeros so only empty KV content is freed; + this is robust to any future constant name (a non-zero buffer is never freed). + """ + if not isinstance(v, torch.Tensor) or isinstance(v, torch.nn.Parameter): + return False + if "kv_cache" not in name or v.numel() == 0 or v.is_meta: + return False + # Only the genuinely all-zero KV content may be freed + re-zeroed; non-zero + # constants (TurboQuant centroids/rotation/...) must be preserved as-is. + return bool(torch.count_nonzero(v) == 0) def _empty_strided_on_device(v, location): From 4025660ac810cf796f6a19c06692b1777e0ac145 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 25 Jun 2026 08:56:51 -0700 Subject: [PATCH 3/3] [executorch][cuda] gemma4_31b: fuse gate/up MLP projections (default-on) Summary: Fuse each gemma4_31b MLP's gate_proj|up_proj into a single [2*intermediate, hidden] coalesced-int4 matmul, applied by default in the CUDA export. This issues one activation-quant + one W4A8 matvec per layer instead of two, cutting per-token launch + activation-quant overhead in the launch-bound decode path. Only Q4_K (CudaCoalescedInt4Tensor) gate/up pairs are fused; any other quant type (e.g. Q6_K) is left as two matmuls (guarded, still correct). Builds on the already-landed kv_len-bounded tq4_sdpa kernel + gemma4_31b call-site (kv_len + mask_is_causal), which recovered 128k decode from ~2.8 to ~43 tok/s. With both, ET gemma4_31b 128k+TurboQuant decode beats llama.cpp at every measured context (cuda_graph ON): ctx ET llama 512 44.80 42.77 2K 43.20 41.97 8K 42.23 41.23 32K 41.64 40.27 127K 38.41 35.97 TurboQuant KV compression kept; prefill restored (6-8x) with no regression; output quality preserved. Test Plan: - Fusion numerics: fused vs unfused MLP through the real W4A8 int4_plain_mm kernel = bit-exact (max_abs_diff 0.0, cos 1.000000) for decode (T=1) and prefill (T=4). - Export + run: fused module exported via CudaPartitioner and executed through executor_runner (RC=0, cos 0.999915 vs eager). Full 31B export logs "Fused gate+up on 60 MLP layers". - Decode A/B (gemma4_31b 128k+TQ, cuda_graph ON, 5x median): table above; beats llama.cpp at 512 -> 127K. nsys: tq4_sdpa 91.7% -> 2.9% of decode. --- .../gemma4_31b/cuda_source_transformations.py | 107 ++++++++++++++++++ examples/models/gemma4_31b/export.py | 9 +- 2 files changed, 111 insertions(+), 5 deletions(-) diff --git a/examples/models/gemma4_31b/cuda_source_transformations.py b/examples/models/gemma4_31b/cuda_source_transformations.py index 666d0c44e9d..6609178e084 100644 --- a/examples/models/gemma4_31b/cuda_source_transformations.py +++ b/examples/models/gemma4_31b/cuda_source_transformations.py @@ -30,6 +30,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from executorch.examples.models.gemma4.text_decoder import apply_rotary_emb from executorch.extension.llm.modules.turboquant import TurboQuantKVCache @@ -110,6 +111,105 @@ def _turboquant_attention_forward( return self.o_proj(y) +def _fused_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: + """Drop-in ``Gemma4MLP.forward`` over a fused gate|up projection. + + Identical math to ``down(gelu(gate(x)) * up(x))``: the single + ``gate_up_proj`` emits ``[gate | up]`` concatenated on the last dim, + which is then split. One W4A8 matmul (and one activation-quant of ``x``) + instead of two. + """ + h = self.gate_up_proj(x) + gate = h[..., : self.intermediate_size] + up = h[..., self.intermediate_size :] + return self.down_proj(F.gelu(gate, approximate="tanh") * up) + + +def _concat_coalesced_int4_along_n(a, b): + """Concatenate two ``CudaCoalescedInt4Tensor`` along the output (N) dim. + + qdata is ``[N, K/2]`` and scale/zero_point are ``[N, n_groups]`` in the + coalesced layout, so a per-output-row concat on dim 0 is exact: the W4A8 + dp4a matvec reads each output row's qdata/scale/zero independently, so + out[:N_a] reproduces ``a`` and out[N_a:] reproduces ``b`` bit-for-bit. + """ + from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor + + return CudaCoalescedInt4Tensor( + torch.cat([a.qdata, b.qdata], dim=0), + torch.cat([a.scale, b.scale], dim=0), + torch.cat([a.zero_point, b.zero_point], dim=0), + a.block_size, + torch.Size([a.shape[0] + b.shape[0], a.shape[1]]), + None, + a.activation_dtype, + ) + + +def _is_fuseable_int4_pair(gate_w, up_w) -> bool: + """True iff gate/up are both coalesced-int4 with matching K + block_size. + + Q4_K MLP weights become ``CudaCoalescedInt4Tensor`` (fuseable); a Q6_K + weight becomes ``CudaDp4aPlanarInt6Tensor`` (left alone). ``act_pre_scale`` + is unused on this path but we require it absent so the concat stays exact. + """ + from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor + + return ( + isinstance(gate_w, CudaCoalescedInt4Tensor) + and isinstance(up_w, CudaCoalescedInt4Tensor) + and list(gate_w.block_size) == list(up_w.block_size) + and gate_w.shape[1] == up_w.shape[1] + and gate_w.act_pre_scale is None + and up_w.act_pre_scale is None + ) + + +def _fuse_gate_up_proj(model: nn.Module) -> None: + """Fuse each MLP's ``gate_proj | up_proj`` into one ``gate_up_proj``. + + gate and up share the same input, so the unfused path quantizes ``x`` to + int8 twice and launches two W4A8 matvecs per layer. Fusing the weights + into one ``[2*inter, hidden]`` tensor halves both. Weight bytes read are + unchanged, so the win is launch + activation-quant overhead (decode is + launch-bound). Only Q4_K (coalesced-int4) layers are fused; any layer + with a non-int4 weight is left as two matmuls (still correct). + + Must run AFTER weights are packed to ``CudaCoalescedInt4Tensor`` (i.e. + inside ``_export_cuda``), and is independent of TurboQuant. + """ + n_fused = 0 + n_skipped = 0 + for layer in model.layers: + mlp = getattr(layer, "mlp", None) + if mlp is None or not (hasattr(mlp, "gate_proj") and hasattr(mlp, "up_proj")): + continue + gate_w = mlp.gate_proj.weight + up_w = mlp.up_proj.weight + if not _is_fuseable_int4_pair(gate_w, up_w): + n_skipped += 1 + continue + inter = up_w.shape[0] + hidden = up_w.shape[1] + fused_w = _concat_coalesced_int4_along_n(gate_w, up_w) + + # Container built on meta to avoid materializing a dense + # [2*inter, hidden] weight before we overwrite it with fused_w. + gate_up = nn.Linear(hidden, 2 * inter, bias=False, device="meta") + gate_up.weight = nn.Parameter(fused_w, requires_grad=False) + mlp.gate_up_proj = gate_up + mlp.intermediate_size = inter + del mlp.gate_proj + del mlp.up_proj + mlp.forward = types.MethodType(_fused_mlp_forward, mlp) + n_fused += 1 + + msg = f"[gemma4_31b cuda] Fused gate+up on {n_fused} MLP layers" + if n_skipped: + msg += f" ({n_skipped} skipped: non-int4 weights)" + print(msg) + + def cuda_source_transformations( model: nn.Module, *, @@ -117,6 +217,11 @@ def cuda_source_transformations( ) -> None: """Apply CUDA source transformations to a Gemma 4 31B model in place. + Always fuses each MLP's ``gate_proj|up_proj`` into a single matmul (one + activation-quant + one W4A8 matvec per layer instead of two; Q4_K + coalesced-int4 layers only — other quant types are left untouched). + Optionally also swaps full-attention KV caches for TurboQuant TQ4. + Args: model: ``Gemma4_31B`` instance to transform. use_turboquant: When True, swap full-attention layers' KV caches @@ -125,6 +230,8 @@ def cuda_source_transformations( ``torch.ops.triton.tq4_sdpa``. Sliding-window layers are unaffected. """ + _fuse_gate_up_proj(model) + if not use_turboquant: return diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index d9e16bc34df..b2b2264178a 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -182,12 +182,11 @@ def _export_cuda( materialize_runtime_buffers(model, dtype=torch.bfloat16) - if use_turboquant: - from executorch.examples.models.gemma4_31b.cuda_source_transformations import ( - cuda_source_transformations, - ) + from executorch.examples.models.gemma4_31b.cuda_source_transformations import ( + cuda_source_transformations, + ) - cuda_source_transformations(model, use_turboquant=True) + cuda_source_transformations(model, use_turboquant=use_turboquant) # Int4Tensor weights are used directly — no format conversion. # F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim).