Skip to content

Commit

Permalink
Add min_items filter option (#9997)
Browse files Browse the repository at this point in the history
* Add `min_items` filter option

@AyushExel @Laughing-q dataset filter

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Update dataloaders.py

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
glenn-jocher and pre-commit-ci[bot] committed Nov 1, 2022
1 parent cf99788 commit c55e2cd
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
17 changes: 15 additions & 2 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def __init__(self,
single_cls=False,
stride=32,
pad=0.0,
min_items=0,
prefix=''):
self.img_size = img_size
self.augment = augment
Expand Down Expand Up @@ -475,7 +476,7 @@ def __init__(self,
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert self.im_files, f'{prefix}No images found'
except Exception as e:
raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}')
raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e

# Check cache
self.label_files = img2label_paths(self.im_files) # labels
Expand Down Expand Up @@ -505,7 +506,19 @@ def __init__(self,
self.shapes = np.array(shapes)
self.im_files = list(cache.keys()) # update
self.label_files = img2label_paths(cache.keys()) # update
n = len(shapes) # number of images

# Filter images
if min_items:
include = np.array([len(x) > min_items for x in self.labels]).nonzero()[0].astype(int)
LOGGER.info(f'{prefix}{nf - len(include)}/{nf} images filtered from dataset')
self.im_files = [self.im_files[i] for i in include]
self.label_files = [self.label_files[i] for i in include]
self.labels = [self.labels[i] for i in include]
self.segments = [self.segments[i] for i in include]
self.shapes = self.shapes[include] # wh

# Create indices
n = len(self.shapes) # number of images
bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
nb = bi[-1] + 1 # number of batches
self.batch = bi # batch index of image
Expand Down
3 changes: 2 additions & 1 deletion utils/segment/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,13 @@ def __init__(
single_cls=False,
stride=32,
pad=0,
min_items=0,
prefix="",
downsample_ratio=1,
overlap=False,
):
super().__init__(path, img_size, batch_size, augment, hyp, rect, image_weights, cache_images, single_cls,
stride, pad, prefix)
stride, pad, min_items, prefix)
self.downsample_ratio = downsample_ratio
self.overlap = overlap

Expand Down

0 comments on commit c55e2cd

Please sign in to comment.