Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions examples/models/gemma4/text_decoder/gemma4_decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,25 @@ class Gemma4MLP(nn.Module):

def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(
F.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x)
)
# If a loader fused gate_proj|up_proj into one gate_up_proj (single
# matmul; e.g. the GGUF loader's coalesced fusion), use it and split the
# [.., 2*intermediate_size] output back into gate/up. Otherwise fall back
# to the separate projections (unfused checkpoints / non-fusing loaders).
gate_up = getattr(self, "gate_up_proj", None)
if gate_up is not None:
fused = gate_up(x)
gate = fused[..., : self.intermediate_size]
up = fused[..., self.intermediate_size :]
else:
gate = self.gate_proj(x)
up = self.up_proj(x)
return self.down_proj(F.gelu(gate, approximate="tanh") * up)


class Gemma4DecoderLayer(nn.Module):
Expand Down
119 changes: 119 additions & 0 deletions examples/models/gemma4_31b/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,89 @@
return gtensor


# ---------------------------------------------------------------------------
# Single-point gate/up fusion (backend-agnostic, at the raw GGUF level)
#
# gate_proj and up_proj share the same input, so the MLP can issue ONE matmul
# over a [2*intermediate, hidden] weight instead of two. We fuse here -- before
# any backend conversion (_convert_weight) -- by concatenating the two raw GGUF
# block blobs along the output (row) dim. ExportableGGUFTensor.raw is
# (N, row_bytes) row-major with each output row self-contained, so the concat is
# an exact row-stack (no re-quant, no scale recompute). Both CUDA and MLX then
# pack the already-fused weight, so there is no per-backend-type concat. The
# model's Gemma4MLP.forward splits the [.., 2*intermediate] output back into
# gate/up only when a fused gate_up_proj is present (graceful for unfused loads).


def _gate_up_layer_kind(model_key: str):
"""If ``model_key`` is an MLP gate/up proj weight, return ``(layer_idx, kind)``
with ``kind`` in ``{"gate", "up"}``; otherwise ``None``."""
prefix = "layers."
for kind in ("gate", "up"):
suffix = f".mlp.{kind}_proj.weight"
if model_key.startswith(prefix) and model_key.endswith(suffix):
mid = model_key[len(prefix) : len(model_key) - len(suffix)]
if mid.isdigit():
return int(mid), kind
return None


def _gate_up_fuseable(gate, up) -> bool:
"""True iff gate/up are the same GGUF quant type and same packed row width
(hence same K + block layout), so a row-concat along output N is valid."""
return (
gate.ggml_type == up.ggml_type
and gate.raw.shape[1] == up.raw.shape[1]
and int(gate.shape[1]) == int(up.shape[1])
)


def _fuse_gate_up_raw(gate, up):
"""Row-concat gate|up raw GGUF blocks (gate rows first) into one fused
ExportableGGUFTensor of shape (2*N, K)."""
from executorch.extension.llm.export.gguf import ExportableGGUFTensor

fused_raw = torch.cat([gate.raw, up.raw], dim=0)
return ExportableGGUFTensor.from_raw(fused_raw, gate.ggml_type, gate.orig_dtype)


def _assign_gate_up_unfused(model, layer_idx, kind, gtensor, backend, packers):
"""Assign a single gate/up GGUF tensor to its own projection (no fusion)."""
from executorch.examples.models.gemma4_31b.quant import pack_one

key = f"layers.{layer_idx}.mlp.{kind}_proj.weight"
pack_one(model, key, _convert_weight(model, key, gtensor, backend), packers)


def _install_and_pack_fused_gate_up(model, layer_idx, gate, up, backend, packers):
"""Fuse gate|up at the raw level, swap the layer's MLP to a single
``gate_up_proj`` (dropping gate_proj/up_proj), then pack the fused weight."""
import torch.nn as nn

from executorch.examples.models.gemma4_31b.quant import pack_one

fused = _fuse_gate_up_raw(gate, up)
inter, hidden = int(gate.shape[0]), int(gate.shape[1])

mlp = model.get_submodule(f"layers.{layer_idx}.mlp")
mlp.gate_up_proj = nn.Linear(hidden, 2 * inter, bias=False, device="meta")
del mlp.gate_proj
del mlp.up_proj

key = f"layers.{layer_idx}.mlp.gate_up_proj.weight"
pack_one(model, key, _convert_weight(model, key, fused, backend), packers)


def _process_gate_up_pair(model, layer_idx, gate, up, backend, packers) -> bool:
"""Fuse gate|up if compatible (returns True), else assign them unfused."""
if _gate_up_fuseable(gate, up):
_install_and_pack_fused_gate_up(model, layer_idx, gate, up, backend, packers)
return True
_assign_gate_up_unfused(model, layer_idx, "gate", gate, backend, packers)
_assign_gate_up_unfused(model, layer_idx, "up", up, backend, packers)
return False


def _resolve_tied_lm_head(model, lm_head_weight, packers):
"""Assign a tied lm_head (GGUF ties it to the token embedding)."""
from executorch.examples.models.gemma4_31b.quant import pack_one
Expand Down Expand Up @@ -168,7 +251,7 @@
return dequantize_weight(weight, torch.bfloat16), weight


def load_gguf_model(

Check warning on line 254 in examples/models/gemma4_31b/gguf_loader.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 C901

'load_gguf_model' is too complex (16) See https://www.flake8rules.com/rules/C901.html.
gguf_path: str,
max_seq_len: int = 4096,
backend: str = "cuda",
Expand Down Expand Up @@ -217,11 +300,32 @@
n_processed = 0

print(f"Streaming GGUF from {gguf_path}...")
pending_gate_up: dict = {} # layer_idx -> {"gate": raw, "up": raw}
n_fused = 0
n_unfused = 0
for gguf_name, value in iter_gguf(gguf_path):
model_key = gguf_to_model_key(gguf_name)
if model_key is None:
continue

# Buffer the RAW gate/up ExportableGGUFTensor (pre-conversion) and fuse
# once both arrive -- the single common point upstream of _convert_weight.
gu = _gate_up_layer_kind(model_key)
if gu is not None and isinstance(value, ExportableGGUFTensor):
layer_idx, kind = gu
slot = pending_gate_up.setdefault(layer_idx, {})
slot[kind] = value
if "gate" in slot and "up" in slot:
if _process_gate_up_pair(
model, layer_idx, slot["gate"], slot["up"], backend, packers
):
n_fused += 1
else:
n_unfused += 1
pending_gate_up.pop(layer_idx, None)
n_processed += 2
continue

if isinstance(value, ExportableGGUFTensor):
weight = _convert_weight(model, model_key, value, backend)
if model_key == "embed_tokens.weight":
Expand All @@ -238,6 +342,21 @@
if n_processed % 100 == 0:
print(f" Processed {n_processed} tensors...")

# Flush any unpaired gate/up (partial/malformed) as separate unfused
# projections so no weight is left on meta.
for layer_idx, slot in pending_gate_up.items():
for kind in ("gate", "up"):
if kind in slot:
_assign_gate_up_unfused(
model, layer_idx, kind, slot[kind], backend, packers
)
n_unfused += 1

print(
f"[gemma4_31b gguf] Fused gate+up on {n_fused} MLP layers"
+ (f" ({n_unfused} left unfused)" if n_unfused else "")
)

_resolve_tied_lm_head(model, lm_head_weight, packers)

# Fill RoPE tables / KV caches / scalar constants (left on meta by the
Expand Down
Loading