Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jul 31, 2023
1 parent 124db1f commit 2b2d10a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
4 changes: 3 additions & 1 deletion torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
elif isinstance(inpt, datapoints.BoundingBoxes):
inpt = datapoints.BoundingBoxes.wrap_like(
inpt,
F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, spatial_size=inpt.spatial_size),
F.clamp_bounding_boxes(
inpt[params["is_valid"]], format=inpt.format, spatial_size=inpt.spatial_size
),
)

if params["needs_pad"]:
Expand Down
4 changes: 3 additions & 1 deletion torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,7 +1826,9 @@ def center_crop_bounding_boxes(
) -> Tuple[torch.Tensor, Tuple[int, int]]:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *spatial_size)
return crop_bounding_boxes(bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width)
return crop_bounding_boxes(
bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width
)


def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
Expand Down
4 changes: 3 additions & 1 deletion torchvision/transforms/v2/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,9 @@ def clamp_bounding_boxes(
elif isinstance(inpt, datapoints.BoundingBoxes):
if format is not None or spatial_size is not None:
raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.")
output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size)
output = _clamp_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size
)
return datapoints.BoundingBoxes.wrap_like(inpt, output)
else:
raise TypeError(
Expand Down

0 comments on commit 2b2d10a

Please sign in to comment.