Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchcodec] add get_frames_at method to SimpleVideoDecoder #80

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._simple_video_decoder import Frame, SimpleVideoDecoder # noqa
from ._simple_video_decoder import Frame, FrameBatch, SimpleVideoDecoder # noqa
64 changes: 32 additions & 32 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,31 @@ VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions(
}
}

VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
int64_t numFrames,
const VideoStreamDecoderOptions& options,
const StreamMetadata& metadata)
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat})),
durationSeconds(torch::empty({numFrames}, {torch::kFloat})) {
if (options.shape == "NHWC") {
frames = torch::empty(
{numFrames,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width),
3},
{torch::kUInt8});
} else if (options.shape == "NCHW") {
frames = torch::empty(
{numFrames,
3,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width)},
{torch::kUInt8});
} else {
TORCH_CHECK(false, "Unsupported frame shape=" + options.shape)
}
}

VideoDecoder::VideoDecoder() {}

void VideoDecoder::initializeDecoder() {
Expand Down Expand Up @@ -734,30 +759,6 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
return output;
}

torch::Tensor VideoDecoder::getEmptyTensorForBatch(
int64_t numFrames,
const VideoStreamDecoderOptions& options,
const StreamMetadata& metadata) {
if (options.shape == "NHWC") {
return torch::empty(
{numFrames,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width),
3},
{torch::kUInt8});
} else if (options.shape == "NCHW") {
return torch::empty(
{numFrames,
3,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width)},
{torch::kUInt8});
} else {
// TODO: should this be a TORCH macro of some kind?
throw std::runtime_error("Unsupported frame shape=" + options.shape);
}
}

VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestamp(
double seconds) {
for (auto& [streamIndex, stream] : streams_) {
Expand Down Expand Up @@ -822,11 +823,9 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndexes(
"Must scan all streams to update metadata before calling getFrameAtIndex");
}

BatchDecodedOutput output;
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
const auto& options = streams_[streamIndex].options;
output.frames =
getEmptyTensorForBatch(frameIndexes.size(), options, streamMetadata);
BatchDecodedOutput output(frameIndexes.size(), options, streamMetadata);

int i = 0;
if (streams_.count(streamIndex) == 0) {
Expand Down Expand Up @@ -873,16 +872,17 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(

int64_t numOutputFrames = std::ceil((stop - start) / double(step));
const auto& options = stream.options;
BatchDecodedOutput output;
output.frames =
getEmptyTensorForBatch(numOutputFrames, options, streamMetadata);
BatchDecodedOutput output(numOutputFrames, options, streamMetadata);

int64_t f = 0;
for (int64_t i = start; i < stop; i += step) {
int64_t pts = stream.allFrames[i].pts;
setCursorPtsInSeconds(1.0 * pts / stream.timeBase.den);
torch::Tensor frame = getNextDecodedOutput().frame;
output.frames[f++] = frame;
DecodedOutput singleOut = getNextDecodedOutput();
output.frames[f] = singleOut.frame;
output.ptsSeconds[f] = singleOut.ptsSeconds;
output.durationSeconds[f] = singleOut.durationSeconds;
++f;
}

return output;
Expand Down
7 changes: 7 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ class VideoDecoder {
DecodedOutput getFrameAtIndex(int streamIndex, int64_t frameIndex);
struct BatchDecodedOutput {
torch::Tensor frames;
torch::Tensor ptsSeconds;
torch::Tensor durationSeconds;

explicit BatchDecodedOutput(
int64_t numFrames,
const VideoStreamDecoderOptions& options,
const StreamMetadata& metadata);
};
// Returns frames at the given indexes for a given stream as a single stacked
// Tensor.
Expand Down
20 changes: 12 additions & 8 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> Tensor");
m.def(
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> Tensor");
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
m.def(
Expand Down Expand Up @@ -64,13 +64,17 @@ VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) {
return decoder;
}

FramePtsDuration getTensorPtsDurationFromFrame(
VideoDecoder::DecodedOutput& frame) {
FramePtsDuration makeFramePtsDuration(VideoDecoder::DecodedOutput& frame) {
return std::make_tuple(
frame.frame,
torch::tensor(frame.ptsSeconds),
torch::tensor(frame.durationSeconds));
}

BatchedFramesPtsDuration makeBatchedFramesPtsDuration(
VideoDecoder::BatchDecodedOutput& batch) {
return std::make_tuple(batch.frames, batch.ptsSeconds, batch.durationSeconds);
}
} // namespace

// ==============================
Expand Down Expand Up @@ -139,13 +143,13 @@ FramePtsDuration get_next_frame(at::Tensor& decoder) {
"image_size is unexpected. Expected 3, got: " +
std::to_string(result.frame.sizes().size()));
}
return getTensorPtsDurationFromFrame(result);
return makeFramePtsDuration(result);
}

FramePtsDuration get_frame_at_pts(at::Tensor& decoder, double seconds) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto result = videoDecoder->getFrameDisplayedAtTimestamp(seconds);
return getTensorPtsDurationFromFrame(result);
return makeFramePtsDuration(result);
}

FramePtsDuration get_frame_at_index(
Expand All @@ -154,7 +158,7 @@ FramePtsDuration get_frame_at_index(
int64_t frame_index) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index);
return getTensorPtsDurationFromFrame(result);
return makeFramePtsDuration(result);
}

at::Tensor get_frames_at_indices(
Expand All @@ -168,7 +172,7 @@ at::Tensor get_frames_at_indices(
return result.frames;
}

at::Tensor get_frames_in_range(
BatchedFramesPtsDuration get_frames_in_range(
at::Tensor& decoder,
int64_t stream_index,
int64_t start,
Expand All @@ -177,7 +181,7 @@ at::Tensor get_frames_in_range(
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto result = videoDecoder->getFramesInRange(
stream_index, start, stop, step.value_or(1));
return result.frames;
return makeBatchedFramesPtsDuration(result);
}

std::string quoteValue(const std::string& value) {
Expand Down
22 changes: 17 additions & 5 deletions src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,25 @@ void add_video_stream(
// Seek to a particular presentation timestamp in the video in seconds.
void seek_to_pts(at::Tensor& decoder, double seconds);

// The first element of this tuple has the frame data. The second element is a
// Tensor that has a single float value for the PTS and the third element is a
// Tensor that has a single value for the duration.
// The reason we use Tensors for the second and third value is so we can run
// The elements of this tuple are all tensors that represent a single frame:
// 1. The frame data, which is a multidimensional tensor.
// 2. A single float value for the pts in seconds.
// 3. A single float value for the duration in seconds.
// The reason we use Tensors for the second and third values is so we can run
// under torch.compile().
using FramePtsDuration = std::tuple<at::Tensor, at::Tensor, at::Tensor>;

// All elements of this tuple are tensors of the same leading dimension. The
// tuple represents the frames for N total frames, where N is the dimension of
// each stacked tensor. The elments are:
// 1. Stacked tensor of data for all N frames. Each frame is also a
// multidimensional tensor.
// 2. Tensor of N pts values in seconds, where each pts is a single
// float.
// 3. Tensor of N durationis in seconds, where each duration is a
// single float.
using BatchedFramesPtsDuration = std::tuple<at::Tensor, at::Tensor, at::Tensor>;

// Return the frame that is visible at a given timestamp in seconds. Each frame
// in FFMPEG has a presentation timestamp and a duration. The frame visible at a
// given timestamp T has T >= PTS and T < PTS + Duration.
Expand All @@ -67,7 +79,7 @@ at::Tensor get_frames_at_indices(

// Return the frames inside a range as a single stacked Tensor. The range is
// defined as [start, stop).
at::Tensor get_frames_in_range(
BatchedFramesPtsDuration get_frames_in_range(
at::Tensor& decoder,
int64_t stream_index,
int64_t start,
Expand Down
8 changes: 6 additions & 2 deletions src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,13 @@ def get_frames_in_range_abstract(
start: int,
stop: int,
step: Optional[int] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return torch.empty(image_size)
return (
torch.empty(image_size),
torch.empty([], dtype=torch.float),
torch.empty([], dtype=torch.float),
)


@impl_abstract("torchcodec_ns::get_json_metadata")
Expand Down
34 changes: 33 additions & 1 deletion src/torchcodec/decoders/_simple_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ def __iter__(self) -> Iterator[Union[Tensor, float]]:
yield getattr(self, field.name)


@dataclass
class FrameBatch(Iterable):
data: Tensor
pts_seconds: Tensor
duration_seconds: Tensor

def __iter__(self) -> Iterator[Union[Tensor, float]]:
for field in dataclasses.fields(self):
yield getattr(self, field.name)


_ERROR_REPORTING_INSTRUCTIONS = """
This should never happen. Please report an issue following the steps in <TODO>.
"""
Expand Down Expand Up @@ -92,13 +103,14 @@ def _getitem_slice(self, key: slice) -> Tensor:
assert isinstance(key, slice)

start, stop, step = key.indices(len(self))
return core.get_frames_in_range(
frame_data, *_ = core.get_frames_in_range(
self._decoder,
stream_index=self._stream_index,
start=start,
stop=stop,
step=step,
)
return frame_data

def __getitem__(self, key: Union[int, slice]) -> Tensor:
if isinstance(key, int):
Expand All @@ -120,6 +132,26 @@ def get_frame_at(self, index: int) -> Frame:
)
return Frame(*frame)

def get_frames_at(self, start: int, stop: int, step: int = 1) -> FrameBatch:
if not 0 <= start < self._num_frames:
raise IndexError(
f"Start index {start} is out of bounds; must be in the range [0, {self._num_frames})."
)
if stop < start:
raise IndexError(
f"Stop index ({stop}) must not be less than the start index ({start})."
)
if not step > 0:
raise IndexError(f"Step ({step}) must be greater than 0.")
frames = core.get_frames_in_range(
self._decoder,
stream_index=self._stream_index,
start=start,
stop=stop,
step=step,
)
return FrameBatch(*frames)

def get_frame_displayed_at(self, pts_seconds: float) -> Frame:
if not self._min_pts_seconds <= pts_seconds < self._max_pts_seconds:
raise IndexError(
Expand Down
Loading
Loading