Skip to content

Commit

Permalink
When creating processing instances some arguments may be not Tuple bu…
Browse files Browse the repository at this point in the history
…t ListConfig and silently serialized to state dict with pickle. By adding explicit casting to primitive tuple type we fix that (#1534)
  • Loading branch information
BloodAxe committed Oct 16, 2023
1 parent 69d2594 commit c74b034
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/super_gradients/training/processing/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class ImagePermute(Processing):
"""

def __init__(self, permutation: Tuple[int, int, int] = (2, 0, 1)):
self.permutation = permutation
self.permutation = tuple(permutation)

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]:
processed_image = np.ascontiguousarray(image.transpose(*self.permutation))
Expand Down Expand Up @@ -187,7 +187,7 @@ class StandardizeImage(Processing):

def __init__(self, max_value: float = 255.0):
super().__init__()
self.max_value = max_value
self.max_value = float(max_value)

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]:
"""Reverse the channel order of an image.
Expand Down Expand Up @@ -246,7 +246,7 @@ class _DetectionPadding(Processing, ABC):
"""

def __init__(self, output_shape: Tuple[int, int], pad_value: int):
self.output_shape = output_shape
self.output_shape = tuple(output_shape)
self.pad_value = pad_value

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, DetectionPadToSizeMetadata]:
Expand Down Expand Up @@ -288,7 +288,7 @@ class _KeypointsPadding(Processing, ABC):
"""

def __init__(self, output_shape: Tuple[int, int], pad_value: int):
self.output_shape = output_shape
self.output_shape = tuple(output_shape)
self.pad_value = pad_value

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, DetectionPadToSizeMetadata]:
Expand Down Expand Up @@ -351,7 +351,7 @@ class _Rescale(Processing, ABC):
"""

def __init__(self, output_shape: Tuple[int, int]):
self.output_shape = output_shape
self.output_shape = tuple(output_shape)

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, RescaleMetadata]:
scale_factor_h, scale_factor_w = self.output_shape[0] / image.shape[0], self.output_shape[1] / image.shape[1]
Expand All @@ -378,7 +378,7 @@ class _LongestMaxSizeRescale(Processing, ABC):
"""

def __init__(self, output_shape: Tuple[int, int]):
self.output_shape = output_shape
self.output_shape = tuple(output_shape)

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, RescaleMetadata]:
height, width = image.shape[:2]
Expand Down Expand Up @@ -443,7 +443,7 @@ def postprocess_predictions(self, predictions: Prediction, metadata: None) -> Pr
class Resize(ClassificationProcess):
def __init__(self, size: int = 224):
super().__init__()
self.size = size
self.size = int(size)

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]:
"""Resize an image.
Expand Down Expand Up @@ -477,7 +477,7 @@ class CenterCrop(ClassificationProcess):

def __init__(self, size: int = 224):
super().__init__()
self.size = size
self.size = int(size)

def preprocess_image(self, image: np.ndarray) -> Tuple[np.ndarray, None]:
"""Crops the given image at the center.
Expand Down

0 comments on commit c74b034

Please sign in to comment.