diff --git a/utils/general.py b/utils/general.py index de7871cb23f9..a855691d3a1f 100644 --- a/utils/general.py +++ b/utils/general.py @@ -843,6 +843,8 @@ def non_max_suppression( if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out) prediction = prediction[0] # select only inference output + if 'mps' in prediction.device.type: # MPS not fully supported yet, convert tensors to CPU before NMS + prediction = prediction.cpu() bs = prediction.shape[0] # batch size nc = prediction.shape[2] - nm - 5 # number of classes xc = prediction[..., 4] > conf_thres # candidates