From 9f0afd55394454e5686af69efc4a38905a3f96c4 Mon Sep 17 00:00:00 2001 From: vfdev Date: Thu, 24 Aug 2023 17:51:05 +0200 Subject: [PATCH] Replaced ConvertImageDtype by ToDtype in reference scripts (#7862) Co-authored-by: Nicolas Hug --- references/classification/presets.py | 4 ++-- references/detection/presets.py | 4 ++-- references/detection/transforms.py | 7 +++++-- references/segmentation/presets.py | 4 ++-- references/segmentation/transforms.py | 7 +++++-- references/segmentation/v2_extras.py | 2 +- 6 files changed, 17 insertions(+), 11 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index 84651493f01..8653957a576 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -61,7 +61,7 @@ def __init__( transforms.extend( [ - T.ConvertImageDtype(torch.float), + T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), ] ) @@ -106,7 +106,7 @@ def __init__( transforms.append(T.PILToTensor()) transforms += [ - T.ConvertImageDtype(torch.float), + T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), ] diff --git a/references/detection/presets.py b/references/detection/presets.py index 0949a99896e..e7b2ca35792 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -73,7 +73,7 @@ def __init__( # Note: we could just convert to pure tensors even in v2. transforms += [T.ToImage() if use_v2 else T.PILToTensor()] - transforms += [T.ConvertImageDtype(torch.float)] + transforms += [T.ToDtype(torch.float, scale=True)] if use_v2: transforms += [ @@ -103,7 +103,7 @@ def __init__(self, backend="pil", use_v2=False): else: raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") - transforms += [T.ConvertImageDtype(torch.float)] + transforms += [T.ToDtype(torch.float, scale=True)] if use_v2: transforms += [T.ToPureTensor()] diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 65cf4e83592..e07ccfc9921 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -53,14 +53,17 @@ def forward( return image, target -class ConvertImageDtype(nn.Module): - def __init__(self, dtype: torch.dtype) -> None: +class ToDtype(nn.Module): + def __init__(self, dtype: torch.dtype, scale: bool = False) -> None: super().__init__() self.dtype = dtype + self.scale = scale def forward( self, image: Tensor, target: Optional[Dict[str, Tensor]] = None ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + if not self.scale: + return image.to(dtype=self.dtype), target image = F.convert_image_dtype(image, self.dtype) return image, target diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index 7b7d0493bd2..b0539fcca3f 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -60,7 +60,7 @@ def __init__( ] else: # No need to explicitly convert masks as they're magically int64 already - transforms += [T.ConvertImageDtype(torch.float)] + transforms += [T.ToDtype(torch.float, scale=True)] transforms += [T.Normalize(mean=mean, std=std)] if use_v2: @@ -97,7 +97,7 @@ def __init__( transforms += [T.ToImage() if use_v2 else T.PILToTensor()] transforms += [ - T.ConvertImageDtype(torch.float), + T.ToDtype(torch.float, scale=True), T.Normalize(mean=mean, std=std), ] if use_v2: diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 2b3e79b1461..6934b9f862e 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -81,11 +81,14 @@ def __call__(self, image, target): return image, target -class ConvertImageDtype: - def __init__(self, dtype): +class ToDtype: + def __init__(self, dtype, scale=False): self.dtype = dtype + self.scale = scale def __call__(self, image, target): + if not self.scale: + return image.to(dtype=self.dtype), target image = F.convert_image_dtype(image, self.dtype) return image, target diff --git a/references/segmentation/v2_extras.py b/references/segmentation/v2_extras.py index 137a00ccf55..ae55f0727a4 100644 --- a/references/segmentation/v2_extras.py +++ b/references/segmentation/v2_extras.py @@ -78,6 +78,6 @@ def _coco_detection_masks_to_voc_segmentation_mask(self, target): def forward(self, image, target): segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target) if segmentation_mask is None: - segmentation_mask = torch.zeros(v2.functional.get_spatial_size(image), dtype=torch.uint8) + segmentation_mask = torch.zeros(v2.functional.get_size(image), dtype=torch.uint8) return image, datapoints.Mask(segmentation_mask)