Skip to content

Commit

Permalink
[torchcodec] refactor test utils into its own library (#25)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #25

Refactors utility functions that were defined directly in test source files into a library that can be shared among tests. This diff:

1. Creates a Python test utility library that can be imported by Python tests.
2. Moves the test resources to the test top-level.
3. Defines a C++ target for those resources.

This is in preparation for adding new tests that will need this library.

Differential Revision: D58530481
  • Loading branch information
scotts authored and facebook-github-bot committed Jun 13, 2024
1 parent 067e18b commit cbbfe89
Show file tree
Hide file tree
Showing 14 changed files with 65 additions and 60 deletions.
3 changes: 1 addition & 2 deletions test/decoders/VideoDecoderOpsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ namespace facebook::torchcodec {

std::string getResourcePath(const std::string& filename) {
#ifdef FBCODE_BUILD
std::string filepath =
"pytorch/torchcodec/test/decoders/resources/" + filename;
std::string filepath = "pytorch/torchcodec/test/resources/" + filename;
filepath = build::getResourcePath(filepath).string();
#else
std::filesystem::path dirPath = std::filesystem::path(__FILE__);
Expand Down
3 changes: 1 addition & 2 deletions test/decoders/VideoDecoderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ namespace facebook::torchcodec {

std::string getResourcePath(const std::string& filename) {
#ifdef FBCODE_BUILD
std::string filepath =
"pytorch/torchcodec/test/decoders/resources/" + filename;
std::string filepath = "pytorch/torchcodec/test/resources/" + filename;
filepath = build::getResourcePath(filepath).string();
#else
std::filesystem::path dirPath = std::filesystem::path(__FILE__);
Expand Down
2 changes: 1 addition & 1 deletion test/decoders/manual_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torchcodec

decoder = torchcodec.decoders._core.create_from_file(
os.path.dirname(__file__) + "/resources/nasa_13013.mp4"
os.path.dirname(__file__) + "../resources/nasa_13013.mp4"
)
torchcodec.decoders._core.add_video_stream(decoder, stream_index=3)
frame = torchcodec.decoders._core.get_frame_at_index(
Expand Down
43 changes: 7 additions & 36 deletions test/decoders/video_decoder_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pytest

import torch
import torchvision.transforms as transforms
from PIL import Image

from torchcodec.decoders._core import (
Expand All @@ -26,42 +25,14 @@
seek_to_pts,
)

torch._dynamo.config.capture_dynamic_output_shape_ops = True
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHCODEC") == "1"


# TODO: Eventually move that as a common test util
def assert_equal(*args, **kwargs):
torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0)


# TODO: Eventually move that as a common test util
def get_video_path(filename: str) -> pathlib.Path:
if IN_FBCODE:
resource = (
importlib.resources.files(__package__)
.joinpath("resources")
.joinpath(filename)
)
with importlib.resources.as_file(resource) as path:
return path
else:
return pathlib.Path(__file__).parent / "resources" / filename


# TODO: make this a fixture or wrap with @functools.lru_cache to avoid
# re-computing?
def load_tensor_from_file(filename: str) -> torch.Tensor:
file_path = get_video_path(filename)
return torch.load(file_path)


def get_reference_video_path() -> pathlib.Path:
return get_video_path("nasa_13013.mp4")

from ..test_utils import (
assert_equal,
get_reference_audio_path,
get_reference_video_path,
load_tensor_from_file,
)

def get_reference_audio_path() -> pathlib.Path:
return get_video_path("nasa_13013.mp4.audio.mp3")
torch._dynamo.config.capture_dynamic_output_shape_ops = True


class ReferenceDecoder:
Expand Down
File renamed without changes.
File renamed without changes.
22 changes: 3 additions & 19 deletions test/samplers/video_clip_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,7 @@
VideoClipSampler,
)


# TODO: move this to a common util
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHCODEC") == "1"


# TODO: Eventually rely on common util for this
@pytest.fixture()
def nasa_13013() -> torch.Tensor:
if IN_FBCODE:
video_path = importlib.resources.path(__package__, "nasa_13013.mp4")
else:
video_path = (
Path(__file__).parent.parent / "decoders" / "resources" / "nasa_13013.mp4"
)
arr = np.fromfile(video_path, dtype=np.uint8)
video_tensor = torch.from_numpy(arr)
return video_tensor
from ..test_utils import assert_equal, nasa_13013 # noqa: F401; see nasa_13013 use


@pytest.mark.parametrize(
Expand All @@ -51,13 +35,13 @@ def nasa_13013() -> torch.Tensor:
),
],
)
def test_sampler(sampler_args, nasa_13013):
def test_sampler(sampler_args, nasa_13013): # noqa: F811; linter does not see this as a use
torch.manual_seed(0)
desired_width, desired_height = 320, 240
video_args = VideoArgs(desired_width=desired_width, desired_height=desired_height)
sampler = VideoClipSampler(video_args, sampler_args)
clips = sampler(nasa_13013)
assert len(clips) == sampler_args.clips_per_video
assert_equal(len(clips), sampler_args.clips_per_video)
clip = clips[0]
if isinstance(sampler_args, TimeBasedSamplerArgs):
# TODO FIXME: Looks like we have an API inconsistency.
Expand Down
52 changes: 52 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import importlib
import os
import pathlib

import numpy as np
import pytest

import torch


def in_fbcode() -> bool:
return os.environ.get("IN_FBCODE_TORCHCODEC") == "1"


IN_FBCODE = in_fbcode()


def assert_equal(*args, **kwargs):
torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0)


def get_video_path(filename: str) -> pathlib.Path:
if IN_FBCODE:
resource = (
importlib.resources.files(__spec__.parent)
.joinpath("resources")
.joinpath(filename)
)
with importlib.resources.as_file(resource) as path:
return path
else:
return pathlib.Path(__file__).parent / "resources" / filename


def get_reference_video_path() -> pathlib.Path:
return get_video_path("nasa_13013.mp4")


def get_reference_audio_path() -> pathlib.Path:
return get_video_path("nasa_13013.mp4.audio.mp3")


def load_tensor_from_file(filename: str) -> torch.Tensor:
file_path = get_video_path(filename)
return torch.load(file_path)


@pytest.fixture()
def nasa_13013() -> torch.Tensor:
arr = np.fromfile(get_reference_video_path(), dtype=np.uint8)
video_tensor = torch.from_numpy(arr)
return video_tensor

0 comments on commit cbbfe89

Please sign in to comment.