diff --git a/models/common.py b/models/common.py index 21a2ed5a2ca7..5c0e571b752f 100644 --- a/models/common.py +++ b/models/common.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn from PIL import Image +from torch.cuda import amp from utils.datasets import letterbox from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh @@ -219,17 +220,17 @@ def forward(self, imgs, size=640, augment=False, profile=False): x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 t.append(time_synchronized()) - # Inference - with torch.no_grad(): + with torch.no_grad(), amp.autocast(enabled=p.device.type != 'cpu'): + # Inference y = self.model(x, augment, profile)[0] # forward - t.append(time_synchronized()) + t.append(time_synchronized()) - # Post-process - y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS - for i in range(n): - scale_coords(shape1, y[i][:, :4], shape0[i]) - t.append(time_synchronized()) + # Post-process + y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS + for i in range(n): + scale_coords(shape1, y[i][:, :4], shape0[i]) + t.append(time_synchronized()) return Detections(imgs, y, files, t, self.names, x.shape)