Skip to content

Commit

Permalink
[torchcodec] refactor test utils to be based around dataclasses (#57)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #57

Refactors our testing utilities to consider the actual reference file to be a first-class concept by making it an object, and then all operations we do are on that object. Some principles I was trying to keep:

1. There should be one clear definition of a reference media file, with all of its important parameters in one place with names that have obvious semantic meaning.
2. Operations that are conceptually connected to a reference media file should be methods on that object.
3. Tests should only have to use the object defined in 1.
4. Formalize the patterns we've already established in how we name reference files.
5. Make adding new reference files easy and obvious.

Right now, we can only support reference tensors by timestamp with a generic approach. I think we should try to make that more an explicit pattern based on the pts value, but I'm not sure exactly how to do that right now. That will probably require changing some of our current reference file names.

Also in the future, the reference file *generation* should probably use these definitions as well. That will ensure we keep everything consistent.

Differential Revision: D59161329
  • Loading branch information
scotts authored and facebook-github-bot committed Jun 28, 2024
1 parent faf73b0 commit 9343787
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 109 deletions.
144 changes: 102 additions & 42 deletions test/decoders/simple_video_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,18 @@

from torchcodec.decoders import _core, SimpleVideoDecoder

from ..test_utils import (
assert_equal,
EMPTY_REF_TENSOR,
get_reference_video_path,
get_reference_video_tensor,
load_tensor_from_file,
REF_DIMS,
)
from ..test_utils import assert_equal, NASA_REF_VIDEO


class TestSimpleDecoder:
@pytest.mark.parametrize("source_kind", ("path", "tensor", "bytes"))
def test_create(self, source_kind):
if source_kind == "path":
source = str(get_reference_video_path())
source = str(NASA_REF_VIDEO.path)
elif source_kind == "tensor":
source = get_reference_video_tensor()
source = NASA_REF_VIDEO.to_tensor()
elif source_kind == "bytes":
path = str(get_reference_video_path())
path = str(NASA_REF_VIDEO.path)
with open(path, "rb") as f:
source = f.read()
else:
Expand All @@ -42,99 +35,166 @@ def test_create_fails(self):
decoder = SimpleVideoDecoder(123) # noqa

def test_getitem_int(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))
decoder = SimpleVideoDecoder(str(NASA_REF_VIDEO.path))

ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt")
ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt")
ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt")
ref_frame0 = NASA_REF_VIDEO.get_ref_tensor_token("frame000001")
ref_frame1 = NASA_REF_VIDEO.get_ref_tensor_token("frame000002")
ref_frame180 = NASA_REF_VIDEO.get_ref_tensor_token("time6.000000")
ref_frame_last = NASA_REF_VIDEO.get_ref_tensor_token("time12.979633")

assert_equal(ref_frame0, decoder[0])
assert_equal(ref_frame1, decoder[1])
assert_equal(ref_frame180, decoder[180])
assert_equal(ref_frame_last, decoder[-1])

def test_getitem_slice(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))
decoder = SimpleVideoDecoder(str(NASA_REF_VIDEO.path))

ref_frames0_9 = [
load_tensor_from_file(f"nasa_13013.mp4.frame{i + 1:06d}.pt")
for i in range(0, 9)
NASA_REF_VIDEO.get_ref_tensor_index(i + 1) for i in range(0, 9)
]

# Ensure that the degenerate case of a range of size 1 works; note that we get
# a tensor which CONTAINS a single frame, rather than a tensor that itself IS a
# single frame. Hence we have to access the 0th element of the return tensor.
slice_0 = decoder[0:1]
assert slice_0.shape == torch.Size([1, *REF_DIMS])
assert slice_0.shape == torch.Size(
[
1,
NASA_REF_VIDEO.height,
NASA_REF_VIDEO.width,
NASA_REF_VIDEO.num_color_channels,
]
)
assert_equal(ref_frames0_9[0], slice_0[0])

slice_4 = decoder[4:5]
assert slice_4.shape == torch.Size([1, *REF_DIMS])
assert slice_4.shape == torch.Size(
[
1,
NASA_REF_VIDEO.height,
NASA_REF_VIDEO.width,
NASA_REF_VIDEO.num_color_channels,
]
)
assert_equal(ref_frames0_9[4], slice_4[0])

slice_8 = decoder[8:9]
assert slice_8.shape == torch.Size([1, *REF_DIMS])
assert slice_8.shape == torch.Size(
[
1,
NASA_REF_VIDEO.height,
NASA_REF_VIDEO.width,
NASA_REF_VIDEO.num_color_channels,
]
)
assert_equal(ref_frames0_9[8], slice_8[0])

ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
ref_frame180 = NASA_REF_VIDEO.get_ref_tensor_token("time6.000000")
slice_180 = decoder[180:181]
assert slice_180.shape == torch.Size([1, *REF_DIMS])
assert slice_180.shape == torch.Size(
[
1,
NASA_REF_VIDEO.height,
NASA_REF_VIDEO.width,
NASA_REF_VIDEO.num_color_channels,
]
)
assert_equal(ref_frame180, slice_180[0])

# contiguous ranges
slice_frames0_9 = decoder[0:9]
assert slice_frames0_9.shape == torch.Size([9, *REF_DIMS])
assert slice_frames0_9.shape == torch.Size(
[
9,
NASA_REF_VIDEO.height,
NASA_REF_VIDEO.width,
NASA_REF_VIDEO.num_color_channels,
]
)
for ref_frame, slice_frame in zip(ref_frames0_9, slice_frames0_9):
assert_equal(ref_frame, slice_frame)

slice_frames4_8 = decoder[4:8]
assert slice_frames4_8.shape == torch.Size([4, *REF_DIMS])
assert slice_frames4_8.shape == torch.Size(
[
4,
NASA_REF_VIDEO.height,
NASA_REF_VIDEO.width,
NASA_REF_VIDEO.num_color_channels,
]
)
for ref_frame, slice_frame in zip(ref_frames0_9[4:8], slice_frames4_8):
assert_equal(ref_frame, slice_frame)

# ranges with a stride
ref_frames15_35 = [
load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt")
for i in range(15, 36, 5)
NASA_REF_VIDEO.get_ref_tensor_index(i) for i in range(15, 36, 5)
]
slice_frames15_35 = decoder[15:36:5]
assert slice_frames15_35.shape == torch.Size([5, *REF_DIMS])
assert slice_frames15_35.shape == torch.Size(
[
5,
NASA_REF_VIDEO.height,
NASA_REF_VIDEO.width,
NASA_REF_VIDEO.num_color_channels,
]
)
for ref_frame, slice_frame in zip(ref_frames15_35, slice_frames15_35):
assert_equal(ref_frame, slice_frame)

slice_frames0_9_2 = decoder[0:9:2]
assert slice_frames0_9_2.shape == torch.Size([5, *REF_DIMS])
assert slice_frames0_9_2.shape == torch.Size(
[
5,
NASA_REF_VIDEO.height,
NASA_REF_VIDEO.width,
NASA_REF_VIDEO.num_color_channels,
]
)
for ref_frame, slice_frame in zip(ref_frames0_9[0:0:2], slice_frames0_9_2):
assert_equal(ref_frame, slice_frame)

# negative numbers in the slice
ref_frames386_389 = [
load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt")
for i in range(386, 390)
NASA_REF_VIDEO.get_ref_tensor_index(i) for i in range(386, 390)
]

slice_frames386_389 = decoder[-4:]
assert slice_frames386_389.shape == torch.Size([4, *REF_DIMS])
assert slice_frames386_389.shape == torch.Size(
[
4,
NASA_REF_VIDEO.height,
NASA_REF_VIDEO.width,
NASA_REF_VIDEO.num_color_channels,
]
)
for ref_frame, slice_frame in zip(ref_frames386_389[-4:], slice_frames386_389):
assert_equal(ref_frame, slice_frame)

# an empty range is valid!
empty_frame = decoder[5:5]
assert_equal(empty_frame, EMPTY_REF_TENSOR)
assert_equal(empty_frame, NASA_REF_VIDEO.empty_hwc_tensor)

# slices that are out-of-range are also valid - they return an empty tensor
also_empty = decoder[10000:]
assert_equal(also_empty, EMPTY_REF_TENSOR)
assert_equal(also_empty, NASA_REF_VIDEO.empty_hwc_tensor)

# should be just a copy
all_frames = decoder[:]
assert all_frames.shape == torch.Size([len(decoder), *REF_DIMS])
assert all_frames.shape == torch.Size(
[
len(decoder),
NASA_REF_VIDEO.height,
NASA_REF_VIDEO.width,
NASA_REF_VIDEO.num_color_channels,
]
)
for sliced, ref in zip(all_frames, decoder):
assert_equal(sliced, ref)

def test_getitem_fails(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))
decoder = SimpleVideoDecoder(str(NASA_REF_VIDEO.path))

with pytest.raises(IndexError, match="out of bounds"):
frame = decoder[1000] # noqa
Expand All @@ -146,12 +206,12 @@ def test_getitem_fails(self):
frame = decoder["0"] # noqa

def test_next(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))
decoder = SimpleVideoDecoder(str(NASA_REF_VIDEO.path))

ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt")
ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt")
ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt")
ref_frame0 = NASA_REF_VIDEO.get_ref_tensor_token("frame000001")
ref_frame1 = NASA_REF_VIDEO.get_ref_tensor_token("frame000002")
ref_frame180 = NASA_REF_VIDEO.get_ref_tensor_token("time6.000000")
ref_frame_last = NASA_REF_VIDEO.get_ref_tensor_token("time12.979633")

for i, frame in enumerate(decoder):
if i == 0:
Expand Down
4 changes: 2 additions & 2 deletions test/decoders/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
StreamMetadata,
)

from ..test_utils import get_reference_video_path
from ..test_utils import NASA_REF_VIDEO


def test_get_video_metadata():
decoder = create_from_file(str(get_reference_video_path()))
decoder = create_from_file(str(NASA_REF_VIDEO.path))
metadata = get_video_metadata(decoder)
assert len(metadata.streams) == 6
assert metadata.best_video_stream_index == 3
Expand Down
Loading

0 comments on commit 9343787

Please sign in to comment.