From 72bdd2512c16f7cc6c6f2e026024502776289c37 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Fri, 4 Oct 2024 11:31:46 -0700 Subject: [PATCH 01/26] Add GPU decoding support --- .../decoders/_core/CPUOnlyDevice.cpp | 5 ++- src/torchcodec/decoders/_core/CudaDevice.cpp | 43 ++++++++++++++++--- .../decoders/_core/DeviceInterface.h | 8 +++- .../decoders/_core/VideoDecoder.cpp | 2 +- 4 files changed, 50 insertions(+), 8 deletions(-) diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 5a591b04..4b28c970 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -1,8 +1,11 @@ #include +#include "src/torchcodec/decoders/_core/DeviceInterface.h" namespace facebook::torchcodec { -void maybeInitializeDeviceContext(const torch::Device& device) { +void maybeInitializeDeviceContext( + const torch::Device& device, + AVCodecContext* codecContext) { if (device.type() == torch::kCPU) { return; } diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index d1ae233c..ca492661 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -1,15 +1,48 @@ #include +#include "src/torchcodec/decoders/_core/DeviceInterface.h" +#include "src/torchcodec/decoders/_core/FFMPEGCommon.h" + +extern "C" { +#include +#include +} namespace facebook::torchcodec { -void maybeInitializeDeviceContext(const torch::Device& device) { +AVBufferRef* getCudaContext() { + enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); + TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); + int err = 0; + AVBufferRef* hw_device_ctx; + err = av_hwdevice_ctx_create( + &hw_device_ctx, + type, + nullptr, + nullptr, + // Introduced in 58.26.100: + // https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265 +#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100) + AV_CUDA_USE_CURRENT_CONTEXT +#else + 0 +#endif + ); + if (err < 0) { + TORCH_CHECK( + false, + "Failed to create specified HW device", + getFFMPEGErrorStringFromErrorCode(err)); + } + return hw_device_ctx; +} + +void maybeInitializeDeviceContext( + const torch::Device& device, + AVCodecContext* codecContext) { if (device.type() == torch::kCPU) { return; } else if (device.type() == torch::kCUDA) { - // TODO: https://github.com/pytorch/torchcodec/issues/238: Implement CUDA - // device. - throw std::runtime_error( - "CUDA device is unimplemented. Follow this issue for tracking progress: https://github.com/pytorch/torchcodec/issues/238"); + codecContext->hw_device_ctx = av_buffer_ref(getCudaContext()); } throw std::runtime_error("Unsupported device: " + device.str()); } diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 3c13ac5e..884ca3f3 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -11,10 +11,16 @@ #include #include +extern "C" { +#include +} + namespace facebook::torchcodec { // Initialize the hardware device that is specified in `device`. Some builds // support CUDA and others only support CPU. -void maybeInitializeDeviceContext(const torch::Device& device); +void maybeInitializeDeviceContext( + const torch::Device& device, + AVCodecContext* codecContext); } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index dcf82ea4..802cb840 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -428,7 +428,7 @@ void VideoDecoder::addVideoStreamDecoder( streamInfo.codecContext.reset(codecContext); int retVal = avcodec_parameters_to_context( streamInfo.codecContext.get(), streamInfo.stream->codecpar); - maybeInitializeDeviceContext(options.device); + maybeInitializeDeviceContext(options.device, codecContext); TORCH_CHECK_EQ(retVal, AVSUCCESS); retVal = avcodec_open2(streamInfo.codecContext.get(), codec, nullptr); if (retVal < AVSUCCESS) { From 707cff34f303de98504b3b7021383e245d9dbc23 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Fri, 4 Oct 2024 13:25:01 -0700 Subject: [PATCH 02/26] . --- src/torchcodec/decoders/_core/CudaDevice.cpp | 64 ++++++++++++++++++- .../decoders/_core/DeviceInterface.h | 7 ++ .../decoders/_core/VideoDecoder.cpp | 7 ++ test/decoders/test_video_decoder_ops.py | 8 ++- 4 files changed, 83 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index ca492661..b9bb8281 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -1,14 +1,17 @@ +#include #include #include "src/torchcodec/decoders/_core/DeviceInterface.h" #include "src/torchcodec/decoders/_core/FFMPEGCommon.h" +#include "src/torchcodec/decoders/_core/VideoDecoder.h" extern "C" { #include #include +#include } namespace facebook::torchcodec { - +namespace { AVBufferRef* getCudaContext() { enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); @@ -36,15 +39,74 @@ AVBufferRef* getCudaContext() { return hw_device_ctx; } +torch::Tensor allocateDeviceTensor( + at::IntArrayRef shape, + torch::Device device, + const torch::Dtype dtype = torch::kUInt8) { + return torch::empty( + shape, + torch::TensorOptions() + .dtype(dtype) + .layout(torch::kStrided) + .device(device)); +} +} // namespace + void maybeInitializeDeviceContext( const torch::Device& device, AVCodecContext* codecContext) { if (device.type() == torch::kCPU) { return; } else if (device.type() == torch::kCUDA) { + torch::Tensor dummyTensorForCudaInitialization = torch::empty( + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); codecContext->hw_device_ctx = av_buffer_ref(getCudaContext()); + return; } throw std::runtime_error("Unsupported device: " + device.str()); } +VideoDecoder::DecodedOutput convertAVFrameToDecodedOutputOnDevice( + const torch::Device& device, + const VideoDecoder::VideoStreamDecoderOptions& options, + AVCodecContext* codecContext, + VideoDecoder::RawDecodedOutput& rawOutput) { + AVFrame* src = rawOutput.frame.get(); + + TORCH_CHECK( + src->format == AV_PIX_FMT_CUDA, + "Expected format to be AV_PIX_FMT_CUDA, got " + + std::string(av_get_pix_fmt_name((AVPixelFormat)src->format))); + int width = options.width.value_or(codecContext->width); + int height = options.height.value_or(codecContext->height); + NppStatus status; + NppiSize oSizeROI; + oSizeROI.width = width; + oSizeROI.height = height; + Npp8u* input[2]; + input[0] = (Npp8u*)src->data[0]; + input[1] = (Npp8u*)src->data[1]; + VideoDecoder::DecodedOutput output; + torch::Tensor& dst = output.frame; + dst = allocateDeviceTensor({height, width, 3}, options.device); + auto start = std::chrono::high_resolution_clock::now(); + status = nppiNV12ToRGB_8u_P2C3R( + input, + src->linesize[0], + static_cast(dst.data_ptr()), + dst.stride(0), + oSizeROI); + TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration duration = end - start; + VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width + << " took: " << duration.count() << "us" << std::endl; + if (options.dimensionOrder == "NCHW") { + // The docs guaranty this to return a view: + // https://pytorch.org/docs/stable/generated/torch.permute.html + dst = dst.permute({2, 0, 1}); + } + return output; +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 884ca3f3..b1e8be82 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -10,6 +10,7 @@ #include #include #include +#include "src/torchcodec/decoders/_core/VideoDecoder.h" extern "C" { #include @@ -23,4 +24,10 @@ void maybeInitializeDeviceContext( const torch::Device& device, AVCodecContext* codecContext); +VideoDecoder::DecodedOutput convertAVFrameToDecodedOutputOnDevice( + const torch::Device& device, + const VideoDecoder::VideoStreamDecoderOptions& options, + AVCodecContext* codecContext, + VideoDecoder::RawDecodedOutput& rawOutput); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 802cb840..dcdf2826 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -854,6 +854,13 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( output.duration = getDuration(frame); output.durationSeconds = ptsToSeconds( getDuration(frame), formatContext_->streams[streamIndex]->time_base); + if (streamInfo.options.device.type() != torch::kCPU) { + return convertAVFrameToDecodedOutputOnDevice( + streamInfo.options.device, + streamInfo.options, + streamInfo.codecContext.get(), + rawOutput); + } if (output.streamType == AVMEDIA_TYPE_VIDEO) { if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { int width = streamInfo.options.width.value_or(frame->width); diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 2fbe8d9a..331d4968 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -464,8 +464,12 @@ def test_color_conversion_library_with_generated_videos( def test_cuda_decoder(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) - with pytest.raises(RuntimeError, match="CUDA device is unimplemented"): - add_video_stream(decoder, device="cuda") + add_video_stream(decoder, device="cuda") + frame0, *_ = get_next_frame(decoder) + assert frame0.device.type == "cuda" + frame0_cpu = frame0.to("cpu") + reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) + torch.testing.assert_close(frame0_cpu, reference_frame0, atol=60, rtol=0) if __name__ == "__main__": From 0c916b33c8a53e673117d5e0ba2bb8658eeb5329 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Fri, 4 Oct 2024 14:43:02 -0700 Subject: [PATCH 03/26] . --- src/torchcodec/decoders/_core/CPUOnlyDevice.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 4b28c970..4c3fbd22 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -12,4 +12,12 @@ void maybeInitializeDeviceContext( throw std::runtime_error("Unsupported device: " + device.str()); } +VideoDecoder::DecodedOutput convertAVFrameToDecodedOutputOnDevice( + const torch::Device& device, + const VideoDecoder::VideoStreamDecoderOptions& options, + AVCodecContext* codecContext, + VideoDecoder::RawDecodedOutput& rawOutput) { + TORCH_CHECK(false, "We should not run device code on CPU") +} + } // namespace facebook::torchcodec From 3600eeee1f037b59a06855e28c8599f25f8d1809 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Fri, 4 Oct 2024 14:58:42 -0700 Subject: [PATCH 04/26] . --- src/torchcodec/decoders/_core/CMakeLists.txt | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index eff5b1f6..f0b46f68 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -34,12 +34,17 @@ function(make_torchcodec_library library_name ffmpeg_target) ${Python3_INCLUDE_DIRS} ) + set(NEEDED_LIBRARIES ${ffmpeg_target} ${TORCH_LIBRARIES} ${Python3_LIBRARIES}) + if(ENABLE_CUDA) + list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY} ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} ) + endif() + if(ENABLE_NVTX) + list(APPEND NEEDED_LIBRARIES nvtx3-cpp) + endif() target_link_libraries( ${library_name} PUBLIC - ${ffmpeg_target} - ${TORCH_LIBRARIES} - ${Python3_LIBRARIES} + ${NEEDED_LIBRARIES} ) # We already set the library_name to be libtorchcodecN, so we don't want From 05d02a4a542b537be1928e94118b72d5166f8651 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Mon, 7 Oct 2024 10:19:47 -0700 Subject: [PATCH 05/26] . --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 2a9c28ff..b2941efe 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -429,7 +429,7 @@ void VideoDecoder::addVideoStreamDecoder( int retVal = avcodec_parameters_to_context( streamInfo.codecContext.get(), streamInfo.stream->codecpar); if (options.device.type() != torch::kCPU) { - initializeDeviceContext(options.device); + initializeDeviceContext(options.device, codecContext); } TORCH_CHECK_EQ(retVal, AVSUCCESS); retVal = avcodec_open2(streamInfo.codecContext.get(), codec, nullptr); From 9f169c9810588f10296057ff0e6ba293a016ebef Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Mon, 7 Oct 2024 10:23:46 -0700 Subject: [PATCH 06/26] . --- src/torchcodec/decoders/_core/CPUOnlyDevice.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index f739e084..9b53fd3b 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -24,7 +24,9 @@ VideoDecoder::DecodedOutput convertAVFrameToDecodedOutputOnDevice( return output; } -void initializeDeviceContext(const torch::Device& device) { +void initializeDeviceContext( + const torch::Device& device, + AVCodecContext* codecContext) { throwUnsupportedDeviceError(device); } From 461a2ff379f529c2d16a13bf350cc8a1a4f9b196 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Mon, 7 Oct 2024 11:12:37 -0700 Subject: [PATCH 07/26] . --- src/torchcodec/decoders/_core/CudaDevice.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 6765ad85..875ab2d3 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -95,12 +95,23 @@ VideoDecoder::DecodedOutput convertAVFrameToDecodedOutputOnDevice( torch::Tensor& dst = output.frame; dst = allocateDeviceTensor({height, width, 3}, options.device); auto start = std::chrono::high_resolution_clock::now(); + cudaStream_t nppStream = nppGetStream(); + cudaStream_t torchStream = getCurrentCUDAStream().stream(); status = nppiNV12ToRGB_8u_P2C3R( input, src->linesize[0], static_cast(dst.data_ptr()), dst.stride(0), oSizeROI); + // Make the pytorch stream wait for the npp kernel to finish before using the + // output. + cudaEvent_t nppDoneEvent; + cudaEventCreate(&nppDoneEvent); + cudaEventRecord(nppDoneEvent, nppStream); + cudaEvent_t torchDoneEvent; + cudaEventCreate(&torchDoneEvent); + cudaEventRecord(torchDoneEvent, torchStream); + cudaStreamWaitEvent(torchStream, nppDoneEvent, 0); TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration = end - start; From 41a1ba2e5279b9ef26293e02b3d45e65885bf1ae Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Mon, 7 Oct 2024 11:18:12 -0700 Subject: [PATCH 08/26] . --- src/torchcodec/decoders/_core/CudaDevice.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 875ab2d3..5d0fd4f0 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -1,3 +1,4 @@ +#include #include #include #include "src/torchcodec/decoders/_core/DeviceInterface.h" @@ -96,7 +97,7 @@ VideoDecoder::DecodedOutput convertAVFrameToDecodedOutputOnDevice( dst = allocateDeviceTensor({height, width, 3}, options.device); auto start = std::chrono::high_resolution_clock::now(); cudaStream_t nppStream = nppGetStream(); - cudaStream_t torchStream = getCurrentCUDAStream().stream(); + cudaStream_t torchStream = at::cuda::getCurrentCUDAStream().stream(); status = nppiNV12ToRGB_8u_P2C3R( input, src->linesize[0], From 57818c5524a97915307f387ea99c6a32343912bd Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Mon, 7 Oct 2024 11:28:38 -0700 Subject: [PATCH 09/26] . --- benchmarks/decoders/gpu_benchmark.py | 144 ++++++++++++++++++ .../decoders/_core/VideoDecoderOps.cpp | 5 +- 2 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 benchmarks/decoders/gpu_benchmark.py diff --git a/benchmarks/decoders/gpu_benchmark.py b/benchmarks/decoders/gpu_benchmark.py new file mode 100644 index 00000000..fd26c2f7 --- /dev/null +++ b/benchmarks/decoders/gpu_benchmark.py @@ -0,0 +1,144 @@ +import argparse +import os +import time + +import torch.utils.benchmark as benchmark + +import torchcodec +import torchvision.transforms.v2.functional as F + +RESIZED_WIDTH = 256 +RESIZED_HEIGHT = 256 + + +def transfer_and_resize_frame(frame, resize_device_string): + # This should be a no-op if the frame is already on the target device. + frame = frame.to(resize_device_string) + frame = F.resize(frame, (RESIZED_HEIGHT, RESIZED_WIDTH)) + return frame + + +def decode_full_video(video_path, decode_device_string, resize_device_string): + # We use the core API instead of SimpleVideoDecoder because the core API + # allows us to natively resize as part of the decode step. + print(f"{decode_device_string=} {resize_device_string=}") + decoder = torchcodec.decoders._core.create_from_file(video_path) + num_threads = None + if "cuda" in decode_device_string: + num_threads = 1 + width = None + height = None + if "native" in resize_device_string: + width = RESIZED_WIDTH + height = RESIZED_HEIGHT + torchcodec.decoders._core._add_video_stream( + decoder, + stream_index=-1, + device=decode_device_string, + num_threads=num_threads, + width=width, + height=height, + ) + + start_time = time.time() + frame_count = 0 + while True: + try: + frame, *_ = torchcodec.decoders._core.get_next_frame(decoder) + if resize_device_string != "none" and "native" not in resize_device_string: + frame = transfer_and_resize_frame(frame, resize_device_string) + + frame_count += 1 + except Exception as e: + print("EXCEPTION", e) + break + + end_time = time.time() + elapsed = end_time - start_time + fps = frame_count / (end_time - start_time) + print( + f"****** DECODED full video {decode_device_string=} {frame_count=} {elapsed=} {fps=}" + ) + return frame_count, end_time - start_time + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--devices", + default="cuda:0,cpu", + type=str, + help="Comma-separated devices to test decoding on.", + ) + parser.add_argument( + "--resize_devices", + default="cuda:0,cpu,native,none", + type=str, + help="Comma-separated devices to test preroc (resize) on. Use 'none' to specify no resize.", + ) + parser.add_argument( + "--video", + type=str, + default=os.path.dirname(__file__) + "/../../test/resources/nasa_13013.mp4", + ) + parser.add_argument( + "--use_torch_benchmark", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Use pytorch benchmark to measure decode time with warmup and " + "autorange. Without this we just run one iteration without warmup " + "to measure the cold start time." + ), + ) + args = parser.parse_args() + video_path = args.video + + if not args.use_torch_benchmark: + for device in args.devices.split(","): + print("Testing on", device) + decode_full_video(video_path, device) + return + + resize_devices = args.resize_devices.split(",") + resize_devices = [d for d in resize_devices if d != ""] + if len(resize_devices) == 0: + resize_devices.append("none") + + label = "Decode+Resize Time" + + results = [] + for decode_device_string in args.devices.split(","): + for resize_device_string in resize_devices: + decode_label = decode_device_string + if "cuda" in decode_label: + # Shorten "cuda:0" to "cuda" + decode_label = "cuda" + resize_label = resize_device_string + if "cuda" in resize_device_string: + # Shorten "cuda:0" to "cuda" + resize_label = "cuda" + print("decode_device", decode_device_string) + print("resize_device", resize_device_string) + t = benchmark.Timer( + stmt="decode_full_video(video_path, decode_device_string, resize_device_string)", + globals={ + "decode_device_string": decode_device_string, + "video_path": video_path, + "decode_full_video": decode_full_video, + "resize_device_string": resize_device_string, + }, + label=label, + description=f"video={os.path.basename(video_path)}", + sub_label=f"D={decode_label} R={resize_label}", + ).blocked_autorange() + results.append(t) + compare = benchmark.Compare(results) + compare.print() + print("Key: D=Decode, R=Resize") + print("Native resize is done as part of the decode step") + print("none resize means there is no resize step -- native or otherwise") + + +if __name__ == "__main__": + main() diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 8b3a373d..03800a04 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -168,8 +168,9 @@ void _add_video_stream( if (device.has_value()) { if (device.value() == "cpu") { options.device = torch::Device(torch::kCPU); - } else if (device.value() == "cuda") { - options.device = torch::Device(torch::kCUDA); + } else if (device.value().starts_with("cuda")) { + std::string deviceStr(device.value()); + options.device = torch::Device(deviceStr); } else { throw std::runtime_error( "Invalid device=" + std::string(device.value()) + From 16112452aa961ac4984ddeccc21d334b5c32f19a Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Mon, 7 Oct 2024 11:44:26 -0700 Subject: [PATCH 10/26] . --- test/decoders/test_video_decoder_ops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 331d4968..8806392c 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -469,6 +469,8 @@ def test_cuda_decoder(self): assert frame0.device.type == "cuda" frame0_cpu = frame0.to("cpu") reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) + # We pass in atol of 60 because the CUDA decoder is not bit-accurate + # compared to the CPU decoder. torch.testing.assert_close(frame0_cpu, reference_frame0, atol=60, rtol=0) From 58624d0abcd47dec48c69402780a6315f3d6121c Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Mon, 7 Oct 2024 12:29:06 -0700 Subject: [PATCH 11/26] . --- benchmarks/decoders/gpu_benchmark.py | 70 ++++++++++++++++---- src/torchcodec/decoders/_core/CudaDevice.cpp | 14 +--- 2 files changed, 57 insertions(+), 27 deletions(-) diff --git a/benchmarks/decoders/gpu_benchmark.py b/benchmarks/decoders/gpu_benchmark.py index fd26c2f7..ea677a91 100644 --- a/benchmarks/decoders/gpu_benchmark.py +++ b/benchmarks/decoders/gpu_benchmark.py @@ -1,6 +1,7 @@ import argparse import os import time +from concurrent.futures import ThreadPoolExecutor import torch.utils.benchmark as benchmark @@ -62,6 +63,17 @@ def decode_full_video(video_path, decode_device_string, resize_device_string): return frame_count, end_time - start_time +def decode_videos_using_threads( + video_path, decode_device_string, resize_device_string, num_videos, num_threads +): + executor = ThreadPoolExecutor(max_workers=num_threads) + for i in range(num_videos): + executor.submit( + decode_full_video, video_path, decode_device_string, resize_device_string + ) + executor.shutdown(wait=True) + + def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -91,6 +103,18 @@ def main(): "to measure the cold start time." ), ) + parser.add_argument( + "--num_threads", + type=int, + default=1, + help="Number of threads to use for decoding. Only used when --use_torch_benchmark is set.", + ) + parser.add_argument( + "--num_videos", + type=int, + default=50, + help="Number of videos to decode in parallel. Only used when --num_threads is set.", + ) args = parser.parse_args() video_path = args.video @@ -120,22 +144,40 @@ def main(): resize_label = "cuda" print("decode_device", decode_device_string) print("resize_device", resize_device_string) - t = benchmark.Timer( - stmt="decode_full_video(video_path, decode_device_string, resize_device_string)", - globals={ - "decode_device_string": decode_device_string, - "video_path": video_path, - "decode_full_video": decode_full_video, - "resize_device_string": resize_device_string, - }, - label=label, - description=f"video={os.path.basename(video_path)}", - sub_label=f"D={decode_label} R={resize_label}", - ).blocked_autorange() - results.append(t) + if args.num_threads > 1: + t = benchmark.Timer( + stmt="decode_videos_using_threads(video_path, decode_device_string, resize_device_string, num_videos, num_threads)", + globals={ + "decode_device_string": decode_device_string, + "video_path": video_path, + "decode_full_video": decode_full_video, + "decode_videos_using_threads": decode_videos_using_threads, + "resize_device_string": resize_device_string, + "num_videos": args.num_videos, + "num_threads": args.num_threads, + }, + label=label, + description=f"threads={args.num_threads} work={args.num_videos} video={os.path.basename(video_path)}", + sub_label=f"D={decode_label} R={resize_label} T={args.num_threads} W={args.num_videos}", + ).blocked_autorange() + results.append(t) + else: + t = benchmark.Timer( + stmt="decode_full_video(video_path, decode_device_string, resize_device_string)", + globals={ + "decode_device_string": decode_device_string, + "video_path": video_path, + "decode_full_video": decode_full_video, + "resize_device_string": resize_device_string, + }, + label=label, + description=f"video={os.path.basename(video_path)}", + sub_label=f"D={decode_label} R={resize_label}", + ).blocked_autorange() + results.append(t) compare = benchmark.Compare(results) compare.print() - print("Key: D=Decode, R=Resize") + print("Key: D=Decode, R=Resize T=threads W=work (number of videos to decode)") print("Native resize is done as part of the decode step") print("none resize means there is no resize step -- native or otherwise") diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 5d0fd4f0..44956d53 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -18,19 +18,7 @@ AVBufferRef* getCudaContext() { TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); int err = 0; AVBufferRef* hw_device_ctx; - err = av_hwdevice_ctx_create( - &hw_device_ctx, - type, - nullptr, - nullptr, - // Introduced in 58.26.100: - // https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265 -#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100) - AV_CUDA_USE_CURRENT_CONTEXT -#else - 0 -#endif - ); + err = av_hwdevice_ctx_create(&hw_device_ctx, type, nullptr, nullptr, 0); if (err < 0) { TORCH_CHECK( false, From e5769295a53b1a5b16d85464ffc9afd1d8b71b5a Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Mon, 7 Oct 2024 15:04:54 -0700 Subject: [PATCH 12/26] . --- src/torchcodec/decoders/_core/CudaDevice.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 44956d53..dd7602fc 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -13,12 +13,18 @@ extern "C" { namespace facebook::torchcodec { namespace { -AVBufferRef* getCudaContext() { +AVBufferRef* getCudaContext(const torch::Device& device) { enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); int err = 0; AVBufferRef* hw_device_ctx; - err = av_hwdevice_ctx_create(&hw_device_ctx, type, nullptr, nullptr, 0); + torch::DeviceIndex deviceIndex = device.index(); + if (deviceIndex < 0) { + deviceIndex = 0; + } + std::string deviceOrdinal = std::to_string(deviceIndex); + err = av_hwdevice_ctx_create( + &hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0); if (err < 0) { TORCH_CHECK( false, @@ -56,7 +62,7 @@ void initializeDeviceContext( throwErrorIfNonCudaDevice(device); torch::Tensor dummyTensorForCudaInitialization = torch::empty( {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); - codecContext->hw_device_ctx = av_buffer_ref(getCudaContext()); + codecContext->hw_device_ctx = av_buffer_ref(getCudaContext(device)); return; } From dca3540121e9e21ebee295be057739b0a9829b94 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Mon, 7 Oct 2024 15:23:03 -0700 Subject: [PATCH 13/26] . --- src/torchcodec/decoders/_core/CudaDevice.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index dd7602fc..771d4877 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -89,7 +89,13 @@ VideoDecoder::DecodedOutput convertAVFrameToDecodedOutputOnDevice( VideoDecoder::DecodedOutput output; torch::Tensor& dst = output.frame; dst = allocateDeviceTensor({height, width, 3}, options.device); + at::DeviceIndex deviceIndex = device.index(); + deviceIndex = std::max(deviceIndex, 0); + at::DeviceIndex originalDeviceIndex = at::cuda::current_device(); + cudaSetDevice(deviceIndex); + auto start = std::chrono::high_resolution_clock::now(); + cudaStream_t nppStream = nppGetStream(); cudaStream_t torchStream = at::cuda::getCurrentCUDAStream().stream(); status = nppiNV12ToRGB_8u_P2C3R( @@ -108,7 +114,14 @@ VideoDecoder::DecodedOutput convertAVFrameToDecodedOutputOnDevice( cudaEventRecord(torchDoneEvent, torchStream); cudaStreamWaitEvent(torchStream, nppDoneEvent, 0); TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); + cudaEventDestroy(nppDoneEvent); + cudaEventDestroy(torchDoneEvent); + auto end = std::chrono::high_resolution_clock::now(); + + // Restore the original device_index. + cudaSetDevice(originalDeviceIndex); + std::chrono::duration duration = end - start; VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width << " took: " << duration.count() << "us" << std::endl; From 8ff05ee93cb5996f2e8a920aeca71445983ccb9c Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Mon, 7 Oct 2024 15:32:39 -0700 Subject: [PATCH 14/26] . --- src/torchcodec/decoders/_core/CudaDevice.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 771d4877..02ae39ae 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -19,9 +19,7 @@ AVBufferRef* getCudaContext(const torch::Device& device) { int err = 0; AVBufferRef* hw_device_ctx; torch::DeviceIndex deviceIndex = device.index(); - if (deviceIndex < 0) { - deviceIndex = 0; - } + deviceIndex = std::max(deviceIndex, 0); std::string deviceOrdinal = std::to_string(deviceIndex); err = av_hwdevice_ctx_create( &hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0); From a52ba5ccc828deac21b2d8d8cdb942dca483476a Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Mon, 7 Oct 2024 15:41:16 -0700 Subject: [PATCH 15/26] . --- benchmarks/decoders/gpu_benchmark.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/benchmarks/decoders/gpu_benchmark.py b/benchmarks/decoders/gpu_benchmark.py index ea677a91..a2a38357 100644 --- a/benchmarks/decoders/gpu_benchmark.py +++ b/benchmarks/decoders/gpu_benchmark.py @@ -3,6 +3,8 @@ import time from concurrent.futures import ThreadPoolExecutor +import torch + import torch.utils.benchmark as benchmark import torchcodec @@ -64,12 +66,20 @@ def decode_full_video(video_path, decode_device_string, resize_device_string): def decode_videos_using_threads( - video_path, decode_device_string, resize_device_string, num_videos, num_threads + video_path, + decode_device_string, + resize_device_string, + num_videos, + num_threads, + use_multiple_gpus, ): executor = ThreadPoolExecutor(max_workers=num_threads) for i in range(num_videos): + actual_decode_device = decode_device_string + if "cuda" in decode_device_string: + actual_decode_device = f"cuda:{i % torch.cuda.device_count()}" executor.submit( - decode_full_video, video_path, decode_device_string, resize_device_string + decode_full_video, video_path, actual_decode_device, resize_device_string ) executor.shutdown(wait=True) @@ -115,6 +125,12 @@ def main(): default=50, help="Number of videos to decode in parallel. Only used when --num_threads is set.", ) + parser.add_argument( + "--use_multiple_gpus", + action=argparse.BooleanOptionalAction, + default=True, + help=("Use multiple GPUs to decode multiple videos in multi-threaded mode."), + ) args = parser.parse_args() video_path = args.video @@ -146,7 +162,7 @@ def main(): print("resize_device", resize_device_string) if args.num_threads > 1: t = benchmark.Timer( - stmt="decode_videos_using_threads(video_path, decode_device_string, resize_device_string, num_videos, num_threads)", + stmt="decode_videos_using_threads(video_path, decode_device_string, resize_device_string, num_videos, num_threads, use_multiple_gpus)", globals={ "decode_device_string": decode_device_string, "video_path": video_path, @@ -155,6 +171,7 @@ def main(): "resize_device_string": resize_device_string, "num_videos": args.num_videos, "num_threads": args.num_threads, + "use_multiple_gpus": args.use_multiple_gpus, }, label=label, description=f"threads={args.num_threads} work={args.num_videos} video={os.path.basename(video_path)}", From a98aaa3643e3d73863603f21169e9e182daace1a Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Tue, 8 Oct 2024 06:45:29 -0700 Subject: [PATCH 16/26] . --- benchmarks/decoders/gpu_benchmark.py | 3 ++- src/torchcodec/decoders/_core/CMakeLists.txt | 3 --- src/torchcodec/decoders/_core/CPUOnlyDevice.cpp | 11 +++++------ src/torchcodec/decoders/_core/CudaDevice.cpp | 8 ++++---- src/torchcodec/decoders/_core/DeviceInterface.h | 5 +++-- src/torchcodec/decoders/_core/VideoDecoder.cpp | 17 ++++++++++++++--- src/torchcodec/decoders/_core/VideoDecoder.h | 3 +++ test/decoders/test_video_decoder_ops.py | 13 +++++++++---- 8 files changed, 40 insertions(+), 23 deletions(-) diff --git a/benchmarks/decoders/gpu_benchmark.py b/benchmarks/decoders/gpu_benchmark.py index a2a38357..44e9cea7 100644 --- a/benchmarks/decoders/gpu_benchmark.py +++ b/benchmarks/decoders/gpu_benchmark.py @@ -1,5 +1,6 @@ import argparse import os +import pathlib import time from concurrent.futures import ThreadPoolExecutor @@ -101,7 +102,7 @@ def main(): parser.add_argument( "--video", type=str, - default=os.path.dirname(__file__) + "/../../test/resources/nasa_13013.mp4", + default=pathlib.Path(__file__).parent / "../../test/resources/nasa_13013.mp4", ) parser.add_argument( "--use_torch_benchmark", diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index f0b46f68..d3d2c202 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -38,9 +38,6 @@ function(make_torchcodec_library library_name ffmpeg_target) if(ENABLE_CUDA) list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY} ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} ) endif() - if(ENABLE_NVTX) - list(APPEND NEEDED_LIBRARIES nvtx3-cpp) - endif() target_link_libraries( ${library_name} PUBLIC diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 9b53fd3b..65082e6f 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -7,21 +7,20 @@ namespace facebook::torchcodec { // So all functions will throw an error because they should only be called if // the device is not CPU. -void throwUnsupportedDeviceError(const torch::Device& device) { +[[noreturn]] void throwUnsupportedDeviceError(const torch::Device& device) { TORCH_CHECK( device.type() != torch::kCPU, "Device functions should only be called if the device is not CPU.") - throw std::runtime_error("Unsupported device: " + device.str()); + TORCH_CHECK(false, "Unsupported device: " + device.str()); } -VideoDecoder::DecodedOutput convertAVFrameToDecodedOutputOnDevice( +void convertAVFrameToDecodedOutputOnDevice( const torch::Device& device, const VideoDecoder::VideoStreamDecoderOptions& options, AVCodecContext* codecContext, - VideoDecoder::RawDecodedOutput& rawOutput) { + VideoDecoder::RawDecodedOutput& rawOutput, + VideoDecoder::DecodedOutput& output) { throwUnsupportedDeviceError(device); - VideoDecoder::DecodedOutput output; - return output; } void initializeDeviceContext( diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 02ae39ae..7abaf886 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -58,17 +58,19 @@ void initializeDeviceContext( const torch::Device& device, AVCodecContext* codecContext) { throwErrorIfNonCudaDevice(device); + // This is a dummy tensor to initialize the cuda context. torch::Tensor dummyTensorForCudaInitialization = torch::empty( {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); codecContext->hw_device_ctx = av_buffer_ref(getCudaContext(device)); return; } -VideoDecoder::DecodedOutput convertAVFrameToDecodedOutputOnDevice( +void convertAVFrameToDecodedOutputOnDevice( const torch::Device& device, const VideoDecoder::VideoStreamDecoderOptions& options, AVCodecContext* codecContext, - VideoDecoder::RawDecodedOutput& rawOutput) { + VideoDecoder::RawDecodedOutput& rawOutput, + VideoDecoder::DecodedOutput& output) { AVFrame* src = rawOutput.frame.get(); TORCH_CHECK( @@ -84,7 +86,6 @@ VideoDecoder::DecodedOutput convertAVFrameToDecodedOutputOnDevice( Npp8u* input[2]; input[0] = (Npp8u*)src->data[0]; input[1] = (Npp8u*)src->data[1]; - VideoDecoder::DecodedOutput output; torch::Tensor& dst = output.frame; dst = allocateDeviceTensor({height, width, 3}, options.device); at::DeviceIndex deviceIndex = device.index(); @@ -128,7 +129,6 @@ VideoDecoder::DecodedOutput convertAVFrameToDecodedOutputOnDevice( // https://pytorch.org/docs/stable/generated/torch.permute.html dst = dst.permute({2, 0, 1}); } - return output; } } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 0aed72c4..7f20cf0b 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -32,10 +32,11 @@ void initializeDeviceContext( const torch::Device& device, AVCodecContext* codecContext); -VideoDecoder::DecodedOutput convertAVFrameToDecodedOutputOnDevice( +void convertAVFrameToDecodedOutputOnDevice( const torch::Device& device, const VideoDecoder::VideoStreamDecoderOptions& options, AVCodecContext* codecContext, - VideoDecoder::RawDecodedOutput& rawOutput); + VideoDecoder::RawDecodedOutput& rawOutput, + VideoDecoder::DecodedOutput& output); } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index b2941efe..f294f60f 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -857,12 +857,24 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( output.durationSeconds = ptsToSeconds( getDuration(frame), formatContext_->streams[streamIndex]->time_base); if (streamInfo.options.device.type() != torch::kCPU) { - return convertAVFrameToDecodedOutputOnDevice( + convertAVFrameToDecodedOutputOnDevice( streamInfo.options.device, streamInfo.options, streamInfo.codecContext.get(), - rawOutput); + rawOutput, + output); + } else { + convertAVFrameToDecodedOutputOnCPU(rawOutput, output); } + return output; +} + +void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( + VideoDecoder::RawDecodedOutput& rawOutput, + DecodedOutput& output) { + int streamIndex = rawOutput.streamIndex; + AVFrame* frame = rawOutput.frame.get(); + auto& streamInfo = streams_[streamIndex]; if (output.streamType == AVMEDIA_TYPE_VIDEO) { if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { int width = streamInfo.options.width.value_or(frame->width); @@ -891,7 +903,6 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( // audio decoding. throw std::runtime_error("Audio is not supported yet."); } - return output; } VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index a3a1888b..41509adc 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -364,6 +364,9 @@ class VideoDecoder { const AVFrame* frame); void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput); DecodedOutput convertAVFrameToDecodedOutput(RawDecodedOutput& rawOutput); + void convertAVFrameToDecodedOutputOnCPU( + RawDecodedOutput& rawOutput, + DecodedOutput& output); DecoderOptions options_; ContainerMetadata containerMetadata_; diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 8806392c..1bb28feb 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -465,13 +465,18 @@ def test_cuda_decoder(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder, device="cuda") - frame0, *_ = get_next_frame(decoder) + frame0, pts, duration = get_next_frame(decoder) assert frame0.device.type == "cuda" frame0_cpu = frame0.to("cpu") reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) - # We pass in atol of 60 because the CUDA decoder is not bit-accurate - # compared to the CPU decoder. - torch.testing.assert_close(frame0_cpu, reference_frame0, atol=60, rtol=0) + # GPU decode is not bit-accurate. In the following assertion we ensure + # not more than 0.3% of values have a difference greater than 20. + diff = (reference_frame0.float() - frame0_cpu.float()).abs() + assert (diff > 20).float().mean() <= 0.003 + assert pts == torch.tensor([0]) + torch.testing.assert_close( + duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 + ) if __name__ == "__main__": From ec160a9712717f14f0d7b907626aa35ac426797f Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Tue, 8 Oct 2024 06:46:28 -0700 Subject: [PATCH 17/26] . --- src/torchcodec/decoders/_core/CudaDevice.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 7abaf886..715f9be8 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -16,12 +16,11 @@ namespace { AVBufferRef* getCudaContext(const torch::Device& device) { enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); - int err = 0; - AVBufferRef* hw_device_ctx; torch::DeviceIndex deviceIndex = device.index(); deviceIndex = std::max(deviceIndex, 0); std::string deviceOrdinal = std::to_string(deviceIndex); - err = av_hwdevice_ctx_create( + AVBufferRef* hw_device_ctx; + int err = av_hwdevice_ctx_create( &hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0); if (err < 0) { TORCH_CHECK( From f096a16576b2c83762760e9134b98c3718074265 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Tue, 8 Oct 2024 07:24:12 -0700 Subject: [PATCH 18/26] . --- benchmarks/decoders/gpu_benchmark.py | 4 +- src/torchcodec/decoders/_core/CudaDevice.cpp | 41 +++++++------------- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/benchmarks/decoders/gpu_benchmark.py b/benchmarks/decoders/gpu_benchmark.py index 44e9cea7..eef50765 100644 --- a/benchmarks/decoders/gpu_benchmark.py +++ b/benchmarks/decoders/gpu_benchmark.py @@ -102,7 +102,9 @@ def main(): parser.add_argument( "--video", type=str, - default=pathlib.Path(__file__).parent / "../../test/resources/nasa_13013.mp4", + default=str( + pathlib.Path(__file__).parent / "../../test/resources/nasa_13013.mp4" + ), ) parser.add_argument( "--use_torch_benchmark", diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 715f9be8..f8abb5a1 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -17,6 +18,7 @@ AVBufferRef* getCudaContext(const torch::Device& device) { enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); torch::DeviceIndex deviceIndex = device.index(); + // FFMPEG cannot handle negative device indices. deviceIndex = std::max(deviceIndex, 0); std::string deviceOrdinal = std::to_string(deviceIndex); AVBufferRef* hw_device_ctx; @@ -78,48 +80,33 @@ void convertAVFrameToDecodedOutputOnDevice( std::string(av_get_pix_fmt_name((AVPixelFormat)src->format))); int width = options.width.value_or(codecContext->width); int height = options.height.value_or(codecContext->height); - NppStatus status; - NppiSize oSizeROI; - oSizeROI.width = width; - oSizeROI.height = height; - Npp8u* input[2]; - input[0] = (Npp8u*)src->data[0]; - input[1] = (Npp8u*)src->data[1]; + NppiSize oSizeROI = {width, height}; + Npp8u* input[2] = {src->data[0], src->data[1]}; torch::Tensor& dst = output.frame; dst = allocateDeviceTensor({height, width, 3}, options.device); - at::DeviceIndex deviceIndex = device.index(); - deviceIndex = std::max(deviceIndex, 0); - at::DeviceIndex originalDeviceIndex = at::cuda::current_device(); - cudaSetDevice(deviceIndex); + + // Use the user-requested GPU for running the NPP kernel. + c10::cuda::CUDAGuard deviceGuard(device); auto start = std::chrono::high_resolution_clock::now(); - cudaStream_t nppStream = nppGetStream(); - cudaStream_t torchStream = at::cuda::getCurrentCUDAStream().stream(); - status = nppiNV12ToRGB_8u_P2C3R( + NppStatus status = nppiNV12ToRGB_8u_P2C3R( input, src->linesize[0], static_cast(dst.data_ptr()), dst.stride(0), oSizeROI); + TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); // Make the pytorch stream wait for the npp kernel to finish before using the // output. - cudaEvent_t nppDoneEvent; - cudaEventCreate(&nppDoneEvent); - cudaEventRecord(nppDoneEvent, nppStream); - cudaEvent_t torchDoneEvent; - cudaEventCreate(&torchDoneEvent); - cudaEventRecord(torchDoneEvent, torchStream); - cudaStreamWaitEvent(torchStream, nppDoneEvent, 0); - TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); - cudaEventDestroy(nppDoneEvent); - cudaEventDestroy(torchDoneEvent); + at::cuda::CUDAEvent nppDoneEvent; + at::cuda::CUDAStream nppStreamWrapper = + c10::cuda::getStreamFromExternal(nppGetStream(), device.index()); + nppDoneEvent.record(nppStreamWrapper); + nppDoneEvent.block(at::cuda::getCurrentCUDAStream()); auto end = std::chrono::high_resolution_clock::now(); - // Restore the original device_index. - cudaSetDevice(originalDeviceIndex); - std::chrono::duration duration = end - start; VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width << " took: " << duration.count() << "us" << std::endl; From 0c78564594b8191c915df4d0c128567af862bcbe Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 9 Oct 2024 07:45:23 -0700 Subject: [PATCH 19/26] . --- src/torchcodec/decoders/_core/CPUOnlyDevice.cpp | 4 ++-- src/torchcodec/decoders/_core/CudaDevice.cpp | 4 ++-- src/torchcodec/decoders/_core/DeviceInterface.h | 4 ++-- src/torchcodec/decoders/_core/VideoDecoder.cpp | 17 ++++++++++++----- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 65082e6f..404d8750 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -14,7 +14,7 @@ namespace facebook::torchcodec { TORCH_CHECK(false, "Unsupported device: " + device.str()); } -void convertAVFrameToDecodedOutputOnDevice( +void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamDecoderOptions& options, AVCodecContext* codecContext, @@ -23,7 +23,7 @@ void convertAVFrameToDecodedOutputOnDevice( throwUnsupportedDeviceError(device); } -void initializeDeviceContext( +void initializeContextOnCuda( const torch::Device& device, AVCodecContext* codecContext) { throwUnsupportedDeviceError(device); diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index f8abb5a1..a3e5af3d 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -55,7 +55,7 @@ void throwErrorIfNonCudaDevice(const torch::Device& device) { } } // namespace -void initializeDeviceContext( +void initializeContextOnCuda( const torch::Device& device, AVCodecContext* codecContext) { throwErrorIfNonCudaDevice(device); @@ -66,7 +66,7 @@ void initializeDeviceContext( return; } -void convertAVFrameToDecodedOutputOnDevice( +void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamDecoderOptions& options, AVCodecContext* codecContext, diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 7f20cf0b..3ef428fd 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -28,11 +28,11 @@ namespace facebook::torchcodec { // Initialize the hardware device that is specified in `device`. Some builds // support CUDA and others only support CPU. -void initializeDeviceContext( +void initializeContextOnCuda( const torch::Device& device, AVCodecContext* codecContext); -void convertAVFrameToDecodedOutputOnDevice( +void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamDecoderOptions& options, AVCodecContext* codecContext, diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index f294f60f..c16009cd 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -428,8 +428,12 @@ void VideoDecoder::addVideoStreamDecoder( streamInfo.codecContext.reset(codecContext); int retVal = avcodec_parameters_to_context( streamInfo.codecContext.get(), streamInfo.stream->codecpar); - if (options.device.type() != torch::kCPU) { - initializeDeviceContext(options.device, codecContext); + if (options.device.type() == torch::kCPU) { + // No more initialization needed for CPU. + } else if (options.device.type() == torch::kCUDA) { + initializeContextOnCuda(options.device, codecContext); + } else { + throw std::invalid_argument("Invalid device type: " + options.device.str()); } TORCH_CHECK_EQ(retVal, AVSUCCESS); retVal = avcodec_open2(streamInfo.codecContext.get(), codec, nullptr); @@ -856,15 +860,18 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( output.duration = getDuration(frame); output.durationSeconds = ptsToSeconds( getDuration(frame), formatContext_->streams[streamIndex]->time_base); - if (streamInfo.options.device.type() != torch::kCPU) { - convertAVFrameToDecodedOutputOnDevice( + if (streamInfo.options.device.type() == torch::kCPU) { + convertAVFrameToDecodedOutputOnCPU(rawOutput, output); + } else if (streamInfo.options.device.type() == torch::kCUDA) { + convertAVFrameToDecodedOutputOnCuda( streamInfo.options.device, streamInfo.options, streamInfo.codecContext.get(), rawOutput, output); } else { - convertAVFrameToDecodedOutputOnCPU(rawOutput, output); + throw std::invalid_argument( + "Invalid device type: " + streamInfo.options.device.str()); } return output; } From 32cdb3758a8d5f4f255e57a1e920990958595c61 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 9 Oct 2024 07:50:37 -0700 Subject: [PATCH 20/26] . --- src/torchcodec/decoders/_core/CudaDevice.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index a3e5af3d..9700d419 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -19,6 +19,9 @@ AVBufferRef* getCudaContext(const torch::Device& device) { TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); torch::DeviceIndex deviceIndex = device.index(); // FFMPEG cannot handle negative device indices. + // For single GPU- machines libtorch returns -1 for the device index. So for + // that case we set the device index to 0. + // TODO: Double check if this works for multi-GPU machines correctly. deviceIndex = std::max(deviceIndex, 0); std::string deviceOrdinal = std::to_string(deviceIndex); AVBufferRef* hw_device_ctx; From 27bb2b29112ccfea91b304278958aa936597551c Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 9 Oct 2024 08:01:26 -0700 Subject: [PATCH 21/26] . --- src/torchcodec/decoders/_core/CudaDevice.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 9700d419..d2773ffd 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -65,7 +65,7 @@ void initializeContextOnCuda( // This is a dummy tensor to initialize the cuda context. torch::Tensor dummyTensorForCudaInitialization = torch::empty( {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); - codecContext->hw_device_ctx = av_buffer_ref(getCudaContext(device)); + codecContext->hw_device_ctx = getCudaContext(device); return; } From e65ddc35258c0d6673c2fe3751177ac31db1f771 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 9 Oct 2024 09:27:27 -0700 Subject: [PATCH 22/26] . --- src/torchcodec/decoders/_core/CudaDevice.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index d2773ffd..58234922 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -62,6 +62,8 @@ void initializeContextOnCuda( const torch::Device& device, AVCodecContext* codecContext) { throwErrorIfNonCudaDevice(device); + // It is important for pytorch itself to create the cuda context. If ffmpeg + // creates the context it may not be compatible with pytorch. // This is a dummy tensor to initialize the cuda context. torch::Tensor dummyTensorForCudaInitialization = torch::empty( {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); From 0abc173a85dff036d1a488a404bc932223822e1b Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 9 Oct 2024 11:58:27 -0700 Subject: [PATCH 23/26] . --- benchmarks/decoders/gpu_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/decoders/gpu_benchmark.py b/benchmarks/decoders/gpu_benchmark.py index eef50765..8655f889 100644 --- a/benchmarks/decoders/gpu_benchmark.py +++ b/benchmarks/decoders/gpu_benchmark.py @@ -77,7 +77,7 @@ def decode_videos_using_threads( executor = ThreadPoolExecutor(max_workers=num_threads) for i in range(num_videos): actual_decode_device = decode_device_string - if "cuda" in decode_device_string: + if "cuda" in decode_device_string and use_multiple_gpus: actual_decode_device = f"cuda:{i % torch.cuda.device_count()}" executor.submit( decode_full_video, video_path, actual_decode_device, resize_device_string From 21d8c1a7bac169ff3cb7dd3df841143700a5a06d Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 9 Oct 2024 12:35:20 -0700 Subject: [PATCH 24/26] . --- src/torchcodec/decoders/_core/CMakeLists.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index d3d2c202..edf65696 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -34,9 +34,11 @@ function(make_torchcodec_library library_name ffmpeg_target) ${Python3_INCLUDE_DIRS} ) - set(NEEDED_LIBRARIES ${ffmpeg_target} ${TORCH_LIBRARIES} ${Python3_LIBRARIES}) + set(NEEDED_LIBRARIES ${ffmpeg_target} ${TORCH_LIBRARIES} \ + ${Python3_LIBRARIES}) if(ENABLE_CUDA) - list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY} ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} ) + list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY} \ + ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} ) endif() target_link_libraries( ${library_name} From e4e02b38e66b6b0a67a1d3772abcb14ad2b433b8 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 9 Oct 2024 12:35:59 -0700 Subject: [PATCH 25/26] . --- src/torchcodec/decoders/_core/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index edf65696..7ba2fb76 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -38,7 +38,7 @@ function(make_torchcodec_library library_name ffmpeg_target) ${Python3_LIBRARIES}) if(ENABLE_CUDA) list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY} \ - ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} ) + ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} ) endif() target_link_libraries( ${library_name} From 4624f5b2808817924eef9056630b50205d013140 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 9 Oct 2024 12:40:50 -0700 Subject: [PATCH 26/26] . --- src/torchcodec/decoders/_core/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 7ba2fb76..2527c217 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -34,10 +34,10 @@ function(make_torchcodec_library library_name ffmpeg_target) ${Python3_INCLUDE_DIRS} ) - set(NEEDED_LIBRARIES ${ffmpeg_target} ${TORCH_LIBRARIES} \ + set(NEEDED_LIBRARIES ${ffmpeg_target} ${TORCH_LIBRARIES} ${Python3_LIBRARIES}) if(ENABLE_CUDA) - list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY} \ + list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY} ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} ) endif() target_link_libraries(