From d9231b61a81f6c117a6ddea5d44b443ea4f08e68 Mon Sep 17 00:00:00 2001 From: Anna Ivagnes Date: Mon, 8 Apr 2024 18:04:54 +0200 Subject: [PATCH] fix GPU training in inverse problem --- pina/trainer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pina/trainer.py b/pina/trainer.py index 0acecaaa9..90779a6e9 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,5 +1,6 @@ """ Trainer module. """ +import torch import pytorch_lightning from .utils import check_consistency from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset @@ -63,6 +64,12 @@ def _create_or_update_loader(self): self._loader = SamplePointLoader( dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True ) + pb = self._model.problem + if hasattr(pb, "unknown_parameters"): + for key in pb.unknown_parameters: + pb.unknown_parameters[key] = torch.nn.Parameter(pb.unknown_parameters[key].data.to(device)) + + def train(self, **kwargs): """