Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUDA decoding support #242

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open

Conversation

ahmadsharif1
Copy link
Contributor

@ahmadsharif1 ahmadsharif1 commented Oct 4, 2024

Actually implement cuda decoding in C++:

  1. Initialize a cuda device if requested. We create a small tensor on the device to initialize the context.
  2. Use the cuda device to decode the video to NV12 format.
  3. Use libNPP to convert from NV12 to RGB. We make sure to wait on this cuda event so there are no race conditions in accessing this tensor from the downstream consumer (that is on a different stream than libNPP).

Note that the GPU decodes frames that are not bit-accurate. This is by design and we ensure tensors are approximately equal rather than fully accurate. The actual tensor values depends on the GPU architecture because GPU math is not precise.

Also added a gpu_benchmark with the following results:

python benchmarks/decoders/gpu_benchmark.py --video /tmp/frame_numbers_1920x1080_100.mp4
[--------------------- Decode+Resize Time --------------------]
                       |  video=frame_numbers_1920x1080_100.mp4
1 threads: ----------------------------------------------------
      D=cuda R=cuda    |                   12.1                
      D=cuda R=cpu     |                  148.0                
      D=cuda R=native  |                   11.5                
      D=cuda R=none    |                   11.5                
      D=cpu R=cuda     |                   16.9                
      D=cpu R=cpu      |                  134.2                
      D=cpu R=native   |                   23.2                
      D=cpu R=none     |                    9.4                

Times are in seconds (s).

Key: D=Decode, R=Resize
Native resize is done as part of the decode step
none resize means there is no resize step -- native or otherwise

Results show that a single NVDec is slower than 22 core CPU without resizing, but faster with resizing.

I also added a "throughput mode" for the benchmark that decodes W videos in parallel using T threads. Results of this "throughput mode" shows that A100 has higher decode throughput than my 22-core CPU:

python benchmarks/decoders/gpu_benchmark.py --video /tmp/frame_numbers_1920x1080_100.mp4 --devices=cuda:0,cpu --resize_
devices=none --num_threads 10 --num_videos 10

[---------------------------------- Decode+Resize Time ----------------------------------]
                               |  threads=10 work=10 video=frame_numbers_1920x1080_100.mp4
1 threads: -------------------------------------------------------------------------------
      D=cuda R=none T=10 W=10  |                            29.0                          
      D=cpu R=none T=10 W=10   |                            38.8                          

Times are in seconds (s).

Key: D=Decode, R=Resize T=threads W=work (number of videos to decode)
Native resize is done as part of the decode step
none resize means there is no resize step -- native or otherwise

nvidia-smi shows 99% NVDEC utilization :)

# gpu         pid   type     sm    mem    enc    dec    jpg    ofa    command 
# Idx           #    C/G      %      %      %      %      %      %    name 
    0    2180816     C     70      5      -     99      -      -    python      

This throughput mode is representative of video decoding using the dataloader with multiple threads.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 4, 2024
@ronghanghu
Copy link

ronghanghu commented Oct 7, 2024

Looking forward to this! It would be great to have GPU decoding added (back) to TorchCodec

src/torchcodec/decoders/_core/CMakeLists.txt Outdated Show resolved Hide resolved
test/decoders/test_video_decoder_ops.py Outdated Show resolved Hide resolved
const torch::Device& device,
AVCodecContext* codecContext);

VideoDecoder::DecodedOutput convertAVFrameToDecodedOutputOnDevice(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should "OnDevice" be "OnCUDA"? I know that within the context of CUDA development, "device" is often used to mean the GPU in contrast to the host, but in the context of torchcodec/pytorch the distinction isn't always as obvious to me. The CPU is a device too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is an interface, I want it to be generic so we can support AMD, etc. in the future. That's why I call it "device"

Copy link
Contributor

@scotts scotts Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that we should say "OnCUDA" if the implementation only supports CUDA. We may support other kinds of devices in the future. If we have "device" in the name, that should mean the implementation works for any kind of device.

src/torchcodec/decoders/_core/CudaDevice.cpp Show resolved Hide resolved
src/torchcodec/decoders/_core/CudaDevice.cpp Show resolved Hide resolved
benchmarks/decoders/gpu_benchmark.py Outdated Show resolved Hide resolved
const VideoDecoder::VideoStreamDecoderOptions& options,
AVCodecContext* codecContext,
VideoDecoder::RawDecodedOutput& rawOutput) {
throwUnsupportedDeviceError(device);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this function and the function below always throw? If yes, then we should just do something like TORCH_CHECK(false, "Unsupported device.");. In order avoid the need for a return value, mark the function as [[noreturn]]: https://en.cppreference.com/w/cpp/language/attributes/noreturn. We should rely on a TORCH macro to do the throwing for us rather than doing the throw ourselves, and we should make it obviously one that will always fail its check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we maybe should also annotate convertAVFrameToDecodedOutputOnDevice() and initializeDeviceContext() with [[noreturn]]. Let's also avoid two TORCH_CHECK calls. Whatever message we want to put on stderr, we can do it in one check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two checks are there because one is a programming/logic error on our part -- we should never pass in a CPU device for device functions.

The other is the check for passing in a non-compiled device.

Comment on lines 90 to 91
at::DeviceIndex deviceIndex = device.index();
deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things:

  1. This is the second place we're doing this same logic. We should abstract it into a function, even though it's small. The function name will probably help with my second point.
  2. I'm not quite sure why we're doing it? Under what circumstance will ATen's reported index for a device be less than 0? It looks like it defaults to -1 in some cases (https://pytorch.org/cppdocs/api/structc10_1_1_device.html#_CPPv4N3c106Device6DeviceE10DeviceType11DeviceIndex), but wouldn't that be an error for us? Notably, this logic will make any value less than 0 be 0, which means maybe we could map multiple devices to 0. I don't think we should ever see such values, but it's confusing to me that our code makes it possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ffmpeg doesn't accept negative values for the device index. I added a comment to that effect

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a single GPU libtorch returns -1 while ffmpeg assumes it will be 0. So we have to bridge that gap.

For multi-GPU setup, I haven't seen -1 being returned by torch -- so there we wont have to do a max.

The -1 seems to be a libtorch specific thing.

at::DeviceIndex deviceIndex = device.index();
deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0);
at::DeviceIndex originalDeviceIndex = at::cuda::current_device();
cudaSetDevice(deviceIndex);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see that later on line 121 we restore the device index. Why? Can we explain why we need to set and then restore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up using CUDADeviceGuard. Callers are assuming we are not interfering with the cuda device in this function

@@ -856,6 +856,25 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
output.duration = getDuration(frame);
output.durationSeconds = ptsToSeconds(
getDuration(frame), formatContext_->streams[streamIndex]->time_base);
if (streamInfo.options.device.type() != torch::kCPU) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this logic truly general to all non-CPU devices? If no, then here and elsewhere, we should do something closer to:

if (streamInfo.options.device.type() == torch::kCUDA) {
  logicSpecificToCUDA();
}
else if (streamInfo.options.device.type() == torch::kCPU) {
  logicSpecificToCPU();
}
else {
  TORCH_CHECK(false, "Unsupported device");
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now we only support cuda.

I am assuming if we support AMD we will use the same interface, and use cmake or #ifdefs to link in the correct device code.

So VideoDecoder.cpp just assumes cmake or the linker will do the right thing and just calls the device code for any type of device.

At the moment cuda device is linked by cmake for cuda builds. How that will be done for AMD is TBD.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that makes sense. That also means that we don't need to have CPU versions tof functions hat throw for all N devices we support.

@ahmadsharif1 ahmadsharif1 marked this pull request as ready for review October 8, 2024 14:25
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
torch::DeviceIndex deviceIndex = device.index();
// FFMPEG cannot handle negative device indices.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand now what instigated this code, but I still can't evaluate if it's correct. Looking at the docs, a negative value indicates the "current device": https://pytorch.org/cppdocs/api/structc10_1_1_device.html#_CPPv4N3c106DeviceE

Is it safe to map all values of "current device" to 0? Is this a mapping we need to track? What happens when we are on a system with multiple GPUs? I'm assuming we don't fully understand the answers to these questions, and I don't want to block progress. So I think we should have a meatier comment both explaining what we do know, and indicating this may be a problem in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a longer comment with a TODO to investigate that it works properly with multi-GPU setup. I am sure once users start using it, we will hit more edge cases.

@facebook-github-bot
Copy link
Contributor

@ahmadsharif1 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ahmadsharif1 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ahmadsharif1 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ahmadsharif1 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@ronghanghu
Copy link

Hi @ahmadsharif1, thanks for the great PR to add back GPU support! Wondering if it's possible to also add back the device parameter into SimpleVideoDecoder, which was previously removed in https://github.com/pytorch/torchcodec/pull/196/files#diff-5ff4f051479ffd5d021001e2a101973746feda3a3f579bf2d072629329c421dc?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants