From 473e20e74042cb2f276a35a63d1f95f4c681d37f Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 18 Jun 2024 08:08:16 -0700 Subject: [PATCH] [torchcodec] stream_index should not be optional for getting frame by index Differential Revision: D58505608 --- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 14 ++++++-------- src/torchcodec/decoders/_core/video_decoder_ops.py | 4 ++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index b2f69ed1..013d672f 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -37,9 +37,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def("get_next_frame(Tensor(a!) decoder) -> Tensor"); m.def("get_frame_at_pts(Tensor(a!) decoder, float seconds) -> Tensor"); m.def( - "get_frame_at_index(Tensor(a!) decoder, *, int frame_index, int? stream_index=None) -> Tensor"); + "get_frame_at_index(Tensor(a!) decoder, *, int frame_index, int stream_index) -> Tensor"); m.def( - "get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int? stream_index=None) -> Tensor"); + "get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int stream_index) -> Tensor"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); } @@ -131,22 +131,20 @@ at::Tensor get_frame_at_pts(at::Tensor& decoder, double seconds) { at::Tensor get_frame_at_index( at::Tensor& decoder, int64_t frame_index, - std::optional stream_index) { + int64_t stream_index) { auto videoDecoder = static_cast(decoder.mutable_data_ptr()); - auto result = - videoDecoder->getFrameAtIndex(stream_index.value_or(-1), frame_index); + auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index); return result.frame; } at::Tensor get_frames_at_indices( at::Tensor& decoder, at::IntArrayRef frame_indices, - std::optional stream_index) { + int64_t stream_index) { auto videoDecoder = static_cast(decoder.mutable_data_ptr()); std::vector frameIndicesVec( frame_indices.begin(), frame_indices.end()); - auto result = videoDecoder->getFramesAtIndexes( - stream_index.value_or(-1), frameIndicesVec); + auto result = videoDecoder->getFramesAtIndexes(stream_index, frameIndicesVec); return result.frames; } diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index cb1d21e6..c8d24a4d 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -114,7 +114,7 @@ def get_frame_at_pts_abstract(decoder: torch.Tensor, seconds: float) -> torch.Te @register_fake("torchcodec_ns::get_frame_at_index") def get_frame_at_index_abstract( - decoder: torch.Tensor, *, frame_index: int, stream_index: Optional[int] = None + decoder: torch.Tensor, *, frame_index: int, stream_index: int ) -> torch.Tensor: image_size = [get_ctx().new_dynamic_size() for _ in range(3)] return torch.empty(image_size) @@ -125,7 +125,7 @@ def get_frames_at_indices_abstract( decoder: torch.Tensor, *, frame_indices: List[int], - stream_index: Optional[int] = None, + stream_index: int, ) -> torch.Tensor: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] return torch.empty(image_size)