From 9343787df759ec93fc74bf87923b8123d8cc5ec6 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 28 Jun 2024 08:29:21 -0700 Subject: [PATCH] [torchcodec] refactor test utils to be based around dataclasses (#57) Summary: Pull Request resolved: https://github.com/pytorch-labs/torchcodec/pull/57 Refactors our testing utilities to consider the actual reference file to be a first-class concept by making it an object, and then all operations we do are on that object. Some principles I was trying to keep: 1. There should be one clear definition of a reference media file, with all of its important parameters in one place with names that have obvious semantic meaning. 2. Operations that are conceptually connected to a reference media file should be methods on that object. 3. Tests should only have to use the object defined in 1. 4. Formalize the patterns we've already established in how we name reference files. 5. Make adding new reference files easy and obvious. Right now, we can only support reference tensors by timestamp with a generic approach. I think we should try to make that more an explicit pattern based on the pts value, but I'm not sure exactly how to do that right now. That will probably require changing some of our current reference file names. Also in the future, the reference file *generation* should probably use these definitions as well. That will ensure we keep everything consistent. Differential Revision: D59161329 --- test/decoders/simple_video_decoder_test.py | 144 +++++++++++++++------ test/decoders/test_metadata.py | 4 +- test/decoders/video_decoder_ops_test.py | 78 +++++------ test/test_utils.py | 69 ++++++---- 4 files changed, 186 insertions(+), 109 deletions(-) diff --git a/test/decoders/simple_video_decoder_test.py b/test/decoders/simple_video_decoder_test.py index 208e9dbb..f29628ef 100644 --- a/test/decoders/simple_video_decoder_test.py +++ b/test/decoders/simple_video_decoder_test.py @@ -3,25 +3,18 @@ from torchcodec.decoders import _core, SimpleVideoDecoder -from ..test_utils import ( - assert_equal, - EMPTY_REF_TENSOR, - get_reference_video_path, - get_reference_video_tensor, - load_tensor_from_file, - REF_DIMS, -) +from ..test_utils import assert_equal, NASA_REF_VIDEO class TestSimpleDecoder: @pytest.mark.parametrize("source_kind", ("path", "tensor", "bytes")) def test_create(self, source_kind): if source_kind == "path": - source = str(get_reference_video_path()) + source = str(NASA_REF_VIDEO.path) elif source_kind == "tensor": - source = get_reference_video_tensor() + source = NASA_REF_VIDEO.to_tensor() elif source_kind == "bytes": - path = str(get_reference_video_path()) + path = str(NASA_REF_VIDEO.path) with open(path, "rb") as f: source = f.read() else: @@ -42,12 +35,12 @@ def test_create_fails(self): decoder = SimpleVideoDecoder(123) # noqa def test_getitem_int(self): - decoder = SimpleVideoDecoder(str(get_reference_video_path())) + decoder = SimpleVideoDecoder(str(NASA_REF_VIDEO.path)) - ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") - ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt") - ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") - ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") + ref_frame0 = NASA_REF_VIDEO.get_ref_tensor_token("frame000001") + ref_frame1 = NASA_REF_VIDEO.get_ref_tensor_token("frame000002") + ref_frame180 = NASA_REF_VIDEO.get_ref_tensor_token("time6.000000") + ref_frame_last = NASA_REF_VIDEO.get_ref_tensor_token("time12.979633") assert_equal(ref_frame0, decoder[0]) assert_equal(ref_frame1, decoder[1]) @@ -55,86 +48,153 @@ def test_getitem_int(self): assert_equal(ref_frame_last, decoder[-1]) def test_getitem_slice(self): - decoder = SimpleVideoDecoder(str(get_reference_video_path())) + decoder = SimpleVideoDecoder(str(NASA_REF_VIDEO.path)) ref_frames0_9 = [ - load_tensor_from_file(f"nasa_13013.mp4.frame{i + 1:06d}.pt") - for i in range(0, 9) + NASA_REF_VIDEO.get_ref_tensor_index(i + 1) for i in range(0, 9) ] # Ensure that the degenerate case of a range of size 1 works; note that we get # a tensor which CONTAINS a single frame, rather than a tensor that itself IS a # single frame. Hence we have to access the 0th element of the return tensor. slice_0 = decoder[0:1] - assert slice_0.shape == torch.Size([1, *REF_DIMS]) + assert slice_0.shape == torch.Size( + [ + 1, + NASA_REF_VIDEO.height, + NASA_REF_VIDEO.width, + NASA_REF_VIDEO.num_color_channels, + ] + ) assert_equal(ref_frames0_9[0], slice_0[0]) slice_4 = decoder[4:5] - assert slice_4.shape == torch.Size([1, *REF_DIMS]) + assert slice_4.shape == torch.Size( + [ + 1, + NASA_REF_VIDEO.height, + NASA_REF_VIDEO.width, + NASA_REF_VIDEO.num_color_channels, + ] + ) assert_equal(ref_frames0_9[4], slice_4[0]) slice_8 = decoder[8:9] - assert slice_8.shape == torch.Size([1, *REF_DIMS]) + assert slice_8.shape == torch.Size( + [ + 1, + NASA_REF_VIDEO.height, + NASA_REF_VIDEO.width, + NASA_REF_VIDEO.num_color_channels, + ] + ) assert_equal(ref_frames0_9[8], slice_8[0]) - ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") + ref_frame180 = NASA_REF_VIDEO.get_ref_tensor_token("time6.000000") slice_180 = decoder[180:181] - assert slice_180.shape == torch.Size([1, *REF_DIMS]) + assert slice_180.shape == torch.Size( + [ + 1, + NASA_REF_VIDEO.height, + NASA_REF_VIDEO.width, + NASA_REF_VIDEO.num_color_channels, + ] + ) assert_equal(ref_frame180, slice_180[0]) # contiguous ranges slice_frames0_9 = decoder[0:9] - assert slice_frames0_9.shape == torch.Size([9, *REF_DIMS]) + assert slice_frames0_9.shape == torch.Size( + [ + 9, + NASA_REF_VIDEO.height, + NASA_REF_VIDEO.width, + NASA_REF_VIDEO.num_color_channels, + ] + ) for ref_frame, slice_frame in zip(ref_frames0_9, slice_frames0_9): assert_equal(ref_frame, slice_frame) slice_frames4_8 = decoder[4:8] - assert slice_frames4_8.shape == torch.Size([4, *REF_DIMS]) + assert slice_frames4_8.shape == torch.Size( + [ + 4, + NASA_REF_VIDEO.height, + NASA_REF_VIDEO.width, + NASA_REF_VIDEO.num_color_channels, + ] + ) for ref_frame, slice_frame in zip(ref_frames0_9[4:8], slice_frames4_8): assert_equal(ref_frame, slice_frame) # ranges with a stride ref_frames15_35 = [ - load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt") - for i in range(15, 36, 5) + NASA_REF_VIDEO.get_ref_tensor_index(i) for i in range(15, 36, 5) ] slice_frames15_35 = decoder[15:36:5] - assert slice_frames15_35.shape == torch.Size([5, *REF_DIMS]) + assert slice_frames15_35.shape == torch.Size( + [ + 5, + NASA_REF_VIDEO.height, + NASA_REF_VIDEO.width, + NASA_REF_VIDEO.num_color_channels, + ] + ) for ref_frame, slice_frame in zip(ref_frames15_35, slice_frames15_35): assert_equal(ref_frame, slice_frame) slice_frames0_9_2 = decoder[0:9:2] - assert slice_frames0_9_2.shape == torch.Size([5, *REF_DIMS]) + assert slice_frames0_9_2.shape == torch.Size( + [ + 5, + NASA_REF_VIDEO.height, + NASA_REF_VIDEO.width, + NASA_REF_VIDEO.num_color_channels, + ] + ) for ref_frame, slice_frame in zip(ref_frames0_9[0:0:2], slice_frames0_9_2): assert_equal(ref_frame, slice_frame) # negative numbers in the slice ref_frames386_389 = [ - load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt") - for i in range(386, 390) + NASA_REF_VIDEO.get_ref_tensor_index(i) for i in range(386, 390) ] slice_frames386_389 = decoder[-4:] - assert slice_frames386_389.shape == torch.Size([4, *REF_DIMS]) + assert slice_frames386_389.shape == torch.Size( + [ + 4, + NASA_REF_VIDEO.height, + NASA_REF_VIDEO.width, + NASA_REF_VIDEO.num_color_channels, + ] + ) for ref_frame, slice_frame in zip(ref_frames386_389[-4:], slice_frames386_389): assert_equal(ref_frame, slice_frame) # an empty range is valid! empty_frame = decoder[5:5] - assert_equal(empty_frame, EMPTY_REF_TENSOR) + assert_equal(empty_frame, NASA_REF_VIDEO.empty_hwc_tensor) # slices that are out-of-range are also valid - they return an empty tensor also_empty = decoder[10000:] - assert_equal(also_empty, EMPTY_REF_TENSOR) + assert_equal(also_empty, NASA_REF_VIDEO.empty_hwc_tensor) # should be just a copy all_frames = decoder[:] - assert all_frames.shape == torch.Size([len(decoder), *REF_DIMS]) + assert all_frames.shape == torch.Size( + [ + len(decoder), + NASA_REF_VIDEO.height, + NASA_REF_VIDEO.width, + NASA_REF_VIDEO.num_color_channels, + ] + ) for sliced, ref in zip(all_frames, decoder): assert_equal(sliced, ref) def test_getitem_fails(self): - decoder = SimpleVideoDecoder(str(get_reference_video_path())) + decoder = SimpleVideoDecoder(str(NASA_REF_VIDEO.path)) with pytest.raises(IndexError, match="out of bounds"): frame = decoder[1000] # noqa @@ -146,12 +206,12 @@ def test_getitem_fails(self): frame = decoder["0"] # noqa def test_next(self): - decoder = SimpleVideoDecoder(str(get_reference_video_path())) + decoder = SimpleVideoDecoder(str(NASA_REF_VIDEO.path)) - ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") - ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt") - ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") - ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") + ref_frame0 = NASA_REF_VIDEO.get_ref_tensor_token("frame000001") + ref_frame1 = NASA_REF_VIDEO.get_ref_tensor_token("frame000002") + ref_frame180 = NASA_REF_VIDEO.get_ref_tensor_token("time6.000000") + ref_frame_last = NASA_REF_VIDEO.get_ref_tensor_token("time12.979633") for i, frame in enumerate(decoder): if i == 0: diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py index 321dd640..469ad155 100644 --- a/test/decoders/test_metadata.py +++ b/test/decoders/test_metadata.py @@ -6,11 +6,11 @@ StreamMetadata, ) -from ..test_utils import get_reference_video_path +from ..test_utils import NASA_REF_VIDEO def test_get_video_metadata(): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_REF_VIDEO.path)) metadata = get_video_metadata(decoder) assert len(metadata.streams) == 6 assert metadata.best_video_stream_index == 3 diff --git a/test/decoders/video_decoder_ops_test.py b/test/decoders/video_decoder_ops_test.py index 0a2c402e..3be97518 100644 --- a/test/decoders/video_decoder_ops_test.py +++ b/test/decoders/video_decoder_ops_test.py @@ -24,20 +24,14 @@ seek_to_pts, ) -from ..test_utils import ( - assert_equal, - EMPTY_REF_TENSOR, - get_reference_audio_path, - get_reference_video_path, - load_tensor_from_file, -) +from ..test_utils import assert_equal, NASA_REF_AUDIO, NASA_REF_VIDEO torch._dynamo.config.capture_dynamic_output_shape_ops = True class ReferenceDecoder: def __init__(self): - self.decoder: torch.Tensor = create_from_file(str(get_reference_video_path())) + self.decoder: torch.Tensor = create_from_file(str(NASA_REF_VIDEO.path)) add_video_stream(self.decoder) def get_next_frame(self) -> torch.Tensor: @@ -52,26 +46,26 @@ def seek(self, pts: float): # TODO: Some of these tests could probably be unified and parametrized? class TestOps: def test_seek_and_next(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_REF_VIDEO.path)) add_video_stream(decoder) frame1 = get_next_frame(decoder) - reference_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") + reference_frame1 = NASA_REF_VIDEO.get_ref_tensor_token("frame000001") assert_equal(frame1, reference_frame1) - reference_frame2 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt") + reference_frame2 = NASA_REF_VIDEO.get_ref_tensor_token("frame000002") img2 = get_next_frame(decoder) assert_equal(img2, reference_frame2) seek_to_pts(decoder, 6.0) frame_time6 = get_next_frame(decoder) - reference_frame_time6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") + reference_frame_time6 = NASA_REF_VIDEO.get_ref_tensor_token("time6.000000") assert_equal(frame_time6, reference_frame_time6) def test_get_frame_at_pts(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_REF_VIDEO.path)) add_video_stream(decoder) # This frame has pts=6.006 and duration=0.033367, so it should be visible # at timestamps in the range [6.006, 6.039367) (not including the last timestamp). frame6 = get_frame_at_pts(decoder, 6.006) - reference_frame6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") + reference_frame6 = NASA_REF_VIDEO.get_ref_tensor_token("time6.000000") assert_equal(frame6, reference_frame6) frame6 = get_frame_at_pts(decoder, 6.02) assert_equal(frame6, reference_frame6) @@ -85,37 +79,36 @@ def test_get_frame_at_pts(self): assert_equal(next_frame, reference_frame6) def test_get_frame_at_index(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_REF_VIDEO.path)) add_video_stream(decoder) frame1 = get_frame_at_index(decoder, stream_index=3, frame_index=0) - reference_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") + reference_frame1 = NASA_REF_VIDEO.get_ref_tensor_token("frame000001") assert_equal(frame1, reference_frame1) # The frame that is displayed at 6 seconds is frame 180 from a 0-based index. frame6 = get_frame_at_index(decoder, stream_index=3, frame_index=180) - reference_frame6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") + reference_frame6 = NASA_REF_VIDEO.get_ref_tensor_token("time6.000000") assert_equal(frame6, reference_frame6) def test_get_frames_at_indices(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_REF_VIDEO.path)) add_video_stream(decoder) frames1and6 = get_frames_at_indices( decoder, stream_index=3, frame_indices=[0, 180] ) - reference_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") - reference_frame6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") + reference_frame1 = NASA_REF_VIDEO.get_ref_tensor_token("frame000001") + reference_frame6 = NASA_REF_VIDEO.get_ref_tensor_token("time6.000000") assert_equal(frames1and6[0], reference_frame1) assert_equal(frames1and6[1], reference_frame6) def test_get_frames_in_range(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_REF_VIDEO.path)) add_video_stream(decoder) ref_frames0_9 = [ - load_tensor_from_file(f"nasa_13013.mp4.frame{i + 1:06d}.pt") - for i in range(0, 9) + NASA_REF_VIDEO.get_ref_tensor_index(i + 1) for i in range(0, 9) ] - ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") - ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") + ref_frame180 = NASA_REF_VIDEO.get_ref_tensor_token("time6.000000") + ref_frame_last = NASA_REF_VIDEO.get_ref_tensor_token("time12.979633") # ensure that the degenerate case of a range of size 1 works bulk_frame0 = get_frames_in_range(decoder, stream_index=3, start=0, stop=1) @@ -145,8 +138,7 @@ def test_get_frames_in_range(self): # ranges with a stride ref_frames15_35 = [ - load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt") - for i in range(15, 36, 5) + NASA_REF_VIDEO.get_ref_tensor_index(i) for i in range(15, 36, 5) ] bulk_frames15_35 = get_frames_in_range( decoder, stream_index=3, start=15, stop=36, step=5 @@ -162,20 +154,20 @@ def test_get_frames_in_range(self): # an empty range is valid! empty_frame = get_frames_in_range(decoder, stream_index=3, start=5, stop=5) - assert_equal(empty_frame, EMPTY_REF_TENSOR) + assert_equal(empty_frame, NASA_REF_VIDEO.empty_hwc_tensor) def test_throws_exception_at_eof(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_REF_VIDEO.path)) add_video_stream(decoder) seek_to_pts(decoder, 12.979633) last_frame = get_next_frame(decoder) - reference_last_frame = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") + reference_last_frame = NASA_REF_VIDEO.get_ref_tensor_token("time12.979633") assert_equal(last_frame, reference_last_frame) with pytest.raises(RuntimeError, match="End of file"): get_next_frame(decoder) def test_throws_exception_if_seek_too_far(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_REF_VIDEO.path)) add_video_stream(decoder) # pts=12.979633 is the last frame in the video. seek_to_pts(decoder, 12.979633 + 1.0e-4) @@ -196,10 +188,10 @@ def get_frame1_and_frame_time6(decoder): # NB: create needs to happen outside the torch.compile region, # for now. Otherwise torch.compile constant-props it. - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_REF_VIDEO.path)) frame1, frame_time6 = get_frame1_and_frame_time6(decoder) - reference_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") - reference_frame_time6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") + reference_frame1 = NASA_REF_VIDEO.get_ref_tensor_token("frame000001") + reference_frame_time6 = NASA_REF_VIDEO.get_ref_tensor_token("time6.000000") assert_equal(frame1, reference_frame1) assert_equal(frame_time6, reference_frame_time6) @@ -216,14 +208,14 @@ def class_based_get_frame1_and_frame_time6( decoder = ReferenceDecoder() frame1, frame_time6 = class_based_get_frame1_and_frame_time6(decoder) - reference_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") - reference_frame_time6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") + reference_frame1 = NASA_REF_VIDEO.get_ref_tensor_token("frame000001") + reference_frame_time6 = NASA_REF_VIDEO.get_ref_tensor_token("time6.000000") assert_equal(frame1, reference_frame1) assert_equal(frame_time6, reference_frame_time6) @pytest.mark.parametrize("create_from", ("file", "tensor", "bytes")) def test_create_decoder(self, create_from): - path = str(get_reference_video_path()) + path = str(NASA_REF_VIDEO.path) if create_from == "file": decoder = create_from_file(path) elif create_from == "tensor": @@ -237,14 +229,14 @@ def test_create_decoder(self, create_from): add_video_stream(decoder) frame1 = get_next_frame(decoder) - reference_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") + reference_frame1 = NASA_REF_VIDEO.get_ref_tensor_token("frame000001") assert_equal(frame1, reference_frame1) - reference_frame2 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt") + reference_frame2 = NASA_REF_VIDEO.get_ref_tensor_token("frame000002") img2 = get_next_frame(decoder) assert_equal(img2, reference_frame2) seek_to_pts(decoder, 6.0) frame_time6 = get_next_frame(decoder) - reference_frame_time6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") + reference_frame_time6 = NASA_REF_VIDEO.get_ref_tensor_token("time6.000000") assert_equal(frame_time6, reference_frame_time6) # TODO: Keeping the metadata tests below for now, but we should remove them @@ -255,7 +247,7 @@ def test_create_decoder(self, create_from): # always call scanFileAndUpdateMetadataAndIndex() when creating a decoder # from the core API. def test_video_get_json_metadata(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_REF_VIDEO.path)) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) @@ -271,7 +263,7 @@ def test_video_get_json_metadata(self): assert metadata_dict["bitRate"] == 324915.0 def test_video_get_json_metadata_with_stream(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_REF_VIDEO.path)) add_video_stream(decoder) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) @@ -281,7 +273,7 @@ def test_video_get_json_metadata_with_stream(self): assert metadata_dict["maxPtsSecondsFromScan"] == 13.013 def test_audio_get_json_metadata(self): - decoder = create_from_file(str(get_reference_audio_path())) + decoder = create_from_file(str(NASA_REF_AUDIO.path)) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) assert metadata_dict["durationSeconds"] == pytest.approx(13.25, abs=0.01) diff --git a/test/test_utils.py b/test/test_utils.py index 6073d843..f0f0ea92 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2,25 +2,23 @@ import os import pathlib +from dataclasses import dataclass + import numpy as np import pytest import torch -# The dimensions and type have to match the frames in our reference video. -REF_DIMS = (270, 480, 3) -EMPTY_REF_TENSOR = torch.empty([0, *REF_DIMS], dtype=torch.uint8) + +def assert_equal(*args, **kwargs): + torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0) def in_fbcode() -> bool: return os.environ.get("IN_FBCODE_TORCHCODEC") == "1" -def assert_equal(*args, **kwargs): - torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0) - - -def get_video_path(filename: str) -> pathlib.Path: +def _get_file_path(filename: str) -> pathlib.Path: if in_fbcode(): resource = ( importlib.resources.files(__spec__.parent) @@ -33,25 +31,52 @@ def get_video_path(filename: str) -> pathlib.Path: return pathlib.Path(__file__).parent / "resources" / filename -def get_reference_video_path() -> pathlib.Path: - return get_video_path("nasa_13013.mp4") +def _load_tensor_from_file(filename: str) -> torch.Tensor: + file_path = _get_file_path(filename) + return torch.load(file_path, weights_only=True) -def get_reference_audio_path() -> pathlib.Path: - return get_video_path("nasa_13013.mp4.audio.mp3") +@pytest.fixture() +def reference_video_tensor() -> torch.Tensor: + return NASA_REF_VIDEO.to_tensor() -def load_tensor_from_file(filename: str) -> torch.Tensor: - file_path = get_video_path(filename) - return torch.load(file_path, weights_only=True) +@dataclass +class TestFile: + filename: str + @property + def path(self) -> pathlib.Path: + return _get_file_path(self.filename) -def get_reference_video_tensor() -> torch.Tensor: - arr = np.fromfile(get_reference_video_path(), dtype=np.uint8) - video_tensor = torch.from_numpy(arr) - return video_tensor + def to_tensor(self) -> torch.Tensor: + arr = np.fromfile(self.path, dtype=np.uint8) + return torch.from_numpy(arr) + def get_ref_tensor_index(self, idx: int) -> torch.Tensor: + return _load_tensor_from_file(f"{self.filename}.frame{idx:06d}.pt") -@pytest.fixture() -def reference_video_tensor() -> torch.Tensor: - return get_reference_video_tensor() + def get_ref_tensor_token(self, token: str) -> torch.Tensor: + return _load_tensor_from_file(f"{self.filename}.{token}.pt") + + +@dataclass +class TestVideo(TestFile): + height: int + width: int + num_color_channels: int + + @property + def empty_hwc_tensor(self) -> torch.Tensor: + return torch.empty( + [0, self.height, self.width, self.num_color_channels], dtype=torch.uint8 + ) + + +NASA_REF_VIDEO = TestVideo( + filename="nasa_13013.mp4", height=270, width=480, num_color_channels=3 +) + +# When we start actually decoding audio-only files, we'll probably need to define +# a TestAudio class with audio specific values. Until then, we only need a filename. +NASA_REF_AUDIO = TestFile(filename="nasa_13013.mp4.audio.mp3")