Skip to content

Commit

Permalink
Update segment cropping (#31)
Browse files Browse the repository at this point in the history
* Allow uncompressed rle in crop_covered

* Add an option to prohibit segment removal in crop covered segments

* Improve type declarations for RLE
  • Loading branch information
zhiltsov-max committed Nov 20, 2023
1 parent 91c7499 commit 26bd789
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 46 deletions.
2 changes: 1 addition & 1 deletion datumaro/plugins/coco_format/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def crop_segments(cls, instances, img_width, img_height):
for inst_idx, inst in enumerate(instances):
new_segments = [s for si_id, s in zip(segment_map, segments) if si_id == inst_idx]

if not new_segments:
if not new_segments or isinstance(new_segments[0], list) and len(new_segments[0]) == 0:
inst[1] = []
inst[2] = None
continue
Expand Down
34 changes: 31 additions & 3 deletions datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,23 @@ class CropCoveredSegments(ItemTransform, CliPlugin):
the corresponding number of separate annotations joined into a group.
"""

ALLOW_REMOVAL_ARG = "--allow-removal"

@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument(
cls.ALLOW_REMOVAL_ARG,
action="store_true",
help="Allow automatic removal of completely covered segments (default: %(default)s)",
)
return parser

def __init__(self, extractor, allow_removal=False):
super().__init__(extractor)

self._allow_removal = allow_removal

def transform_item(self, item):
annotations = []
segments = []
Expand All @@ -71,13 +88,15 @@ def transform_item(self, item):
if not isinstance(item.media, Image):
raise Exception("Image info is required for this transform")
h, w = item.media.size
segments = self.crop_segments(segments, w, h)
segments = self.crop_segments(segments, w, h, item=item, allow_removal=self._allow_removal)

annotations += segments
return self.wrap_item(item, annotations=annotations)

@classmethod
def crop_segments(cls, segment_anns, img_width, img_height):
def crop_segments(
cls, segment_anns, img_width, img_height, *, item: DatasetItem, allow_removal: bool = False
):
segment_anns = sorted(segment_anns, key=lambda x: x.z_order)

segments = []
Expand All @@ -95,6 +114,15 @@ def crop_segments(cls, segment_anns, img_width, img_height):

new_anns = []
for ann, new_segment in zip(segment_anns, segments):
if new_segment is None or isinstance(new_segment, list) and not new_segment:
message = "completely covered object removed " "(allow with '%s')" % (
cls.ALLOW_REMOVAL_ARG,
)
if not allow_removal:
raise DatumaroError(("Item %s: " + message) % (item.id,))
else:
log.debug("[%s]: item %s: " + message, cls.NAME, item.id)

fields = {
"z_order": ann.z_order,
"label": ann.label,
Expand All @@ -107,7 +135,7 @@ def crop_segments(cls, segment_anns, img_width, img_height):
fields["group"] = cls._make_group_id(segment_anns + new_anns, fields["id"])
for polygon in new_segment:
new_anns.append(Polygon(points=polygon, **fields))
else:
elif new_segment is not None:
rle = mask_tools.mask_to_rle(new_segment)
rle = mask_utils.frPyObjects(rle, *rle["size"])
new_anns.append(RleMask(rle=rle, **fields))
Expand Down
117 changes: 81 additions & 36 deletions datumaro/util/mask_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,33 @@

from functools import partial
from itertools import chain
from typing import Tuple
from typing import List, NamedTuple, NewType, Optional, Sequence, Tuple, TypedDict, Union

import numpy as np

from datumaro.util.image import lazy_image, load_image


class UncompressedRle(TypedDict):
size: Sequence[int]
counts: bytes


class CompressedRle(TypedDict):
size: Sequence[int]
counts: Sequence[int]


Rle = Union[CompressedRle, UncompressedRle]
Polygon = List[List[int]]
BboxCoords = NamedTuple("BboxCoords", [("x", int), ("y", int), ("w", int), ("h", int)])
Segment = Union[Polygon, Rle]

BinaryMask = NewType("BinaryMask", np.ndarray)
IndexMask = NewType("IndexMask", np.ndarray)
ColorMask = NewType("ColorMask", np.ndarray)


def generate_colormap(length=256, *, include_background=True):
"""
Generates colors using PASCAL VOC algorithm.
Expand Down Expand Up @@ -42,7 +62,7 @@ def invert_colormap(colormap):
return {tuple(a): index for index, a in colormap.items()}


def check_is_mask(mask):
def check_is_mask(mask: np.ndarray) -> bool:
assert len(mask.shape) in {2, 3}
if len(mask.shape) == 3:
assert mask.shape[2] == 1
Expand All @@ -52,7 +72,7 @@ def check_is_mask(mask):
_default_unpaint_colormap = invert_colormap(_default_colormap)


def unpaint_mask(painted_mask, inverse_colormap=None):
def unpaint_mask(painted_mask: ColorMask, inverse_colormap=None) -> IndexMask:
"""
Convert color mask to index mask
Expand Down Expand Up @@ -86,7 +106,7 @@ def unpaint_mask(painted_mask, inverse_colormap=None):
return unpainted_mask


def paint_mask(mask, colormap=None):
def paint_mask(mask: IndexMask, colormap=None) -> ColorMask:
"""
Applies colormap to index mask
Expand All @@ -109,7 +129,7 @@ def paint_mask(mask, colormap=None):
return painted_mask


def remap_mask(mask, map_fn):
def remap_mask(mask: ColorMask, map_fn) -> ColorMask:
"""
Changes mask elements from one colormap to another
Expand All @@ -120,11 +140,11 @@ def remap_mask(mask, map_fn):
return np.array([max(0, map_fn(c)) for c in range(256)], dtype=np.uint8)[mask]


def make_index_mask(binary_mask, index, dtype=None):
def make_index_mask(binary_mask: BinaryMask, index: int, dtype=None) -> IndexMask:
return binary_mask * np.array([index], dtype=dtype or np.min_scalar_type(index))


def make_binary_mask(mask):
def make_binary_mask(mask: Union[BinaryMask, IndexMask]) -> BinaryMask:
if mask.dtype.kind == "b":
return mask
return mask.astype(bool)
Expand Down Expand Up @@ -152,7 +172,7 @@ def lazy_mask(path, inverse_colormap=None):
return lazy_image(path, partial(load_mask, inverse_colormap=inverse_colormap))


def mask_to_rle(binary_mask):
def mask_to_rle(binary_mask: BinaryMask) -> CompressedRle:
# walk in row-major order as COCO format specifies
bounded = binary_mask.ravel(order="F")

Expand All @@ -170,7 +190,7 @@ def mask_to_rle(binary_mask):
return {"counts": counts, "size": list(binary_mask.shape)}


def mask_to_polygons(mask, area_threshold=1):
def mask_to_polygons(mask: BinaryMask, area_threshold=1) -> List[Polygon]:
"""
Convert an instance mask to polygons
Expand Down Expand Up @@ -210,15 +230,36 @@ def mask_to_polygons(mask, area_threshold=1):
return polygons


def is_uncompressed_rle(obj: Segment) -> bool:
return isinstance(obj, dict) and isinstance(obj.get("counts"), bytes)


def is_polygon(obj: Segment) -> bool:
return (
isinstance(obj, list)
and isinstance(obj[0], list)
and (len(obj[0]) == 0 or isinstance(obj[0][0], int))
)


def to_uncompressed_rle(rle: Rle, *, width: int, height: int) -> UncompressedRle:
if is_uncompressed_rle(rle):
return rle

from pycocotools import mask as mask_utils

return mask_utils.frPyObjects(rle, height, width)


def crop_covered_segments(
segments,
width,
height,
iou_threshold=0.0,
ratio_tolerance=0.001,
area_threshold=1,
return_masks=False,
):
segments: Sequence[Segment],
width: int,
height: int,
iou_threshold: float = 0.0,
ratio_tolerance: float = 0.001,
area_threshold: int = 1,
return_masks: bool = False,
) -> List[Union[Optional[BinaryMask], Polygon]]:
"""
Find all segments occluded by others and crop them to the visible part only.
Input segments are expected to be sorted from background to foreground.
Expand Down Expand Up @@ -248,13 +289,18 @@ def crop_covered_segments(
"""
from pycocotools import mask as mask_utils

segments = [[s] for s in segments]
input_rles = [mask_utils.frPyObjects(s, height, width) for s in segments]
# Convert to uncompressed RLEs
wrapped_segments = [[s] for s in segments]
input_rles = [
mask_utils.frPyObjects(s, height, width) if not is_uncompressed_rle(s[0]) else s
for s in wrapped_segments
]

output_segments = []
for i, rle_bottom in enumerate(input_rles):
area_bottom = sum(mask_utils.area(rle_bottom))
if area_bottom < area_threshold:
segments[i] = [] if not return_masks else None
output_segments.append([] if not return_masks else None)
continue

rles_top = []
Expand All @@ -268,19 +314,14 @@ def crop_covered_segments(
area_top = sum(mask_utils.area(rle_top))
area_ratio = area_top / area_bottom

# If a segment is fully inside another one, skip this segment
# If a segment is already fully inside the top ones, stop accumulating the top
if abs(area_ratio - iou) < ratio_tolerance:
continue

# Check if the bottom segment is fully covered by the top one.
# There is a mistake in the annotation, keep the background one
if abs(1 / area_ratio - iou) < ratio_tolerance:
rles_top = []
break

rles_top += rle_top

if not rles_top and not isinstance(segments[i][0], dict) and not return_masks:
if not rles_top and is_polygon(wrapped_segments[i]) and not return_masks:
output_segments.append(wrapped_segments[i])
continue

rle_bottom = rle_bottom[0]
Expand All @@ -293,15 +334,17 @@ def crop_covered_segments(
bottom_mask -= top_mask
bottom_mask[bottom_mask != 1] = 0

if not return_masks and not isinstance(segments[i][0], dict):
segments[i] = mask_to_polygons(bottom_mask, area_threshold=area_threshold)
if not return_masks and is_polygon(wrapped_segments[i]):
output_segments.append(mask_to_polygons(bottom_mask, area_threshold=area_threshold))
else:
segments[i] = bottom_mask
if np.sum(bottom_mask) < area_threshold:
bottom_mask = None
output_segments.append(bottom_mask)

return segments
return output_segments


def rles_to_mask(rles, width, height):
def rles_to_mask(rles: Sequence[Union[CompressedRle, Polygon]], width, height) -> BinaryMask:
from pycocotools import mask as mask_utils

rles = mask_utils.frPyObjects(rles, height, width)
Expand All @@ -310,15 +353,17 @@ def rles_to_mask(rles, width, height):
return mask


def find_mask_bbox(mask) -> Tuple[int, int, int, int]:
def find_mask_bbox(mask: BinaryMask) -> BboxCoords:
cols = np.any(mask, axis=0)
rows = np.any(mask, axis=1)
x0, x1 = np.where(cols)[0][[0, -1]]
y0, y1 = np.where(rows)[0][[0, -1]]
return (x0, y0, x1 - x0, y1 - y0)
return BboxCoords(x0, y0, x1 - x0, y1 - y0)


def merge_masks(masks, start=None):
def merge_masks(
masks: Sequence[Union[IndexMask, Tuple[BinaryMask, int]]], start: Optional[BinaryMask] = None
) -> IndexMask:
"""
Merges masks into one, mask order is responsible for z order.
To avoid memory explosion on mask materialization, consider passing
Expand Down
44 changes: 38 additions & 6 deletions tests/test_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from .requirements import Requirements, mark_requirement


def _compare_polygons(a, b) -> bool:
return len(a) == len(b) and frozenset(map(frozenset, a)) == frozenset(map(frozenset, b))


class PolygonConversionsTest(TestCase):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_mask_can_be_converted_to_polygon(self):
Expand All @@ -27,13 +31,13 @@ def test_mask_can_be_converted_to_polygon(self):

computed = mask_tools.mask_to_polygons(mask)

self.assertEqual(len(expected), len(computed))
self.assertTrue(_compare_polygons(expected, computed))

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_crop_covered_segments(self):
image_size = [7, 7]
initial = [
[1, 1, 6, 1, 6, 6, 1, 6], # rectangle
[1, 1, 6, 1, 6, 6, 1, 6], # rectangle polygon
mask_tools.mask_to_rle(
np.array(
[
Expand All @@ -46,16 +50,33 @@ def test_can_crop_covered_segments(self):
[0, 0, 0, 0, 0, 0, 0],
]
)
),
[1, 1, 6, 6, 1, 6], # lower-left triangle
), # compressed RLE
mask_tools.to_uncompressed_rle(
mask_tools.mask_to_rle(
np.array(
[
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
]
)
),
width=image_size[1],
height=image_size[0],
), # uncompressed RLE
[1, 1, 6, 6, 1, 6], # lower-left triangle polygon
]
expected = [
np.array(
[
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
Expand All @@ -72,7 +93,18 @@ def test_can_crop_covered_segments(self):
[0, 0, 0, 0, 0, 0, 0],
]
), # half-covered
mask_tools.rles_to_mask([initial[2]], *image_size), # unchanged
np.array(
[
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
]
), # unchanged
mask_tools.rles_to_mask([initial[3]], *image_size), # unchanged
]

computed = mask_tools.crop_covered_segments(
Expand Down
Loading

0 comments on commit 26bd789

Please sign in to comment.