Skip to content

Commit

Permalink
Add --backend and --use-v2 support to detection refs (#7732)
Browse files Browse the repository at this point in the history
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
  • Loading branch information
NicolasHug and pmeier authored Jul 13, 2023
1 parent 08c9938 commit bb3aae7
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 106 deletions.
47 changes: 25 additions & 22 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def get_module(use_v2):


class ClassificationPresetTrain:
# Note: this transform assumes that the input to forward() are always PIL
# images, regardless of the backend parameter. We may change that in the
# future though, if we change the output type from the dataset.
def __init__(
self,
*,
Expand All @@ -30,42 +33,42 @@ def __init__(
backend="pil",
use_v2=False,
):
module = get_module(use_v2)
T = get_module(use_v2)

transforms = []
backend = backend.lower()
if backend == "tensor":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

transforms.append(module.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
transforms.append(T.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
if hflip_prob > 0:
transforms.append(module.RandomHorizontalFlip(hflip_prob))
transforms.append(T.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
transforms.append(module.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
transforms.append(T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
elif auto_augment_policy == "ta_wide":
transforms.append(module.TrivialAugmentWide(interpolation=interpolation))
transforms.append(T.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
transforms.append(module.AugMix(interpolation=interpolation, severity=augmix_severity))
transforms.append(T.AugMix(interpolation=interpolation, severity=augmix_severity))
else:
aa_policy = module.AutoAugmentPolicy(auto_augment_policy)
transforms.append(module.AutoAugment(policy=aa_policy, interpolation=interpolation))
aa_policy = T.AutoAugmentPolicy(auto_augment_policy)
transforms.append(T.AutoAugment(policy=aa_policy, interpolation=interpolation))

if backend == "pil":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())

transforms.extend(
[
module.ConvertImageDtype(torch.float),
module.Normalize(mean=mean, std=std),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)
if random_erase_prob > 0:
transforms.append(module.RandomErasing(p=random_erase_prob))
transforms.append(T.RandomErasing(p=random_erase_prob))

self.transforms = module.Compose(transforms)
self.transforms = T.Compose(transforms)

def __call__(self, img):
return self.transforms(img)
Expand All @@ -83,28 +86,28 @@ def __init__(
backend="pil",
use_v2=False,
):
module = get_module(use_v2)
T = get_module(use_v2)
transforms = []
backend = backend.lower()
if backend == "tensor":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

transforms += [
module.Resize(resize_size, interpolation=interpolation, antialias=True),
module.CenterCrop(crop_size),
T.Resize(resize_size, interpolation=interpolation, antialias=True),
T.CenterCrop(crop_size),
]

if backend == "pil":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())

transforms += [
module.ConvertImageDtype(torch.float),
module.Normalize(mean=mean, std=std),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]

self.transforms = module.Compose(transforms)
self.transforms = T.Compose(transforms)

def __call__(self, img):
return self.transforms(img)
35 changes: 20 additions & 15 deletions references/detection/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import transforms as T
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
from torchvision.datasets import wrap_dataset_for_transforms_v2


class FilterAndRemapCocoCategories:
Expand Down Expand Up @@ -49,7 +50,6 @@ def __call__(self, image, target):
w, h = image.size

image_id = target["image_id"]
image_id = torch.tensor([image_id])

anno = target["annotations"]

Expand Down Expand Up @@ -126,10 +126,6 @@ def _has_valid_annotation(anno):
return True
return False

if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
Expand Down Expand Up @@ -196,12 +192,15 @@ def convert_to_coco_api(ds):


def get_coco_api_from_dataset(dataset):
# FIXME: This is... awful?
for _ in range(10):
if isinstance(dataset, torchvision.datasets.CocoDetection):
break
if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection):
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
):
return dataset.coco
return convert_to_coco_api(dataset)

Expand All @@ -220,25 +219,29 @@ def __getitem__(self, idx):
return img, target


def get_coco(root, image_set, transforms, mode="instances"):
def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
anno_file_template = "{}_{}2017.json"
PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
"val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
# "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
}

t = [ConvertCocoPolysToMask()]

if transforms is not None:
t.append(transforms)
transforms = T.Compose(t)

img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)

dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
if use_v2:
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
# TODO: need to update target_keys to handle masks for segmentation!
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"})
else:
t = [ConvertCocoPolysToMask()]
if transforms is not None:
t.append(transforms)
transforms = T.Compose(t)

dataset = CocoDetection(img_folder, ann_file, transforms=transforms)

if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset)
Expand All @@ -248,5 +251,7 @@ def get_coco(root, image_set, transforms, mode="instances"):
return dataset


def get_coco_kp(root, image_set, transforms):
def get_coco_kp(root, image_set, transforms, use_v2=False):
if use_v2:
raise ValueError("KeyPoints aren't supported by transforms V2 yet.")
return get_coco(root, image_set, transforms, mode="person_keypoints")
4 changes: 2 additions & 2 deletions references/detection/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc

for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
with torch.cuda.amp.autocast(enabled=scaler is not None):
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
Expand Down Expand Up @@ -97,7 +97,7 @@ def evaluate(model, data_loader, device):
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
model_time = time.time() - model_time

res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
res = {target["image_id"]: output for target, output in zip(targets, outputs)}

This comment has been minimized.

Copy link
@vfdev-5

vfdev-5 Aug 22, 2023

Collaborator

This change is breaking for non-Coco datasets, typically the example from https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html as now image_id types are mixed between tensors and int.
We can fix the counterpart (namely convert_to_coco_api), I'll send a PR

evaluator_time = time.time()
coco_evaluator.update(res)
evaluator_time = time.time() - evaluator_time
Expand Down
4 changes: 3 additions & 1 deletion references/detection/group_by_aspect_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def compute_aspect_ratios(dataset, indices=None):
if hasattr(dataset, "get_height_and_width"):
return _compute_aspect_ratios_custom_dataset(dataset, indices)

if isinstance(dataset, torchvision.datasets.CocoDetection):
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
):
return _compute_aspect_ratios_coco_dataset(dataset, indices)

if isinstance(dataset, torchvision.datasets.VOCDetection):
Expand Down
142 changes: 89 additions & 53 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,109 @@
from collections import defaultdict

import torch
import transforms as T
import transforms as reference_transforms


def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.datapoints
import torchvision.transforms.v2

return torchvision.transforms.v2, torchvision.datapoints
else:
return reference_transforms, None


class DetectionPresetTrain:
def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
# Note: this transform assumes that the input to forward() are always PIL
# images, regardless of the backend parameter.
def __init__(
self,
*,
data_augmentation,
hflip_prob=0.5,
mean=(123.0, 117.0, 104.0),
backend="pil",
use_v2=False,
):

T, datapoints = get_modules(use_v2)

transforms = []
backend = backend.lower()
if backend == "datapoint":
transforms.append(T.ToImageTensor())
elif backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

if data_augmentation == "hflip":
self.transforms = T.Compose(
[
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
elif data_augmentation == "lsj":
self.transforms = T.Compose(
[
T.ScaleJitter(target_size=(1024, 1024)),
T.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [
T.ScaleJitter(target_size=(1024, 1024), antialias=True),
# TODO: FixedSizeCrop below doesn't work on tensors!
reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "multiscale":
self.transforms = T.Compose(
[
T.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [
T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "ssd":
self.transforms = T.Compose(
[
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=list(mean)),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
fill = defaultdict(lambda: mean, {datapoints.Mask: 0}) if use_v2 else list(mean)
transforms += [
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=fill),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "ssdlite":
self.transforms = T.Compose(
[
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
]
else:
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')

if backend == "pil":
# Note: we could just convert to pure tensors even in v2.
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]

transforms += [T.ConvertImageDtype(torch.float)]

if use_v2:
transforms += [
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY),
T.SanitizeBoundingBox(),
]

self.transforms = T.Compose(transforms)

def __call__(self, img, target):
return self.transforms(img, target)


class DetectionPresetEval:
def __init__(self):
self.transforms = T.Compose(
[
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
def __init__(self, backend="pil", use_v2=False):
T, _ = get_modules(use_v2)
transforms = []
backend = backend.lower()
if backend == "pil":
# Note: we could just convert to pure tensors even in v2?
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]
elif backend == "tensor":
transforms += [T.PILToTensor()]
elif backend == "datapoint":
transforms += [T.ToImageTensor()]
else:
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

transforms += [T.ConvertImageDtype(torch.float)]
self.transforms = T.Compose(transforms)

def __call__(self, img, target):
return self.transforms(img, target)
Loading

0 comments on commit bb3aae7

Please sign in to comment.