Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Classify]: Allow inference on dirs and videos #9003

Merged
merged 9 commits into from
Aug 17, 2022
64 changes: 33 additions & 31 deletions classify/predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Run classification inference on images
Run classification inference on file/dir/URL/glob

Usage:
$ python classify/predict.py --weights yolov5s-cls.pt --source data/images/bus.jpg
Expand All @@ -11,7 +11,6 @@
import sys
from pathlib import Path

import cv2
import torch.nn.functional as F

FILE = Path(__file__).resolve()
Expand All @@ -20,73 +19,76 @@
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from classify.train import imshow_cls
from models.common import DetectMultiBackend
from utils.augmentations import classify_transforms
from utils.general import LOGGER, check_requirements, colorstr, increment_path, print_args
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages
from utils.general import LOGGER, check_file, check_requirements, colorstr, increment_path, print_args
from utils.torch_utils import select_device, smart_inference_mode, time_sync


@smart_inference_mode()
def run(
weights=ROOT / 'yolov5s-cls.pt', # model.pt path(s)
source=ROOT / 'data/images/bus.jpg', # file/dir/URL/glob, 0 for webcam
source=ROOT / 'data/images', # file/dir/URL/glob
imgsz=224, # inference size
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
half=False, # use FP16 half-precision inference
dnn=False, # use OpenCV DNN for ONNX inference
show=True,
project=ROOT / 'runs/predict-cls', # save to project/name
name='exp', # save to project/name
exist_ok=False, # existing project/name ok, do not increment
):
file = str(source)
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
if is_url and is_file:
source = check_file(source) # download

seen, dt = 1, [0.0, 0.0, 0.0]
device = select_device(device)

# Directories
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
save_dir.mkdir(parents=True, exist_ok=True) # make dir

# Transforms
transforms = classify_transforms(imgsz)

# Load model
model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half)
model.warmup(imgsz=(1, 3, imgsz, imgsz)) # warmup

# Image
t1 = time_sync()
im = cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB)
im = transforms(im).unsqueeze(0).to(device)
im = im.half() if model.fp16 else im.float()
t2 = time_sync()
dt[0] += t2 - t1

# Inference
results = model(im)
t3 = time_sync()
dt[1] += t3 - t2

p = F.softmax(results, dim=1) # probabilities
i = p.argsort(1, descending=True)[:, :5].squeeze() # top 5 indices
dt[2] += time_sync() - t3
LOGGER.info(f"image 1/1 {file}: {imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i.tolist())}")
dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz))
for path, im, im0s, vid_cap, s in dataset:
# Image
t1 = time_sync()
im = im.unsqueeze(0).to(device)
im = im.half() if model.fp16 else im.float()
t2 = time_sync()
dt[0] += t2 - t1

# Inference
results = model(im)
t3 = time_sync()
dt[1] += t3 - t2

# Post-process
p = F.softmax(results, dim=1) # probabilities
i = p.argsort(1, descending=True)[:, :5].squeeze().tolist() # top 5 indices
dt[2] += time_sync() - t3
# if save:
# imshow_cls(im, f=save_dir / Path(path).name, verbose=True)
seen += 1
LOGGER.info(f"{s}{imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}")

# Print results
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
shape = (1, 3, imgsz, imgsz)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
if show:
imshow_cls(im, f=save_dir / Path(file).name, verbose=True)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
return p


def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model path(s)')
parser.add_argument('--source', type=str, default=ROOT / 'data/images/bus.jpg', help='file')
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob')
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='train, val image size (pixels)')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
Expand Down
25 changes: 13 additions & 12 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __iter__(self):

class LoadImages:
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
def __init__(self, path, img_size=640, stride=32, auto=True):
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None):
files = []
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
p = str(Path(p).resolve())
Expand All @@ -210,6 +210,7 @@ def __init__(self, path, img_size=640, stride=32, auto=True):
self.video_flag = [False] * ni + [True] * nv
self.mode = 'image'
self.auto = auto
self.transforms = transforms # optional
if any(videos):
self.new_video(videos[0]) # new video
else:
Expand All @@ -229,34 +230,34 @@ def __next__(self):
if self.video_flag[self.count]:
# Read video
self.mode = 'video'
ret_val, img0 = self.cap.read()
ret_val, im0 = self.cap.read()
while not ret_val:
self.count += 1
self.cap.release()
if self.count == self.nf: # last video
raise StopIteration
path = self.files[self.count]
self.new_video(path)
ret_val, img0 = self.cap.read()
ret_val, im0 = self.cap.read()

self.frame += 1
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '

else:
# Read image
self.count += 1
img0 = cv2.imread(path) # BGR
assert img0 is not None, f'Image Not Found {path}'
im0 = cv2.imread(path) # BGR
assert im0 is not None, f'Image Not Found {path}'
s = f'image {self.count}/{self.nf} {path}: '

# Padded resize
img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]

# Convert
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)
if self.transforms:
im = self.transforms(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)) # classify transforms
else:
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous

return path, img, img0, self.cap, s
return path, im, im0, self.cap, s

def new_video(self, path):
self.frame = 0
Expand Down