Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix/sg 000 fix predict loading from np torch #1419

Merged
merged 9 commits into from
Sep 19, 2023
38 changes: 38 additions & 0 deletions documentation/source/ModelPredictions.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,44 @@ The model used in this tutorial is [YOLO-NAS](YoloNASQuickstart.md), pre-trained

*Note that the `model.predict()` method is currently only available for detection tasks.*

## Supported Media Formats

A `mode.predict()` method is built to handle multiple data formats and types.
Here is the full list of what `predict()` method can handle:

| Argument Semantics | Argument Type | Supported layout | Example | Notes |
|------------------------------------|--------------------|-----------------------------------|------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------|
| Path to local image | `str` | - | `predict("path/to/image.jpg")` | All common image extensions are supported. |
| Path to images directory | `str` | - | `predict("path/to/images/directory")` | |
| Path to local video | `str` | - | `predict("path/to/video.mp4")` | All common video extensions are supported. |
| URL to remote image | `str` | - | `predict("https://example.com/image.jpg")` | |
| 3-dimensional Numpy image | `np.ndarray` | `[H, W, C]` | `predict(np.zeros((480, 640, 3), dtype=np.uint8))` | Channels last, RGB channel order for 3-channel images |
| 4-dimensional Numpy image | `np.ndarray` | `[N, H, W, C]` or `[N, C, H, W]` | `predict(np.zeros((480, 640, 3), dtype=np.uint8))` | Tensor layout (NHWC or NCHW) is inferred w.r.t to number of input channels of underlying model |
| List of 3-dimensional numpy arrays | `List[np.ndarray]` | `[H1, W1, C]`, `[H2, W2, C]`, ... | `predict([np.zeros((480, 640, 3), dtype=np.uint8), np.zeros((384, 512, 3), dtype=np.uint8) ])` | Images may vary in size, but should have same number of channels |
| 3-dimensional Torch Tensor | `torch.Tensor` | `[H, W, C]` or `[C, H, W]` | `predict(torch.zeros((480, 640, 3), dtype=torch.uint8))` | Tensor layout (HWC or CHW) is inferred w.r.t to number of input channels of underlying model |
| 4-dimensional Torch Tensor | `torch.Tensor` | `[N, H, W, C]` or `[N, C, H, W]` | `predict(torch.zeros((4, 480, 640, 3), dtype=torch.uint8))` | Tensor layout (NHWC or NCHW) is inferred w.r.t to number of input channels of underlying model |

**Important note** - When using batched input (4-dimensional `np.ndarray` or `torch.Tensor`) formats, **normalization and size preprocessing will be applied to these inputs**.
This means that the input tensors **should not** be normalized beforehand.
Here is the example of **incorrect** code of using `model.predict()`:

```python
# Incorrect code example. Do not use it.
from super_gradients.training import dataloaders
from super_gradients.common.object_names import Models
from super_gradients.training import models

val_loader = dataloaders.get("coco2017_val_yolo_nas")

model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")

for (inputs, *_) in val_loader: # Error here: inputs as already normalized by dataset class
model.predict(inputs).show() # This will not work as expected
```

Since `model.predict()` encapsulates normalization and size preprocessing, it is not designed to handle pre-normalized images as input.
Please keep this in mind when using `model.predict()` with batched inputs.


## Detect Objects in Multiple Images

Expand Down
48 changes: 35 additions & 13 deletions src/super_gradients/training/utils/media/image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Union, List, Iterable, Iterator
from typing_extensions import get_args
import PIL
Expand Down Expand Up @@ -50,17 +51,29 @@ def generate_image_loader(images: Union[List[ImageSource], ImageSource]) -> Iter
images_paths = list_images_in_folder(images)
for image_path in images_paths:
yield load_image(image=image_path)
elif _is_batch_of_images(images=images):
elif _is_4d_array(images):
warnings.warn(
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
"It seems you are using predict() with 4D array as input. "
"Please note we cannot track whether the input was already normalized or not. "
"You will get incorrect results if you feeding batches from train/validation dataloader."
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
"Please check https://docs.deci.ai/super-gradients/latest/documentation/source/ModelPredictions.html for more details."
)
for image in images:
yield load_image(image=image)
elif _is_list_of_images(images=images):
warnings.warn("It seems you are using predict() with batch input")
for image in images:
yield load_image(image=image)
else:
yield load_image(image=images)


def _is_batch_of_images(images: ImageSource) -> bool:
return (
isinstance(images, (list, Iterator)) or (isinstance(images, np.ndarray) and images.ndim == 4) or (isinstance(images, torch.Tensor) and images.ndim == 4)
)
def _is_4d_array(images: ImageSource) -> bool:
return isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4


def _is_list_of_images(images: ImageSource) -> bool:
return isinstance(images, (list, Iterator))
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved


def list_images_in_folder(directory: str) -> List[str]:
Expand All @@ -69,11 +82,11 @@ def list_images_in_folder(directory: str) -> List[str]:
:return: A list of image file names.
"""
files = os.listdir(directory)
images_paths = [os.path.join(directory, f) for f in files if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif"))]
images_paths = [os.path.join(directory, f) for f in files if is_image(f)]
return images_paths


def load_image(image: ImageSource) -> np.ndarray:
def load_image(image: ImageSource, input_image_channels: int = 3) -> np.ndarray:
"""Load a single image and return it as a numpy arrays (H, W, C).

Supported image types include:
Expand All @@ -83,17 +96,26 @@ def load_image(image: ImageSource) -> np.ndarray:
- str: A string representing either a local file path or a URL to an image

:param image: Single image of supported types.
:param input_image_channels: Number of channels that model expects as input.
This value helps to infer the layout of the input image array.
As of now this argument has default value of 3, but in future it will become mandatory.
shaydeci marked this conversation as resolved.
Show resolved Hide resolved

:return: Image as numpy arrays (H, W, C). If loaded from string, the image will be returned as RGB.
"""
if isinstance(image, np.ndarray):
if image.shape[0] == 3:
image = image.transpose((1, 2, 0))
if image.ndim != 3:
raise ValueError(f"Unsupported image shape: {image.shape}. This function only supports 3-dimensional images.")
if image.shape[0] == input_image_channels:
image = np.ascontiguousarray(image.transpose((1, 2, 0)))
elif image.shape[2] == input_image_channels:
pass
else:
raise ValueError(f"Cannot infer image layout (HWC or CHW) for image of shape {image.shape} while C is {input_image_channels}")

return image
elif isinstance(image, torch.Tensor):
image = image.cpu().numpy()
if image.shape[0] == 3:
image = image.transpose((1, 2, 0))
return image
image = image.detach().cpu().numpy()
return load_image(image=image, input_image_channels=input_image_channels)
elif isinstance(image, PIL.Image.Image):
return load_np_image_from_pil(image)
elif isinstance(image, str):
Expand Down