diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8efe2a8878a..8d4e8470afc 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -7,6 +7,7 @@ import numpy as np import torch +from numpy.typing import NDArray from PIL import Image from PIL.Image import Image as PILImage from torch import Tensor @@ -124,7 +125,7 @@ def _is_numpy_image(img: Any) -> bool: return img.ndim in {2, 3} -def to_tensor(pic: Union[PILImage, np.ndarray]) -> Tensor: +def to_tensor(pic: Union[PILImage.Image, NDArray[Union[np.uint8, np.float32]]]) -> Tensor: """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This function does not support torchscript.