Skip to content

Cache ModelMixin.dtype to avoid named_parameters walk per access#13571

Open
akshan-main wants to merge 5 commits into
huggingface:mainfrom
akshan-main:cache-modelmixin-dtype
Open

Cache ModelMixin.dtype to avoid named_parameters walk per access#13571
akshan-main wants to merge 5 commits into
huggingface:mainfrom
akshan-main:cache-modelmixin-dtype

Conversation

@akshan-main

@akshan-main akshan-main commented Apr 28, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Addresses #13401

ModelMixin.dtype calls get_parameter_dtype(), which walks named_parameters() on every access. Pipelines read self.transformer.dtype / text_encoder.dtype / vae.dtype inside denoise loops, so the walk fires every step.

This PR caches dtype on first access and invalidates in _apply, which .to(), .cpu(), .cuda(), .half(), and .bfloat16() all route through. Generation outputs are bit-identical. A microbench on AutoencoderKL drops cached .dtype access from ~88us to ~0.1us (~1000x).

device is intentionally not cached: with group offloading the effective device changes per-forward, so a cache there would be wrong.

Same shape as the cache_context._set_context cache in #13356

Profiling: 10 pipelines (eager, 2 inference steps, A100)

pipeline             |  dtype calls  | dtype total (ms)  | step gap (ms) | pipeline_call (ms) 
                     | before  after |   before    after | before  after |    before     after
----------------------------------------------------------------------------------------------
flux2                |      2      0 |     6.04     0.00 |   3.83   0.85 |    384.78    369.71
qwenimage(fd)        |      1      0 |     2.46     0.00 |   0.29   0.29 |   2148.46   2147.69
qwenimage_edit(fd)   |      1      0 |     2.44     0.00 |   0.30   0.31 |   6250.67   6237.57
z_image              |      2      0 |    10.52     0.00 |  10.63   5.40 |   1394.03   1391.21
chroma               |      0      0 |     0.00     0.00 |   0.16   0.17 |   2163.04   2162.70
sdxl(fd)             |      2      0 |     4.46     0.00 |   0.00   0.00 |    752.58    758.52
sana                 |      1      0 |     3.68     0.00 |   1.69   1.57 |    387.98    388.07
hunyuanv15(video)    |      5      0 |    82.59     0.00 |   0.19   0.18 |  16970.52  16881.84
wan2.2(video)        |      1      0 |     6.60     0.00 |   0.15   0.16 |   3015.05   3045.30
ltx2(video)          |      0      0 |     0.00     0.00 |   8.75   8.79 |   5240.18   5228.16

The fix removes the walk where it appears. The largest single saving is hunyuanv15 at 82.59 ms over 2 inference steps, scaling linearly to ~2.1 s at a typical 50 steps. Pipelines without the walk (chroma, ltx2) are unaffected.

Reproduction notebook (Colab)

Before submitting

Who can review?

@sayakpaul @dg845

@akshan-main

akshan-main commented Apr 28, 2026

Copy link
Copy Markdown
Contributor Author

Profiled SD3 too (eager + compile, RTX PRO 6000 Blackwell, 2 steps) following the profiling guide.
Notebook: https://colab.research.google.com/gist/akshan-main/cb9ee83575806704e93e03496ba0d940/sd3_profiling.ipynb

Denoising loop is clean. 0 syncs in in_loop_body, in_transformer_forward, or in_scheduler_step after set_begin_index(0).

Pre-loop has 2x ~10ms aten::copy_ from scheduler.set_timesteps (numpy to GPU sigmas) and _get_clip_prompt_embeds (tokenizer ids to GPU). One 62ms aten::nonzero in the first _init_step_index call which set_begin_index(0) eliminates.

Tested adding set_begin_index(0) (matches Flux/Wan/Flux2). Trace sync drops from 62ms to 0 but wall-clock is within noise:

Mode Before After Delta
Eager 233.0 ± 0.6 ms 231.9 ± 1.0 ms -1.1 ms
Compile 200.2 ± 0.3 ms 199.9 ± 0.3 ms -0.3 ms

The sync was queue-drain. GPU has to do that work anyway, CPU just doesn't wait for it. Unlike Z-Image #13461, no per-step .item()/.cpu() to chase here. Remaining pre-loop syncs are legitimate one-time copies. Not opening a PR for SD3 profile.

@sayakpaul @dg845

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu April 30, 2026 06:35

@DN6 DN6 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proposal looks good to me. But we also need to account for enable_layerwise_casting.

Comment thread src/diffusers/models/modeling_utils.py
@github-actions github-actions Bot added the tests label Jun 1, 2026
@akshan-main akshan-main requested a review from DN6 June 1, 2026 13:36
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6

DN6 commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

@akshan-main Could you take a look at the CI failures please.

@github-actions github-actions Bot added size/M PR with diff < 200 LOC and removed size/S PR with diff < 50 LOC labels Jun 8, 2026
@akshan-main

Copy link
Copy Markdown
Contributor Author

@DN6 test_to_dtype broke because StableUnCLIP's custom .to() bypasses _apply; I added register_parameter invalidation. That's a third hook though, if you'd rather not chase every mutation path, I'm happy to instead fix the normalizer's .to() or drop the cache. wdyt.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/M PR with diff < 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants