diff --git a/models/common.py b/models/common.py index 7690f714def8..a6488dd85648 100644 --- a/models/common.py +++ b/models/common.py @@ -441,6 +441,9 @@ def wrap_frozen_graph(gd, inputs, outputs): def forward(self, im, augment=False, visualize=False, val=False): # YOLOv5 MultiBackend inference b, ch, h, w = im.shape # batch, channel, height, width + if self.fp16 and im.dtype != torch.float16: + im = im.half() # to FP16 + if self.pt: # PyTorch y = self.model(im, augment=augment, visualize=visualize)[0] elif self.jit: # TorchScript