diff --git a/documentation/source/ModelPredictions.md b/documentation/source/ModelPredictions.md index 765c08a6dc..fceca00c2c 100644 --- a/documentation/source/ModelPredictions.md +++ b/documentation/source/ModelPredictions.md @@ -8,6 +8,44 @@ The model used in this tutorial is [YOLO-NAS](YoloNASQuickstart.md), pre-trained *Note that the `model.predict()` method is currently only available for detection tasks.* +## Supported Media Formats + +A `mode.predict()` method is built to handle multiple data formats and types. +Here is the full list of what `predict()` method can handle: + +| Argument Semantics | Argument Type | Supported layout | Example | Notes | +|------------------------------------|--------------------|-----------------------------------|------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------| +| Path to local image | `str` | - | `predict("path/to/image.jpg")` | All common image extensions are supported. | +| Path to images directory | `str` | - | `predict("path/to/images/directory")` | | +| Path to local video | `str` | - | `predict("path/to/video.mp4")` | All common video extensions are supported. | +| URL to remote image | `str` | - | `predict("https://example.com/image.jpg")` | | +| 3-dimensional Numpy image | `np.ndarray` | `[H, W, C]` | `predict(np.zeros((480, 640, 3), dtype=np.uint8))` | Channels last, RGB channel order for 3-channel images | +| 4-dimensional Numpy image | `np.ndarray` | `[N, H, W, C]` or `[N, C, H, W]` | `predict(np.zeros((480, 640, 3), dtype=np.uint8))` | Tensor layout (NHWC or NCHW) is inferred w.r.t to number of input channels of underlying model | +| List of 3-dimensional numpy arrays | `List[np.ndarray]` | `[H1, W1, C]`, `[H2, W2, C]`, ... | `predict([np.zeros((480, 640, 3), dtype=np.uint8), np.zeros((384, 512, 3), dtype=np.uint8) ])` | Images may vary in size, but should have same number of channels | +| 3-dimensional Torch Tensor | `torch.Tensor` | `[H, W, C]` or `[C, H, W]` | `predict(torch.zeros((480, 640, 3), dtype=torch.uint8))` | Tensor layout (HWC or CHW) is inferred w.r.t to number of input channels of underlying model | +| 4-dimensional Torch Tensor | `torch.Tensor` | `[N, H, W, C]` or `[N, C, H, W]` | `predict(torch.zeros((4, 480, 640, 3), dtype=torch.uint8))` | Tensor layout (NHWC or NCHW) is inferred w.r.t to number of input channels of underlying model | + +**Important note** - When using batched input (4-dimensional `np.ndarray` or `torch.Tensor`) formats, **normalization and size preprocessing will be applied to these inputs**. +This means that the input tensors **should not** be normalized beforehand. +Here is the example of **incorrect** code of using `model.predict()`: + +```python +# Incorrect code example. Do not use it. +from super_gradients.training import dataloaders +from super_gradients.common.object_names import Models +from super_gradients.training import models + +val_loader = dataloaders.get("coco2017_val_yolo_nas") + +model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco") + +for (inputs, *_) in val_loader: # Error here: inputs as already normalized by dataset class + model.predict(inputs).show() # This will not work as expected +``` + +Since `model.predict()` encapsulates normalization and size preprocessing, it is not designed to handle pre-normalized images as input. +Please keep this in mind when using `model.predict()` with batched inputs. + ## Detect Objects in Multiple Images diff --git a/src/super_gradients/training/pipelines/pipelines.py b/src/super_gradients/training/pipelines/pipelines.py index f099eb592d..1812a35929 100644 --- a/src/super_gradients/training/pipelines/pipelines.py +++ b/src/super_gradients/training/pipelines/pipelines.py @@ -118,6 +118,7 @@ def predict_images(self, images: Union[ImageSource, List[ImageSource]], batch_si from super_gradients.training.utils.media.image import load_images images = load_images(images) + result_generator = self._generate_prediction_result(images=images, batch_size=batch_size) return self._combine_image_prediction_to_images(result_generator, n_images=len(images)) diff --git a/src/super_gradients/training/utils/media/image.py b/src/super_gradients/training/utils/media/image.py index 3ef3b84ac4..9a007eb444 100644 --- a/src/super_gradients/training/utils/media/image.py +++ b/src/super_gradients/training/utils/media/image.py @@ -1,3 +1,4 @@ +import warnings from typing import Union, List, Iterable, Iterator from typing_extensions import get_args import PIL @@ -14,7 +15,7 @@ IMG_EXTENSIONS = ("bmp", "dng", "jpeg", "jpg", "mpo", "pfm", "pgm", "png", "ppm", "tif", "tiff", "webp") SingleImageSource = Union[str, np.ndarray, torch.Tensor, PIL.Image.Image] -ImageSource = Union[SingleImageSource, List[SingleImageSource]] +ImageSource = Union[SingleImageSource, List[SingleImageSource], Iterator[SingleImageSource]] def load_images(images: Union[List[ImageSource], ImageSource]) -> List[np.ndarray]: @@ -44,31 +45,49 @@ def generate_image_loader(images: Union[List[ImageSource], ImageSource]) -> Iter - List: A list of images of any of the above types. :param images: Single image or a list of images of supported types. - :return: Generator of images as numpy arrays. If loaded from string, the image will be returned as RGB. + :return: Generator of images as numpy arrays (H, W, C). If loaded from string, the image will be returned as RGB. """ if isinstance(images, str) and os.path.isdir(images): images_paths = list_images_in_folder(images) for image_path in images_paths: yield load_image(image=image_path) - elif isinstance(images, (list, Iterator)): + elif _is_4d_array(images): + warnings.warn( + "It seems you are using predict() with 4D array as input. " + "Please note we cannot track whether the input was already normalized or not. " + "You will get incorrect results if you feed batches from train/validation dataloader that were already normalized." + "Please check https://docs.deci.ai/super-gradients/latest/documentation/source/ModelPredictions.html for more details." + ) + for image in images: + yield load_image(image=image) + elif _is_list_of_images(images=images): + warnings.warn("It seems you are using predict() with batch input") for image in images: yield load_image(image=image) else: yield load_image(image=images) +def _is_4d_array(images: ImageSource) -> bool: + return isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4 + + +def _is_list_of_images(images: ImageSource) -> bool: + return isinstance(images, (list, Iterator)) + + def list_images_in_folder(directory: str) -> List[str]: """List all the images in a directory. :param directory: The path to the directory containing the images. :return: A list of image file names. """ files = os.listdir(directory) - images_paths = [os.path.join(directory, f) for f in files if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif"))] + images_paths = [os.path.join(directory, f) for f in files if is_image(f)] return images_paths -def load_image(image: ImageSource) -> np.ndarray: - """Load a single image and return it as a numpy arrays. +def load_image(image: ImageSource, input_image_channels: int = 3) -> np.ndarray: + """Load a single image and return it as a numpy arrays (H, W, C). Supported image types include: - numpy.ndarray: A numpy array representing the image @@ -77,12 +96,26 @@ def load_image(image: ImageSource) -> np.ndarray: - str: A string representing either a local file path or a URL to an image :param image: Single image of supported types. - :return: Image as numpy arrays. If loaded from string, the image will be returned as RGB. + :param input_image_channels: Number of channels that model expects as input. + This value helps to infer the layout of the input image array. + As of now this argument has default value of 3, but in future it will become mandatory. + + :return: Image as numpy arrays (H, W, C). If loaded from string, the image will be returned as RGB. """ if isinstance(image, np.ndarray): + if image.ndim != 3: + raise ValueError(f"Unsupported image shape: {image.shape}. This function only supports 3-dimensional images.") + if image.shape[0] == input_image_channels: + image = np.ascontiguousarray(image.transpose((1, 2, 0))) + elif image.shape[2] == input_image_channels: + pass + else: + raise ValueError(f"Cannot infer image layout (HWC or CHW) for image of shape {image.shape} while C is {input_image_channels}") + return image elif isinstance(image, torch.Tensor): - return image.numpy() + image = image.detach().cpu().numpy() + return load_image(image=image, input_image_channels=input_image_channels) elif isinstance(image, PIL.Image.Image): return load_np_image_from_pil(image) elif isinstance(image, str): diff --git a/tests/unit_tests/test_media_utils.py b/tests/unit_tests/test_media_utils.py new file mode 100644 index 0000000000..881cb110e1 --- /dev/null +++ b/tests/unit_tests/test_media_utils.py @@ -0,0 +1,46 @@ +import unittest + +import numpy as np +import torch + +from super_gradients.training.utils.media.image import load_images + + +class TrainingParamsTest(unittest.TestCase): + def test_load_images(self): + + # list - numpy + list_images = [np.zeros((3, 100, 100)) for _ in range(15)] + loaded_images = load_images(list_images) + self.assertEqual(len(loaded_images), 15) + for image in loaded_images: + self.assertIsInstance(image, np.ndarray) + self.assertEqual(image.shape, (100, 100, 3)) + + # numpy - batch + np_images = np.zeros((15, 3, 100, 100)) + loaded_images = load_images(np_images) + self.assertEqual(len(loaded_images), 15) + for image in loaded_images: + self.assertIsInstance(image, np.ndarray) + self.assertEqual(image.shape, (100, 100, 3)) + + # list - torcj + list_images = [torch.zeros((3, 100, 100)) for _ in range(15)] + loaded_images = load_images(list_images) + self.assertEqual(len(loaded_images), 15) + for image in loaded_images: + self.assertIsInstance(image, np.ndarray) + self.assertEqual(image.shape, (100, 100, 3)) + + # torch - batch + torch_images = torch.zeros((15, 3, 100, 100)) + loaded_images = load_images(torch_images) + self.assertEqual(len(loaded_images), 15) + for image in loaded_images: + self.assertIsInstance(image, np.ndarray) + self.assertEqual(image.shape, (100, 100, 3)) + + +if __name__ == "__main__": + unittest.main()