Skip to content

Commit

Permalink
[torchcodec] add get_frames_at method to SimpleVideoDecoder (#80)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #80

This diff does several things:
* Adds new method to `SimpleVideoDecoder`, with the signature:
    get_frames_at(
      self,
      start: int,
      stop: int,
      step: int = 1
    ) -> FrameBatch
* Adds the `FrameBatch` dataclass as a part of the public API. It has stacked version of the data, and pts and duration seconds are in N-dimensional tensors, where N is the size of the range.  It is a sibling to the `Frame` dataclass.
* Changes the return value of the core library function `get_frames_in_range()` as well as the underlying member function `VideoDecoder::getFramesInRange()` to return `Tuple[Tensor, Tensor, Tensor]`. The first value are the stacked frames, second is each pts in the range in a tensor, and third is each duration seconds in a tensor. This matches the new return type of `get_frame_at_index()`.
* Updates our testing framework so that we can systematically associate pts and duration metadata with a particular test file. In this diff we are hardcoding the metadata into the testing utils. In the future, we should read it from a checked-in JSON file we generate with ffprobe.

This diff is a partial implementation of the design in: https://fburl.com/gdoc/i6eqb634

Reviewed By: ahmadsharif1

Differential Revision: D59767617
  • Loading branch information
scotts authored and facebook-github-bot committed Jul 17, 2024
1 parent cfb0388 commit b13fed8
Show file tree
Hide file tree
Showing 10 changed files with 296 additions and 110 deletions.
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._simple_video_decoder import Frame, SimpleVideoDecoder # noqa
from ._simple_video_decoder import Frame, FrameBatch, SimpleVideoDecoder # noqa
64 changes: 32 additions & 32 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,31 @@ VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions(
}
}

VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
int64_t numFrames,
const VideoStreamDecoderOptions& options,
const StreamMetadata& metadata)
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat})),
durationSeconds(torch::empty({numFrames}, {torch::kFloat})) {
if (options.shape == "NHWC") {
frames = torch::empty(
{numFrames,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width),
3},
{torch::kUInt8});
} else if (options.shape == "NCHW") {
frames = torch::empty(
{numFrames,
3,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width)},
{torch::kUInt8});
} else {
TORCH_CHECK(false, "Unsupported frame shape=" + options.shape)
}
}

VideoDecoder::VideoDecoder() {}

void VideoDecoder::initializeDecoder() {
Expand Down Expand Up @@ -734,30 +759,6 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
return output;
}

torch::Tensor VideoDecoder::getEmptyTensorForBatch(
int64_t numFrames,
const VideoStreamDecoderOptions& options,
const StreamMetadata& metadata) {
if (options.shape == "NHWC") {
return torch::empty(
{numFrames,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width),
3},
{torch::kUInt8});
} else if (options.shape == "NCHW") {
return torch::empty(
{numFrames,
3,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width)},
{torch::kUInt8});
} else {
// TODO: should this be a TORCH macro of some kind?
throw std::runtime_error("Unsupported frame shape=" + options.shape);
}
}

VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestamp(
double seconds) {
for (auto& [streamIndex, stream] : streams_) {
Expand Down Expand Up @@ -822,11 +823,9 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndexes(
"Must scan all streams to update metadata before calling getFrameAtIndex");
}

BatchDecodedOutput output;
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
const auto& options = streams_[streamIndex].options;
output.frames =
getEmptyTensorForBatch(frameIndexes.size(), options, streamMetadata);
BatchDecodedOutput output(frameIndexes.size(), options, streamMetadata);

int i = 0;
if (streams_.count(streamIndex) == 0) {
Expand Down Expand Up @@ -873,16 +872,17 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(

int64_t numOutputFrames = std::ceil((stop - start) / double(step));
const auto& options = stream.options;
BatchDecodedOutput output;
output.frames =
getEmptyTensorForBatch(numOutputFrames, options, streamMetadata);
BatchDecodedOutput output(numOutputFrames, options, streamMetadata);

int64_t f = 0;
for (int64_t i = start; i < stop; i += step) {
int64_t pts = stream.allFrames[i].pts;
setCursorPtsInSeconds(1.0 * pts / stream.timeBase.den);
torch::Tensor frame = getNextDecodedOutput().frame;
output.frames[f++] = frame;
DecodedOutput singleOut = getNextDecodedOutput();
output.frames[f] = singleOut.frame;
output.ptsSeconds[f] = singleOut.ptsSeconds;
output.durationSeconds[f] = singleOut.durationSeconds;
++f;
}

return output;
Expand Down
7 changes: 7 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ class VideoDecoder {
DecodedOutput getFrameAtIndex(int streamIndex, int64_t frameIndex);
struct BatchDecodedOutput {
torch::Tensor frames;
torch::Tensor ptsSeconds;
torch::Tensor durationSeconds;

explicit BatchDecodedOutput(
int64_t numFrames,
const VideoStreamDecoderOptions& options,
const StreamMetadata& metadata);
};
// Returns frames at the given indexes for a given stream as a single stacked
// Tensor.
Expand Down
20 changes: 12 additions & 8 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> Tensor");
m.def(
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> Tensor");
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
m.def(
Expand Down Expand Up @@ -64,13 +64,17 @@ VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) {
return decoder;
}

FramePtsDuration getTensorPtsDurationFromFrame(
VideoDecoder::DecodedOutput& frame) {
FramePtsDuration makeFramePtsDuration(VideoDecoder::DecodedOutput& frame) {
return std::make_tuple(
frame.frame,
torch::tensor(frame.ptsSeconds),
torch::tensor(frame.durationSeconds));
}

BatchedFramesPtsDuration makeBatchedFramesPtsDuration(
VideoDecoder::BatchDecodedOutput& batch) {
return std::make_tuple(batch.frames, batch.ptsSeconds, batch.durationSeconds);
}
} // namespace

// ==============================
Expand Down Expand Up @@ -139,13 +143,13 @@ FramePtsDuration get_next_frame(at::Tensor& decoder) {
"image_size is unexpected. Expected 3, got: " +
std::to_string(result.frame.sizes().size()));
}
return getTensorPtsDurationFromFrame(result);
return makeFramePtsDuration(result);
}

FramePtsDuration get_frame_at_pts(at::Tensor& decoder, double seconds) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto result = videoDecoder->getFrameDisplayedAtTimestamp(seconds);
return getTensorPtsDurationFromFrame(result);
return makeFramePtsDuration(result);
}

FramePtsDuration get_frame_at_index(
Expand All @@ -154,7 +158,7 @@ FramePtsDuration get_frame_at_index(
int64_t frame_index) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index);
return getTensorPtsDurationFromFrame(result);
return makeFramePtsDuration(result);
}

at::Tensor get_frames_at_indices(
Expand All @@ -168,7 +172,7 @@ at::Tensor get_frames_at_indices(
return result.frames;
}

at::Tensor get_frames_in_range(
BatchedFramesPtsDuration get_frames_in_range(
at::Tensor& decoder,
int64_t stream_index,
int64_t start,
Expand All @@ -177,7 +181,7 @@ at::Tensor get_frames_in_range(
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto result = videoDecoder->getFramesInRange(
stream_index, start, stop, step.value_or(1));
return result.frames;
return makeBatchedFramesPtsDuration(result);
}

std::string quoteValue(const std::string& value) {
Expand Down
22 changes: 17 additions & 5 deletions src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,25 @@ void add_video_stream(
// Seek to a particular presentation timestamp in the video in seconds.
void seek_to_pts(at::Tensor& decoder, double seconds);

// The first element of this tuple has the frame data. The second element is a
// Tensor that has a single float value for the PTS and the third element is a
// Tensor that has a single value for the duration.
// The reason we use Tensors for the second and third value is so we can run
// The elements of this tuple are all tensors that represent a single frame:
// 1. The frame data, which is a multidimensional tensor.
// 2. A single float value for the pts in seconds.
// 3. A single float value for the duration in seconds.
// The reason we use Tensors for the second and third values is so we can run
// under torch.compile().
using FramePtsDuration = std::tuple<at::Tensor, at::Tensor, at::Tensor>;

// All elements of this tuple are tensors of the same leading dimension. The
// tuple represents the frames for N total frames, where N is the dimension of
// each stacked tensor. The elments are:
// 1. Stacked tensor of data for all N frames. Each frame is also a
// multidimensional tensor.
// 2. Tensor of N pts values in seconds, where each pts is a single
// float.
// 3. Tensor of N durationis in seconds, where each duration is a
// single float.
using BatchedFramesPtsDuration = std::tuple<at::Tensor, at::Tensor, at::Tensor>;

// Return the frame that is visible at a given timestamp in seconds. Each frame
// in FFMPEG has a presentation timestamp and a duration. The frame visible at a
// given timestamp T has T >= PTS and T < PTS + Duration.
Expand All @@ -67,7 +79,7 @@ at::Tensor get_frames_at_indices(

// Return the frames inside a range as a single stacked Tensor. The range is
// defined as [start, stop).
at::Tensor get_frames_in_range(
BatchedFramesPtsDuration get_frames_in_range(
at::Tensor& decoder,
int64_t stream_index,
int64_t start,
Expand Down
8 changes: 6 additions & 2 deletions src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,13 @@ def get_frames_in_range_abstract(
start: int,
stop: int,
step: Optional[int] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return torch.empty(image_size)
return (
torch.empty(image_size),
torch.empty([], dtype=torch.float),
torch.empty([], dtype=torch.float),
)


@impl_abstract("torchcodec_ns::get_json_metadata")
Expand Down
34 changes: 33 additions & 1 deletion src/torchcodec/decoders/_simple_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ def __iter__(self) -> Iterator[Union[Tensor, float]]:
yield getattr(self, field.name)


@dataclass
class FrameBatch(Iterable):
data: Tensor
pts_seconds: Tensor
duration_seconds: Tensor

def __iter__(self) -> Iterator[Union[Tensor, float]]:
for field in dataclasses.fields(self):
yield getattr(self, field.name)


_ERROR_REPORTING_INSTRUCTIONS = """
This should never happen. Please report an issue following the steps in <TODO>.
"""
Expand Down Expand Up @@ -92,13 +103,14 @@ def _getitem_slice(self, key: slice) -> Tensor:
assert isinstance(key, slice)

start, stop, step = key.indices(len(self))
return core.get_frames_in_range(
frame_data, *_ = core.get_frames_in_range(
self._decoder,
stream_index=self._stream_index,
start=start,
stop=stop,
step=step,
)
return frame_data

def __getitem__(self, key: Union[int, slice]) -> Tensor:
if isinstance(key, int):
Expand All @@ -120,6 +132,26 @@ def get_frame_at(self, index: int) -> Frame:
)
return Frame(*frame)

def get_frames_at(self, start: int, stop: int, step: int = 1) -> FrameBatch:
if not 0 <= start < self._num_frames:
raise IndexError(
f"Start index {start} is out of bounds; must be in the range [0, {self._num_frames})."
)
if stop < start:
raise IndexError(
f"Stop index ({stop}) must not be less than the start index ({start})."
)
if not step > 0:
raise IndexError(f"Step ({step}) must be greater than 0.")
frames = core.get_frames_in_range(
self._decoder,
stream_index=self._stream_index,
start=start,
stop=stop,
step=step,
)
return FrameBatch(*frames)

def get_frame_displayed_at(self, pts_seconds: float) -> Frame:
if not self._min_pts_seconds <= pts_seconds < self._max_pts_seconds:
raise IndexError(
Expand Down
Loading

0 comments on commit b13fed8

Please sign in to comment.