Skip to content

Commit

Permalink
[Classify]: Allow inference on dirs and videos (#9003)
Browse files Browse the repository at this point in the history
* allow image dirs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update predict.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dataloaders.py

* Update predict.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update predict.py

* Update predict.py

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
3 people committed Aug 17, 2022
1 parent e83b422 commit 64e0757
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 43 deletions.
64 changes: 33 additions & 31 deletions classify/predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Run classification inference on images
Run classification inference on file/dir/URL/glob
Usage:
$ python classify/predict.py --weights yolov5s-cls.pt --source data/images/bus.jpg
Expand All @@ -11,7 +11,6 @@
import sys
from pathlib import Path

import cv2
import torch.nn.functional as F

FILE = Path(__file__).resolve()
Expand All @@ -20,73 +19,76 @@
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from classify.train import imshow_cls
from models.common import DetectMultiBackend
from utils.augmentations import classify_transforms
from utils.general import LOGGER, check_requirements, colorstr, increment_path, print_args
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages
from utils.general import LOGGER, check_file, check_requirements, colorstr, increment_path, print_args
from utils.torch_utils import select_device, smart_inference_mode, time_sync


@smart_inference_mode()
def run(
weights=ROOT / 'yolov5s-cls.pt', # model.pt path(s)
source=ROOT / 'data/images/bus.jpg', # file/dir/URL/glob, 0 for webcam
source=ROOT / 'data/images', # file/dir/URL/glob
imgsz=224, # inference size
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
half=False, # use FP16 half-precision inference
dnn=False, # use OpenCV DNN for ONNX inference
show=True,
project=ROOT / 'runs/predict-cls', # save to project/name
name='exp', # save to project/name
exist_ok=False, # existing project/name ok, do not increment
):
file = str(source)
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
if is_url and is_file:
source = check_file(source) # download

seen, dt = 1, [0.0, 0.0, 0.0]
device = select_device(device)

# Directories
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
save_dir.mkdir(parents=True, exist_ok=True) # make dir

# Transforms
transforms = classify_transforms(imgsz)

# Load model
model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half)
model.warmup(imgsz=(1, 3, imgsz, imgsz)) # warmup

# Image
t1 = time_sync()
im = cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB)
im = transforms(im).unsqueeze(0).to(device)
im = im.half() if model.fp16 else im.float()
t2 = time_sync()
dt[0] += t2 - t1

# Inference
results = model(im)
t3 = time_sync()
dt[1] += t3 - t2

p = F.softmax(results, dim=1) # probabilities
i = p.argsort(1, descending=True)[:, :5].squeeze() # top 5 indices
dt[2] += time_sync() - t3
LOGGER.info(f"image 1/1 {file}: {imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i.tolist())}")
dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz))
for path, im, im0s, vid_cap, s in dataset:
# Image
t1 = time_sync()
im = im.unsqueeze(0).to(device)
im = im.half() if model.fp16 else im.float()
t2 = time_sync()
dt[0] += t2 - t1

# Inference
results = model(im)
t3 = time_sync()
dt[1] += t3 - t2

# Post-process
p = F.softmax(results, dim=1) # probabilities
i = p.argsort(1, descending=True)[:, :5].squeeze().tolist() # top 5 indices
dt[2] += time_sync() - t3
# if save:
# imshow_cls(im, f=save_dir / Path(path).name, verbose=True)
seen += 1
LOGGER.info(f"{s}{imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}")

# Print results
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
shape = (1, 3, imgsz, imgsz)
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
if show:
imshow_cls(im, f=save_dir / Path(file).name, verbose=True)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
return p


def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model path(s)')
parser.add_argument('--source', type=str, default=ROOT / 'data/images/bus.jpg', help='file')
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob')
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='train, val image size (pixels)')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
Expand Down
25 changes: 13 additions & 12 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __iter__(self):

class LoadImages:
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
def __init__(self, path, img_size=640, stride=32, auto=True):
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None):
files = []
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
p = str(Path(p).resolve())
Expand All @@ -210,6 +210,7 @@ def __init__(self, path, img_size=640, stride=32, auto=True):
self.video_flag = [False] * ni + [True] * nv
self.mode = 'image'
self.auto = auto
self.transforms = transforms # optional
if any(videos):
self.new_video(videos[0]) # new video
else:
Expand All @@ -229,34 +230,34 @@ def __next__(self):
if self.video_flag[self.count]:
# Read video
self.mode = 'video'
ret_val, img0 = self.cap.read()
ret_val, im0 = self.cap.read()
while not ret_val:
self.count += 1
self.cap.release()
if self.count == self.nf: # last video
raise StopIteration
path = self.files[self.count]
self.new_video(path)
ret_val, img0 = self.cap.read()
ret_val, im0 = self.cap.read()

self.frame += 1
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '

else:
# Read image
self.count += 1
img0 = cv2.imread(path) # BGR
assert img0 is not None, f'Image Not Found {path}'
im0 = cv2.imread(path) # BGR
assert im0 is not None, f'Image Not Found {path}'
s = f'image {self.count}/{self.nf} {path}: '

# Padded resize
img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]

# Convert
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)
if self.transforms:
im = self.transforms(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)) # classify transforms
else:
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous

return path, img, img0, self.cap, s
return path, im, im0, self.cap, s

def new_video(self, path):
self.frame = 0
Expand Down

0 comments on commit 64e0757

Please sign in to comment.