diff --git a/utils/datasets.py b/utils/datasets.py index a086e6bfb782..4f9bd0f05d09 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -437,10 +437,6 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r self.shapes = np.array(shapes, dtype=np.float64) self.img_files = list(cache.keys()) # update self.label_files = img2label_paths(cache.keys()) # update - if single_cls: - for x in self.labels: - x[:, 0] = 0 - n = len(shapes) # number of images bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index nb = bi[-1] + 1 # number of batches @@ -448,6 +444,20 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r self.n = n self.indices = range(n) + # Update labels + include_class = [] # filter labels to include only these classes (optional) + include_class_array = np.array(include_class).reshape(1, -1) + for i, (label, segment) in enumerate(zip(self.labels, self.segments)): + if include_class: + j = (label[:, 0:1] == include_class_array).any(1) + self.labels[i] = label[j] + if segment: + self.segments[i] = segment[j] + if single_cls: # single-class training, merge all classes into 0 + self.labels[i][:, 0] = 0 + if segment: + self.segments[i][:, 0] = 0 + # Rectangular Training if self.rect: # Sort by aspect ratio