Skip to content

Commit

Permalink
Directory streaming (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinvaidya17 committed Apr 11, 2022
1 parent 2c71f97 commit 2405892
Showing 1 changed file with 41 additions and 9 deletions.
50 changes: 41 additions & 9 deletions tools/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from argparse import ArgumentParser, Namespace
from importlib import import_module
from pathlib import Path
from typing import Optional

import cv2
import numpy as np
Expand Down Expand Up @@ -65,9 +66,11 @@ def add_label(prediction: np.ndarray, scores: float, font: int = cv2.FONT_HERSHE
return prediction


def infer() -> None:
"""Perform inference on an input image."""
def stream() -> None:
"""Stream predictions.
Show/save the output if path is to an image. If the path is a directory, go over each image in the directory.
"""
# Get the command line arguments, and config from the config.yaml file.
# This config file is also used for training and contains all the relevant
# information regarding the data, model, train and inference details.
Expand All @@ -77,28 +80,51 @@ def infer() -> None:
# Get the inferencer. We use .ckpt extension for Torch models and (onnx, bin)
# for the openvino models.
extension = args.weight_path.suffix
inference: Inferencer
inferencer: Inferencer
if extension in (".ckpt"):
module = import_module("anomalib.deploy.inferencers.torch")
TorchInferencer = getattr(module, "TorchInferencer") # pylint: disable=invalid-name
inference = TorchInferencer(config=config, model_source=args.weight_path, meta_data_path=args.meta_data)
inferencer = TorchInferencer(config=config, model_source=args.weight_path, meta_data_path=args.meta_data)

elif extension in (".onnx", ".bin", ".xml"):
module = import_module("anomalib.deploy.inferencers.openvino")
OpenVINOInferencer = getattr(module, "OpenVINOInferencer") # pylint: disable=invalid-name
inference = OpenVINOInferencer(config=config, path=args.weight_path, meta_data_path=args.meta_data)
inferencer = OpenVINOInferencer(config=config, path=args.weight_path, meta_data_path=args.meta_data)

else:
raise ValueError(
f"Model extension is not supported. Torch Inferencer exptects a .ckpt file,"
f"OpenVINO Inferencer expects either .onnx, .bin or .xml file. Got {extension}"
)
if args.image_path.is_dir():
# Write the output to save_path in the same structure as the input directory.
for image in args.image_path.glob("**/*"):
if image.is_file() and image.suffix in (".jpg", ".png", ".jpeg"):
# Here save_path is assumed to be a directory. Image subdirectories are appended to the save_path.
save_path = Path(args.save_path / image.relative_to(args.image_path).parent) if args.save_path else None
infer(image, inferencer, save_path)
elif args.image_path.suffix in (".jpg", ".png", ".jpeg"):
infer(args.image_path, inferencer, args.save_path)
else:
raise ValueError(
f"Image extension is not supported. Supported extensions are .jpg, .png, .jpeg."
f" Got {args.image_path.suffix}"
)


def infer(image_path: Path, inferencer: Inferencer, save_path: Optional[Path] = None) -> None:
"""Perform inference on a single image.
Args:
image_path (Path): Path to image/directory containing images.
inferencer (Inferencer): Inferencer to use.
save_path (Path, optional): Path to save the output image. If this is None, the output is visualized.
"""
# Perform inference for the given image or image path. if image
# path is provided, `predict` method will read the image from
# file for convenience. We set the superimpose flag to True
# to overlay the predicted anomaly map on top of the input image.
output = inference.predict(image=args.image_path, superimpose=True)
output = inferencer.predict(image=image_path, superimpose=True)

# Incase both anomaly map and scores are returned add scores to the image.
if isinstance(output, tuple):
Expand All @@ -108,11 +134,17 @@ def infer() -> None:
# Show or save the output image, depending on what's provided as
# the command line argument.
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
if args.save_path is None:
if save_path is None:
cv2.imshow("Anomaly Map", output)
cv2.waitKey(0) # wait for any key press
else:
cv2.imwrite(filename=str(args.save_path), img=output)
# Create directory for parents if it doesn't exist.
save_path.parent.mkdir(parents=True, exist_ok=True)
if save_path.suffix == "": # This is a directory
save_path.mkdir(exist_ok=True) # Create current directory
save_path = save_path / image_path.name
cv2.imwrite(filename=str(save_path), img=output)


if __name__ == "__main__":
infer()
stream()

0 comments on commit 2405892

Please sign in to comment.