Skip to content

Commit

Permalink
🐞 Fix inference for draem (#470)
Browse files Browse the repository at this point in the history
* fix lightning and openvino inference for draem

* add draem to inference tests
  • Loading branch information
djdameln committed Aug 2, 2022
1 parent 496786b commit 92a4b95
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 15 deletions.
6 changes: 4 additions & 2 deletions anomalib/deploy/inferencers/openvino_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ def pre_process(self, image: np.ndarray) -> np.ndarray:
Returns:
np.ndarray: pre-processed image.
"""
config = self.config.transform if "transform" in self.config.keys() else None
transform_config = (
self.config.dataset.transform_config.val if "transform_config" in self.config.dataset.keys() else None
)
image_size = tuple(self.config.dataset.image_size)
pre_processor = PreProcessor(config, image_size)
pre_processor = PreProcessor(transform_config, image_size)
processed_image = pre_processor(image=image)["image"]

if len(processed_image.shape) == 3:
Expand Down
8 changes: 5 additions & 3 deletions anomalib/deploy/inferencers/torch_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ def pre_process(self, image: np.ndarray) -> Tensor:
Returns:
Tensor: pre-processed image.
"""
config = self.config.transform if "transform" in self.config.keys() else None
transform_config = (
self.config.dataset.transform_config.val if "transform_config" in self.config.dataset.keys() else None
)
image_size = tuple(self.config.dataset.image_size)
pre_processor = PreProcessor(config, image_size)
pre_processor = PreProcessor(transform_config, image_size)
processed_image = pre_processor(image=image)["image"]

if len(processed_image) == 3:
Expand Down Expand Up @@ -143,7 +145,7 @@ def post_process(self, predictions: Tensor, meta_data: Optional[Union[Dict, Dict
meta_data = self.meta_data

if isinstance(predictions, Tensor):
anomaly_map = predictions.cpu().numpy()
anomaly_map = predictions.detach().cpu().numpy()
pred_score = anomaly_map.reshape(-1).max()
else:
# NOTE: Patchcore `forward`` returns heatmap and score.
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/draem/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def validation_step(self, batch, _):
Dictionary to which predicted anomaly maps have been added.
"""
prediction = self.model(batch["image"])
batch["anomaly_maps"] = prediction[:, 1, :, :]
batch["anomaly_maps"] = prediction
return batch


Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/draem/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, batch: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
prediction = self.discriminative_subnetwork(concatenated_inputs)
if self.training:
return reconstruction, prediction
return torch.softmax(prediction, dim=1)
return torch.softmax(prediction, dim=1)[:, 1, ...]


class ReconstructiveSubNetwork(nn.Module):
Expand Down
20 changes: 13 additions & 7 deletions tests/pre_merge/deploy/test_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,18 @@ def get_model_config(
class TestInferencers:
@pytest.mark.parametrize(
"model_name",
["cflow", "dfm", "dfkde", "fastflow", "ganomaly", "padim", "patchcore", "reverse_distillation", "stfpm"],
[
"cflow",
"dfm",
"dfkde",
"draem",
"fastflow",
"ganomaly",
"padim",
"patchcore",
"reverse_distillation",
"stfpm",
],
)
@TestDataset(num_train=20, num_test=1, path=get_dataset_path(), use_mvtec=False)
def test_torch_inference(self, model_name: str, category: str = "shapes", path: str = "./datasets/MVTec"):
Expand Down Expand Up @@ -81,12 +92,7 @@ def test_torch_inference(self, model_name: str, category: str = "shapes", path:

@pytest.mark.parametrize(
"model_name",
[
"dfm",
"ganomaly",
"padim",
"stfpm",
],
["dfm", "draem", "ganomaly", "padim", "stfpm"],
)
@TestDataset(num_train=20, num_test=1, path=get_dataset_path(), use_mvtec=False)
def test_openvino_inference(self, model_name: str, category: str = "shapes", path: str = "./datasets/MVTec"):
Expand Down
5 changes: 4 additions & 1 deletion tools/inference/lightning_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def infer():

trainer = Trainer(callbacks=callbacks, **config.trainer)

dataset = InferenceDataset(args.input, image_size=tuple(config.dataset.image_size))
transform_config = config.dataset.transform_config.val if "transform_config" in config.dataset.keys() else None
dataset = InferenceDataset(
args.input, image_size=tuple(config.dataset.image_size), transform_config=transform_config
)
dataloader = DataLoader(dataset)
trainer.predict(model=model, dataloaders=[dataloader])

Expand Down

0 comments on commit 92a4b95

Please sign in to comment.