Skip to content

Commit

Permalink
Apple MPS -> CPU NMS fallback strategy (#9600)
Browse files Browse the repository at this point in the history
Until more ops are fully supported this update will allow for seamless MPS inference (but slower MPS to CPU transfer before NMS, so slower NMS times).

Partially resolves #9596

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
glenn-jocher authored Sep 26, 2022
1 parent bd9c0c4 commit c4c0ee8
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,8 @@ def non_max_suppression(
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output

if 'mps' in prediction.device.type: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()
bs = prediction.shape[0] # batch size
nc = prediction.shape[2] - nm - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates
Expand Down

0 comments on commit c4c0ee8

Please sign in to comment.