Skip to content

Commit

Permalink
[fbsync] [PoC] refactor transforms v2 tests (#7562)
Browse files Browse the repository at this point in the history
Reviewed By: vmoens

Differential Revision: D47186580

fbshipit-source-id: b5703d86d716c6eb804076ebb969fa23e20287b4

Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
  • Loading branch information
3 people authored and facebook-github-bot committed Jul 3, 2023
1 parent 5726885 commit c3d3914
Show file tree
Hide file tree
Showing 5 changed files with 743 additions and 210 deletions.
22 changes: 22 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import os
import pathlib
import random
import re
import shutil
import sys
import tempfile
import warnings
from collections import defaultdict
from subprocess import CalledProcessError, check_output, STDOUT
from typing import Callable, Sequence, Tuple, Union
Expand Down Expand Up @@ -880,3 +882,23 @@ def assert_run_python_script(source_code):
raise RuntimeError(f"script errored with output:\n{e.output.decode()}")
if out != b"":
raise AssertionError(out.decode())


@contextlib.contextmanager
def assert_no_warnings():
# The name `catch_warnings` is a misnomer as the context manager does **not** catch any warnings, but rather scopes
# the warning filters. All changes that are made to the filters while in this context, will be reset upon exit.
with warnings.catch_warnings():
warnings.simplefilter("error")
yield


@contextlib.contextmanager
def ignore_jit_no_profile_information_warning():
# Calling a scripted object often triggers a warning like
# `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=re.escape("operator() profile_node %"), category=UserWarning)
yield
24 changes: 0 additions & 24 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,8 +1711,6 @@ def test_antialias_warning():
tensor_video = torch.randint(0, 256, size=(2, 3, 10, 10), dtype=torch.uint8)

match = "The default value of the antialias parameter"
with pytest.warns(UserWarning, match=match):
transforms.Resize((20, 20))(tensor_img)
with pytest.warns(UserWarning, match=match):
transforms.RandomResizedCrop((20, 20))(tensor_img)
with pytest.warns(UserWarning, match=match):
Expand All @@ -1722,18 +1720,6 @@ def test_antialias_warning():
with pytest.warns(UserWarning, match=match):
transforms.RandomResize(10, 20)(tensor_img)

with pytest.warns(UserWarning, match=match):
transforms.functional.resize(tensor_img, (20, 20))
with pytest.warns(UserWarning, match=match):
transforms.functional.resize_image_tensor(tensor_img, (20, 20))

with pytest.warns(UserWarning, match=match):
transforms.functional.resize(tensor_video, (20, 20))
with pytest.warns(UserWarning, match=match):
transforms.functional.resize_video(tensor_video, (20, 20))

with pytest.warns(UserWarning, match=match):
datapoints.Image(tensor_img).resize((20, 20))
with pytest.warns(UserWarning, match=match):
datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20))

Expand All @@ -1744,27 +1730,17 @@ def test_antialias_warning():

with warnings.catch_warnings():
warnings.simplefilter("error")
transforms.Resize((20, 20))(pil_img)
transforms.RandomResizedCrop((20, 20))(pil_img)
transforms.ScaleJitter((20, 20))(pil_img)
transforms.RandomShortestSize((20, 20))(pil_img)
transforms.RandomResize(10, 20)(pil_img)
transforms.functional.resize(pil_img, (20, 20))

transforms.Resize((20, 20), antialias=True)(tensor_img)
transforms.RandomResizedCrop((20, 20), antialias=True)(tensor_img)
transforms.ScaleJitter((20, 20), antialias=True)(tensor_img)
transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img)
transforms.RandomResize(10, 20, antialias=True)(tensor_img)

transforms.functional.resize(tensor_img, (20, 20), antialias=True)
transforms.functional.resize_image_tensor(tensor_img, (20, 20), antialias=True)
transforms.functional.resize(tensor_video, (20, 20), antialias=True)
transforms.functional.resize_video(tensor_video, (20, 20), antialias=True)

datapoints.Image(tensor_img).resize((20, 20), antialias=True)
datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20), antialias=True)
datapoints.Video(tensor_video).resize((20, 20), antialias=True)
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20), antialias=True)


Expand Down
Loading

0 comments on commit c3d3914

Please sign in to comment.