From 027fd6db7087669753c1d6293dbdc8e118586d83 Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Tue, 16 Apr 2024 14:23:26 +0200 Subject: [PATCH] modify dataloading in solvers --- pina/solvers/garom.py | 9 +-------- pina/solvers/pinns/basepinn.py | 15 --------------- pina/solvers/solver.py | 15 +++++++++++++++ pina/solvers/supervised.py | 10 ++-------- 4 files changed, 18 insertions(+), 31 deletions(-) diff --git a/pina/solvers/garom.py b/pina/solvers/garom.py index 08856704f..d6cd6246e 100644 --- a/pina/solvers/garom.py +++ b/pina/solvers/garom.py @@ -253,18 +253,11 @@ def training_step(self, batch, batch_idx): :rtype: LabelTensor """ - dataloader = self.trainer.train_dataloader condition_idx = batch["condition"] for condition_id in range(condition_idx.min(), condition_idx.max() + 1): - if sys.version_info >= (3, 8): - condition_name = dataloader.condition_names[condition_id] - else: - condition_name = dataloader.loaders.condition_names[ - condition_id - ] - + condition_name = self._dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] pts = batch["pts"].detach() out = batch["output"] diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index f1b59d977..1b623f343 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -76,21 +76,6 @@ def __init__( self.__logged_res_losses = [] - def on_train_start(self): - """ - On training epoch start this function is call to do global checks for - the PINN training. - """ - - # 1. Check the verison for dataloader - dataloader = self.trainer.train_dataloader - if sys.version_info < (3, 8): - dataloader = dataloader.loaders - self._dataloader = dataloader - - return super().on_train_start() - - def training_step(self, batch, _): """ PINN solver training step. diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index 324a023dd..729a9d485 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -6,6 +6,7 @@ from ..utils import check_consistency from ..problem import AbstractProblem import torch +import sys class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): @@ -141,6 +142,20 @@ def problem(self): """ The problem formulation.""" return self._pina_problem + + def on_train_start(self): + """ + On training epoch start this function is call to do global checks for + the different solvers. + """ + + # 1. Check the verison for dataloader + dataloader = self.trainer.train_dataloader + if sys.version_info < (3, 8): + dataloader = dataloader.loaders + self._dataloader = dataloader + + return super().on_train_start() # @model.setter # def model(self, new_model): diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index c6a8a35bf..75d95fb8b 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -96,18 +96,12 @@ def training_step(self, batch, batch_idx): :return: The sum of the loss functions. :rtype: LabelTensor """ - - dataloader = self.trainer.train_dataloader + condition_idx = batch["condition"] for condition_id in range(condition_idx.min(), condition_idx.max() + 1): - if sys.version_info >= (3, 8): - condition_name = dataloader.condition_names[condition_id] - else: - condition_name = dataloader.loaders.condition_names[ - condition_id - ] + condition_name = self._dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] pts = batch["pts"] out = batch["output"]