diff --git a/.circleci/config.yml b/.circleci/config.yml index 5f98b8b6e..1ccc2b23a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -16,7 +16,7 @@ cpu: &cpu environment: TERM: xterm machine: - image: default + image: ubuntu-1604:201903-01 resource_class: medium gpu: &gpu @@ -37,8 +37,8 @@ install_python: &install_python working_directory: ~/ command: | pyenv versions - pyenv install 3.6.2 - pyenv global 3.6.2 + pyenv install -f 3.7.0 + pyenv global 3.7.0 update_gcc7: &update_gcc7 - run: @@ -107,6 +107,17 @@ install_vissl_dep: &install_vissl_dep # Update this since classy_vision seems to need it. pip install --progress-bar off --upgrade iopath +# Must install python3-magic as per documentation: +# https://github.com/facebookresearch/AugLy#installation +install_augly: &install_augly + - run: + name: Install augly + working_directory: ~/vissl + command: | + pip install augly + sudo apt-get update + sudo apt-get install python3-magic + install_apex_gpu: &install_apex_gpu - run: name: Install Apex @@ -153,9 +164,10 @@ jobs: # Cache the vissl_venv directory that contains dependencies - restore_cache: keys: - - v5-cpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }} + - v6-cpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }} - <<: *install_vissl_dep + - <<: *install_augly - <<: *install_classy_vision - <<: *install_apex_cpu - <<: *pip_list @@ -163,7 +175,7 @@ jobs: - save_cache: paths: - ~/vissl_venv - key: v5-cpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }} + key: v6-cpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }} - <<: *install_vissl @@ -196,7 +208,7 @@ jobs: # Download and cache dependencies - restore_cache: keys: - - v5-gpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}-{{ checksum "docker/common/install_apex.sh" }} + - v6-gpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}-{{ checksum "docker/common/install_apex.sh" }} - <<: *install_vissl_dep - <<: *install_classy_vision @@ -211,7 +223,7 @@ jobs: - save_cache: paths: - ~/vissl_venv - key: v5-gpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}-{{ checksum "docker/common/install_apex.sh" }} + key: v6-gpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}-{{ checksum "docker/common/install_apex.sh" }} - <<: *install_vissl diff --git a/configs/config/test/transforms/augly_transforms_example.yaml b/configs/config/test/transforms/augly_transforms_example.yaml new file mode 100644 index 000000000..1fa86c762 --- /dev/null +++ b/configs/config/test/transforms/augly_transforms_example.yaml @@ -0,0 +1,25 @@ +# @package _global_ +config: + DATA: + TRAIN: + TRANSFORMS: + - name: ImgReplicatePil + num_times: 2 + - name: RandomResizedCrop + size: 224 + - name: RandomHorizontalFlip + p: 0.5 + - name: ImgPilColorDistortion + strength: 1.0 + - name: ImgPilGaussianBlur + p: 0.5 + radius_min: 0.1 + radius_max: 2.0 + - name: Blur + transform_type: "augly" + radius: 2.0 + p: 1.0 + - name: ToTensor + - name: Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] diff --git a/tests/test_transforms.py b/tests/test_transforms.py index e4f395f98..99a451e77 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -12,6 +12,11 @@ from vissl.data.ssl_transforms.img_pil_to_multicrop import ImgPilToMultiCrop from vissl.data.ssl_transforms.img_pil_to_tensor import ImgToTensor from vissl.data.ssl_transforms.mnist_img_pil_to_rgb_mode import MNISTImgPil2RGB +from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict +from vissl.utils.test_utils import ( + in_temporary_directory, + run_integration_test, +) RAND_TENSOR = (torch.rand((224, 224, 3)) * 255).to(dtype=torch.uint8) @@ -77,3 +82,16 @@ def test_img_pil_to_multicrop(self): self.assertEqual((224, 224), crop.size) for crop in crops[2:]: self.assertEqual((96, 96), crop.size) + + def test_augly_transforms(self): + cfg = compose_hydra_configuration( + [ + "config=test/cpu_test/test_cpu_resnet_simclr.yaml", + "+config/test/transforms=augly_transforms_example", + ], + ) + args, config = convert_to_attrdict(cfg) + + with in_temporary_directory() as _: + # Test that the training runs with an augly transformation. + run_integration_test(config) diff --git a/vissl/data/ssl_transforms/ssl_transforms_wrapper.py b/vissl/data/ssl_transforms/ssl_transforms_wrapper.py index 3891d4f34..c615002fb 100644 --- a/vissl/data/ssl_transforms/ssl_transforms_wrapper.py +++ b/vissl/data/ssl_transforms/ssl_transforms_wrapper.py @@ -5,9 +5,15 @@ from typing import Any, Dict -from classy_vision.dataset.transforms import build_transform, register_transform +from classy_vision.dataset.transforms import ( + build_transform as build_classy_transform, + register_transform, +) from classy_vision.dataset.transforms.classy_transform import ClassyTransform +from vissl.utils.misc import is_augly_available +if is_augly_available(): + import augly.image as imaugs # NOQA # Below the transforms that require passing the labels as well. This is specifc # to SSL only where we automatically generate the labels for training. All other @@ -108,7 +114,7 @@ def __init__( """ self.indices = set(indices) self.name = args["name"] - self.transform = build_transform(args) + self.transform = self._build_transform(args) self.transform_receives_entire_batch = transform_receives_entire_batch self.transforms_with_labels = transform_types["TRANSFORMS_WITH_LABELS"] self.transforms_with_copies = transform_types["TRANSFORMS_WITH_COPIES"] @@ -117,6 +123,38 @@ def __init__( ] self.transforms_with_grouping = transform_types["TRANSFORMS_WITH_GROUPING"] + def _build_transform(self, args): + if "transform_type" not in args: + # Default to classy transform. + return build_classy_transform(args) + elif args["transform_type"] == "augly": + # Build augly transform. + return self._build_augly_transform(args) + else: + raise RuntimeError( + f"Transform type: { args.transform_type } is not supported" + ) + + def _build_augly_transform(self, args): + assert is_augly_available(), "Please pip install augly." + + # the name should be available in augly.image + # if users specify the transform name in snake case, + # we need to convert it to title case. + name = args["name"] + + if not hasattr(imaugs, name): + # Try converting name to title case. + name = name.title().replace("_", "") + + assert hasattr(imaugs, name), f"{name} isn't a registered tranform for augly." + + # Delete superfluous keys. + del args["name"] + del args["transform_type"] + + return getattr(imaugs, name)(**args) + def _is_transform_with_labels(self): """ _TRANSFORMS_WITH_LABELS = ["ImgRotatePil", "ShuffleImgPatches"] diff --git a/vissl/utils/hydra_config.py b/vissl/utils/hydra_config.py index 9922621c7..b46e80475 100644 --- a/vissl/utils/hydra_config.py +++ b/vissl/utils/hydra_config.py @@ -12,6 +12,7 @@ from omegaconf import DictConfig, OmegaConf from vissl.config import AttrDict, check_cfg_version from vissl.utils.io import save_file +from vissl.utils.misc import is_augly_available def save_attrdict_to_disk(cfg: AttrDict): @@ -462,6 +463,16 @@ def infer_losses_config(cfg): return cfg +def assert_transforms(cfg): + for transforms in [cfg.DATA.TRAIN.TRANSFORMS, cfg.DATA.TEST.TRANSFORMS]: + for transform in transforms: + if "transform_type" in transform: + assert transform["transform_type"] in [None, "augly"] + + if transform["transform_type"] == "augly": + assert is_augly_available(), "Please pip install augly." + + def infer_and_assert_hydra_config(cfg): """ Infer values of few parameters in the config file using the value of other config parameters @@ -480,6 +491,7 @@ def infer_and_assert_hydra_config(cfg): """ cfg = infer_losses_config(cfg) cfg = infer_learning_rate(cfg) + assert_transforms(cfg) # pass the seed to cfg["MODEL"] so that model init on different nodes can # use the same seed. diff --git a/vissl/utils/misc.py b/vissl/utils/misc.py index 6e988b86b..13cbe107f 100644 --- a/vissl/utils/misc.py +++ b/vissl/utils/misc.py @@ -7,6 +7,7 @@ import logging import os import random +import sys import tempfile import time from functools import partial, wraps @@ -80,6 +81,25 @@ def is_apex_available(): return apex_available +def is_augly_available(): + """ + Check if apex is available with simple python imports. + """ + try: + assert sys.version_info >= ( + 3, + 7, + 0, + ), "Please upgrade your python version to 3.7 or higher to use Augly." + + import augly.image # NOQA + + augly_available = True + except ImportError: + augly_available = False + return augly_available + + def find_free_tcp_port(): """ Find the free port that can be used for Rendezvous on the local machine.