diff --git a/tools/inference.py b/tools/inference.py index c19531c123..1c42cec3c7 100644 --- a/tools/inference.py +++ b/tools/inference.py @@ -21,6 +21,7 @@ from argparse import ArgumentParser, Namespace from importlib import import_module from pathlib import Path +from typing import Optional import cv2 import numpy as np @@ -65,9 +66,11 @@ def add_label(prediction: np.ndarray, scores: float, font: int = cv2.FONT_HERSHE return prediction -def infer() -> None: - """Perform inference on an input image.""" +def stream() -> None: + """Stream predictions. + Show/save the output if path is to an image. If the path is a directory, go over each image in the directory. + """ # Get the command line arguments, and config from the config.yaml file. # This config file is also used for training and contains all the relevant # information regarding the data, model, train and inference details. @@ -77,28 +80,51 @@ def infer() -> None: # Get the inferencer. We use .ckpt extension for Torch models and (onnx, bin) # for the openvino models. extension = args.weight_path.suffix - inference: Inferencer + inferencer: Inferencer if extension in (".ckpt"): module = import_module("anomalib.deploy.inferencers.torch") TorchInferencer = getattr(module, "TorchInferencer") # pylint: disable=invalid-name - inference = TorchInferencer(config=config, model_source=args.weight_path, meta_data_path=args.meta_data) + inferencer = TorchInferencer(config=config, model_source=args.weight_path, meta_data_path=args.meta_data) elif extension in (".onnx", ".bin", ".xml"): module = import_module("anomalib.deploy.inferencers.openvino") OpenVINOInferencer = getattr(module, "OpenVINOInferencer") # pylint: disable=invalid-name - inference = OpenVINOInferencer(config=config, path=args.weight_path, meta_data_path=args.meta_data) + inferencer = OpenVINOInferencer(config=config, path=args.weight_path, meta_data_path=args.meta_data) else: raise ValueError( f"Model extension is not supported. Torch Inferencer exptects a .ckpt file," f"OpenVINO Inferencer expects either .onnx, .bin or .xml file. Got {extension}" ) + if args.image_path.is_dir(): + # Write the output to save_path in the same structure as the input directory. + for image in args.image_path.glob("**/*"): + if image.is_file() and image.suffix in (".jpg", ".png", ".jpeg"): + # Here save_path is assumed to be a directory. Image subdirectories are appended to the save_path. + save_path = Path(args.save_path / image.relative_to(args.image_path).parent) if args.save_path else None + infer(image, inferencer, save_path) + elif args.image_path.suffix in (".jpg", ".png", ".jpeg"): + infer(args.image_path, inferencer, args.save_path) + else: + raise ValueError( + f"Image extension is not supported. Supported extensions are .jpg, .png, .jpeg." + f" Got {args.image_path.suffix}" + ) + +def infer(image_path: Path, inferencer: Inferencer, save_path: Optional[Path] = None) -> None: + """Perform inference on a single image. + + Args: + image_path (Path): Path to image/directory containing images. + inferencer (Inferencer): Inferencer to use. + save_path (Path, optional): Path to save the output image. If this is None, the output is visualized. + """ # Perform inference for the given image or image path. if image # path is provided, `predict` method will read the image from # file for convenience. We set the superimpose flag to True # to overlay the predicted anomaly map on top of the input image. - output = inference.predict(image=args.image_path, superimpose=True) + output = inferencer.predict(image=image_path, superimpose=True) # Incase both anomaly map and scores are returned add scores to the image. if isinstance(output, tuple): @@ -108,11 +134,17 @@ def infer() -> None: # Show or save the output image, depending on what's provided as # the command line argument. output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) - if args.save_path is None: + if save_path is None: cv2.imshow("Anomaly Map", output) + cv2.waitKey(0) # wait for any key press else: - cv2.imwrite(filename=str(args.save_path), img=output) + # Create directory for parents if it doesn't exist. + save_path.parent.mkdir(parents=True, exist_ok=True) + if save_path.suffix == "": # This is a directory + save_path.mkdir(exist_ok=True) # Create current directory + save_path = save_path / image_path.name + cv2.imwrite(filename=str(save_path), img=output) if __name__ == "__main__": - infer() + stream()