Skip to content

Commit

Permalink
Hotfix/sg 000 fix predict loading from np torch (#1419)
Browse files Browse the repository at this point in the history
* fix

* Added table with all supported input types to predict() and improved load_image method to get rid of hard-coded number of input channels

* Improve spelling

* Improve type alias

---------

Co-authored-by: Shay Aharon <80472096+shaydeci@users.noreply.github.com>
Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
  • Loading branch information
3 people committed Sep 19, 2023
1 parent 61f46e3 commit 4536d2d
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 8 deletions.
38 changes: 38 additions & 0 deletions documentation/source/ModelPredictions.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,44 @@ The model used in this tutorial is [YOLO-NAS](YoloNASQuickstart.md), pre-trained

*Note that the `model.predict()` method is currently only available for detection tasks.*

## Supported Media Formats

A `mode.predict()` method is built to handle multiple data formats and types.
Here is the full list of what `predict()` method can handle:

| Argument Semantics | Argument Type | Supported layout | Example | Notes |
|------------------------------------|--------------------|-----------------------------------|------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------|
| Path to local image | `str` | - | `predict("path/to/image.jpg")` | All common image extensions are supported. |
| Path to images directory | `str` | - | `predict("path/to/images/directory")` | |
| Path to local video | `str` | - | `predict("path/to/video.mp4")` | All common video extensions are supported. |
| URL to remote image | `str` | - | `predict("https://example.com/image.jpg")` | |
| 3-dimensional Numpy image | `np.ndarray` | `[H, W, C]` | `predict(np.zeros((480, 640, 3), dtype=np.uint8))` | Channels last, RGB channel order for 3-channel images |
| 4-dimensional Numpy image | `np.ndarray` | `[N, H, W, C]` or `[N, C, H, W]` | `predict(np.zeros((480, 640, 3), dtype=np.uint8))` | Tensor layout (NHWC or NCHW) is inferred w.r.t to number of input channels of underlying model |
| List of 3-dimensional numpy arrays | `List[np.ndarray]` | `[H1, W1, C]`, `[H2, W2, C]`, ... | `predict([np.zeros((480, 640, 3), dtype=np.uint8), np.zeros((384, 512, 3), dtype=np.uint8) ])` | Images may vary in size, but should have same number of channels |
| 3-dimensional Torch Tensor | `torch.Tensor` | `[H, W, C]` or `[C, H, W]` | `predict(torch.zeros((480, 640, 3), dtype=torch.uint8))` | Tensor layout (HWC or CHW) is inferred w.r.t to number of input channels of underlying model |
| 4-dimensional Torch Tensor | `torch.Tensor` | `[N, H, W, C]` or `[N, C, H, W]` | `predict(torch.zeros((4, 480, 640, 3), dtype=torch.uint8))` | Tensor layout (NHWC or NCHW) is inferred w.r.t to number of input channels of underlying model |

**Important note** - When using batched input (4-dimensional `np.ndarray` or `torch.Tensor`) formats, **normalization and size preprocessing will be applied to these inputs**.
This means that the input tensors **should not** be normalized beforehand.
Here is the example of **incorrect** code of using `model.predict()`:

```python
# Incorrect code example. Do not use it.
from super_gradients.training import dataloaders
from super_gradients.common.object_names import Models
from super_gradients.training import models

val_loader = dataloaders.get("coco2017_val_yolo_nas")

model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")

for (inputs, *_) in val_loader: # Error here: inputs as already normalized by dataset class
model.predict(inputs).show() # This will not work as expected
```

Since `model.predict()` encapsulates normalization and size preprocessing, it is not designed to handle pre-normalized images as input.
Please keep this in mind when using `model.predict()` with batched inputs.


## Detect Objects in Multiple Images

Expand Down
1 change: 1 addition & 0 deletions src/super_gradients/training/pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def predict_images(self, images: Union[ImageSource, List[ImageSource]], batch_si
from super_gradients.training.utils.media.image import load_images

images = load_images(images)

result_generator = self._generate_prediction_result(images=images, batch_size=batch_size)
return self._combine_image_prediction_to_images(result_generator, n_images=len(images))

Expand Down
49 changes: 41 additions & 8 deletions src/super_gradients/training/utils/media/image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Union, List, Iterable, Iterator
from typing_extensions import get_args
import PIL
Expand All @@ -14,7 +15,7 @@

IMG_EXTENSIONS = ("bmp", "dng", "jpeg", "jpg", "mpo", "pfm", "pgm", "png", "ppm", "tif", "tiff", "webp")
SingleImageSource = Union[str, np.ndarray, torch.Tensor, PIL.Image.Image]
ImageSource = Union[SingleImageSource, List[SingleImageSource]]
ImageSource = Union[SingleImageSource, List[SingleImageSource], Iterator[SingleImageSource]]


def load_images(images: Union[List[ImageSource], ImageSource]) -> List[np.ndarray]:
Expand Down Expand Up @@ -44,31 +45,49 @@ def generate_image_loader(images: Union[List[ImageSource], ImageSource]) -> Iter
- List: A list of images of any of the above types.
:param images: Single image or a list of images of supported types.
:return: Generator of images as numpy arrays. If loaded from string, the image will be returned as RGB.
:return: Generator of images as numpy arrays (H, W, C). If loaded from string, the image will be returned as RGB.
"""
if isinstance(images, str) and os.path.isdir(images):
images_paths = list_images_in_folder(images)
for image_path in images_paths:
yield load_image(image=image_path)
elif isinstance(images, (list, Iterator)):
elif _is_4d_array(images):
warnings.warn(
"It seems you are using predict() with 4D array as input. "
"Please note we cannot track whether the input was already normalized or not. "
"You will get incorrect results if you feed batches from train/validation dataloader that were already normalized."
"Please check https://docs.deci.ai/super-gradients/latest/documentation/source/ModelPredictions.html for more details."
)
for image in images:
yield load_image(image=image)
elif _is_list_of_images(images=images):
warnings.warn("It seems you are using predict() with batch input")
for image in images:
yield load_image(image=image)
else:
yield load_image(image=images)


def _is_4d_array(images: ImageSource) -> bool:
return isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4


def _is_list_of_images(images: ImageSource) -> bool:
return isinstance(images, (list, Iterator))


def list_images_in_folder(directory: str) -> List[str]:
"""List all the images in a directory.
:param directory: The path to the directory containing the images.
:return: A list of image file names.
"""
files = os.listdir(directory)
images_paths = [os.path.join(directory, f) for f in files if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif"))]
images_paths = [os.path.join(directory, f) for f in files if is_image(f)]
return images_paths


def load_image(image: ImageSource) -> np.ndarray:
"""Load a single image and return it as a numpy arrays.
def load_image(image: ImageSource, input_image_channels: int = 3) -> np.ndarray:
"""Load a single image and return it as a numpy arrays (H, W, C).
Supported image types include:
- numpy.ndarray: A numpy array representing the image
Expand All @@ -77,12 +96,26 @@ def load_image(image: ImageSource) -> np.ndarray:
- str: A string representing either a local file path or a URL to an image
:param image: Single image of supported types.
:return: Image as numpy arrays. If loaded from string, the image will be returned as RGB.
:param input_image_channels: Number of channels that model expects as input.
This value helps to infer the layout of the input image array.
As of now this argument has default value of 3, but in future it will become mandatory.
:return: Image as numpy arrays (H, W, C). If loaded from string, the image will be returned as RGB.
"""
if isinstance(image, np.ndarray):
if image.ndim != 3:
raise ValueError(f"Unsupported image shape: {image.shape}. This function only supports 3-dimensional images.")
if image.shape[0] == input_image_channels:
image = np.ascontiguousarray(image.transpose((1, 2, 0)))
elif image.shape[2] == input_image_channels:
pass
else:
raise ValueError(f"Cannot infer image layout (HWC or CHW) for image of shape {image.shape} while C is {input_image_channels}")

return image
elif isinstance(image, torch.Tensor):
return image.numpy()
image = image.detach().cpu().numpy()
return load_image(image=image, input_image_channels=input_image_channels)
elif isinstance(image, PIL.Image.Image):
return load_np_image_from_pil(image)
elif isinstance(image, str):
Expand Down
46 changes: 46 additions & 0 deletions tests/unit_tests/test_media_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest

import numpy as np
import torch

from super_gradients.training.utils.media.image import load_images


class TrainingParamsTest(unittest.TestCase):
def test_load_images(self):

# list - numpy
list_images = [np.zeros((3, 100, 100)) for _ in range(15)]
loaded_images = load_images(list_images)
self.assertEqual(len(loaded_images), 15)
for image in loaded_images:
self.assertIsInstance(image, np.ndarray)
self.assertEqual(image.shape, (100, 100, 3))

# numpy - batch
np_images = np.zeros((15, 3, 100, 100))
loaded_images = load_images(np_images)
self.assertEqual(len(loaded_images), 15)
for image in loaded_images:
self.assertIsInstance(image, np.ndarray)
self.assertEqual(image.shape, (100, 100, 3))

# list - torcj
list_images = [torch.zeros((3, 100, 100)) for _ in range(15)]
loaded_images = load_images(list_images)
self.assertEqual(len(loaded_images), 15)
for image in loaded_images:
self.assertIsInstance(image, np.ndarray)
self.assertEqual(image.shape, (100, 100, 3))

# torch - batch
torch_images = torch.zeros((15, 3, 100, 100))
loaded_images = load_images(torch_images)
self.assertEqual(len(loaded_images), 15)
for image in loaded_images:
self.assertIsInstance(image, np.ndarray)
self.assertEqual(image.shape, (100, 100, 3))


if __name__ == "__main__":
unittest.main()

0 comments on commit 4536d2d

Please sign in to comment.