From 8b145b5fc6d3d18b3d5702ea0823402a7699f401 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 23 Jun 2026 15:21:25 -0700 Subject: [PATCH 1/2] [executorch][cuda] gemma4_31b: fuse gate/up MLP projections (default-on) Summary: Fuse each gemma4_31b MLP's gate_proj|up_proj into a single [2*intermediate, hidden] coalesced-int4 matmul, applied by default in the CUDA export. This issues one activation-quant + one W4A8 matvec per layer instead of two, cutting per-token launch + activation-quant overhead in the launch-bound decode path. Only Q4_K (CudaCoalescedInt4Tensor) gate/up pairs are fused; any other quant type (e.g. Q6_K) is left as two matmuls (guarded, still correct). Builds on the already-landed kv_len-bounded tq4_sdpa kernel + gemma4_31b call-site (kv_len + mask_is_causal), which recovered 128k decode from ~2.8 to ~43 tok/s. With both, ET gemma4_31b 128k+TurboQuant decode beats llama.cpp at every measured context (cuda_graph ON): ctx ET llama 512 44.80 42.77 2K 43.20 41.97 8K 42.23 41.23 32K 41.64 40.27 127K 38.41 35.97 TurboQuant KV compression kept; prefill restored (6-8x) with no regression; output quality preserved. Test Plan: - Fusion numerics: fused vs unfused MLP through the real W4A8 int4_plain_mm kernel = bit-exact (max_abs_diff 0.0, cos 1.000000) for decode (T=1) and prefill (T=4). - Export + run: fused module exported via CudaPartitioner and executed through executor_runner (RC=0, cos 0.999915 vs eager). Full 31B export logs "Fused gate+up on 60 MLP layers". - Decode A/B (gemma4_31b 128k+TQ, cuda_graph ON, 5x median): table above; beats llama.cpp at 512 -> 127K. nsys: tq4_sdpa 91.7% -> 2.9% of decode. --- .../gemma4_31b/cuda_source_transformations.py | 107 ++++++++++++++++++ examples/models/gemma4_31b/export.py | 9 +- 2 files changed, 111 insertions(+), 5 deletions(-) diff --git a/examples/models/gemma4_31b/cuda_source_transformations.py b/examples/models/gemma4_31b/cuda_source_transformations.py index 666d0c44e9d..6609178e084 100644 --- a/examples/models/gemma4_31b/cuda_source_transformations.py +++ b/examples/models/gemma4_31b/cuda_source_transformations.py @@ -30,6 +30,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from executorch.examples.models.gemma4.text_decoder import apply_rotary_emb from executorch.extension.llm.modules.turboquant import TurboQuantKVCache @@ -110,6 +111,105 @@ def _turboquant_attention_forward( return self.o_proj(y) +def _fused_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: + """Drop-in ``Gemma4MLP.forward`` over a fused gate|up projection. + + Identical math to ``down(gelu(gate(x)) * up(x))``: the single + ``gate_up_proj`` emits ``[gate | up]`` concatenated on the last dim, + which is then split. One W4A8 matmul (and one activation-quant of ``x``) + instead of two. + """ + h = self.gate_up_proj(x) + gate = h[..., : self.intermediate_size] + up = h[..., self.intermediate_size :] + return self.down_proj(F.gelu(gate, approximate="tanh") * up) + + +def _concat_coalesced_int4_along_n(a, b): + """Concatenate two ``CudaCoalescedInt4Tensor`` along the output (N) dim. + + qdata is ``[N, K/2]`` and scale/zero_point are ``[N, n_groups]`` in the + coalesced layout, so a per-output-row concat on dim 0 is exact: the W4A8 + dp4a matvec reads each output row's qdata/scale/zero independently, so + out[:N_a] reproduces ``a`` and out[N_a:] reproduces ``b`` bit-for-bit. + """ + from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor + + return CudaCoalescedInt4Tensor( + torch.cat([a.qdata, b.qdata], dim=0), + torch.cat([a.scale, b.scale], dim=0), + torch.cat([a.zero_point, b.zero_point], dim=0), + a.block_size, + torch.Size([a.shape[0] + b.shape[0], a.shape[1]]), + None, + a.activation_dtype, + ) + + +def _is_fuseable_int4_pair(gate_w, up_w) -> bool: + """True iff gate/up are both coalesced-int4 with matching K + block_size. + + Q4_K MLP weights become ``CudaCoalescedInt4Tensor`` (fuseable); a Q6_K + weight becomes ``CudaDp4aPlanarInt6Tensor`` (left alone). ``act_pre_scale`` + is unused on this path but we require it absent so the concat stays exact. + """ + from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor + + return ( + isinstance(gate_w, CudaCoalescedInt4Tensor) + and isinstance(up_w, CudaCoalescedInt4Tensor) + and list(gate_w.block_size) == list(up_w.block_size) + and gate_w.shape[1] == up_w.shape[1] + and gate_w.act_pre_scale is None + and up_w.act_pre_scale is None + ) + + +def _fuse_gate_up_proj(model: nn.Module) -> None: + """Fuse each MLP's ``gate_proj | up_proj`` into one ``gate_up_proj``. + + gate and up share the same input, so the unfused path quantizes ``x`` to + int8 twice and launches two W4A8 matvecs per layer. Fusing the weights + into one ``[2*inter, hidden]`` tensor halves both. Weight bytes read are + unchanged, so the win is launch + activation-quant overhead (decode is + launch-bound). Only Q4_K (coalesced-int4) layers are fused; any layer + with a non-int4 weight is left as two matmuls (still correct). + + Must run AFTER weights are packed to ``CudaCoalescedInt4Tensor`` (i.e. + inside ``_export_cuda``), and is independent of TurboQuant. + """ + n_fused = 0 + n_skipped = 0 + for layer in model.layers: + mlp = getattr(layer, "mlp", None) + if mlp is None or not (hasattr(mlp, "gate_proj") and hasattr(mlp, "up_proj")): + continue + gate_w = mlp.gate_proj.weight + up_w = mlp.up_proj.weight + if not _is_fuseable_int4_pair(gate_w, up_w): + n_skipped += 1 + continue + inter = up_w.shape[0] + hidden = up_w.shape[1] + fused_w = _concat_coalesced_int4_along_n(gate_w, up_w) + + # Container built on meta to avoid materializing a dense + # [2*inter, hidden] weight before we overwrite it with fused_w. + gate_up = nn.Linear(hidden, 2 * inter, bias=False, device="meta") + gate_up.weight = nn.Parameter(fused_w, requires_grad=False) + mlp.gate_up_proj = gate_up + mlp.intermediate_size = inter + del mlp.gate_proj + del mlp.up_proj + mlp.forward = types.MethodType(_fused_mlp_forward, mlp) + n_fused += 1 + + msg = f"[gemma4_31b cuda] Fused gate+up on {n_fused} MLP layers" + if n_skipped: + msg += f" ({n_skipped} skipped: non-int4 weights)" + print(msg) + + def cuda_source_transformations( model: nn.Module, *, @@ -117,6 +217,11 @@ def cuda_source_transformations( ) -> None: """Apply CUDA source transformations to a Gemma 4 31B model in place. + Always fuses each MLP's ``gate_proj|up_proj`` into a single matmul (one + activation-quant + one W4A8 matvec per layer instead of two; Q4_K + coalesced-int4 layers only — other quant types are left untouched). + Optionally also swaps full-attention KV caches for TurboQuant TQ4. + Args: model: ``Gemma4_31B`` instance to transform. use_turboquant: When True, swap full-attention layers' KV caches @@ -125,6 +230,8 @@ def cuda_source_transformations( ``torch.ops.triton.tq4_sdpa``. Sliding-window layers are unaffected. """ + _fuse_gate_up_proj(model) + if not use_turboquant: return diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index d9e16bc34df..b2b2264178a 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -182,12 +182,11 @@ def _export_cuda( materialize_runtime_buffers(model, dtype=torch.bfloat16) - if use_turboquant: - from executorch.examples.models.gemma4_31b.cuda_source_transformations import ( - cuda_source_transformations, - ) + from executorch.examples.models.gemma4_31b.cuda_source_transformations import ( + cuda_source_transformations, + ) - cuda_source_transformations(model, use_turboquant=True) + cuda_source_transformations(model, use_turboquant=use_turboquant) # Int4Tensor weights are used directly — no format conversion. # F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim). From 638f07ae1d3aad4f00122217dde062d5a0a4b3a8 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 23 Jun 2026 17:08:39 -0700 Subject: [PATCH 2/2] [executorch][gemma4] fuse MLP gate/up at GGUF load (single point, cuda+mlx) Summary: Move the gemma4 MLP gate_proj|up_proj fusion to a single backend-agnostic point in the GGUF loader, and make the model forward consume it. Supersedes the earlier CUDA-only export-time fusion (reverted here). - gguf_loader.py: before any backend conversion (_convert_weight), buffer each layer's raw gate/up ExportableGGUFTensor and, once both arrive, row-concat their raw GGUF blocks along the output dim into one fused gate_up ExportableGGUFTensor (gate rows then up rows). Both backends then pack the already-fused weight with NO per-type concat: CUDA (Q4_K -> CudaCoalescedInt4Tensor, Q6_K -> CudaDp4aPlanarInt6Tensor) and MLX (ExportableGGUFTensor). Guards: same ggml_type + K; non-fuseable pairs and unpaired leftovers fall through unfused. - Gemma4MLP: when a fused gate_up_proj is present, run one matmul and split the [.., 2*intermediate_size] output back into gate/up; otherwise use the separate projections. The shared MLP stays safe for unfused checkpoints and the prequant/HF load paths (no gate_up_proj -> original path, no crash). - Revert the previous CUDA-localized fusion (cuda_source_transformations.py and export.py back to their original form). The kv_len-bounded tq4_sdpa kernel + call-site (already on main) are unchanged. Single fusion point widens applicability (CUDA + MLX, incl. Q6_K) and keeps the model def backend-agnostic. Decode win is unchanged (same fused matmul, produced at load instead of at export). Test Plan: - Raw concat (real GGUF blk.0 ffn, q4_k): fused.dequantize() == [gate; up] stacked, bit-exact; fused CudaCoalescedInt4Tensor rows [:N]/[N:] qdata+scale+zero bit-identical to gate/up. - Model-def fused vs unfused forward through real W4A8 int4_plain_mm: decode (T=1) bit-exact (cos 1.000000); prefill (T=4) cos 0.999988 -- the only delta is cuBLAS GEMM shape-dependent fp ordering (N=43008 vs 21504, identical weights), benign and inherent to any gate/up fusion. - Full CUDA GGUF export (gemma4_31b, --turboquant, max-seq-len 131072): loader logs "Fused gate+up on 60 MLP layers", TurboQuant swaps 10 layers, AOTI build clean (model.pte + 26.18GB aoti_cuda_blob.ptd, "Done."). - Decode via gemma4_31b_runner on the new build: coherent output, no NaN; prefill 1375 tok/s, decode 38.3 tok/s (no cuda_graph sanity). --- .../text_decoder/gemma4_decoder_layer.py | 17 ++- .../gemma4_31b/cuda_source_transformations.py | 107 ---------------- examples/models/gemma4_31b/export.py | 9 +- examples/models/gemma4_31b/gguf_loader.py | 119 ++++++++++++++++++ 4 files changed, 138 insertions(+), 114 deletions(-) diff --git a/examples/models/gemma4/text_decoder/gemma4_decoder_layer.py b/examples/models/gemma4/text_decoder/gemma4_decoder_layer.py index e10c1c7e415..fe3e3bb94cb 100644 --- a/examples/models/gemma4/text_decoder/gemma4_decoder_layer.py +++ b/examples/models/gemma4/text_decoder/gemma4_decoder_layer.py @@ -34,14 +34,25 @@ class Gemma4MLP(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() + self.intermediate_size = intermediate_size self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj( - F.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x) - ) + # If a loader fused gate_proj|up_proj into one gate_up_proj (single + # matmul; e.g. the GGUF loader's coalesced fusion), use it and split the + # [.., 2*intermediate_size] output back into gate/up. Otherwise fall back + # to the separate projections (unfused checkpoints / non-fusing loaders). + gate_up = getattr(self, "gate_up_proj", None) + if gate_up is not None: + fused = gate_up(x) + gate = fused[..., : self.intermediate_size] + up = fused[..., self.intermediate_size :] + else: + gate = self.gate_proj(x) + up = self.up_proj(x) + return self.down_proj(F.gelu(gate, approximate="tanh") * up) class Gemma4DecoderLayer(nn.Module): diff --git a/examples/models/gemma4_31b/cuda_source_transformations.py b/examples/models/gemma4_31b/cuda_source_transformations.py index 6609178e084..666d0c44e9d 100644 --- a/examples/models/gemma4_31b/cuda_source_transformations.py +++ b/examples/models/gemma4_31b/cuda_source_transformations.py @@ -30,7 +30,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from executorch.examples.models.gemma4.text_decoder import apply_rotary_emb from executorch.extension.llm.modules.turboquant import TurboQuantKVCache @@ -111,105 +110,6 @@ def _turboquant_attention_forward( return self.o_proj(y) -def _fused_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: - """Drop-in ``Gemma4MLP.forward`` over a fused gate|up projection. - - Identical math to ``down(gelu(gate(x)) * up(x))``: the single - ``gate_up_proj`` emits ``[gate | up]`` concatenated on the last dim, - which is then split. One W4A8 matmul (and one activation-quant of ``x``) - instead of two. - """ - h = self.gate_up_proj(x) - gate = h[..., : self.intermediate_size] - up = h[..., self.intermediate_size :] - return self.down_proj(F.gelu(gate, approximate="tanh") * up) - - -def _concat_coalesced_int4_along_n(a, b): - """Concatenate two ``CudaCoalescedInt4Tensor`` along the output (N) dim. - - qdata is ``[N, K/2]`` and scale/zero_point are ``[N, n_groups]`` in the - coalesced layout, so a per-output-row concat on dim 0 is exact: the W4A8 - dp4a matvec reads each output row's qdata/scale/zero independently, so - out[:N_a] reproduces ``a`` and out[N_a:] reproduces ``b`` bit-for-bit. - """ - from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor - - return CudaCoalescedInt4Tensor( - torch.cat([a.qdata, b.qdata], dim=0), - torch.cat([a.scale, b.scale], dim=0), - torch.cat([a.zero_point, b.zero_point], dim=0), - a.block_size, - torch.Size([a.shape[0] + b.shape[0], a.shape[1]]), - None, - a.activation_dtype, - ) - - -def _is_fuseable_int4_pair(gate_w, up_w) -> bool: - """True iff gate/up are both coalesced-int4 with matching K + block_size. - - Q4_K MLP weights become ``CudaCoalescedInt4Tensor`` (fuseable); a Q6_K - weight becomes ``CudaDp4aPlanarInt6Tensor`` (left alone). ``act_pre_scale`` - is unused on this path but we require it absent so the concat stays exact. - """ - from executorch.backends.cuda.coalesced_int4_tensor import CudaCoalescedInt4Tensor - - return ( - isinstance(gate_w, CudaCoalescedInt4Tensor) - and isinstance(up_w, CudaCoalescedInt4Tensor) - and list(gate_w.block_size) == list(up_w.block_size) - and gate_w.shape[1] == up_w.shape[1] - and gate_w.act_pre_scale is None - and up_w.act_pre_scale is None - ) - - -def _fuse_gate_up_proj(model: nn.Module) -> None: - """Fuse each MLP's ``gate_proj | up_proj`` into one ``gate_up_proj``. - - gate and up share the same input, so the unfused path quantizes ``x`` to - int8 twice and launches two W4A8 matvecs per layer. Fusing the weights - into one ``[2*inter, hidden]`` tensor halves both. Weight bytes read are - unchanged, so the win is launch + activation-quant overhead (decode is - launch-bound). Only Q4_K (coalesced-int4) layers are fused; any layer - with a non-int4 weight is left as two matmuls (still correct). - - Must run AFTER weights are packed to ``CudaCoalescedInt4Tensor`` (i.e. - inside ``_export_cuda``), and is independent of TurboQuant. - """ - n_fused = 0 - n_skipped = 0 - for layer in model.layers: - mlp = getattr(layer, "mlp", None) - if mlp is None or not (hasattr(mlp, "gate_proj") and hasattr(mlp, "up_proj")): - continue - gate_w = mlp.gate_proj.weight - up_w = mlp.up_proj.weight - if not _is_fuseable_int4_pair(gate_w, up_w): - n_skipped += 1 - continue - inter = up_w.shape[0] - hidden = up_w.shape[1] - fused_w = _concat_coalesced_int4_along_n(gate_w, up_w) - - # Container built on meta to avoid materializing a dense - # [2*inter, hidden] weight before we overwrite it with fused_w. - gate_up = nn.Linear(hidden, 2 * inter, bias=False, device="meta") - gate_up.weight = nn.Parameter(fused_w, requires_grad=False) - mlp.gate_up_proj = gate_up - mlp.intermediate_size = inter - del mlp.gate_proj - del mlp.up_proj - mlp.forward = types.MethodType(_fused_mlp_forward, mlp) - n_fused += 1 - - msg = f"[gemma4_31b cuda] Fused gate+up on {n_fused} MLP layers" - if n_skipped: - msg += f" ({n_skipped} skipped: non-int4 weights)" - print(msg) - - def cuda_source_transformations( model: nn.Module, *, @@ -217,11 +117,6 @@ def cuda_source_transformations( ) -> None: """Apply CUDA source transformations to a Gemma 4 31B model in place. - Always fuses each MLP's ``gate_proj|up_proj`` into a single matmul (one - activation-quant + one W4A8 matvec per layer instead of two; Q4_K - coalesced-int4 layers only — other quant types are left untouched). - Optionally also swaps full-attention KV caches for TurboQuant TQ4. - Args: model: ``Gemma4_31B`` instance to transform. use_turboquant: When True, swap full-attention layers' KV caches @@ -230,8 +125,6 @@ def cuda_source_transformations( ``torch.ops.triton.tq4_sdpa``. Sliding-window layers are unaffected. """ - _fuse_gate_up_proj(model) - if not use_turboquant: return diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index b2b2264178a..d9e16bc34df 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -182,11 +182,12 @@ def _export_cuda( materialize_runtime_buffers(model, dtype=torch.bfloat16) - from executorch.examples.models.gemma4_31b.cuda_source_transformations import ( - cuda_source_transformations, - ) + if use_turboquant: + from executorch.examples.models.gemma4_31b.cuda_source_transformations import ( + cuda_source_transformations, + ) - cuda_source_transformations(model, use_turboquant=use_turboquant) + cuda_source_transformations(model, use_turboquant=True) # Int4Tensor weights are used directly — no format conversion. # F.linear dispatches to executorch_cuda::int4_plain_mm (CUDA shim). diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 90839ea6f6a..6a4a70ced18 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -104,6 +104,89 @@ def _convert_weight(model, model_key: str, gtensor, backend: str): return gtensor +# --------------------------------------------------------------------------- +# Single-point gate/up fusion (backend-agnostic, at the raw GGUF level) +# +# gate_proj and up_proj share the same input, so the MLP can issue ONE matmul +# over a [2*intermediate, hidden] weight instead of two. We fuse here -- before +# any backend conversion (_convert_weight) -- by concatenating the two raw GGUF +# block blobs along the output (row) dim. ExportableGGUFTensor.raw is +# (N, row_bytes) row-major with each output row self-contained, so the concat is +# an exact row-stack (no re-quant, no scale recompute). Both CUDA and MLX then +# pack the already-fused weight, so there is no per-backend-type concat. The +# model's Gemma4MLP.forward splits the [.., 2*intermediate] output back into +# gate/up only when a fused gate_up_proj is present (graceful for unfused loads). + + +def _gate_up_layer_kind(model_key: str): + """If ``model_key`` is an MLP gate/up proj weight, return ``(layer_idx, kind)`` + with ``kind`` in ``{"gate", "up"}``; otherwise ``None``.""" + prefix = "layers." + for kind in ("gate", "up"): + suffix = f".mlp.{kind}_proj.weight" + if model_key.startswith(prefix) and model_key.endswith(suffix): + mid = model_key[len(prefix) : len(model_key) - len(suffix)] + if mid.isdigit(): + return int(mid), kind + return None + + +def _gate_up_fuseable(gate, up) -> bool: + """True iff gate/up are the same GGUF quant type and same packed row width + (hence same K + block layout), so a row-concat along output N is valid.""" + return ( + gate.ggml_type == up.ggml_type + and gate.raw.shape[1] == up.raw.shape[1] + and int(gate.shape[1]) == int(up.shape[1]) + ) + + +def _fuse_gate_up_raw(gate, up): + """Row-concat gate|up raw GGUF blocks (gate rows first) into one fused + ExportableGGUFTensor of shape (2*N, K).""" + from executorch.extension.llm.export.gguf import ExportableGGUFTensor + + fused_raw = torch.cat([gate.raw, up.raw], dim=0) + return ExportableGGUFTensor.from_raw(fused_raw, gate.ggml_type, gate.orig_dtype) + + +def _assign_gate_up_unfused(model, layer_idx, kind, gtensor, backend, packers): + """Assign a single gate/up GGUF tensor to its own projection (no fusion).""" + from executorch.examples.models.gemma4_31b.quant import pack_one + + key = f"layers.{layer_idx}.mlp.{kind}_proj.weight" + pack_one(model, key, _convert_weight(model, key, gtensor, backend), packers) + + +def _install_and_pack_fused_gate_up(model, layer_idx, gate, up, backend, packers): + """Fuse gate|up at the raw level, swap the layer's MLP to a single + ``gate_up_proj`` (dropping gate_proj/up_proj), then pack the fused weight.""" + import torch.nn as nn + + from executorch.examples.models.gemma4_31b.quant import pack_one + + fused = _fuse_gate_up_raw(gate, up) + inter, hidden = int(gate.shape[0]), int(gate.shape[1]) + + mlp = model.get_submodule(f"layers.{layer_idx}.mlp") + mlp.gate_up_proj = nn.Linear(hidden, 2 * inter, bias=False, device="meta") + del mlp.gate_proj + del mlp.up_proj + + key = f"layers.{layer_idx}.mlp.gate_up_proj.weight" + pack_one(model, key, _convert_weight(model, key, fused, backend), packers) + + +def _process_gate_up_pair(model, layer_idx, gate, up, backend, packers) -> bool: + """Fuse gate|up if compatible (returns True), else assign them unfused.""" + if _gate_up_fuseable(gate, up): + _install_and_pack_fused_gate_up(model, layer_idx, gate, up, backend, packers) + return True + _assign_gate_up_unfused(model, layer_idx, "gate", gate, backend, packers) + _assign_gate_up_unfused(model, layer_idx, "up", up, backend, packers) + return False + + def _resolve_tied_lm_head(model, lm_head_weight, packers): """Assign a tied lm_head (GGUF ties it to the token embedding).""" from executorch.examples.models.gemma4_31b.quant import pack_one @@ -217,11 +300,32 @@ def load_gguf_model( n_processed = 0 print(f"Streaming GGUF from {gguf_path}...") + pending_gate_up: dict = {} # layer_idx -> {"gate": raw, "up": raw} + n_fused = 0 + n_unfused = 0 for gguf_name, value in iter_gguf(gguf_path): model_key = gguf_to_model_key(gguf_name) if model_key is None: continue + # Buffer the RAW gate/up ExportableGGUFTensor (pre-conversion) and fuse + # once both arrive -- the single common point upstream of _convert_weight. + gu = _gate_up_layer_kind(model_key) + if gu is not None and isinstance(value, ExportableGGUFTensor): + layer_idx, kind = gu + slot = pending_gate_up.setdefault(layer_idx, {}) + slot[kind] = value + if "gate" in slot and "up" in slot: + if _process_gate_up_pair( + model, layer_idx, slot["gate"], slot["up"], backend, packers + ): + n_fused += 1 + else: + n_unfused += 1 + pending_gate_up.pop(layer_idx, None) + n_processed += 2 + continue + if isinstance(value, ExportableGGUFTensor): weight = _convert_weight(model, model_key, value, backend) if model_key == "embed_tokens.weight": @@ -238,6 +342,21 @@ def load_gguf_model( if n_processed % 100 == 0: print(f" Processed {n_processed} tensors...") + # Flush any unpaired gate/up (partial/malformed) as separate unfused + # projections so no weight is left on meta. + for layer_idx, slot in pending_gate_up.items(): + for kind in ("gate", "up"): + if kind in slot: + _assign_gate_up_unfused( + model, layer_idx, kind, slot[kind], backend, packers + ) + n_unfused += 1 + + print( + f"[gemma4_31b gguf] Fused gate+up on {n_fused} MLP layers" + + (f" ({n_unfused} left unfused)" if n_unfused else "") + ) + _resolve_tied_lm_head(model, lm_head_weight, packers) # Fill RoPE tables / KV caches / scalar constants (left on meta by the