Skip to content

Commit

Permalink
ahmad's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
deekay42 committed Jun 26, 2024
1 parent 01a5621 commit ccdafd4
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 13 deletions.
30 changes: 20 additions & 10 deletions torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ std::vector<torch::Tensor> decode_jpegs_cuda(
"torchvision.csrc.io.image.cuda.decode_jpegs_cuda.decode_jpegs_cuda");

std::lock_guard<std::mutex> lock(decoderMutex);
std::vector<torch::Tensor> contig_images;
contig_images.reserve(encoded_images.size());

for (auto& encoded_image : encoded_images) {
TORCH_CHECK(
Expand All @@ -52,6 +54,13 @@ std::vector<torch::Tensor> decode_jpegs_cuda(
TORCH_CHECK(
encoded_image.dim() == 1 && encoded_image.numel() > 0,
"Expected a non empty 1-dimensional tensor");

// nvjpeg requires images to be contiguous
if (encoded_image.is_contiguous()) {
contig_images.push_back(encoded_image);
} else {
contig_images.push_back(encoded_image.contiguous());
}
}

TORCH_CHECK(device.is_cuda(), "Expected a cuda device");
Expand Down Expand Up @@ -81,9 +90,11 @@ std::vector<torch::Tensor> decode_jpegs_cuda(

if (cudaJpegDecoder == nullptr || device != cudaJpegDecoder->target_device) {
if (cudaJpegDecoder != nullptr)
delete cudaJpegDecoder.release();
cudaJpegDecoder = std::make_unique<CUDAJpegDecoder>(device);
std::atexit([]() { delete cudaJpegDecoder.release(); });
cudaJpegDecoder.reset(new CUDAJpegDecoder(device));
else {
cudaJpegDecoder = std::make_unique<CUDAJpegDecoder>(device);
std::atexit([]() { cudaJpegDecoder.reset(); });
}
}

nvjpegOutputFormat_t output_format;
Expand All @@ -109,14 +120,13 @@ std::vector<torch::Tensor> decode_jpegs_cuda(

try {
at::cuda::CUDAEvent event;
auto result = cudaJpegDecoder->decode_images(contig_images, output_format);
auto current_stream{
device.has_index() ? at::cuda::getCurrentCUDAStream(
cudaJpegDecoder->original_device.index())
: at::cuda::getCurrentCUDAStream()};
event.record(cudaJpegDecoder->stream);
auto result = cudaJpegDecoder->decode_images(encoded_images, output_format);
if (device.has_index())
event.block(at::cuda::getCurrentCUDAStream(
cudaJpegDecoder->original_device.index()));
else
event.block(at::cuda::getCurrentCUDAStream());
return result;
event.block(current_stream) return result;
} catch (const std::exception& e) {
if (typeid(e) != typeid(std::runtime_error)) {
TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what());
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class CUDAJpegDecoder {
const torch::Device target_device;
const c10::cuda::CUDAStream stream;

protected:
private:
std::tuple<
std::vector<nvjpegImage_t>,
std::vector<torch::Tensor>,
Expand Down
38 changes: 38 additions & 0 deletions torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,49 @@
namespace vision {
namespace image {

/*
Fast jpeg decoding with CUDA.
A100+ GPUs have dedicated hardware support for jpeg decoding.
Args:
- encoded_images (const std::vector<torch::Tensor>&): a vector of tensors
containing the jpeg bitstreams to be decoded. Each tensor must have dtype
torch.uint8 and device cpu
- mode (ImageReadMode): IMAGE_READ_MODE_UNCHANGED, IMAGE_READ_MODE_GRAY and
IMAGE_READ_MODE_RGB are supported
- device (torch::Device): The desired CUDA device to run the decoding on and
which will contain the output tensors
Returns:
- decoded_images (std::vector<torch::Tensor>): a vector of torch::Tensors of
dtype torch.uint8 on the specified <device> containing the decoded images
Notes:
- If a single image fails, the whole batch fails.
- This function is thread-safe
*/
C10_EXPORT std::vector<torch::Tensor> decode_jpegs_cuda(
const std::vector<torch::Tensor>& encoded_images,
vision::image::ImageReadMode mode,
torch::Device device);

/*
Fast jpeg encoding with CUDA.
Args:
- decoded_images (const std::vector<torch::Tensor>&): a vector of contiguous
CUDA tensors of dtype torch.uint8 to be encoded.
- quality (int64_t): 0-100, 75 is the default
Returns:
- encoded_images (std::vector<torch::Tensor>): a vector of CUDA
torch::Tensors of dtype torch.uint8 containing the encoded images
Notes:
- If a single image fails, the whole batch fails.
- This function is thread-safe
*/
C10_EXPORT std::vector<torch::Tensor> encode_jpegs_cuda(
const std::vector<torch::Tensor>& decoded_images,
const int64_t quality);
Expand Down
9 changes: 7 additions & 2 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,13 @@ def decode_jpeg(
"""
Decodes a (list of) JPEG image(s) into a (list of) 3 dimensional RGB or grayscale Tensor(s).
Optionally converts the image(s) to the desired format.
The values of the output tensor(s) are uint8 between 0 and 255.
.. note::
When using a CUDA device, passing a list of tensors is more efficient than repeated individual calls to ``decode_jpeg``.
When using CPU the performance is equivalent.
The CUDA version of this function has explicitly been designed with thread-safety in mind.
This function does not return partial results in case of an error.
Args:
input (Tensor[1] or list[Tensor[1]]): a (list of) one dimensional uint8 tensor(s) containing
Expand All @@ -175,7 +177,10 @@ def decode_jpeg(
Default: False. Only implemented for JPEG format on CPU.
Returns:
output (Tensor[image_channels, image_height, image_width])
output (Tensor[image_channels, image_height, image_width] or list[Tensor[image_channels, image_height, image_width]]):
The values of the output tensor(s) are uint8 between 0 and 255. output.device will be set to the specified ``device``
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_jpeg)
Expand Down

0 comments on commit ccdafd4

Please sign in to comment.