Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUDA decoding support #242

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
204 changes: 204 additions & 0 deletions benchmarks/decoders/gpu_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import argparse
import os
import pathlib
import time
from concurrent.futures import ThreadPoolExecutor

import torch

import torch.utils.benchmark as benchmark

import torchcodec
import torchvision.transforms.v2.functional as F

RESIZED_WIDTH = 256
RESIZED_HEIGHT = 256


def transfer_and_resize_frame(frame, resize_device_string):
# This should be a no-op if the frame is already on the target device.
frame = frame.to(resize_device_string)
frame = F.resize(frame, (RESIZED_HEIGHT, RESIZED_WIDTH))
return frame


def decode_full_video(video_path, decode_device_string, resize_device_string):
# We use the core API instead of SimpleVideoDecoder because the core API
# allows us to natively resize as part of the decode step.
print(f"{decode_device_string=} {resize_device_string=}")
decoder = torchcodec.decoders._core.create_from_file(video_path)
num_threads = None
if "cuda" in decode_device_string:
num_threads = 1
width = None
height = None
if "native" in resize_device_string:
width = RESIZED_WIDTH
height = RESIZED_HEIGHT
torchcodec.decoders._core._add_video_stream(
decoder,
stream_index=-1,
device=decode_device_string,
num_threads=num_threads,
width=width,
height=height,
)

start_time = time.time()
frame_count = 0
while True:
try:
frame, *_ = torchcodec.decoders._core.get_next_frame(decoder)
if resize_device_string != "none" and "native" not in resize_device_string:
frame = transfer_and_resize_frame(frame, resize_device_string)

frame_count += 1
except Exception as e:
print("EXCEPTION", e)
break

end_time = time.time()
elapsed = end_time - start_time
fps = frame_count / (end_time - start_time)
print(
f"****** DECODED full video {decode_device_string=} {frame_count=} {elapsed=} {fps=}"
)
return frame_count, end_time - start_time


def decode_videos_using_threads(
video_path,
decode_device_string,
resize_device_string,
num_videos,
num_threads,
use_multiple_gpus,
):
executor = ThreadPoolExecutor(max_workers=num_threads)
for i in range(num_videos):
actual_decode_device = decode_device_string
if "cuda" in decode_device_string:
actual_decode_device = f"cuda:{i % torch.cuda.device_count()}"
executor.submit(
decode_full_video, video_path, actual_decode_device, resize_device_string
)
executor.shutdown(wait=True)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--devices",
default="cuda:0,cpu",
type=str,
help="Comma-separated devices to test decoding on.",
)
parser.add_argument(
"--resize_devices",
default="cuda:0,cpu,native,none",
type=str,
help="Comma-separated devices to test preroc (resize) on. Use 'none' to specify no resize.",
)
parser.add_argument(
"--video",
type=str,
default=pathlib.Path(__file__).parent / "../../test/resources/nasa_13013.mp4",
)
parser.add_argument(
"--use_torch_benchmark",
action=argparse.BooleanOptionalAction,
default=True,
help=(
"Use pytorch benchmark to measure decode time with warmup and "
"autorange. Without this we just run one iteration without warmup "
"to measure the cold start time."
),
)
parser.add_argument(
"--num_threads",
type=int,
default=1,
help="Number of threads to use for decoding. Only used when --use_torch_benchmark is set.",
)
parser.add_argument(
"--num_videos",
type=int,
default=50,
help="Number of videos to decode in parallel. Only used when --num_threads is set.",
)
parser.add_argument(
"--use_multiple_gpus",
action=argparse.BooleanOptionalAction,
default=True,
help=("Use multiple GPUs to decode multiple videos in multi-threaded mode."),
)
args = parser.parse_args()
video_path = args.video

if not args.use_torch_benchmark:
for device in args.devices.split(","):
print("Testing on", device)
decode_full_video(video_path, device)
return

resize_devices = args.resize_devices.split(",")
resize_devices = [d for d in resize_devices if d != ""]
if len(resize_devices) == 0:
resize_devices.append("none")

label = "Decode+Resize Time"

results = []
for decode_device_string in args.devices.split(","):
for resize_device_string in resize_devices:
decode_label = decode_device_string
if "cuda" in decode_label:
# Shorten "cuda:0" to "cuda"
decode_label = "cuda"
resize_label = resize_device_string
if "cuda" in resize_device_string:
# Shorten "cuda:0" to "cuda"
resize_label = "cuda"
print("decode_device", decode_device_string)
print("resize_device", resize_device_string)
if args.num_threads > 1:
t = benchmark.Timer(
stmt="decode_videos_using_threads(video_path, decode_device_string, resize_device_string, num_videos, num_threads, use_multiple_gpus)",
globals={
"decode_device_string": decode_device_string,
"video_path": video_path,
"decode_full_video": decode_full_video,
"decode_videos_using_threads": decode_videos_using_threads,
"resize_device_string": resize_device_string,
"num_videos": args.num_videos,
"num_threads": args.num_threads,
"use_multiple_gpus": args.use_multiple_gpus,
},
label=label,
description=f"threads={args.num_threads} work={args.num_videos} video={os.path.basename(video_path)}",
sub_label=f"D={decode_label} R={resize_label} T={args.num_threads} W={args.num_videos}",
).blocked_autorange()
results.append(t)
else:
t = benchmark.Timer(
stmt="decode_full_video(video_path, decode_device_string, resize_device_string)",
globals={
"decode_device_string": decode_device_string,
"video_path": video_path,
"decode_full_video": decode_full_video,
"resize_device_string": resize_device_string,
},
label=label,
description=f"video={os.path.basename(video_path)}",
sub_label=f"D={decode_label} R={resize_label}",
).blocked_autorange()
results.append(t)
compare = benchmark.Compare(results)
compare.print()
print("Key: D=Decode, R=Resize T=threads W=work (number of videos to decode)")
print("Native resize is done as part of the decode step")
print("none resize means there is no resize step -- native or otherwise")


if __name__ == "__main__":
main()
8 changes: 5 additions & 3 deletions src/torchcodec/decoders/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ function(make_torchcodec_library library_name ffmpeg_target)
${Python3_INCLUDE_DIRS}
)

set(NEEDED_LIBRARIES ${ffmpeg_target} ${TORCH_LIBRARIES} ${Python3_LIBRARIES})
if(ENABLE_CUDA)
list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY} ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} )
endif()
target_link_libraries(
${library_name}
PUBLIC
${ffmpeg_target}
${TORCH_LIBRARIES}
${Python3_LIBRARIES}
${NEEDED_LIBRARIES}
)

# We already set the library_name to be libtorchcodecN, so we don't want
Expand Down
18 changes: 15 additions & 3 deletions src/torchcodec/decoders/_core/CPUOnlyDevice.cpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
#include <torch/types.h>
#include "src/torchcodec/decoders/_core/DeviceInterface.h"

namespace facebook::torchcodec {

// This file is linked with the CPU-only version of torchcodec.
// So all functions will throw an error because they should only be called if
// the device is not CPU.

void throwUnsupportedDeviceError(const torch::Device& device) {
[[noreturn]] void throwUnsupportedDeviceError(const torch::Device& device) {
TORCH_CHECK(
device.type() != torch::kCPU,
"Device functions should only be called if the device is not CPU.")
throw std::runtime_error("Unsupported device: " + device.str());
TORCH_CHECK(false, "Unsupported device: " + device.str());
}

void initializeDeviceContext(const torch::Device& device) {
void convertAVFrameToDecodedOutputOnDevice(
const torch::Device& device,
const VideoDecoder::VideoStreamDecoderOptions& options,
AVCodecContext* codecContext,
VideoDecoder::RawDecodedOutput& rawOutput,
VideoDecoder::DecodedOutput& output) {
throwUnsupportedDeviceError(device);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this function and the function below always throw? If yes, then we should just do something like TORCH_CHECK(false, "Unsupported device.");. In order avoid the need for a return value, mark the function as [[noreturn]]: https://en.cppreference.com/w/cpp/language/attributes/noreturn. We should rely on a TORCH macro to do the throwing for us rather than doing the throw ourselves, and we should make it obviously one that will always fail its check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we maybe should also annotate convertAVFrameToDecodedOutputOnDevice() and initializeDeviceContext() with [[noreturn]]. Let's also avoid two TORCH_CHECK calls. Whatever message we want to put on stderr, we can do it in one check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two checks are there because one is a programming/logic error on our part -- we should never pass in a CPU device for device functions.

The other is the check for passing in a non-compiled device.

}

void initializeDeviceContext(
const torch::Device& device,
AVCodecContext* codecContext) {
throwUnsupportedDeviceError(device);
}

Expand Down
121 changes: 116 additions & 5 deletions src/torchcodec/decoders/_core/CudaDevice.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,47 @@
#include <c10/cuda/CUDAStream.h>
#include <npp.h>
#include <torch/types.h>
#include "src/torchcodec/decoders/_core/DeviceInterface.h"
#include "src/torchcodec/decoders/_core/FFMPEGCommon.h"
#include "src/torchcodec/decoders/_core/VideoDecoder.h"

extern "C" {
#include <libavcodec/avcodec.h>
#include <libavutil/hwcontext_cuda.h>
#include <libavutil/pixdesc.h>
}

namespace facebook::torchcodec {
namespace {
AVBufferRef* getCudaContext(const torch::Device& device) {
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
torch::DeviceIndex deviceIndex = device.index();
deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0);
std::string deviceOrdinal = std::to_string(deviceIndex);
AVBufferRef* hw_device_ctx;
int err = av_hwdevice_ctx_create(
&hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0);
if (err < 0) {
TORCH_CHECK(
false,
"Failed to create specified HW device",
getFFMPEGErrorStringFromErrorCode(err));
}
return hw_device_ctx;
}

torch::Tensor allocateDeviceTensor(
at::IntArrayRef shape,
torch::Device device,
const torch::Dtype dtype = torch::kUInt8) {
return torch::empty(
shape,
torch::TensorOptions()
.dtype(dtype)
.layout(torch::kStrided)
.device(device));
}

void throwErrorIfNonCudaDevice(const torch::Device& device) {
TORCH_CHECK(
Expand All @@ -10,13 +51,83 @@ void throwErrorIfNonCudaDevice(const torch::Device& device) {
throw std::runtime_error("Unsupported device: " + device.str());
}
}
} // namespace

void initializeDeviceContext(const torch::Device& device) {
void initializeDeviceContext(
const torch::Device& device,
AVCodecContext* codecContext) {
throwErrorIfNonCudaDevice(device);
// TODO: https://github.com/pytorch/torchcodec/issues/238: Implement CUDA
// device.
throw std::runtime_error(
"CUDA device is unimplemented. Follow this issue for tracking progress: https://github.com/pytorch/torchcodec/issues/238");
// This is a dummy tensor to initialize the cuda context.
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device));
ahmadsharif1 marked this conversation as resolved.
Show resolved Hide resolved
codecContext->hw_device_ctx = av_buffer_ref(getCudaContext(device));
return;
}

void convertAVFrameToDecodedOutputOnDevice(
const torch::Device& device,
const VideoDecoder::VideoStreamDecoderOptions& options,
AVCodecContext* codecContext,
VideoDecoder::RawDecodedOutput& rawOutput,
VideoDecoder::DecodedOutput& output) {
AVFrame* src = rawOutput.frame.get();

TORCH_CHECK(
src->format == AV_PIX_FMT_CUDA,
"Expected format to be AV_PIX_FMT_CUDA, got " +
std::string(av_get_pix_fmt_name((AVPixelFormat)src->format)));
int width = options.width.value_or(codecContext->width);
int height = options.height.value_or(codecContext->height);
NppStatus status;
NppiSize oSizeROI;
oSizeROI.width = width;
oSizeROI.height = height;
ahmadsharif1 marked this conversation as resolved.
Show resolved Hide resolved
Npp8u* input[2];
input[0] = (Npp8u*)src->data[0];
input[1] = (Npp8u*)src->data[1];
ahmadsharif1 marked this conversation as resolved.
Show resolved Hide resolved
torch::Tensor& dst = output.frame;
dst = allocateDeviceTensor({height, width, 3}, options.device);
at::DeviceIndex deviceIndex = device.index();
deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things:

  1. This is the second place we're doing this same logic. We should abstract it into a function, even though it's small. The function name will probably help with my second point.
  2. I'm not quite sure why we're doing it? Under what circumstance will ATen's reported index for a device be less than 0? It looks like it defaults to -1 in some cases (https://pytorch.org/cppdocs/api/structc10_1_1_device.html#_CPPv4N3c106Device6DeviceE10DeviceType11DeviceIndex), but wouldn't that be an error for us? Notably, this logic will make any value less than 0 be 0, which means maybe we could map multiple devices to 0. I don't think we should ever see such values, but it's confusing to me that our code makes it possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ffmpeg doesn't accept negative values for the device index. I added a comment to that effect

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a single GPU libtorch returns -1 while ffmpeg assumes it will be 0. So we have to bridge that gap.

For multi-GPU setup, I haven't seen -1 being returned by torch -- so there we wont have to do a max.

The -1 seems to be a libtorch specific thing.

at::DeviceIndex originalDeviceIndex = at::cuda::current_device();
cudaSetDevice(deviceIndex);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see that later on line 121 we restore the device index. Why? Can we explain why we need to set and then restore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up using CUDADeviceGuard. Callers are assuming we are not interfering with the cuda device in this function


auto start = std::chrono::high_resolution_clock::now();

cudaStream_t nppStream = nppGetStream();
cudaStream_t torchStream = at::cuda::getCurrentCUDAStream().stream();
status = nppiNV12ToRGB_8u_P2C3R(
ahmadsharif1 marked this conversation as resolved.
Show resolved Hide resolved
input,
src->linesize[0],
static_cast<Npp8u*>(dst.data_ptr()),
dst.stride(0),
oSizeROI);
// Make the pytorch stream wait for the npp kernel to finish before using the
// output.
ahmadsharif1 marked this conversation as resolved.
Show resolved Hide resolved
cudaEvent_t nppDoneEvent;
cudaEventCreate(&nppDoneEvent);
cudaEventRecord(nppDoneEvent, nppStream);
cudaEvent_t torchDoneEvent;
cudaEventCreate(&torchDoneEvent);
cudaEventRecord(torchDoneEvent, torchStream);
cudaStreamWaitEvent(torchStream, nppDoneEvent, 0);
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame.");
cudaEventDestroy(nppDoneEvent);
cudaEventDestroy(torchDoneEvent);

auto end = std::chrono::high_resolution_clock::now();

// Restore the original device_index.
cudaSetDevice(originalDeviceIndex);

std::chrono::duration<double, std::micro> duration = end - start;
VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width
<< " took: " << duration.count() << "us" << std::endl;
if (options.dimensionOrder == "NCHW") {
// The docs guaranty this to return a view:
// https://pytorch.org/docs/stable/generated/torch.permute.html
dst = dst.permute({2, 0, 1});
}
}

} // namespace facebook::torchcodec
Loading
Loading