Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ashwin Vaidya committed Jan 5, 2022
1 parent 9e8ead8 commit 07c1b3c
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 67 deletions.
30 changes: 18 additions & 12 deletions anomalib/core/callbacks/compress.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
"""Callback that compresses a trained model by first exporting to .onnx format, and then converting to OpenVINO IR."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

import os
from typing import Tuple

import torch
from pytorch_lightning import Callback, LightningModule

from anomalib.utils.optimize import export_convert


class CompressModelCallback(Callback):
"""Callback to compresses a trained model.
Expand All @@ -30,14 +46,4 @@ def on_train_end(self, trainer, pl_module: LightningModule) -> None: # pylint:
"""
os.makedirs(self.dirpath, exist_ok=True)
onnx_path = os.path.join(self.dirpath, self.filename + ".onnx")
height, width = self.input_size
torch.onnx.export(
pl_module.model,
torch.zeros((1, 3, height, width)).to(pl_module.device),
onnx_path,
opset_version=11,
input_names=["input"],
output_names=["output"],
)
optimize_command = "mo --input_model " + onnx_path + " --output_dir " + self.dirpath
os.system(optimize_command)
export_convert(model=pl_module.model, input_size=self.input_size, onnx_path=onnx_path, export_path=self.dirpath)
52 changes: 52 additions & 0 deletions anomalib/utils/optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Utilities for optimization and OpenVINO conversion."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.


import os
from pathlib import Path
from typing import List, Tuple, Union

import pytorch_lightning as pl
import torch

from anomalib.core.model.anomaly_module import AnomalyModule


def export_convert(
model: Union[pl.LightningModule, AnomalyModule],
input_size: Union[List[int], Tuple[int, int]],
onnx_path: Union[str, Path],
export_path: Union[str, Path],
):
"""Export the model to onnx format and convert to OpenVINO IR.
Args:
model (Union[pl.LightningModule, AnomalyModule]): Model to convert.
input_size (Union[List[int], Tuple[int, int]]): Image size used as the input for onnx converter.
onnx_path (Union[str, Path]): Path to output onnx model.
export_path (Union[str, Path]): Path to exported OpenVINO IR.
"""
height, width = input_size
torch.onnx.export(

This comment has been minimized.

Copy link
@samet-akcay

samet-akcay Jan 5, 2022

Contributor

I was wondering if we could use PyTorch Lighning's to_onnx method instead?

model,
torch.zeros((1, 3, height, width)).to(model.device),
onnx_path,
opset_version=11,
input_names=["input"],
output_names=["output"],
)
optimize_command = "mo --input_model " + str(onnx_path) + " --output_dir " + str(export_path)
os.system(optimize_command)
15 changes: 15 additions & 0 deletions anomalib/utils/sweep/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Helpers for benchmarking and hyperparameter optimization."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import time
from pathlib import Path
from typing import Dict, Iterable, List, Union
from typing import Dict, Iterable, List, Tuple, Union

import numpy as np
from omegaconf import DictConfig, ListConfig
Expand Down Expand Up @@ -55,6 +55,41 @@ def __call__(self) -> Iterable[np.ndarray]:
yield self.image


def get_meta_data(model: AnomalyModule, input_size: Tuple[int, int]) -> Dict:
"""Get meta data for inference.
Args:
model (AnomalyModule): Trained model from which the metadata is extracted.
input_size (Tuple[int, int]): Input size used to resize the pixel level mean and std.
Returns:
(Dict): Metadata as dictionary.
"""
meta_data = {
"image_threshold": model.image_threshold.value.cpu().numpy(),
"pixel_threshold": model.pixel_threshold.value.cpu().numpy(),
"stats": {},
}

image_mean = model.training_distribution.image_mean.cpu().numpy()
if image_mean.size > 0:
meta_data["stats"]["image_mean"] = image_mean

image_std = model.training_distribution.image_std.cpu().numpy()
if image_std.size > 0:
meta_data["stats"]["image_std"] = image_std

pixel_mean = model.training_distribution.pixel_mean.cpu().numpy()
if pixel_mean.size > 0:
meta_data["stats"]["pixel_mean"] = pixel_mean.reshape(input_size)

pixel_std = model.training_distribution.pixel_std.cpu().numpy()
if pixel_std.size > 0:
meta_data["stats"]["pixel_std"] = pixel_std.reshape(input_size)

return meta_data


def get_torch_throughput(
config: Union[DictConfig, ListConfig], model: AnomalyModule, test_dataset: DataLoader, meta_data: Dict
) -> float:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Helpers for benchmarking."""
"""Tests for models and functions in core/model."""

# Copyright (C) 2020 Intel Corporation
#
Expand Down
124 changes: 124 additions & 0 deletions tests/core/model/test_inferencer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Tests for Torch and OpenVINO inferencers."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Union

import pytest
import torch
from omegaconf import DictConfig, ListConfig
from pytorch_lightning import Trainer

from anomalib.config import get_configurable_parameters
from anomalib.core.model.inference import OpenVINOInferencer, TorchInferencer
from anomalib.data import get_datamodule
from anomalib.models import get_model
from anomalib.utils.optimize import export_convert
from anomalib.utils.sweep.helpers.inference import MockImageLoader, get_meta_data
from tests.helpers.dataset import TestDataset, get_dataset_path


def get_model_config(
model_name: str, project_path: str, dataset_path: str, category: str
) -> Union[DictConfig, ListConfig]:
model_config = get_configurable_parameters(model_name=model_name)
model_config.project.path = project_path
model_config.dataset.path = dataset_path
model_config.dataset.category = category
model_config.trainer.max_epochs = 1
return model_config


class TestInferencers:
@pytest.mark.parametrize(
"model_name",
[
"padim",
"stfpm",
"patchcore",
],
)
@TestDataset(num_train=20, num_test=1, path=get_dataset_path(), use_mvtec=False)
def test_torch_inference(self, model_name: str, category: str = "shapes", path: str = "./datasets/MVTec"):
"""Tests Torch inference.
Model is not trained as this checks that the inferencers are working.
Args:
model_name (str): Name of the model
"""
with TemporaryDirectory() as project_path:
model_config = get_model_config(
model_name=model_name, dataset_path=path, category=category, project_path=project_path
)

model = get_model(model_config)
trainer = Trainer(logger=False, **model_config.trainer)
datamodule = get_datamodule(model_config)

trainer.fit(model=model, datamodule=datamodule)

model.eval()

# Test torch inferencer
torch_inferencer = TorchInferencer(model_config, model)
torch_dataloader = MockImageLoader(model_config.dataset.image_size, total_count=1)
meta_data = get_meta_data(model, model_config.dataset.image_size)
with torch.no_grad():
for image in torch_dataloader():
torch_inferencer.predict(image, superimpose=False, meta_data=meta_data)

@pytest.mark.parametrize(
"model_name",
[
"padim",
"stfpm",
],
)
@TestDataset(num_train=20, num_test=1, path=get_dataset_path(), use_mvtec=False)
def test_openvino_inference(self, model_name: str, category: str = "shapes", path: str = "./datasets/MVTec"):
"""Tests OpenVINO inference.
Model is not trained as this checks that the inferencers are working.
Args:
model_name (str): Name of the model
"""
with TemporaryDirectory() as project_path:
model_config = get_model_config(
model_name=model_name, dataset_path=path, category=category, project_path=project_path
)
export_path = Path(project_path)

model = get_model(model_config)
trainer = Trainer(logger=False, **model_config.trainer)
datamodule = get_datamodule(model_config)
trainer.fit(model=model, datamodule=datamodule)

export_convert(
model=model,
input_size=model_config.dataset.image_size,
onnx_path=export_path / "model.onnx",
export_path=export_path,
)

# Test OpenVINO inferencer
openvino_inferencer = OpenVINOInferencer(model_config, export_path / "model.xml")
openvino_dataloader = MockImageLoader(model_config.dataset.image_size, total_count=1)
meta_data = get_meta_data(model, model_config.dataset.image_size)
for image in openvino_dataloader():
openvino_inferencer.predict(image, superimpose=False, meta_data=meta_data)
66 changes: 66 additions & 0 deletions tests/utils/test_sweep_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Tests for benchmarking configuration utils."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from omegaconf import DictConfig

from anomalib.utils.sweep.config import (
flatten_sweep_params,
get_run_config,
set_in_nested_config,
)


class TestSweepConfig:
def test_flatten_params(self):
# simulate grid search config
dummy_config = DictConfig(
{"parent1": {"child1": ["a", "b", "c"], "child2": [1, 2, 3]}, "parent2": ["model1", "model2"]}
)
dummy_config = flatten_sweep_params(dummy_config)
assert dummy_config == {
"parent1.child1": ["a", "b", "c"],
"parent1.child2": [1, 2, 3],
"parent2": ["model1", "model2"],
}

def test_get_run_config(self):
# simulate model config
model_config = DictConfig(
{
"parent1": {
"child1": "e",
"child2": 4,
},
"parent3": False,
}
)
# simulate grid search config
dummy_config = DictConfig({"parent1": {"child1": ["a"], "child2": [1, 2]}, "parent2": ["model1"]})

config_iterator = get_run_config(dummy_config)
# First iteration
run_config = next(config_iterator)
assert run_config == {"parent1.child1": "a", "parent1.child2": 1, "parent2": "model1"}
for param in run_config.keys():
set_in_nested_config(model_config, param.split("."), run_config[param])
assert model_config == {"parent1": {"child1": "a", "child2": 1}, "parent3": False, "parent2": "model1"}

# Second iteration
run_config = next(config_iterator)
assert run_config == {"parent1.child1": "a", "parent1.child2": 2, "parent2": "model1"}
for param in run_config.keys():
set_in_nested_config(model_config, param.split("."), run_config[param])
assert model_config == {"parent1": {"child1": "a", "child2": 2}, "parent3": False, "parent2": "model1"}
Loading

0 comments on commit 07c1b3c

Please sign in to comment.