Skip to content
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
resolve_writer,
)
from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer
from .meta_obj import MetaObj, get_track_meta, get_track_transforms, set_track_meta, set_track_transforms
from .meta_obj import MetaObj, get_track_meta, set_track_meta
from .meta_tensor import MetaTensor
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
Expand Down
81 changes: 38 additions & 43 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
from typing import Any, Callable, Sequence

_TRACK_META = True
_TRACK_TRANSFORMS = True

__all__ = ["get_track_meta", "get_track_transforms", "set_track_meta", "set_track_transforms", "MetaObj"]
__all__ = ["get_track_meta", "set_track_meta", "MetaObj"]
Comment thread
rijobro marked this conversation as resolved.


def set_track_meta(val: bool) -> None:
Expand All @@ -26,9 +25,8 @@ def set_track_meta(val: bool) -> None:
its data by using subclasses of `MetaObj`. If `False`, then data will be returned
with empty metadata.

If both `set_track_meta` and `set_track_transforms` are set to
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and
`np.ndarray`) as opposed to our enhanced objects.
If `set_track_meta` is `False`, then standard data objects will be returned (e.g.,
`torch.Tensor` and `np.ndarray`) as opposed to our enhanced objects.

By default, this is `True`, and most users will want to leave it this way. However,
if you are experiencing any problems regarding metadata, and aren't interested in
Expand All @@ -38,33 +36,14 @@ def set_track_meta(val: bool) -> None:
_TRACK_META = val


def set_track_transforms(val: bool) -> None:
"""
Boolean to set whether transforms are tracked. If `True`, applied transforms will be
associated its data by using subclasses of `MetaObj`. If `False`, then transforms
won't be tracked.

If both `set_track_meta` and `set_track_transforms` are set to
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and
`np.ndarray`) as opposed to our enhanced objects.

By default, this is `True`, and most users will want to leave it this way. However,
if you are experiencing any problems regarding transforms, and aren't interested in
preserving transforms, then you can disable it.
"""
global _TRACK_TRANSFORMS
_TRACK_TRANSFORMS = val


def get_track_meta() -> bool:
"""
Return the boolean as to whether metadata is tracked. If `True`, metadata will be
associated its data by using subclasses of `MetaObj`. If `False`, then data will be
returned with empty metadata.

If both `set_track_meta` and `set_track_transforms` are set to
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and
`np.ndarray`) as opposed to our enhanced objects.
If `set_track_meta` is `False`, then standard data objects will be returned (e.g.,
`torch.Tensor` and `np.ndarray`) as opposed to our enhanced objects.

By default, this is `True`, and most users will want to leave it this way. However,
if you are experiencing any problems regarding metadata, and aren't interested in
Expand All @@ -73,23 +52,6 @@ def get_track_meta() -> bool:
return _TRACK_META


def get_track_transforms() -> bool:
"""
Return the boolean as to whether transforms are tracked. If `True`, applied
transforms will be associated its data by using subclasses of `MetaObj`. If `False`,
then transforms won't be tracked.

If both `set_track_meta` and `set_track_transforms` are set to
`False`, then standard data objects will be returned (e.g., `torch.Tensor` and
`np.ndarray`) as opposed to our enhanced objects.

By default, this is `True`, and most users will want to leave it this way. However,
if you are experiencing any problems regarding transforms, and aren't interested in
preserving transforms, then you can disable it.
"""
return _TRACK_TRANSFORMS


class MetaObj:
"""
Abstract base class that stores data as well as any extra metadata.
Expand Down Expand Up @@ -177,6 +139,7 @@ def _copy_meta(self, input_objs: list[MetaObj]) -> None:
id_in = id(input_objs[0]) if len(input_objs) > 0 else None
deep_copy = id(self) != id_in
self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy)
self._copy_attr("applied_operations", input_objs, self.get_default_applied_operations, deep_copy)
Comment thread
rijobro marked this conversation as resolved.
self.is_batch = input_objs[0].is_batch if len(input_objs) > 0 else False

def get_default_meta(self) -> dict:
Expand All @@ -187,6 +150,14 @@ def get_default_meta(self) -> dict:
"""
return {}

def get_default_applied_operations(self) -> list:
"""Get the default applied operations.

Returns:
default applied operations.
"""
return []

def __repr__(self) -> str:
"""String representation of class."""
out: str = super().__repr__()
Expand All @@ -196,6 +167,14 @@ def __repr__(self) -> str:
out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items())
else:
out += "None"

out += "\nApplied operations\n"
if self.applied_operations is not None:
for i in self.applied_operations:
out += f"\t{str(i)}\n"
else:
out += "None"

out += f"\nIs batch?: {self.is_batch}"

return out
Expand All @@ -210,6 +189,22 @@ def meta(self, d: dict) -> None:
"""Set the meta."""
self._meta = d

@property
def applied_operations(self) -> list:
"""Get the applied operations."""
return self._applied_operations

@applied_operations.setter
def applied_operations(self, t: list) -> None:
"""Set the applied operations."""
self._applied_operations = t

def push_applied_operation(self, t: Any) -> None:
self._applied_operations.append(t)

def pop_applied_operation(self) -> Any:
return self._applied_operations.pop()

@property
def is_batch(self) -> bool:
"""Return whether object is part of batch or not."""
Expand Down
40 changes: 32 additions & 8 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

import torch

from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms
from monai.config.type_definitions import NdarrayTensor
from monai.data.meta_obj import MetaObj, get_track_meta
from monai.data.utils import decollate_batch, list_data_collate
from monai.utils.enums import PostFix

Expand Down Expand Up @@ -72,10 +73,20 @@ class MetaTensor(MetaObj, torch.Tensor):
"""

@staticmethod
def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor:
def __new__(
cls,
x,
affine: torch.Tensor | None = None,
meta: dict | None = None,
applied_operations: list | None = None,
*args,
**kwargs,
) -> MetaTensor:
return torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore

def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = None) -> None:
def __init__(
self, x, affine: torch.Tensor | None = None, meta: dict | None = None, applied_operations: list | None = None
) -> None:
"""
If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it.
Else, use the default value. Similar for the affine, except this could come from
Expand All @@ -94,15 +105,24 @@ def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = No
warnings.warn("Setting affine, but the applied meta contains an affine. This will be overwritten.")
self.affine = affine
elif "affine" in self.meta:
pass # nothing to do
# by using the setter function, we ensure it is converted to torch.Tensor if not already
self.affine = self.meta["affine"]
elif isinstance(x, MetaTensor):
self.affine = x.affine
else:
self.affine = self.get_default_affine()
# applied_operations
if applied_operations is not None:
self.applied_operations = applied_operations
elif isinstance(x, MetaTensor):
self.applied_operations = x.applied_operations
else:
self.applied_operations = self.get_default_applied_operations()

# if we are creating a new MetaTensor, then deep copy attributes
if isinstance(x, torch.Tensor) and not isinstance(x, MetaTensor):
self.meta = deepcopy(self.meta)
self.applied_operations = deepcopy(self.applied_operations)
self.affine = self.affine.to(self.device)

def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None:
Expand All @@ -126,7 +146,7 @@ def update_meta(rets: Sequence, func, args, kwargs):
if not isinstance(ret, MetaTensor):
pass
# if not tracking, convert to `torch.Tensor`.
elif not (get_track_meta() or get_track_transforms()):
elif not get_track_meta():
ret = ret.as_tensor()
# else, handle the `MetaTensor` metadata.
else:
Expand Down Expand Up @@ -221,17 +241,21 @@ def as_dict(self, key: str) -> dict:
A dictionary consisting of two keys, the main data (stored under `key`) and
the metadata.
"""
return {key: self.as_tensor(), PostFix.meta(key): self.meta}
return {
key: self.as_tensor(),
PostFix.meta(key): deepcopy(self.meta),
PostFix.transforms(key): deepcopy(self.applied_operations),
}

@property
def affine(self) -> torch.Tensor:
"""Get the affine."""
return self.meta["affine"] # type: ignore

@affine.setter
def affine(self, d: torch.Tensor) -> None:
def affine(self, d: NdarrayTensor) -> None:
"""Set the affine."""
self.meta["affine"] = d
self.meta["affine"] = torch.as_tensor(d, device=self.device)

def new_empty(self, size, dtype=None, device=None, requires_grad=False):
"""
Expand Down
27 changes: 19 additions & 8 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch

from monai.data.meta_tensor import MetaTensor
from monai.transforms.transform import Transform
from monai.utils.enums import TraceKeys

Expand Down Expand Up @@ -54,7 +55,8 @@ def trace_key(key: Hashable = None):
def push_transform(
self, data: Mapping, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None
) -> None:
"""PUsh to a stack of applied transforms for that key."""
"""Push to a stack of applied transforms for that key."""

if not self.tracing:
return
info = {TraceKeys.CLASS_NAME: self.__class__.__name__, TraceKeys.ID: id(self)}
Expand All @@ -67,17 +69,23 @@ def push_transform(
# If class is randomizable transform, store whether the transform was actually performed (based on `prob`)
if hasattr(self, "_do_transform"): # RandomizableTransform
info[TraceKeys.DO_TRANSFORM] = self._do_transform # type: ignore
# If this is the first, create list
if self.trace_key(key) not in data:
if not isinstance(data, dict):
data = dict(data)
data[self.trace_key(key)] = []
data[self.trace_key(key)].append(info)

if key in data and isinstance(data[key], MetaTensor):
data[key].push_applied_operation(info)
Comment thread
rijobro marked this conversation as resolved.
else:
# If this is the first, create list
if self.trace_key(key) not in data:
if not isinstance(data, dict):
data = dict(data)
data[self.trace_key(key)] = []
data[self.trace_key(key)].append(info)

def pop_transform(self, data: Mapping, key: Hashable = None):
"""Remove the most recent applied transform."""
if not self.tracing:
return
if key in data and isinstance(data[key], MetaTensor):
return data[key].pop_applied_operation()
return data.get(self.trace_key(key), []).pop()


Expand Down Expand Up @@ -133,7 +141,10 @@ def get_most_recent_transform(self, data: Mapping, key: Hashable = None):
"""Get most recent transform."""
if not self.tracing:
raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.")
transform = data[self.trace_key(key)][-1]
if isinstance(data[key], MetaTensor):
transform = data[key].applied_operations[-1]
else:
transform = data[self.trace_key(key)][-1]
self.check_transforms_match(transform)
return transform

Expand Down
16 changes: 10 additions & 6 deletions monai/transforms/meta_utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ class FromMetaTensord(MapTransform, InvertibleTransform):
Dictionary-based transform to convert MetaTensor to a dictionary.

If input is `{"a": MetaTensor, "b": MetaTensor}`, then output will
have the form `{"a": torch.Tensor, "a_meta_dict": dict, "b": ...}`.
have the form `{"a": torch.Tensor, "a_meta_dict": dict, "a_transforms": list, "b": ...}`.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
self.push_transform(d, key)
im: MetaTensor = d[key] # type: ignore
d.update(im.as_dict(key))
self.push_transform(d, key)
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
Expand All @@ -58,8 +58,10 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
# check transform
_ = self.get_most_recent_transform(d, key)
# do the inverse
im, meta = d[key], d.pop(PostFix.meta(key), None)
im = MetaTensor(im, meta=meta) # type: ignore
im = d[key]
meta = d.pop(PostFix.meta(key), None)
transforms = d.pop(PostFix.transforms(key), None)
im = MetaTensor(im, meta=meta, applied_operations=transforms) # type: ignore
d[key] = im
# Remove the applied transform
self.pop_transform(d, key)
Expand All @@ -80,8 +82,10 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
d = dict(data)
for key in self.key_iterator(d):
self.push_transform(d, key)
im, meta = d[key], d.pop(PostFix.meta(key), None)
im = MetaTensor(im, meta=meta) # type: ignore
im = d[key]
meta = d.pop(PostFix.meta(key), None)
transforms = d.pop(PostFix.transforms(key), None)
im = MetaTensor(im, meta=meta, applied_operations=transforms) # type: ignore
d[key] = im
return d

Expand Down
Loading