Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions pina/solver/physics_informed_solver/pinn_interface.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
"""Module for the Physics-Informed Neural Network Interface."""

from abc import ABCMeta, abstractmethod
import warnings
import torch

from ...utils import custom_warning_format
from ..supervised_solver import SupervisedSolverInterface
from ...condition import (
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)

# set the warning for torch >= 2.8 compile
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=UserWarning)


class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
"""
Expand Down Expand Up @@ -46,6 +52,36 @@ def __init__(self, **kwargs):
# current condition name
self.__metric = None

def setup(self, stage):
"""
Setup method executed at the beginning of training and testing.

This method compiles the model only if the installed torch version
is earlier than 2.8, due to known issues with later versions
(see https://github.com/mathLab/PINA/issues/621).

.. warning::
For torch >= 2.8, compilation is disabled. Forcing compilation
on these versions may cause runtime errors or unstable behavior.

:param str stage: The current stage of the training process
(e.g., ``fit``, ``validate``, ``test``, ``predict``).
:return: The result of the parent class ``setup`` method.
:rtype: Any
"""
# Override the compilation, compiling only for torch < 2.8, see
# related issue at https://github.com/mathLab/PINA/issues/621
if torch.__version__ < "2.8":
self.trainer.compile = True
else:
self.trainer.compile = False
warnings.warn(
"Compilation is disabled for torch >= 2.8. "
"Forcing compilation may cause runtime errors or instability.",
UserWarning,
)
return super().setup(stage)

def optimization_cycle(self, batch, loss_residuals=None):
"""
The optimization cycle for the PINN solver.
Expand Down
5 changes: 4 additions & 1 deletion pina/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ def setup(self, stage):
compile the model if the :class:`~pina.trainer.Trainer`
``compile`` is ``True``.


:param str stage: The current stage of the training process
(e.g., ``fit``, ``validate``, ``test``, ``predict``).
:return: The result of the parent class ``setup`` method.
:rtype: Any
"""
if stage == "fit" and self.trainer.compile:
self._setup_compile()
Expand Down
Loading