diff --git a/utils/datasets.py b/utils/datasets.py index daaa8d24855e..7c74d2c01322 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -17,12 +17,13 @@ import numpy as np import torch import torch.nn.functional as F +import yaml from PIL import Image, ExifTags from torch.utils.data import Dataset from tqdm import tqdm -from utils.general import check_requirements, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, \ - resample_segments, clean_str +from utils.general import check_requirements, check_file, check_dataset, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, \ + segment2box, segments2boxes, resample_segments, clean_str from utils.torch_utils import torch_distributed_zero_first # Parameters @@ -1083,3 +1084,34 @@ def verify_image_label(params): nc = 1 logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') return [None] * 4 + [nm, nf, ne, nc] + + +def dataset_stats(path='data/coco128.yaml', verbose=False): + """ Return dataset statistics dictionary with images and instances counts per split per class + Usage: from utils.datasets import *; dataset_stats('data/coco128.yaml') + Arguments + path: Path to data.yaml + verbose: Print stats dictionary + """ + path = check_file(Path(path)) + with open(path) as f: + data = yaml.safe_load(f) # data dict + check_dataset(data) # download dataset if missing + + nc = data['nc'] # number of classes + stats = {'nc': nc, 'names': data['names']} # statistics dictionary + for split in 'train', 'val', 'test': + if split not in data: + stats[split] = None # i.e. no test set + continue + x = [] + dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset + for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'): + x.append(np.bincount(label[:, 0].astype(int), minlength=nc)) + x = np.array(x) # shape(128x80) + stats[split] = {'instances': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()}, + 'images': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()), + 'per_class': (x > 0).sum(0).tolist()}} + if verbose: + print(yaml.dump([stats], sort_keys=False, default_flow_style=False)) + return stats