From f1df50809a3504e53e4c57d5d24a863385559505 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Mon, 8 Sep 2025 10:52:08 +0200 Subject: [PATCH] fix compile issue --- .../physics_informed_solver/pinn_interface.py | 36 +++++++++++++++++++ pina/solver/solver.py | 5 ++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/pina/solver/physics_informed_solver/pinn_interface.py b/pina/solver/physics_informed_solver/pinn_interface.py index 976f6ce6b..535e7ae11 100644 --- a/pina/solver/physics_informed_solver/pinn_interface.py +++ b/pina/solver/physics_informed_solver/pinn_interface.py @@ -1,8 +1,10 @@ """Module for the Physics-Informed Neural Network Interface.""" from abc import ABCMeta, abstractmethod +import warnings import torch +from ...utils import custom_warning_format from ..supervised_solver import SupervisedSolverInterface from ...condition import ( InputTargetCondition, @@ -10,6 +12,10 @@ DomainEquationCondition, ) +# set the warning for torch >= 2.8 compile +warnings.formatwarning = custom_warning_format +warnings.filterwarnings("always", category=UserWarning) + class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta): """ @@ -46,6 +52,36 @@ def __init__(self, **kwargs): # current condition name self.__metric = None + def setup(self, stage): + """ + Setup method executed at the beginning of training and testing. + + This method compiles the model only if the installed torch version + is earlier than 2.8, due to known issues with later versions + (see https://github.com/mathLab/PINA/issues/621). + + .. warning:: + For torch >= 2.8, compilation is disabled. Forcing compilation + on these versions may cause runtime errors or unstable behavior. + + :param str stage: The current stage of the training process + (e.g., ``fit``, ``validate``, ``test``, ``predict``). + :return: The result of the parent class ``setup`` method. + :rtype: Any + """ + # Override the compilation, compiling only for torch < 2.8, see + # related issue at https://github.com/mathLab/PINA/issues/621 + if torch.__version__ < "2.8": + self.trainer.compile = True + else: + self.trainer.compile = False + warnings.warn( + "Compilation is disabled for torch >= 2.8. " + "Forcing compilation may cause runtime errors or instability.", + UserWarning, + ) + return super().setup(stage) + def optimization_cycle(self, batch, loss_residuals=None): """ The optimization cycle for the PINN solver. diff --git a/pina/solver/solver.py b/pina/solver/solver.py index f6bcc2ac2..f3ff40579 100644 --- a/pina/solver/solver.py +++ b/pina/solver/solver.py @@ -169,7 +169,10 @@ def setup(self, stage): compile the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``. - + :param str stage: The current stage of the training process + (e.g., ``fit``, ``validate``, ``test``, ``predict``). + :return: The result of the parent class ``setup`` method. + :rtype: Any """ if stage == "fit" and self.trainer.compile: self._setup_compile()