Skip to content

Commit

Permalink
remove batch_dims from make bounding boxes and detection masks (#7855)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Aug 18, 2023
1 parent 59b27ed commit a7501e1
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,26 +406,21 @@ def make_bounding_boxes(
canvas_size=DEFAULT_SIZE,
*,
format=datapoints.BoundingBoxFormat.XYXY,
batch_dims=(),
dtype=None,
device="cpu",
):
def sample_position(values, max_value):
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
# However, if we have batch_dims, we need tensors as limits.
return torch.stack([torch.randint(max_value - v, ()) for v in values.flatten().tolist()]).reshape(values.shape)
return torch.stack([torch.randint(max_value - v, ()) for v in values.tolist()])

if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]

dtype = dtype or torch.float32

if any(dim == 0 for dim in batch_dims):
return datapoints.BoundingBoxes(
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, canvas_size=canvas_size
)

h, w = [torch.randint(1, c, batch_dims) for c in canvas_size]
num_objects = 1
h, w = [torch.randint(1, c, (num_objects,)) for c in canvas_size]
y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1])

Expand All @@ -448,11 +443,12 @@ def sample_position(values, max_value):
)


def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtype=None, device="cpu"):
def make_detection_mask(size=DEFAULT_SIZE, *, dtype=None, device="cpu"):
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
num_objects = 1
return datapoints.Mask(
torch.testing.make_tensor(
(*batch_dims, num_objects, *size),
(num_objects, *size),
low=0,
high=2,
dtype=dtype or torch.bool,
Expand Down

0 comments on commit a7501e1

Please sign in to comment.