Skip to content

Commit

Permalink
Clip TTA Augmented Tails (#5028)
Browse files Browse the repository at this point in the history
* Clip TTA Augmented Tails

Experimental TTA update.

* Update yolo.py

* Update yolo.py

* Update yolo.py

* Update yolo.py
  • Loading branch information
glenn-jocher committed Oct 4, 2021
1 parent 1922dde commit d133968
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def _forward_augment(self, x):
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
yi = self._descale_pred(yi, fi, si, img_size)
y.append(yi)
y = self._clip_augmented(y) # clip augmented tails
return torch.cat(y, 1), None # augmented inference, train

def _forward_once(self, x, profile=False, visualize=False):
Expand Down Expand Up @@ -166,6 +167,17 @@ def _descale_pred(self, p, flips, scale, img_size):
p = torch.cat((x, y, wh, p[..., 4:]), -1)
return p

def _clip_augmented(self, y):
# Clip YOLOv5 augmented inference tails
nl = self.model[-1].nl # number of detection layers (P3-P5)
g = sum(4 ** x for x in range(nl)) # grid points
e = 1 # exclude layer count
i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
y[0] = y[0][:, :-i] # large
i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
y[-1] = y[-1][:, i:] # small
return y

def _profile_one_layer(self, m, x, dt):
c = isinstance(m, Detect) # is final layer, copy input as inplace fix
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
Expand Down

0 comments on commit d133968

Please sign in to comment.