diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 029dcf4ee4..5e8fe741be 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -58,9 +58,6 @@ def __init__( self.load_path = load_path assert load_dict is not None and len(load_dict) > 0, "must provide target objects to load." self.logger = logging.getLogger(name) - for k, v in load_dict.items(): - if hasattr(v, "module"): - load_dict[k] = v.module self.load_dict = load_dict self._name = name self.map_location = map_location @@ -80,10 +77,6 @@ def __call__(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ checkpoint = torch.load(self.load_path, map_location=self.map_location) - if len(self.load_dict) == 1: - key = list(self.load_dict.keys())[0] - if not (key in checkpoint): - checkpoint = {key: checkpoint} Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint) self.logger.info(f"Restored all variables from {self.load_path}") diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index da4df981ee..57d8728cd4 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -91,9 +91,6 @@ def __init__( assert save_dir is not None, "must provide directory to save the checkpoints." self.save_dir = save_dir assert save_dict is not None and len(save_dict) > 0, "must provide source objects to save." - for k, v in save_dict.items(): - if hasattr(v, "module"): - save_dict[k] = v.module self.save_dict = save_dict self.logger = logging.getLogger(name) self.epoch_level = epoch_level