Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --source screen for screenshot inference #9542

Merged
merged 24 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshot, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
Expand Down Expand Up @@ -82,6 +82,18 @@ def run(
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
scrshot = source.lower().startswith('screen') # screenshot
if scrshot:
# get all parames
source, *params = source.split()
screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
if len(params) == 1:
screen = int(params[0])
elif len(params) == 4:
left, top, width, height = (int(x) for x in params)
elif len(params) == 5:
screen, left, top, width, height = (int(x) for x in params)

if is_url and is_file:
source = check_file(source) # download

Expand All @@ -100,6 +112,16 @@ def run(
view_img = check_imshow()
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset) # batch_size
elif scrshot:
dataset = LoadScreenshot(screen=screen,
img_size=imgsz,
stride=stride,
auto=pt,
left=left,
top=top,
width=width,
height=height)
bs = 1 # batch_size
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = 1 # batch_size
Expand Down Expand Up @@ -213,7 +235,11 @@ def run(
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)')
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam')
parser.add_argument(
'--source',
type=str,
default=ROOT / 'data/images',
help='file/dir/URL/glob, 0 for webcam, "screen [screennumber] [left top width height]" for screen')
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ torch>=1.7.0
torchvision>=0.8.1
tqdm>=4.64.0
# protobuf<=3.20.1 # https://github.com/ultralytics/yolov5/issues/8012
mss

# Logging -------------------------------------
tensorboard>=2.4.1
Expand Down
58 changes: 58 additions & 0 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from urllib.parse import urlparse
from zipfile import ZipFile

import mss
import numpy as np
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -185,6 +186,63 @@ def __iter__(self):
yield from iter(self.sampler)


class LoadScreenshot:
# YOLOv5 image/video dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"`
def __init__(self,
screen=0,
left=None,
top=None,
width=None,
height=None,
img_size=640,
stride=32,
auto=True,
transforms=None):
self.img_size = img_size
self.stride = stride
self.transforms = transforms
self.auto = auto
self.screen = screen
self.left = left
self.top = top
self.width = width
self.height = height
self.mode = 'image'
self.sct = mss.mss()
# Get information of monitor
self.monitor = self.sct.monitors[self.screen]
if self.top is None:
self.top = self.monitor["top"]
else:
self.top = self.monitor["top"] + self.top
if self.left is None:
self.left = self.monitor["left"]
else:
self.left = self.monitor["left"] + self.left
if self.width is None:
self.width = self.monitor["width"]
if self.height is None:
self.height = self.monitor["height"]
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}

def __iter__(self):
return self

def __next__(self):
# mss screen capture
# Get raw pixels from the screen, save it to a Numpy array
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "

if self.transforms:
im = self.transforms(im0) # transforms
else:
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s


class LoadImages:
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
Expand Down