diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 02c718cd14..c86dc3160f 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -18,6 +18,7 @@ import torch from monai.config import IgniteInfo +from monai.engines import SupervisedTrainer from monai.transforms import apply_transform from monai.utils import ensure_tuple, min_version, optional_import from monai.utils.enums import CommonKeys, GanKeys @@ -39,6 +40,7 @@ "default_make_latent", "engine_apply_transform", "default_metric_cmp_fn", + "GradientAccumulationSupervisedTrainingStep", ] @@ -286,3 +288,61 @@ def default_metric_cmp_fn(current_metric: float, prev_best: float) -> bool: """ return current_metric > prev_best + + +class GradientAccumulationSupervisedTrainingStep(): + """Factory function for supervised training. + + Args: + gradient_accumulation_steps: Number of steps the gradients should be accumulated across. + (default: 1 (means no gradient accumulation)) + Returns: + Callable: update function. + """ + + def __init__(self, gradient_accumulation_steps: int = 1) -> None: + if gradient_accumulation_steps <= 0: + raise ValueError("Gradient_accumulation_steps must be strictly positive. " + "No gradient accumulation if the value set to one (default).") + self.gradient_accumulation_steps = gradient_accumulation_steps + + def __call__(self, engine: SupervisedTrainer, batchdata: Sequence[torch.Tensor]) -> Any | tuple[torch.Tensor]: + if batchdata is None: + raise ValueError("Must provide batch data for current iteration.") + batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) + if len(batch) == 2: + inputs, targets = batch + args: tuple = () + kwargs: dict = {} + else: + inputs, targets, args, kwargs = batch + # put iteration outputs into engine.state + engine.state.output = {CommonKeys.IMAGE: inputs, CommonKeys.LABEL: targets} + + def _compute_pred_loss(): + engine.state.output[CommonKeys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) + engine.fire_event(IterationEvents.FORWARD_COMPLETED) + engine.state.output[CommonKeys.LOSS] = engine.loss_function(engine.state.output[CommonKeys.PRED], targets).mean() + engine.fire_event(IterationEvents.LOSS_COMPLETED) + + engine.network.train() + if (engine.state.iteration - 1) % self.gradient_accumulation_steps == 0: + engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none) + + if engine.amp and engine.scaler is not None: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + _compute_pred_loss() + engine.scaler.scale(engine.state.output[CommonKeys.LOSS] / self.gradient_accumulation_steps).backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) + if engine.state.iteration % self.gradient_accumulation_steps == 0: + engine.scaler.step(engine.optimizer) + engine.scaler.update() + else: + _compute_pred_loss() + (engine.state.output[CommonKeys.LOSS] / self.gradient_accumulation_steps).backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) + if engine.state.iteration % self.gradient_accumulation_steps == 0: + engine.optimizer.step() + engine.fire_event(IterationEvents.MODEL_COMPLETED) + + return engine.state.output