Skip to content

Commit

Permalink
[torchcodec] Add support for Nvidia GPU Decoding (#58)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #58

1. Add CUDA support to VideoDecoder.cpp. This is done by checking what device is passed into the options and using CUDA if the device type is cuda.
2. Add -DENABLE_CUDA flag in cmake.
3. Check ENABLE_CUDA environment variable in setup.py and pass it down to cmake if it is present.
4. Add a unit test to demonstrate that CUDA decoding does work. This uses a different tensor than the one from CPU decoding because hardware decoding is intrinsically a bit inaccurate. I generated the reference tensor by dumping the tensor from the GPU on my devVM. It is possible different Nvidia hardware show different outputs. How to test this in a more robust way is TBD.
5. Added a new parameter for cuda device index for `add_video_stream`. If this is present, we will use it to do hardware decoding on a CUDA device.

Differential Revision: D59121006
  • Loading branch information
ahmadsharif1 authored and facebook-github-bot committed Jul 1, 2024
1 parent faf73b0 commit cc7223e
Show file tree
Hide file tree
Showing 15 changed files with 221 additions and 19 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
cmake_minimum_required(VERSION 3.18)
project(TorchCodec)

option(ENABLE_CUDA "Enable CUDA decoding using NVDEC" OFF)

add_subdirectory(src/torchcodec/decoders/_core)


Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,13 @@ def _build_all_extensions_with_cmake(self):
# python setup.py build_ext --debug install
torch_dir = Path(torch.utils.cmake_prefix_path) / "Torch"
cmake_build_type = os.environ.get("CMAKE_BUILD_TYPE", "Release")
enable_cuda = os.environ.get("ENABLE_CUDA", "")
cmake_args = [
f"-DCMAKE_INSTALL_PREFIX={self._install_prefix}",
f"-DTorch_DIR={torch_dir}",
"-DCMAKE_VERBOSE_MAKEFILE=ON",
f"-DCMAKE_BUILD_TYPE={cmake_build_type}",
f"-DENABLE_CUDA={enable_cuda}",
]

Path(self.build_temp).mkdir(parents=True, exist_ok=True)
Expand Down
3 changes: 3 additions & 0 deletions src/torchcodec/decoders/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ function(make_torchcodec_library library_name ffmpeg_target)
)
add_library(${library_name} SHARED ${sources})
set_property(TARGET ${library_name} PROPERTY CXX_STANDARD 17)
if(ENABLE_CUDA)
target_compile_definitions(${library_name} PRIVATE ENABLE_CUDA=1)
endif()

target_include_directories(
${library_name}
Expand Down
2 changes: 2 additions & 0 deletions src/torchcodec/decoders/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ using UniqueAVFilterInOut = std::unique_ptr<
Deleterp<AVFilterInOut, void, avfilter_inout_free>>;
using UniqueAVIOContext = std::
unique_ptr<AVIOContext, Deleterp<AVIOContext, void, avio_context_free>>;
using UniqueAVBufferRef =
std::unique_ptr<AVBufferRef, Deleterp<AVBufferRef, void, av_buffer_unref>>;

// av_find_best_stream is not const-correct before commit:
// https://github.com/FFmpeg/FFmpeg/commit/46dac8cf3d250184ab4247809bc03f60e14f4c0c
Expand Down
127 changes: 120 additions & 7 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include "src/torchcodec/decoders/_core/VideoDecoder.h"
#include <torch/torch.h>
#include <torch/types.h>
#include <cstdint>
#include <cstdio>
#include <iostream>
#include <stdexcept>
#include <string_view>
#include "torch/types.h"

#ifdef ENABLE_CUDA
#include <c10/cuda/CUDAStream.h>
#include <npp.h>
#endif

extern "C" {
#include <libavcodec/avcodec.h>
Expand Down Expand Up @@ -84,6 +91,47 @@ std::vector<std::string> splitStringWithDelimiters(
return result;
}

#ifdef ENABLE_CUDA
static std::mutex gCudaContextCacheMutex;
static std::map<int, UniqueAVBufferRef> gCudaContextCache;

AVBufferRef* getCudaContextForDeviceIndex(int index) {
std::lock_guard<std::mutex> lock(gCudaContextCacheMutex);
if (index == -1) {
index = 0;
}
if (gCudaContextCache.count(index) == 0) {
AVBufferRef* p = nullptr;
int ret = av_hwdevice_ctx_create(
&p, AV_HWDEVICE_TYPE_CUDA, std::to_string(index).c_str(), nullptr, 0);
TORCH_CHECK(
ret >= 0,
"Failed to create CUDA device context on device ",
index,
"(",
getFFMPEGErrorStringFromErrorCode(ret),
")");
assert(p);
gCudaContextCache.emplace(index, p);
return p;
}
UniqueAVBufferRef& buffer = gCudaContextCache.at(index);
return buffer.get();
}

torch::Tensor getCudaImageBuffer(
at::IntArrayRef shape,
torch::Device device,
const torch::Dtype dtype = torch::kUInt8) {
return torch::empty(
shape,
torch::TensorOptions()
.dtype(dtype)
.layout(torch::kStrided)
.device(device));
}
#endif

} // namespace

VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions(
Expand Down Expand Up @@ -299,13 +347,13 @@ void VideoDecoder::initializeFilterGraphForStream(
inputs.reset(inputsTmp);
if (ffmpegStatus < 0) {
throw std::runtime_error(
"Failed to parse filter description: " +
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
"Failed to parse filter description: " + std::string(description) +
"; " + getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
}
ffmpegStatus = avfilter_graph_config(filterState.filterGraph.get(), nullptr);
if (ffmpegStatus < 0) {
throw std::runtime_error(
"Failed to configure filter graph: " +
"Failed to configure filter graph: " + std::string(description) + "; " +
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
}
}
Expand Down Expand Up @@ -354,15 +402,34 @@ void VideoDecoder::addVideoStreamDecoder(
int retVal = avcodec_parameters_to_context(
streamInfo.codecContext.get(), streamInfo.stream->codecpar);
TORCH_CHECK_EQ(retVal, AVSUCCESS);

if (options.device.type() == torch::DeviceType::CUDA) {
#ifdef ENABLE_CUDA
codecContext->hw_device_ctx =
av_buffer_ref(getCudaContextForDeviceIndex(options.device.index()));

TORCH_INTERNAL_ASSERT(
codecContext->hw_device_ctx,
"Failed to create/reference the CUDA HW device context for index=" +
std::to_string(options.device.index()) + ".");
#else
throw std::runtime_error(
"CUDA support is not enabled in this build of TorchCodec.");
#endif
}

retVal = avcodec_open2(streamInfo.codecContext.get(), codec, nullptr);
if (retVal < AVSUCCESS) {
throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal));
}

codecContext->time_base = streamInfo.stream->time_base;
activeStreamIndices_.insert(streamNumber);
updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext);
streamInfo.options = options;
initializeFilterGraphForStream(streamNumber, options);
if (options.device.is_cpu()) {
initializeFilterGraphForStream(streamNumber, options);
}
}

void VideoDecoder::updateMetadataWithCodecContext(
Expand Down Expand Up @@ -584,6 +651,40 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
}
}

#ifdef ENABLE_CUDA
torch::Tensor VideoDecoder::convertFrameToTensorUsingCUDA(
int streamIndex,
const AVFrame* src) {
StreamInfo& streamInfo = streams_[streamIndex];
AVCodecContext* codecContext = streamInfo.codecContext.get();
int width = codecContext->width;
int height = codecContext->height;
NppStatus status;
NppiSize oSizeROI;
oSizeROI.width = width;
oSizeROI.height = height;
Npp8u* input[2];
input[0] = (Npp8u*)src->data[0];
input[1] = (Npp8u*)src->data[1];
torch::Tensor dst =
getCudaImageBuffer({height, width, 3}, streamInfo.options.device);
auto start = std::chrono::high_resolution_clock::now();
status = nppiNV12ToRGB_8u_P2C3R(
input,
src->linesize[0],
static_cast<Npp8u*>(dst.data_ptr()),
width * 3,
oSizeROI);
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame.");
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double, std::micro> duration = end - start;
VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width
<< " took: " << duration.count() << "us" << std::endl;

return dst;
}
#endif

VideoDecoder::DecodedOutput VideoDecoder::getDecodedOutputWithFilter(
std::function<bool(int, AVFrame*)> filterFunction) {
if (activeStreamIndices_.size() == 0) {
Expand Down Expand Up @@ -674,8 +775,12 @@ VideoDecoder::DecodedOutput VideoDecoder::getDecodedOutputWithFilter(
// This packet is not for any of the active streams.
continue;
}
auto start = std::chrono::high_resolution_clock::now();
ffmpegStatus = avcodec_send_packet(
streams_[packet->stream_index].codecContext.get(), packet.get());
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double, std::micro> duration = end - start;
VLOG(9) << "Packet send took: " << duration.count() << "us" << std::endl;
decodeStats_.numPacketsSentToDecoder++;
if (ffmpegStatus < AVSUCCESS) {
throw std::runtime_error(
Expand Down Expand Up @@ -714,8 +819,16 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
output.ptsSeconds =
1.0 * frame->pts / formatContext_->streams[streamIndex]->time_base.den;
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
output.frame =
convertFrameToTensorUsingFilterGraph(streamIndex, frame.get());
if (streams_[streamIndex].options.device.is_cpu()) {
output.frame =
convertFrameToTensorUsingFilterGraph(streamIndex, frame.get());
} else if (streams_[streamIndex].options.device.is_cuda()) {
#ifdef ENABLE_CUDA
output.frame = convertFrameToTensorUsingCUDA(streamIndex, frame.get());
#else
throw std::runtime_error("CUDA is not enabled in this build.");
#endif // ENABLE_CUDA
}
} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
// TODO: implement audio AVFrame to Tensor conversion here.
throw std::runtime_error("Audio is not supported yet.");
Expand Down
8 changes: 8 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ostream>
#include <string_view>

#include "c10/core/Device.h"
#include "src/torchcodec/decoders/_core/FFMPEGCommon.h"

namespace facebook::torchcodec {
Expand Down Expand Up @@ -136,6 +137,8 @@ class VideoDecoder {
// is the same as the original video.
std::optional<int> width;
std::optional<int> height;
// Set the device to torch::kGPU for GPU decoding.
torch::Device device = torch::kCPU;
};
struct AudioStreamDecoderOptions {
// TODO: Add channels, shape, sample options, etc.
Expand Down Expand Up @@ -277,6 +280,11 @@ class VideoDecoder {
torch::Tensor convertFrameToTensorUsingFilterGraph(
int streamIndex,
const AVFrame* frame);
#ifdef ENABLE_CUDA
torch::Tensor convertFrameToTensorUsingCUDA(
int streamIndex,
const AVFrame* frame);
#endif
DecodedOutput convertAVFrameToDecodedOutput(
int streamIndex,
UniqueAVFrame frame);
Expand Down
9 changes: 7 additions & 2 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def("create_from_file(str filename) -> Tensor");
m.def("create_from_tensor(Tensor video_tensor) -> Tensor");
m.def(
"add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? shape=None, int? stream_index=None) -> ()");
"add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? shape=None, int? stream_index=None, str? device_string=None) -> ()");
m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()");
m.def("get_next_frame(Tensor(a!) decoder) -> Tensor");
m.def("get_frame_at_pts(Tensor(a!) decoder, float seconds) -> Tensor");
Expand Down Expand Up @@ -87,7 +87,8 @@ void add_video_stream(
std::optional<int64_t> height = std::nullopt,
std::optional<int64_t> num_threads = std::nullopt,
std::optional<c10::string_view> shape = std::nullopt,
std::optional<int64_t> stream_index = std::nullopt) {
std::optional<int64_t> stream_index = std::nullopt,
std::optional<c10::string_view> device_string = std::nullopt) {
VideoDecoder::VideoStreamDecoderOptions options;
options.width = width;
options.height = height;
Expand All @@ -98,6 +99,10 @@ void add_video_stream(
TORCH_CHECK(stdShape == "NHWC" || stdShape == "NCHW");
options.shape = stdShape;
}
if (device_string.has_value()) {
std::string deviceString{device_string.value()};
options.device = torch::Device(deviceString);
}

auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
videoDecoder->addVideoStreamDecoder(stream_index.value_or(-1), options);
Expand Down
3 changes: 2 additions & 1 deletion src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ void add_video_stream(
std::optional<int64_t> height = std::nullopt,
std::optional<int64_t> num_threads = std::nullopt,
std::optional<c10::string_view> shape = std::nullopt,
std::optional<int64_t> stream_index = std::nullopt);
std::optional<int64_t> stream_index = std::nullopt,
std::optional<c10::string_view> device_string = std::nullopt);

// Seek to a particular presentation timestamp in the video in seconds.
void seek_to_pts(at::Tensor& decoder, double seconds);
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def add_video_stream_abstract(
num_threads: Optional[int] = None,
shape: Optional[str] = None,
stream_index: Optional[int] = None,
device_string: Optional[str] = None,
) -> None:
return

Expand Down
4 changes: 4 additions & 0 deletions test/decoders/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ add_executable(
VideoDecoderOpsTest.cpp
)

if(ENABLE_CUDA)
target_compile_definitions(VideoDecoderTest PRIVATE ENABLE_CUDA=1)
endif()

target_include_directories(VideoDecoderTest SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS})
target_include_directories(VideoDecoderTest PRIVATE ../../)
target_include_directories(VideoDecoderTest SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS})
Expand Down
44 changes: 44 additions & 0 deletions test/decoders/VideoDecoderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,50 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) {
}
}

#ifdef ENABLE_CUDA
TEST(GPUVideoDecoderTest, ReturnsFirstTwoFramesOfVideo) {
std::string path = getResourcePath("nasa_13013.mp4");
std::unique_ptr<VideoDecoder> ourDecoder =
VideoDecoder::createFromFilePath(path);
VideoDecoder::VideoStreamDecoderOptions streamOptions;
streamOptions.device = torch::Device("cuda");
ASSERT_TRUE(streamOptions.device.is_cuda());
ASSERT_EQ(streamOptions.device.type(), torch::DeviceType::CUDA);
ourDecoder->addVideoStreamDecoder(-1, streamOptions);
auto output = ourDecoder->getNextDecodedOutput();
torch::Tensor tensor1FromOurDecoder = output.frame;
EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector<long>({270, 480, 3}));
EXPECT_EQ(output.ptsSeconds, 0.0);
EXPECT_EQ(output.pts, 0);
output = ourDecoder->getNextDecodedOutput();
torch::Tensor tensor2FromOurDecoder = output.frame;
EXPECT_EQ(tensor2FromOurDecoder.sizes(), std::vector<long>({270, 480, 3}));
EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000);
EXPECT_EQ(output.pts, 1001);

torch::Tensor tensor1FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000001.cuda.pt");
torch::Tensor tensor2FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000002.cuda.pt");

EXPECT_EQ(tensor1FromFFMPEG.sizes(), std::vector<long>({270, 480, 3}));
EXPECT_EQ(tensor2FromFFMPEG.sizes(), std::vector<long>({270, 480, 3}));
EXPECT_EQ(tensor1FromOurDecoder.device().type(), torch::DeviceType::CUDA);
EXPECT_EQ(tensor2FromOurDecoder.device().type(), torch::DeviceType::CUDA);
torch::Tensor tensor1FromOurDecoderCPU = tensor1FromOurDecoder.cpu();
torch::Tensor tensor2FromOurDecoderCPU = tensor1FromOurDecoder.cpu();
EXPECT_TRUE(torch::equal(tensor1FromOurDecoderCPU, tensor1FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor2FromOurDecoderCPU, tensor2FromFFMPEG));

if (FLAGS_dump_frames_for_debugging) {
dumpTensorToDisk(tensor1FromFFMPEG, "tensor1FromFFMPEG.pt");
dumpTensorToDisk(tensor2FromFFMPEG, "tensor2FromFFMPEG.pt");
dumpTensorToDisk(tensor1FromOurDecoderCPU, "tensor1FromOurDecoder.pt");
dumpTensorToDisk(tensor2FromOurDecoderCPU, "tensor2FromOurDecoder.pt");
}
}
#endif

TEST_P(VideoDecoderTest, DecodesFramesInABatchInNHWC) {
std::string path = getResourcePath("nasa_13013.mp4");
std::unique_ptr<VideoDecoder> ourDecoder =
Expand Down
Loading

0 comments on commit cc7223e

Please sign in to comment.