Skip to content

Commit

Permalink
🐞 Fix inferencer in Gradio (#332)
Browse files Browse the repository at this point in the history
* Fix inferencer in gradio

* Adddress PR comments

* Address PR comments
  • Loading branch information
ashwinvaidya17 committed May 25, 2022
1 parent c7d5232 commit b044e63
Showing 1 changed file with 32 additions and 16 deletions.
48 changes: 32 additions & 16 deletions tools/inference_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
from argparse import ArgumentParser, Namespace
from importlib import import_module
from pathlib import Path
from typing import Tuple, Union
from typing import Optional, Tuple

import gradio as gr
import gradio.inputs
import gradio.outputs
import numpy as np
from omegaconf import DictConfig, ListConfig
from skimage.segmentation import mark_boundaries

from anomalib.config import get_configurable_parameters
Expand Down Expand Up @@ -46,13 +45,21 @@ def infer(
def get_args() -> Namespace:
"""Get command line arguments.
Example:
>>> python tools/inference_gradio.py \
--config_path ./anomalib/models/padim/config.yaml \
--weight_path ./results/padim/mvtec/bottle/weights/model.ckpt
Returns:
Namespace: List of arguments.
"""
parser = ArgumentParser()
parser.add_argument("--config", type=Path, required=True, help="Path to a model config file")
parser.add_argument("--config_path", type=Path, required=True, help="Path to a model config file")
parser.add_argument("--weight_path", type=Path, required=True, help="Path to a model weights")
parser.add_argument("--meta_data", type=Path, required=False, help="Path to JSON file containing the metadata.")
parser.add_argument(
"--meta_data_path", type=Path, required=False, help="Path to JSON file containing the metadata."
)

parser.add_argument(
"--threshold",
Expand All @@ -69,26 +76,35 @@ def get_args() -> Namespace:
return args


def get_inferencer(config_path: Path, weight_path: Path, meta_data_path: Path) -> Inferencer:
"""Parse args and open inferencer."""
config = get_configurable_parameters(config_path)
def get_inferencer(config_path: Path, weight_path: Path, meta_data_path: Optional[Path] = None) -> Inferencer:
"""Parse args and open inferencer.
Args:
config_path (Path): Path to model configuration file or the name of the model.
weight_path (Path): Path to model weights.
meta_data_path (Optional[Path], optional): Metadata is required for OpenVINO models. Defaults to None.
Raises:
ValueError: If unsupported model weight is passed.
Returns:
Inferencer: Torch or OpenVINO inferencer.
"""
config = get_configurable_parameters(config_path=config_path)

# Get the inferencer. We use .ckpt extension for Torch models and (onnx, bin)
# for the openvino models.
extension = weight_path.suffix
inferencer: Inferencer
if extension in (".ckpt"):
module = import_module("anomalib.deploy.inferencers.torch")
TorchInferencer = getattr(module, "TorchInferencer") # pylint: disable=invalid-name
inferencer = TorchInferencer(
config=config, model_source=weight_path, meta_data_path=meta_data
)
TorchInferencer = getattr(module, "TorchInferencer")
inferencer = TorchInferencer(config=config, model_source=weight_path, meta_data_path=meta_data_path)

elif extension in (".onnx", ".bin", ".xml"):
module = import_module("anomalib.deploy.inferencers.openvino")
OpenVINOInferencer = getattr(module, "OpenVINOInferencer") # pylint: disable=invalid-name
inferencer = OpenVINOInferencer(
config=config, path=weight_path, meta_data_path=meta_data
)
OpenVINOInferencer = getattr(module, "OpenVINOInferencer")
inferencer = OpenVINOInferencer(config=config, path=weight_path, meta_data_path=meta_data_path)

else:
raise ValueError(
Expand All @@ -102,7 +118,7 @@ def get_inferencer(config_path: Path, weight_path: Path, meta_data_path: Path) -
if __name__ == "__main__":
session_args = get_args()

gradio_inferencer = get_inferencer(session_args.config, session_args.weight_path, session_args.meta_data)
gradio_inferencer = get_inferencer(session_args.config_path, session_args.weight_path, session_args.meta_data_path)

interface = gr.Interface(
fn=lambda image, threshold: infer(image, gradio_inferencer, threshold),
Expand Down

0 comments on commit b044e63

Please sign in to comment.