Skip to content

Commit

Permalink
Merge branch 'main' into fix_ssl_issue
Browse files Browse the repository at this point in the history
  • Loading branch information
atalman authored Aug 24, 2023
2 parents dd6f30a + 9f0afd5 commit ba3d230
Show file tree
Hide file tree
Showing 14 changed files with 140 additions and 35 deletions.
7 changes: 3 additions & 4 deletions gallery/v2_transforms/plot_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,9 @@
format="XYXY", canvas_size=img.shape[-2:])

transforms = v2.Compose([
v2.RandomPhotometricDistort(),
v2.RandomIoUCrop(),
v2.RandomHorizontalFlip(p=0.5),
v2.SanitizeBoundingBoxes(),
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomPhotometricDistort(p=1),
v2.RandomHorizontalFlip(p=1),
])
out_img, out_bboxes = transforms(img, bboxes)

Expand Down
4 changes: 2 additions & 2 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(

transforms.extend(
[
T.ConvertImageDtype(torch.float),
T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(
transforms.append(T.PILToTensor())

transforms += [
T.ConvertImageDtype(torch.float),
T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]

Expand Down
4 changes: 2 additions & 2 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
# Note: we could just convert to pure tensors even in v2.
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]

transforms += [T.ConvertImageDtype(torch.float)]
transforms += [T.ToDtype(torch.float, scale=True)]

if use_v2:
transforms += [
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(self, backend="pil", use_v2=False):
else:
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

transforms += [T.ConvertImageDtype(torch.float)]
transforms += [T.ToDtype(torch.float, scale=True)]

if use_v2:
transforms += [T.ToPureTensor()]
Expand Down
7 changes: 5 additions & 2 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,17 @@ def forward(
return image, target


class ConvertImageDtype(nn.Module):
def __init__(self, dtype: torch.dtype) -> None:
class ToDtype(nn.Module):
def __init__(self, dtype: torch.dtype, scale: bool = False) -> None:
super().__init__()
self.dtype = dtype
self.scale = scale

def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if not self.scale:
return image.to(dtype=self.dtype), target
image = F.convert_image_dtype(image, self.dtype)
return image, target

Expand Down
4 changes: 2 additions & 2 deletions references/segmentation/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
]
else:
# No need to explicitly convert masks as they're magically int64 already
transforms += [T.ConvertImageDtype(torch.float)]
transforms += [T.ToDtype(torch.float, scale=True)]

transforms += [T.Normalize(mean=mean, std=std)]
if use_v2:
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]

transforms += [
T.ConvertImageDtype(torch.float),
T.ToDtype(torch.float, scale=True),
T.Normalize(mean=mean, std=std),
]
if use_v2:
Expand Down
7 changes: 5 additions & 2 deletions references/segmentation/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,14 @@ def __call__(self, image, target):
return image, target


class ConvertImageDtype:
def __init__(self, dtype):
class ToDtype:
def __init__(self, dtype, scale=False):
self.dtype = dtype
self.scale = scale

def __call__(self, image, target):
if not self.scale:
return image.to(dtype=self.dtype), target
image = F.convert_image_dtype(image, self.dtype)
return image, target

Expand Down
2 changes: 1 addition & 1 deletion references/segmentation/v2_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,6 @@ def _coco_detection_masks_to_voc_segmentation_mask(self, target):
def forward(self, image, target):
segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target)
if segmentation_mask is None:
segmentation_mask = torch.zeros(v2.functional.get_spatial_size(image), dtype=torch.uint8)
segmentation_mask = torch.zeros(v2.functional.get_size(image), dtype=torch.uint8)

return image, datapoints.Mask(segmentation_mask)
22 changes: 17 additions & 5 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,27 +662,39 @@ class VideoDatasetTestCase(DatasetTestCase):
FEATURE_TYPES = (torch.Tensor, torch.Tensor, int)
REQUIRED_PACKAGES = ("av",)

DEFAULT_FRAMES_PER_CLIP = 1
FRAMES_PER_CLIP = 1

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_args = self._set_default_frames_per_clip(self.dataset_args)

def _set_default_frames_per_clip(self, inject_fake_data):
def _set_default_frames_per_clip(self, dataset_args):
argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__)
args_without_default = argspec.args[1 : (-len(argspec.defaults) if argspec.defaults else None)]
frames_per_clip_last = args_without_default[-1] == "frames_per_clip"

@functools.wraps(inject_fake_data)
@functools.wraps(dataset_args)
def wrapper(tmpdir, config):
args = inject_fake_data(tmpdir, config)
args = dataset_args(tmpdir, config)
if frames_per_clip_last and len(args) == len(args_without_default) - 1:
args = (*args, self.DEFAULT_FRAMES_PER_CLIP)
args = (*args, self.FRAMES_PER_CLIP)

return args

return wrapper

def test_output_format(self):
for output_format in ["TCHW", "THWC"]:
with self.create_dataset(output_format=output_format) as (dataset, _):
for video, *_ in dataset:
if output_format == "TCHW":
num_frames, num_channels, *_ = video.shape
else: # output_format == "THWC":
num_frames, *_, num_channels = video.shape

assert num_frames == self.FRAMES_PER_CLIP
assert num_channels == 3

@test_all_configs
def test_transforms_v2_wrapper(self, config):
# `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly
Expand Down
5 changes: 1 addition & 4 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,9 @@ def test_decode_jpeg(img_path, pil_mode, mode):
with Image.open(img_path) as img:
is_cmyk = img.mode == "CMYK"
if pil_mode is not None:
if is_cmyk:
# libjpeg does not support the conversion
pytest.xfail("Decoding a CMYK jpeg isn't supported")
img = img.convert(pil_mode)
img_pil = torch.from_numpy(np.array(img))
if is_cmyk:
if is_cmyk and mode == ImageReadMode.UNCHANGED:
# flip the colors to match libjpeg
img_pil = 255 - img_pil

Expand Down
14 changes: 14 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,20 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert out_labels.tolist() == valid_indices


def test_sanitize_bounding_boxes_no_label():
# Non-regression test for https://github.com/pytorch/vision/issues/7878

img = make_image()
boxes = make_bounding_boxes()

with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"):
transforms.SanitizeBoundingBoxes()(img, boxes)

out_img, out_boxes = transforms.SanitizeBoundingBoxes(labels_getter=None)(img, boxes)
assert isinstance(out_img, datapoints.Image)
assert isinstance(out_boxes, datapoints.BoundingBoxes)


def test_sanitize_bounding_boxes_errors():

good_bbox = datapoints.BoundingBoxes(
Expand Down
87 changes: 82 additions & 5 deletions torchvision/csrc/io/image/cpu/decode_jpeg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,58 @@ static void torch_jpeg_set_source_mgr(
src->pub.next_input_byte = src->data;
}

inline unsigned char clamped_cmyk_rgb_convert(
unsigned char k,
unsigned char cmy) {
// Inspired from Pillow:
// https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569
int v = k * cmy + 128;
v = ((v >> 8) + v) >> 8;
return std::clamp(k - v, 0, 255);
}

void convert_line_cmyk_to_rgb(
j_decompress_ptr cinfo,
const unsigned char* cmyk_line,
unsigned char* rgb_line) {
int width = cinfo->output_width;
for (int i = 0; i < width; ++i) {
int c = cmyk_line[i * 4 + 0];
int m = cmyk_line[i * 4 + 1];
int y = cmyk_line[i * 4 + 2];
int k = cmyk_line[i * 4 + 3];

rgb_line[i * 3 + 0] = clamped_cmyk_rgb_convert(k, 255 - c);
rgb_line[i * 3 + 1] = clamped_cmyk_rgb_convert(k, 255 - m);
rgb_line[i * 3 + 2] = clamped_cmyk_rgb_convert(k, 255 - y);
}
}

inline unsigned char rgb_to_gray(int r, int g, int b) {
// Inspired from Pillow:
// https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226
return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16;
}

void convert_line_cmyk_to_gray(
j_decompress_ptr cinfo,
const unsigned char* cmyk_line,
unsigned char* gray_line) {
int width = cinfo->output_width;
for (int i = 0; i < width; ++i) {
int c = cmyk_line[i * 4 + 0];
int m = cmyk_line[i * 4 + 1];
int y = cmyk_line[i * 4 + 2];
int k = cmyk_line[i * 4 + 3];

int r = clamped_cmyk_rgb_convert(k, 255 - c);
int g = clamped_cmyk_rgb_convert(k, 255 - m);
int b = clamped_cmyk_rgb_convert(k, 255 - y);

gray_line[i] = rgb_to_gray(r, g, b);
}
}

} // namespace

torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
Expand Down Expand Up @@ -102,20 +154,29 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
jpeg_read_header(&cinfo, TRUE);

int channels = cinfo.num_components;
bool cmyk_to_rgb_or_gray = false;

if (mode != IMAGE_READ_MODE_UNCHANGED) {
switch (mode) {
case IMAGE_READ_MODE_GRAY:
if (cinfo.jpeg_color_space != JCS_GRAYSCALE) {
if (cinfo.jpeg_color_space == JCS_CMYK ||
cinfo.jpeg_color_space == JCS_YCCK) {
cinfo.out_color_space = JCS_CMYK;
cmyk_to_rgb_or_gray = true;
} else {
cinfo.out_color_space = JCS_GRAYSCALE;
channels = 1;
}
channels = 1;
break;
case IMAGE_READ_MODE_RGB:
if (cinfo.jpeg_color_space != JCS_RGB) {
if (cinfo.jpeg_color_space == JCS_CMYK ||
cinfo.jpeg_color_space == JCS_YCCK) {
cinfo.out_color_space = JCS_CMYK;
cmyk_to_rgb_or_gray = true;
} else {
cinfo.out_color_space = JCS_RGB;
channels = 3;
}
channels = 3;
break;
/*
* Libjpeg does not support converting from CMYK to grayscale etc. There
Expand All @@ -139,12 +200,28 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
auto tensor =
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
auto ptr = tensor.data_ptr<uint8_t>();
torch::Tensor cmyk_line_tensor;
if (cmyk_to_rgb_or_gray) {
cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8);
}

while (cinfo.output_scanline < cinfo.output_height) {
/* jpeg_read_scanlines expects an array of pointers to scanlines.
* Here the array is only one element long, but you could ask for
* more than one scanline at a time if that's more convenient.
*/
jpeg_read_scanlines(&cinfo, &ptr, 1);
if (cmyk_to_rgb_or_gray) {
auto cmyk_line_ptr = cmyk_line_tensor.data_ptr<uint8_t>();
jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1);

if (channels == 3) {
convert_line_cmyk_to_rgb(&cinfo, cmyk_line_ptr, ptr);
} else if (channels == 1) {
convert_line_cmyk_to_gray(&cinfo, cmyk_line_ptr, ptr);
}
} else {
jpeg_read_scanlines(&cinfo, &ptr, 1);
}
ptr += stride;
}

Expand Down
6 changes: 3 additions & 3 deletions torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ def subset(self, indices: List[int]) -> "VideoClips":
}
return type(self)(
video_paths,
self.num_frames,
self.step,
self.frame_rate,
clip_length_in_frames=self.num_frames,
frames_between_clips=self.step,
frame_rate=self.frame_rate,
_precomputed_metadata=metadata,
num_workers=self.num_workers,
_video_width=self._video_width,
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class MixUp(_BaseMixUpCutMix):
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
Expand Down Expand Up @@ -279,7 +279,7 @@ class CutMix(_BaseMixUpCutMix):
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter a the labels if it's a tensor. This covers the most
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
inputs = inputs[1]

# MixUp, CutMix
if isinstance(inputs, torch.Tensor):
if is_pure_tensor(inputs):
return inputs

if not isinstance(inputs, collections.abc.Mapping):
Expand Down

0 comments on commit ba3d230

Please sign in to comment.