diff --git a/README.md b/README.md index da8bf1dad862..1d43111d56e7 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,7 @@ the latest YOLOv5 [release](https://github.com/ultralytics/yolov5/releases) and python detect.py --source 0 # webcam img.jpg # image vid.mp4 # video + screen # screenshot path/ # directory 'path/*.jpg' # glob 'https://youtu.be/Zgi9g1ksQHc' # YouTube diff --git a/classify/predict.py b/classify/predict.py index ef59ff6f550a..011e7b83f09b 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -42,7 +42,7 @@ from models.common import DetectMultiBackend from utils.augmentations import classify_transforms -from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams +from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2, increment_path, print_args, strip_optimizer) from utils.plots import Annotator @@ -52,7 +52,7 @@ @smart_inference_mode() def run( weights=ROOT / 'yolov5s-cls.pt', # model.pt path(s) - source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam + source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam) data=ROOT / 'data/coco128.yaml', # dataset.yaml path imgsz=(224, 224), # inference size (height, width) device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu @@ -74,6 +74,7 @@ def run( is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file) + screenshot = source.lower().startswith('screen') if is_url and is_file: source = check_file(source) # download @@ -91,6 +92,8 @@ def run( if webcam: view_img = check_imshow() dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride) + elif screenshot: + dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt) else: dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride) bs = len(dataset) # batch_size @@ -187,7 +190,7 @@ def run( def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model path(s)') - parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam') + parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)') parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path') parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[224], help='inference size h,w') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') diff --git a/detect.py b/detect.py index 4015b9ae0d7f..9036b26263e5 100644 --- a/detect.py +++ b/detect.py @@ -40,7 +40,7 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative from models.common import DetectMultiBackend -from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams +from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2, increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh) from utils.plots import Annotator, colors, save_one_box @@ -50,7 +50,7 @@ @smart_inference_mode() def run( weights=ROOT / 'yolov5s.pt', # model.pt path(s) - source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam + source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam) data=ROOT / 'data/coco128.yaml', # dataset.yaml path imgsz=(640, 640), # inference size (height, width) conf_thres=0.25, # confidence threshold @@ -82,6 +82,7 @@ def run( is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file) + screenshot = source.lower().startswith('screen') if is_url and is_file: source = check_file(source) # download @@ -99,6 +100,8 @@ def run( if webcam: view_img = check_imshow() dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) + elif screenshot: + dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt) else: dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) bs = len(dataset) # batch_size @@ -212,7 +215,7 @@ def run( def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)') - parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam') + parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)') parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path') parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w') parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold') diff --git a/requirements.txt b/requirements.txt index 55c1f2428e3f..914da54e73fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,6 +38,7 @@ seaborn>=0.11.0 ipython # interactive notebook psutil # system utilization thop>=0.1.1 # FLOPs computation +# mss # screenshots # albumentations>=1.0.3 # pycocotools>=2.0 # COCO mAP # roboflow diff --git a/segment/predict.py b/segment/predict.py index 2ea6bd9327e0..43cebc706371 100644 --- a/segment/predict.py +++ b/segment/predict.py @@ -40,7 +40,7 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative from models.common import DetectMultiBackend -from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams +from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2, increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh) from utils.plots import Annotator, colors, save_one_box @@ -51,7 +51,7 @@ @smart_inference_mode() def run( weights=ROOT / 'yolov5s-seg.pt', # model.pt path(s) - source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam + source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam) data=ROOT / 'data/coco128.yaml', # dataset.yaml path imgsz=(640, 640), # inference size (height, width) conf_thres=0.25, # confidence threshold @@ -84,6 +84,7 @@ def run( is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file) + screenshot = source.lower().startswith('screen') if is_url and is_file: source = check_file(source) # download @@ -101,6 +102,8 @@ def run( if webcam: view_img = check_imshow() dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) + elif screenshot: + dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt) else: dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) bs = len(dataset) # batch_size @@ -222,7 +225,7 @@ def run( def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-seg.pt', help='model path(s)') - parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam') + parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)') parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path') parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w') parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold') diff --git a/tutorial.ipynb b/tutorial.ipynb index 957437b2be6d..f87cccd99df8 100644 --- a/tutorial.ipynb +++ b/tutorial.ipynb @@ -445,6 +445,7 @@ "python detect.py --source 0 # webcam\n", " img.jpg # image \n", " vid.mp4 # video\n", + " screen # screenshot\n", " path/ # directory\n", " 'path/*.jpg' # glob\n", " 'https://youtu.be/Zgi9g1ksQHc' # YouTube\n", diff --git a/utils/dataloaders.py b/utils/dataloaders.py index ee79bd0bc5a5..7aee0b891161 100644 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -185,6 +185,55 @@ def __iter__(self): yield from iter(self.sampler) +class LoadScreenshots: + # YOLOv5 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"` + def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None): + # source = [screen_number left top width height] (pixels) + check_requirements('mss') + import mss + + source, *params = source.split() + self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0 + if len(params) == 1: + self.screen = int(params[0]) + elif len(params) == 4: + left, top, width, height = (int(x) for x in params) + elif len(params) == 5: + self.screen, left, top, width, height = (int(x) for x in params) + self.img_size = img_size + self.stride = stride + self.transforms = transforms + self.auto = auto + self.mode = 'stream' + self.frame = 0 + self.sct = mss.mss() + + # Parse monitor shape + monitor = self.sct.monitors[self.screen] + self.top = monitor["top"] if top is None else (monitor["top"] + top) + self.left = monitor["left"] if left is None else (monitor["left"] + left) + self.width = width or monitor["width"] + self.height = height or monitor["height"] + self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height} + + def __iter__(self): + return self + + def __next__(self): + # mss screen capture: get raw pixels from the screen as np array + im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR + s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " + + if self.transforms: + im = self.transforms(im0) # transforms + else: + im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize + im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB + im = np.ascontiguousarray(im) # contiguous + self.frame += 1 + return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s + + class LoadImages: # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4` def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):