From 2276307b2ac82264abb88303ff78660d1d27201a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 7 Jul 2022 20:36:23 +0200 Subject: [PATCH] Avoid FP64 ops for MPS support in train.py (#8511) Avoid FP64 ops for MPS support Resolves https://github.com/ultralytics/yolov5/pull/7878#issuecomment-1177952614 --- utils/general.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/general.py b/utils/general.py index 17b689010b39..a85a2915a31a 100755 --- a/utils/general.py +++ b/utils/general.py @@ -644,7 +644,7 @@ def labels_to_class_weights(labels, nc=80): return torch.Tensor() labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO - classes = labels[:, 0].astype(np.int) # labels = [class xywh] + classes = labels[:, 0].astype(int) # labels = [class xywh] weights = np.bincount(classes, minlength=nc) # occurrences per class # Prepend gridpoint count (for uCE training) @@ -654,13 +654,13 @@ def labels_to_class_weights(labels, nc=80): weights[weights == 0] = 1 # replace empty bins with 1 weights = 1 / weights # number of targets per class weights /= weights.sum() # normalize - return torch.from_numpy(weights) + return torch.from_numpy(weights).float() def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)): # Produces image weights based on class_weights and image contents # Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample - class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels]) + class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels]) return (class_weights.reshape(1, nc) * class_counts).sum(1)