Fix fp16 LoRA unscale crash after validation in train_dreambooth_lora.py#13895
Merged
sayakpaul merged 2 commits intoJun 9, 2026
Merged
Conversation
When training with `--mixed_precision="fp16"` and `--validation_prompt`, the first optimizer step after a validation run fails with `ValueError: Attempting to unscale FP16 gradients`. Under fp16, `cast_training_params` keeps the trainable LoRA params in fp32. The in-loop validation pipeline is built with the same live `unet` object, and `log_validation` then calls `pipeline.to(device, dtype=torch_dtype)`, which downcasts those fp32 LoRA params back to fp16. The next backward therefore produces fp16 grads and `GradScaler.unscale_` raises. Drop the dtype cast from that `.to(...)` so the shared `unet` keeps its fp32 LoRA params. This matches train_dreambooth_lora_sdxl.py, which moves the validation pipeline with `.to(accelerator.device)` only. Fixes huggingface#13124
fd60288 to
4a9c51a
Compare
sayakpaul
approved these changes
Jun 9, 2026
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
3 tasks
HaozheZhang6
added a commit
to HaozheZhang6/diffusers
that referenced
this pull request
Jun 10, 2026
…LoRA scripts Follow-up to huggingface#13895, which fixed this for examples/dreambooth/train_dreambooth_lora.py. The same fp16 footgun is present in the other DreamBooth LoRA training scripts: under `--mixed_precision="fp16"`, `cast_training_params(..., dtype=torch.float32)` keeps the trainable LoRA params in fp32, but `log_validation` rebuilds the in-loop validation pipeline around the *live* training transformer (`transformer=unwrap_model(transformer)`) and then casts it to fp16. That downcasts the fp32 LoRA params, so the next optimizer step raises `ValueError: Attempting to unscale FP16 gradients`. Apply the same fix across the remaining scripts: - flux, flux_kontext, qwen_image, hidream, and advanced flux: drop `dtype=torch_dtype` from `pipeline.to(accelerator.device, ...)` (keep the device move), matching huggingface#13895. - z_image and the flux2 variants: the cast is `pipeline.to(dtype=torch_dtype)` with no device move, immediately followed by `enable_model_cpu_offload()`, so just drop the cast line. Frozen weights already use `weight_dtype` and the offload call handles device placement. The final (post-training) validation in every script builds a fresh pipeline from the saved weights, so it is unaffected either way.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes #13124
Training
train_dreambooth_lora.pywith--mixed_precision="fp16"and--validation_promptcrashes on the first optimizer step after a validation run:Removing
--validation_promptavoids it, which points atlog_validation.Root cause
Under fp16,
cast_training_params(models, dtype=torch.float32)keeps the trainable LoRA params in fp32 (the standard fp16 mitigation, see #6514 / #6554).The in-loop validation pipeline is built with the same live
unetobject:log_validationthen runspipeline.to(accelerator.device, dtype=torch_dtype)withtorch_dtype=weight_dtype(fp16). That.to(..., dtype=fp16)downcasts the sharedunet's fp32 LoRA params back to fp16, so the next backward produces fp16 grads andGradScaler.unscale_raises.train_dreambooth_lora_sdxl.pydoes not hit this — itslog_validationmoves the pipeline with.to(accelerator.device)only (no dtype). This PR makes the SD1.5 script consistent.Fix
Drop the
dtype=torch_dtypefrom the.to(...)inlog_validation(plus an explanatory comment) so the sharedunetkeeps its fp32 LoRA params. The validation pipeline is already built withtorch_dtype=weight_dtype, and inference runs undertorch.amp.autocast, so validation behavior is unchanged.Verification
The crash is GPU-only and hard to exercise in the CPU-based example CI (see "On a regression test" below), so I verified the mechanism directly on CPU: a module with fp16 base weights + fp32 LoRA params, run through autocast forward → backward →
GradScaler.unscale_:unscale_.to(device, dtype=fp16)(before)ValueError: Attempting to unscale FP16 gradients..to(device)(after)ruff checkandruff format --checkpass on the changed file.On a regression test
I looked into adding a subprocess test under
examples/dreambooth/test_dreambooth_lora.py, but this bug cannot be reproduced by the CPU-based example CI, for two independent reasons:GradScaleris CUDA-only. WithAccelerator(mixed_precision="fp16")on CPU,accelerator.scaler is None, sounscale_is never called and the error can't fire. This is also why the example suite currently has no--mixed_precision fp16tests.--validation_promptonhf-internal-testing/tiny-stable-diffusion-pipefails earlier insidelog_validationwithValueError: Input image size (224*224) doesn't match model (30*30), before the post-validation step is ever reached.So a green/red CPU test isn't achievable here. I kept the change minimal and consistent with the SDXL script instead. Happy to add a
@require_torch_gpunightly test (or anything else you'd prefer) if that's the convention you'd like for this path.Note: a prior attempt (#13510) was self-closed unreviewed; it used the alternative approach of re-running
cast_training_paramsafter validation. This PR instead removes the source of the downcast, matching the SDXL script.Before submitting
Who can review?
@sayakpaul