Skip to content

[WIP] Refactor Model Design#13794

Open
DN6 wants to merge 21 commits into
mainfrom
refactor-model-metadata
Open

[WIP] Refactor Model Design#13794
DN6 wants to merge 21 commits into
mainfrom
refactor-model-metadata

Conversation

@DN6
Copy link
Copy Markdown
Collaborator

@DN6 DN6 commented May 22, 2026

What does this PR do?

This refactor turns models into self-contained modules that declare their capabilities in one place. Per-model conversion code moves next to the model and a unified metadata() API makes feature attributes inspectable from any model class.

Motivation

Today, features are added to models through a mix of class attributes and mixins. Mixins define their own class attributes as well, so when examining a model class it isn't immediately clear which attributes and features are relevant or available.

Models are defined in a single file, so we end up using centralized utility files for things like model-specific weight and LoRA conversions. These files have grown enormous as they accumulate code to handle per-model variants and their idiosyncrasies.

The new design makes the mixins model-agnostic and has each mixin reach for the per-model metadata it needs through small handler objects attached to the model class.

Proposed Structure

Using Flux as a reference:

models/transformers/flux/
├── __init__.py
├── _ip_adapter.py        # FluxIPAdapterMixin + converters (internal)
├── _lora.py              # FLUX_LORA handler + per-format converters (internal)
├── _weight_mapping.py    # FLUX_WEIGHT_MAPPING handler + key tables (internal)
└── model.py              # FluxTransformer2DModel class declaration

Two patterns live next to model.py, picked per subsystem based on whether the behavior actually generalizes across models:

  1. handler + shared mixin — for features where the steps are the same across models and only the data/conversion function varies. The model opts in by inheriting the shared mixin in loaders/ and assigning its handler as a class attribute. LoRA and single-file weight mapping fit here:

    class FluxTransformer2DModel(ModelMixin, LoRAModelMixin, ...):
        _lora = LoRAHandler("...")               # handler instance, consumed by LoRAModelMixin
  2. Per-model mixin — for features that vary too much across models for a single shared mixin to be useful. Each model gets its own mixin declared right next to the model and inherited directly. IP-Adapter is the showcase:

    class FluxTransformer2DModel(..., FluxIPAdapterMixin):
        ...

This should simplify developing on top of these models — modifications or enhancements stay within one folder. If a model has a very specific feature that doesn't generalize across others, it can be kept isolated there too (e.g. FreeNoise for the AnimateDiff UNet). Additionally, if a custom model is modifying an existing diffusers model (Self-Forcing Wan), the folder method of organizing the model lends itself well to custom code loading with AutoModel.

Features Introduced

Model capability introspection via Model.metadata()

Each model exposes a metadata() classmethod that returns a metadata object, keyed by the class attribute that controls each feature. The displayed row tells you exactly what to set or inherit to change the behavior.

>>> print(FluxTransformer2DModel.metadata())
FluxTransformer2DModel feature attributes
──────────────────────────────────────────────────────────────────────────────────
  _supports_gradient_checkpointing  True
  _supports_group_offloading        True
  _no_split_modules                 FluxTransformerBlock, FluxSingleTransformerBlock
  _skip_layerwise_casting_patterns  pos_embed, norm
  _repeated_blocks                  FluxTransformerBlock, FluxSingleTransformerBlock
  _cp_plan                          True
  _weight_mapping                   flux-depth, flux-dev, flux-fill, flux-schnell
  _lora                             bfl, kohya, kontext, xlabs
  _supports_cache                   True
  _supports_ip_adapter              True

The returned ModelMetadata exposes each feature value as an attribute (meta._supports_ip_adapter, meta._lora, ...), supports keys() / values() / items() for mapping-style iteration, and in for presence checks. meta.describe(verbose=True) adds an indented description and docs link under each row. Which can be useful for agents.

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6 DN6 marked this pull request as ready for review May 27, 2026 16:35
@DN6 DN6 requested review from dg845, sayakpaul and yiyixuxu May 27, 2026 16:35
@DN6 DN6 requested a review from asomoza May 27, 2026 16:36
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for working on this!
I really like the direction this is taking, very exciting!
It's a big PR with a lot of changes, so I'll go through it step by step - left some initial questions:)

_default_processor_cls = None
_available_processors = []
_supports_qkv_fusion = True
_parallel_config = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need it here?

f"{DOCS_BASE}/optimization/memory#gradient-checkpointing",
)
if cls._supports_group_offloading:
rows["_supports_group_offloading"] = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the metadata + mixin direction but I'd like to understand a bit more: is there a fundamental reason why some features get their own self-contained mixin (Cache, Lora) and many others stay as methods on ModelMixin?

return "\n".join(lines)


def register_metadata(metadata):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we have to keep two systems to attach metadata to a class? the register_metadata vs the class attribute like _cp_plan?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is because for handling with stuff like _cp_plan we don't have a mixin like CPMixin. However for others, we have dedicated mixins.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But do we need this registration, though? All the available feature set can be queried through the main model class, no?



@dataclass
class TransformerBlockOutput(TransformerModuleOutput):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are not used yet no?

_cached_parameter_indices: dict[str, int] = None

def _register(self, cls):
"""Attach this metadata to ``cls`` and register it in :class:`TransformerBlockRegistry`.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single "" is what we use across diffusers. Also, :class:` isn't something we do.

Suggested change
"""Attach this metadata to ``cls`` and register it in :class:`TransformerBlockRegistry`.
"""Attach this metadata to `cls` and register it in `TransformerBlockRegistry`.

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a pass on LoRA. Will now do a pass on modeling_utils.py.

I think it might be better to do the LoRA-related separation in another PR because it's difficult to truly assess if it's in a working state.

_cls: Type = None
_cached_parameter_indices: dict[str, int] = None

def _register(self, cls):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little unclear about the docstring. How is this method metadata (which is what the docstring reads)?

def _register(self, cls):
"""Attach this metadata to ``cls`` and register it in :class:`TransformerBlockRegistry`.

Lets ``@register_metadata(TransformerBlockMetadata(...))`` work for block classes that opt into the decorator
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

Comment on lines +42 to +44
self._cls = cls
cls._block_metadata = self
TransformerBlockRegistry._registry[cls] = self
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it only applicable to classes with _repeated_blocks set?



@maybe_allow_in_graph
@register_metadata(TransformerBlockMetadata(return_hidden_states_index=1, return_encoder_hidden_states_index=0))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this?

"_supports_cache": (
True,
"True",
"Supports caching techniques (PAB / FasterCache / FirstBlockCache) via `enable_cache`.",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Supports caching techniques (PAB / FasterCache / FirstBlockCache) via `enable_cache`.",
"Supports caching techniques (e.g. FasterCache) via `enable_cache`.",

r"""
Add an adapter to the underlying model.

``source`` can be either:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is nothing called source here.


_maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name)

def _inject_adapter(self, state_dict, lora_config, adapter_name, peft_kwargs):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will prefer it inline as it's not shared.

self._rollback_adapter(adapter_name, e)
raise

def _maybe_apply_deferred_hotswap_prep(self, lora_config):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

prepare_model_for_compiled_hotswap(self, config=lora_config, **self._lora_hotswap_kwargs)
self._lora_hotswap_kwargs = None

def _hotswap_adapter(self, state_dict, lora_config, adapter_name):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

self._rollback_adapter(adapter_name, e)
raise

def _rollback_adapter(self, adapter_name, error):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.



@dataclass
class AttnProcessorOutput(TransformerModuleOutput):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does AttnProcessorOutput has to live in src/diffusers/models/transformers/utils.py? Could it not be used by VAEs or other components under src/diffusers/models/?

class ModelMetadata:
"""Snapshot of a model class's feature attributes.

Constructed by :meth:`ModelMixin.metadata` — walks ``cls.__mro__`` collecting rows from each mixin's ``_metadata``
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:meth: can be dangerous abbreviation 🤪

return "\n".join(lines)


def register_metadata(metadata):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is because for handling with stuff like _cp_plan we don't have a mixin like CPMixin. However for others, we have dedicated mixins.

return "\n".join(lines)


def register_metadata(metadata):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But do we need this registration, though? All the available feature set can be queried through the main model class, no?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants