Skip to content

Commit

Permalink
Avoid FP64 ops for MPS support in train.py (ultralytics#8511)
Browse files Browse the repository at this point in the history
Avoid FP64 ops for MPS support

Resolves ultralytics#7878 (comment)
  • Loading branch information
glenn-jocher authored and Shivvrat committed Jul 12, 2022
1 parent ad26981 commit 6cd9ed1
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def labels_to_class_weights(labels, nc=80):
return torch.Tensor()

labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
classes = labels[:, 0].astype(np.int) # labels = [class xywh]
classes = labels[:, 0].astype(int) # labels = [class xywh]
weights = np.bincount(classes, minlength=nc) # occurrences per class

# Prepend gridpoint count (for uCE training)
Expand All @@ -654,13 +654,13 @@ def labels_to_class_weights(labels, nc=80):
weights[weights == 0] = 1 # replace empty bins with 1
weights = 1 / weights # number of targets per class
weights /= weights.sum() # normalize
return torch.from_numpy(weights)
return torch.from_numpy(weights).float()


def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
# Produces image weights based on class_weights and image contents
# Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
return (class_weights.reshape(1, nc) * class_counts).sum(1)


Expand Down

0 comments on commit 6cd9ed1

Please sign in to comment.