Skip to content

Commit

Permalink
batch inference update
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Oct 5, 2020
1 parent 372ad4c commit 50974db
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,22 +116,38 @@ def forward(self, x):


class autoShape(nn.Module):
# auto-reshape image size model wrapper
# auto-reshape input image model wrapper
img_size = 640 # inference size (pixels)

def __init__(self, model):
super(autoShape, self).__init__()
self.model = model

def forward(self, x, shape=640, augment=False, profile=False): # x = cv2.imread('img.jpg')
x0shape = x.shape[:2]
p = next(self.model.parameters())
x, ratio, (dw, dh) = letterbox(x, new_shape=make_divisible(shape or max(x0shape), int(self.stride.max())))
x1shape = x.shape[:2]
x = np.ascontiguousarray(x[:, :, ::-1].transpose(2, 0, 1)) # BGR to RGB, to 3x640x640
x = torch.from_numpy(x).to(p.device).type_as(p).unsqueeze(0) / 255. # uint8 to fp16/32
def forward(self, x, shape=640, augment=False, profile=False):
# x is cv2/np/PIL RGB image, or list of images for batched inference, i.e. x = Image.open('image.jpg')
p = next(self.model.parameters()) # for device and type
if not isinstance(x, list):
x = [x]
batch = range(len(x)) # batch size

shape0, shape1 = [], [] # image and inference shapes
for i in batch:
x[i] = np.array(x[i])[:, :, :3] # up to 3 channels if png
s = x[i].shape[:2] # HWC
shape0.append(s) # image shape
g = (shape / max(s)) # gain
shape1.append([y * g for y in s])
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape

x = [letterbox(x[i], new_shape=shape1, auto=False)[0] for i in batch] # pad
x = np.stack(x, 0) if batch[-1] else x[0][None] # stack
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32

x = self.model(x, augment, profile) # forward
x[0][:, :4] = scale_coords(x1shape, x[0][:, :4], x0shape)

for i in batch:
x[i][:, :4] = scale_coords(shape1, x[i][:, :4], shape0[i]) # postprocess
return x


Expand Down

0 comments on commit 50974db

Please sign in to comment.