From a610f0c0d50c5fb6bf39912680ce147491c8a03e Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 27 Sep 2022 17:55:45 +0200 Subject: [PATCH 1/2] NMS MPS device wrapper May resolve https://github.com/ultralytics/yolov5/issues/9613 Signed-off-by: Glenn Jocher --- utils/general.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/utils/general.py b/utils/general.py index a855691d3a1f..8d307e944104 100644 --- a/utils/general.py +++ b/utils/general.py @@ -843,7 +843,9 @@ def non_max_suppression( if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out) prediction = prediction[0] # select only inference output - if 'mps' in prediction.device.type: # MPS not fully supported yet, convert tensors to CPU before NMS + device = prediction.device + mps = 'mps' in device.type + if mps: # MPS not fully supported yet, convert tensors to CPU before NMS prediction = prediction.cpu() bs = prediction.shape[0] # batch size nc = prediction.shape[2] - nm - 5 # number of classes @@ -930,6 +932,8 @@ def non_max_suppression( i = i[iou.sum(1) > 1] # require redundancy output[xi] = x[i] + if mps: + output[xi] = output[xi].to(device) if (time.time() - t) > time_limit: LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded') break # time limit exceeded From 53622188091f87c88b8dd943eac7e6651d869910 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 27 Sep 2022 17:58:35 +0200 Subject: [PATCH 2/2] Update general.py Signed-off-by: Glenn Jocher --- utils/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/general.py b/utils/general.py index 8d307e944104..d31b043a113e 100644 --- a/utils/general.py +++ b/utils/general.py @@ -844,7 +844,7 @@ def non_max_suppression( prediction = prediction[0] # select only inference output device = prediction.device - mps = 'mps' in device.type + mps = 'mps' in device.type # Apple MPS if mps: # MPS not fully supported yet, convert tensors to CPU before NMS prediction = prediction.cpu() bs = prediction.shape[0] # batch size