Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gradio demo for yolov5 (command line and colab notebook) #10572

Closed
wants to merge 9 commits into from
43 changes: 40 additions & 3 deletions tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "f9f016ad-3dcf-4bd2-e1c3-d5b79efc6f32"
"outputId": "f0aa921f-9303-46bb-c590-212b518edbe9"
},
"source": [
"!git clone https://github.com/ultralytics/yolov5 # clone\n",
Expand All @@ -418,14 +418,14 @@
"output_type": "stream",
"name": "stderr",
"text": [
"YOLOv5 πŸš€ v7.0-1-gb32f67f Python-3.7.15 torch-1.12.1+cu113 CUDA:0 (Tesla T4, 15110MiB)\n"
"YOLOv5 πŸš€ 2022-12-23 Python-3.8.16 torch-1.13.0+cu116 CUDA:0 (Tesla T4, 15110MiB)\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Setup complete βœ… (2 CPUs, 12.7 GB RAM, 22.6/78.2 GB disk)\n"
"Setup complete βœ… (2 CPUs, 12.7 GB RAM, 23.1/78.2 GB disk)\n"
]
}
]
Expand Down Expand Up @@ -971,6 +971,43 @@
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Integrated gradio interface."
],
"metadata": {
"id": "FUxwtzRH2a5Y"
}
},
{
"cell_type": "code",
"source": [
"%pip install -q gradio\n",
"import gradio as gr\n",
"\n",
"def predict(inp, conf, iou, agnostic_nms):\n",
" model.conf = conf\n",
" model.iou = iou\n",
" model.agnostic = agnostic_nms\n",
" res = model([inp[..., ::-1]], size=imgsz).render()[0][..., ::-1]\n",
" return res\n",
"\n",
"imgsz = [640, 640]\n",
"model = torch.hub.load('ultralytics/yolov5', 'yolov5s')\n",
"# model = torch.hub.load('./', 'custom', source='local', path='yolov5s.pt')\n",
"demo = gr.Interface(fn=predict,\n",
" inputs=[gr.Image(), gr.Slider(0, 1, 0.25), gr.Slider(0, 1, 0.45), gr.Checkbox()],\n",
" outputs=\"image\",\n",
" examples=[['data/images/bus.jpg'], ['data/images/zidane.jpg']])\n",
"demo.launch()"
],
"metadata": {
"id": "YVJRQLMQuxxm"
},
"execution_count": null,
"outputs": []
}
]
}
69 changes: 69 additions & 0 deletions utils/gradio/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
Run Gradio demo.
Add --host 0.0.0.0 to share with other machines by IP address.

Usage - local model:
python utils/gradio/demo.py --path MODEL_FILE

Usage - github repo:
python utils/gradio/demo.py --model yolov5s
"""

import argparse
import os
import sys
from pathlib import Path

import torch

FILE = Path(__file__).resolve()
ROOT = FILE.parents[2] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from utils.dataloaders import IMG_FORMATS
from utils.general import check_requirements, print_args


def predict(inp, conf, iou, agnostic_nms):
model.conf = conf
model.iou = iou
model.agnostic = agnostic_nms
res = model([inp[..., ::-1]], size=opt.imgsz).render()[0][..., ::-1]
return res


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
parser.add_argument('--share', action='store_true', help='share yolov5 demo with public link')
parser.add_argument('--host', type=str, default='localhost', help='server ip/name (0.0.0.0 for network request)')
parser.add_argument('--port', type=int, default=7860, help='server port')
parser.add_argument('--example_dir', type=str, default=ROOT / 'data/images', help='example image dir')
parser.add_argument('--model', type=str, default='custom', help='model name used by github source')
parser.add_argument('--path', type=str, default=ROOT / 'yolov5l.pt', help='local model path')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
print_args(vars(opt))

check_requirements(exclude=('tensorboard', 'thop'))
check_requirements('gradio')
import gradio as gr

files = Path(opt.example_dir).glob('*')
examples = []
for f in files:
if f.suffix.lower()[1:] in IMG_FORMATS:
examples.append([f])
source = 'local' if opt.model == 'custom' else 'github'
kwargs = {'path': opt.path} if source == 'local' else {}
repo = ROOT if source == 'local' else 'ultralytics/yolov5'
model = torch.hub.load(repo, opt.model, source=source, **kwargs)
demo = gr.Interface(fn=predict,
inputs=[gr.Image(), gr.Slider(0, 1, 0.25),
gr.Slider(0, 1, 0.45),
gr.Checkbox()],
outputs="image",
examples=examples)
demo.launch(share=opt.share, server_name=opt.host, server_port=opt.port)