Skip to content

Commit

Permalink
SanitizeBoundingBox based on minimum area
Browse files Browse the repository at this point in the history
  • Loading branch information
antoinebrl committed Jul 11, 2023
1 parent 08c9938 commit 5e5358b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
21 changes: 12 additions & 9 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,12 +1613,12 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes_expected


@pytest.mark.parametrize("min_size", (1, 10))
@pytest.mark.parametrize("min_size,min_area", [(1, 1), (10, 1), (10, 101)])
@pytest.mark.parametrize(
"labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None)
)
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
def test_sanitize_bounding_boxes(min_size, min_area, labels_getter, sample_type):

if sample_type is tuple and not isinstance(labels_getter, str):
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
Expand All @@ -1634,12 +1634,13 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
([0, 0, 10, min_size - 1], False), # W < min_size
([0, 0, 10, H + 1], False), # Y2 > H
([0, 0, W + 1, 10], False), # X2 > W
([-1, 1, 10, 20], False), # any < 0
([0, 0, -1, 20], False), # any < 0
([0, 0, -10, -1], False), # any < 0
([0, 0, min_size, 10], True), # H < min_size
([0, 0, 10, min_size], True), # W < min_size
([0, 0, W, H], True), # TODO: Is that actually OK?? Should it be -1?
([-1, 1, 10, 20], False), # X1 < 0
([0, -1, 10, 20], False), # Y1 < 0
([0, 0, -1, 20], False), # X2 < 0
([0, 0, 10, -1], False), # Y2 < 0
([0, 0, min_size, 10], min_area / 10 <= min_size), # H >= min_size
([0, 0, 10, min_size], min_area / 10 <= min_size), # W >= min_size
([0, 0, W, H], True),
([1, 1, 30, 20], True),
([0, 0, 10, 10], True),
([1, 1, 30, 20], True),
Expand Down Expand Up @@ -1674,7 +1675,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
img = sample.pop("image")
sample = (img, sample)

out = transforms.SanitizeBoundingBox(min_size=min_size, labels_getter=labels_getter)(sample)
out = transforms.SanitizeBoundingBox(min_size=min_size, min_area=min_area, labels_getter=labels_getter)(sample)

if sample_type is tuple:
out_image = out[0]
Expand Down Expand Up @@ -1730,6 +1731,8 @@ def test_sanitize_bounding_boxes_errors():

with pytest.raises(ValueError, match="min_size must be >= 1"):
transforms.SanitizeBoundingBox(min_size=0)
with pytest.raises(ValueError, match="min_area must be >= 1"):
transforms.SanitizeBoundingBox(min_area=0)
with pytest.raises(ValueError, match="labels_getter should either be a str"):
transforms.SanitizeBoundingBox(labels_getter=12)

Expand Down
11 changes: 9 additions & 2 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class SanitizeBoundingBox(Transform):
This transform removes bounding boxes and their associated labels/masks that:
- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- are below a given ``min_size`` or ``min_area``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- have any coordinate outside of their corresponding image. You may want to
call :class:`~torchvision.transforms.v2.ClampBoundingBox` first to avoid undesired removals.
Expand All @@ -277,6 +277,7 @@ class SanitizeBoundingBox(Transform):
Args:
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
min_area (float, optional) The area below which bounding boxes are removed. Default is 1.
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input.
It can be a str in which case the input is expected to be a dict, and ``labels_getter`` then specifies
the key whose value corresponds to the labels. It can also be a callable that takes the same input
Expand All @@ -289,6 +290,7 @@ class SanitizeBoundingBox(Transform):
def __init__(
self,
min_size: float = 1.0,
min_area: float = 1.0,
labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default",
) -> None:
super().__init__()
Expand All @@ -297,6 +299,10 @@ def __init__(
raise ValueError(f"min_size must be >= 1, got {min_size}.")
self.min_size = min_size

if min_area < 1:
raise ValueError(f"min_area must be >= 1, got {min_area}.")
self.min_area = min_area

self.labels_getter = labels_getter
self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]]
if labels_getter == "default":
Expand Down Expand Up @@ -381,10 +387,11 @@ def forward(self, *inputs: Any) -> Any:
),
)
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
valid = (ws >= self.min_size) & (hs >= self.min_size) & (ws * hs >= self.min_area)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h, image_w = boxes.spatial_size
valid &= (boxes >= 0).all(dim=-1)
valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)

Expand Down

0 comments on commit 5e5358b

Please sign in to comment.