Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ Callbacks
Optimizer callback <callback/optimizer_callback.rst>
R3 Refinment callback <callback/refinement/r3_refinement.rst>
Refinment Interface callback <callback/refinement/refinement_interface.rst>
Normalizer callback <callback/normalizer_data_callback.rst>

Losses and Weightings
---------------------
Expand Down
7 changes: 7 additions & 0 deletions docs/source/_rst/callback/normalizer_data_callback.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Normalizer callbacks
=======================

.. currentmodule:: pina.callback.normalizer_data_callback
.. autoclass:: NormalizerDataCallback
:members:
:show-inheritance:
2 changes: 2 additions & 0 deletions pina/callback/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
"MetricTracker",
"PINAProgressBar",
"R3Refinement",
"NormalizerDataCallback",
]

from .optimizer_callback import SwitchOptimizer
from .processing_callback import MetricTracker, PINAProgressBar
from .refinement import R3Refinement
from .normalizer_data_callback import NormalizerDataCallback
228 changes: 228 additions & 0 deletions pina/callback/normalizer_data_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
"""Module for the Normalizer callback."""

import torch
from lightning.pytorch import Callback
from ..label_tensor import LabelTensor
from ..utils import check_consistency, is_function
from ..condition import InputTargetCondition
from ..data.dataset import PinaGraphDataset


class NormalizerDataCallback(Callback):
r"""
A Callback used to normalize the dataset inputs or targets according to
user-provided scale and shift functions.

The transformation is applied as:

.. math::

x_{\text{new}} = \frac{x - \text{shift}}{\text{scale}}

:Example:

>>> NormalizerDataCallback()
>>> NormalizerDataCallback(
... scale_fn: torch.std,
... shift_fn: torch.mean,
... stage: "all",
... apply_to: "input",
... )
"""

def __init__(
self,
scale_fn=torch.std,
shift_fn=torch.mean,
stage="all",
apply_to="input",
):
"""
Initialization of the :class:`NormalizerDataCallback` class.

:param Callable scale_fn: The function to compute the scaling factor.
Default is ``torch.std``.
:param Callable shift_fn: The function to compute the shifting factor.
Default is ``torch.mean``.
:param str stage: The stage in which normalization is applied.
Accepted values are "train", "validate", "test", or "all".
Default is ``"all"``.
:param str apply_to: Whether to normalize "input" or "target" data.
Default is ``"input"``.
:raises ValueError: If ``scale_fn`` is not callable.
:raises ValueError: If ``shift_fn`` is not callable.
"""
super().__init__()

# Validate parameters
self.apply_to = self._validate_apply_to(apply_to)
self.stage = self._validate_stage(stage)

# Validate functions
if not is_function(scale_fn):
raise ValueError(f"scale_fn must be Callable, got {scale_fn}")
if not is_function(shift_fn):
raise ValueError(f"shift_fn must be Callable, got {shift_fn}")
self.scale_fn = scale_fn
self.shift_fn = shift_fn

# Initialize normalizer dictionary
self._normalizer = {}

def _validate_apply_to(self, apply_to):
"""
Validate the ``apply_to`` parameter.

:param str apply_to: The candidate value for the ``apply_to`` parameter.
:raises ValueError: If ``apply_to`` is neither "input" nor "target".
:return: The validated ``apply_to`` value.
:rtype: str
"""
check_consistency(apply_to, str)
if apply_to not in {"input", "target"}:
raise ValueError(
f"apply_to must be either 'input' or 'target', got {apply_to}"
)

return apply_to

def _validate_stage(self, stage):
"""
Validate the ``stage`` parameter.

:param str stage: The candidate value for the ``stage`` parameter.
:raises ValueError: If ``stage`` is not one of "train", "validate",
"test", or "all".
:return: The validated ``stage`` value.
:rtype: str
"""
check_consistency(stage, str)
if stage not in {"train", "validate", "test", "all"}:
raise ValueError(
"stage must be one of 'train', 'validate', 'test', or 'all',"
f" got {stage}"
)

return stage

def setup(self, trainer, pl_module, stage):
"""
Apply normalization during setup.

:param Trainer trainer: A :class:`~pina.trainer.Trainer` instance.
:param SolverInterface pl_module: A
:class:`~pina.solver.solver.SolverInterface` instance.
:param str stage: The current stage.
:raises RuntimeError: If the training dataset is not available when
computing normalization parameters.
:return: The result of the parent setup.
:rtype: Any

:raises NotImplementedError: If the dataset is graph-based.
"""

# Ensure datsets are not graph-based
if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset):
raise NotImplementedError(
"NormalizerDataCallback is not compatible with "
"graph-based datasets."
)

# Extract conditions
conditions_to_normalize = [
name
for name, cond in pl_module.problem.conditions.items()
if isinstance(cond, InputTargetCondition)
]

# Compute scale and shift parameters
if not self.normalizer:
if not trainer.datamodule.train_dataset:
raise RuntimeError(
"Training dataset is not available. Cannot compute "
"normalization parameters."
)
self._compute_scale_shift(
conditions_to_normalize, trainer.datamodule.train_dataset
)

# Apply normalization based on the specified stage
if stage == "fit" and self.stage in ["train", "all"]:
self.normalize_dataset(trainer.datamodule.train_dataset)
if stage == "fit" and self.stage in ["validate", "all"]:
self.normalize_dataset(trainer.datamodule.val_dataset)
if stage == "test" and self.stage in ["test", "all"]:
self.normalize_dataset(trainer.datamodule.test_dataset)

return super().setup(trainer, pl_module, stage)

def _compute_scale_shift(self, conditions, dataset):
"""
Compute scale and shift parameters for each condition in the dataset.

:param list conditions: The list of condition names.
:param dataset: The `~pina.data.dataset.PinaDataset` dataset.
"""
for cond in conditions:
if cond in dataset.conditions_dict:
data = dataset.conditions_dict[cond][self.apply_to]
shift = self.shift_fn(data)
scale = self.scale_fn(data)
self._normalizer[cond] = {
"shift": shift,
"scale": scale,
}

@staticmethod
def _norm_fn(value, scale, shift):
"""
Normalize a value according to the scale and shift parameters.

:param value: The input tensor to normalize.
:type value: torch.Tensor | LabelTensor
:param float scale: The scaling factor.
:param float shift: The shifting factor.
:return: The normalized tensor.
:rtype: torch.Tensor | LabelTensor
"""
scaled_value = (value - shift) / scale
if isinstance(value, LabelTensor):
scaled_value = LabelTensor(scaled_value, value.labels)

return scaled_value

def normalize_dataset(self, dataset):
"""
Apply in-place normalization to the dataset.

:param PinaDataset dataset: The dataset to be normalized.
"""
# Initialize update dictionary
update_dataset_dict = {}

# Iterate over conditions and apply normalization
for cond, norm_params in self.normalizer.items():
points = dataset.conditions_dict[cond][self.apply_to]
scale = norm_params["scale"]
shift = norm_params["shift"]
normalized_points = self._norm_fn(points, scale, shift)
update_dataset_dict[cond] = {
self.apply_to: (
LabelTensor(normalized_points, points.labels)
if isinstance(points, LabelTensor)
else normalized_points
)
}

# Update the dataset in-place
dataset.update_data(update_dataset_dict)

@property
def normalizer(self):
"""
Get the dictionary of normalization parameters.

:return: The dictionary of normalization parameters.
:rtype: dict
"""
return self._normalizer
2 changes: 1 addition & 1 deletion pina/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def is_function(f):
:return: ``True`` if ``f`` is a function, ``False`` otherwise.
:rtype: bool
"""
return isinstance(f, (types.FunctionType, types.LambdaType))
return callable(f)


def chebyshev_roots(n):
Expand Down
Loading
Loading