diff --git a/monai/apps/detection/transforms/array.py b/monai/apps/detection/transforms/array.py index cb788f8d92..fb338682ee 100644 --- a/monai/apps/detection/transforms/array.py +++ b/monai/apps/detection/transforms/array.py @@ -13,7 +13,7 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ -from typing import Optional, Sequence, Tuple, Type, Union +from typing import Callable, Optional, Sequence, Tuple, Type, Union import numpy as np import torch @@ -27,7 +27,7 @@ get_spatial_dims, spatial_crop_boxes, ) -from monai.transforms import SpatialCrop +from monai.transforms import Rotate90, SpatialCrop from monai.transforms.transform import Transform from monai.utils import ensure_tuple, ensure_tuple_rep, fall_back_tuple, look_up_option from monai.utils.enums import TransformBackends @@ -38,6 +38,7 @@ convert_mask_to_box, flip_boxes, resize_boxes, + rot90_boxes, select_labels, zoom_boxes, ) @@ -53,6 +54,7 @@ "BoxToMask", "MaskToBox", "SpatialCropBox", + "RotateBox90", ] @@ -514,3 +516,30 @@ def __call__( # type: ignore [self.slices[axis].stop for axis in range(spatial_dims)], ) return boxes_crop, select_labels(labels, keep) + + +class RotateBox90(Rotate90): + """ + Rotate a boxes by 90 degrees in the plane specified by `axes`. + See box_ops.rot90_boxes for additional details + + Args: + k: number of times to rotate by 90 degrees. + spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. + Default: (0, 1), this is the first two axis in spatial dimensions. + If axis is negative it counts from the last to the first axis. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: + super().__init__(k, spatial_axes) + + def __call__(self, boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int]) -> NdarrayOrTensor: # type: ignore + """ + Args: + img: channel first array, must have shape: (num_channels, H[, W, ..., ]), + """ + rot90: Callable = rot90_boxes + out: NdarrayOrTensor = rot90(boxes, spatial_size, self.k, self.spatial_axes) + return out diff --git a/monai/apps/detection/transforms/box_ops.py b/monai/apps/detection/transforms/box_ops.py index 65d95d8220..4f877967b8 100644 --- a/monai/apps/detection/transforms/box_ops.py +++ b/monai/apps/detection/transforms/box_ops.py @@ -347,3 +347,79 @@ def select_labels( return labels_select_list[0] # type: ignore return tuple(labels_select_list) + + +def swapaxes_boxes(boxes: NdarrayOrTensor, axis1: int, axis2: int) -> NdarrayOrTensor: + """ + Interchange two axes of boxes. + + Args: + boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` + axis1: First axis. + axis2: Second axis. + + Returns: + boxes with two axes interchanged. + + """ + spatial_dims: int = get_spatial_dims(boxes=boxes) + boxes_swap: NdarrayOrTensor = deepcopy(boxes) + boxes_swap[:, [axis1, axis2]] = boxes_swap[:, [axis2, axis1]] # type: ignore + boxes_swap[:, [spatial_dims + axis1, spatial_dims + axis2]] = boxes_swap[ # type: ignore + :, [spatial_dims + axis2, spatial_dims + axis1] + ] + return boxes_swap + + +def rot90_boxes( + boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int], k: int = 1, axes: Tuple[int, int] = (0, 1) +) -> NdarrayOrTensor: + """ + Rotate boxes by 90 degrees in the plane specified by axes. + Rotation direction is from the first towards the second axis. + + Args: + boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` + spatial_size: image spatial size. + k : number of times the array is rotated by 90 degrees. + axes: (2,) array_like + The array is rotated in the plane defined by the axes. Axes must be different. + + Returns: + A rotated view of `boxes`. + + Notes: + ``rot90_boxes(boxes, spatial_size, k=1, axes=(1,0))`` is the reverse of + ``rot90_boxes(boxes, spatial_size, k=1, axes=(0,1))`` + ``rot90_boxes(boxes, spatial_size, k=1, axes=(1,0))`` is equivalent to + ``rot90_boxes(boxes, spatial_size, k=-1, axes=(0,1))`` + """ + spatial_dims: int = get_spatial_dims(boxes=boxes) + spatial_size_ = list(ensure_tuple_rep(spatial_size, spatial_dims)) + + axes = ensure_tuple(axes) # type: ignore + + if len(axes) != 2: + raise ValueError("len(axes) must be 2.") + + if axes[0] == axes[1] or abs(axes[0] - axes[1]) == spatial_dims: + raise ValueError("Axes must be different.") + + if axes[0] >= spatial_dims or axes[0] < -spatial_dims or axes[1] >= spatial_dims or axes[1] < -spatial_dims: + raise ValueError(f"Axes={axes} out of range for array of ndim={spatial_dims}.") + + k %= 4 + + if k == 0: + return boxes + if k == 2: + return flip_boxes(flip_boxes(boxes, spatial_size_, axes[0]), spatial_size_, axes[1]) + + if k == 1: + boxes_ = flip_boxes(boxes, spatial_size_, axes[1]) + return swapaxes_boxes(boxes_, axes[0], axes[1]) + else: + # k == 3 + boxes_ = swapaxes_boxes(boxes, axes[0], axes[1]) + spatial_size_[axes[0]], spatial_size_[axes[1]] = spatial_size_[axes[1]], spatial_size_[axes[0]] + return flip_boxes(boxes_, spatial_size_, axes[1]) diff --git a/monai/apps/detection/transforms/dictionary.py b/monai/apps/detection/transforms/dictionary.py index b1591c097c..e47238c222 100644 --- a/monai/apps/detection/transforms/dictionary.py +++ b/monai/apps/detection/transforms/dictionary.py @@ -29,6 +29,7 @@ ConvertBoxToStandardMode, FlipBox, MaskToBox, + RotateBox90, SpatialCropBox, ZoomBox, ) @@ -37,7 +38,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.box_utils import COMPUTE_DTYPE, BoxMode, clip_boxes_to_image from monai.data.utils import orientation_ras_lps -from monai.transforms import Flip, RandFlip, RandZoom, SpatialCrop, SpatialPad, Zoom +from monai.transforms import Flip, RandFlip, RandRotate90d, RandZoom, Rotate90, SpatialCrop, SpatialPad, Zoom from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices @@ -80,6 +81,12 @@ "RandCropBoxByPosNegLabeld", "RandCropBoxByPosNegLabelD", "RandCropBoxByPosNegLabelDict", + "RotateBox90d", + "RotateBox90D", + "RotateBox90Dict", + "RandRotateBox90d", + "RandRotateBox90D", + "RandRotateBox90Dict", ] DEFAULT_POST_FIX = PostFix.meta() @@ -285,7 +292,7 @@ class ZoomBoxd(MapTransform, InvertibleTransform): Args: image_keys: Keys to pick image data for transformation. box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``. - box_ref_image_keys: Keys that represents the reference images to which ``box_keys`` are attached. + box_ref_image_keys: Keys that represent the reference images to which ``box_keys`` are attached. zoom: The zoom factor along the spatial axes. If a float, zoom is the same for each spatial axis. If a sequence, zoom should contain one value for each spatial axis. @@ -414,7 +421,7 @@ class RandZoomBoxd(RandomizableTransform, MapTransform, InvertibleTransform): Args: image_keys: Keys to pick image data for transformation. box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``. - box_ref_image_keys: Keys that represents the reference images to which ``box_keys`` are attached. + box_ref_image_keys: Keys that represent the reference images to which ``box_keys`` are attached. prob: Probability of zooming. min_zoom: Min zoom factor. Can be float or sequence same size as image. If a float, select a random factor from `[min_zoom, max_zoom]` then apply to all spatial dims @@ -577,7 +584,7 @@ class FlipBoxd(MapTransform, InvertibleTransform): Args: image_keys: Keys to pick image data for transformation. box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``. - box_ref_image_keys: Keys that represents the reference images to which ``box_keys`` are attached. + box_ref_image_keys: Keys that represent the reference images to which ``box_keys`` are attached. spatial_axis: Spatial axes along which to flip over. Default is None. allow_missing_keys: don't raise exception if key is missing. """ @@ -641,7 +648,7 @@ class RandFlipBoxd(RandomizableTransform, MapTransform, InvertibleTransform): Args: image_keys: Keys to pick image data for transformation. box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``. - box_ref_image_keys: Keys that represents the reference images to which ``box_keys`` are attached. + box_ref_image_keys: Keys that represent the reference images to which ``box_keys`` are attached. prob: Probability of flipping. spatial_axis: Spatial axes along which to flip over. Default is None. allow_missing_keys: don't raise exception if key is missing. @@ -721,7 +728,7 @@ class ClipBoxToImaged(MapTransform): Args: box_keys: The single key to pick box data for transformation. The box mode is assumed to be ``StandardMode``. - label_keys: Keys that represents the labels corresponding to the ``box_keys``. Multiple keys are allowed. + label_keys: Keys that represent the labels corresponding to the ``box_keys``. Multiple keys are allowed. box_ref_image_keys: The single key that represents the reference image to which ``box_keys`` and ``label_keys`` are attached. remove_empty: whether to remove the boxes that are actually empty @@ -791,8 +798,8 @@ class BoxToMaskd(MapTransform): Args: box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``. box_mask_keys: Keys to store output box mask results for transformation. Same length with ``box_keys``. - label_keys: Keys that represents the labels corresponding to the ``box_keys``. Same length with ``box_keys``. - box_ref_image_keys: Keys that represents the reference images to which ``box_keys`` are attached. + label_keys: Keys that represent the labels corresponding to the ``box_keys``. Same length with ``box_keys``. + box_ref_image_keys: Keys that represent the reference images to which ``box_keys`` are attached. min_fg_label: min foreground box label. ellipse_mask: bool. @@ -879,7 +886,7 @@ class MaskToBoxd(MapTransform): Args: box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``. box_mask_keys: Keys to store output box mask results for transformation. Same length with ``box_keys``. - label_keys: Keys that represents the labels corresponding to the ``box_keys``. Same length with ``box_keys``. + label_keys: Keys that represent the labels corresponding to the ``box_keys``. Same length with ``box_keys``. min_fg_label: min foreground box label. box_dtype: output dtype for box_keys label_dtype: output dtype for label_keys @@ -954,7 +961,7 @@ class RandCropBoxByPosNegLabeld(Randomizable, MapTransform): Args: image_keys: Keys to pick image data for transformation. They need to have the same spatial size. box_keys: The single key to pick box data for transformation. The box mode is assumed to be ``StandardMode``. - label_keys: Keys that represents the labels corresponding to the ``box_keys``. Multiple keys are allowed. + label_keys: Keys that represent the labels corresponding to the ``box_keys``. Multiple keys are allowed. spatial_size: the spatial size of the crop region e.g. [224, 224, 128]. if a dimension of ROI size is bigger than image size, will not crop that dimension of the image. if its components have non-positive values, the corresponding size of `data[label_key]` will be used. @@ -1161,6 +1168,167 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab return results +class RotateBox90d(MapTransform, InvertibleTransform): + """ + Input boxes and images are rotated by 90 degrees + in the plane specified by ``spatial_axes`` for ``k`` times + + Args: + image_keys: Keys to pick image data for transformation. + box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``. + box_ref_image_keys: Keys that represent the reference images to which ``box_keys`` are attached. + k: number of times to rotate by 90 degrees. + spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. + Default (0, 1), this is the first two axis in spatial dimensions. + allow_missing_keys: don't raise exception if key is missing. + """ + + backend = RotateBox90.backend + + def __init__( + self, + image_keys: KeysCollection, + box_keys: KeysCollection, + box_ref_image_keys: KeysCollection, + k: int = 1, + spatial_axes: Tuple[int, int] = (0, 1), + allow_missing_keys: bool = False, + ) -> None: + self.image_keys = ensure_tuple(image_keys) + self.box_keys = ensure_tuple(box_keys) + super().__init__(self.image_keys + self.box_keys, allow_missing_keys) + self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys)) + self.img_rotator = Rotate90(k, spatial_axes) + self.box_rotator = RotateBox90(k, spatial_axes) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + d = dict(data) + for key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys): + spatial_size = list(d[box_ref_image_key].shape[1:]) + d[key] = self.box_rotator(d[key], spatial_size) + if self.img_rotator.k % 2 == 1: + # if k = 1 or 3, spatial_size will be transposed + spatial_size[self.img_rotator.spatial_axes[0]], spatial_size[self.img_rotator.spatial_axes[1]] = ( + spatial_size[self.img_rotator.spatial_axes[1]], + spatial_size[self.img_rotator.spatial_axes[0]], + ) + self.push_transform(d, key, extra_info={"spatial_size": spatial_size, "type": "box_key"}) + + for key in self.image_keys: + d[key] = self.img_rotator(d[key]) + self.push_transform(d, key, extra_info={"type": "image_key"}) + return d + + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = deepcopy(dict(data)) + + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + key_type = transform[TraceKeys.EXTRA_INFO]["type"] + num_times_to_rotate = 4 - self.img_rotator.k + + if key_type == "image_key": + inverse_transform = Rotate90(num_times_to_rotate, self.img_rotator.spatial_axes) + d[key] = inverse_transform(d[key]) + if key_type == "box_key": + spatial_size = transform[TraceKeys.EXTRA_INFO]["spatial_size"] + inverse_transform = RotateBox90(num_times_to_rotate, self.box_rotator.spatial_axes) + d[key] = inverse_transform(d[key], spatial_size) + self.pop_transform(d, key) + return d + + +class RandRotateBox90d(RandRotate90d): + """ + With probability `prob`, input boxes and images are rotated by 90 degrees + in the plane specified by `spatial_axes`. + + Args: + image_keys: Keys to pick image data for transformation. + box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``. + box_ref_image_keys: Keys that represent the reference images to which ``box_keys`` are attached. + prob: probability of rotating. + (Default 0.1, with 10% probability it returns a rotated array.) + max_k: number of rotations will be sampled from `np.random.randint(max_k) + 1`. + (Default 3) + spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. + Default: (0, 1), this is the first two axis in spatial dimensions. + allow_missing_keys: don't raise exception if key is missing. + """ + + backend = RotateBox90.backend + + def __init__( + self, + image_keys: KeysCollection, + box_keys: KeysCollection, + box_ref_image_keys: KeysCollection, + prob: float = 0.1, + max_k: int = 3, + spatial_axes: Tuple[int, int] = (0, 1), + allow_missing_keys: bool = False, + ) -> None: + self.image_keys = ensure_tuple(image_keys) + self.box_keys = ensure_tuple(box_keys) + super().__init__(self.image_keys + self.box_keys, prob, max_k, spatial_axes, allow_missing_keys) + self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys)) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + self.randomize() + d = dict(data) + + if self._rand_k % 4 == 0: + return d + + # FIXME: here we didn't use array version `RandRotate90` transform as others, because we need + # to be compatible with the random status of some previous integration tests + box_rotator = RotateBox90(self._rand_k, self.spatial_axes) + img_rotator = Rotate90(self._rand_k, self.spatial_axes) + + for key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys): + if self._do_transform: + spatial_size = list(d[box_ref_image_key].shape[1:]) + d[key] = box_rotator(d[key], spatial_size) + if self._rand_k % 2 == 1: + # if k = 1 or 3, spatial_size will be transposed + spatial_size[self.spatial_axes[0]], spatial_size[self.spatial_axes[1]] = ( + spatial_size[self.spatial_axes[1]], + spatial_size[self.spatial_axes[0]], + ) + self.push_transform( + d, key, extra_info={"rand_k": self._rand_k, "spatial_size": spatial_size, "type": "box_key"} + ) + + for key in self.image_keys: + if self._do_transform: + d[key] = img_rotator(d[key]) + self.push_transform(d, key, extra_info={"rand_k": self._rand_k, "type": "image_key"}) + return d + + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = deepcopy(dict(data)) + if self._rand_k % 4 == 0: + return d + + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + key_type = transform[TraceKeys.EXTRA_INFO]["type"] + # Check if random transform was actually performed (based on `prob`) + if transform[TraceKeys.DO_TRANSFORM]: + num_times_rotated = transform[TraceKeys.EXTRA_INFO]["rand_k"] + num_times_to_rotate = 4 - num_times_rotated + # flip image, copied from monai.transforms.spatial.dictionary.RandFlipd + if key_type == "image_key": + inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) + d[key] = inverse_transform(d[key]) + if key_type == "box_key": + spatial_size = transform[TraceKeys.EXTRA_INFO]["spatial_size"] + inverse_transform = RotateBox90(num_times_to_rotate, self.spatial_axes) + d[key] = inverse_transform(d[key], spatial_size) + self.pop_transform(d, key) + return d + + ConvertBoxModeD = ConvertBoxModeDict = ConvertBoxModed ConvertBoxToStandardModeD = ConvertBoxToStandardModeDict = ConvertBoxToStandardModed ZoomBoxD = ZoomBoxDict = ZoomBoxd @@ -1172,3 +1340,5 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab BoxToMaskD = BoxToMaskDict = BoxToMaskd MaskToBoxD = MaskToBoxDict = MaskToBoxd RandCropBoxByPosNegLabelD = RandCropBoxByPosNegLabelDict = RandCropBoxByPosNegLabeld +RotateBox90D = RotateBox90Dict = RotateBox90d +RandRotateBox90D = RandRotateBox90Dict = RandRotateBox90d diff --git a/tests/test_box_transform.py b/tests/test_box_transform.py index ea999ecb91..5d984175aa 100644 --- a/tests/test_box_transform.py +++ b/tests/test_box_transform.py @@ -24,7 +24,9 @@ MaskToBoxd, RandCropBoxByPosNegLabeld, RandFlipBoxd, + RandRotateBox90d, RandZoomBoxd, + RotateBox90d, ZoomBoxd, ) from monai.transforms import CastToTyped, Invertd @@ -47,6 +49,7 @@ p([[1, -6, -1, 1, -6, -1], [1, -3, -1, 2, 3, 3.5], [1, -3, 0.5, 2, 3, 5]]), p([[4, 6, 4, 4, 6, 4], [2, 3, 1, 4, 5, 4], [2, 3, 0, 4, 5, 3]]), p([[0, 1, 0, 2, 3, 3], [0, 1, 1, 2, 3, 4]]), + p([[6, 0, 0, 6, 0, 0], [3, 0, 0, 5, 2, 3], [3, 0, 1, 5, 2, 4]]), ] ) @@ -118,6 +121,7 @@ def test_value_3d( expected_zoom_keepsize_result, expected_flip_result, expected_clip_result, + expected_rotate_result, ): test_dtype = [torch.float32] for dtype in test_dtype: @@ -254,6 +258,30 @@ def test_value_3d( atol=1e-3, ) + # test RotateBox90d + transform_rotate = RotateBox90d( + image_keys="image", box_keys="boxes", box_ref_image_keys="image", k=1, spatial_axes=[0, 1] + ) + rotate_result = transform_rotate(data) + assert_allclose(rotate_result["boxes"], expected_rotate_result, type_test=True, device_test=True, atol=1e-3) + invert_transform_rotate = Invertd( + keys=["image", "boxes"], transform=transform_rotate, orig_keys=["image", "boxes"] + ) + data_back = invert_transform_rotate(rotate_result) + assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) + assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) + + transform_rotate = RandRotateBox90d( + image_keys="image", box_keys="boxes", box_ref_image_keys="image", prob=1.0, max_k=3, spatial_axes=[0, 1] + ) + rotate_result = transform_rotate(data) + invert_transform_rotate = Invertd( + keys=["image", "boxes"], transform=transform_rotate, orig_keys=["image", "boxes"] + ) + data_back = invert_transform_rotate(rotate_result) + assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) + assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) + if __name__ == "__main__": unittest.main()