diff --git a/src/super_gradients/examples/predict/detection_predict.py b/src/super_gradients/examples/predict/detection_predict.py index 42963e4ed7..00aaa02ce4 100644 --- a/src/super_gradients/examples/predict/detection_predict.py +++ b/src/super_gradients/examples/predict/detection_predict.py @@ -1,9 +1,13 @@ +import torch from super_gradients.common.object_names import Models from super_gradients.training import models # Note that currently only YoloX, PPYoloE and YOLO-NAS are supported. model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco") +# We want to use cuda if available to speed up inference. +model = model.to("cuda" if torch.cuda.is_available() else "cpu") + IMAGES = [ "../../../../documentation/source/images/examples/countryside.jpg", "../../../../documentation/source/images/examples/street_busy.jpg", diff --git a/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_head.py b/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_head.py index c82e5a652f..82bcbcdee0 100644 --- a/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_head.py +++ b/src/super_gradients/training/models/detection_models/pp_yolo_e/pp_yolo_head.py @@ -270,8 +270,12 @@ def _generate_anchors(self, feats=None, dtype=None, device=None): else: h = int(self.eval_size[0] / stride) w = int(self.eval_size[1] / stride) - shift_x = torch.arange(end=w, dtype=dtype) + self.grid_cell_offset - shift_y = torch.arange(end=h, dtype=dtype) + self.grid_cell_offset + + # ONNX export does not support arange with float16, so it is created as fp32 and then casted to fp16 + # This produce correct fp16 weights in ONNX model when exported + shift_x = torch.arange(end=w, dtype=torch.float32, device=device) + self.grid_cell_offset + shift_y = torch.arange(end=h, dtype=torch.float32, device=device) + self.grid_cell_offset + if torch_version_is_greater_or_equal(1, 10): shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing="ij") else: @@ -279,13 +283,10 @@ def _generate_anchors(self, feats=None, dtype=None, device=None): anchor_point = torch.stack([shift_x, shift_y], dim=-1).to(dtype=dtype) anchor_points.append(anchor_point.reshape([-1, 2])) - stride_tensor.append(torch.full([h * w, 1], stride, dtype=dtype)) + stride_tensor.append(torch.full([h * w, 1], stride, dtype=dtype, device=device)) + anchor_points = torch.cat(anchor_points) stride_tensor = torch.cat(stride_tensor) - - if device is not None: - anchor_points = anchor_points.to(device) - stride_tensor = stride_tensor.to(device) return anchor_points, stride_tensor def forward(self, feats: Tuple[Tensor]): diff --git a/src/super_gradients/training/models/detection_models/yolo_nas/dfl_heads.py b/src/super_gradients/training/models/detection_models/yolo_nas/dfl_heads.py index 0418f97782..790b030e1c 100644 --- a/src/super_gradients/training/models/detection_models/yolo_nas/dfl_heads.py +++ b/src/super_gradients/training/models/detection_models/yolo_nas/dfl_heads.py @@ -283,8 +283,10 @@ def _generate_anchors(self, feats=None, dtype=None, device=None): h = int(self.eval_size[0] / stride) w = int(self.eval_size[1] / stride) - shift_x = torch.arange(end=w, dtype=dtype) + self.grid_cell_offset - shift_y = torch.arange(end=h, dtype=dtype) + self.grid_cell_offset + # ONNX export does not support arange with float16, so it is created as fp32 and then casted to fp16 + # This produce correct fp16 weights in ONNX model when exported + shift_x = torch.arange(end=w, dtype=torch.float32, device=device) + self.grid_cell_offset + shift_y = torch.arange(end=h, dtype=torch.float32, device=device) + self.grid_cell_offset if torch_version_is_greater_or_equal(1, 10): shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing="ij") @@ -293,11 +295,8 @@ def _generate_anchors(self, feats=None, dtype=None, device=None): anchor_point = torch.stack([shift_x, shift_y], dim=-1).to(dtype=dtype) anchor_points.append(anchor_point.reshape([-1, 2])) - stride_tensor.append(torch.full([h * w, 1], stride, dtype=dtype)) + stride_tensor.append(torch.full([h * w, 1], stride, dtype=dtype, device=device)) + anchor_points = torch.cat(anchor_points) stride_tensor = torch.cat(stride_tensor) - - if device is not None: - anchor_points = anchor_points.to(device) - stride_tensor = stride_tensor.to(device) return anchor_points, stride_tensor