Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add variable-stride inference support #2091

Merged
merged 1 commit into from
Jan 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def detect(save_img=False):

# Load model
model = attempt_load(weights, map_location=device) # load FP32 model
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
stride = int(model.stride.max()) # model stride
imgsz = check_img_size(imgsz, s=stride) # check img_size
if half:
model.half() # to FP16

Expand All @@ -46,10 +47,10 @@ def detect(save_img=False):
if webcam:
view_img = True
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz)
dataset = LoadStreams(source, img_size=imgsz, stride=stride)
else:
save_img = True
dataset = LoadImages(source, img_size=imgsz)
dataset = LoadImages(source, img_size=imgsz, stride=stride)

# Get names and colors
names = model.module.names if hasattr(model, 'module') else model.names
Expand Down
23 changes: 13 additions & 10 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __iter__(self):


class LoadImages: # for inference
def __init__(self, path, img_size=640):
def __init__(self, path, img_size=640, stride=32):
p = str(Path(path)) # os-agnostic
p = os.path.abspath(p) # absolute path
if '*' in p:
Expand All @@ -136,6 +136,7 @@ def __init__(self, path, img_size=640):
ni, nv = len(images), len(videos)

self.img_size = img_size
self.stride = stride
self.files = images + videos
self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv
Expand Down Expand Up @@ -181,7 +182,7 @@ def __next__(self):
print(f'image {self.count}/{self.nf} {path}: ', end='')

# Padded resize
img = letterbox(img0, new_shape=self.img_size)[0]
img = letterbox(img0, self.img_size, stride=self.stride)[0]

# Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
Expand All @@ -199,8 +200,9 @@ def __len__(self):


class LoadWebcam: # for inference
def __init__(self, pipe='0', img_size=640):
def __init__(self, pipe='0', img_size=640, stride=32):
self.img_size = img_size
self.stride = stride

if pipe.isnumeric():
pipe = eval(pipe) # local camera
Expand Down Expand Up @@ -243,7 +245,7 @@ def __next__(self):
print(f'webcam {self.count}: ', end='')

# Padded resize
img = letterbox(img0, new_shape=self.img_size)[0]
img = letterbox(img0, self.img_size, stride=self.stride)[0]

# Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
Expand All @@ -256,9 +258,10 @@ def __len__(self):


class LoadStreams: # multiple IP or RTSP cameras
def __init__(self, sources='streams.txt', img_size=640):
def __init__(self, sources='streams.txt', img_size=640, stride=32):
self.mode = 'stream'
self.img_size = img_size
self.stride = stride

if os.path.isfile(sources):
with open(sources, 'r') as f:
Expand All @@ -284,7 +287,7 @@ def __init__(self, sources='streams.txt', img_size=640):
print('') # newline

# check for common shapes
s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
if not self.rect:
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
Expand Down Expand Up @@ -313,7 +316,7 @@ def __next__(self):
raise StopIteration

# Letterbox
img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0]
img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]

# Stack
img = np.stack(img, 0)
Expand Down Expand Up @@ -784,8 +787,8 @@ def replicate(img, labels):
return img, labels


def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
# Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
# Resize and pad image while meeting stride-multiple constraints
shape = img.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
Expand All @@ -800,7 +803,7 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
elif scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
Expand Down