Skip to content

Commit

Permalink
Add --backend and --use-v2 support for segmentation references (#7743)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Jul 27, 2023
1 parent 8233c9c commit b9b7cfc
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 49 deletions.
3 changes: 2 additions & 1 deletion references/detection/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import transforms as T
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
from torchvision.datasets import wrap_dataset_for_transforms_v2


def convert_coco_poly_to_mask(segmentations, height, width):
Expand Down Expand Up @@ -213,6 +212,8 @@ def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_m
ann_file = os.path.join(root, ann_file)

if use_v2:
from torchvision.datasets import wrap_dataset_for_transforms_v2

dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
target_keys = ["boxes", "labels", "image_id"]
if with_masks:
Expand Down
24 changes: 15 additions & 9 deletions references/segmentation/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ def _has_valid_annotation(anno):
# if more than 1k pixels occupied in the image
return sum(obj["area"] for obj in anno) > 1000

if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)

ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
Expand All @@ -86,21 +81,32 @@ def _has_valid_annotation(anno):
return dataset


def get_coco(root, image_set, transforms):
def get_coco(root, image_set, transforms, use_v2=False):
PATHS = {
"train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
# "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
}
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]

transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])

img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)

dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
# The 2 "Compose" below achieve the same thing: converting coco detection
# samples into segmentation-compatible samples. They just do it with
# slightly different implementations. We could refactor and unify, but
# keeping them separate helps keeping the v2 version clean
if use_v2:
import v2_extras
from torchvision.datasets import wrap_dataset_for_transforms_v2

transforms = Compose([v2_extras.CocoDetectionToVOCSegmentation(), transforms])
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"masks", "labels"})
else:
transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)

if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)
Expand Down
113 changes: 90 additions & 23 deletions references/segmentation/presets.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,106 @@
from collections import defaultdict

import torch
import transforms as T


def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.datapoints
import torchvision.transforms.v2
import v2_extras

return torchvision.transforms.v2, torchvision.datapoints, v2_extras
else:
import transforms

return transforms, None, None


class SegmentationPresetTrain:
def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
min_size = int(0.5 * base_size)
max_size = int(2.0 * base_size)
def __init__(
self,
*,
base_size,
crop_size,
hflip_prob=0.5,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
backend="pil",
use_v2=False,
):
T, datapoints, v2_extras = get_modules(use_v2)

transforms = []
backend = backend.lower()
if backend == "datapoint":
transforms.append(T.ToImageTensor())
elif backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

transforms += [T.RandomResize(min_size=int(0.5 * base_size), max_size=int(2.0 * base_size))]

trans = [T.RandomResize(min_size, max_size)]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))
trans.extend(
[
T.RandomCrop(crop_size),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
transforms += [T.RandomHorizontalFlip(hflip_prob)]

if use_v2:
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))]

transforms += [T.RandomCrop(crop_size)]

if backend == "pil":
transforms += [T.PILToTensor()]

if use_v2:
img_type = datapoints.Image if backend == "datapoint" else torch.Tensor
transforms += [
T.ToDtype(dtype={img_type: torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True)
]
)
self.transforms = T.Compose(trans)
else:
# No need to explicitly convert masks as they're magically int64 already
transforms += [T.ConvertImageDtype(torch.float)]

transforms += [T.Normalize(mean=mean, std=std)]

self.transforms = T.Compose(transforms)

def __call__(self, img, target):
return self.transforms(img, target)


class SegmentationPresetEval:
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose(
[
T.RandomResize(base_size, base_size),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)
def __init__(
self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), backend="pil", use_v2=False
):
T, _, _ = get_modules(use_v2)

transforms = []
backend = backend.lower()
if backend == "tensor":
transforms += [T.PILToTensor()]
elif backend == "datapoint":
transforms += [T.ToImageTensor()]
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

if use_v2:
transforms += [T.Resize(size=(base_size, base_size))]
else:
transforms += [T.RandomResize(min_size=base_size, max_size=base_size)]

if backend == "pil":
# Note: we could just convert to pure tensors even in v2?
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]

transforms += [
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
self.transforms = T.Compose(transforms)

def __call__(self, img, target):
return self.transforms(img, target)
38 changes: 26 additions & 12 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,30 @@
from torchvision.transforms import functional as F, InterpolationMode


def get_dataset(dir_path, name, image_set, transform):
def get_dataset(args, is_train):
def sbd(*args, **kwargs):
kwargs.pop("use_v2")
return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)

def voc(*args, **kwargs):
kwargs.pop("use_v2")
return torchvision.datasets.VOCSegmentation(*args, **kwargs)

paths = {
"voc": (dir_path, torchvision.datasets.VOCSegmentation, 21),
"voc_aug": (dir_path, sbd, 21),
"coco": (dir_path, get_coco, 21),
"voc": (args.data_path, voc, 21),
"voc_aug": (args.data_path, sbd, 21),
"coco": (args.data_path, get_coco, 21),
}
p, ds_fn, num_classes = paths[name]
p, ds_fn, num_classes = paths[args.dataset]

ds = ds_fn(p, image_set=image_set, transforms=transform)
image_set = "train" if is_train else "val"
ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2)
return ds, num_classes


def get_transform(train, args):
if train:
return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
def get_transform(is_train, args):
if is_train:
return presets.SegmentationPresetTrain(base_size=520, crop_size=480, backend=args.backend, use_v2=args.use_v2)
elif args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
Expand All @@ -44,7 +50,7 @@ def preprocessing(img, target):

return preprocessing
else:
return presets.SegmentationPresetEval(base_size=520)
return presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2)


def criterion(inputs, target):
Expand Down Expand Up @@ -120,6 +126,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi


def main(args):
if args.backend.lower() != "pil" and not args.use_v2:
# TODO: Support tensor backend in V1?
raise ValueError("Use --use-v2 if you want to use the datapoint or tensor backend.")
if args.use_v2 and args.dataset != "coco":
raise ValueError("v2 is only support supported for coco dataset for now.")

if args.output_dir:
utils.mkdir(args.output_dir)

Expand All @@ -134,8 +146,8 @@ def main(args):
else:
torch.backends.cudnn.benchmark = True

dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args))
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args))
dataset, num_classes = get_dataset(args, is_train=True)
dataset_test, _ = get_dataset(args, is_train=False)

if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
Expand Down Expand Up @@ -307,6 +319,8 @@ def get_args_parser(add_help=True):
# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
return parser


Expand Down
2 changes: 1 addition & 1 deletion references/segmentation/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, min_size, max_size=None):

def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = F.resize(image, size)
image = F.resize(image, size, antialias=True)
target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
return image, target

Expand Down
6 changes: 3 additions & 3 deletions references/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,9 @@ def init_distributed_mode(args):
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
# elif "SLURM_PROCID" in os.environ:
# args.rank = int(os.environ["SLURM_PROCID"])
# args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
Expand Down
83 changes: 83 additions & 0 deletions references/segmentation/v2_extras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""This file only exists to be lazy-imported and avoid V2-related import warnings when just using V1."""
import torch
from torchvision import datapoints
from torchvision.transforms import v2


class PadIfSmaller(v2.Transform):
def __init__(self, size, fill=0):
super().__init__()
self.size = size
self.fill = v2._geometry._setup_fill_arg(fill)

def _get_params(self, sample):
_, height, width = v2.utils.query_chw(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding)

def _transform(self, inpt, params):
if not params["needs_padding"]:
return inpt

fill = self.fill[type(inpt)]
fill = v2._utils._convert_fill_arg(fill)

return v2.functional.pad(inpt, padding=params["padding"], fill=fill)


class CocoDetectionToVOCSegmentation(v2.Transform):
"""Turn samples from datasets.CocoDetection into the same format as VOCSegmentation.
This is achieved in two steps:
1. COCO differentiates between 91 categories while VOC only supports 21, including background for both. Fortunately,
the COCO categories are a superset of the VOC ones and thus can be mapped. Instances of the 70 categories not
present in VOC are dropped and replaced by background.
2. COCO only offers detection masks, i.e. a (N, H, W) bool-ish tensor, where the truthy values in each individual
mask denote the instance. However, a segmentation mask is a (H, W) integer tensor (typically torch.uint8), where
the value of each pixel denotes the category it belongs to. The detection masks are merged into one segmentation
mask while pixels that belong to multiple detection masks are marked as invalid.
"""

COCO_TO_VOC_LABEL_MAP = dict(
zip(
[0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72],
range(21),
)
)
INVALID_VALUE = 255

def _coco_detection_masks_to_voc_segmentation_mask(self, target):
if "masks" not in target:
return None

instance_masks, instance_labels_coco = target["masks"], target["labels"]

valid_labels_voc = [
(idx, label_voc)
for idx, label_coco in enumerate(instance_labels_coco.tolist())
if (label_voc := self.COCO_TO_VOC_LABEL_MAP.get(label_coco)) is not None
]

if not valid_labels_voc:
return None

valid_voc_category_idcs, instance_labels_voc = zip(*valid_labels_voc)

instance_masks = instance_masks[list(valid_voc_category_idcs)].to(torch.uint8)
instance_labels_voc = torch.tensor(instance_labels_voc, dtype=torch.uint8)

# Calling `.max()` on the stacked detection masks works fine to separate background from foreground as long as
# there is at most a single instance per pixel. Overlapping instances will be filtered out in the next step.
segmentation_mask, _ = (instance_masks * instance_labels_voc.reshape(-1, 1, 1)).max(dim=0)
segmentation_mask[instance_masks.sum(dim=0) > 1] = self.INVALID_VALUE

return segmentation_mask

def forward(self, image, target):
segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target)
if segmentation_mask is None:
segmentation_mask = torch.zeros(v2.functional.get_spatial_size(image), dtype=torch.uint8)

return image, datapoints.Mask(segmentation_mask)

0 comments on commit b9b7cfc

Please sign in to comment.