Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --backend and --use-v2 support to detection refs #7732

Merged
merged 7 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def get_module(use_v2):


class ClassificationPresetTrain:
# Note: this transform assumes that the input to forward() are always PIL
# images, regardless of the backend parameter. We may change that in the
# future though, if we change the output type from the dataset.
def __init__(
self,
*,
Expand All @@ -30,42 +33,42 @@ def __init__(
backend="pil",
use_v2=False,
):
module = get_module(use_v2)
T = get_module(use_v2)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just did s/module/T/ in the file to make it consistent with the detection one


transforms = []
backend = backend.lower()
if backend == "tensor":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

transforms.append(module.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
transforms.append(T.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
if hflip_prob > 0:
transforms.append(module.RandomHorizontalFlip(hflip_prob))
transforms.append(T.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
transforms.append(module.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
transforms.append(T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
elif auto_augment_policy == "ta_wide":
transforms.append(module.TrivialAugmentWide(interpolation=interpolation))
transforms.append(T.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
transforms.append(module.AugMix(interpolation=interpolation, severity=augmix_severity))
transforms.append(T.AugMix(interpolation=interpolation, severity=augmix_severity))
else:
aa_policy = module.AutoAugmentPolicy(auto_augment_policy)
transforms.append(module.AutoAugment(policy=aa_policy, interpolation=interpolation))
aa_policy = T.AutoAugmentPolicy(auto_augment_policy)
transforms.append(T.AutoAugment(policy=aa_policy, interpolation=interpolation))

if backend == "pil":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())

transforms.extend(
[
module.ConvertImageDtype(torch.float),
module.Normalize(mean=mean, std=std),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)
if random_erase_prob > 0:
transforms.append(module.RandomErasing(p=random_erase_prob))
transforms.append(T.RandomErasing(p=random_erase_prob))

self.transforms = module.Compose(transforms)
self.transforms = T.Compose(transforms)

def __call__(self, img):
return self.transforms(img)
Expand All @@ -83,28 +86,28 @@ def __init__(
backend="pil",
use_v2=False,
):
module = get_module(use_v2)
T = get_module(use_v2)
transforms = []
backend = backend.lower()
if backend == "tensor":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

transforms += [
module.Resize(resize_size, interpolation=interpolation, antialias=True),
module.CenterCrop(crop_size),
T.Resize(resize_size, interpolation=interpolation, antialias=True),
T.CenterCrop(crop_size),
]

if backend == "pil":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())

transforms += [
module.ConvertImageDtype(torch.float),
module.Normalize(mean=mean, std=std),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]

self.transforms = module.Compose(transforms)
self.transforms = T.Compose(transforms)

def __call__(self, img):
return self.transforms(img)
34 changes: 19 additions & 15 deletions references/detection/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import transforms as T
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
from torchvision.datasets import wrap_dataset_for_transforms_v2


class FilterAndRemapCocoCategories:
Expand Down Expand Up @@ -49,7 +50,6 @@
w, h = image.size

image_id = target["image_id"]
image_id = torch.tensor([image_id])

anno = target["annotations"]

Expand Down Expand Up @@ -126,10 +126,6 @@
return True
return False

if not isinstance(dataset, torchvision.datasets.CocoDetection):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of removing this (seemingly useless check) I could just add the same workaround as elsewhere i.e. add

of isinstance(
        getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
    ):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still have #7239. Maybe we should go at it again?

raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
Expand Down Expand Up @@ -196,12 +192,15 @@


def get_coco_api_from_dataset(dataset):
# FIXME: This is... awful?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Happy for you to address it here, but not required.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would if I knew what to do lol. (I'm gonna leave this out for now I think)

for _ in range(10):
if isinstance(dataset, torchvision.datasets.CocoDetection):
break
if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection):
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
):
return dataset.coco
return convert_to_coco_api(dataset)

Expand All @@ -220,25 +219,29 @@
return img, target


def get_coco(root, image_set, transforms, mode="instances"):
def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
anno_file_template = "{}_{}2017.json"
PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
"val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
# "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
}

t = [ConvertCocoPolysToMask()]

if transforms is not None:
t.append(transforms)
transforms = T.Compose(t)

img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)

dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
if use_v2:
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
# TODO: need to update target_keys to handle masks for segmentation!
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"})
else:
t = [ConvertCocoPolysToMask()]
if transforms is not None:
t.append(transforms)
transforms = T.Compose(t)

dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could get rid of this custom CocoDetection dataset here. Ideally we would always call wrap_dataset_for_transforms_v2 and just "unwrap" the datapoints classes into pure-tensors etc...? But we can't use it without silencing the V2 warning first :/

Not sure what to do to clean that up.


if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset)
Expand All @@ -248,5 +251,6 @@
return dataset


def get_coco_kp(root, image_set, transforms):
def get_coco_kp(root, image_set, transforms, use_v2):

Check warning on line 254 in references/detection/coco_utils.py

View workflow job for this annotation

GitHub Actions / bc

Function get_coco_kp: use_v2 was added and is now required
# TODO: handle use_v2
return get_coco(root, image_set, transforms, mode="person_keypoints")
4 changes: 2 additions & 2 deletions references/detection/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc

for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for the image ID, right?

with torch.cuda.amp.autocast(enabled=scaler is not None):
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
Expand Down Expand Up @@ -97,7 +97,7 @@ def evaluate(model, data_loader, device):
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
model_time = time.time() - model_time

res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
res = {target["image_id"]: output for target, output in zip(targets, outputs)}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for consistency with the V2 wrapper which leaves image_id as an int. In our references we used to manually wrap it into a tensor (why, IDK), and I removed that as well below in coco_utils

evaluator_time = time.time()
coco_evaluator.update(res)
evaluator_time = time.time() - evaluator_time
Expand Down
4 changes: 3 additions & 1 deletion references/detection/group_by_aspect_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def compute_aspect_ratios(dataset, indices=None):
if hasattr(dataset, "get_height_and_width"):
return _compute_aspect_ratios_custom_dataset(dataset, indices)

if isinstance(dataset, torchvision.datasets.CocoDetection):
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
):
return _compute_aspect_ratios_coco_dataset(dataset, indices)

if isinstance(dataset, torchvision.datasets.VOCDetection):
Expand Down
140 changes: 87 additions & 53 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,107 @@
from collections import defaultdict

import torch
import transforms as T
import transforms as reference_transforms


def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.datapoints
import torchvision.transforms.v2

return torchvision.transforms.v2, torchvision.datapoints
else:
return reference_transforms, None


class DetectionPresetTrain:
def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
def __init__(
self,
*,
data_augmentation,
hflip_prob=0.5,
mean=(123.0, 117.0, 104.0),
backend="pil",
use_v2=False,
):

T, datapoints = get_modules(use_v2)

transforms = []
backend = backend.lower()
if backend == "datapoint":
transforms.append(T.ToImageTensor())
elif backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved

if data_augmentation == "hflip":
self.transforms = T.Compose(
[
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
elif data_augmentation == "lsj":
self.transforms = T.Compose(
[
T.ScaleJitter(target_size=(1024, 1024)),
T.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [
T.ScaleJitter(target_size=(1024, 1024), antialias=True),
# TODO: FixedSizeCrop below doesn't work on tensors!
reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
Comment on lines +47 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In v2 we have RandomCrop that does what FixedSizedCrop does minus the clamping and sanitizing bounding boxes.

T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "multiscale":
self.transforms = T.Compose(
[
T.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [
T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "ssd":
self.transforms = T.Compose(
[
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=list(mean)),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
fill = defaultdict(lambda: mean, {datapoints.Mask: 0}) if use_v2 else list(mean)
transforms += [
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=fill),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "ssdlite":
self.transforms = T.Compose(
[
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
]
else:
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')

if backend == "pil":
# Note: we could just convert to pure tensors even in v2.
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]

transforms += [T.ConvertImageDtype(torch.float)]

if use_v2:
transforms += [
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY),
T.SanitizeBoundingBox(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need ClampBoundingBox here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so since we established that all transforms should clamp already (those that need to, at least)?

]

self.transforms = T.Compose(transforms)

def __call__(self, img, target):
return self.transforms(img, target)


class DetectionPresetEval:
def __init__(self):
self.transforms = T.Compose(
[
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
def __init__(self, backend="pil", use_v2=False):
T, _ = get_modules(use_v2)
transforms = []
backend = backend.lower()
if backend == "pil":
# Note: we could just convert to pure tensors even in v2?
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]
elif backend == "tensor":
transforms += [T.PILToTensor()]
elif backend == "datapoint":
transforms += [T.ToImageTensor()]
else:
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

transforms += [T.ConvertImageDtype(torch.float)]
self.transforms = T.Compose(transforms)

def __call__(self, img, target):
return self.transforms(img, target)
Loading
Loading