From d849e85a602c89331e69dbbcfa5f04fd6b77285b Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Thu, 8 Aug 2024 11:34:56 +0000 Subject: [PATCH] 2024-08-08 nightly release (0d8084811eb442c8211a71017b7e02143d97fb0d) --- benchmarks/encoding.py | 67 -- benchmarks/encoding_decoding.py | 99 +++ test/test_image.py | 121 +++- .../csrc/io/image/cuda/decode_jpeg_cuda.cpp | 208 ------ .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 603 ++++++++++++++++++ .../csrc/io/image/cuda/decode_jpegs_cuda.h | 45 ++ .../io/image/cuda/encode_decode_jpegs_cuda.h | 45 +- torchvision/csrc/io/image/image.cpp | 2 +- torchvision/io/image.py | 54 +- .../transforms/v2/functional/_augment.py | 7 +- 10 files changed, 934 insertions(+), 317 deletions(-) delete mode 100644 benchmarks/encoding.py create mode 100644 benchmarks/encoding_decoding.py delete mode 100644 torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp create mode 100644 torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp create mode 100644 torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h diff --git a/benchmarks/encoding.py b/benchmarks/encoding.py deleted file mode 100644 index f994b03c783..00000000000 --- a/benchmarks/encoding.py +++ /dev/null @@ -1,67 +0,0 @@ -import os -import platform -import statistics - -import torch -import torch.utils.benchmark as benchmark -import torchvision - - -def print_machine_specs(): - print("Processor:", platform.processor()) - print("Platform:", platform.platform()) - print("Logical CPUs:", os.cpu_count()) - print(f"\nCUDA device: {torch.cuda.get_device_name()}") - print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") - - -def get_data(): - transform = torchvision.transforms.Compose( - [ - torchvision.transforms.PILToTensor(), - ] - ) - path = os.path.join(os.getcwd(), "data") - testset = torchvision.datasets.Places365( - root="./data", download=not os.path.exists(path), transform=transform, split="val" - ) - testloader = torch.utils.data.DataLoader( - testset, batch_size=1000, shuffle=False, num_workers=1, collate_fn=lambda batch: [r[0] for r in batch] - ) - return next(iter(testloader)) - - -def run_benchmark(batch): - results = [] - for device in ["cpu", "cuda"]: - batch_device = [t.to(device=device) for t in batch] - for size in [1, 100, 1000]: - for num_threads in [1, 12, 24]: - for stmt, strat in zip( - [ - "[torchvision.io.encode_jpeg(img) for img in batch_input]", - "torchvision.io.encode_jpeg(batch_input)", - ], - ["unfused", "fused"], - ): - batch_input = batch_device[:size] - t = benchmark.Timer( - stmt=stmt, - setup="import torchvision", - globals={"batch_input": batch_input}, - label="Image Encoding", - sub_label=f"{device.upper()} ({strat}): {stmt}", - description=f"{size} images", - num_threads=num_threads, - ) - results.append(t.blocked_autorange()) - compare = benchmark.Compare(results) - compare.print() - - -if __name__ == "__main__": - print_machine_specs() - batch = get_data() - mean_h, mean_w = statistics.mean(t.shape[-2] for t in batch), statistics.mean(t.shape[-1] for t in batch) - print(f"\nMean image size: {int(mean_h)}x{int(mean_w)}") - run_benchmark(batch) diff --git a/benchmarks/encoding_decoding.py b/benchmarks/encoding_decoding.py new file mode 100644 index 00000000000..0cafdb2d8a6 --- /dev/null +++ b/benchmarks/encoding_decoding.py @@ -0,0 +1,99 @@ +import os +import platform +import statistics + +import torch +import torch.utils.benchmark as benchmark +import torchvision + + +def print_machine_specs(): + print("Processor:", platform.processor()) + print("Platform:", platform.platform()) + print("Logical CPUs:", os.cpu_count()) + print(f"\nCUDA device: {torch.cuda.get_device_name()}") + print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") + + +def get_data(): + transform = torchvision.transforms.Compose( + [ + torchvision.transforms.PILToTensor(), + ] + ) + path = os.path.join(os.getcwd(), "data") + testset = torchvision.datasets.Places365( + root="./data", download=not os.path.exists(path), transform=transform, split="val" + ) + testloader = torch.utils.data.DataLoader( + testset, batch_size=1000, shuffle=False, num_workers=1, collate_fn=lambda batch: [r[0] for r in batch] + ) + return next(iter(testloader)) + + +def run_encoding_benchmark(decoded_images): + results = [] + for device in ["cpu", "cuda"]: + decoded_images_device = [t.to(device=device) for t in decoded_images] + for size in [1, 100, 1000]: + for num_threads in [1, 12, 24]: + for stmt, strat in zip( + [ + "[torchvision.io.encode_jpeg(img) for img in decoded_images_device_trunc]", + "torchvision.io.encode_jpeg(decoded_images_device_trunc)", + ], + ["unfused", "fused"], + ): + decoded_images_device_trunc = decoded_images_device[:size] + t = benchmark.Timer( + stmt=stmt, + setup="import torchvision", + globals={"decoded_images_device_trunc": decoded_images_device_trunc}, + label="Image Encoding", + sub_label=f"{device.upper()} ({strat}): {stmt}", + description=f"{size} images", + num_threads=num_threads, + ) + results.append(t.blocked_autorange()) + compare = benchmark.Compare(results) + compare.print() + + +def run_decoding_benchmark(encoded_images): + results = [] + for device in ["cpu", "cuda"]: + for size in [1, 100, 1000]: + for num_threads in [1, 12, 24]: + for stmt, strat in zip( + [ + f"[torchvision.io.decode_jpeg(img, device='{device}') for img in encoded_images_trunc]", + f"torchvision.io.decode_jpeg(encoded_images_trunc, device='{device}')", + ], + ["unfused", "fused"], + ): + encoded_images_trunc = encoded_images[:size] + t = benchmark.Timer( + stmt=stmt, + setup="import torchvision", + globals={"encoded_images_trunc": encoded_images_trunc}, + label="Image Decoding", + sub_label=f"{device.upper()} ({strat}): {stmt}", + description=f"{size} images", + num_threads=num_threads, + ) + results.append(t.blocked_autorange()) + compare = benchmark.Compare(results) + compare.print() + + +if __name__ == "__main__": + print_machine_specs() + decoded_images = get_data() + mean_h, mean_w = statistics.mean(t.shape[-2] for t in decoded_images), statistics.mean( + t.shape[-1] for t in decoded_images + ) + print(f"\nMean image size: {int(mean_h)}x{int(mean_w)}") + run_encoding_benchmark(decoded_images) + encoded_images_cuda = torchvision.io.encode_jpeg([img.cuda() for img in decoded_images]) + encoded_images_cpu = [img.cpu() for img in encoded_images_cuda] + run_decoding_benchmark(encoded_images_cpu) diff --git a/test/test_image.py b/test/test_image.py index 005cf41b1ca..f083e53b87b 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -413,23 +413,32 @@ def test_read_interlaced_png(): @needs_cuda -@pytest.mark.parametrize( - "img_path", - [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")], -) @pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) @pytest.mark.parametrize("scripted", (False, True)) -def test_decode_jpeg_cuda(mode, img_path, scripted): - if "cmyk" in img_path: - pytest.xfail("Decoding a CMYK jpeg isn't supported") +def test_decode_jpegs_cuda(mode, scripted): + encoded_images = [] + for jpeg_path in get_images(IMAGE_ROOT, ".jpg"): + if "cmyk" in jpeg_path: + continue + encoded_image = read_file(jpeg_path) + encoded_images.append(encoded_image) + decoded_images_cpu = decode_jpeg(encoded_images, mode=mode) + decode_fn = torch.jit.script(decode_jpeg) if scripted else decode_jpeg - data = read_file(img_path) - img = decode_image(data, mode=mode) - f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg - img_nvjpeg = f(data, mode=mode, device="cuda") + # test multithreaded decoding + # in the current version we prevent this by using a lock but we still want to test it + num_workers = 10 - # Some difference expected between jpeg implementations - assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2 + with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)] + decoded_images_threaded = [future.result() for future in futures] + assert len(decoded_images_threaded) == num_workers + for decoded_images in decoded_images_threaded: + assert len(decoded_images) == len(encoded_images) + for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu): + assert decoded_image_cuda.shape == decoded_image_cpu.shape + assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8 + assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2 @needs_cuda @@ -440,12 +449,21 @@ def test_decode_image_cuda_raises(): @needs_cuda -@pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda"))) -def test_decode_jpeg_cuda_device_param(cuda_device): - """Make sure we can pass a string or a torch.device as device param""" +def test_decode_jpeg_cuda_device_param(): path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path) data = read_file(path) - decode_jpeg(data, device=cuda_device) + current_device = torch.cuda.current_device() + current_stream = torch.cuda.current_stream() + num_devices = torch.cuda.device_count() + devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)] + results = [] + for device in devices: + results.append(decode_jpeg(data, device=device)) + assert len(results) == len(devices) + for result in results: + assert torch.all(result.cpu() == results[0].cpu()) + assert current_device == torch.cuda.current_device() + assert current_stream == torch.cuda.current_stream() @needs_cuda @@ -453,12 +471,73 @@ def test_decode_jpeg_cuda_errors(): data = read_file(next(get_images(IMAGE_ROOT, ".jpg"))) with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): decode_jpeg(data.reshape(-1, 1), device="cuda") - with pytest.raises(RuntimeError, match="input tensor must be on CPU"): + with pytest.raises(ValueError, match="must be tensors"): + decode_jpeg([1, 2, 3]) + with pytest.raises(ValueError, match="Input tensor must be a CPU tensor"): decode_jpeg(data.to("cuda"), device="cuda") with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"): decode_jpeg(data.to(torch.float), device="cuda") - with pytest.raises(RuntimeError, match="Expected a cuda device"): - torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, "cpu") + with pytest.raises(RuntimeError, match="Expected the device parameter to be a cuda device"): + torch.ops.image.decode_jpegs_cuda([data], ImageReadMode.UNCHANGED.value, "cpu") + with pytest.raises(ValueError, match="Input tensor must be a CPU tensor"): + decode_jpeg( + torch.empty((100,), dtype=torch.uint8, device="cuda"), + ) + with pytest.raises(ValueError, match="Input list must contain tensors on CPU"): + decode_jpeg( + [ + torch.empty((100,), dtype=torch.uint8, device="cuda"), + torch.empty((100,), dtype=torch.uint8, device="cuda"), + ] + ) + + with pytest.raises(ValueError, match="Input list must contain tensors on CPU"): + decode_jpeg( + [ + torch.empty((100,), dtype=torch.uint8, device="cuda"), + torch.empty((100,), dtype=torch.uint8, device="cuda"), + ], + device="cuda", + ) + + with pytest.raises(ValueError, match="Input list must contain tensors on CPU"): + decode_jpeg( + [ + torch.empty((100,), dtype=torch.uint8, device="cpu"), + torch.empty((100,), dtype=torch.uint8, device="cuda"), + ], + device="cuda", + ) + + with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"): + decode_jpeg( + [ + torch.empty((100,), dtype=torch.uint8), + torch.empty((100,), dtype=torch.float32), + ], + device="cuda", + ) + + with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): + decode_jpeg( + [ + torch.empty((100,), dtype=torch.uint8), + torch.empty((1, 100), dtype=torch.uint8), + ], + device="cuda", + ) + + with pytest.raises(RuntimeError, match="Error while decoding JPEG images"): + decode_jpeg( + [ + torch.empty((100,), dtype=torch.uint8), + torch.empty((100,), dtype=torch.uint8), + ], + device="cuda", + ) + + with pytest.raises(ValueError, match="Input list must contain at least one element"): + decode_jpeg([], device="cuda") def test_encode_jpeg_errors(): @@ -515,12 +594,10 @@ def test_encode_jpeg_cuda_device_param(): devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)] results = [] for device in devices: - print(f"python: device: {device}") results.append(encode_jpeg(data.to(device=device))) assert len(results) == len(devices) for result in results: assert torch.all(result.cpu() == results[0].cpu()) - assert current_device == torch.cuda.current_device() assert current_stream == torch.cuda.current_stream() diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp deleted file mode 100644 index 26fecc3e1f3..00000000000 --- a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp +++ /dev/null @@ -1,208 +0,0 @@ -#include "encode_decode_jpegs_cuda.h" - -#include - -#if NVJPEG_FOUND -#include -#include -#include -#endif - -#include - -namespace vision { -namespace image { - -#if !NVJPEG_FOUND - -torch::Tensor decode_jpeg_cuda( - const torch::Tensor& data, - ImageReadMode mode, - torch::Device device) { - TORCH_CHECK( - false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support"); -} - -#else - -namespace { -static nvjpegHandle_t nvjpeg_handle = nullptr; -} - -torch::Tensor decode_jpeg_cuda( - const torch::Tensor& data, - ImageReadMode mode, - torch::Device device) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cuda.decode_jpeg_cuda.decode_jpeg_cuda"); - TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - - TORCH_CHECK( - !data.is_cuda(), - "The input tensor must be on CPU when decoding with nvjpeg") - - TORCH_CHECK( - data.dim() == 1 && data.numel() > 0, - "Expected a non empty 1-dimensional tensor"); - - TORCH_CHECK(device.is_cuda(), "Expected a cuda device") - - int major_version; - int minor_version; - nvjpegStatus_t get_major_property_status = - nvjpegGetProperty(MAJOR_VERSION, &major_version); - nvjpegStatus_t get_minor_property_status = - nvjpegGetProperty(MINOR_VERSION, &minor_version); - - TORCH_CHECK( - get_major_property_status == NVJPEG_STATUS_SUCCESS, - "nvjpegGetProperty failed: ", - get_major_property_status); - TORCH_CHECK( - get_minor_property_status == NVJPEG_STATUS_SUCCESS, - "nvjpegGetProperty failed: ", - get_minor_property_status); - if ((major_version < 11) || ((major_version == 11) && (minor_version < 6))) { - TORCH_WARN_ONCE( - "There is a memory leak issue in the nvjpeg library for CUDA versions < 11.6. " - "Make sure to rely on CUDA 11.6 or above before using decode_jpeg(..., device='cuda')."); - } - - at::cuda::CUDAGuard device_guard(device); - - // Create global nvJPEG handle - static std::once_flag nvjpeg_handle_creation_flag; - std::call_once(nvjpeg_handle_creation_flag, []() { - if (nvjpeg_handle == nullptr) { - nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); - - if (create_status != NVJPEG_STATUS_SUCCESS) { - // Reset handle so that one can still call the function again in the - // same process if there was a failure - free(nvjpeg_handle); - nvjpeg_handle = nullptr; - } - TORCH_CHECK( - create_status == NVJPEG_STATUS_SUCCESS, - "nvjpegCreateSimple failed: ", - create_status); - } - }); - - // Create the jpeg state - nvjpegJpegState_t jpeg_state; - nvjpegStatus_t state_status = - nvjpegJpegStateCreate(nvjpeg_handle, &jpeg_state); - - TORCH_CHECK( - state_status == NVJPEG_STATUS_SUCCESS, - "nvjpegJpegStateCreate failed: ", - state_status); - - auto datap = data.data_ptr(); - - // Get the image information - int num_channels; - nvjpegChromaSubsampling_t subsampling; - int widths[NVJPEG_MAX_COMPONENT]; - int heights[NVJPEG_MAX_COMPONENT]; - nvjpegStatus_t info_status = nvjpegGetImageInfo( - nvjpeg_handle, - datap, - data.numel(), - &num_channels, - &subsampling, - widths, - heights); - - if (info_status != NVJPEG_STATUS_SUCCESS) { - nvjpegJpegStateDestroy(jpeg_state); - TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status); - } - - if (subsampling == NVJPEG_CSS_UNKNOWN) { - nvjpegJpegStateDestroy(jpeg_state); - TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling"); - } - - int width = widths[0]; - int height = heights[0]; - - nvjpegOutputFormat_t ouput_format; - int num_channels_output; - - switch (mode) { - case IMAGE_READ_MODE_UNCHANGED: - num_channels_output = num_channels; - // For some reason, setting output_format to NVJPEG_OUTPUT_UNCHANGED will - // not properly decode RGB images (it's fine for grayscale), so we set - // output_format manually here - if (num_channels == 1) { - ouput_format = NVJPEG_OUTPUT_Y; - } else if (num_channels == 3) { - ouput_format = NVJPEG_OUTPUT_RGB; - } else { - nvjpegJpegStateDestroy(jpeg_state); - TORCH_CHECK( - false, - "When mode is UNCHANGED, only 1 or 3 input channels are allowed."); - } - break; - case IMAGE_READ_MODE_GRAY: - ouput_format = NVJPEG_OUTPUT_Y; - num_channels_output = 1; - break; - case IMAGE_READ_MODE_RGB: - ouput_format = NVJPEG_OUTPUT_RGB; - num_channels_output = 3; - break; - default: - nvjpegJpegStateDestroy(jpeg_state); - TORCH_CHECK( - false, "The provided mode is not supported for JPEG decoding on GPU"); - } - - auto out_tensor = torch::empty( - {int64_t(num_channels_output), int64_t(height), int64_t(width)}, - torch::dtype(torch::kU8).device(device)); - - // nvjpegImage_t is a struct with - // - an array of pointers to each channel - // - the pitch for each channel - // which must be filled in manually - nvjpegImage_t out_image; - - for (int c = 0; c < num_channels_output; c++) { - out_image.channel[c] = out_tensor[c].data_ptr(); - out_image.pitch[c] = width; - } - for (int c = num_channels_output; c < NVJPEG_MAX_COMPONENT; c++) { - out_image.channel[c] = nullptr; - out_image.pitch[c] = 0; - } - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index()); - - nvjpegStatus_t decode_status = nvjpegDecode( - nvjpeg_handle, - jpeg_state, - datap, - data.numel(), - ouput_format, - &out_image, - stream); - - nvjpegJpegStateDestroy(jpeg_state); - - TORCH_CHECK( - decode_status == NVJPEG_STATUS_SUCCESS, - "nvjpegDecode failed: ", - decode_status); - - return out_tensor; -} - -#endif // NVJPEG_FOUND - -} // namespace image -} // namespace vision diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp new file mode 100644 index 00000000000..6314ececef1 --- /dev/null +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -0,0 +1,603 @@ +#include "decode_jpegs_cuda.h" +#if !NVJPEG_FOUND +namespace vision { +namespace image { +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, + vision::image::ImageReadMode mode, + torch::Device device) { + TORCH_CHECK( + false, "decode_jpegs_cuda: torchvision not compiled with nvJPEG support"); +} +} // namespace image +} // namespace vision + +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace vision { +namespace image { + +std::mutex decoderMutex; +std::unique_ptr cudaJpegDecoder; + +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, + vision::image::ImageReadMode mode, + torch::Device device) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cuda.decode_jpegs_cuda.decode_jpegs_cuda"); + + std::lock_guard lock(decoderMutex); + std::vector contig_images; + contig_images.reserve(encoded_images.size()); + + TORCH_CHECK( + device.is_cuda(), "Expected the device parameter to be a cuda device"); + + for (auto& encoded_image : encoded_images) { + TORCH_CHECK( + encoded_image.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + + TORCH_CHECK( + !encoded_image.is_cuda(), + "The input tensor must be on CPU when decoding with nvjpeg") + + TORCH_CHECK( + encoded_image.dim() == 1 && encoded_image.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + // nvjpeg requires images to be contiguous + if (encoded_image.is_contiguous()) { + contig_images.push_back(encoded_image); + } else { + contig_images.push_back(encoded_image.contiguous()); + } + } + + int major_version; + int minor_version; + nvjpegStatus_t get_major_property_status = + nvjpegGetProperty(MAJOR_VERSION, &major_version); + nvjpegStatus_t get_minor_property_status = + nvjpegGetProperty(MINOR_VERSION, &minor_version); + + TORCH_CHECK( + get_major_property_status == NVJPEG_STATUS_SUCCESS, + "nvjpegGetProperty failed: ", + get_major_property_status); + TORCH_CHECK( + get_minor_property_status == NVJPEG_STATUS_SUCCESS, + "nvjpegGetProperty failed: ", + get_minor_property_status); + if ((major_version < 11) || ((major_version == 11) && (minor_version < 6))) { + TORCH_WARN_ONCE( + "There is a memory leak issue in the nvjpeg library for CUDA versions < 11.6. " + "Make sure to rely on CUDA 11.6 or above before using decode_jpeg(..., device='cuda')."); + } + + at::cuda::CUDAGuard device_guard(device); + + if (cudaJpegDecoder == nullptr || device != cudaJpegDecoder->target_device) { + if (cudaJpegDecoder != nullptr) + cudaJpegDecoder.reset(new CUDAJpegDecoder(device)); + else { + cudaJpegDecoder = std::make_unique(device); + std::atexit([]() { cudaJpegDecoder.reset(); }); + } + } + + nvjpegOutputFormat_t output_format; + + switch (mode) { + case vision::image::IMAGE_READ_MODE_UNCHANGED: + // Using NVJPEG_OUTPUT_UNCHANGED causes differently sized output channels + // which is related to the subsampling used I'm not sure why this is the + // case, but for now we're just using RGB and later removing channels from + // grayscale images. + output_format = NVJPEG_OUTPUT_UNCHANGED; + break; + case vision::image::IMAGE_READ_MODE_GRAY: + output_format = NVJPEG_OUTPUT_Y; + break; + case vision::image::IMAGE_READ_MODE_RGB: + output_format = NVJPEG_OUTPUT_RGB; + break; + default: + TORCH_CHECK( + false, "The provided mode is not supported for JPEG decoding on GPU"); + } + + try { + at::cuda::CUDAEvent event; + auto result = cudaJpegDecoder->decode_images(contig_images, output_format); + auto current_stream{ + device.has_index() ? at::cuda::getCurrentCUDAStream( + cudaJpegDecoder->original_device.index()) + : at::cuda::getCurrentCUDAStream()}; + event.record(cudaJpegDecoder->stream); + event.block(current_stream); + return result; + } catch (const std::exception& e) { + if (typeid(e) != typeid(std::runtime_error)) { + TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); + } else { + throw; + } + } +} + +CUDAJpegDecoder::CUDAJpegDecoder(const torch::Device& target_device) + : original_device{torch::kCUDA, torch::cuda::current_device()}, + target_device{target_device}, + stream{ + target_device.has_index() + ? at::cuda::getStreamFromPool(false, target_device.index()) + : at::cuda::getStreamFromPool(false)} { + nvjpegStatus_t status; + + hw_decode_available = true; + status = nvjpegCreateEx( + NVJPEG_BACKEND_HARDWARE, + NULL, + NULL, + NVJPEG_FLAGS_DEFAULT, + &nvjpeg_handle); + if (status == NVJPEG_STATUS_ARCH_MISMATCH) { + status = nvjpegCreateEx( + NVJPEG_BACKEND_DEFAULT, + NULL, + NULL, + NVJPEG_FLAGS_DEFAULT, + &nvjpeg_handle); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to initialize nvjpeg with default backend: ", + status); + hw_decode_available = false; + } else { + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to initialize nvjpeg with hardware backend: ", + status); + } + + status = nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg state: ", + status); + + status = nvjpegDecoderCreate( + nvjpeg_handle, NVJPEG_BACKEND_DEFAULT, &nvjpeg_decoder); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg decoder: ", + status); + + status = nvjpegDecoderStateCreate( + nvjpeg_handle, nvjpeg_decoder, &nvjpeg_decoupled_state); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create nvjpeg decoder state: ", + status); + + status = nvjpegBufferPinnedCreate(nvjpeg_handle, NULL, &pinned_buffers[0]); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create pinned buffer: ", + status); + + status = nvjpegBufferPinnedCreate(nvjpeg_handle, NULL, &pinned_buffers[1]); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create pinned buffer: ", + status); + + status = nvjpegBufferDeviceCreate(nvjpeg_handle, NULL, &device_buffer); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create device buffer: ", + status); + + status = nvjpegJpegStreamCreate(nvjpeg_handle, &jpeg_streams[0]); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create jpeg stream: ", + status); + + status = nvjpegJpegStreamCreate(nvjpeg_handle, &jpeg_streams[1]); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create jpeg stream: ", + status); + + status = nvjpegDecodeParamsCreate(nvjpeg_handle, &nvjpeg_decode_params); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to create decode params: ", + status); +} + +CUDAJpegDecoder::~CUDAJpegDecoder() { + /* + The below code works on Mac and Linux, but fails on Windows. + This is because on Windows, the atexit hook which calls this + destructor executes after cuda is already shut down causing SIGSEGV. + We do not have a solution to this problem at the moment, so we'll + just leak the libnvjpeg & cuda variables for the time being and hope + that the CUDA runtime handles cleanup for us. + Please send a PR if you have a solution for this problem. + */ + + // nvjpegStatus_t status; + + // status = nvjpegDecodeParamsDestroy(nvjpeg_decode_params); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy nvjpeg decode params: ", + // status); + + // status = nvjpegJpegStreamDestroy(jpeg_streams[0]); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy jpeg stream: ", + // status); + + // status = nvjpegJpegStreamDestroy(jpeg_streams[1]); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy jpeg stream: ", + // status); + + // status = nvjpegBufferPinnedDestroy(pinned_buffers[0]); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy pinned buffer[0]: ", + // status); + + // status = nvjpegBufferPinnedDestroy(pinned_buffers[1]); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy pinned buffer[1]: ", + // status); + + // status = nvjpegBufferDeviceDestroy(device_buffer); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy device buffer: ", + // status); + + // status = nvjpegJpegStateDestroy(nvjpeg_decoupled_state); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy nvjpeg decoupled state: ", + // status); + + // status = nvjpegDecoderDestroy(nvjpeg_decoder); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy nvjpeg decoder: ", + // status); + + // status = nvjpegJpegStateDestroy(nvjpeg_state); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, + // "Failed to destroy nvjpeg state: ", + // status); + + // status = nvjpegDestroy(nvjpeg_handle); + // TORCH_CHECK( + // status == NVJPEG_STATUS_SUCCESS, "nvjpegDestroy failed: ", status); +} + +std::tuple< + std::vector, + std::vector, + std::vector> +CUDAJpegDecoder::prepare_buffers( + const std::vector& encoded_images, + const nvjpegOutputFormat_t& output_format) { + /* + This function scans the encoded images' jpeg headers and + allocates decoding buffers based on the metadata found + + Args: + - encoded_images (std::vector): a vector of tensors + containing the jpeg bitstreams to be decoded. Each tensor must have dtype + torch.uint8 and device cpu + - output_format (nvjpegOutputFormat_t): NVJPEG_OUTPUT_RGB, NVJPEG_OUTPUT_Y + or NVJPEG_OUTPUT_UNCHANGED + + Returns: + - decoded_images (std::vector): a vector of nvjpegImages + containing pointers to the memory of the decoded images + - output_tensors (std::vector): a vector of Tensors + containing the decoded images. `decoded_images` points to the memory of + output_tensors + - channels (std::vector): a vector of ints containing the number of + output image channels for every image + */ + + int width[NVJPEG_MAX_COMPONENT]; + int height[NVJPEG_MAX_COMPONENT]; + std::vector channels(encoded_images.size()); + nvjpegChromaSubsampling_t subsampling; + nvjpegStatus_t status; + + std::vector output_tensors{encoded_images.size()}; + std::vector decoded_images{encoded_images.size()}; + + for (std::vector::size_type i = 0; i < encoded_images.size(); + i++) { + // extract bitstream meta data to figure out the number of channels, height, + // width for every image + status = nvjpegGetImageInfo( + nvjpeg_handle, + (unsigned char*)encoded_images[i].data_ptr(), + encoded_images[i].numel(), + &channels[i], + &subsampling, + width, + height); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, "Failed to get image info: ", status); + + TORCH_CHECK( + subsampling != NVJPEG_CSS_UNKNOWN, "Unknown chroma subsampling"); + + // output channels may be different from the actual number of channels in + // the image, e.g. we decode a grayscale image as RGB and slice off the + // extra channels later + int output_channels = 3; + if (output_format == NVJPEG_OUTPUT_RGB || + output_format == NVJPEG_OUTPUT_UNCHANGED) { + output_channels = 3; + } else if (output_format == NVJPEG_OUTPUT_Y) { + output_channels = 1; + } + + // reserve output buffer + auto output_tensor = torch::empty( + {int64_t(output_channels), int64_t(height[0]), int64_t(width[0])}, + torch::dtype(torch::kU8).device(target_device)); + output_tensors[i] = output_tensor; + + // fill nvjpegImage_t struct + for (int c = 0; c < output_channels; c++) { + decoded_images[i].channel[c] = output_tensor[c].data_ptr(); + decoded_images[i].pitch[c] = width[0]; + } + for (int c = output_channels; c < NVJPEG_MAX_COMPONENT; c++) { + decoded_images[i].channel[c] = NULL; + decoded_images[i].pitch[c] = 0; + } + } + return {decoded_images, output_tensors, channels}; +} + +std::vector CUDAJpegDecoder::decode_images( + const std::vector& encoded_images, + const nvjpegOutputFormat_t& output_format) { + /* + This function decodes a batch of jpeg bitstreams. + We scan all encoded bitstreams and sort them into two groups: + 1. Baseline JPEGs: Can be decoded with hardware support on A100+ GPUs. + 2. Other JPEGs (e.g. progressive JPEGs): Can also be decoded on the + GPU (albeit with software support only) but need some preprocessing on the + host first. + + See + https://github.com/NVIDIA/CUDALibrarySamples/blob/f17940ac4e705bf47a8c39f5365925c1665f6c98/nvJPEG/nvJPEG-Decoder/nvjpegDecoder.cpp#L33 + for reference. + + Args: + - encoded_images (std::vector): a vector of tensors + containing the jpeg bitstreams to be decoded + - output_format (nvjpegOutputFormat_t): NVJPEG_OUTPUT_RGB, NVJPEG_OUTPUT_Y + or NVJPEG_OUTPUT_UNCHANGED + - device (torch::Device): The desired CUDA device for the returned Tensors + + Returns: + - output_tensors (std::vector): a vector of Tensors + containing the decoded images + */ + + auto [decoded_imgs_buf, output_tensors, channels] = + prepare_buffers(encoded_images, output_format); + + nvjpegStatus_t status; + cudaError_t cudaStatus; + + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK( + cudaStatus == cudaSuccess, + "Failed to synchronize CUDA stream: ", + cudaStatus); + + // baseline JPEGs can be batch decoded with hardware support on A100+ GPUs + // ultra fast! + std::vector hw_input_buffer; + std::vector hw_input_buffer_size; + std::vector hw_output_buffer; + + // other JPEG types such as progressive JPEGs can be decoded one-by-one in + // software slow :( + std::vector sw_input_buffer; + std::vector sw_input_buffer_size; + std::vector sw_output_buffer; + + if (hw_decode_available) { + for (std::vector::size_type i = 0; i < encoded_images.size(); + ++i) { + // extract bitstream meta data to figure out whether a bit-stream can be + // decoded + nvjpegJpegStreamParseHeader( + nvjpeg_handle, + encoded_images[i].data_ptr(), + encoded_images[i].numel(), + jpeg_streams[0]); + int isSupported = -1; + nvjpegDecodeBatchedSupported( + nvjpeg_handle, jpeg_streams[0], &isSupported); + + if (isSupported == 0) { + hw_input_buffer.push_back(encoded_images[i].data_ptr()); + hw_input_buffer_size.push_back(encoded_images[i].numel()); + hw_output_buffer.push_back(decoded_imgs_buf[i]); + } else { + sw_input_buffer.push_back(encoded_images[i].data_ptr()); + sw_input_buffer_size.push_back(encoded_images[i].numel()); + sw_output_buffer.push_back(decoded_imgs_buf[i]); + } + } + } else { + for (std::vector::size_type i = 0; i < encoded_images.size(); + ++i) { + sw_input_buffer.push_back(encoded_images[i].data_ptr()); + sw_input_buffer_size.push_back(encoded_images[i].numel()); + sw_output_buffer.push_back(decoded_imgs_buf[i]); + } + } + + if (hw_input_buffer.size() > 0) { + // UNCHANGED behaves weird, so we use RGB instead + status = nvjpegDecodeBatchedInitialize( + nvjpeg_handle, + nvjpeg_state, + hw_input_buffer.size(), + 1, + output_format == NVJPEG_OUTPUT_UNCHANGED ? NVJPEG_OUTPUT_RGB + : output_format); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to initialize batch decoding: ", + status); + + status = nvjpegDecodeBatched( + nvjpeg_handle, + nvjpeg_state, + hw_input_buffer.data(), + hw_input_buffer_size.data(), + hw_output_buffer.data(), + stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, "Failed to decode batch: ", status); + } + + if (sw_input_buffer.size() > 0) { + status = + nvjpegStateAttachDeviceBuffer(nvjpeg_decoupled_state, device_buffer); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to attach device buffer: ", + status); + int buffer_index = 0; + // UNCHANGED behaves weird, so we use RGB instead + status = nvjpegDecodeParamsSetOutputFormat( + nvjpeg_decode_params, + output_format == NVJPEG_OUTPUT_UNCHANGED ? NVJPEG_OUTPUT_RGB + : output_format); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to set output format: ", + status); + for (std::vector::size_type i = 0; i < sw_input_buffer.size(); + ++i) { + status = nvjpegJpegStreamParse( + nvjpeg_handle, + sw_input_buffer[i], + sw_input_buffer_size[i], + 0, + 0, + jpeg_streams[buffer_index]); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to parse jpeg stream: ", + status); + + status = nvjpegStateAttachPinnedBuffer( + nvjpeg_decoupled_state, pinned_buffers[buffer_index]); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to attach pinned buffer: ", + status); + + status = nvjpegDecodeJpegHost( + nvjpeg_handle, + nvjpeg_decoder, + nvjpeg_decoupled_state, + nvjpeg_decode_params, + jpeg_streams[buffer_index]); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to decode jpeg stream: ", + status); + + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK( + cudaStatus == cudaSuccess, + "Failed to synchronize CUDA stream: ", + cudaStatus); + + status = nvjpegDecodeJpegTransferToDevice( + nvjpeg_handle, + nvjpeg_decoder, + nvjpeg_decoupled_state, + jpeg_streams[buffer_index], + stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to transfer jpeg to device: ", + status); + + buffer_index = 1 - buffer_index; // switch pinned buffer in pipeline mode + // to avoid an extra sync + + status = nvjpegDecodeJpegDevice( + nvjpeg_handle, + nvjpeg_decoder, + nvjpeg_decoupled_state, + &sw_output_buffer[i], + stream); + TORCH_CHECK( + status == NVJPEG_STATUS_SUCCESS, + "Failed to decode jpeg stream: ", + status); + } + } + + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK( + cudaStatus == cudaSuccess, + "Failed to synchronize CUDA stream: ", + cudaStatus); + + // prune extraneous channels from single channel images + if (output_format == NVJPEG_OUTPUT_UNCHANGED) { + for (std::vector::size_type i = 0; i < output_tensors.size(); + ++i) { + if (channels[i] == 1) { + output_tensors[i] = output_tensors[i][0].unsqueeze(0).clone(); + } + } + } + + return output_tensors; +} + +} // namespace image +} // namespace vision + +#endif diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h new file mode 100644 index 00000000000..2458a103a3a --- /dev/null +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -0,0 +1,45 @@ +#pragma once +#include +#include +#include "../image_read_mode.h" + +#if NVJPEG_FOUND +#include +#include + +namespace vision { +namespace image { +class CUDAJpegDecoder { + public: + CUDAJpegDecoder(const torch::Device& target_device); + ~CUDAJpegDecoder(); + + std::vector decode_images( + const std::vector& encoded_images, + const nvjpegOutputFormat_t& output_format); + + const torch::Device original_device; + const torch::Device target_device; + const c10::cuda::CUDAStream stream; + + private: + std::tuple< + std::vector, + std::vector, + std::vector> + prepare_buffers( + const std::vector& encoded_images, + const nvjpegOutputFormat_t& output_format); + nvjpegJpegState_t nvjpeg_state; + nvjpegJpegState_t nvjpeg_decoupled_state; + nvjpegBufferPinned_t pinned_buffers[2]; + nvjpegBufferDevice_t device_buffer; + nvjpegJpegStream_t jpeg_streams[2]; + nvjpegDecodeParams_t nvjpeg_decode_params; + nvjpegJpegDecoder_t nvjpeg_decoder; + bool hw_decode_available{false}; + nvjpegHandle_t nvjpeg_handle; +}; +} // namespace image +} // namespace vision +#endif diff --git a/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h index 7723d11d621..3fdf715b00f 100644 --- a/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h @@ -2,16 +2,55 @@ #include #include "../image_read_mode.h" +#include "decode_jpegs_cuda.h" #include "encode_jpegs_cuda.h" namespace vision { namespace image { -C10_EXPORT torch::Tensor decode_jpeg_cuda( - const torch::Tensor& data, - ImageReadMode mode, +/* + +Fast jpeg decoding with CUDA. +A100+ GPUs have dedicated hardware support for jpeg decoding. + +Args: + - encoded_images (const std::vector&): a vector of tensors + containing the jpeg bitstreams to be decoded. Each tensor must have dtype + torch.uint8 and device cpu + - mode (ImageReadMode): IMAGE_READ_MODE_UNCHANGED, IMAGE_READ_MODE_GRAY and +IMAGE_READ_MODE_RGB are supported + - device (torch::Device): The desired CUDA device to run the decoding on and +which will contain the output tensors + +Returns: + - decoded_images (std::vector): a vector of torch::Tensors of +dtype torch.uint8 on the specified containing the decoded images + +Notes: + - If a single image fails, the whole batch fails. + - This function is thread-safe +*/ +C10_EXPORT std::vector decode_jpegs_cuda( + const std::vector& encoded_images, + vision::image::ImageReadMode mode, torch::Device device); +/* +Fast jpeg encoding with CUDA. + +Args: + - decoded_images (const std::vector&): a vector of contiguous +CUDA tensors of dtype torch.uint8 to be encoded. + - quality (int64_t): 0-100, 75 is the default + +Returns: + - encoded_images (std::vector): a vector of CUDA +torch::Tensors of dtype torch.uint8 containing the encoded images + +Notes: + - If a single image fails, the whole batch fails. + - This function is thread-safe +*/ C10_EXPORT std::vector encode_jpegs_cuda( const std::vector& decoded_images, const int64_t quality); diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index e351ed425b5..9f7563eebf8 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -26,7 +26,7 @@ static auto registry = .op("image::write_file", &write_file) .op("image::decode_image(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", &decode_image) - .op("image::decode_jpeg_cuda", &decode_jpeg_cuda) + .op("image::decode_jpegs_cuda", &decode_jpegs_cuda) .op("image::encode_jpegs_cuda", &encode_jpegs_cuda) .op("image::_jpeg_version", &_jpeg_version) .op("image::_is_compiled_against_turbo", &_is_compiled_against_turbo); diff --git a/torchvision/io/image.py b/torchvision/io/image.py index debef443f7a..eec073ce55e 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -143,22 +143,28 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): def decode_jpeg( - input: torch.Tensor, + input: Union[torch.Tensor, List[torch.Tensor]], mode: ImageReadMode = ImageReadMode.UNCHANGED, - device: str = "cpu", + device: Union[str, torch.device] = "cpu", apply_exif_orientation: bool = False, -) -> torch.Tensor: +) -> Union[torch.Tensor, List[torch.Tensor]]: """ - Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor. + Decode JPEG image(s) into 3 dimensional RGB or grayscale Tensor(s). The values of the output tensor are uint8 between 0 and 255. + .. note:: + When using a CUDA device, passing a list of tensors is more efficient than repeated individual calls to ``decode_jpeg``. + When using CPU the performance is equivalent. + The CUDA version of this function has explicitly been designed with thread-safety in mind. + This function does not return partial results in case of an error. + Args: - input (Tensor[1]): a one dimensional uint8 tensor containing - the raw bytes of the JPEG image. This tensor must be on CPU, + input (Tensor[1] or list[Tensor[1]]): a (list of) one dimensional uint8 tensor(s) containing + the raw bytes of the JPEG image. The tensor(s) must be on CPU, regardless of the ``device`` parameter. mode (ImageReadMode): the read mode used for optionally - converting the image. The supported modes are: ``ImageReadMode.UNCHANGED``, + converting the image(s). The supported modes are: ``ImageReadMode.UNCHANGED``, ``ImageReadMode.GRAY`` and ``ImageReadMode.RGB`` Default: ``ImageReadMode.UNCHANGED``. See ``ImageReadMode`` class for more information on various @@ -177,16 +183,36 @@ def decode_jpeg( Default: False. Only implemented for JPEG format on CPU. Returns: - output (Tensor[image_channels, image_height, image_width]) + output (Tensor[image_channels, image_height, image_width] or list[Tensor[image_channels, image_height, image_width]]): + The values of the output tensor(s) are uint8 between 0 and 255. + ``output.device`` will be set to the specified ``device`` + + """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(decode_jpeg) - device = torch.device(device) - if device.type == "cuda": - output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device) - else: - output = torch.ops.image.decode_jpeg(input, mode.value, apply_exif_orientation) - return output + if isinstance(device, str): + device = torch.device(device) + + if isinstance(input, list): + if len(input) == 0: + raise ValueError("Input list must contain at least one element") + if not all(isinstance(t, torch.Tensor) for t in input): + raise ValueError("All elements of the input list must be tensors.") + if not all(t.device.type == "cpu" for t in input): + raise ValueError("Input list must contain tensors on CPU.") + if device.type == "cuda": + return torch.ops.image.decode_jpegs_cuda(input, mode.value, device) + else: + return [torch.ops.image.decode_jpeg(img, mode.value, apply_exif_orientation) for img in input] + + else: # input is tensor + if input.device.type != "cpu": + raise ValueError("Input tensor must be a CPU tensor") + if device.type == "cuda": + return torch.ops.image.decode_jpegs_cuda([input], mode.value, device)[0] + else: + return torch.ops.image.decode_jpeg(input, mode.value, apply_exif_orientation) def encode_jpeg( diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 60b49099fc5..a904d8d7cbd 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -80,9 +80,12 @@ def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor: images = [] for i in range(image.shape[0]): + # isinstance checks are needed for torchscript. encoded_image = encode_jpeg(image[i], quality=quality) - assert isinstance(encoded_image, torch.Tensor) # For torchscript - images.append(decode_jpeg(encoded_image)) + assert isinstance(encoded_image, torch.Tensor) + decoded_image = decode_jpeg(encoded_image) + assert isinstance(decoded_image, torch.Tensor) + images.append(decoded_image) images = torch.stack(images, dim=0).view(original_shape) return images