Skip to content

Commit

Permalink
[torchcodec] fix simple decoder iteration (#59)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #59

The previous iterable and iterator implementation had a bug, demonstrated by the test modified in this diff. The problem:

1. We were using the `SimpleVideoDecoder` as its own iterator object by directly implementing `__iter__()` and `__next__()`.
2. In `__next__()`, we were calling a core library function, `get_next_frame()`, that returned the next frame to be decoded, and advanced the internal state of the C++ decoder.
3. But we were not *initializing* the iterator.

Because of the points above, for-based iteration only worked as expected on a freshly-created `SimpleVideoDecoder` object.

The simplest fix is to just remove the implementations of `__iter__()` and `__next__()`. Because it implements `__len__()` and `__getitem__()`, a `SimpleVideoDecoder` is a Python sequence. Python sequences are automatically iterable through `__len__()` and `__getitem__()`. See: https://docs.python.org/3/glossary.html#term-iterable

Differential Revision: D59309882
  • Loading branch information
scotts authored and facebook-github-bot committed Jul 3, 2024
1 parent 7af94c7 commit b910603
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
11 changes: 0 additions & 11 deletions src/torchcodec/decoders/_simple_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,6 @@ def __getitem__(self, key: Union[int, slice]) -> torch.Tensor:
f"Unsupported key type: {type(key)}. Supported types are int and slice."
)

def __iter__(self) -> "SimpleVideoDecoder":
return self

def __next__(self) -> torch.Tensor:
# TODO: We should distinguish between expected end-of-file and unexpected
# runtime error.
try:
return core.get_next_frame(self._decoder)
except RuntimeError:
raise StopIteration()


def _get_and_validate_stream_metadata(decoder: torch.Tensor) -> core.StreamMetadata:
video_metadata = core.get_video_metadata(decoder)
Expand Down
12 changes: 12 additions & 0 deletions test/decoders/simple_video_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,26 @@ def test_next(self):

ref_frame0 = NASA_VIDEO.get_tensor_by_index(0)
ref_frame1 = NASA_VIDEO.get_tensor_by_index(1)
ref_frame9 = NASA_VIDEO.get_tensor_by_index(9)
ref_frame35 = NASA_VIDEO.get_tensor_by_index(35)
ref_frame180 = NASA_VIDEO.get_tensor_by_name("time6.000000")
ref_frame_last = NASA_VIDEO.get_tensor_by_name("time12.979633")

# Access an arbitrary frame to make sure that the later iteration
# still works as expected. The underlying C++ decoder object is
# actually stateful, and accessing a frame will move its internal
# cursor.
assert_tensor_equal(ref_frame35, decoder[35])

for i, frame in enumerate(decoder):
if i == 0:
assert_tensor_equal(ref_frame0, frame)
elif i == 1:
assert_tensor_equal(ref_frame1, frame)
elif i == 9:
assert_tensor_equal(ref_frame9, frame)
elif i == 35:
assert_tensor_equal(ref_frame35, frame)
elif i == 180:
assert_tensor_equal(ref_frame180, frame)
elif i == 389:
Expand Down

0 comments on commit b910603

Please sign in to comment.