From 4a9c51afbeac4db300a4feec8920ac5c75c5c1ab Mon Sep 17 00:00:00 2001 From: Haozhe Zhang Date: Tue, 9 Jun 2026 01:45:01 -0700 Subject: [PATCH] Fix fp16 LoRA unscale crash after validation in train_dreambooth_lora.py When training with `--mixed_precision="fp16"` and `--validation_prompt`, the first optimizer step after a validation run fails with `ValueError: Attempting to unscale FP16 gradients`. Under fp16, `cast_training_params` keeps the trainable LoRA params in fp32. The in-loop validation pipeline is built with the same live `unet` object, and `log_validation` then calls `pipeline.to(device, dtype=torch_dtype)`, which downcasts those fp32 LoRA params back to fp16. The next backward therefore produces fp16 grads and `GradScaler.unscale_` raises. Drop the dtype cast from that `.to(...)` so the shared `unet` keeps its fp32 LoRA params. This matches train_dreambooth_lora_sdxl.py, which moves the validation pipeline with `.to(accelerator.device)` only. Fixes #13124 --- examples/dreambooth/train_dreambooth_lora.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index e6168b257c85..0ee636b4a00e 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -147,7 +147,11 @@ def log_validation( pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + # Don't pass `dtype` here: under fp16 the trainable LoRA params are kept in fp32 (see + # `cast_training_params` above) and the validation pipeline shares the training `unet`, so casting it + # to fp16 would break the next optimizer step ("Attempting to unscale FP16 gradients"). Matches the + # SDXL script. + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference