From 87b8b82c1955ecc3b16ba0fe1749e0bf14b3828f Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Tue, 9 Sep 2025 17:16:59 +0200 Subject: [PATCH 01/10] add normalizer callback --- docs/source/_rst/_code.rst | 1 + docs/source/_rst/callback/normalizer.rst | 7 + pina/callback/__init__.py | 2 + pina/callback/normalizer.py | 195 +++++++++++++++++++++++ tests/test_callback/test_normalizer.py | 187 ++++++++++++++++++++++ 5 files changed, 392 insertions(+) create mode 100644 docs/source/_rst/callback/normalizer.rst create mode 100644 pina/callback/normalizer.py create mode 100644 tests/test_callback/test_normalizer.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index a7242562b..e911936fd 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -253,6 +253,7 @@ Callbacks Optimizer callback R3 Refinment callback Refinment Interface callback + Normalizer callback Losses and Weightings --------------------- diff --git a/docs/source/_rst/callback/normalizer.rst b/docs/source/_rst/callback/normalizer.rst new file mode 100644 index 000000000..eb61b754a --- /dev/null +++ b/docs/source/_rst/callback/normalizer.rst @@ -0,0 +1,7 @@ +Normalizer callbacks +======================= + +.. currentmodule:: pina.callback.normalizer +.. autoclass:: NormalizerDataCallback + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/callback/__init__.py b/pina/callback/__init__.py index e9a70ea34..599f76b9c 100644 --- a/pina/callback/__init__.py +++ b/pina/callback/__init__.py @@ -5,8 +5,10 @@ "MetricTracker", "PINAProgressBar", "R3Refinement", + "NormalizerDataCallback", ] from .optimizer_callback import SwitchOptimizer from .processing_callback import MetricTracker, PINAProgressBar from .refinement import R3Refinement +from .normalizer import NormalizerDataCallback diff --git a/pina/callback/normalizer.py b/pina/callback/normalizer.py new file mode 100644 index 000000000..d35ef1e99 --- /dev/null +++ b/pina/callback/normalizer.py @@ -0,0 +1,195 @@ +"""Module for the Normalizer callback.""" + +import torch +from lightning.pytorch import Callback +from ..label_tensor import LabelTensor +from ..utils import check_consistency + +_REQUIRED_KEYS = {"scale", "shift"} + + +class NormalizerDataCallback(Callback): + r""" + A Lightning Callback that normalizes dataset inputs or targets + according to user-provided scale and shift parameters. + + The transformation is applied as: + + .. math:: + + x_{\text{new}} = \frac{x - \text{shift}}{\text{scale}} + + :Example: + + >>> NormalizerDataCallback({"scale": 1, "shift": 0}) + >>> NormalizerDataCallback({ + ... "a": {"scale": 2.0, "shift": 1.0}, + ... "b": {"scale": 0.5, "shift": 0.0}, + ... }) + + """ + + def __init__( + self, + normalizer=None, + stage="all", + apply_to="input", + ): + """ + Initialize the NormalizerDataCallback. + + :param dict normalizer: Normalization specification. Either + - a dict with + {"scale": float | torch.Tensor, "shift": float | torch.Tensor}, or + - a dict mapping condition names to such dicts. If ``None`` no + normalization is performed. Default ``None``. + :param str stage: Stage during which to apply normalization. + One of {"train", "validate", "test", "all"}. + Defaults to "all". + :param str apply_to: Whether to normalize "input" or "target" data. + Defaults to "input". + :raises ValueError: If `apply_to` or `stage` are invalid. + """ + super().__init__() + + # validate apply_to + check_consistency(apply_to, str) + if apply_to not in {"input", "target"}: + raise ValueError( + f"apply_to must be 'input' or 'target', got {apply_to!r}" + ) + + # validate stage (can be None for setup flexibility) + check_consistency(stage, str) + if stage not in {"train", "validate", "test", "all"}: + raise ValueError( + f"stage must be 'train', 'validate', 'test', or 'all' " + f"got {stage!r}" + ) + + normalizer = normalizer or {"scale": 1, "shift": 1} + self.normalizer = self._validate_normalizer(normalizer) + self.apply_to = apply_to + self.stage = stage + + def _is_normalizer_dict(self, d): + """ + Check if a dictionary is a valid normalizer specification. + + :param dict d: Dictionary to validate. + :return: True if dict has {"scale", "shift"} keys with numeric values. + :rtype: bool + """ + return ( + isinstance(d, dict) + and set(d.keys()) == _REQUIRED_KEYS + and all( + isinstance(d[k], (float, int, torch.Tensor)) + for k in _REQUIRED_KEYS + ) + ) + + def _validate_normalizer(self, normalizer): + """ + Validate a normalizer configuration. + + :param dict normalizer: Candidate normalizer specification. + :raises ValueError: If the normalizer format is invalid. + :return: A validated normalizer dictionary. + :rtype: dict + """ + if self._is_normalizer_dict(normalizer): + return normalizer + + if isinstance(normalizer, dict) and all( + self._is_normalizer_dict(v) for v in normalizer.values() + ): + return normalizer + + raise ValueError( + "normalizer must be either:\n" + f" - dict with {_REQUIRED_KEYS}\n" + f" - dict of such dicts" + ) + + def setup(self, trainer, solver, stage): + """ + Apply normalization during setup. + + :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. + :param SolverInterface pl_module: A + :class:`~pina.solver.solver.SolverInterface` instance. + :param str stage: Current stage, not used kept for consistency. + :raises RuntimeError: If condition names do not match solver conditions. + :raises RuntimeError: If attempting to scale unavailable targets. + :return: Result of parent setup. + :rtype: Any + """ + # extract conditions + conditions = solver.problem.conditions + + # expand single normalizer to all conditions + if set(self.normalizer.keys()) == _REQUIRED_KEYS: + self.normalizer = {c: self.normalizer for c in conditions} + + # check condition keys + for cond in self.normalizer: + if cond not in conditions: + raise RuntimeError( + f"Condition '{cond}' not found in the normalizer dict. " + f"Got {list(self.normalizer)}, expected {list(conditions)}." + ) + if ( + hasattr(conditions[cond], "equation") + and self.apply_to == "target" + ): + raise RuntimeError( + f"Condition '{cond}' contains an equation object, " + "so there is no available target data to scale." + ) + + # select dataset and normalize + stage = stage or "fit" + if stage == "fit" and self.stage in ["train", "all"]: + self._scale_data(trainer.data_module.train_dataset) + if stage == "fit" and self.stage in ["validate", "all"]: + self._scale_data(trainer.data_module.val_dataset) + if stage == "test" and self.stage in ["test", "all"]: + self._scale_data(trainer.data_module.test_dataset) + + return super().setup(trainer, solver, stage) + + @staticmethod + def scale_fn(value, scale, shift): + """ + Normalize a tensor with the given scale and shift. + + :param value: Input tensor to normalize. + :type value: torch.Tensor | LabelTensor + :param scale: Scaling factor. + :type scale: float | int + :param shift: Shifting factor. + :type shift: float | int + :return: Normalized tensor (value - shift) / scale. + :rtype: torch.Tensor | LabelTensor + """ + if isinstance(value, LabelTensor): + return LabelTensor((value.tensor - shift) / scale, value.labels) + return (value - shift) / scale + + def _scale_data(self, dataset): + """ + Apply normalization to a dataset in-place. + + :param dataset: Dataset object with `conditions_dict` and `update_data`. + :type dataset: object + """ + new_points = {} + for cond in self.normalizer: + current_points = dataset.conditions_dict[cond][self.apply_to] + scale = self.normalizer[cond]["scale"] + shift = self.normalizer[cond]["shift"] + new_points[cond] = { + self.apply_to: self.scale_fn(current_points, scale, shift) + } + dataset.update_data(new_points) diff --git a/tests/test_callback/test_normalizer.py b/tests/test_callback/test_normalizer.py new file mode 100644 index 000000000..b3caf215c --- /dev/null +++ b/tests/test_callback/test_normalizer.py @@ -0,0 +1,187 @@ +import torch +import pytest + +from copy import deepcopy + +from pina import Trainer, LabelTensor, Condition +from pina.solver import PINN, SupervisedSolver +from pina.model import FeedForward +from pina.callback import NormalizerDataCallback +from pina.problem import AbstractProblem +from pina.problem.zoo import Poisson2DSquareProblem as Poisson + + +# for checking normalization +stage_map = { + "train": ["train_dataset"], + "validate": ["val_dataset"], + "test": ["test_dataset"], + "all": ["train_dataset", "val_dataset", "test_dataset"], +} + +# pinn solver +problem = Poisson() +problem.discretise_domain(10) +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) +pinn_solver = PINN(problem=problem, model=model) + + +class LabelTensorProblem(AbstractProblem): + input_variables = ["u_0", "u_1"] + output_variables = ["u"] + conditions = { + "data1": Condition( + input=LabelTensor(torch.randn(20, 2), ["u_0", "u_1"]), + target=LabelTensor(torch.randn(20, 1), ["u"]), + ), + "data2": Condition( + input=LabelTensor(torch.randn(20, 2), ["u_0", "u_1"]), + target=LabelTensor(torch.randn(20, 1), ["u"]), + ), + } + + +class TensorProblem(AbstractProblem): + input_variables = ["u_0", "u_1"] + output_variables = ["u"] + conditions = { + "data1": Condition(input=torch.randn(20, 2), target=torch.randn(20, 1)), + "data2": Condition(input=torch.randn(20, 2), target=torch.randn(20, 1)), + } + + +supervised_solver_no_lt = SupervisedSolver( + problem=TensorProblem(), model=FeedForward(2, 1), use_lt=False +) +supervised_solver_lt = SupervisedSolver( + problem=LabelTensorProblem(), model=FeedForward(2, 1), use_lt=False +) + + +# Test constructor +@pytest.mark.parametrize( + "normalizer", + [ + {"scale": 2.1, "shift": 1}, + {"scale": 2, "shift": torch.randn(2)}, + {"scale": 2, "shift": 1.1}, + {"a": {"scale": 2, "shift": 1}, "b": {"scale": 3, "shift": 0.5}}, + ], +) +def test_constructor_valid_normalizers(normalizer): + NormalizerDataCallback(normalizer=normalizer) + + +@pytest.mark.parametrize( + "invalid_normalizer", + [ + {"scale": 1}, # missing shift + {"shift": 1}, # missing scale + {"a": {"scale": 1}}, # dict of dicts, inner missing shift + [1, 2, 3], # wrong type + "invalid", # wrong type + ], +) +def test_constructor_invalid_normalizer_raises(invalid_normalizer): + with pytest.raises(ValueError): + NormalizerDataCallback(normalizer=invalid_normalizer) + + +@pytest.mark.parametrize("apply_to", ["input", "target"]) +def test_constructor_valid_apply_to(apply_to): + cb = NormalizerDataCallback(apply_to=apply_to) + assert cb.apply_to == apply_to + + +@pytest.mark.parametrize("apply_to", ["invalid", "", None, 123]) +def test_constructor_invalid_apply_to_raises(apply_to): + with pytest.raises(ValueError): + NormalizerDataCallback(apply_to=apply_to) + + +@pytest.mark.parametrize("stage", ["train", "validate", "test", "all"]) +def test_constructor_valid_stage(stage): + cb = NormalizerDataCallback(stage=stage) + assert cb.stage == stage + + +@pytest.mark.parametrize("stage", ["invalid", "", None, 123]) +def test_constructor_invalid_stage_raises(stage): + with pytest.raises(ValueError): + NormalizerDataCallback(stage=stage) + + +# Test setup +@pytest.mark.parametrize( + "normalizer", + [ + {"scale": 0.5, "shift": 1}, + ], +) +def test_invalid_setup(normalizer): + trainer = Trainer( + solver=pinn_solver, + callbacks=NormalizerDataCallback(normalizer, apply_to="target"), + max_epochs=1, + train_size=0.4, + val_size=0.3, + test_size=0.3, + ) + # trigger setup + with pytest.raises(RuntimeError): + trainer.train() + with pytest.raises(RuntimeError): + trainer.test() + + +@pytest.mark.parametrize("apply_to", ["input", "target"]) +@pytest.mark.parametrize( + "solver", [supervised_solver_lt, supervised_solver_no_lt] +) +@pytest.mark.parametrize("stage", ["train", "validate", "test", "all"]) +def test_setup(apply_to, solver, stage): + shift = torch.tensor([1, 1]) if apply_to == "input" else torch.tensor([1]) + + # Helper function to run trainer and check normalization + def check_normalization(normalizer_spec, check_cond=None): + trainer = Trainer( + solver=solver, + callbacks=NormalizerDataCallback( + normalizer=normalizer_spec, stage=stage, apply_to=apply_to + ), + max_epochs=1, + train_size=0.4, + val_size=0.3, + test_size=0.3, + shuffle=False, + ) + # save a copy of the old trainer datamodule + trainer_copy = deepcopy(trainer) + # trigger setup + trainer_copy.data_module.setup("fit") + trainer_copy.data_module.setup("test") + trainer.train() + trainer.test() + normalizer_spec = trainer.callbacks[0].normalizer + for ds_name in stage_map[stage]: + dataset = getattr(trainer.data_module, ds_name, None) + old_dataset = getattr(trainer_copy.data_module, ds_name, None) + for cond in ["data1", "data2"]: + current_points = dataset.conditions_dict[cond][apply_to] + old_points = old_dataset.conditions_dict[cond][apply_to] + if check_cond is None or cond in check_cond: + scale = normalizer_spec[cond]["scale"] + shift_val = normalizer_spec[cond]["shift"] + expected = (old_points - shift_val) / scale + else: + expected = old_points + print(torch.allclose(current_points, expected)) + assert torch.allclose(current_points, expected) + + # Test full normalizer applied to all conditions + full_normalizer = {"scale": 0.5, "shift": shift} + check_normalization(full_normalizer) + + # Test partial normalizer applied to some conditions + partial_normalizer = {"data1": {"scale": 0.5, "shift": shift}} + check_normalization(partial_normalizer, check_cond=["data1"]) From b5b068098f224984691e6ad64c12507aac1e57ee Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 11 Sep 2025 16:17:46 +0200 Subject: [PATCH 02/10] implement shift and scale parameters computation --- pina/callback/normalizer.py | 81 ++++++++++++++++++-------- tests/test_callback/test_normalizer.py | 13 +++++ 2 files changed, 71 insertions(+), 23 deletions(-) diff --git a/pina/callback/normalizer.py b/pina/callback/normalizer.py index d35ef1e99..83b31610b 100644 --- a/pina/callback/normalizer.py +++ b/pina/callback/normalizer.py @@ -41,8 +41,8 @@ def __init__( :param dict normalizer: Normalization specification. Either - a dict with {"scale": float | torch.Tensor, "shift": float | torch.Tensor}, or - - a dict mapping condition names to such dicts. If ``None`` no - normalization is performed. Default ``None``. + - a dict mapping condition names to such dicts. If ``None`` shift + and scale parameters are inferred during runtime. Default ``None``. :param str stage: Stage during which to apply normalization. One of {"train", "validate", "test", "all"}. Defaults to "all". @@ -67,7 +67,6 @@ def __init__( f"got {stage!r}" ) - normalizer = normalizer or {"scale": 1, "shift": 1} self.normalizer = self._validate_normalizer(normalizer) self.apply_to = apply_to self.stage = stage @@ -98,6 +97,9 @@ def _validate_normalizer(self, normalizer): :return: A validated normalizer dictionary. :rtype: dict """ + if normalizer is None: + return None + if self._is_normalizer_dict(normalizer): return normalizer @@ -108,8 +110,9 @@ def _validate_normalizer(self, normalizer): raise ValueError( "normalizer must be either:\n" - f" - dict with {_REQUIRED_KEYS}\n" - f" - dict of such dicts" + f" - dict with {_REQUIRED_KEYS} as keys\n" + f" - dict of dicts where the outer keys are condition names\n" + f" and the inner dicts have {_REQUIRED_KEYS} as keys\n" ) def setup(self, trainer, solver, stage): @@ -129,36 +132,68 @@ def setup(self, trainer, solver, stage): conditions = solver.problem.conditions # expand single normalizer to all conditions - if set(self.normalizer.keys()) == _REQUIRED_KEYS: - self.normalizer = {c: self.normalizer for c in conditions} - - # check condition keys - for cond in self.normalizer: - if cond not in conditions: - raise RuntimeError( - f"Condition '{cond}' not found in the normalizer dict. " - f"Got {list(self.normalizer)}, expected {list(conditions)}." - ) - if ( - hasattr(conditions[cond], "equation") - and self.apply_to == "target" - ): - raise RuntimeError( - f"Condition '{cond}' contains an equation object, " - "so there is no available target data to scale." - ) + if self.normalizer is not None: + if set(self.normalizer.keys()) == _REQUIRED_KEYS: + self.normalizer = {c: self.normalizer for c in conditions} + + # check condition keys + for cond in self.normalizer: + if cond not in conditions: + raise RuntimeError( + f"Condition '{cond}' not found in the normalizer dict. " + f"Got {list(self.normalizer)}, expected {list(conditions)}." + ) + if ( + hasattr(conditions[cond], "equation") + and self.apply_to == "target" + ): + raise RuntimeError( + f"Condition '{cond}' contains an equation object, " + "so there is no available target data to scale." + ) # select dataset and normalize stage = stage or "fit" if stage == "fit" and self.stage in ["train", "all"]: + if self.normalizer is None: + self.normalizer = {} + self._compute_scale_shift( + conditions, trainer.data_module.train_dataset + ) self._scale_data(trainer.data_module.train_dataset) if stage == "fit" and self.stage in ["validate", "all"]: + if self.normalizer is None: + raise RuntimeError( + "Cannot compute scale and shift from test data. " + "Please provide a valid normalizer dict." + ) self._scale_data(trainer.data_module.val_dataset) if stage == "test" and self.stage in ["test", "all"]: + if self.normalizer is None: + raise RuntimeError( + "Cannot compute scale and shift from test data. " + "Please provide a valid normalizer dict." + ) self._scale_data(trainer.data_module.test_dataset) return super().setup(trainer, solver, stage) + def _compute_scale_shift(self, conditions, dataset): + """ + Compute scale and shift for each condition from dataset. + + :param list conditions: List of condition names. + :param dataset: `~pina.data.dataset.PinaDataset` object. + :rtype: dict + """ + for cond in conditions: + if cond in dataset.conditions_dict: + data = dataset.conditions_dict[cond][self.apply_to] + self.normalizer[cond] = { + "shift": data.mean(dim=0), + "scale": data.std(dim=0) + 1e-8, + } + @staticmethod def scale_fn(value, scale, shift): """ diff --git a/tests/test_callback/test_normalizer.py b/tests/test_callback/test_normalizer.py index b3caf215c..7e1507b07 100644 --- a/tests/test_callback/test_normalizer.py +++ b/tests/test_callback/test_normalizer.py @@ -160,6 +160,16 @@ def check_normalization(normalizer_spec, check_cond=None): # trigger setup trainer_copy.data_module.setup("fit") trainer_copy.data_module.setup("test") + if normalizer_spec is None: + if stage == "validate": + with pytest.raises(RuntimeError): + trainer.train() + return + if stage == "test": + with pytest.raises(RuntimeError): + trainer.test() + return + trainer.train() trainer.test() normalizer_spec = trainer.callbacks[0].normalizer @@ -185,3 +195,6 @@ def check_normalization(normalizer_spec, check_cond=None): # Test partial normalizer applied to some conditions partial_normalizer = {"data1": {"scale": 0.5, "shift": shift}} check_normalization(partial_normalizer, check_cond=["data1"]) + + none_normalizer = None + check_normalization(none_normalizer) From 12111c685247b7db083b52fbdc223d5f3ca12ff7 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 15 Sep 2025 13:35:13 +0200 Subject: [PATCH 03/10] fix normalizer callback --- pina/callback/normalizer.py | 187 +++++++----------- tests/test_callback/test_normalizer.py | 262 +++++++++++++------------ 2 files changed, 207 insertions(+), 242 deletions(-) diff --git a/pina/callback/normalizer.py b/pina/callback/normalizer.py index 83b31610b..7efc0ec02 100644 --- a/pina/callback/normalizer.py +++ b/pina/callback/normalizer.py @@ -4,6 +4,7 @@ from lightning.pytorch import Callback from ..label_tensor import LabelTensor from ..utils import check_consistency +from ..condition import InputTargetCondition _REQUIRED_KEYS = {"scale", "shift"} @@ -21,28 +22,28 @@ class NormalizerDataCallback(Callback): :Example: - >>> NormalizerDataCallback({"scale": 1, "shift": 0}) - >>> NormalizerDataCallback({ - ... "a": {"scale": 2.0, "shift": 1.0}, - ... "b": {"scale": 0.5, "shift": 0.0}, - ... }) + >>> NormalizerDataCallback() + >>> NormalizerDataCallback( + ... "scale": torch.var, + ... "shift": torch.median + ... ) """ def __init__( self, - normalizer=None, + scale_fn=torch.std, + shift_fn=torch.mean, stage="all", apply_to="input", ): """ Initialize the NormalizerDataCallback. - :param dict normalizer: Normalization specification. Either - - a dict with - {"scale": float | torch.Tensor, "shift": float | torch.Tensor}, or - - a dict mapping condition names to such dicts. If ``None`` shift - and scale parameters are inferred during runtime. Default ``None``. + :param dict strategy: Normalization specification. It must be a dict + with keys "scale" and "shift", each mapping to a callable that + computes the respective value from a tensor. If None, defaults to + using mean and std. Defaults is ``None``. :param str stage: Stage during which to apply normalization. One of {"train", "validate", "test", "all"}. Defaults to "all". @@ -52,68 +53,49 @@ def __init__( """ super().__init__() - # validate apply_to + self.apply_to = self._validate_apply_to(apply_to) + self.stage = self._validate_stage(stage) + if not callable(scale_fn): + raise ValueError(f"scale_fn must be callable, got {scale_fn}") + self.scale_fn = scale_fn + if not callable(shift_fn): + raise ValueError(f"shift_fn must be callable, got {shift_fn}") + self.shift_fn = shift_fn + self.normalizer = {} + + def _validate_apply_to(self, apply_to): + """ + Validate the `apply_to` parameter. + + :param str apply_to: Candidate value for `apply_to`. + :raises ValueError: If `apply_to` is not "input" or "target". + :return: Validated `apply_to` value. + :rtype: str + """ check_consistency(apply_to, str) if apply_to not in {"input", "target"}: raise ValueError( - f"apply_to must be 'input' or 'target', got {apply_to!r}" - ) - - # validate stage (can be None for setup flexibility) - check_consistency(stage, str) - if stage not in {"train", "validate", "test", "all"}: - raise ValueError( - f"stage must be 'train', 'validate', 'test', or 'all' " - f"got {stage!r}" + f"apply_to must be 'input' or 'target', got {apply_to}" ) + return apply_to - self.normalizer = self._validate_normalizer(normalizer) - self.apply_to = apply_to - self.stage = stage - - def _is_normalizer_dict(self, d): + def _validate_stage(self, stage): """ - Check if a dictionary is a valid normalizer specification. + Validate the `stage` parameter. - :param dict d: Dictionary to validate. - :return: True if dict has {"scale", "shift"} keys with numeric values. - :rtype: bool + :param str stage: Candidate value for `stage`. + :raises ValueError: If `stage` is not one of "train", "validate", + "test", or "all". + :return: Validated `stage` value. + :rtype: str """ - return ( - isinstance(d, dict) - and set(d.keys()) == _REQUIRED_KEYS - and all( - isinstance(d[k], (float, int, torch.Tensor)) - for k in _REQUIRED_KEYS + check_consistency(stage, str) + if stage not in {"train", "validate", "test", "all"}: + raise ValueError( + f"stage must be 'train', 'validate', 'test', or 'all', got " + f"{stage}" ) - ) - - def _validate_normalizer(self, normalizer): - """ - Validate a normalizer configuration. - - :param dict normalizer: Candidate normalizer specification. - :raises ValueError: If the normalizer format is invalid. - :return: A validated normalizer dictionary. - :rtype: dict - """ - if normalizer is None: - return None - - if self._is_normalizer_dict(normalizer): - return normalizer - - if isinstance(normalizer, dict) and all( - self._is_normalizer_dict(v) for v in normalizer.values() - ): - return normalizer - - raise ValueError( - "normalizer must be either:\n" - f" - dict with {_REQUIRED_KEYS} as keys\n" - f" - dict of dicts where the outer keys are condition names\n" - f" and the inner dicts have {_REQUIRED_KEYS} as keys\n" - ) + return stage def setup(self, trainer, solver, stage): """ @@ -129,53 +111,27 @@ def setup(self, trainer, solver, stage): :rtype: Any """ # extract conditions - conditions = solver.problem.conditions - - # expand single normalizer to all conditions - if self.normalizer is not None: - if set(self.normalizer.keys()) == _REQUIRED_KEYS: - self.normalizer = {c: self.normalizer for c in conditions} - - # check condition keys - for cond in self.normalizer: - if cond not in conditions: - raise RuntimeError( - f"Condition '{cond}' not found in the normalizer dict. " - f"Got {list(self.normalizer)}, expected {list(conditions)}." - ) - if ( - hasattr(conditions[cond], "equation") - and self.apply_to == "target" - ): - raise RuntimeError( - f"Condition '{cond}' contains an equation object, " - "so there is no available target data to scale." - ) - - # select dataset and normalize - stage = stage or "fit" - if stage == "fit" and self.stage in ["train", "all"]: - if self.normalizer is None: - self.normalizer = {} - self._compute_scale_shift( - conditions, trainer.data_module.train_dataset - ) - self._scale_data(trainer.data_module.train_dataset) - if stage == "fit" and self.stage in ["validate", "all"]: - if self.normalizer is None: - raise RuntimeError( - "Cannot compute scale and shift from test data. " - "Please provide a valid normalizer dict." - ) - self._scale_data(trainer.data_module.val_dataset) - if stage == "test" and self.stage in ["test", "all"]: - if self.normalizer is None: + conditions_to_normalize = [] + for name, cond in solver.problem.conditions.items(): + if isinstance(cond, InputTargetCondition): + conditions_to_normalize.append(name) + + if not self.normalizer: + if not trainer.datamodule.train_dataset: raise RuntimeError( - "Cannot compute scale and shift from test data. " - "Please provide a valid normalizer dict." + "Training dataset is not available. Cannot compute " + "normalization parameters." ) - self._scale_data(trainer.data_module.test_dataset) + self._compute_scale_shift( + conditions_to_normalize, trainer.datamodule.train_dataset + ) + if stage == "fit" and self.stage in ["train", "all"]: + self._scale_data(trainer.datamodule.train_dataset) + if stage == "fit" and self.stage in ["validate", "all"]: + self._scale_data(trainer.datamodule.val_dataset) + if stage == "test" and self.stage in ["test", "all"]: + self._scale_data(trainer.datamodule.test_dataset) return super().setup(trainer, solver, stage) def _compute_scale_shift(self, conditions, dataset): @@ -189,13 +145,15 @@ def _compute_scale_shift(self, conditions, dataset): for cond in conditions: if cond in dataset.conditions_dict: data = dataset.conditions_dict[cond][self.apply_to] + shift = self.shift_fn(data) + scale = self.scale_fn(data) self.normalizer[cond] = { - "shift": data.mean(dim=0), - "scale": data.std(dim=0) + 1e-8, + "shift": shift, + "scale": scale, } @staticmethod - def scale_fn(value, scale, shift): + def _norm_fn(value, scale, shift): """ Normalize a tensor with the given scale and shift. @@ -208,9 +166,10 @@ def scale_fn(value, scale, shift): :return: Normalized tensor (value - shift) / scale. :rtype: torch.Tensor | LabelTensor """ + scaled_value = (value - shift) / scale if isinstance(value, LabelTensor): - return LabelTensor((value.tensor - shift) / scale, value.labels) - return (value - shift) / scale + scaled_value = LabelTensor(scaled_value, value.labels) + return scaled_value def _scale_data(self, dataset): """ @@ -225,6 +184,6 @@ def _scale_data(self, dataset): scale = self.normalizer[cond]["scale"] shift = self.normalizer[cond]["shift"] new_points[cond] = { - self.apply_to: self.scale_fn(current_points, scale, shift) + self.apply_to: self._norm_fn(current_points, scale, shift) } dataset.update_data(new_points) diff --git a/tests/test_callback/test_normalizer.py b/tests/test_callback/test_normalizer.py index 7e1507b07..4f2a660d3 100644 --- a/tests/test_callback/test_normalizer.py +++ b/tests/test_callback/test_normalizer.py @@ -1,15 +1,15 @@ import torch import pytest - from copy import deepcopy from pina import Trainer, LabelTensor, Condition -from pina.solver import PINN, SupervisedSolver +from pina.solver import SupervisedSolver from pina.model import FeedForward from pina.callback import NormalizerDataCallback from pina.problem import AbstractProblem from pina.problem.zoo import Poisson2DSquareProblem as Poisson - +from pina.condition.input_target_condition import InputTargetCondition +from pina.solver import PINN # for checking normalization stage_map = { @@ -19,11 +19,10 @@ "all": ["train_dataset", "val_dataset", "test_dataset"], } -# pinn solver -problem = Poisson() -problem.discretise_domain(10) -model = FeedForward(len(problem.input_variables), len(problem.output_variables)) -pinn_solver = PINN(problem=problem, model=model) +input_1 = torch.rand(20, 2) * 10 +target_1 = torch.rand(20, 1) * 10 +input_2 = torch.rand(20, 2) * 5 +target_2 = torch.rand(20, 1) * 5 class LabelTensorProblem(AbstractProblem): @@ -31,12 +30,12 @@ class LabelTensorProblem(AbstractProblem): output_variables = ["u"] conditions = { "data1": Condition( - input=LabelTensor(torch.randn(20, 2), ["u_0", "u_1"]), - target=LabelTensor(torch.randn(20, 1), ["u"]), + input=LabelTensor(input_1, ["u_0", "u_1"]), + target=LabelTensor(target_1, ["u"]), ), "data2": Condition( - input=LabelTensor(torch.randn(20, 2), ["u_0", "u_1"]), - target=LabelTensor(torch.randn(20, 1), ["u"]), + input=LabelTensor(input_2, ["u_0", "u_1"]), + target=LabelTensor(target_2, ["u"]), ), } @@ -45,8 +44,8 @@ class TensorProblem(AbstractProblem): input_variables = ["u_0", "u_1"] output_variables = ["u"] conditions = { - "data1": Condition(input=torch.randn(20, 2), target=torch.randn(20, 1)), - "data2": Condition(input=torch.randn(20, 2), target=torch.randn(20, 1)), + "data1": Condition(input=input_1, target=target_1), + "data2": Condition(input=input_2, target=target_2), } @@ -57,144 +56,151 @@ class TensorProblem(AbstractProblem): problem=LabelTensorProblem(), model=FeedForward(2, 1), use_lt=False ) - -# Test constructor -@pytest.mark.parametrize( - "normalizer", - [ - {"scale": 2.1, "shift": 1}, - {"scale": 2, "shift": torch.randn(2)}, - {"scale": 2, "shift": 1.1}, - {"a": {"scale": 2, "shift": 1}, "b": {"scale": 3, "shift": 0.5}}, - ], +poisson_problem = Poisson() +poisson_problem.conditions["data"] = Condition( + input=LabelTensor(torch.rand(20, 2) * 10, ["x", "y"]), + target=LabelTensor(torch.rand(20, 1) * 10, ["u"]), ) -def test_constructor_valid_normalizers(normalizer): - NormalizerDataCallback(normalizer=normalizer) -@pytest.mark.parametrize( - "invalid_normalizer", - [ - {"scale": 1}, # missing shift - {"shift": 1}, # missing scale - {"a": {"scale": 1}}, # dict of dicts, inner missing shift - [1, 2, 3], # wrong type - "invalid", # wrong type - ], -) -def test_constructor_invalid_normalizer_raises(invalid_normalizer): - with pytest.raises(ValueError): - NormalizerDataCallback(normalizer=invalid_normalizer) +@pytest.mark.parametrize("scale_fn", [torch.std, torch.var]) +@pytest.mark.parametrize("shift_fn", [torch.mean, torch.median]) +@pytest.mark.parametrize("apply_to", ["input", "target"]) +@pytest.mark.parametrize("stage", ["train", "validate", "test", "all"]) +def test_init(scale_fn, shift_fn, apply_to, stage): + normalizer = NormalizerDataCallback( + scale_fn=scale_fn, shift_fn=shift_fn, apply_to=apply_to, stage=stage + ) + assert normalizer.scale_fn == scale_fn + assert normalizer.shift_fn == shift_fn + assert normalizer.apply_to == apply_to + assert normalizer.stage == stage -@pytest.mark.parametrize("apply_to", ["input", "target"]) -def test_constructor_valid_apply_to(apply_to): - cb = NormalizerDataCallback(apply_to=apply_to) - assert cb.apply_to == apply_to +def test_init_invalid_scale(): + with pytest.raises(ValueError): + NormalizerDataCallback(scale_fn=1) -@pytest.mark.parametrize("apply_to", ["invalid", "", None, 123]) -def test_constructor_invalid_apply_to_raises(apply_to): +def test_init_invalid_shift(): with pytest.raises(ValueError): - NormalizerDataCallback(apply_to=apply_to) + NormalizerDataCallback(shift_fn=1) -@pytest.mark.parametrize("stage", ["train", "validate", "test", "all"]) -def test_constructor_valid_stage(stage): - cb = NormalizerDataCallback(stage=stage) - assert cb.stage == stage +@pytest.mark.parametrize("invalid_apply_to", ["inputt", "targett", 1]) +def test_init_invalid_apply_to(invalid_apply_to): + with pytest.raises(ValueError): + NormalizerDataCallback(apply_to=invalid_apply_to) -@pytest.mark.parametrize("stage", ["invalid", "", None, 123]) -def test_constructor_invalid_stage_raises(stage): +@pytest.mark.parametrize("invalid_stage", ["trainn", "validatee", 1]) +def test_init_invalid_stage(invalid_stage): with pytest.raises(ValueError): - NormalizerDataCallback(stage=stage) + NormalizerDataCallback(stage=invalid_stage) -# Test setup @pytest.mark.parametrize( - "normalizer", - [ - {"scale": 0.5, "shift": 1}, - ], + "solver", [supervised_solver_lt, supervised_solver_no_lt] ) -def test_invalid_setup(normalizer): +@pytest.mark.parametrize("scale_fn", [torch.std, torch.var]) +@pytest.mark.parametrize("shift_fn", [torch.mean, torch.median]) +@pytest.mark.parametrize("apply_to", ["input", "target"]) +@pytest.mark.parametrize("stage", ["all", "train", "validate", "test"]) +def test_setup(solver, scale_fn, shift_fn, stage, apply_to): trainer = Trainer( - solver=pinn_solver, - callbacks=NormalizerDataCallback(normalizer, apply_to="target"), + solver=solver, + callbacks=NormalizerDataCallback( + scale_fn=scale_fn, shift_fn=shift_fn, stage=stage, apply_to=apply_to + ), max_epochs=1, train_size=0.4, val_size=0.3, test_size=0.3, + shuffle=False, ) - # trigger setup - with pytest.raises(RuntimeError): - trainer.train() - with pytest.raises(RuntimeError): - trainer.test() - - -@pytest.mark.parametrize("apply_to", ["input", "target"]) -@pytest.mark.parametrize( - "solver", [supervised_solver_lt, supervised_solver_no_lt] -) -@pytest.mark.parametrize("stage", ["train", "validate", "test", "all"]) -def test_setup(apply_to, solver, stage): - shift = torch.tensor([1, 1]) if apply_to == "input" else torch.tensor([1]) - - # Helper function to run trainer and check normalization - def check_normalization(normalizer_spec, check_cond=None): - trainer = Trainer( - solver=solver, - callbacks=NormalizerDataCallback( - normalizer=normalizer_spec, stage=stage, apply_to=apply_to - ), - max_epochs=1, - train_size=0.4, - val_size=0.3, - test_size=0.3, - shuffle=False, + trainer_copy = deepcopy(trainer) + trainer_copy.data_module.setup("fit") + trainer_copy.data_module.setup("test") + trainer.train() + trainer.test() + + normalizer = trainer.callbacks[0].normalizer + + for cond in ["data1", "data2"]: + scale = scale_fn( + trainer_copy.data_module.train_dataset.conditions_dict[cond][ + apply_to + ] ) - # save a copy of the old trainer datamodule - trainer_copy = deepcopy(trainer) - # trigger setup - trainer_copy.data_module.setup("fit") - trainer_copy.data_module.setup("test") - if normalizer_spec is None: - if stage == "validate": - with pytest.raises(RuntimeError): - trainer.train() - return - if stage == "test": - with pytest.raises(RuntimeError): - trainer.test() - return - - trainer.train() - trainer.test() - normalizer_spec = trainer.callbacks[0].normalizer + shift = shift_fn( + trainer_copy.data_module.train_dataset.conditions_dict[cond][ + apply_to + ] + ) + assert "scale" in normalizer[cond] + assert "shift" in normalizer[cond] + assert normalizer[cond]["scale"] - scale < 1e-5 + assert normalizer[cond]["shift"] - shift < 1e-5 for ds_name in stage_map[stage]: dataset = getattr(trainer.data_module, ds_name, None) old_dataset = getattr(trainer_copy.data_module, ds_name, None) - for cond in ["data1", "data2"]: - current_points = dataset.conditions_dict[cond][apply_to] - old_points = old_dataset.conditions_dict[cond][apply_to] - if check_cond is None or cond in check_cond: - scale = normalizer_spec[cond]["scale"] - shift_val = normalizer_spec[cond]["shift"] - expected = (old_points - shift_val) / scale - else: - expected = old_points - print(torch.allclose(current_points, expected)) - assert torch.allclose(current_points, expected) - - # Test full normalizer applied to all conditions - full_normalizer = {"scale": 0.5, "shift": shift} - check_normalization(full_normalizer) - - # Test partial normalizer applied to some conditions - partial_normalizer = {"data1": {"scale": 0.5, "shift": shift}} - check_normalization(partial_normalizer, check_cond=["data1"]) - - none_normalizer = None - check_normalization(none_normalizer) + current_points = dataset.conditions_dict[cond][apply_to] + old_points = old_dataset.conditions_dict[cond][apply_to] + expected = (old_points - shift) / scale + assert torch.allclose(current_points, expected) + + +@pytest.mark.parametrize("scale_fn", [torch.std, torch.var]) +@pytest.mark.parametrize("shift_fn", [torch.mean, torch.median]) +@pytest.mark.parametrize("apply_to", ["input"]) +@pytest.mark.parametrize("stage", ["all", "train", "validate", "test"]) +def test_setup_pinn(scale_fn, shift_fn, stage, apply_to): + pinn = PINN( + problem=poisson_problem, + model=FeedForward(2, 1), + ) + poisson_problem.discretise_domain(n=10) + trainer = Trainer( + solver=pinn, + callbacks=NormalizerDataCallback( + scale_fn=scale_fn, + shift_fn=shift_fn, + stage=stage, + apply_to=apply_to, + ), + max_epochs=1, + train_size=0.4, + val_size=0.3, + test_size=0.3, + shuffle=False, + ) + + trainer_copy = deepcopy(trainer) + trainer_copy.data_module.setup("fit") + trainer_copy.data_module.setup("test") + trainer.train() + trainer.test() + + conditions = trainer.callbacks[0].normalizer.keys() + assert "data" in conditions + assert len(conditions) == 1 + normalizer = trainer.callbacks[0].normalizer + cond = "data" + + scale = scale_fn( + trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to] + ) + shift = shift_fn( + trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to] + ) + assert "scale" in normalizer[cond] + assert "shift" in normalizer[cond] + assert normalizer[cond]["scale"] - scale < 1e-5 + assert normalizer[cond]["shift"] - shift < 1e-5 + for ds_name in stage_map[stage]: + dataset = getattr(trainer.data_module, ds_name, None) + old_dataset = getattr(trainer_copy.data_module, ds_name, None) + current_points = dataset.conditions_dict[cond][apply_to] + old_points = old_dataset.conditions_dict[cond][apply_to] + expected = (old_points - shift) / scale + assert torch.allclose(current_points, expected) From ab542dab23853d2afc2e1f6cffc591c6742d4dfb Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 15 Sep 2025 14:16:32 +0200 Subject: [PATCH 04/10] fix codacy --- pina/callback/normalizer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pina/callback/normalizer.py b/pina/callback/normalizer.py index 7efc0ec02..4f9f1b184 100644 --- a/pina/callback/normalizer.py +++ b/pina/callback/normalizer.py @@ -97,7 +97,7 @@ def _validate_stage(self, stage): ) return stage - def setup(self, trainer, solver, stage): + def setup(self, trainer, pl_module, stage): """ Apply normalization during setup. @@ -112,7 +112,7 @@ def setup(self, trainer, solver, stage): """ # extract conditions conditions_to_normalize = [] - for name, cond in solver.problem.conditions.items(): + for name, cond in pl_module.problem.conditions.items(): if isinstance(cond, InputTargetCondition): conditions_to_normalize.append(name) @@ -132,7 +132,7 @@ def setup(self, trainer, solver, stage): self._scale_data(trainer.datamodule.val_dataset) if stage == "test" and self.stage in ["test", "all"]: self._scale_data(trainer.datamodule.test_dataset) - return super().setup(trainer, solver, stage) + return super().setup(trainer, pl_module, stage) def _compute_scale_shift(self, conditions, dataset): """ @@ -179,10 +179,10 @@ def _scale_data(self, dataset): :type dataset: object """ new_points = {} - for cond in self.normalizer: + for cond, norm_params in self.normalizer.items(): current_points = dataset.conditions_dict[cond][self.apply_to] - scale = self.normalizer[cond]["scale"] - shift = self.normalizer[cond]["shift"] + scale = norm_params["scale"] + shift = norm_params["shift"] new_points[cond] = { self.apply_to: self._norm_fn(current_points, scale, shift) } From 68aad602155718e048bfd62f1fb09e4af8cac363 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 15 Sep 2025 17:19:47 +0200 Subject: [PATCH 05/10] minor fix --- pina/callback/normalizer.py | 52 ++++++++++++++++---------- tests/test_callback/test_normalizer.py | 2 +- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/pina/callback/normalizer.py b/pina/callback/normalizer.py index 4f9f1b184..dc1d3a9d6 100644 --- a/pina/callback/normalizer.py +++ b/pina/callback/normalizer.py @@ -6,8 +6,6 @@ from ..utils import check_consistency from ..condition import InputTargetCondition -_REQUIRED_KEYS = {"scale", "shift"} - class NormalizerDataCallback(Callback): r""" @@ -61,7 +59,7 @@ def __init__( if not callable(shift_fn): raise ValueError(f"shift_fn must be callable, got {shift_fn}") self.shift_fn = shift_fn - self.normalizer = {} + self._normalizer = {} def _validate_apply_to(self, apply_to): """ @@ -111,12 +109,13 @@ def setup(self, trainer, pl_module, stage): :rtype: Any """ # extract conditions - conditions_to_normalize = [] - for name, cond in pl_module.problem.conditions.items(): - if isinstance(cond, InputTargetCondition): - conditions_to_normalize.append(name) + conditions_to_normalize = [ + name + for name, cond in pl_module.problem.conditions.items() + if isinstance(cond, InputTargetCondition) + ] - if not self.normalizer: + if not self._normalizer: if not trainer.datamodule.train_dataset: raise RuntimeError( "Training dataset is not available. Cannot compute " @@ -127,11 +126,11 @@ def setup(self, trainer, pl_module, stage): ) if stage == "fit" and self.stage in ["train", "all"]: - self._scale_data(trainer.datamodule.train_dataset) + self.normalize_dataset(trainer.datamodule.train_dataset) if stage == "fit" and self.stage in ["validate", "all"]: - self._scale_data(trainer.datamodule.val_dataset) + self.normalize_dataset(trainer.datamodule.val_dataset) if stage == "test" and self.stage in ["test", "all"]: - self._scale_data(trainer.datamodule.test_dataset) + self.normalize_dataset(trainer.datamodule.test_dataset) return super().setup(trainer, pl_module, stage) def _compute_scale_shift(self, conditions, dataset): @@ -147,7 +146,7 @@ def _compute_scale_shift(self, conditions, dataset): data = dataset.conditions_dict[cond][self.apply_to] shift = self.shift_fn(data) scale = self.scale_fn(data) - self.normalizer[cond] = { + self._normalizer[cond] = { "shift": shift, "scale": scale, } @@ -171,19 +170,34 @@ def _norm_fn(value, scale, shift): scaled_value = LabelTensor(scaled_value, value.labels) return scaled_value - def _scale_data(self, dataset): + def normalize_dataset(self, dataset): """ Apply normalization to a dataset in-place. :param dataset: Dataset object with `conditions_dict` and `update_data`. :type dataset: object """ - new_points = {} - for cond, norm_params in self.normalizer.items(): - current_points = dataset.conditions_dict[cond][self.apply_to] + update_dataset_dict = {} + for cond, norm_params in self._normalizer.items(): + points = dataset.conditions_dict[cond][self.apply_to] scale = norm_params["scale"] shift = norm_params["shift"] - new_points[cond] = { - self.apply_to: self._norm_fn(current_points, scale, shift) + normalized_points = self._norm_fn(points, scale, shift) + update_dataset_dict[cond] = { + self.apply_to: ( + LabelTensor(normalized_points, points.labels) + if isinstance(points, LabelTensor) + else normalized_points + ) } - dataset.update_data(new_points) + dataset.update_data(update_dataset_dict) + + @property + def normalizer(self): + """ + Get the computed normalizer parameters. + + :return: Dictionary of normalization parameters. + :rtype: dict + """ + return self._normalizer diff --git a/tests/test_callback/test_normalizer.py b/tests/test_callback/test_normalizer.py index 4f2a660d3..274a62a43 100644 --- a/tests/test_callback/test_normalizer.py +++ b/tests/test_callback/test_normalizer.py @@ -53,7 +53,7 @@ class TensorProblem(AbstractProblem): problem=TensorProblem(), model=FeedForward(2, 1), use_lt=False ) supervised_solver_lt = SupervisedSolver( - problem=LabelTensorProblem(), model=FeedForward(2, 1), use_lt=False + problem=LabelTensorProblem(), model=FeedForward(2, 1), use_lt=True ) poisson_problem = Poisson() From 3aa821453332bf17e2c9b06c0df7cb87e09db50f Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Mon, 15 Sep 2025 18:56:46 +0200 Subject: [PATCH 06/10] * update is_function * change name files normalizer data callback --- docs/source/_rst/_code.rst | 2 +- .../{normalizer.rst => normalizer_data_callback.rst} | 2 +- pina/callback/__init__.py | 2 +- .../{normalizer.py => normalizer_data_callback.py} | 10 +++++----- pina/utils.py | 2 +- ..._normalizer.py => test_normalizer_data_callback.py} | 0 6 files changed, 9 insertions(+), 9 deletions(-) rename docs/source/_rst/callback/{normalizer.rst => normalizer_data_callback.rst} (67%) rename pina/callback/{normalizer.py => normalizer_data_callback.py} (96%) rename tests/test_callback/{test_normalizer.py => test_normalizer_data_callback.py} (100%) diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index e911936fd..160eb3542 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -253,7 +253,7 @@ Callbacks Optimizer callback R3 Refinment callback Refinment Interface callback - Normalizer callback + Normalizer callback Losses and Weightings --------------------- diff --git a/docs/source/_rst/callback/normalizer.rst b/docs/source/_rst/callback/normalizer_data_callback.rst similarity index 67% rename from docs/source/_rst/callback/normalizer.rst rename to docs/source/_rst/callback/normalizer_data_callback.rst index eb61b754a..6f59f7aee 100644 --- a/docs/source/_rst/callback/normalizer.rst +++ b/docs/source/_rst/callback/normalizer_data_callback.rst @@ -1,7 +1,7 @@ Normalizer callbacks ======================= -.. currentmodule:: pina.callback.normalizer +.. currentmodule:: pina.callback.normalizer_data_callback .. autoclass:: NormalizerDataCallback :members: :show-inheritance: \ No newline at end of file diff --git a/pina/callback/__init__.py b/pina/callback/__init__.py index 599f76b9c..f71a89f91 100644 --- a/pina/callback/__init__.py +++ b/pina/callback/__init__.py @@ -11,4 +11,4 @@ from .optimizer_callback import SwitchOptimizer from .processing_callback import MetricTracker, PINAProgressBar from .refinement import R3Refinement -from .normalizer import NormalizerDataCallback +from .normalizer_data_callback import NormalizerDataCallback diff --git a/pina/callback/normalizer.py b/pina/callback/normalizer_data_callback.py similarity index 96% rename from pina/callback/normalizer.py rename to pina/callback/normalizer_data_callback.py index dc1d3a9d6..356bb48cf 100644 --- a/pina/callback/normalizer.py +++ b/pina/callback/normalizer_data_callback.py @@ -3,7 +3,7 @@ import torch from lightning.pytorch import Callback from ..label_tensor import LabelTensor -from ..utils import check_consistency +from ..utils import check_consistency, is_function from ..condition import InputTargetCondition @@ -53,10 +53,10 @@ def __init__( self.apply_to = self._validate_apply_to(apply_to) self.stage = self._validate_stage(stage) - if not callable(scale_fn): + if not is_function(scale_fn): raise ValueError(f"scale_fn must be callable, got {scale_fn}") self.scale_fn = scale_fn - if not callable(shift_fn): + if not is_function(shift_fn): raise ValueError(f"shift_fn must be callable, got {shift_fn}") self.shift_fn = shift_fn self._normalizer = {} @@ -115,7 +115,7 @@ def setup(self, trainer, pl_module, stage): if isinstance(cond, InputTargetCondition) ] - if not self._normalizer: + if not self.normalizer: if not trainer.datamodule.train_dataset: raise RuntimeError( "Training dataset is not available. Cannot compute " @@ -178,7 +178,7 @@ def normalize_dataset(self, dataset): :type dataset: object """ update_dataset_dict = {} - for cond, norm_params in self._normalizer.items(): + for cond, norm_params in self.normalizer.items(): points = dataset.conditions_dict[cond][self.apply_to] scale = norm_params["scale"] shift = norm_params["shift"] diff --git a/pina/utils.py b/pina/utils.py index 2aafba1f2..efc48424e 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -206,7 +206,7 @@ def is_function(f): :return: ``True`` if ``f`` is a function, ``False`` otherwise. :rtype: bool """ - return isinstance(f, (types.FunctionType, types.LambdaType)) + return callable(f) def chebyshev_roots(n): diff --git a/tests/test_callback/test_normalizer.py b/tests/test_callback/test_normalizer_data_callback.py similarity index 100% rename from tests/test_callback/test_normalizer.py rename to tests/test_callback/test_normalizer_data_callback.py From 63135b451f56394f508e91b2d914ee68ed81bd01 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 15 Sep 2025 19:41:05 +0200 Subject: [PATCH 07/10] reduce tests --- .../test_normalizer_data_callback.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/test_callback/test_normalizer_data_callback.py b/tests/test_callback/test_normalizer_data_callback.py index 274a62a43..d6c906b1b 100644 --- a/tests/test_callback/test_normalizer_data_callback.py +++ b/tests/test_callback/test_normalizer_data_callback.py @@ -102,11 +102,13 @@ def test_init_invalid_stage(invalid_stage): @pytest.mark.parametrize( "solver", [supervised_solver_lt, supervised_solver_no_lt] ) -@pytest.mark.parametrize("scale_fn", [torch.std, torch.var]) -@pytest.mark.parametrize("shift_fn", [torch.mean, torch.median]) +@pytest.mark.parametrize( + "fn", [[torch.std, torch.mean], [torch.var, torch.median]] +) @pytest.mark.parametrize("apply_to", ["input", "target"]) @pytest.mark.parametrize("stage", ["all", "train", "validate", "test"]) -def test_setup(solver, scale_fn, shift_fn, stage, apply_to): +def test_setup(solver, fn, stage, apply_to): + scale_fn, shift_fn = fn trainer = Trainer( solver=solver, callbacks=NormalizerDataCallback( @@ -150,11 +152,13 @@ def test_setup(solver, scale_fn, shift_fn, stage, apply_to): assert torch.allclose(current_points, expected) -@pytest.mark.parametrize("scale_fn", [torch.std, torch.var]) -@pytest.mark.parametrize("shift_fn", [torch.mean, torch.median]) +@pytest.mark.parametrize( + "fn", [[torch.std, torch.mean], [torch.var, torch.median]] +) @pytest.mark.parametrize("apply_to", ["input"]) @pytest.mark.parametrize("stage", ["all", "train", "validate", "test"]) -def test_setup_pinn(scale_fn, shift_fn, stage, apply_to): +def test_setup_pinn(fn, stage, apply_to): + scale_fn, shift_fn = fn pinn = PINN( problem=poisson_problem, model=FeedForward(2, 1), From ae7a4d75054c565d2c46c73a6def862257208b93 Mon Sep 17 00:00:00 2001 From: giovanni Date: Tue, 16 Sep 2025 11:03:34 +0200 Subject: [PATCH 08/10] fix documentation --- pina/callback/normalizer_data_callback.py | 114 ++++++++++++---------- 1 file changed, 64 insertions(+), 50 deletions(-) diff --git a/pina/callback/normalizer_data_callback.py b/pina/callback/normalizer_data_callback.py index 356bb48cf..213175c97 100644 --- a/pina/callback/normalizer_data_callback.py +++ b/pina/callback/normalizer_data_callback.py @@ -9,8 +9,8 @@ class NormalizerDataCallback(Callback): r""" - A Lightning Callback that normalizes dataset inputs or targets - according to user-provided scale and shift parameters. + A Callback used to normalize the dataset inputs or targets according to + user-provided scale and shift functions. The transformation is applied as: @@ -22,10 +22,11 @@ class NormalizerDataCallback(Callback): >>> NormalizerDataCallback() >>> NormalizerDataCallback( - ... "scale": torch.var, - ... "shift": torch.median + ... scale_fn: torch.std, + ... shift_fn: torch.mean, + ... stage: "all", + ... apply_to: "input", ... ) - """ def __init__( @@ -36,63 +37,71 @@ def __init__( apply_to="input", ): """ - Initialize the NormalizerDataCallback. + Initialization of the :class:`NormalizerDataCallback` class. - :param dict strategy: Normalization specification. It must be a dict - with keys "scale" and "shift", each mapping to a callable that - computes the respective value from a tensor. If None, defaults to - using mean and std. Defaults is ``None``. - :param str stage: Stage during which to apply normalization. - One of {"train", "validate", "test", "all"}. - Defaults to "all". + :param Callable scale_fn: The function to compute the scaling factor. + Default is ``torch.std``. + :param Callable shift_fn: The function to compute the shifting factor. + Default is ``torch.mean``. + :param str stage: The stage in which normalization is applied. + Accepted values are "train", "validate", "test", or "all". + Default is ``"all"``. :param str apply_to: Whether to normalize "input" or "target" data. - Defaults to "input". - :raises ValueError: If `apply_to` or `stage` are invalid. + Default is ``"input"``. + :raises ValueError: If ``scale_fn`` is not callable. + :raises ValueError: If ``shift_fn`` is not callable. """ super().__init__() + # Validate parameters self.apply_to = self._validate_apply_to(apply_to) self.stage = self._validate_stage(stage) + + # Validate functions if not is_function(scale_fn): - raise ValueError(f"scale_fn must be callable, got {scale_fn}") - self.scale_fn = scale_fn + raise ValueError(f"scale_fn must be Callable, got {scale_fn}") if not is_function(shift_fn): - raise ValueError(f"shift_fn must be callable, got {shift_fn}") + raise ValueError(f"shift_fn must be Callable, got {shift_fn}") + self.scale_fn = scale_fn self.shift_fn = shift_fn + + # Initialize normalizer dictionary self._normalizer = {} def _validate_apply_to(self, apply_to): """ - Validate the `apply_to` parameter. + Validate the ``apply_to`` parameter. - :param str apply_to: Candidate value for `apply_to`. - :raises ValueError: If `apply_to` is not "input" or "target". - :return: Validated `apply_to` value. + :param str apply_to: The candidate value for the ``apply_to`` parameter. + :raises ValueError: If ``apply_to`` is neither "input" nor "target". + :return: The validated ``apply_to`` value. :rtype: str """ check_consistency(apply_to, str) if apply_to not in {"input", "target"}: raise ValueError( - f"apply_to must be 'input' or 'target', got {apply_to}" + f"apply_to must be either 'input' or 'target', got {apply_to}" ) + return apply_to def _validate_stage(self, stage): """ - Validate the `stage` parameter. + Validate the ``stage`` parameter. - :param str stage: Candidate value for `stage`. - :raises ValueError: If `stage` is not one of "train", "validate", + :param str stage: The candidate value for the ``stage`` parameter. + :raises ValueError: If ``stage`` is not one of "train", "validate", "test", or "all". - :return: Validated `stage` value. + :return: The validated ``stage`` value. :rtype: str """ check_consistency(stage, str) if stage not in {"train", "validate", "test", "all"}: raise ValueError( - f"stage must be 'train', 'validate', 'test', or 'all', got " - f"{stage}" + "stage must be one of 'train', 'validate', 'test', or 'all'," + f" got {stage}" ) + return stage def setup(self, trainer, pl_module, stage): @@ -102,19 +111,20 @@ def setup(self, trainer, pl_module, stage): :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. :param SolverInterface pl_module: A :class:`~pina.solver.solver.SolverInterface` instance. - :param str stage: Current stage, not used kept for consistency. - :raises RuntimeError: If condition names do not match solver conditions. - :raises RuntimeError: If attempting to scale unavailable targets. - :return: Result of parent setup. + :param str stage: The current stage. + :raises RuntimeError: If the training dataset is not available when + computing normalization parameters. + :return: The result of the parent setup. :rtype: Any """ - # extract conditions + # Extract conditions conditions_to_normalize = [ name for name, cond in pl_module.problem.conditions.items() if isinstance(cond, InputTargetCondition) ] + # Compute scale and shift parameters if not self.normalizer: if not trainer.datamodule.train_dataset: raise RuntimeError( @@ -125,21 +135,22 @@ def setup(self, trainer, pl_module, stage): conditions_to_normalize, trainer.datamodule.train_dataset ) + # Apply normalization based on the specified stage if stage == "fit" and self.stage in ["train", "all"]: self.normalize_dataset(trainer.datamodule.train_dataset) if stage == "fit" and self.stage in ["validate", "all"]: self.normalize_dataset(trainer.datamodule.val_dataset) if stage == "test" and self.stage in ["test", "all"]: self.normalize_dataset(trainer.datamodule.test_dataset) + return super().setup(trainer, pl_module, stage) def _compute_scale_shift(self, conditions, dataset): """ - Compute scale and shift for each condition from dataset. + Compute scale and shift parameters for each condition in the dataset. - :param list conditions: List of condition names. - :param dataset: `~pina.data.dataset.PinaDataset` object. - :rtype: dict + :param list conditions: The list of condition names. + :param dataset: The `~pina.data.dataset.PinaDataset` dataset. """ for cond in conditions: if cond in dataset.conditions_dict: @@ -154,30 +165,31 @@ def _compute_scale_shift(self, conditions, dataset): @staticmethod def _norm_fn(value, scale, shift): """ - Normalize a tensor with the given scale and shift. + Normalize a value according to the scale and shift parameters. - :param value: Input tensor to normalize. + :param value: The input tensor to normalize. :type value: torch.Tensor | LabelTensor - :param scale: Scaling factor. - :type scale: float | int - :param shift: Shifting factor. - :type shift: float | int - :return: Normalized tensor (value - shift) / scale. + :param float scale: The scaling factor. + :param float shift: The shifting factor. + :return: The normalized tensor. :rtype: torch.Tensor | LabelTensor """ scaled_value = (value - shift) / scale if isinstance(value, LabelTensor): scaled_value = LabelTensor(scaled_value, value.labels) + return scaled_value def normalize_dataset(self, dataset): """ - Apply normalization to a dataset in-place. + Apply in-place normalization to the dataset. - :param dataset: Dataset object with `conditions_dict` and `update_data`. - :type dataset: object + :param PinaDataset dataset: The dataset to be normalized. """ + # Initialize update dictionary update_dataset_dict = {} + + # Iterate over conditions and apply normalization for cond, norm_params in self.normalizer.items(): points = dataset.conditions_dict[cond][self.apply_to] scale = norm_params["scale"] @@ -190,14 +202,16 @@ def normalize_dataset(self, dataset): else normalized_points ) } + + # Update the dataset in-place dataset.update_data(update_dataset_dict) @property def normalizer(self): """ - Get the computed normalizer parameters. + Get the dictionary of normalization parameters. - :return: Dictionary of normalization parameters. + :return: The dictionary of normalization parameters. :rtype: dict """ return self._normalizer From d6bbd93a01b529dbce7d90fd2b94edeea99b17d3 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 16 Sep 2025 14:46:34 +0200 Subject: [PATCH 09/10] add NotImplementedError for PinaGraphDataset --- pina/callback/normalizer_data_callback.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pina/callback/normalizer_data_callback.py b/pina/callback/normalizer_data_callback.py index 213175c97..ef957b9ef 100644 --- a/pina/callback/normalizer_data_callback.py +++ b/pina/callback/normalizer_data_callback.py @@ -5,6 +5,7 @@ from ..label_tensor import LabelTensor from ..utils import check_consistency, is_function from ..condition import InputTargetCondition +from ..data.dataset import PinaGraphDataset class NormalizerDataCallback(Callback): @@ -116,7 +117,17 @@ def setup(self, trainer, pl_module, stage): computing normalization parameters. :return: The result of the parent setup. :rtype: Any + + :raises NotImplementedError: If the dataset is graph-based. """ + + # Ensure datsets are not graph-based + if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset): + raise NotImplementedError( + "NormalizerDataCallback is not compatible with " + "graph-based datasets." + ) + # Extract conditions conditions_to_normalize = [ name From 7f6316d89327f4a1d3c640aaf7701240d99b3f94 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 16 Sep 2025 17:02:36 +0200 Subject: [PATCH 10/10] add graph test --- .../test_normalizer_data_callback.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/tests/test_callback/test_normalizer_data_callback.py b/tests/test_callback/test_normalizer_data_callback.py index d6c906b1b..7cdcc9510 100644 --- a/tests/test_callback/test_normalizer_data_callback.py +++ b/tests/test_callback/test_normalizer_data_callback.py @@ -8,8 +8,8 @@ from pina.callback import NormalizerDataCallback from pina.problem import AbstractProblem from pina.problem.zoo import Poisson2DSquareProblem as Poisson -from pina.condition.input_target_condition import InputTargetCondition from pina.solver import PINN +from pina.graph import RadiusGraph # for checking normalization stage_map = { @@ -49,6 +49,18 @@ class TensorProblem(AbstractProblem): } +input_graph = [RadiusGraph(radius=0.5, pos=torch.rand(10, 2)) for _ in range(5)] +output_graph = torch.rand(5, 1) + + +class GraphProblem(AbstractProblem): + input_variables = ["u_0", "u_1"] + output_variables = ["u"] + conditions = { + "data": Condition(input=input_graph, target=output_graph), + } + + supervised_solver_no_lt = SupervisedSolver( problem=TensorProblem(), model=FeedForward(2, 1), use_lt=False ) @@ -208,3 +220,25 @@ def test_setup_pinn(fn, stage, apply_to): old_points = old_dataset.conditions_dict[cond][apply_to] expected = (old_points - shift) / scale assert torch.allclose(current_points, expected) + + +def test_setup_graph_dataset(): + solver = SupervisedSolver( + problem=GraphProblem(), model=FeedForward(2, 1), use_lt=False + ) + trainer = Trainer( + solver=solver, + callbacks=NormalizerDataCallback( + scale_fn=torch.std, + shift_fn=torch.mean, + stage="all", + apply_to="input", + ), + max_epochs=1, + train_size=0.4, + val_size=0.3, + test_size=0.3, + shuffle=False, + ) + with pytest.raises(NotImplementedError): + trainer.train()