-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add --backend and --use-v2 support for segmentation references (#7743)
- Loading branch information
1 parent
8233c9c
commit b9b7cfc
Showing
7 changed files
with
220 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,106 @@ | ||
from collections import defaultdict | ||
|
||
import torch | ||
import transforms as T | ||
|
||
|
||
def get_modules(use_v2): | ||
# We need a protected import to avoid the V2 warning in case just V1 is used | ||
if use_v2: | ||
import torchvision.datapoints | ||
import torchvision.transforms.v2 | ||
import v2_extras | ||
|
||
return torchvision.transforms.v2, torchvision.datapoints, v2_extras | ||
else: | ||
import transforms | ||
|
||
return transforms, None, None | ||
|
||
|
||
class SegmentationPresetTrain: | ||
def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): | ||
min_size = int(0.5 * base_size) | ||
max_size = int(2.0 * base_size) | ||
def __init__( | ||
self, | ||
*, | ||
base_size, | ||
crop_size, | ||
hflip_prob=0.5, | ||
mean=(0.485, 0.456, 0.406), | ||
std=(0.229, 0.224, 0.225), | ||
backend="pil", | ||
use_v2=False, | ||
): | ||
T, datapoints, v2_extras = get_modules(use_v2) | ||
|
||
transforms = [] | ||
backend = backend.lower() | ||
if backend == "datapoint": | ||
transforms.append(T.ToImageTensor()) | ||
elif backend == "tensor": | ||
transforms.append(T.PILToTensor()) | ||
elif backend != "pil": | ||
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") | ||
|
||
transforms += [T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size))] | ||
|
||
trans = [T.RandomResize(min_size, max_size)] | ||
if hflip_prob > 0: | ||
trans.append(T.RandomHorizontalFlip(hflip_prob)) | ||
trans.extend( | ||
[ | ||
T.RandomCrop(crop_size), | ||
T.PILToTensor(), | ||
T.ConvertImageDtype(torch.float), | ||
T.Normalize(mean=mean, std=std), | ||
transforms += [T.RandomHorizontalFlip(hflip_prob)] | ||
|
||
if use_v2: | ||
# We need a custom pad transform here, since the padding we want to perform here is fundamentally | ||
# different from the padding in `RandomCrop` if `pad_if_needed=True`. | ||
transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))] | ||
|
||
transforms += [T.RandomCrop(crop_size)] | ||
|
||
if backend == "pil": | ||
transforms += [T.PILToTensor()] | ||
|
||
if use_v2: | ||
img_type = datapoints.Image if backend == "datapoint" else torch.Tensor | ||
transforms += [ | ||
T.ToDtype(dtype={img_type: torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True) | ||
] | ||
) | ||
self.transforms = T.Compose(trans) | ||
else: | ||
# No need to explicitly convert masks as they're magically int64 already | ||
transforms += [T.ConvertImageDtype(torch.float)] | ||
|
||
transforms += [T.Normalize(mean=mean, std=std)] | ||
|
||
self.transforms = T.Compose(transforms) | ||
|
||
def __call__(self, img, target): | ||
return self.transforms(img, target) | ||
|
||
|
||
class SegmentationPresetEval: | ||
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): | ||
self.transforms = T.Compose( | ||
[ | ||
T.RandomResize(base_size, base_size), | ||
T.PILToTensor(), | ||
T.ConvertImageDtype(torch.float), | ||
T.Normalize(mean=mean, std=std), | ||
] | ||
) | ||
def __init__( | ||
self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil", use_v2=False | ||
): | ||
T, _, _ = get_modules(use_v2) | ||
|
||
transforms = [] | ||
backend = backend.lower() | ||
if backend == "tensor": | ||
transforms += [T.PILToTensor()] | ||
elif backend == "datapoint": | ||
transforms += [T.ToImageTensor()] | ||
elif backend != "pil": | ||
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") | ||
|
||
if use_v2: | ||
transforms += [T.Resize(size=(base_size, base_size))] | ||
else: | ||
transforms += [T.RandomResize(min_size=base_size, max_size=base_size)] | ||
|
||
if backend == "pil": | ||
# Note: we could just convert to pure tensors even in v2? | ||
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] | ||
|
||
transforms += [ | ||
T.ConvertImageDtype(torch.float), | ||
T.Normalize(mean=mean, std=std), | ||
] | ||
self.transforms = T.Compose(transforms) | ||
|
||
def __call__(self, img, target): | ||
return self.transforms(img, target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
"""This file only exists to be lazy-imported and avoid V2-related import warnings when just using V1.""" | ||
import torch | ||
from torchvision import datapoints | ||
from torchvision.transforms import v2 | ||
|
||
|
||
class PadIfSmaller(v2.Transform): | ||
def __init__(self, size, fill=0): | ||
super().__init__() | ||
self.size = size | ||
self.fill = v2._geometry._setup_fill_arg(fill) | ||
|
||
def _get_params(self, sample): | ||
_, height, width = v2.utils.query_chw(sample) | ||
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] | ||
needs_padding = any(padding) | ||
return dict(padding=padding, needs_padding=needs_padding) | ||
|
||
def _transform(self, inpt, params): | ||
if not params["needs_padding"]: | ||
return inpt | ||
|
||
fill = self.fill[type(inpt)] | ||
fill = v2._utils._convert_fill_arg(fill) | ||
|
||
return v2.functional.pad(inpt, padding=params["padding"], fill=fill) | ||
|
||
|
||
class CocoDetectionToVOCSegmentation(v2.Transform): | ||
"""Turn samples from datasets.CocoDetection into the same format as VOCSegmentation. | ||
This is achieved in two steps: | ||
1. COCO differentiates between 91 categories while VOC only supports 21, including background for both. Fortunately, | ||
the COCO categories are a superset of the VOC ones and thus can be mapped. Instances of the 70 categories not | ||
present in VOC are dropped and replaced by background. | ||
2. COCO only offers detection masks, i.e. a (N, H, W) bool-ish tensor, where the truthy values in each individual | ||
mask denote the instance. However, a segmentation mask is a (H, W) integer tensor (typically torch.uint8), where | ||
the value of each pixel denotes the category it belongs to. The detection masks are merged into one segmentation | ||
mask while pixels that belong to multiple detection masks are marked as invalid. | ||
""" | ||
|
||
COCO_TO_VOC_LABEL_MAP = dict( | ||
zip( | ||
[0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72], | ||
range(21), | ||
) | ||
) | ||
INVALID_VALUE = 255 | ||
|
||
def _coco_detection_masks_to_voc_segmentation_mask(self, target): | ||
if "masks" not in target: | ||
return None | ||
|
||
instance_masks, instance_labels_coco = target["masks"], target["labels"] | ||
|
||
valid_labels_voc = [ | ||
(idx, label_voc) | ||
for idx, label_coco in enumerate(instance_labels_coco.tolist()) | ||
if (label_voc := self.COCO_TO_VOC_LABEL_MAP.get(label_coco)) is not None | ||
] | ||
|
||
if not valid_labels_voc: | ||
return None | ||
|
||
valid_voc_category_idcs, instance_labels_voc = zip(*valid_labels_voc) | ||
|
||
instance_masks = instance_masks[list(valid_voc_category_idcs)].to(torch.uint8) | ||
instance_labels_voc = torch.tensor(instance_labels_voc, dtype=torch.uint8) | ||
|
||
# Calling `.max()` on the stacked detection masks works fine to separate background from foreground as long as | ||
# there is at most a single instance per pixel. Overlapping instances will be filtered out in the next step. | ||
segmentation_mask, _ = (instance_masks * instance_labels_voc.reshape(-1, 1, 1)).max(dim=0) | ||
segmentation_mask[instance_masks.sum(dim=0) > 1] = self.INVALID_VALUE | ||
|
||
return segmentation_mask | ||
|
||
def forward(self, image, target): | ||
segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target) | ||
if segmentation_mask is None: | ||
segmentation_mask = torch.zeros(v2.functional.get_spatial_size(image), dtype=torch.uint8) | ||
|
||
return image, datapoints.Mask(segmentation_mask) |