Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix/sg 000 fix predict loading from np torch #1419

Merged
merged 9 commits into from
Sep 19, 2023
1 change: 1 addition & 0 deletions src/super_gradients/training/pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def predict_images(self, images: Union[ImageSource, List[ImageSource]], batch_si
from super_gradients.training.utils.media.image import load_images

images = load_images(images)

result_generator = self._generate_prediction_result(images=images, batch_size=batch_size)
return self._combine_image_prediction_to_images(result_generator, n_images=len(images))

Expand Down
21 changes: 16 additions & 5 deletions src/super_gradients/training/utils/media/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,25 @@ def generate_image_loader(images: Union[List[ImageSource], ImageSource]) -> Iter
- List: A list of images of any of the above types.

:param images: Single image or a list of images of supported types.
:return: Generator of images as numpy arrays. If loaded from string, the image will be returned as RGB.
:return: Generator of images as numpy arrays (H, W, C). If loaded from string, the image will be returned as RGB.
"""
if isinstance(images, str) and os.path.isdir(images):
images_paths = list_images_in_folder(images)
for image_path in images_paths:
yield load_image(image=image_path)
elif isinstance(images, (list, Iterator)):
elif _is_batch_of_images(images=images):
for image in images:
yield load_image(image=image)
else:
yield load_image(image=images)


def _is_batch_of_images(images: ImageSource) -> bool:
return (
isinstance(images, (list, Iterator)) or (isinstance(images, np.ndarray) and images.ndim == 4) or (isinstance(images, torch.Tensor) and images.ndim == 4)
)


def list_images_in_folder(directory: str) -> List[str]:
"""List all the images in a directory.
:param directory: The path to the directory containing the images.
Expand All @@ -68,7 +74,7 @@ def list_images_in_folder(directory: str) -> List[str]:


def load_image(image: ImageSource) -> np.ndarray:
"""Load a single image and return it as a numpy arrays.
"""Load a single image and return it as a numpy arrays (H, W, C).

Supported image types include:
- numpy.ndarray: A numpy array representing the image
Expand All @@ -77,12 +83,17 @@ def load_image(image: ImageSource) -> np.ndarray:
- str: A string representing either a local file path or a URL to an image

:param image: Single image of supported types.
:return: Image as numpy arrays. If loaded from string, the image will be returned as RGB.
:return: Image as numpy arrays (H, W, C). If loaded from string, the image will be returned as RGB.
"""
if isinstance(image, np.ndarray):
if image.shape[0] == 3:
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
image = image.transpose((1, 2, 0))
return image
elif isinstance(image, torch.Tensor):
return image.numpy()
image = image.cpu().numpy()
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
if image.shape[0] == 3:
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
image = image.transpose((1, 2, 0))
return image
elif isinstance(image, PIL.Image.Image):
return load_np_image_from_pil(image)
elif isinstance(image, str):
Expand Down
46 changes: 46 additions & 0 deletions tests/unit_tests/test_media_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest

import numpy as np
import torch

from super_gradients.training.utils.media.image import load_images


class TrainingParamsTest(unittest.TestCase):
def test_load_images(self):

# list - numpy
list_images = [np.zeros((3, 100, 100)) for _ in range(15)]
loaded_images = load_images(list_images)
self.assertEqual(len(loaded_images), 15)
for image in loaded_images:
self.assertIsInstance(image, np.ndarray)
self.assertEqual(image.shape, (100, 100, 3))

# numpy - batch
np_images = np.zeros((15, 3, 100, 100))
loaded_images = load_images(np_images)
self.assertEqual(len(loaded_images), 15)
for image in loaded_images:
self.assertIsInstance(image, np.ndarray)
self.assertEqual(image.shape, (100, 100, 3))

# list - torcj
list_images = [torch.zeros((3, 100, 100)) for _ in range(15)]
loaded_images = load_images(list_images)
self.assertEqual(len(loaded_images), 15)
for image in loaded_images:
self.assertIsInstance(image, np.ndarray)
self.assertEqual(image.shape, (100, 100, 3))

# torch - batch
torch_images = torch.zeros((15, 3, 100, 100))
loaded_images = load_images(torch_images)
self.assertEqual(len(loaded_images), 15)
for image in loaded_images:
self.assertIsInstance(image, np.ndarray)
self.assertEqual(image.shape, (100, 100, 3))


if __name__ == "__main__":
unittest.main()