|
def _create_tensor_batch(self, data): |
|
""" |
|
Reshape properly ``data`` tensor to be processed handle by the graph |
|
based models. |
|
|
|
:param data: torch.Tensor object of shape ``(N, ...)`` where ``N`` is |
|
the number of data objects. |
|
:type data: torch.Tensor | LabelTensor |
|
:return: Reshaped tensor object. |
|
:rtype: torch.Tensor | LabelTensor |
|
""" |
|
out = data.reshape(-1, *data.shape[2:]) |
|
return out |
These lines prevent the class to work for 2D tensors
PINA/pina/data/dataset.py
Lines 279 to 291 in efc9e32
These lines prevent the class to work for 2D tensors