diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index e6168b257c85..0ee636b4a00e 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -147,7 +147,11 @@ def log_validation( pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + # Don't pass `dtype` here: under fp16 the trainable LoRA params are kept in fp32 (see + # `cast_training_params` above) and the validation pipeline shares the training `unet`, so casting it + # to fp16 would break the next optimizer step ("Attempting to unscale FP16 gradients"). Matches the + # SDXL script. + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference