From c55e2cd73b472de808665f8337d6edeaebb74521 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 1 Nov 2022 14:53:14 +0100 Subject: [PATCH] Add `min_items` filter option (#9997) * Add `min_items` filter option @AyushExel @Laughing-q dataset filter Signed-off-by: Glenn Jocher * Update dataloaders.py Signed-off-by: Glenn Jocher * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Glenn Jocher Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- utils/dataloaders.py | 17 +++++++++++++++-- utils/segment/dataloaders.py | 3 ++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 403252ff6227..6b6e83e30456 100644 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -444,6 +444,7 @@ def __init__(self, single_cls=False, stride=32, pad=0.0, + min_items=0, prefix=''): self.img_size = img_size self.augment = augment @@ -475,7 +476,7 @@ def __init__(self, # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib assert self.im_files, f'{prefix}No images found' except Exception as e: - raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') + raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e # Check cache self.label_files = img2label_paths(self.im_files) # labels @@ -505,7 +506,19 @@ def __init__(self, self.shapes = np.array(shapes) self.im_files = list(cache.keys()) # update self.label_files = img2label_paths(cache.keys()) # update - n = len(shapes) # number of images + + # Filter images + if min_items: + include = np.array([len(x) > min_items for x in self.labels]).nonzero()[0].astype(int) + LOGGER.info(f'{prefix}{nf - len(include)}/{nf} images filtered from dataset') + self.im_files = [self.im_files[i] for i in include] + self.label_files = [self.label_files[i] for i in include] + self.labels = [self.labels[i] for i in include] + self.segments = [self.segments[i] for i in include] + self.shapes = self.shapes[include] # wh + + # Create indices + n = len(self.shapes) # number of images bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index nb = bi[-1] + 1 # number of batches self.batch = bi # batch index of image diff --git a/utils/segment/dataloaders.py b/utils/segment/dataloaders.py index a63d6ec013fd..9de6f0fbf903 100644 --- a/utils/segment/dataloaders.py +++ b/utils/segment/dataloaders.py @@ -93,12 +93,13 @@ def __init__( single_cls=False, stride=32, pad=0, + min_items=0, prefix="", downsample_ratio=1, overlap=False, ): super().__init__(path, img_size, batch_size, augment, hyp, rect, image_weights, cache_images, single_cls, - stride, pad, prefix) + stride, pad, min_items, prefix) self.downsample_ratio = downsample_ratio self.overlap = overlap