diff --git a/pina/trainer.py b/pina/trainer.py index 0acecaaa9..90779a6e9 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,5 +1,6 @@ """ Trainer module. """ +import torch import pytorch_lightning from .utils import check_consistency from .dataset import SamplePointDataset, SamplePointLoader, DataPointDataset @@ -63,6 +64,12 @@ def _create_or_update_loader(self): self._loader = SamplePointLoader( dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True ) + pb = self._model.problem + if hasattr(pb, "unknown_parameters"): + for key in pb.unknown_parameters: + pb.unknown_parameters[key] = torch.nn.Parameter(pb.unknown_parameters[key].data.to(device)) + + def train(self, **kwargs): """