diff --git a/test/common_utils.py b/test/common_utils.py index c815786b586..61f06994801 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -406,26 +406,21 @@ def make_bounding_boxes( canvas_size=DEFAULT_SIZE, *, format=datapoints.BoundingBoxFormat.XYXY, - batch_dims=(), dtype=None, device="cpu", ): def sample_position(values, max_value): # We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high. # However, if we have batch_dims, we need tensors as limits. - return torch.stack([torch.randint(max_value - v, ()) for v in values.flatten().tolist()]).reshape(values.shape) + return torch.stack([torch.randint(max_value - v, ()) for v in values.tolist()]) if isinstance(format, str): format = datapoints.BoundingBoxFormat[format] dtype = dtype or torch.float32 - if any(dim == 0 for dim in batch_dims): - return datapoints.BoundingBoxes( - torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, canvas_size=canvas_size - ) - - h, w = [torch.randint(1, c, batch_dims) for c in canvas_size] + num_objects = 1 + h, w = [torch.randint(1, c, (num_objects,)) for c in canvas_size] y = sample_position(h, canvas_size[0]) x = sample_position(w, canvas_size[1]) @@ -448,11 +443,12 @@ def sample_position(values, max_value): ) -def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtype=None, device="cpu"): +def make_detection_mask(size=DEFAULT_SIZE, *, dtype=None, device="cpu"): """Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks""" + num_objects = 1 return datapoints.Mask( torch.testing.make_tensor( - (*batch_dims, num_objects, *size), + (num_objects, *size), low=0, high=2, dtype=dtype or torch.bool,