Skip to content

Commit

Permalink
Add Gradio (#283)
Browse files Browse the repository at this point in the history
* add gradio

* linting fix

* fix threshold

* fix share

* fix docstring

* fix typo

Co-authored-by: julien.blanchon <julien.blanchon@thalesaleniaspace.com>
  • Loading branch information
julien-blanchon and julien.blanchon committed Apr 29, 2022
1 parent 9076a95 commit edf9393
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ torchvision>=0.9.1
torchtext>=0.9.1
wandb==0.12.9
matplotlib>=3.4.3
gradio>=2.9.4
126 changes: 126 additions & 0 deletions tools/inference_gradio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Anomalib Gradio Script.
This script provide a gradio web interface
"""

from argparse import ArgumentParser, Namespace
from importlib import import_module
from pathlib import Path
from typing import Tuple, Union

import gradio as gr
import gradio.inputs
import gradio.outputs
import numpy as np
from omegaconf import DictConfig, ListConfig
from skimage.segmentation import mark_boundaries

from anomalib.config import get_configurable_parameters
from anomalib.deploy.inferencers.base import Inferencer
from anomalib.post_processing import compute_mask, superimpose_anomaly_map


def infer(
image: np.ndarray, inferencer: Inferencer, threshold: float = 50.0
) -> Tuple[np.ndarray, float, np.ndarray, np.ndarray, np.ndarray]:
"""Inference function, return anomaly map, score, heat map, prediction mask ans visualisation.
Args:
image (np.ndarray): image to compute
inferencer (Inferencer): model inferencer
threshold (float, optional): threshold between 0 and 100. Defaults to 50.0.
Returns:
Tuple[np.ndarray, float, np.ndarray, np.ndarray, np.ndarray]:
anomaly_map, anomaly_score, heat_map, pred_mask, vis_img
"""
# Perform inference for the given image.
threshold = threshold / 100
anomaly_map, anomaly_score = inferencer.predict(image=image, superimpose=False)
heat_map = superimpose_anomaly_map(anomaly_map, image)
pred_mask = compute_mask(anomaly_map, threshold)
vis_img = mark_boundaries(image, pred_mask, color=(1, 0, 0), mode="thick")
return anomaly_map, anomaly_score, heat_map, pred_mask, vis_img


def get_args() -> Namespace:
"""Get command line arguments.
Returns:
Namespace: List of arguments.
"""
parser = ArgumentParser()
parser.add_argument("--config", type=Path, required=True, help="Path to a model config file")
parser.add_argument("--weight_path", type=Path, required=True, help="Path to a model weights")
parser.add_argument("--meta_data", type=Path, required=False, help="Path to JSON file containing the metadata.")

parser.add_argument(
"--threshold",
type=float,
required=False,
default=75.0,
help="Value to threshold anomaly scores into 0-100 range",
)

parser.add_argument("--share", type=bool, required=False, default=False, help="Share Gradio `share_url`")

args = parser.parse_args()

return args


def get_inferencer(config_path: Path, weight_path: Path, meta_data_path: Path) -> Inferencer:
"""Parse args and open inferencer."""
config = get_configurable_parameters(config_path)
# Get the inferencer. We use .ckpt extension for Torch models and (onnx, bin)
# for the openvino models.
extension = weight_path.suffix
inferencer: Inferencer
if extension in (".ckpt"):
module = import_module("anomalib.deploy.inferencers.torch")
TorchInferencer = getattr(module, "TorchInferencer") # pylint: disable=invalid-name
inferencer = TorchInferencer(
config=config, model_source=weight_path, meta_data_path=meta_data
)

elif extension in (".onnx", ".bin", ".xml"):
module = import_module("anomalib.deploy.inferencers.openvino")
OpenVINOInferencer = getattr(module, "OpenVINOInferencer") # pylint: disable=invalid-name
inferencer = OpenVINOInferencer(
config=config, path=weight_path, meta_data_path=meta_data
)

else:
raise ValueError(
f"Model extension is not supported. Torch Inferencer exptects a .ckpt file,"
f"OpenVINO Inferencer expects either .onnx, .bin or .xml file. Got {extension}"
)

return inferencer


if __name__ == "__main__":
session_args = get_args()

gradio_inferencer = get_inferencer(session_args.config, session_args.weight_path, session_args.meta_data)

interface = gr.Interface(
fn=lambda image, threshold: infer(image, gradio_inferencer, threshold),
inputs=[
gradio.inputs.Image(
shape=None, image_mode="RGB", source="upload", tool="editor", type="numpy", label="Image"
),
gradio.inputs.Slider(default=session_args.threshold, label="threshold", optional=False),
],
outputs=[
gradio.outputs.Image(type="numpy", label="Anomaly Map"),
gradio.outputs.Textbox(type="number", label="Anomaly Score"),
gradio.outputs.Image(type="numpy", label="Predicted Heat Map"),
gradio.outputs.Image(type="numpy", label="Predicted Mask"),
gradio.outputs.Image(type="numpy", label="Segmentation Result"),
],
title="Anomalib",
description="Anomalib Gradio",
)

interface.launch(share=session_args.share)

0 comments on commit edf9393

Please sign in to comment.