diff --git a/models/common.py b/models/common.py index 251463525392..48cf55795dd4 100644 --- a/models/common.py +++ b/models/common.py @@ -458,7 +458,8 @@ def forward(self, im, augment=False, visualize=False, val=False): y = (y.astype(np.float32) - zero_point) * scale # re-scale y[..., :4] *= [w, h, w, h] # xywh normalized to pixels - y = torch.tensor(y) if isinstance(y, np.ndarray) else y + if isinstance(y, np.ndarray): + y = torch.tensor(y, device=self.device) return (y, []) if val else y def warmup(self, imgsz=(1, 3, 640, 640)):