From 84a9f1edae6539e1c8d685558fd644260db49588 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 17 Aug 2022 23:29:45 +0530 Subject: [PATCH 1/9] allow image dirs --- classify/predict.py | 43 +++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/classify/predict.py b/classify/predict.py index 87379e42159b..9bba48c88397 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -25,7 +25,7 @@ from utils.augmentations import classify_transforms from utils.general import LOGGER, check_requirements, colorstr, increment_path, print_args from utils.torch_utils import select_device, smart_inference_mode, time_sync - +from utils.dataloaders import LoadImages @smart_inference_mode() def run( @@ -54,24 +54,27 @@ def run( # Load model model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half) model.warmup(imgsz=(1, 3, imgsz, imgsz)) # warmup - - # Image - t1 = time_sync() - im = cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB) - im = transforms(im).unsqueeze(0).to(device) - im = im.half() if model.fp16 else im.float() - t2 = time_sync() - dt[0] += t2 - t1 - - # Inference - results = model(im) - t3 = time_sync() - dt[1] += t3 - t2 - - p = F.softmax(results, dim=1) # probabilities - i = p.argsort(1, descending=True)[:, :5].squeeze() # top 5 indices - dt[2] += time_sync() - t3 - LOGGER.info(f"image 1/1 {file}: {imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i.tolist())}") + dataset = LoadImages(source, img_size=imgsz) + for data in dataset: + # Image + t1 = time_sync() + path = data[0] + im = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) + im = transforms(im).unsqueeze(0).to(device) + im = im.half() if model.fp16 else im.float() + t2 = time_sync() + dt[0] += t2 - t1 + + # Inference + results = model(im) + t3 = time_sync() + dt[1] += t3 - t2 + + p = F.softmax(results, dim=1) # probabilities + i = p.argsort(1, descending=True)[:, :5].squeeze() # top 5 indices + dt[2] += time_sync() - t3 + seen += 1 + LOGGER.info(f"image 1/1 {file}: {imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i.tolist())}") # Print results t = tuple(x / seen * 1E3 for x in dt) # speeds per image @@ -86,7 +89,7 @@ def run( def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model path(s)') - parser.add_argument('--source', type=str, default=ROOT / 'data/images/bus.jpg', help='file') + parser.add_argument('--source', type=str, default=ROOT / 'data/images/bus.jpg', help='Image file/ dir') # TODO: Video parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='train, val image size (pixels)') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') From dbdbd3d17d67f4912b8a5f44e59219de5af06a08 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Aug 2022 18:04:38 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- classify/predict.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/classify/predict.py b/classify/predict.py index 9bba48c88397..5544bbe6ecbf 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -23,9 +23,10 @@ from classify.train import imshow_cls from models.common import DetectMultiBackend from utils.augmentations import classify_transforms +from utils.dataloaders import LoadImages from utils.general import LOGGER, check_requirements, colorstr, increment_path, print_args from utils.torch_utils import select_device, smart_inference_mode, time_sync -from utils.dataloaders import LoadImages + @smart_inference_mode() def run( @@ -74,7 +75,8 @@ def run( i = p.argsort(1, descending=True)[:, :5].squeeze() # top 5 indices dt[2] += time_sync() - t3 seen += 1 - LOGGER.info(f"image 1/1 {file}: {imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i.tolist())}") + LOGGER.info( + f"image 1/1 {file}: {imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i.tolist())}") # Print results t = tuple(x / seen * 1E3 for x in dt) # speeds per image @@ -89,7 +91,8 @@ def run( def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model path(s)') - parser.add_argument('--source', type=str, default=ROOT / 'data/images/bus.jpg', help='Image file/ dir') # TODO: Video + parser.add_argument('--source', type=str, default=ROOT / 'data/images/bus.jpg', + help='Image file/ dir') # TODO: Video parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='train, val image size (pixels)') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') From 7480c2933d319246258638493ddd1bd013b2f070 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 17 Aug 2022 21:26:47 +0200 Subject: [PATCH 3/9] Update predict.py --- classify/predict.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/classify/predict.py b/classify/predict.py index 5544bbe6ecbf..66981083c3fe 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -20,7 +20,7 @@ sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative -from classify.train import imshow_cls +from utils.plots import imshow_cls from models.common import DetectMultiBackend from utils.augmentations import classify_transforms from utils.dataloaders import LoadImages @@ -31,7 +31,7 @@ @smart_inference_mode() def run( weights=ROOT / 'yolov5s-cls.pt', # model.pt path(s) - source=ROOT / 'data/images/bus.jpg', # file/dir/URL/glob, 0 for webcam + source=ROOT / 'data/images', # file or dir imgsz=224, # inference size device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu half=False, # use FP16 half-precision inference @@ -56,10 +56,9 @@ def run( model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half) model.warmup(imgsz=(1, 3, imgsz, imgsz)) # warmup dataset = LoadImages(source, img_size=imgsz) - for data in dataset: + for path, im, im0s, vid_cap, s in dataset: # Image t1 = time_sync() - path = data[0] im = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) im = transforms(im).unsqueeze(0).to(device) im = im.half() if model.fp16 else im.float() @@ -72,11 +71,10 @@ def run( dt[1] += t3 - t2 p = F.softmax(results, dim=1) # probabilities - i = p.argsort(1, descending=True)[:, :5].squeeze() # top 5 indices + i = p.argsort(1, descending=True)[:, :5].squeeze().tolist() # top 5 indices dt[2] += time_sync() - t3 seen += 1 - LOGGER.info( - f"image 1/1 {file}: {imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i.tolist())}") + LOGGER.info(f"{s}{imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}") # Print results t = tuple(x / seen * 1E3 for x in dt) # speeds per image @@ -91,8 +89,7 @@ def run( def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model path(s)') - parser.add_argument('--source', type=str, default=ROOT / 'data/images/bus.jpg', - help='Image file/ dir') # TODO: Video + parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file or dir') # TODO: Video parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='train, val image size (pixels)') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') From d29e4d5c96183c673445db15f7bed5bb85dd937a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Aug 2022 19:27:34 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- classify/predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/classify/predict.py b/classify/predict.py index 66981083c3fe..c61b026fd7c5 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -20,11 +20,11 @@ sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative -from utils.plots import imshow_cls from models.common import DetectMultiBackend from utils.augmentations import classify_transforms from utils.dataloaders import LoadImages from utils.general import LOGGER, check_requirements, colorstr, increment_path, print_args +from utils.plots import imshow_cls from utils.torch_utils import select_device, smart_inference_mode, time_sync From ece280a78642e871eca62baf3d55842280cd8155 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 17 Aug 2022 22:18:18 +0200 Subject: [PATCH 5/9] Update dataloaders.py --- utils/dataloaders.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 33e84ce4056e..3f26be2cd32d 100755 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -186,7 +186,7 @@ def __iter__(self): class LoadImages: # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4` - def __init__(self, path, img_size=640, stride=32, auto=True): + def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None): files = [] for p in sorted(path) if isinstance(path, (list, tuple)) else [path]: p = str(Path(p).resolve()) @@ -210,6 +210,7 @@ def __init__(self, path, img_size=640, stride=32, auto=True): self.video_flag = [False] * ni + [True] * nv self.mode = 'image' self.auto = auto + self.transforms = transforms # optional if any(videos): self.new_video(videos[0]) # new video else: @@ -229,7 +230,7 @@ def __next__(self): if self.video_flag[self.count]: # Read video self.mode = 'video' - ret_val, img0 = self.cap.read() + ret_val, im0 = self.cap.read() while not ret_val: self.count += 1 self.cap.release() @@ -237,7 +238,7 @@ def __next__(self): raise StopIteration path = self.files[self.count] self.new_video(path) - ret_val, img0 = self.cap.read() + ret_val, im0 = self.cap.read() self.frame += 1 s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ' @@ -245,18 +246,18 @@ def __next__(self): else: # Read image self.count += 1 - img0 = cv2.imread(path) # BGR - assert img0 is not None, f'Image Not Found {path}' + im0 = cv2.imread(path) # BGR + assert im0 is not None, f'Image Not Found {path}' s = f'image {self.count}/{self.nf} {path}: ' - # Padded resize - img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0] - - # Convert - img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB - img = np.ascontiguousarray(img) + if self.transforms: + im = self.transforms(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)) # classify transforms + else: + im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize + im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB + im = np.ascontiguousarray(im) # contiguous - return path, img, img0, self.cap, s + return path, im, im0, self.cap, s def new_video(self, path): self.frame = 0 From aeff18202317645f3f3632c8ee097dad4b20db1c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 17 Aug 2022 22:19:22 +0200 Subject: [PATCH 6/9] Update predict.py --- classify/predict.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/classify/predict.py b/classify/predict.py index c61b026fd7c5..689bf0838453 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -11,7 +11,6 @@ import sys from pathlib import Path -import cv2 import torch.nn.functional as F FILE = Path(__file__).resolve() @@ -20,11 +19,11 @@ sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative +from utils.plots import imshow_cls from models.common import DetectMultiBackend from utils.augmentations import classify_transforms from utils.dataloaders import LoadImages from utils.general import LOGGER, check_requirements, colorstr, increment_path, print_args -from utils.plots import imshow_cls from utils.torch_utils import select_device, smart_inference_mode, time_sync @@ -55,12 +54,11 @@ def run( # Load model model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half) model.warmup(imgsz=(1, 3, imgsz, imgsz)) # warmup - dataset = LoadImages(source, img_size=imgsz) + dataset = LoadImages(source, img_size=imgsz, transforms=transforms) for path, im, im0s, vid_cap, s in dataset: # Image t1 = time_sync() - im = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) - im = transforms(im).unsqueeze(0).to(device) + im = im.unsqueeze(0).to(device) im = im.half() if model.fp16 else im.float() t2 = time_sync() dt[0] += t2 - t1 From ff012e12dd77093c53773c20b819619c26bae2ce Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Aug 2022 20:19:42 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- classify/predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/classify/predict.py b/classify/predict.py index 689bf0838453..e3be32943c1a 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -19,11 +19,11 @@ sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative -from utils.plots import imshow_cls from models.common import DetectMultiBackend from utils.augmentations import classify_transforms from utils.dataloaders import LoadImages from utils.general import LOGGER, check_requirements, colorstr, increment_path, print_args +from utils.plots import imshow_cls from utils.torch_utils import select_device, smart_inference_mode, time_sync From b68dcd5e3243ed5edae515a9e2548370001fa189 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 17 Aug 2022 22:30:04 +0200 Subject: [PATCH 8/9] Update predict.py --- classify/predict.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/classify/predict.py b/classify/predict.py index e3be32943c1a..281a3d37a985 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -1,6 +1,6 @@ # YOLOv5 🚀 by Ultralytics, GPL-3.0 license """ -Run classification inference on images +Run classification inference on file/dir/URL/glob Usage: $ python classify/predict.py --weights yolov5s-cls.pt --source data/images/bus.jpg @@ -21,26 +21,29 @@ from models.common import DetectMultiBackend from utils.augmentations import classify_transforms -from utils.dataloaders import LoadImages -from utils.general import LOGGER, check_requirements, colorstr, increment_path, print_args -from utils.plots import imshow_cls +from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages +from utils.general import LOGGER, check_file, check_requirements, colorstr, increment_path, print_args from utils.torch_utils import select_device, smart_inference_mode, time_sync @smart_inference_mode() def run( weights=ROOT / 'yolov5s-cls.pt', # model.pt path(s) - source=ROOT / 'data/images', # file or dir + source=ROOT / 'data/images', # file/dir/URL/glob imgsz=224, # inference size device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu half=False, # use FP16 half-precision inference dnn=False, # use OpenCV DNN for ONNX inference - show=True, project=ROOT / 'runs/predict-cls', # save to project/name name='exp', # save to project/name exist_ok=False, # existing project/name ok, do not increment ): - file = str(source) + source = str(source) + is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) + is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) + if is_url and is_file: + source = check_file(source) # download + seen, dt = 1, [0.0, 0.0, 0.0] device = select_device(device) @@ -68,9 +71,12 @@ def run( t3 = time_sync() dt[1] += t3 - t2 + # Post-process p = F.softmax(results, dim=1) # probabilities i = p.argsort(1, descending=True)[:, :5].squeeze().tolist() # top 5 indices dt[2] += time_sync() - t3 + # if save: + # imshow_cls(im, f=save_dir / Path(path).name, verbose=True) seen += 1 LOGGER.info(f"{s}{imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}") @@ -78,8 +84,6 @@ def run( t = tuple(x / seen * 1E3 for x in dt) # speeds per image shape = (1, 3, imgsz, imgsz) LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t) - if show: - imshow_cls(im, f=save_dir / Path(file).name, verbose=True) LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") return p @@ -87,7 +91,7 @@ def run( def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model path(s)') - parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file or dir') # TODO: Video + parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob') parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='train, val image size (pixels)') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') From 5b82b318de08120808d4305d62f13751e2e64195 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 17 Aug 2022 22:33:31 +0200 Subject: [PATCH 9/9] Update predict.py --- classify/predict.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/classify/predict.py b/classify/predict.py index 281a3d37a985..7af5f60a2b9d 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -51,13 +51,10 @@ def run( save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run save_dir.mkdir(parents=True, exist_ok=True) # make dir - # Transforms - transforms = classify_transforms(imgsz) - # Load model model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half) model.warmup(imgsz=(1, 3, imgsz, imgsz)) # warmup - dataset = LoadImages(source, img_size=imgsz, transforms=transforms) + dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz)) for path, im, im0s, vid_cap, s in dataset: # Image t1 = time_sync()