[WIP] Refactor Model Design#13794
Conversation
yiyixuxu
left a comment
There was a problem hiding this comment.
Thanks so much for working on this!
I really like the direction this is taking, very exciting!
It's a big PR with a lot of changes, so I'll go through it step by step - left some initial questions:)
| _default_processor_cls = None | ||
| _available_processors = [] | ||
| _supports_qkv_fusion = True | ||
| _parallel_config = None |
| f"{DOCS_BASE}/optimization/memory#gradient-checkpointing", | ||
| ) | ||
| if cls._supports_group_offloading: | ||
| rows["_supports_group_offloading"] = ( |
There was a problem hiding this comment.
I really like the metadata + mixin direction but I'd like to understand a bit more: is there a fundamental reason why some features get their own self-contained mixin (Cache, Lora) and many others stay as methods on ModelMixin?
| return "\n".join(lines) | ||
|
|
||
|
|
||
| def register_metadata(metadata): |
There was a problem hiding this comment.
Is there a reason we have to keep two systems to attach metadata to a class? the register_metadata vs the class attribute like _cp_plan?
There was a problem hiding this comment.
I think that is because for handling with stuff like _cp_plan we don't have a mixin like CPMixin. However for others, we have dedicated mixins.
There was a problem hiding this comment.
But do we need this registration, though? All the available feature set can be queried through the main model class, no?
|
|
||
|
|
||
| @dataclass | ||
| class TransformerBlockOutput(TransformerModuleOutput): |
There was a problem hiding this comment.
these are not used yet no?
| _cached_parameter_indices: dict[str, int] = None | ||
|
|
||
| def _register(self, cls): | ||
| """Attach this metadata to ``cls`` and register it in :class:`TransformerBlockRegistry`. |
There was a problem hiding this comment.
Single "" is what we use across diffusers. Also, :class:` isn't something we do.
| """Attach this metadata to ``cls`` and register it in :class:`TransformerBlockRegistry`. | |
| """Attach this metadata to `cls` and register it in `TransformerBlockRegistry`. |
sayakpaul
left a comment
There was a problem hiding this comment.
Did a pass on LoRA. Will now do a pass on modeling_utils.py.
I think it might be better to do the LoRA-related separation in another PR because it's difficult to truly assess if it's in a working state.
- We are likely missing andling Flux Control LoRA and as such, handling text-encoder LoRA modules. It's also not clear how this would affect
lora_pipeline.pyandlora_conversion_utils.py. - One-off utilities that are not shared across multiple different functions / classes.
- It's not clear how the existing https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/peft.py gets affected by this.
| _cls: Type = None | ||
| _cached_parameter_indices: dict[str, int] = None | ||
|
|
||
| def _register(self, cls): |
There was a problem hiding this comment.
I am a little unclear about the docstring. How is this method metadata (which is what the docstring reads)?
| def _register(self, cls): | ||
| """Attach this metadata to ``cls`` and register it in :class:`TransformerBlockRegistry`. | ||
|
|
||
| Lets ``@register_metadata(TransformerBlockMetadata(...))`` work for block classes that opt into the decorator |
| self._cls = cls | ||
| cls._block_metadata = self | ||
| TransformerBlockRegistry._registry[cls] = self |
There was a problem hiding this comment.
Is it only applicable to classes with _repeated_blocks set?
|
|
||
|
|
||
| @maybe_allow_in_graph | ||
| @register_metadata(TransformerBlockMetadata(return_hidden_states_index=1, return_encoder_hidden_states_index=0)) |
| "_supports_cache": ( | ||
| True, | ||
| "True", | ||
| "Supports caching techniques (PAB / FasterCache / FirstBlockCache) via `enable_cache`.", |
There was a problem hiding this comment.
| "Supports caching techniques (PAB / FasterCache / FirstBlockCache) via `enable_cache`.", | |
| "Supports caching techniques (e.g. FasterCache) via `enable_cache`.", |
| r""" | ||
| Add an adapter to the underlying model. | ||
|
|
||
| ``source`` can be either: |
There was a problem hiding this comment.
There is nothing called source here.
|
|
||
| _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name) | ||
|
|
||
| def _inject_adapter(self, state_dict, lora_config, adapter_name, peft_kwargs): |
There was a problem hiding this comment.
Will prefer it inline as it's not shared.
| self._rollback_adapter(adapter_name, e) | ||
| raise | ||
|
|
||
| def _maybe_apply_deferred_hotswap_prep(self, lora_config): |
| prepare_model_for_compiled_hotswap(self, config=lora_config, **self._lora_hotswap_kwargs) | ||
| self._lora_hotswap_kwargs = None | ||
|
|
||
| def _hotswap_adapter(self, state_dict, lora_config, adapter_name): |
| self._rollback_adapter(adapter_name, e) | ||
| raise | ||
|
|
||
| def _rollback_adapter(self, adapter_name, error): |
|
|
||
|
|
||
| @dataclass | ||
| class AttnProcessorOutput(TransformerModuleOutput): |
There was a problem hiding this comment.
Why does AttnProcessorOutput has to live in src/diffusers/models/transformers/utils.py? Could it not be used by VAEs or other components under src/diffusers/models/?
| class ModelMetadata: | ||
| """Snapshot of a model class's feature attributes. | ||
|
|
||
| Constructed by :meth:`ModelMixin.metadata` — walks ``cls.__mro__`` collecting rows from each mixin's ``_metadata`` |
There was a problem hiding this comment.
:meth: can be dangerous abbreviation 🤪
| return "\n".join(lines) | ||
|
|
||
|
|
||
| def register_metadata(metadata): |
There was a problem hiding this comment.
I think that is because for handling with stuff like _cp_plan we don't have a mixin like CPMixin. However for others, we have dedicated mixins.
| return "\n".join(lines) | ||
|
|
||
|
|
||
| def register_metadata(metadata): |
There was a problem hiding this comment.
But do we need this registration, though? All the available feature set can be queried through the main model class, no?
What does this PR do?
This refactor turns models into self-contained modules that declare their capabilities in one place. Per-model conversion code moves next to the model and a unified
metadata()API makes feature attributes inspectable from any model class.Motivation
Today, features are added to models through a mix of class attributes and mixins. Mixins define their own class attributes as well, so when examining a model class it isn't immediately clear which attributes and features are relevant or available.
Models are defined in a single file, so we end up using centralized utility files for things like model-specific weight and LoRA conversions. These files have grown enormous as they accumulate code to handle per-model variants and their idiosyncrasies.
The new design makes the mixins model-agnostic and has each mixin reach for the per-model metadata it needs through small handler objects attached to the model class.
Proposed Structure
Using Flux as a reference:
Two patterns live next to
model.py, picked per subsystem based on whether the behavior actually generalizes across models:handler + shared mixin — for features where the steps are the same across models and only the data/conversion function varies. The model opts in by inheriting the shared mixin in
loaders/and assigning its handler as a class attribute. LoRA and single-file weight mapping fit here:Per-model mixin — for features that vary too much across models for a single shared mixin to be useful. Each model gets its own mixin declared right next to the model and inherited directly. IP-Adapter is the showcase:
This should simplify developing on top of these models — modifications or enhancements stay within one folder. If a model has a very specific feature that doesn't generalize across others, it can be kept isolated there too (e.g. FreeNoise for the AnimateDiff UNet). Additionally, if a custom model is modifying an existing diffusers model (Self-Forcing Wan), the folder method of organizing the model lends itself well to custom code loading with
AutoModel.Features Introduced
Model capability introspection via
Model.metadata()Each model exposes a
metadata()classmethod that returns a metadata object, keyed by the class attribute that controls each feature. The displayed row tells you exactly what to set or inherit to change the behavior.>>> print(FluxTransformer2DModel.metadata()) FluxTransformer2DModel feature attributes ────────────────────────────────────────────────────────────────────────────────── _supports_gradient_checkpointing True _supports_group_offloading True _no_split_modules FluxTransformerBlock, FluxSingleTransformerBlock _skip_layerwise_casting_patterns pos_embed, norm _repeated_blocks FluxTransformerBlock, FluxSingleTransformerBlock _cp_plan True _weight_mapping flux-depth, flux-dev, flux-fill, flux-schnell _lora bfl, kohya, kontext, xlabs _supports_cache True _supports_ip_adapter TrueThe returned
ModelMetadataexposes each feature value as an attribute (meta._supports_ip_adapter,meta._lora, ...), supportskeys()/values()/items()for mapping-style iteration, andinfor presence checks.meta.describe(verbose=True)adds an indented description and docs link under each row. Which can be useful for agents.Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.