From 52c1399fdc6c3db550123e47a2cdcb6dc951e211 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 12 Mar 2022 13:16:29 +0100 Subject: [PATCH] DetectMultiBackend() return `device` update (#6958) Fixes ONNX validation that returns outputs on CPU. --- models/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/common.py b/models/common.py index 251463525392..48cf55795dd4 100644 --- a/models/common.py +++ b/models/common.py @@ -458,7 +458,8 @@ def forward(self, im, augment=False, visualize=False, val=False): y = (y.astype(np.float32) - zero_point) * scale # re-scale y[..., :4] *= [w, h, w, h] # xywh normalized to pixels - y = torch.tensor(y) if isinstance(y, np.ndarray) else y + if isinstance(y, np.ndarray): + y = torch.tensor(y, device=self.device) return (y, []) if val else y def warmup(self, imgsz=(1, 3, 640, 640)):