Skip to content

Commit

Permalink
Fix/da/image size bug (#135)
Browse files Browse the repository at this point in the history
* fix image size bug and add test case

* fix padim inference

* docstrings

* use dummy dataset for image size tests

* detach before moving to cpu
  • Loading branch information
djdameln committed Mar 9, 2022
1 parent 0d23715 commit ee8807b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 4 deletions.
2 changes: 1 addition & 1 deletion anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
# TODO: Remove config values. IAAALD-211
root=config.dataset.path,
category=config.dataset.category,
image_size=(config.dataset.image_size[0], config.dataset.image_size[0]),
image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),
train_batch_size=config.dataset.train_batch_size,
test_batch_size=config.dataset.test_batch_size,
num_workers=config.dataset.num_workers,
Expand Down
2 changes: 1 addition & 1 deletion anomalib/deploy/inferencers/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def post_process(
anomaly_map, pred_score = self._normalize(anomaly_map, pred_score, meta_data)

if isinstance(anomaly_map, Tensor):
anomaly_map = anomaly_map.cpu().numpy()
anomaly_map = anomaly_map.detach().cpu().numpy()

if "image_shape" in meta_data and anomaly_map.shape != meta_data["image_shape"]:
anomaly_map = cv2.resize(anomaly_map, meta_data["image_shape"])
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/padim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(

n_features = DIMS[backbone]["reduced_dims"]
patches_dims = torch.tensor(input_size) / DIMS[backbone]["emb_scale"]
n_patches = patches_dims.prod().int().item()
n_patches = patches_dims.ceil().prod().int().item()
self.gaussian = MultiVariateGaussian(n_features, n_patches)

if apply_tiling:
Expand Down
31 changes: 30 additions & 1 deletion tests/pre_merge/datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import numpy as np
import pytest

from anomalib.config import get_configurable_parameters, update_input_size_config
from anomalib.data import get_datamodule
from anomalib.data.mvtec import MVTecDataModule
from anomalib.pre_processing.transforms import Denormalize, ToNumpy
from tests.helpers.dataset import get_dataset_path
from tests.helpers.dataset import TestDataset, get_dataset_path


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -100,3 +102,30 @@ def test_one_channel_images(self, data_sample):
def test_representation(self):
"""Test ToNumpy() representation should return string `ToNumpy()`"""
assert str(ToNumpy()) == "ToNumpy()"


class TestConfigToDataModule:
"""Tests that check if the dataset parameters in the config achieve the desired effect."""

@pytest.mark.parametrize(
["input_size", "effective_image_size"],
[
(512, (512, 512)),
((245, 276), (245, 276)),
((263, 134), (263, 134)),
((267, 267), (267, 267)),
],
)
@TestDataset(num_train=20, num_test=10)
def test_image_size(self, input_size, effective_image_size, category="shapes", path=""):
"""Test if the image size parameter works as expected."""
model_name = "stfpm"
configurable_parameters = get_configurable_parameters(model_name)
configurable_parameters.dataset.path = path
configurable_parameters.dataset.category = category
configurable_parameters.dataset.image_size = input_size
configurable_parameters = update_input_size_config(configurable_parameters)

data_module = get_datamodule(configurable_parameters)
data_module.setup()
assert iter(data_module.train_dataloader()).__next__()["image"].shape[-2:] == effective_image_size

0 comments on commit ee8807b

Please sign in to comment.