Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YOLOv5 AWS Inferentia Inplace compatibility updates #2953

Merged
merged 10 commits into from
Apr 30, 2021
8 changes: 4 additions & 4 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from models.common import Conv, DWConv
from utils.google_utils import attempt_download
from models.yolo import Detect, Model


class CrossConv(nn.Module):
Expand Down Expand Up @@ -110,7 +111,7 @@ def forward(self, x, augment=False):
return y, None # inference, train output


def attempt_load(weights, map_location=None):
def attempt_load(weights, map_location=None, inplace=True):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
Expand All @@ -120,11 +121,10 @@ def attempt_load(weights, map_location=None):

# Compatibility updates
for m in model.modules():
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
m.inplace = True # pytorch 1.7.0 compatibility
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
m.inplace = inplace # pytorch 1.7.0 compatibility
elif type(m) is Conv:
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility

if len(model) == 1:
return model[-1] # return model
else:
Expand Down
78 changes: 59 additions & 19 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Detect(nn.Module):
stride = None # strides computed during build
export = False # onnx export

def __init__(self, nc=80, anchors=(), ch=()): # detection layer
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
super(Detect, self).__init__()
self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor
Expand All @@ -36,6 +36,7 @@ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
self.register_buffer('anchors', a) # shape(nl,na,2)
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.inplace = inplace # use in-place ops (e.g. slice assignment)

def forward(self, x):
# x = x.copy() # for profiling
Expand All @@ -51,8 +52,17 @@ def forward(self, x):
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)

y = x[i].sigmoid()
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh

# Default behavior modifies the tensor in-place.
if self.inplace:
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else:
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
scores = y[..., 4:]
y = torch.cat([xy, wh, scores], -1)

z.append(y.view(bs, -1, self.no))

return x if self.training else (torch.cat(z, 1), x)
Expand All @@ -63,6 +73,33 @@ def _make_grid(nx=20, ny=20):
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()


def _rescale_coords(y, flips, scale, img_size):
y[..., :4] /= scale # de-scale
if flips == 2:
y[..., 1] = img_size[0] - y[..., 1] # de-flip ud
elif flips == 3:
y[..., 0] = img_size[1] - y[..., 0] # de-flip lr
return y


def _rescale_coords_concat(y, flips, scale, img_size):
coords = y[..., :4] / scale # de-scale
scores = y[..., 4:]

x = coords[..., 0:1]
y = coords[..., 1:2]
wh = coords[..., 2:4]

if flips == 2:
y = img_size[0] - y # de-flip ud
elif flips == 3:
x = img_size[1] - x # de-flip lr
else:
return torch.cat([coords, scores], -1)

return torch.cat([x, y, wh, scores], -1)


class Model(nn.Module):
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
super(Model, self).__init__()
Expand All @@ -84,11 +121,13 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i
self.yaml['anchors'] = round(anchors) # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
self.inplace = self.yaml.get('inplace', True)
# logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])

# Build strides, anchors
m = self.model[-1] # Detect()
if isinstance(m, Detect):
m.inplace = self.inplace
s = 256 # 2x min stride
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
m.anchors /= m.stride.view(-1, 1, 1)
Expand All @@ -104,24 +143,26 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i

def forward(self, x, augment=False, profile=False):
if augment:
img_size = x.shape[-2:] # height, width
s = [1, 0.83, 0.67] # scales
f = [None, 3, None] # flips (2-ud, 3-lr)
y = [] # outputs
for si, fi in zip(s, f):
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
yi = self.forward_once(xi)[0] # forward
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
yi[..., :4] /= si # de-scale
if fi == 2:
yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
elif fi == 3:
yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
y.append(yi)
return torch.cat(y, 1), None # augmented inference, train
return self.forward_augment(x) # augmented inference, None
else:
return self.forward_once(x, profile) # single-scale inference, train

def forward_augment(self, x):
img_size = x.shape[-2:] # height, width
s = [1, 0.83, 0.67] # scales
f = [None, 3, None] # flips (2-ud, 3-lr)
y = [] # outputs
for si, fi in zip(s, f):
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
yi = self.forward_once(xi)[0] # forward
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
if self.inplace:
yi = _rescale_coords(yi, fi, si, img_size)
else:
yi = _rescale_coords_concat(yi, fi, si, img_size)
y.append(yi)
return torch.cat(y, 1), None # augmented inference, train

def forward_once(self, x, profile=False):
y, dt = [], [] # outputs
for m in self.model:
Expand Down Expand Up @@ -264,7 +305,6 @@ def parse_model(d, ch): # model_dict, input_channels(3)
# Create model
model = Model(opt.cfg).to(device)
model.train()

# Profile
# img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 320, 320).to(device)
# y = model(img, profile=True)
Expand Down