Skip to content

Commit

Permalink
Update object matching algorithm (#30)
Browse files Browse the repository at this point in the history
* Change segment matching strategy
  • Loading branch information
zhiltsov-max committed Nov 20, 2023
1 parent 6b65cf8 commit 91c7499
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 49 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/cvat-ai/datumaro/pull/5>)
- item id in MOT format
(<https://github.com/cvat-ai/datumaro/pull/17>)
- Annotation matching algorithm in `datumaro.components.operations.match_segments()`
(<https://github.com/cvat-ai/datumaro/pull/30>)

### Deprecated
- `--save-images` is replaced with `--save-media` in CLI and converter API
Expand Down
129 changes: 81 additions & 48 deletions datumaro/components/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,31 @@
# SPDX-License-Identifier: MIT

import hashlib
import itertools
import logging as log
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
)
from unittest import TestCase

import attr
import cv2
import numpy as np
from attr import attrib, attrs
from scipy.optimize import linear_sum_assignment

from datumaro.components.annotation import (
Annotation,
Expand Down Expand Up @@ -1241,65 +1256,83 @@ class ImageAnnotationMerger(AnnotationMerger, ImageAnnotationMatcher):
pass


_AT1 = TypeVar("_AT1")
_AT2 = TypeVar("_AT2")


def match_segments(
a_segms,
b_segms,
distance=segment_iou,
dist_thresh=1.0,
label_matcher=lambda a, b: a.label == b.label,
):
a_segms: Sequence[_AT1],
b_segms: Sequence[_AT2],
distance: Callable[[_AT1, _AT2], float] = segment_iou,
dist_thresh: float = 1.0,
label_matcher: Callable[[_AT1, _AT2], bool] = lambda a, b: a.label == b.label,
) -> Tuple[List[Tuple[_AT1, _AT2]], List[Tuple[_AT1, _AT2]], List[_AT1], List[_AT2]]:
"""
Finds the best matching annotations using the provided distance function.
If the annotations match by distance, but have different labels,
they are considered mismatching.
Parameters:
- distance: func(a_ann, b_ann) -> float [0; 1] - a function that estimates annotation
similarity, with 0 meaning 'not similar' and 1 - 'exactly the same'.
- dist_thresh: a value in the range [0; 1], minimal distance between a pair of annotations
to be considered for matching
Returns (matching, mismatching, a_unmatched, b_unmatched), where:
- 'matching' and 'mismatching' - lists of (a_ann, b_ann) tuples
- 'a_unmatched' and 'b_unmatched' - lists of corresponding unmatched annotations
"""

assert callable(distance), distance
assert callable(label_matcher), label_matcher

a_segms.sort(key=lambda ann: 1 - ann.attributes.get("score", 1))
b_segms.sort(key=lambda ann: 1 - ann.attributes.get("score", 1))

# a_matches: indices of b_segms matched to a bboxes
# b_matches: indices of a_segms matched to b bboxes
a_matches = -np.ones(len(a_segms), dtype=int)
b_matches = -np.ones(len(b_segms), dtype=int)
max_anns = max(len(a_segms), len(b_segms))
distances = np.array(
[
[
1 - distance(a, b) if a is not None and b is not None else 1
for b, _ in itertools.zip_longest(b_segms, range(max_anns), fillvalue=None)
]
for a, _ in itertools.zip_longest(a_segms, range(max_anns), fillvalue=None)
]
)
distances[distances > 1 - dist_thresh] = 1

distances = np.array([[distance(a, b) for b in b_segms] for a in a_segms])
if a_segms and b_segms:
a_matches, b_matches = linear_sum_assignment(distances)
else:
a_matches = []
b_matches = []

# matches: boxes we succeeded to match completely
# mispred: boxes we succeeded to match, having label mismatch
matches = []
mispred = []

for a_idx, a_segm in enumerate(a_segms):
if len(b_segms) == 0:
break
matched_b = -1
max_dist = -1
b_indices = np.argsort(
[not label_matcher(a_segm, b_segm) for b_segm in b_segms], kind="stable"
) # prioritize those with same label, keep score order
for b_idx in b_indices:
if 0 <= b_matches[b_idx]: # assign a_segm with max conf
continue
d = distances[a_idx, b_idx]
if d < dist_thresh or d <= max_dist:
continue
max_dist = d
matched_b = b_idx

if matched_b < 0:
continue
a_matches[a_idx] = matched_b
b_matches[matched_b] = a_idx

b_segm = b_segms[matched_b]

if label_matcher(a_segm, b_segm):
matches.append((a_segm, b_segm))
mismatches = []
# *_umatched: boxes of (*) we failed to match
a_unmatched = []
b_unmatched = []

for a_idx, b_idx in zip(a_matches, b_matches):
dist = distances[a_idx, b_idx]
if dist > 1 - dist_thresh or dist == 1:
if a_idx < len(a_segms):
a_unmatched.append(a_segms[a_idx])
if b_idx < len(b_segms):
b_unmatched.append(b_segms[b_idx])
else:
mispred.append((a_segm, b_segm))
a_ann = a_segms[a_idx]
b_ann = b_segms[b_idx]
if label_matcher(a_ann, b_ann):
matches.append((a_ann, b_ann))
else:
mismatches.append((a_ann, b_ann))

# *_umatched: boxes of (*) we failed to match
a_unmatched = [a_segms[i] for i, m in enumerate(a_matches) if m < 0]
b_unmatched = [b_segms[i] for i, m in enumerate(b_matches) if m < 0]
if not len(a_matches) and not len(b_matches):
a_unmatched = list(a_segms)
b_unmatched = list(b_segms)

return matches, mispred, a_unmatched, b_unmatched
return matches, mismatches, a_unmatched, b_unmatched


def mean_std(dataset: IDataset):
Expand Down
2 changes: 1 addition & 1 deletion requirements-core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ tensorboardX>=1.8,!=2.3

# Builtin plugin dependencies

# NDR, matlab format
# NDR, matlab format, assigment
scipy

# Image generator
Expand Down
66 changes: 66 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
compute_ann_statistics,
compute_image_statistics,
find_unique_images,
match_segments,
mean_std,
)
from datumaro.util.test_utils import compare_datasets
Expand Down Expand Up @@ -398,6 +399,71 @@ def test_unique_image_count(self):
self.assertEqual(expected, set(frozenset(s) for s in groups.values()))


class TestAnnotationMatching(TestCase):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_match_shape_first_and_label_later(self):
# Labels must mismatch even though there is a better possible match around
# when the default matchers are used. This yields more adequate matching
# results for use cases such as annotation tools.
# In the case of models, typically we work with annotations after NMS,
# so this method will yield adequate results.
# If there is no NMS, a different strategy can be considered,
# which looks around for best matches including the labels.

anns1 = [
Bbox(0, 0, 4, 4, label=0, id=1),
Bbox(1, 1, 4, 4, label=1, id=2),
]

anns2 = [
Bbox(1, 1, 4, 4, label=0, id=2),
Bbox(0, 0, 4, 4, label=1, id=1),
]

matches, mismatches, a_extra, b_extra = match_segments(anns1, anns2, dist_thresh=0.5)
assert sorted(mismatches, key=lambda e: e[0].id) == [
(anns1[0], anns2[1]),
(anns1[1], anns2[0]),
]
assert not matches + a_extra + b_extra

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_match(self):
anns1 = [
# mismatching
Bbox(0, 0, 4, 4, label=0, id=1),
Bbox(1, 1, 4, 4, label=1, id=2),
# matching
Bbox(5, 5, 4, 4, label=0, id=3),
Bbox(6, 6, 4, 4, label=1, id=4),
# extra
Bbox(6, 0, 4, 4, label=1, id=5),
]

anns2 = [
# mismatching
Bbox(1, 1, 4, 4, label=0, id=2),
Bbox(0, 0, 4, 4, label=1, id=1),
# matching
Bbox(5, 5, 4, 4, label=0, id=3),
Bbox(6, 6, 4, 4, label=1, id=4),
# extra
Bbox(0, 6, 4, 4, label=1, id=5),
]

matches, mismatches, a_extra, b_extra = match_segments(anns1, anns2, dist_thresh=0.5)
assert sorted(mismatches, key=lambda e: e[0].id) == [
(anns1[0], anns2[1]),
(anns1[1], anns2[0]),
]
assert sorted(matches, key=lambda e: e[0].id) == [
(anns1[2], anns2[2]),
(anns1[3], anns2[3]),
]
assert a_extra == [anns1[4]]
assert b_extra == [anns2[4]]


class TestMultimerge(TestCase):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_match_items(self):
Expand Down

0 comments on commit 91c7499

Please sign in to comment.