Skip to content

Commit

Permalink
Create dataset_stats() for HUB
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jun 8, 2021
1 parent ac8691e commit b6fdd2e
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import numpy as np
import torch
import torch.nn.functional as F
import yaml
from PIL import Image, ExifTags
from torch.utils.data import Dataset
from tqdm import tqdm

from utils.general import check_requirements, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, \
resample_segments, clean_str
from utils.general import check_requirements, check_file, check_dataset, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, \
segment2box, segments2boxes, resample_segments, clean_str
from utils.torch_utils import torch_distributed_zero_first

# Parameters
Expand Down Expand Up @@ -1083,3 +1084,34 @@ def verify_image_label(params):
nc = 1
logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
return [None] * 4 + [nm, nf, ne, nc]


def dataset_stats(path='data/coco128.yaml', verbose=False):
""" Return dataset statistics dictionary with images and instances counts per split per class
Usage: from utils.datasets import *; dataset_stats('data/coco128.yaml')
Arguments
path: Path to data.yaml
verbose: Print stats dictionary
"""
path = check_file(Path(path))
with open(path) as f:
data = yaml.safe_load(f) # data dict
check_dataset(data) # download dataset if missing

nc = data['nc'] # number of classes
stats = {'nc': nc, 'names': data['names']} # statistics dictionary
for split in 'train', 'val', 'test':
if split not in data:
stats[split] = None # i.e. no test set
continue
x = []
dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
x = np.array(x) # shape(128x80)
stats[split] = {'instances': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
'images': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
'per_class': (x > 0).sum(0).tolist()}}
if verbose:
print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
return stats

0 comments on commit b6fdd2e

Please sign in to comment.