From 17a59ebd5a7af0dc9793414d73e9bb9df2312636 Mon Sep 17 00:00:00 2001 From: giovanni Date: Fri, 29 Aug 2025 19:11:08 +0200 Subject: [PATCH 1/4] add mutual solver-weighting link --- pina/loss/ntk_weighting.py | 38 ++++++------ pina/loss/scalar_weighting.py | 6 +- pina/loss/weighting_interface.py | 12 +++- pina/solver/solver.py | 2 +- tests/test_weighting/test_ntk_weighting.py | 62 +++++-------------- ..._weighting.py => test_scalar_weighting.py} | 16 ++--- 6 files changed, 61 insertions(+), 75 deletions(-) rename tests/test_weighting/{test_standard_weighting.py => test_scalar_weighting.py} (82%) diff --git a/pina/loss/ntk_weighting.py b/pina/loss/ntk_weighting.py index d8c947f06..6149f2376 100644 --- a/pina/loss/ntk_weighting.py +++ b/pina/loss/ntk_weighting.py @@ -1,7 +1,6 @@ """Module for Neural Tangent Kernel Class""" import torch -from torch.nn import Module from .weighting_interface import WeightingInterface from ..utils import check_consistency @@ -21,43 +20,45 @@ class NeuralTangentKernelWeighting(WeightingInterface): """ - def __init__(self, model, alpha=0.5): + def __init__(self, alpha=0.5): """ Initialization of the :class:`NeuralTangentKernelWeighting` class. - :param torch.nn.Module model: The neural network model. :param float alpha: The alpha parameter. - :raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive). """ - super().__init__() + + # Check consistency check_consistency(alpha, float) - check_consistency(model, Module) if alpha < 0 or alpha > 1: raise ValueError("alpha should be a value between 0 and 1") + + # Initialize parameters self.alpha = alpha - self.model = model self.weights = {} - self.default_value_weights = 1 + self.default_value_weights = 1.0 def aggregate(self, losses): """ - Weight the losses according to the Neural Tangent Kernel - algorithm. + Weight the losses according to the Neural Tangent Kernel algorithm. :param dict(torch.Tensor) input: The dictionary of losses. - :return: The losses aggregation. It should be a scalar Tensor. + :return: The aggregation of the losses. It should be a scalar Tensor. :rtype: torch.Tensor """ + # Define a dictionary to store the norms of the gradients losses_norm = {} - for condition in losses: - losses[condition].backward(retain_graph=True) - grads = [] - for param in self.model.parameters(): - grads.append(param.grad.view(-1)) - grads = torch.cat(grads) - losses_norm[condition] = torch.norm(grads) + + # Compute the gradient norms for each loss component + for condition, loss in losses.items(): + loss.backward(retain_graph=True) + grads = torch.cat( + [p.grad.flatten() for p in self.solver.model.parameters()] + ) + losses_norm[condition] = grads.norm() + + # Update the weights self.weights = { condition: self.alpha * self.weights.get(condition, self.default_value_weights) @@ -66,6 +67,7 @@ def aggregate(self, losses): / sum(losses_norm.values()) for condition in losses } + return sum( self.weights[condition] * loss for condition, loss in losses.items() ) diff --git a/pina/loss/scalar_weighting.py b/pina/loss/scalar_weighting.py index 6bc093c7d..c10b5741a 100644 --- a/pina/loss/scalar_weighting.py +++ b/pina/loss/scalar_weighting.py @@ -37,12 +37,16 @@ def __init__(self, weights): :type weights: float | int | dict """ super().__init__() + + # Check consistency check_consistency([weights], (float, dict, int)) + + # Weights initialization if isinstance(weights, (float, int)): self.default_value_weights = weights self.weights = {} else: - self.default_value_weights = 1 + self.default_value_weights = 1.0 self.weights = weights def aggregate(self, losses): diff --git a/pina/loss/weighting_interface.py b/pina/loss/weighting_interface.py index 8b8cb2f28..567d493c1 100644 --- a/pina/loss/weighting_interface.py +++ b/pina/loss/weighting_interface.py @@ -13,7 +13,7 @@ def __init__(self): """ Initialization of the :class:`WeightingInterface` class. """ - self.condition_names = None + self._solver = None @abstractmethod def aggregate(self, losses): @@ -22,3 +22,13 @@ def aggregate(self, losses): :param dict losses: The dictionary of losses. """ + + @property + def solver(self): + """ + The solver employing this weighting schema. + + :return: The solver. + :rtype: :class:`~pina.solver.SolverInterface` + """ + return self._solver diff --git a/pina/solver/solver.py b/pina/solver/solver.py index f3ff40579..6948ec664 100644 --- a/pina/solver/solver.py +++ b/pina/solver/solver.py @@ -44,7 +44,7 @@ def __init__(self, problem, weighting, use_lt): weighting = _NoWeighting() check_consistency(weighting, WeightingInterface) self._pina_weighting = weighting - weighting.condition_names = list(self._pina_problem.conditions.keys()) + weighting._solver = self # check consistency use_lt check_consistency(use_lt, bool) diff --git a/tests/test_weighting/test_ntk_weighting.py b/tests/test_weighting/test_ntk_weighting.py index 840237fb4..236c4987e 100644 --- a/tests/test_weighting/test_ntk_weighting.py +++ b/tests/test_weighting/test_ntk_weighting.py @@ -2,64 +2,32 @@ from pina import Trainer from pina.solver import PINN from pina.model import FeedForward -from pina.problem.zoo import Poisson2DSquareProblem from pina.loss import NeuralTangentKernelWeighting +from pina.problem.zoo import Poisson2DSquareProblem -problem = Poisson2DSquareProblem() -condition_names = problem.conditions.keys() +# Initialize problem and model +problem = Poisson2DSquareProblem() +problem.discretise_domain(10) +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) -@pytest.mark.parametrize( - "model,alpha", - [ - ( - FeedForward( - len(problem.input_variables), len(problem.output_variables) - ), - 0.5, - ) - ], -) -def test_constructor(model, alpha): - NeuralTangentKernelWeighting(model=model, alpha=alpha) +@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) +def test_constructor(alpha): + NeuralTangentKernelWeighting(alpha=alpha) -@pytest.mark.parametrize("model", [0.5]) -def test_wrong_constructor1(model): + # Should fail if alpha is not >= 0 with pytest.raises(ValueError): - NeuralTangentKernelWeighting(model) - + NeuralTangentKernelWeighting(alpha=-0.1) -@pytest.mark.parametrize( - "model,alpha", - [ - ( - FeedForward( - len(problem.input_variables), len(problem.output_variables) - ), - 1.2, - ) - ], -) -def test_wrong_constructor2(model, alpha): + # Should fail if alpha is not <= 1 with pytest.raises(ValueError): - NeuralTangentKernelWeighting(model, alpha) + NeuralTangentKernelWeighting(alpha=1.1) -@pytest.mark.parametrize( - "model,alpha", - [ - ( - FeedForward( - len(problem.input_variables), len(problem.output_variables) - ), - 0.5, - ) - ], -) -def test_train_aggregation(model, alpha): - weighting = NeuralTangentKernelWeighting(model=model, alpha=alpha) - problem.discretise_domain(50) +@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) +def test_train_aggregation(alpha): + weighting = NeuralTangentKernelWeighting(alpha=alpha) solver = PINN(problem=problem, model=model, weighting=weighting) trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") trainer.train() diff --git a/tests/test_weighting/test_standard_weighting.py b/tests/test_weighting/test_scalar_weighting.py similarity index 82% rename from tests/test_weighting/test_standard_weighting.py rename to tests/test_weighting/test_scalar_weighting.py index 9caa89ae1..54b3293f5 100644 --- a/tests/test_weighting/test_standard_weighting.py +++ b/tests/test_weighting/test_scalar_weighting.py @@ -1,16 +1,17 @@ import pytest import torch - from pina import Trainer from pina.solver import PINN from pina.model import FeedForward -from pina.problem.zoo import Poisson2DSquareProblem from pina.loss import ScalarWeighting +from pina.problem.zoo import Poisson2DSquareProblem + +# Initialize problem and model problem = Poisson2DSquareProblem() +problem.discretise_domain(50) model = FeedForward(len(problem.input_variables), len(problem.output_variables)) condition_names = problem.conditions.keys() -print(problem.conditions.keys()) @pytest.mark.parametrize( @@ -19,11 +20,13 @@ def test_constructor(weights): ScalarWeighting(weights=weights) + # Should fail if weights are not a scalar + with pytest.raises(ValueError): + ScalarWeighting(weights="invalid") -@pytest.mark.parametrize("weights", ["a", [1, 2, 3]]) -def test_wrong_constructor(weights): + # Should fail if weights are not a dictionary with pytest.raises(ValueError): - ScalarWeighting(weights=weights) + ScalarWeighting(weights=[1, 2, 3]) @pytest.mark.parametrize( @@ -45,7 +48,6 @@ def test_aggregate(weights): ) def test_train_aggregation(weights): weighting = ScalarWeighting(weights=weights) - problem.discretise_domain(50) solver = PINN(problem=problem, model=model, weighting=weighting) trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") trainer.train() From 28b0223827d46a4d57099c20539538fa62644f9a Mon Sep 17 00:00:00 2001 From: giovanni Date: Fri, 29 Aug 2025 19:11:48 +0200 Subject: [PATCH 2/4] add self-adaptive weighting --- docs/source/_rst/_code.rst | 1 + .../_rst/loss/self_adaptive_weighting.rst | 9 +++ pina/loss/__init__.py | 2 + pina/loss/self_adaptive_weighting.py | 80 +++++++++++++++++++ .../test_self_adaptive_weighting.py | 37 +++++++++ 5 files changed, 129 insertions(+) create mode 100644 docs/source/_rst/loss/self_adaptive_weighting.rst create mode 100644 pina/loss/self_adaptive_weighting.py create mode 100644 tests/test_weighting/test_self_adaptive_weighting.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 9bd36ab2d..2bb62a4e8 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -267,3 +267,4 @@ Losses and Weightings WeightingInterface ScalarWeighting NeuralTangentKernelWeighting + SelfAdaptiveWeighting \ No newline at end of file diff --git a/docs/source/_rst/loss/self_adaptive_weighting.rst b/docs/source/_rst/loss/self_adaptive_weighting.rst new file mode 100644 index 000000000..cd1daed1f --- /dev/null +++ b/docs/source/_rst/loss/self_adaptive_weighting.rst @@ -0,0 +1,9 @@ +SelfAdaptiveWeighting +============================= +.. currentmodule:: pina.loss.self_adaptive_weighting + +.. automodule:: pina.loss.self_adaptive_weighting + +.. autoclass:: SelfAdaptiveWeighting + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index 2f15c6db9..fc47e62de 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -7,6 +7,7 @@ "WeightingInterface", "ScalarWeighting", "NeuralTangentKernelWeighting", + "SelfAdaptiveWeighting", ] from .loss_interface import LossInterface @@ -15,3 +16,4 @@ from .weighting_interface import WeightingInterface from .scalar_weighting import ScalarWeighting from .ntk_weighting import NeuralTangentKernelWeighting +from .self_adaptive_weighting import SelfAdaptiveWeighting diff --git a/pina/loss/self_adaptive_weighting.py b/pina/loss/self_adaptive_weighting.py new file mode 100644 index 000000000..853307852 --- /dev/null +++ b/pina/loss/self_adaptive_weighting.py @@ -0,0 +1,80 @@ +"""Module for Self-Adaptive Weighting class.""" + +import torch +from .weighting_interface import WeightingInterface +from ..utils import check_positive_integer + + +class SelfAdaptiveWeighting(WeightingInterface): + """ + A self-adaptive weighting scheme to tackle the imbalance among the loss + components. This formulation equalizes the gradient norms of the losses, + preventing bias toward any particular term during training. + + .. seealso:: + + **Original reference**: + Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025). + *Simulating Three-dimensional Turbulence with Physics-informed Neural + Networks*. + DOI: `arXiv preprint arXiv:2507.08972. + `_ + + """ + + def __init__(self, k=100): + """ + Initialization of the :class:`SelfAdaptiveWeighting` class. + + :param int k: The number of epochs after which the weights are updated. + Default is 100. + + :raises ValueError: If ``k`` is not a positive integer. + """ + super().__init__() + + # Check consistency + check_positive_integer(value=k, strict=True) + + # Initialize parameters + self.k = k + self.weights = {} + self.default_value_weights = 1.0 + + def aggregate(self, losses): + """ + Weight the losses according to the self-adaptive algorithm. + + :param dict(torch.Tensor) losses: The dictionary of losses. + :return: The aggregation of the losses. It should be a scalar Tensor. + :rtype: torch.Tensor + """ + # If weights have not been initialized, set them to 1 + if not self.weights: + self.weights = { + condition: self.default_value_weights for condition in losses + } + + # Update every k epochs + if self.solver.trainer.current_epoch % self.k == 0: + + # Define a dictionary to store the norms of the gradients + losses_norm = {} + + # Compute the gradient norms for each loss component + for condition, loss in losses.items(): + loss.backward(retain_graph=True) + grads = torch.cat( + [p.grad.flatten() for p in self.solver.model.parameters()] + ) + losses_norm[condition] = grads.norm() + + # Update the weights + self.weights = { + condition: sum(losses_norm.values()) / losses_norm[condition] + for condition in losses + } + + return sum( + self.weights[condition] * loss for condition, loss in losses.items() + ) diff --git a/tests/test_weighting/test_self_adaptive_weighting.py b/tests/test_weighting/test_self_adaptive_weighting.py new file mode 100644 index 000000000..b82f54575 --- /dev/null +++ b/tests/test_weighting/test_self_adaptive_weighting.py @@ -0,0 +1,37 @@ +import pytest +from pina import Trainer +from pina.solver import PINN +from pina.model import FeedForward +from pina.loss import SelfAdaptiveWeighting +from pina.problem.zoo import Poisson2DSquareProblem + + +# Initialize problem and model +problem = Poisson2DSquareProblem() +problem.discretise_domain(10) +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + + +@pytest.mark.parametrize("k", [10, 100, 1000]) +def test_constructor(k): + SelfAdaptiveWeighting(k=k) + + # Should fail if k is not an integer + with pytest.raises(AssertionError): + SelfAdaptiveWeighting(k=1.5) + + # Should fail if k is not > 0 + with pytest.raises(AssertionError): + SelfAdaptiveWeighting(k=0) + + # Should fail if k is not > 0 + with pytest.raises(AssertionError): + SelfAdaptiveWeighting(k=-3) + + +@pytest.mark.parametrize("k", [2, 3]) +def test_train_aggregation(k): + weighting = SelfAdaptiveWeighting(k=k) + solver = PINN(problem=problem, model=model, weighting=weighting) + trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") + trainer.train() From 6f5a816fbb4e085197428740718d00bee46cf98c Mon Sep 17 00:00:00 2001 From: giovanni Date: Mon, 1 Sep 2025 11:00:14 +0200 Subject: [PATCH 3/4] weighting refactory Co-authored-by: Dario Coscia --- .../linear_weight_update_callback.rst | 7 - pina/callback/__init__.py | 2 - .../callback/linear_weight_update_callback.py | 87 ---------- pina/loss/ntk_weighting.py | 32 ++-- pina/loss/scalar_weighting.py | 58 +++---- pina/loss/self_adaptive_weighting.py | 75 +++----- pina/loss/weighting_interface.py | 81 ++++++++- pina/utils.py | 28 +++ .../test_linear_weight_update_callback.py | 164 ------------------ tests/test_weighting/test_ntk_weighting.py | 30 +++- tests/test_weighting/test_scalar_weighting.py | 14 -- .../test_self_adaptive_weighting.py | 26 +-- 12 files changed, 215 insertions(+), 389 deletions(-) delete mode 100644 docs/source/_rst/callback/linear_weight_update_callback.rst delete mode 100644 pina/callback/linear_weight_update_callback.py delete mode 100644 tests/test_callback/test_linear_weight_update_callback.py diff --git a/docs/source/_rst/callback/linear_weight_update_callback.rst b/docs/source/_rst/callback/linear_weight_update_callback.rst deleted file mode 100644 index fe45b56e2..000000000 --- a/docs/source/_rst/callback/linear_weight_update_callback.rst +++ /dev/null @@ -1,7 +0,0 @@ -Weighting callbacks -======================== - -.. currentmodule:: pina.callback.linear_weight_update_callback -.. autoclass:: LinearWeightUpdate - :members: - :show-inheritance: \ No newline at end of file diff --git a/pina/callback/__init__.py b/pina/callback/__init__.py index dc1164e47..e9a70ea34 100644 --- a/pina/callback/__init__.py +++ b/pina/callback/__init__.py @@ -4,11 +4,9 @@ "SwitchOptimizer", "MetricTracker", "PINAProgressBar", - "LinearWeightUpdate", "R3Refinement", ] from .optimizer_callback import SwitchOptimizer from .processing_callback import MetricTracker, PINAProgressBar -from .linear_weight_update_callback import LinearWeightUpdate from .refinement import R3Refinement diff --git a/pina/callback/linear_weight_update_callback.py b/pina/callback/linear_weight_update_callback.py deleted file mode 100644 index ae25ca158..000000000 --- a/pina/callback/linear_weight_update_callback.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Module for the LinearWeightUpdate callback.""" - -import warnings -from lightning.pytorch.callbacks import Callback -from ..utils import check_consistency -from ..loss import ScalarWeighting - - -class LinearWeightUpdate(Callback): - """ - Callback to linearly adjust the weight of a condition from an - initial value to a target value over a specified number of epochs. - """ - - def __init__( - self, target_epoch, condition_name, initial_value, target_value - ): - """ - Callback initialization. - - :param int target_epoch: The epoch at which the weight of the condition - should reach the target value. - :param str condition_name: The name of the condition whose weight - should be adjusted. - :param float initial_value: The initial value of the weight. - :param float target_value: The target value of the weight. - """ - super().__init__() - self.target_epoch = target_epoch - self.condition_name = condition_name - self.initial_value = initial_value - self.target_value = target_value - - # Check consistency - check_consistency(self.target_epoch, int, subclass=False) - check_consistency(self.condition_name, str, subclass=False) - check_consistency(self.initial_value, (float, int), subclass=False) - check_consistency(self.target_value, (float, int), subclass=False) - - def on_train_start(self, trainer, pl_module): - """ - Initialize the weight of the condition to the specified `initial_value`. - - :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. - :param SolverInterface pl_module: A - :class:`~pina.solver.solver.SolverInterface` instance. - """ - # Check that the target epoch is valid - if not 0 < self.target_epoch <= trainer.max_epochs: - raise ValueError( - "`target_epoch` must be greater than 0" - " and less than or equal to `max_epochs`." - ) - - # Check that the condition is a problem condition - if self.condition_name not in pl_module.problem.conditions: - raise ValueError( - f"`{self.condition_name}` must be a problem condition." - ) - - # Check that the initial value is not equal to the target value - if self.initial_value == self.target_value: - warnings.warn( - "`initial_value` is equal to `target_value`. " - "No effective adjustment will be performed.", - UserWarning, - ) - - # Check that the weighting schema is ScalarWeighting - if not isinstance(pl_module.weighting, ScalarWeighting): - raise ValueError("The weighting schema must be ScalarWeighting.") - - # Initialize the weight of the condition - pl_module.weighting.weights[self.condition_name] = self.initial_value - - def on_train_epoch_start(self, trainer, pl_module): - """ - Adjust at each epoch the weight of the condition. - - :param Trainer trainer: A :class:`~pina.trainer.Trainer` instance. - :param SolverInterface pl_module: A - :class:`~pina.solver.solver.SolverInterface` instance. - """ - if 0 < trainer.current_epoch <= self.target_epoch: - pl_module.weighting.weights[self.condition_name] += ( - self.target_value - self.initial_value - ) / (self.target_epoch - 1) diff --git a/pina/loss/ntk_weighting.py b/pina/loss/ntk_weighting.py index 6149f2376..b88812615 100644 --- a/pina/loss/ntk_weighting.py +++ b/pina/loss/ntk_weighting.py @@ -2,7 +2,7 @@ import torch from .weighting_interface import WeightingInterface -from ..utils import check_consistency +from ..utils import check_consistency, in_range class NeuralTangentKernelWeighting(WeightingInterface): @@ -20,32 +20,34 @@ class NeuralTangentKernelWeighting(WeightingInterface): """ - def __init__(self, alpha=0.5): + def __init__(self, update_every_n_epochs=1, alpha=0.5): """ Initialization of the :class:`NeuralTangentKernelWeighting` class. + :param int update_every_n_epochs: The number of training epochs between + weight updates. If set to 1, the weights are updated at every epoch. + Default is 1. :param float alpha: The alpha parameter. :raises ValueError: If ``alpha`` is not between 0 and 1 (inclusive). """ - super().__init__() + super().__init__(update_every_n_epochs=update_every_n_epochs) # Check consistency check_consistency(alpha, float) - if alpha < 0 or alpha > 1: - raise ValueError("alpha should be a value between 0 and 1") + if not in_range(alpha, [0, 1], strict=False): + raise ValueError("alpha must be in range (0, 1).") # Initialize parameters self.alpha = alpha self.weights = {} - self.default_value_weights = 1.0 - def aggregate(self, losses): + def weights_update(self, losses): """ - Weight the losses according to the Neural Tangent Kernel algorithm. + Update the weighting scheme based on the given losses. - :param dict(torch.Tensor) input: The dictionary of losses. - :return: The aggregation of the losses. It should be a scalar Tensor. - :rtype: torch.Tensor + :param dict losses: The dictionary of losses. + :return: The updated weights. + :rtype: dict """ # Define a dictionary to store the norms of the gradients losses_norm = {} @@ -60,14 +62,10 @@ def aggregate(self, losses): # Update the weights self.weights = { - condition: self.alpha - * self.weights.get(condition, self.default_value_weights) + condition: self.alpha * self.weights.get(condition, 1) + (1 - self.alpha) * losses_norm[condition] / sum(losses_norm.values()) for condition in losses } - - return sum( - self.weights[condition] * loss for condition, loss in losses.items() - ) + return self.weights diff --git a/pina/loss/scalar_weighting.py b/pina/loss/scalar_weighting.py index c10b5741a..d770c8961 100644 --- a/pina/loss/scalar_weighting.py +++ b/pina/loss/scalar_weighting.py @@ -4,22 +4,6 @@ from ..utils import check_consistency -class _NoWeighting(WeightingInterface): - """ - Weighting scheme that does not apply any weighting to the losses. - """ - - def aggregate(self, losses): - """ - Aggregate the losses. - - :param dict losses: The dictionary of losses. - :return: The aggregated losses. - :rtype: torch.Tensor - """ - return sum(losses.values()) - - class ScalarWeighting(WeightingInterface): """ Weighting scheme that assigns a scalar weight to each loss term. @@ -36,28 +20,42 @@ def __init__(self, weights): dictionary, the default value is used. :type weights: float | int | dict """ - super().__init__() + super().__init__(update_every_n_epochs=1, aggregator="sum") # Check consistency check_consistency([weights], (float, dict, int)) - # Weights initialization - if isinstance(weights, (float, int)): + # Initialization + if isinstance(weights, dict): + self.values = weights + self.default_value_weights = 1 + elif isinstance(weights, (float, int)): + self.values = {} self.default_value_weights = weights - self.weights = {} else: - self.default_value_weights = 1.0 - self.weights = weights + raise ValueError - def aggregate(self, losses): + def weights_update(self, losses): """ - Aggregate the losses. + Update the weighting scheme based on the given losses. :param dict losses: The dictionary of losses. - :return: The aggregated losses. - :rtype: torch.Tensor + :return: The updated weights. + :rtype: dict + """ + return { + condition: self.values.get(condition, self.default_value_weights) + for condition in losses.keys() + } + + +class _NoWeighting(ScalarWeighting): + """ + Weighting scheme that does not apply any weighting to the losses. + """ + + def __init__(self): + """ + Initialization of the :class:`_NoWeighting` class. """ - return sum( - self.weights.get(condition, self.default_value_weights) * loss - for condition, loss in losses.items() - ) + super().__init__(weights=1) diff --git a/pina/loss/self_adaptive_weighting.py b/pina/loss/self_adaptive_weighting.py index 853307852..62196c529 100644 --- a/pina/loss/self_adaptive_weighting.py +++ b/pina/loss/self_adaptive_weighting.py @@ -2,7 +2,6 @@ import torch from .weighting_interface import WeightingInterface -from ..utils import check_positive_integer class SelfAdaptiveWeighting(WeightingInterface): @@ -22,59 +21,37 @@ class SelfAdaptiveWeighting(WeightingInterface): """ - def __init__(self, k=100): + def __init__(self, update_every_n_epochs=1): """ Initialization of the :class:`SelfAdaptiveWeighting` class. - :param int k: The number of epochs after which the weights are updated. - Default is 100. - - :raises ValueError: If ``k`` is not a positive integer. + :param int update_every_n_epochs: The number of training epochs between + weight updates. If set to 1, the weights are updated at every epoch. + Default is 1. """ - super().__init__() - - # Check consistency - check_positive_integer(value=k, strict=True) + super().__init__(update_every_n_epochs=update_every_n_epochs) - # Initialize parameters - self.k = k - self.weights = {} - self.default_value_weights = 1.0 - - def aggregate(self, losses): + def weights_update(self, losses): """ - Weight the losses according to the self-adaptive algorithm. + Update the weighting scheme based on the given losses. - :param dict(torch.Tensor) losses: The dictionary of losses. - :return: The aggregation of the losses. It should be a scalar Tensor. - :rtype: torch.Tensor + :param dict losses: The dictionary of losses. + :return: The updated weights. + :rtype: dict """ - # If weights have not been initialized, set them to 1 - if not self.weights: - self.weights = { - condition: self.default_value_weights for condition in losses - } - - # Update every k epochs - if self.solver.trainer.current_epoch % self.k == 0: - - # Define a dictionary to store the norms of the gradients - losses_norm = {} - - # Compute the gradient norms for each loss component - for condition, loss in losses.items(): - loss.backward(retain_graph=True) - grads = torch.cat( - [p.grad.flatten() for p in self.solver.model.parameters()] - ) - losses_norm[condition] = grads.norm() - - # Update the weights - self.weights = { - condition: sum(losses_norm.values()) / losses_norm[condition] - for condition in losses - } - - return sum( - self.weights[condition] * loss for condition, loss in losses.items() - ) + # Define a dictionary to store the norms of the gradients + losses_norm = {} + + # Compute the gradient norms for each loss component + for condition, loss in losses.items(): + loss.backward(retain_graph=True) + grads = torch.cat( + [p.grad.flatten() for p in self.solver.model.parameters()] + ) + losses_norm[condition] = grads.norm() + + # Update the weights + return { + condition: sum(losses_norm.values()) / losses_norm[condition] + for condition in losses + } diff --git a/pina/loss/weighting_interface.py b/pina/loss/weighting_interface.py index 567d493c1..bc34c3181 100644 --- a/pina/loss/weighting_interface.py +++ b/pina/loss/weighting_interface.py @@ -1,6 +1,10 @@ """Module for the Weighting Interface.""" from abc import ABCMeta, abstractmethod +from typing import final +from ..utils import check_positive_integer, is_function + +_AGGREGATE_METHODS = {"sum": sum, "mean": lambda x: sum(x) / len(x)} class WeightingInterface(metaclass=ABCMeta): @@ -9,19 +13,92 @@ class WeightingInterface(metaclass=ABCMeta): should inherit from this class. """ - def __init__(self): + def __init__(self, update_every_n_epochs=1, aggregator="sum"): """ Initialization of the :class:`WeightingInterface` class. + + :param int update_every_n_epochs: The number of training epochs between + weight updates. If set to 1, the weights are updated at every epoch. + This parameter is ignored by static weighting schemes. Default is 1. + :param aggregator: The aggregation method. Either: + - 'sum' → torch.sum + - 'mean' → torch.mean + - callable → custom aggregation function + :type aggregator: str | Callable """ + # Check consistency + check_positive_integer(value=update_every_n_epochs, strict=True) + + # Aggregation + if isinstance(aggregator, str): + if aggregator not in _AGGREGATE_METHODS: + raise ValueError( + f"Invalid aggregator '{aggregator}'. Must be one of " + f"{list(_AGGREGATE_METHODS.keys())}." + ) + aggregator = _AGGREGATE_METHODS[aggregator] + + elif not is_function(aggregator): + raise TypeError( + f"Aggregator must be either a string or a callable, " + f"got {type(aggregator).__name__}." + ) + + # Initialization self._solver = None + self.update_every_n_epochs = update_every_n_epochs + self.aggregator_fn = aggregator + self._saved_weights = {} @abstractmethod + def weights_update(self, losses): + """ + Update the weighting scheme based on the given losses. + + This method must be implemented by subclasses. Its role is to update the + values of the weights. The updated weights will then be used by + :meth:`aggregate` to compute the final aggregated loss. + + :param dict losses: The dictionary of losses. + :return: The updated weights. + :rtype: dict + """ + + @final def aggregate(self, losses): """ - Aggregate the losses. + Update the weights (if needed) and aggregate the given losses. + + This method first checks whether the loss weights need to be updated + based on the current epoch and the ``update_every_n_epochs`` setting. + If an update is required, it calls :meth:`weights_update` to refresh the + weights. Afterwards, it aggregates the (weighted) losses into a single + scalar tensor using the configured aggregator function. This method must + not be overridden. :param dict losses: The dictionary of losses. + :return: The aggregated loss tensor. + :rtype: torch.Tensor + """ + # Update weights + if self.solver.trainer.current_epoch % self.update_every_n_epochs == 0: + self._saved_weights = self.weights_update(losses) + + # Aggregate. Using direct indexing instead of .get() ensures that a + # KeyError is raised if the expected condition is missing from the dict. + return self.aggregator_fn( + self._saved_weights[condition] * loss + for condition, loss in losses.items() + ) + + def last_saved_weights(self): + """ + Get the last saved weights. + + :return: The last saved weights. + :rtype: dict """ + return self._saved_weights @property def solver(self): diff --git a/pina/utils.py b/pina/utils.py index ddbd2e8ac..2aafba1f2 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -240,3 +240,31 @@ def check_positive_integer(value, strict=True): assert ( isinstance(value, int) and value >= 0 ), f"Expected a non-negative integer, got {value}." + + +def in_range(value, range_vals, strict=True): + """ + Check if a value is within a specified range. + + :param int value: The integer value to check. + :param list[int] range_vals: A list of two integers representing the range + limits. The first element specifies the lower bound, and the second + specifies the upper bound. + :param bool strict: If True, the value must be strictly positive. + Default is True. + :return: True if the value satisfies the range condition, False otherwise. + :rtype: bool + """ + # Validate inputs + check_consistency(value, (float, int)) + check_consistency(range_vals, (float, int)) + assert ( + isinstance(range_vals, list) and len(range_vals) == 2 + ), "range_vals must be a list of two integers [lower, upper]" + lower, upper = range_vals + + # Check the range + if strict: + return lower < value < upper + + return lower <= value <= upper diff --git a/tests/test_callback/test_linear_weight_update_callback.py b/tests/test_callback/test_linear_weight_update_callback.py deleted file mode 100644 index c1f4cf357..000000000 --- a/tests/test_callback/test_linear_weight_update_callback.py +++ /dev/null @@ -1,164 +0,0 @@ -import pytest -import math -from pina.solver import PINN -from pina.loss import ScalarWeighting -from pina.trainer import Trainer -from pina.model import FeedForward -from pina.problem.zoo import Poisson2DSquareProblem as Poisson -from pina.callback import LinearWeightUpdate - - -# Define the problem -poisson_problem = Poisson() -poisson_problem.discretise_domain(50, "grid") -cond_name = list(poisson_problem.conditions.keys())[0] - -# Define the model -model = FeedForward( - input_dimensions=len(poisson_problem.input_variables), - output_dimensions=len(poisson_problem.output_variables), - layers=[32, 32], -) - -# Define the weighting schema -weights_dict = {key: 1 for key in poisson_problem.conditions.keys()} -weighting = ScalarWeighting(weights=weights_dict) - -# Define the solver -solver = PINN(problem=poisson_problem, model=model, weighting=weighting) - -# Value used for testing -epochs = 10 - - -@pytest.mark.parametrize("initial_value", [1, 5.5]) -@pytest.mark.parametrize("target_value", [10, 25.5]) -def test_constructor(initial_value, target_value): - LinearWeightUpdate( - target_epoch=epochs, - condition_name=cond_name, - initial_value=initial_value, - target_value=target_value, - ) - - # Target_epoch must be int - with pytest.raises(ValueError): - LinearWeightUpdate( - target_epoch=10.0, - condition_name=cond_name, - initial_value=0, - target_value=1, - ) - - # Condition_name must be str - with pytest.raises(ValueError): - LinearWeightUpdate( - target_epoch=epochs, - condition_name=100, - initial_value=0, - target_value=1, - ) - - # Initial_value must be float or int - with pytest.raises(ValueError): - LinearWeightUpdate( - target_epoch=epochs, - condition_name=cond_name, - initial_value="0", - target_value=1, - ) - - # Target_value must be float or int - with pytest.raises(ValueError): - LinearWeightUpdate( - target_epoch=epochs, - condition_name=cond_name, - initial_value=0, - target_value="1", - ) - - -@pytest.mark.parametrize("initial_value, target_value", [(1, 10), (10, 1)]) -def test_training(initial_value, target_value): - callback = LinearWeightUpdate( - target_epoch=epochs, - condition_name=cond_name, - initial_value=initial_value, - target_value=target_value, - ) - trainer = Trainer( - solver=solver, - callbacks=[callback], - accelerator="cpu", - max_epochs=epochs, - ) - trainer.train() - - # Check that the final weight value matches the target value - final_value = solver.weighting.weights[cond_name] - assert math.isclose(final_value, target_value) - - # Target_epoch must be greater than 0 - with pytest.raises(ValueError): - callback = LinearWeightUpdate( - target_epoch=0, - condition_name=cond_name, - initial_value=0, - target_value=1, - ) - trainer = Trainer( - solver=solver, - callbacks=[callback], - accelerator="cpu", - max_epochs=5, - ) - trainer.train() - - # Target_epoch must be less than or equal to max_epochs - with pytest.raises(ValueError): - callback = LinearWeightUpdate( - target_epoch=epochs, - condition_name=cond_name, - initial_value=0, - target_value=1, - ) - trainer = Trainer( - solver=solver, - callbacks=[callback], - accelerator="cpu", - max_epochs=epochs - 1, - ) - trainer.train() - - # Condition_name must be a problem condition - with pytest.raises(ValueError): - callback = LinearWeightUpdate( - target_epoch=epochs, - condition_name="not_a_condition", - initial_value=0, - target_value=1, - ) - trainer = Trainer( - solver=solver, - callbacks=[callback], - accelerator="cpu", - max_epochs=epochs, - ) - trainer.train() - - # Weighting schema must be ScalarWeighting - with pytest.raises(ValueError): - callback = LinearWeightUpdate( - target_epoch=epochs, - condition_name=cond_name, - initial_value=0, - target_value=1, - ) - unweighted_solver = PINN(problem=poisson_problem, model=model) - trainer = Trainer( - solver=unweighted_solver, - callbacks=[callback], - accelerator="cpu", - max_epochs=epochs, - ) - trainer.train() diff --git a/tests/test_weighting/test_ntk_weighting.py b/tests/test_weighting/test_ntk_weighting.py index 236c4987e..49442b9fb 100644 --- a/tests/test_weighting/test_ntk_weighting.py +++ b/tests/test_weighting/test_ntk_weighting.py @@ -12,22 +12,42 @@ model = FeedForward(len(problem.input_variables), len(problem.output_variables)) +@pytest.mark.parametrize("update_every_n_epochs", [1, 10, 100, 1000]) @pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) -def test_constructor(alpha): - NeuralTangentKernelWeighting(alpha=alpha) +def test_constructor(update_every_n_epochs, alpha): + NeuralTangentKernelWeighting( + update_every_n_epochs=update_every_n_epochs, alpha=alpha + ) # Should fail if alpha is not >= 0 with pytest.raises(ValueError): - NeuralTangentKernelWeighting(alpha=-0.1) + NeuralTangentKernelWeighting( + update_every_n_epochs=update_every_n_epochs, alpha=-0.1 + ) # Should fail if alpha is not <= 1 with pytest.raises(ValueError): NeuralTangentKernelWeighting(alpha=1.1) + # Should fail if update_every_n_epochs is not an integer + with pytest.raises(AssertionError): + NeuralTangentKernelWeighting(update_every_n_epochs=1.5) + # Should fail if update_every_n_epochs is not > 0 + with pytest.raises(AssertionError): + NeuralTangentKernelWeighting(update_every_n_epochs=0) + + # Should fail if update_every_n_epochs is not > 0 + with pytest.raises(AssertionError): + NeuralTangentKernelWeighting(update_every_n_epochs=-3) + + +@pytest.mark.parametrize("update_every_n_epochs", [1, 3]) @pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0]) -def test_train_aggregation(alpha): - weighting = NeuralTangentKernelWeighting(alpha=alpha) +def test_train_aggregation(update_every_n_epochs, alpha): + weighting = NeuralTangentKernelWeighting( + update_every_n_epochs=update_every_n_epochs, alpha=alpha + ) solver = PINN(problem=problem, model=model, weighting=weighting) trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") trainer.train() diff --git a/tests/test_weighting/test_scalar_weighting.py b/tests/test_weighting/test_scalar_weighting.py index 54b3293f5..bbf71afde 100644 --- a/tests/test_weighting/test_scalar_weighting.py +++ b/tests/test_weighting/test_scalar_weighting.py @@ -29,20 +29,6 @@ def test_constructor(weights): ScalarWeighting(weights=[1, 2, 3]) -@pytest.mark.parametrize( - "weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))] -) -def test_aggregate(weights): - weighting = ScalarWeighting(weights=weights) - losses = dict( - zip( - condition_names, - [torch.randn(1) for _ in range(len(condition_names))], - ) - ) - weighting.aggregate(losses=losses) - - @pytest.mark.parametrize( "weights", [1, 1.0, dict(zip(condition_names, [1] * len(condition_names)))] ) diff --git a/tests/test_weighting/test_self_adaptive_weighting.py b/tests/test_weighting/test_self_adaptive_weighting.py index b82f54575..066e8855e 100644 --- a/tests/test_weighting/test_self_adaptive_weighting.py +++ b/tests/test_weighting/test_self_adaptive_weighting.py @@ -12,26 +12,28 @@ model = FeedForward(len(problem.input_variables), len(problem.output_variables)) -@pytest.mark.parametrize("k", [10, 100, 1000]) -def test_constructor(k): - SelfAdaptiveWeighting(k=k) +@pytest.mark.parametrize("update_every_n_epochs", [10, 100, 1000]) +def test_constructor(update_every_n_epochs): + SelfAdaptiveWeighting(update_every_n_epochs=update_every_n_epochs) - # Should fail if k is not an integer + # Should fail if update_every_n_epochs is not an integer with pytest.raises(AssertionError): - SelfAdaptiveWeighting(k=1.5) + SelfAdaptiveWeighting(update_every_n_epochs=1.5) - # Should fail if k is not > 0 + # Should fail if update_every_n_epochs is not > 0 with pytest.raises(AssertionError): - SelfAdaptiveWeighting(k=0) + SelfAdaptiveWeighting(update_every_n_epochs=0) - # Should fail if k is not > 0 + # Should fail if update_every_n_epochs is not > 0 with pytest.raises(AssertionError): - SelfAdaptiveWeighting(k=-3) + SelfAdaptiveWeighting(update_every_n_epochs=-3) -@pytest.mark.parametrize("k", [2, 3]) -def test_train_aggregation(k): - weighting = SelfAdaptiveWeighting(k=k) +@pytest.mark.parametrize("update_every_n_epochs", [1, 3]) +def test_train_aggregation(update_every_n_epochs): + weighting = SelfAdaptiveWeighting( + update_every_n_epochs=update_every_n_epochs + ) solver = PINN(problem=problem, model=model, weighting=weighting) trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") trainer.train() From 9e4b2389369d4c7a3cd003252a6198af018ff6f3 Mon Sep 17 00:00:00 2001 From: giovanni Date: Fri, 5 Sep 2025 10:12:09 +0200 Subject: [PATCH 4/4] add linear weighting --- docs/source/_rst/_code.rst | 4 +- docs/source/_rst/loss/linear_weighting.rst | 9 ++ pina/loss/__init__.py | 2 + pina/loss/linear_weighting.py | 64 +++++++++++++ pina/loss/ntk_weighting.py | 5 +- pina/loss/scalar_weighting.py | 6 +- tests/test_weighting/test_linear_weighting.py | 95 +++++++++++++++++++ 7 files changed, 176 insertions(+), 9 deletions(-) create mode 100644 docs/source/_rst/loss/linear_weighting.rst create mode 100644 pina/loss/linear_weighting.py create mode 100644 tests/test_weighting/test_linear_weighting.py diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 2bb62a4e8..a7242562b 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -253,7 +253,6 @@ Callbacks Optimizer callback R3 Refinment callback Refinment Interface callback - Weighting callback Losses and Weightings --------------------- @@ -267,4 +266,5 @@ Losses and Weightings WeightingInterface ScalarWeighting NeuralTangentKernelWeighting - SelfAdaptiveWeighting \ No newline at end of file + SelfAdaptiveWeighting + LinearWeighting \ No newline at end of file diff --git a/docs/source/_rst/loss/linear_weighting.rst b/docs/source/_rst/loss/linear_weighting.rst new file mode 100644 index 000000000..16e6232d0 --- /dev/null +++ b/docs/source/_rst/loss/linear_weighting.rst @@ -0,0 +1,9 @@ +LinearWeighting +============================= +.. currentmodule:: pina.loss.linear_weighting + +.. automodule:: pina.loss.linear_weighting + +.. autoclass:: LinearWeighting + :members: + :show-inheritance: diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index fc47e62de..d91cf7ab0 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -8,6 +8,7 @@ "ScalarWeighting", "NeuralTangentKernelWeighting", "SelfAdaptiveWeighting", + "LinearWeighting", ] from .loss_interface import LossInterface @@ -17,3 +18,4 @@ from .scalar_weighting import ScalarWeighting from .ntk_weighting import NeuralTangentKernelWeighting from .self_adaptive_weighting import SelfAdaptiveWeighting +from .linear_weighting import LinearWeighting diff --git a/pina/loss/linear_weighting.py b/pina/loss/linear_weighting.py new file mode 100644 index 000000000..9049b52fa --- /dev/null +++ b/pina/loss/linear_weighting.py @@ -0,0 +1,64 @@ +"""Module for the LinearWeighting class.""" + +from ..loss import WeightingInterface +from ..utils import check_consistency, check_positive_integer + + +class LinearWeighting(WeightingInterface): + """ + A weighting scheme that linearly scales weights from initial values to final + values over a specified number of epochs. + """ + + def __init__(self, initial_weights, final_weights, target_epoch): + """ + :param dict initial_weights: The weights to be assigned to each loss + term at the beginning of training. The keys are the conditions and + the values are the corresponding weights. If a condition is not + present in the dictionary, the default value (1) is used. + :param dict final_weights: The weights to be assigned to each loss term + once the target epoch is reached. The keys are the conditions and + the values are the corresponding weights. If a condition is not + present in the dictionary, the default value (1) is used. + :param int target_epoch: The epoch at which the weights reach their + final values. + :raises ValueError: If the keys of the two dictionaries are not + consistent. + """ + super().__init__(update_every_n_epochs=1, aggregator="sum") + + # Check consistency + check_consistency([initial_weights, final_weights], dict) + check_positive_integer(value=target_epoch, strict=True) + + # Check that the keys of the two dictionaries are the same + if initial_weights.keys() != final_weights.keys(): + raise ValueError( + "The keys of the initial_weights and final_weights " + "dictionaries must be the same." + ) + + # Initialization + self.initial_weights = initial_weights + self.final_weights = final_weights + self.target_epoch = target_epoch + + def weights_update(self, losses): + """ + Update the weighting scheme based on the given losses. + + :param dict losses: The dictionary of losses. + :return: The updated weights. + :rtype: dict + """ + return { + condition: self.last_saved_weights().get( + condition, self.initial_weights.get(condition, 1) + ) + + ( + self.final_weights.get(condition, 1) + - self.initial_weights.get(condition, 1) + ) + / (self.target_epoch) + for condition in losses.keys() + } diff --git a/pina/loss/ntk_weighting.py b/pina/loss/ntk_weighting.py index b88812615..fe671157a 100644 --- a/pina/loss/ntk_weighting.py +++ b/pina/loss/ntk_weighting.py @@ -61,11 +61,10 @@ def weights_update(self, losses): losses_norm[condition] = grads.norm() # Update the weights - self.weights = { - condition: self.alpha * self.weights.get(condition, 1) + return { + condition: self.alpha * self.last_saved_weights().get(condition, 1) + (1 - self.alpha) * losses_norm[condition] / sum(losses_norm.values()) for condition in losses } - return self.weights diff --git a/pina/loss/scalar_weighting.py b/pina/loss/scalar_weighting.py index d770c8961..692c4937b 100644 --- a/pina/loss/scalar_weighting.py +++ b/pina/loss/scalar_weighting.py @@ -17,7 +17,7 @@ def __init__(self, weights): If a single scalar value is provided, it is assigned to all loss terms. If a dictionary is provided, the keys are the conditions and the values are the weights. If a condition is not present in the - dictionary, the default value is used. + dictionary, the default value (1) is used. :type weights: float | int | dict """ super().__init__(update_every_n_epochs=1, aggregator="sum") @@ -29,11 +29,9 @@ def __init__(self, weights): if isinstance(weights, dict): self.values = weights self.default_value_weights = 1 - elif isinstance(weights, (float, int)): + else: self.values = {} self.default_value_weights = weights - else: - raise ValueError def weights_update(self, losses): """ diff --git a/tests/test_weighting/test_linear_weighting.py b/tests/test_weighting/test_linear_weighting.py new file mode 100644 index 000000000..a11952073 --- /dev/null +++ b/tests/test_weighting/test_linear_weighting.py @@ -0,0 +1,95 @@ +import math +import pytest +from pina import Trainer +from pina.solver import PINN +from pina.model import FeedForward +from pina.loss import LinearWeighting +from pina.problem.zoo import Poisson2DSquareProblem + + +# Initialize problem and model +problem = Poisson2DSquareProblem() +problem.discretise_domain(10) +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + +# Weights for testing +init_weight_1 = {cond: 3 for cond in problem.conditions.keys()} +init_weight_2 = {cond: 4 for cond in problem.conditions.keys()} +final_weight_1 = {cond: 1 for cond in problem.conditions.keys()} +final_weight_2 = {cond: 5 for cond in problem.conditions.keys()} + + +@pytest.mark.parametrize("initial_weights", [init_weight_1, init_weight_2]) +@pytest.mark.parametrize("final_weights", [final_weight_1, final_weight_2]) +@pytest.mark.parametrize("target_epoch", [5, 10]) +def test_constructor(initial_weights, final_weights, target_epoch): + LinearWeighting( + initial_weights=initial_weights, + final_weights=final_weights, + target_epoch=target_epoch, + ) + + # Should fail if initial_weights is not a dictionary + with pytest.raises(ValueError): + LinearWeighting( + initial_weights=[1, 1, 1], + final_weights=final_weights, + target_epoch=target_epoch, + ) + + # Should fail if final_weights is not a dictionary + with pytest.raises(ValueError): + LinearWeighting( + initial_weights=initial_weights, + final_weights=[1, 1, 1], + target_epoch=target_epoch, + ) + + # Should fail if target_epoch is not an integer + with pytest.raises(AssertionError): + LinearWeighting( + initial_weights=initial_weights, + final_weights=final_weights, + target_epoch=1.5, + ) + + # Should fail if target_epoch is not positive + with pytest.raises(AssertionError): + LinearWeighting( + initial_weights=initial_weights, + final_weights=final_weights, + target_epoch=0, + ) + + # Should fail if dictionary keys do not match + with pytest.raises(ValueError): + LinearWeighting( + initial_weights={list(initial_weights.keys())[0]: 1}, + final_weights=final_weights, + target_epoch=target_epoch, + ) + + +@pytest.mark.parametrize("initial_weights", [init_weight_1, init_weight_2]) +@pytest.mark.parametrize("final_weights", [final_weight_1, final_weight_2]) +@pytest.mark.parametrize("target_epoch", [5, 10]) +def test_train_aggregation(initial_weights, final_weights, target_epoch): + weighting = LinearWeighting( + initial_weights=initial_weights, + final_weights=final_weights, + target_epoch=target_epoch, + ) + solver = PINN(problem=problem, model=model, weighting=weighting) + trainer = Trainer(solver=solver, max_epochs=target_epoch, accelerator="cpu") + trainer.train() + + # Check that weights are updated correctly + assert all( + math.isclose( + weighting.last_saved_weights()[cond], + final_weights[cond], + rel_tol=1e-5, + abs_tol=1e-8, + ) + for cond in final_weights.keys() + )