diff --git a/examples/models/gemma4/text_decoder/gemma4_decoder_layer.py b/examples/models/gemma4/text_decoder/gemma4_decoder_layer.py index e10c1c7e415..fe3e3bb94cb 100644 --- a/examples/models/gemma4/text_decoder/gemma4_decoder_layer.py +++ b/examples/models/gemma4/text_decoder/gemma4_decoder_layer.py @@ -34,14 +34,25 @@ class Gemma4MLP(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() + self.intermediate_size = intermediate_size self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj( - F.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x) - ) + # If a loader fused gate_proj|up_proj into one gate_up_proj (single + # matmul; e.g. the GGUF loader's coalesced fusion), use it and split the + # [.., 2*intermediate_size] output back into gate/up. Otherwise fall back + # to the separate projections (unfused checkpoints / non-fusing loaders). + gate_up = getattr(self, "gate_up_proj", None) + if gate_up is not None: + fused = gate_up(x) + gate = fused[..., : self.intermediate_size] + up = fused[..., self.intermediate_size :] + else: + gate = self.gate_proj(x) + up = self.up_proj(x) + return self.down_proj(F.gelu(gate, approximate="tanh") * up) class Gemma4DecoderLayer(nn.Module): diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 90839ea6f6a..6a4a70ced18 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -104,6 +104,89 @@ def _convert_weight(model, model_key: str, gtensor, backend: str): return gtensor +# --------------------------------------------------------------------------- +# Single-point gate/up fusion (backend-agnostic, at the raw GGUF level) +# +# gate_proj and up_proj share the same input, so the MLP can issue ONE matmul +# over a [2*intermediate, hidden] weight instead of two. We fuse here -- before +# any backend conversion (_convert_weight) -- by concatenating the two raw GGUF +# block blobs along the output (row) dim. ExportableGGUFTensor.raw is +# (N, row_bytes) row-major with each output row self-contained, so the concat is +# an exact row-stack (no re-quant, no scale recompute). Both CUDA and MLX then +# pack the already-fused weight, so there is no per-backend-type concat. The +# model's Gemma4MLP.forward splits the [.., 2*intermediate] output back into +# gate/up only when a fused gate_up_proj is present (graceful for unfused loads). + + +def _gate_up_layer_kind(model_key: str): + """If ``model_key`` is an MLP gate/up proj weight, return ``(layer_idx, kind)`` + with ``kind`` in ``{"gate", "up"}``; otherwise ``None``.""" + prefix = "layers." + for kind in ("gate", "up"): + suffix = f".mlp.{kind}_proj.weight" + if model_key.startswith(prefix) and model_key.endswith(suffix): + mid = model_key[len(prefix) : len(model_key) - len(suffix)] + if mid.isdigit(): + return int(mid), kind + return None + + +def _gate_up_fuseable(gate, up) -> bool: + """True iff gate/up are the same GGUF quant type and same packed row width + (hence same K + block layout), so a row-concat along output N is valid.""" + return ( + gate.ggml_type == up.ggml_type + and gate.raw.shape[1] == up.raw.shape[1] + and int(gate.shape[1]) == int(up.shape[1]) + ) + + +def _fuse_gate_up_raw(gate, up): + """Row-concat gate|up raw GGUF blocks (gate rows first) into one fused + ExportableGGUFTensor of shape (2*N, K).""" + from executorch.extension.llm.export.gguf import ExportableGGUFTensor + + fused_raw = torch.cat([gate.raw, up.raw], dim=0) + return ExportableGGUFTensor.from_raw(fused_raw, gate.ggml_type, gate.orig_dtype) + + +def _assign_gate_up_unfused(model, layer_idx, kind, gtensor, backend, packers): + """Assign a single gate/up GGUF tensor to its own projection (no fusion).""" + from executorch.examples.models.gemma4_31b.quant import pack_one + + key = f"layers.{layer_idx}.mlp.{kind}_proj.weight" + pack_one(model, key, _convert_weight(model, key, gtensor, backend), packers) + + +def _install_and_pack_fused_gate_up(model, layer_idx, gate, up, backend, packers): + """Fuse gate|up at the raw level, swap the layer's MLP to a single + ``gate_up_proj`` (dropping gate_proj/up_proj), then pack the fused weight.""" + import torch.nn as nn + + from executorch.examples.models.gemma4_31b.quant import pack_one + + fused = _fuse_gate_up_raw(gate, up) + inter, hidden = int(gate.shape[0]), int(gate.shape[1]) + + mlp = model.get_submodule(f"layers.{layer_idx}.mlp") + mlp.gate_up_proj = nn.Linear(hidden, 2 * inter, bias=False, device="meta") + del mlp.gate_proj + del mlp.up_proj + + key = f"layers.{layer_idx}.mlp.gate_up_proj.weight" + pack_one(model, key, _convert_weight(model, key, fused, backend), packers) + + +def _process_gate_up_pair(model, layer_idx, gate, up, backend, packers) -> bool: + """Fuse gate|up if compatible (returns True), else assign them unfused.""" + if _gate_up_fuseable(gate, up): + _install_and_pack_fused_gate_up(model, layer_idx, gate, up, backend, packers) + return True + _assign_gate_up_unfused(model, layer_idx, "gate", gate, backend, packers) + _assign_gate_up_unfused(model, layer_idx, "up", up, backend, packers) + return False + + def _resolve_tied_lm_head(model, lm_head_weight, packers): """Assign a tied lm_head (GGUF ties it to the token embedding).""" from executorch.examples.models.gemma4_31b.quant import pack_one @@ -217,11 +300,32 @@ def load_gguf_model( n_processed = 0 print(f"Streaming GGUF from {gguf_path}...") + pending_gate_up: dict = {} # layer_idx -> {"gate": raw, "up": raw} + n_fused = 0 + n_unfused = 0 for gguf_name, value in iter_gguf(gguf_path): model_key = gguf_to_model_key(gguf_name) if model_key is None: continue + # Buffer the RAW gate/up ExportableGGUFTensor (pre-conversion) and fuse + # once both arrive -- the single common point upstream of _convert_weight. + gu = _gate_up_layer_kind(model_key) + if gu is not None and isinstance(value, ExportableGGUFTensor): + layer_idx, kind = gu + slot = pending_gate_up.setdefault(layer_idx, {}) + slot[kind] = value + if "gate" in slot and "up" in slot: + if _process_gate_up_pair( + model, layer_idx, slot["gate"], slot["up"], backend, packers + ): + n_fused += 1 + else: + n_unfused += 1 + pending_gate_up.pop(layer_idx, None) + n_processed += 2 + continue + if isinstance(value, ExportableGGUFTensor): weight = _convert_weight(model, model_key, value, backend) if model_key == "embed_tokens.weight": @@ -238,6 +342,21 @@ def load_gguf_model( if n_processed % 100 == 0: print(f" Processed {n_processed} tensors...") + # Flush any unpaired gate/up (partial/malformed) as separate unfused + # projections so no weight is left on meta. + for layer_idx, slot in pending_gate_up.items(): + for kind in ("gate", "up"): + if kind in slot: + _assign_gate_up_unfused( + model, layer_idx, kind, slot[kind], backend, packers + ) + n_unfused += 1 + + print( + f"[gemma4_31b gguf] Fused gate+up on {n_fused} MLP layers" + + (f" ({n_unfused} left unfused)" if n_unfused else "") + ) + _resolve_tied_lm_head(model, lm_head_weight, packers) # Fill RoPE tables / KV caches / scalar constants (left on meta by the