From 1d0ea1ccdac774499ec57897f7a1cea2625dcea3 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 17 Feb 2025 12:11:43 +0100 Subject: [PATCH 1/2] Fix bug in Collector with Graph data --- pina/collector.py | 6 ++++ pina/data/data_module.py | 39 +++++++++++++++++----- pina/data/dataset.py | 72 +++++++++++++++++++++++++++++++--------- pina/graph.py | 15 +++++---- pina/trainer.py | 2 -- pina/utils.py | 3 +- 6 files changed, 105 insertions(+), 32 deletions(-) diff --git a/pina/collector.py b/pina/collector.py index 381e9499c..93ea18254 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -1,3 +1,7 @@ +""" +# TODO +""" +from .graph import Graph from .utils import check_consistency @@ -52,6 +56,8 @@ def store_fixed_data(self): # get data keys = condition.__slots__ values = [getattr(condition, name) for name in keys] + values = [value.data if isinstance( + value, Graph) else value for value in values] self.data_collections[condition_name] = dict(zip(keys, values)) # condition now is ready self._is_conditions_ready[condition_name] = True diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 20b3c1c29..7c3def654 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -2,10 +2,10 @@ import warnings from lightning.pytorch import LightningDataModule import torch -from ..label_tensor import LabelTensor -from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, \ - RandomSampler +from torch_geometric.data import Data, Batch +from torch.utils.data import DataLoader, SequentialSampler, RandomSampler from torch.utils.data.distributed import DistributedSampler +from ..label_tensor import LabelTensor from .dataset import PinaDatasetFactory from ..collector import Collector @@ -86,9 +86,8 @@ def _collate_standard_dataloader(self, batch): single_cond_dict[arg] = LabelTensor.stack(data_list) elif isinstance(data_list[0], torch.Tensor): single_cond_dict[arg] = torch.stack(data_list) - else: - raise NotImplementedError( - f"Data type {type(data_list[0])} not supported") + elif isinstance(data_list[0], Data): + single_cond_dict[arg] = Batch.from_data_list(data_list) batch_dict[condition_name] = single_cond_dict return batch_dict @@ -125,7 +124,7 @@ def __init__(self, batch_size=None, shuffle=True, repeat=False, - automatic_batching=False, + automatic_batching=None, num_workers=0, pin_memory=False, ): @@ -158,7 +157,6 @@ def __init__(self, logging.debug('Start initialization of Pina DataModule') logging.info('Start initialization of Pina DataModule') super().__init__() - self.automatic_batching = automatic_batching self.batch_size = batch_size self.shuffle = shuffle self.repeat = repeat @@ -192,6 +190,10 @@ def __init__(self, collector = Collector(problem) collector.store_fixed_data() collector.store_sample_domains() + + self.automatic_batching = self._set_automatic_batching_option( + collector, automatic_batching) + if batch_size is None and num_workers != 0: warnings.warn( "Setting num_workers when batch_size is None has no effect on " @@ -393,6 +395,27 @@ def _check_slit_sizes(train_size, test_size, val_size, predict_size): if abs(train_size + test_size + val_size + predict_size - 1) > 1e-6: raise ValueError("The sum of the splits must be 1") + @staticmethod + def _set_automatic_batching_option(collector, automatic_batching): + """ + Determines whether automatic batching should be enabled. + + If all 'input_points' in the collector's data collections are + tensors (torch.Tensor or LabelTensor), it respects the provided + `automatic_batching` value; otherwise, mainly in the Graph scenario, + it forces automatic batching on. + + :param Collector collector: Collector object with contains all data + retrieved from input conditions + :param bool automatic_batching : If the user wants to enable automatic + batching or not + """ + if all(isinstance(v['input_points'], (torch.Tensor, LabelTensor)) + for v in collector.data_collections.values()): + return automatic_batching if automatic_batching is not None \ + else False + return True + @property def input_points(self): """ diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 02450400b..3c4a1b9df 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -1,10 +1,12 @@ """ This module provide basic data management functionalities """ +import functools import torch from torch.utils.data import Dataset from abc import abstractmethod from torch_geometric.data import Batch +from pina import LabelTensor class PinaDatasetFactory: @@ -107,10 +109,25 @@ class PinaGraphDataset(PinaDataset): def __init__(self, conditions_dict, max_conditions_lengths, automatic_batching): super().__init__(conditions_dict, max_conditions_lengths) - if automatic_batching: - self._getitem_func = self._getitem_int - else: - self._getitem_func = self._getitem_list + self.in_labels = {} + self.out_labels = None + ex_data = conditions_dict[list(conditions_dict.keys())[ + 0]]['input_points'][0] + for name, attr in ex_data.items(): + if isinstance(attr, LabelTensor): + self.in_labels[name] = attr.stored_labels + ex_data = conditions_dict[list(conditions_dict.keys())[ + 0]]['output_points'][0] + if isinstance(ex_data, LabelTensor): + self.out_labels = ex_data.labels + + self._create_graph_batch_from_list = self._labelise_batch( + self._base_create_graph_batch_from_list) if self.in_labels \ + else self._base_create_graph_batch_from_list + + self._create_output_batch = self._labelise_tensor( + self._base_create_output_batch) if self.out_labels is not None \ + else self._base_create_output_batch def fetch_from_idx_list(self, idx): to_return_dict = {} @@ -119,18 +136,22 @@ def fetch_from_idx_list(self, idx): condition_len = self.conditions_length[condition] if self.length > condition_len: cond_idx = [idx % condition_len for idx in cond_idx] - to_return_dict[condition] = {k: Batch.from_data_list([ - v[i] for i in cond_idx]) - if isinstance(v, list) - else v[ - cond_idx].reshape( - -1, *v[cond_idx].shape[2:]) - for k, v in data.items() - } + to_return_dict[condition] = { + k: self._create_graph_batch_from_list(v, cond_idx) + if isinstance(v, list) + else self._create_output_batch(v, cond_idx) + for k, v in data.items() + } + return to_return_dict - def _getitem_list(self, idx): - return idx + def _base_create_graph_batch_from_list(self, data, idx): + batch = Batch.from_data_list([data[i] for i in idx]) + return batch + + def _base_create_output_batch(self, data, idx): + out = data[idx].reshape(-1, *data[idx].shape[2:]) + return out def _getitem_int(self, idx): return { @@ -143,4 +164,25 @@ def get_all_data(self): return self.fetch_from_idx_list(index) def __getitem__(self, idx): - return self._getitem_func(idx) + return self._getitem_int(idx) if isinstance(idx, int) else \ + self.fetch_from_idx_list(idx=idx) + + def _labelise_batch(self, func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + batch = func(*args, **kwargs) + for k, v in self.in_labels.items(): + tmp = batch[k] + tmp.labels = v + batch[k] = tmp + return batch + return wrapper + + def _labelise_tensor(self, func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + out = func(*args, **kwargs) + if isinstance(out, LabelTensor): + out.labels = self.out_labels + return out + return wrapper diff --git a/pina/graph.py b/pina/graph.py index 959bd9cc0..cd89868ed 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -108,16 +108,14 @@ def __init__( x) # Perform the graph construction - self._build_graph_list(x, pos, edge_index, edge_attr, additional_params) + self._build_graph_list( + x, pos, edge_index, edge_attr, additional_params) def _build_graph_list(self, x, pos, edge_index, edge_attr, additional_params): for i, (x_, pos_, edge_index_) in enumerate(zip(x, pos, edge_index)): - if isinstance(x_, LabelTensor): - x_ = x_.tensor add_params_local = {k: v[i] for k, v in additional_params.items()} if edge_attr is not None: - self.data.append(Data(x=x_, pos=pos_, edge_index=edge_index_, edge_attr=edge_attr[i], **add_params_local)) @@ -165,7 +163,8 @@ def _check_input_consistency(x, pos, edge_index=None): # If edge_index is a 3D tensor, we split it into a list of 2D tensors if edge_index is not None: if isinstance(edge_index, torch.Tensor) and edge_index.ndim == 3: - edge_index = [edge_index[i] for i in range(edge_index.shape[0])] + edge_index = [edge_index[i] + for i in range(edge_index.shape[0])] elif not (isinstance(edge_index, list) and all( t.ndim == 2 for t in edge_index)) and not ( isinstance(edge_index, @@ -219,7 +218,7 @@ def _check_and_build_edge_attr(self, edge_attr, build_edge_attr, data_len, if isinstance(edge_attr, list): if len(edge_attr) != data_len: raise TypeError("edge_attr must have the same length as x " - "and pos.") + "and pos.") return [edge_attr] * data_len if build_edge_attr: @@ -258,6 +257,8 @@ def _radius_graph(points, r): """ dist = torch.cdist(points, points, p=2) edge_index = torch.nonzero(dist <= r, as_tuple=False).t() + if isinstance(edge_index, LabelTensor): + edge_index = edge_index.tensor return edge_index @@ -293,4 +294,6 @@ def _knn_graph(points, k): row = torch.arange(points.size(0)).repeat_interleave(k) col = knn_indices.flatten() edge_index = torch.stack([row, col], dim=0) + if isinstance(edge_index, LabelTensor): + edge_index = edge_index.tensor return edge_index diff --git a/pina/trainer.py b/pina/trainer.py index 0d15e7699..4e6282078 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -105,8 +105,6 @@ def __init__(self, # checking compilation and automatic batching if compile is None or sys.platform == "win32": compile = False - if automatic_batching is None: - automatic_batching = False # set attributes self.compile = compile diff --git a/pina/utils.py b/pina/utils.py index e633cceaf..3bd77e6b6 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -48,7 +48,8 @@ def labelize_forward(forward, input_variables, output_variables): :type output_variables: list[str] | tuple[str] """ def wrapper(x): - x = x.extract(input_variables) + if isinstance(x, LabelTensor): + x = x.extract(input_variables) output = forward(x) # keep it like this, directly using LabelTensor(...) raises errors # when compiling the code From 37d61ea7e3bec223fd099e104ef7a7ad992f0778 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 18 Feb 2025 12:04:42 +0100 Subject: [PATCH 2/2] Add comments in DataModule class and bug fix in collate --- pina/data/data_module.py | 101 ++++++++++++++-------------- pina/data/dataset.py | 58 ++++++++++++---- pina/graph.py | 3 +- pina/trainer.py | 4 +- pina/utils.py | 3 +- tests/test_data/test_datamodule.py | 102 +++++++++++++++++++++++++---- 6 files changed, 193 insertions(+), 78 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 7c3def654..9ecfaa5ad 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -2,11 +2,11 @@ import warnings from lightning.pytorch import LightningDataModule import torch -from torch_geometric.data import Data, Batch +from torch_geometric.data import Data from torch.utils.data import DataLoader, SequentialSampler, RandomSampler from torch.utils.data.distributed import DistributedSampler from ..label_tensor import LabelTensor -from .dataset import PinaDatasetFactory +from .dataset import PinaDatasetFactory, PinaTensorDataset from ..collector import Collector @@ -61,6 +61,10 @@ def __init__(self, max_conditions_lengths, dataset=None): max_conditions_lengths is None else ( self._collate_standard_dataloader) self.dataset = dataset + if isinstance(self.dataset, PinaTensorDataset): + self._collate = self._collate_tensor_dataset + else: + self._collate = self._collate_graph_dataset def _collate_custom_dataloader(self, batch): return self.dataset.fetch_from_idx_list(batch) @@ -73,7 +77,6 @@ def _collate_standard_dataloader(self, batch): if isinstance(batch, dict): return batch conditions_names = batch[0].keys() - # Condition names for condition_name in conditions_names: single_cond_dict = {} @@ -82,15 +85,28 @@ def _collate_standard_dataloader(self, batch): data_list = [batch[idx][condition_name][arg] for idx in range( min(len(batch), self.max_conditions_lengths[condition_name]))] - if isinstance(data_list[0], LabelTensor): - single_cond_dict[arg] = LabelTensor.stack(data_list) - elif isinstance(data_list[0], torch.Tensor): - single_cond_dict[arg] = torch.stack(data_list) - elif isinstance(data_list[0], Data): - single_cond_dict[arg] = Batch.from_data_list(data_list) + single_cond_dict[arg] = self._collate(data_list) + batch_dict[condition_name] = single_cond_dict return batch_dict + @staticmethod + def _collate_tensor_dataset(data_list): + if isinstance(data_list[0], LabelTensor): + return LabelTensor.stack(data_list) + if isinstance(data_list[0], torch.Tensor): + return torch.stack(data_list) + raise RuntimeError("Data must be Tensors or LabelTensor ") + + def _collate_graph_dataset(self, data_list): + if isinstance(data_list[0], LabelTensor): + return LabelTensor.cat(data_list) + if isinstance(data_list[0], torch.Tensor): + return torch.cat(data_list) + if isinstance(data_list[0], Data): + return self.dataset.create_graph_batch(data_list) + raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data") + def __call__(self, batch): return self.callable_function(batch) @@ -157,14 +173,35 @@ def __init__(self, logging.debug('Start initialization of Pina DataModule') logging.info('Start initialization of Pina DataModule') super().__init__() + + # Store fixed attributes self.batch_size = batch_size self.shuffle = shuffle self.repeat = repeat + self.automatic_batching = automatic_batching + if batch_size is None and num_workers != 0: + warnings.warn( + "Setting num_workers when batch_size is None has no effect on " + "the DataLoading process.") + self.num_workers = 0 + else: + self.num_workers = num_workers + if batch_size is None and pin_memory: + warnings.warn("Setting pin_memory to True has no effect when " + "batch_size is None.") + self.pin_memory = False + else: + self.pin_memory = pin_memory + + # Collect data + collector = Collector(problem) + collector.store_fixed_data() + collector.store_sample_domains() # Check if the splits are correct self._check_slit_sizes(train_size, test_size, val_size, predict_size) - # Begin Data splitting + # Split input data into subsets splits_dict = {} if train_size > 0: splits_dict['train'] = train_size @@ -186,23 +223,6 @@ def __init__(self, self.predict_dataset = None else: self.predict_dataloader = super().predict_dataloader - - collector = Collector(problem) - collector.store_fixed_data() - collector.store_sample_domains() - - self.automatic_batching = self._set_automatic_batching_option( - collector, automatic_batching) - - if batch_size is None and num_workers != 0: - warnings.warn( - "Setting num_workers when batch_size is None has no effect on " - "the DataLoading process.") - if batch_size is None and pin_memory: - warnings.warn("Setting pin_memory to True has no effect when " - "batch_size is None.") - self.num_workers = num_workers - self.pin_memory = pin_memory self.collector_splits = self._create_splits(collector, splits_dict) self.transfer_batch_to_device = self._transfer_batch_to_device @@ -318,10 +338,10 @@ def _create_dataloader(self, split, dataset): if self.batch_size is not None: sampler = PinaSampler(dataset, shuffle) if self.automatic_batching: - collate = Collator(self.find_max_conditions_lengths(split)) - + collate = Collator(self.find_max_conditions_lengths(split), + dataset=dataset) else: - collate = Collator(None, dataset) + collate = Collator(None, dataset=dataset) return DataLoader(dataset, self.batch_size, collate_fn=collate, sampler=sampler, num_workers=self.num_workers) @@ -395,27 +415,6 @@ def _check_slit_sizes(train_size, test_size, val_size, predict_size): if abs(train_size + test_size + val_size + predict_size - 1) > 1e-6: raise ValueError("The sum of the splits must be 1") - @staticmethod - def _set_automatic_batching_option(collector, automatic_batching): - """ - Determines whether automatic batching should be enabled. - - If all 'input_points' in the collector's data collections are - tensors (torch.Tensor or LabelTensor), it respects the provided - `automatic_batching` value; otherwise, mainly in the Graph scenario, - it forces automatic batching on. - - :param Collector collector: Collector object with contains all data - retrieved from input conditions - :param bool automatic_batching : If the user wants to enable automatic - batching or not - """ - if all(isinstance(v['input_points'], (torch.Tensor, LabelTensor)) - for v in collector.data_collections.values()): - return automatic_batching if automatic_batching is not None \ - else False - return True - @property def input_points(self): """ diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 3c4a1b9df..2fecb9348 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -5,7 +5,7 @@ import torch from torch.utils.data import Dataset from abc import abstractmethod -from torch_geometric.data import Batch +from torch_geometric.data import Batch, Data from pina import LabelTensor @@ -64,7 +64,7 @@ def __init__(self, conditions_dict, max_conditions_lengths, if automatic_batching: self._getitem_func = self._getitem_int else: - self._getitem_func = self._getitem_list + self._getitem_func = self._getitem_dummy def _getitem_int(self, idx): return { @@ -84,7 +84,7 @@ def fetch_from_idx_list(self, idx): return to_return_dict @staticmethod - def _getitem_list(idx): + def _getitem_dummy(idx): return idx def get_all_data(self): @@ -104,6 +104,27 @@ def input_points(self): } +class PinaBatch(Batch): + """ + Add extract function to torch_geometric Batch object + """ + def __init__(self): + + super().__init__(self) + + def extract(self, labels): + """ + Perform extraction of labels on node features (x) + + :param labels: Labels to extract + :type labels: list[str] | tuple[str] | str + :return: Batch object with extraction performed on x + :rtype: PinaBatch + """ + self.x = self.x.extract(labels) + return self + + class PinaGraphDataset(PinaDataset): def __init__(self, conditions_dict, max_conditions_lengths, @@ -111,6 +132,11 @@ def __init__(self, conditions_dict, max_conditions_lengths, super().__init__(conditions_dict, max_conditions_lengths) self.in_labels = {} self.out_labels = None + if automatic_batching: + self._getitem_func = self._getitem_int + else: + self._getitem_func = self._getitem_dummy + ex_data = conditions_dict[list(conditions_dict.keys())[ 0]]['input_points'][0] for name, attr in ex_data.items(): @@ -137,22 +163,25 @@ def fetch_from_idx_list(self, idx): if self.length > condition_len: cond_idx = [idx % condition_len for idx in cond_idx] to_return_dict[condition] = { - k: self._create_graph_batch_from_list(v, cond_idx) + k: self._create_graph_batch_from_list([v[i] for i in idx]) if isinstance(v, list) - else self._create_output_batch(v, cond_idx) + else self._create_output_batch(v[idx]) for k, v in data.items() } return to_return_dict - def _base_create_graph_batch_from_list(self, data, idx): - batch = Batch.from_data_list([data[i] for i in idx]) + def _base_create_graph_batch_from_list(self, data): + batch = PinaBatch.from_data_list(data) return batch - def _base_create_output_batch(self, data, idx): - out = data[idx].reshape(-1, *data[idx].shape[2:]) + def _base_create_output_batch(self, data): + out = data.reshape(-1, *data.shape[2:]) return out + def _getitem_dummy(self, idx): + return idx + def _getitem_int(self, idx): return { k: {k_data: v[k_data][idx % len(v['input_points'])] for k_data @@ -164,8 +193,7 @@ def get_all_data(self): return self.fetch_from_idx_list(index) def __getitem__(self, idx): - return self._getitem_int(idx) if isinstance(idx, int) else \ - self.fetch_from_idx_list(idx=idx) + return self._getitem_func(idx) def _labelise_batch(self, func): @functools.wraps(func) @@ -186,3 +214,11 @@ def wrapper(*args, **kwargs): out.labels = self.out_labels return out return wrapper + + def create_graph_batch(self, data): + """ + # TODO + """ + if isinstance(data[0], Data): + return self._create_graph_batch_from_list(data) + return self._create_output_batch(data) diff --git a/pina/graph.py b/pina/graph.py index cd89868ed..ca92ab435 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -125,7 +125,8 @@ def _build_graph_list(self, x, pos, edge_index, edge_attr, @staticmethod def _build_edge_attr(x, pos, edge_index): - distance = torch.abs(pos[edge_index[0]] - pos[edge_index[1]]) + distance = torch.abs(pos[edge_index[0]] - + pos[edge_index[1]]).as_subclass(torch.Tensor) return distance @staticmethod diff --git a/pina/trainer.py b/pina/trainer.py index 4e6282078..eb8639e16 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -106,6 +106,8 @@ def __init__(self, if compile is None or sys.platform == "win32": compile = False + self.automatic_batching = automatic_batching if automatic_batching \ + is not None else False # set attributes self.compile = compile self.solver = solver @@ -113,7 +115,7 @@ def __init__(self, self._move_to_device() self.data_module = None self._create_datamodule(train_size, test_size, val_size, predict_size, - batch_size, automatic_batching, pin_memory, + batch_size, automatic_batching, pin_memory, num_workers) # logging diff --git a/pina/utils.py b/pina/utils.py index 3bd77e6b6..e633cceaf 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -48,8 +48,7 @@ def labelize_forward(forward, input_variables, output_variables): :type output_variables: list[str] | tuple[str] """ def wrapper(x): - if isinstance(x, LabelTensor): - x = x.extract(input_variables) + x = x.extract(input_variables) output = forward(x) # keep it like this, directly using LabelTensor(...) raises errors # when compiling the code diff --git a/tests/test_data/test_datamodule.py b/tests/test_data/test_datamodule.py index 866eebc69..f475c0498 100644 --- a/tests/test_data/test_datamodule.py +++ b/tests/test_data/test_datamodule.py @@ -13,10 +13,10 @@ input_tensor = torch.rand((100, 10)) output_tensor = torch.rand((100, 2)) -x = torch.rand((100, 50 , 10)) -pos = torch.rand((100, 50 , 2)) +x = torch.rand((100, 50, 10)) +pos = torch.rand((100, 50, 2)) input_graph = RadiusGraph(x, pos, r=.1, build_edge_attr=True) -output_graph = torch.rand((100, 50 , 10)) +output_graph = torch.rand((100, 50, 10)) @pytest.mark.parametrize( @@ -30,6 +30,7 @@ def test_constructor(input_, output_): problem = SupervisedProblem(input_=input_, output_=output_) PinaDataModule(problem) + @pytest.mark.parametrize( "input_, output_", [ @@ -46,14 +47,15 @@ def test_constructor(input_, output_): ) def test_setup_train(input_, output_, train_size, val_size, test_size): problem = SupervisedProblem(input_=input_, output_=output_) - dm = PinaDataModule(problem, train_size=train_size, val_size=val_size, test_size=test_size) + dm = PinaDataModule(problem, train_size=train_size, + val_size=val_size, test_size=test_size) dm.setup() assert hasattr(dm, "train_dataset") if isinstance(input_, torch.Tensor): assert isinstance(dm.train_dataset, PinaTensorDataset) else: assert isinstance(dm.train_dataset, PinaGraphDataset) - #assert len(dm.train_dataset) == int(len(input_) * train_size) + # assert len(dm.train_dataset) == int(len(input_) * train_size) if test_size > 0: assert hasattr(dm, "test_dataset") assert dm.test_dataset is None @@ -64,7 +66,8 @@ def test_setup_train(input_, output_, train_size, val_size, test_size): assert isinstance(dm.val_dataset, PinaTensorDataset) else: assert isinstance(dm.val_dataset, PinaGraphDataset) - #assert len(dm.val_dataset) == int(len(input_) * val_size) + # assert len(dm.val_dataset) == int(len(input_) * val_size) + @pytest.mark.parametrize( "input_, output_", @@ -82,7 +85,8 @@ def test_setup_train(input_, output_, train_size, val_size, test_size): ) def test_setup_test(input_, output_, train_size, val_size, test_size): problem = SupervisedProblem(input_=input_, output_=output_) - dm = PinaDataModule(problem, train_size=train_size, val_size=val_size, test_size=test_size) + dm = PinaDataModule(problem, train_size=train_size, + val_size=val_size, test_size=test_size) dm.setup(stage='test') if train_size > 0: assert hasattr(dm, "train_dataset") @@ -94,13 +98,14 @@ def test_setup_test(input_, output_, train_size, val_size, test_size): assert dm.val_dataset is None else: assert not hasattr(dm, "val_dataset") - + assert hasattr(dm, "test_dataset") if isinstance(input_, torch.Tensor): assert isinstance(dm.test_dataset, PinaTensorDataset) else: assert isinstance(dm.test_dataset, PinaGraphDataset) - #assert len(dm.test_dataset) == int(len(input_) * test_size) + # assert len(dm.test_dataset) == int(len(input_) * test_size) + @pytest.mark.parametrize( "input_, output_", @@ -112,7 +117,8 @@ def test_setup_test(input_, output_, train_size, val_size, test_size): def test_dummy_dataloader(input_, output_): problem = SupervisedProblem(input_=input_, output_=output_) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) - trainer = Trainer(solver, batch_size=None, train_size=.7, val_size=.3, test_size=0.) + trainer = Trainer(solver, batch_size=None, train_size=.7, + val_size=.3, test_size=0.) dm = trainer.data_module dm.setup() dm.trainer = trainer @@ -140,6 +146,7 @@ def test_dummy_dataloader(input_, output_): assert isinstance(data[0][1]['input_points'], torch.Tensor) assert isinstance(data[0][1]['output_points'], torch.Tensor) + @pytest.mark.parametrize( "input_, output_", [ @@ -147,10 +154,17 @@ def test_dummy_dataloader(input_, output_): (input_graph, output_graph) ] ) -def test_dataloader(input_, output_): +@pytest.mark.parametrize( + "automatic_batching", + [ + True, False + ] +) +def test_dataloader(input_, output_, automatic_batching): problem = SupervisedProblem(input_=input_, output_=output_) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) - trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3, test_size=0.) + trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3, + test_size=0., automatic_batching=automatic_batching) dm = trainer.data_module dm.setup() dm.trainer = trainer @@ -176,3 +190,67 @@ def test_dataloader(input_, output_): assert isinstance(data['data']['input_points'], torch.Tensor) assert isinstance(data['data']['output_points'], torch.Tensor) +from pina import LabelTensor + +input_tensor = LabelTensor(torch.rand((100, 3)), ['u', 'v', 'w']) +output_tensor = LabelTensor(torch.rand((100, 3)), ['u', 'v', 'w']) + +x = LabelTensor(torch.rand((100, 50, 3)), ['u', 'v', 'w']) +pos = LabelTensor(torch.rand((100, 50, 2)), ['x', 'y']) +input_graph = RadiusGraph(x, pos, r=.1, build_edge_attr=True) +output_graph = LabelTensor(torch.rand((100, 50, 3)), ['u', 'v', 'w']) + +@pytest.mark.parametrize( + "input_, output_", + [ + (input_tensor, output_tensor), + (input_graph, output_graph) + ] +) +@pytest.mark.parametrize( + "automatic_batching", + [ + True, False + ] +) +def test_dataloader_labels(input_, output_, automatic_batching): + problem = SupervisedProblem(input_=input_, output_=output_) + solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) + trainer = Trainer(solver, batch_size=10, train_size=.7, val_size=.3, + test_size=0., automatic_batching=automatic_batching) + dm = trainer.data_module + dm.setup() + dm.trainer = trainer + dataloader = dm.train_dataloader() + assert isinstance(dataloader, DataLoader) + assert len(dataloader) == 7 + data = next(iter(dataloader)) + assert isinstance(data, dict) + if isinstance(input_, RadiusGraph): + assert isinstance(data['data']['input_points'], Batch) + assert isinstance(data['data']['input_points'].x, LabelTensor) + assert data['data']['input_points'].x.labels == ['u', 'v', 'w'] + assert data['data']['input_points'].pos.labels == ['x', 'y'] + else: + assert isinstance(data['data']['input_points'], LabelTensor) + assert data['data']['input_points'].labels == ['u', 'v', 'w'] + assert isinstance(data['data']['output_points'], LabelTensor) + assert data['data']['output_points'].labels == ['u', 'v', 'w'] + + dataloader = dm.val_dataloader() + assert isinstance(dataloader, DataLoader) + assert len(dataloader) == 3 + data = next(iter(dataloader)) + assert isinstance(data, dict) + if isinstance(input_, RadiusGraph): + assert isinstance(data['data']['input_points'], Batch) + assert isinstance(data['data']['input_points'].x, LabelTensor) + assert data['data']['input_points'].x.labels == ['u', 'v', 'w'] + assert data['data']['input_points'].pos.labels == ['x', 'y'] + else: + assert isinstance(data['data']['input_points'], torch.Tensor) + assert isinstance(data['data']['input_points'], LabelTensor) + assert data['data']['input_points'].labels == ['u', 'v', 'w'] + assert isinstance(data['data']['output_points'], torch.Tensor) + assert data['data']['output_points'].labels == ['u', 'v', 'w'] +test_dataloader_labels(input_graph, output_graph, True) \ No newline at end of file