Skip to content

Commit

Permalink
Fix YoloNAS on cuda (#1444)
Browse files Browse the repository at this point in the history
* fix

* Fixed creation of torch.arange with fp16 dtype

---------

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
  • Loading branch information
Louis-Dupont and BloodAxe committed Oct 12, 2023
1 parent 06038a2 commit ecdec5e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
4 changes: 4 additions & 0 deletions src/super_gradients/examples/predict/detection_predict.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import torch
from super_gradients.common.object_names import Models
from super_gradients.training import models

# Note that currently only YoloX, PPYoloE and YOLO-NAS are supported.
model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco")

# We want to use cuda if available to speed up inference.
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

IMAGES = [
"../../../../documentation/source/images/examples/countryside.jpg",
"../../../../documentation/source/images/examples/street_busy.jpg",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,22 +270,23 @@ def _generate_anchors(self, feats=None, dtype=None, device=None):
else:
h = int(self.eval_size[0] / stride)
w = int(self.eval_size[1] / stride)
shift_x = torch.arange(end=w, dtype=dtype) + self.grid_cell_offset
shift_y = torch.arange(end=h, dtype=dtype) + self.grid_cell_offset

# ONNX export does not support arange with float16, so it is created as fp32 and then casted to fp16
# This produce correct fp16 weights in ONNX model when exported
shift_x = torch.arange(end=w, dtype=torch.float32, device=device) + self.grid_cell_offset
shift_y = torch.arange(end=h, dtype=torch.float32, device=device) + self.grid_cell_offset

if torch_version_is_greater_or_equal(1, 10):
shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing="ij")
else:
shift_y, shift_x = torch.meshgrid(shift_y, shift_x)

anchor_point = torch.stack([shift_x, shift_y], dim=-1).to(dtype=dtype)
anchor_points.append(anchor_point.reshape([-1, 2]))
stride_tensor.append(torch.full([h * w, 1], stride, dtype=dtype))
stride_tensor.append(torch.full([h * w, 1], stride, dtype=dtype, device=device))

anchor_points = torch.cat(anchor_points)
stride_tensor = torch.cat(stride_tensor)

if device is not None:
anchor_points = anchor_points.to(device)
stride_tensor = stride_tensor.to(device)
return anchor_points, stride_tensor

def forward(self, feats: Tuple[Tensor]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,10 @@ def _generate_anchors(self, feats=None, dtype=None, device=None):
h = int(self.eval_size[0] / stride)
w = int(self.eval_size[1] / stride)

shift_x = torch.arange(end=w, dtype=dtype) + self.grid_cell_offset
shift_y = torch.arange(end=h, dtype=dtype) + self.grid_cell_offset
# ONNX export does not support arange with float16, so it is created as fp32 and then casted to fp16
# This produce correct fp16 weights in ONNX model when exported
shift_x = torch.arange(end=w, dtype=torch.float32, device=device) + self.grid_cell_offset
shift_y = torch.arange(end=h, dtype=torch.float32, device=device) + self.grid_cell_offset

if torch_version_is_greater_or_equal(1, 10):
shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing="ij")
Expand All @@ -293,11 +295,8 @@ def _generate_anchors(self, feats=None, dtype=None, device=None):

anchor_point = torch.stack([shift_x, shift_y], dim=-1).to(dtype=dtype)
anchor_points.append(anchor_point.reshape([-1, 2]))
stride_tensor.append(torch.full([h * w, 1], stride, dtype=dtype))
stride_tensor.append(torch.full([h * w, 1], stride, dtype=dtype, device=device))

anchor_points = torch.cat(anchor_points)
stride_tensor = torch.cat(stride_tensor)

if device is not None:
anchor_points = anchor_points.to(device)
stride_tensor = stride_tensor.to(device)
return anchor_points, stride_tensor

0 comments on commit ecdec5e

Please sign in to comment.