diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3743581794f..ca78e667b83 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1613,12 +1613,12 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes_expected -@pytest.mark.parametrize("min_size", (1, 10)) +@pytest.mark.parametrize("min_size,min_area", [(1, 1), (10, 1), (10, 101)]) @pytest.mark.parametrize( "labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None) ) @pytest.mark.parametrize("sample_type", (tuple, dict)) -def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): +def test_sanitize_bounding_boxes(min_size, min_area, labels_getter, sample_type): if sample_type is tuple and not isinstance(labels_getter, str): # The "lambda inputs: inputs["labels"]" labels_getter used in this test @@ -1634,12 +1634,13 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): ([0, 0, 10, min_size - 1], False), # W < min_size ([0, 0, 10, H + 1], False), # Y2 > H ([0, 0, W + 1, 10], False), # X2 > W - ([-1, 1, 10, 20], False), # any < 0 - ([0, 0, -1, 20], False), # any < 0 - ([0, 0, -10, -1], False), # any < 0 - ([0, 0, min_size, 10], True), # H < min_size - ([0, 0, 10, min_size], True), # W < min_size - ([0, 0, W, H], True), # TODO: Is that actually OK?? Should it be -1? + ([-1, 1, 10, 20], False), # X1 < 0 + ([0, -1, 10, 20], False), # Y1 < 0 + ([0, 0, -1, 20], False), # X2 < 0 + ([0, 0, 10, -1], False), # Y2 < 0 + ([0, 0, min_size, 10], min_area / 10 <= min_size), # H >= min_size + ([0, 0, 10, min_size], min_area / 10 <= min_size), # W >= min_size + ([0, 0, W, H], True), ([1, 1, 30, 20], True), ([0, 0, 10, 10], True), ([1, 1, 30, 20], True), @@ -1674,7 +1675,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): img = sample.pop("image") sample = (img, sample) - out = transforms.SanitizeBoundingBox(min_size=min_size, labels_getter=labels_getter)(sample) + out = transforms.SanitizeBoundingBox(min_size=min_size, min_area=min_area, labels_getter=labels_getter)(sample) if sample_type is tuple: out_image = out[0] @@ -1730,6 +1731,8 @@ def test_sanitize_bounding_boxes_errors(): with pytest.raises(ValueError, match="min_size must be >= 1"): transforms.SanitizeBoundingBox(min_size=0) + with pytest.raises(ValueError, match="min_area must be >= 1"): + transforms.SanitizeBoundingBox(min_area=0) with pytest.raises(ValueError, match="labels_getter should either be a str"): transforms.SanitizeBoundingBox(labels_getter=12) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 90741c4ec7d..f138664bbbb 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -264,7 +264,7 @@ class SanitizeBoundingBox(Transform): This transform removes bounding boxes and their associated labels/masks that: - - are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1. + - are below a given ``min_size`` or ``min_area``: by default this also removes degenerate boxes that have e.g. X2 <= X1. - have any coordinate outside of their corresponding image. You may want to call :class:`~torchvision.transforms.v2.ClampBoundingBox` first to avoid undesired removals. @@ -277,6 +277,7 @@ class SanitizeBoundingBox(Transform): Args: min_size (float, optional) The size below which bounding boxes are removed. Default is 1. + min_area (float, optional) The area below which bounding boxes are removed. Default is 1. labels_getter (callable or str or None, optional): indicates how to identify the labels in the input. It can be a str in which case the input is expected to be a dict, and ``labels_getter`` then specifies the key whose value corresponds to the labels. It can also be a callable that takes the same input @@ -289,6 +290,7 @@ class SanitizeBoundingBox(Transform): def __init__( self, min_size: float = 1.0, + min_area: float = 1.0, labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default", ) -> None: super().__init__() @@ -297,6 +299,10 @@ def __init__( raise ValueError(f"min_size must be >= 1, got {min_size}.") self.min_size = min_size + if min_area < 1: + raise ValueError(f"min_area must be >= 1, got {min_area}.") + self.min_area = min_area + self.labels_getter = labels_getter self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]] if labels_getter == "default": @@ -381,10 +387,11 @@ def forward(self, *inputs: Any) -> Any: ), ) ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] - valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) + valid = (ws >= self.min_size) & (hs >= self.min_size) & (ws * hs >= self.min_area) # TODO: Do we really need to check for out of bounds here? All # transforms should be clamping anyway, so this should never happen? image_h, image_w = boxes.spatial_size + valid &= (boxes >= 0).all(dim=-1) valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)