Skip to content

Fix EMAModel.restore() foreach path crashing with device mismatch when model is on GPU#13782

Open
Dev-X25874 wants to merge 2 commits into
huggingface:mainfrom
Dev-X25874:fix/ema-restore-foreach-device-mismatch
Open

Fix EMAModel.restore() foreach path crashing with device mismatch when model is on GPU#13782
Dev-X25874 wants to merge 2 commits into
huggingface:mainfrom
Dev-X25874:fix/ema-restore-foreach-device-mismatch

Conversation

@Dev-X25874
Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes a runtime crash in EMAModel.restore() when foreach=True and the model lives on a non-CPU device (e.g. CUDA).

store() always saves parameters to CPU (param.detach().cpu().clone()). The foreach path in restore() then passed those raw CPU tensors directly to torch._foreach_copy_(), which requires all tensors to be on the same device:

# before (broken on GPU)
torch._foreach_copy_(
    [param.data for param in parameters],
    [c_param.data for c_param in self.temp_stored_params],  # always CPU
)

This raises RuntimeError: Expected all tensors to be on same device for any user who calls the standard EMA validation pattern (store → copy_to → restore) with foreach=True on a GPU machine.

The fix mirrors the pattern already used correctly in copy_to()'s foreach path (line 780), which moves each shadow param to the target device before the copy:

# after (matches copy_to() pattern)
torch._foreach_copy_(
    [param.data for param in parameters],
    [c_param.to(param.device).data for c_param, param in zip(self.temp_stored_params, parameters)],
)

Also adds test_store_restore to both EMAModelTests and EMAModelTestsForeach — the store/restore round-trip was completely untested prior to this PR.

Before submitting

Who can review?

@sayakpaul

@github-actions github-actions Bot added tests size/S PR with diff < 50 LOC labels May 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

size/S PR with diff < 50 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant