diff --git a/collect.py b/collect.py new file mode 100644 index 000000000000..4c2475270138 --- /dev/null +++ b/collect.py @@ -0,0 +1,237 @@ +# YOLOv5 🚀 AGPL-3.0 license +""" +Collect 3LC metrics for a trained YOLOv5 detection model on a detection dataset + +Usage: + $ python collect.py --weights yolov5s.pt --data coco128.yaml --img 640 +""" + +import argparse +import os +import sys +from pathlib import Path + +import tlc + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[0] # YOLOv5 root directory +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) # add ROOT to PATH +ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative + +from tlc.client.torch.metrics.collect import collect_metrics +from tlc.core.builtins.constants.column_names import BOUNDING_BOXES + +from models.common import DetectMultiBackend +from utils.general import LOGGER, check_dataset, check_img_size, check_requirements, check_yaml, colorstr, print_args +from utils.tlc_integration import (TLCComputeLoss, create_dataloader, get_or_create_tlc_table, + tlc_create_metrics_collectors) +from utils.torch_utils import select_device, smart_inference_mode + + +@smart_inference_mode() +def run( + data=ROOT / 'data/coco128.yaml', # dataset.yaml path + weights=ROOT / 'yolov5s.pt', # model.pt path(s) + batch_size=1, # batch size TODO: Support batch size > 1 + imgsz=640, # inference size (pixels) + conf_thres=0.001, # confidence threshold + iou_thres=0.6, # NMS IoU threshold + tlc_iou_thres=0.3, # 3LC Metrics collection IoU threshold + max_det=300, # maximum detections per image + split='val', # Split to collect metrics for + device='', # cuda device, i.e. 0 or cpu (only single device supported) + workers=8, # max dataloader workers + single_cls=False, # treat as single-class dataset + half=True, # use FP16 half-precision inference + dnn=False, # use OpenCV DNN for ONNX inference + stride=None, # stride (from training) + epoch=None, # epoch (use when training) + model=None, # model instance + table=None, # table (from training) + tlc_revision_url='', # 3LC revision URL to use for metrics collection + tlc_image_embeddings_dim=0, # Dimension of image embeddings (2 or 3). Default is 0, which means no image embeddings are used. + compute_loss=None, # ComputeLoss instance (from training) + collect_loss=False, # Compute and collect loss for each image during metrics collection +): + + # Initialize/load model and set device + training = model is not None + + if tlc_image_embeddings_dim not in (0, 2, 3): + raise ValueError(f'Invalid value for tlc_image_embeddings_dim: {tlc_image_embeddings_dim}') + if tlc_image_embeddings_dim in (2, 3): + # We need to ensure we have UMAP installed + try: + import umap # noqa: F401 + except ImportError: + raise ValueError('Missing UMAP dependency, run `pip install umap-learn` to enable embeddings collection.') + + if training: # called by train.py + # Check for required args + if any(v is None for v in (epoch, table)): + raise ValueError('When training, epoch and table must be passed') + + device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model + half &= device.type != 'cpu' # half precision only supported on CUDA + model.half() if half else model.float() + model.collecting = tlc_image_embeddings_dim > 0 + + else: # called directly + # Check for required args + if any(v is None for v in (weights, data, split)): + raise ValueError('When not training, model weights, data and split must be passed') + + device = select_device(device, batch_size=batch_size) + + # Load model + model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half, fuse=False) + stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine + imgsz = check_img_size(imgsz, s=stride) # check image size + half = model.fp16 # FP16 supported on limited backends with CUDA + if engine: + batch_size = model.batch_size + else: + device = model.device + if not (pt or jit): + batch_size = 1 # export.py models default to batch-size 1 + LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') + + # Data + if data: + check_dataset(data) # check + table = get_or_create_tlc_table( + yolo_yaml_file=data, + split=split, + revision_url=tlc_revision_url, + ) + if collect_loss: + m = model.model.model[-1] + compute_loss = TLCComputeLoss('cpu', model.model.hyp, m.stride, m.na, m.nc, m.nl, + m.anchors) # DetectMultiBackend holds a DetectionModel, which has hyp + else: + compute_loss = None + run = tlc.init(project_name=table.project_name) # Only create a run when called directly + + # Ensure table is in collecting metrics mode + table.collecting_metrics = True + + # Setup dataloader + dataloader = create_dataloader( + data, # Not really used + imgsz, + batch_size, + stride, + single_cls, + pad=0.5, + rect=False, + workers=workers, + prefix=colorstr(f'collect-{split}: '), + table=table, + )[0] + + # Verify dataset classes + categories = table.get_value_map_for_column(BOUNDING_BOXES) if not training else dataloader.dataset.categories + nc = 1 if single_cls else len(categories) # number of classes + + if not training and pt and not single_cls: # check --weights are trained on --data + ncm = model.model.nc + assert ncm == nc, (f'{weights} ({ncm} classes) trained on different --data than what you passed ({nc} ' + f'classes). Pass correct combination of --weights and --data that are trained together.') + model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz)) # warmup + model.model.collecting = tlc_image_embeddings_dim > 0 + + # Configure + model.eval() + + # Set up metrics collectors + metrics_collectors = tlc_create_metrics_collectors(model=model, + names=categories, + conf_thres=conf_thres, + nms_iou_thres=iou_thres, + max_det=max_det, + iou_thres=tlc_iou_thres, + compute_embeddings=tlc_image_embeddings_dim > 0, + compute_loss=compute_loss) + + # If half precision, update metrics collector models to this + # if half: + # for metrics_collector in metrics_collectors: + # metrics_collector.model.half() + + # Collect metrics + collect_metrics( + table=dataloader.dataset, + metrics_collectors=metrics_collectors, + constants={'epoch': epoch} if epoch is not None else {}, + dataset_name=dataloader.dataset.tlc_name, + dataset_url=dataloader.dataset.tlc_table_url, + dataloader_args={ + 'batch_size': batch_size, + 'collate_fn': dataloader.collate_fn, + 'num_workers': workers, }, + ) + + # Finish up + if training: + model.float() + model.train() + model.collecting = False + table.collecting_metrics = False + return None, dataloader + + else: + if tlc_image_embeddings_dim in (2, 3): + run.reduce_embeddings_per_dataset(n_components=opt.tlc_image_embeddings_dim) + tlc.close() + return run, dataloader + + +def parse_opt(): + parser = argparse.ArgumentParser() + parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') + parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)') + parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)') + parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold') + parser.add_argument('--iou-thres', type=float, default=0.6, help='NMS IoU threshold') + parser.add_argument('--max-det', type=int, default=300, help='maximum detections per image') + parser.add_argument('--split', type=str, default='val', help='Split to collect metrics for') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--workers', type=int, default=1, help='max dataloader workers (per RANK in DDP mode)') + parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset') + parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') + parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference') + parser.add_argument('--batch-size', type=int, default=1, help='Batch size for metrics collection. Defaults to 4.') + # 3LC args + parser.add_argument('--tlc-iou-thres', + type=float, + default=0.3, + help='IoU threshold for 3LC to consider a prediction a match') + parser.add_argument('--tlc-revision-url', + type=str, + default='', + help='URL to the revision of the 3LC dataset to collect metrics for') + parser.add_argument('--tlc-image-embeddings-dim', + type=int, + default=0, + help='Dimension of image embeddings (2 or 3). Defaults to 0, corresponding to no embeddings.') + parser.add_argument('--tlc-collect-loss', + dest='collect_loss', + action='store_true', + help='Collect loss for each image during metrics collection.') + + opt = parser.parse_args() + opt.data = check_yaml(opt.data) # check YAML + print_args(vars(opt)) + return opt + + +def main(opt): + check_requirements(ROOT / 'requirements.txt', exclude=('tensorboard', 'thop')) + + run(**vars(opt)) + + +if __name__ == '__main__': + opt = parse_opt() + main(opt) diff --git a/models/yolo.py b/models/yolo.py index 4f4d567bec73..984c177bb36c 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -21,6 +21,7 @@ if platform.system() != 'Windows': ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative +import utils.tlc_integration.collectors as collectors from models.common import * # noqa from models.experimental import * # noqa from utils.autoanchor import check_anchor_order @@ -119,6 +120,9 @@ def _forward_once(self, x, profile=False, visualize=False): if profile: self._profile_one_layer(m, x, dt) x = m(x) # run + if 'SPPF' in m.type and hasattr(self, 'collecting') and self.collecting: + activations = x.mean(dim=(2, 3)) + collectors.ACTIVATIONS.append(activations) y.append(x if m.i in self.save else None) # save output if visualize: feature_visualization(x, m.type, m.i, save_dir=visualize) @@ -185,6 +189,7 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist self.names = [str(i) for i in range(self.yaml['nc'])] # default names self.inplace = self.yaml.get('inplace', True) + self.collecting = False # Build strides, anchors m = self.model[-1] # Detect() diff --git a/train.py b/train.py index 4c3bec34835f..b7a63c04439f 100644 --- a/train.py +++ b/train.py @@ -45,24 +45,27 @@ sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative +import tlc +from tlc.core.builtins.constants.column_names import BOUNDING_BOXES + import val as validate # for end-of-epoch mAP from models.experimental import attempt_load from models.yolo import Model from utils.autoanchor import check_anchors from utils.autobatch import check_train_batch_size from utils.callbacks import Callbacks -from utils.dataloaders import create_dataloader from utils.downloads import attempt_download, is_url from utils.general import (LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_git_info, check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, intersect_dicts, labels_to_class_weights, - labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer, - yaml_save) + methods, one_cycle, print_args, print_mutation, strip_optimizer, yaml_save) from utils.loggers import Loggers from utils.loggers.comet.comet_utils import check_comet_resume from utils.loss import ComputeLoss from utils.metrics import fitness from utils.plots import plot_evolve +from utils.tlc_integration import TLCComputeLoss, create_dataloader, get_or_create_tlc_table +from utils.tlc_integration.utils import create_tlc_info_string_before_training from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer, smart_resume, torch_distributed_zero_first) @@ -71,6 +74,8 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) GIT_INFO = check_git_info() +import collect + def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \ @@ -78,6 +83,13 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze callbacks.run('on_pretrain_routine_start') + if opt.tlc_image_embeddings_dim in [2, 3]: + # We need to ensure we have UMAP installed + try: + import umap # noqa: F401 + except ImportError: + raise ValueError('Missing UMAP dependency, run `pip install umap-learn` to enable embeddings collection.') + # Directories w = save_dir / 'weights' # weights dir (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir @@ -90,7 +102,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) opt.hyp = hyp.copy() # for saving hyps to checkpoints - # Save run settings + # Save run s2ettings if not evolve: yaml_save(save_dir / 'hyp.yaml', hyp) yaml_save(save_dir / 'opt.yaml', vars(opt)) @@ -114,10 +126,34 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio cuda = device.type != 'cpu' init_seeds(opt.seed + 1 + RANK, deterministic=True) with torch_distributed_zero_first(LOCAL_RANK): - data_dict = data_dict or check_dataset(data) # check if None + + train_table, val_table = (get_or_create_tlc_table( + yolo_yaml_file=data, + root_url=opt.tlc_root_url, + split=split, + revision_url=opt.tlc_train_revision_url if split == 'train' else opt.tlc_val_revision_url, + ) for split in ('train', 'val')) + + if data and not data_dict: + # If data argument is not provided, this means 3LC revision Urls have been provided + data_dict = check_dataset(data) + + # Make sure train and val tables have same number of classes + train_categories = train_table.get_value_map_for_column(BOUNDING_BOXES) + val_categories = val_table.get_value_map_for_column(BOUNDING_BOXES) + + if not train_categories == val_categories: + raise ValueError('Train and val tables have different categories. Need to have the same categories.') + + nc = 1 if single_cls else len(train_categories) + names = {0: 'item'} if single_cls and len(train_categories) != 1 else train_categories + if not data_dict: + # --data argument was explicitly set empty, use 3LC Table to + # populate data dict as far as possible. + data_dict = {'nc': nc, 'train': None, 'val': None} + train_path, val_path = data_dict['train'], data_dict['val'] - nc = 1 if single_cls else int(data_dict['nc']) # number of classes - names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names + is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset # Model @@ -192,40 +228,47 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio LOGGER.info('Using SyncBatchNorm()') # Trainloader - train_loader, dataset = create_dataloader(train_path, - imgsz, - batch_size // WORLD_SIZE, - gs, - single_cls, - hyp=hyp, - augment=True, - cache=None if opt.cache == 'val' else opt.cache, - rect=opt.rect, - rank=LOCAL_RANK, - workers=workers, - image_weights=opt.image_weights, - quad=opt.quad, - prefix=colorstr('train: '), - shuffle=True, - seed=opt.seed) + train_loader, dataset = create_dataloader( + train_path, + imgsz, + batch_size // WORLD_SIZE, + gs, + single_cls, + hyp=hyp, + augment=True, + cache=False, # None if opt.cache == 'val' else opt.cache, + rect=opt.rect, + rank=LOCAL_RANK, + workers=workers, + image_weights=opt.image_weights, + quad=opt.quad, + prefix=colorstr('train: '), + shuffle=True, + seed=opt.seed, + table=train_table, + tlc_sampling_weights=not opt.tlc_disable_sample_weights, # Use sampling weights in training + ) labels = np.concatenate(dataset.labels, 0) mlc = int(labels[:, 0].max()) # max label class assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}' # Process 0 if RANK in {-1, 0}: - val_loader = create_dataloader(val_path, - imgsz, - batch_size // WORLD_SIZE * 2, - gs, - single_cls, - hyp=hyp, - cache=None if noval else opt.cache, - rect=True, - rank=-1, - workers=workers * 2, - pad=0.5, - prefix=colorstr('val: '))[0] + val_loader = create_dataloader( + val_path, + imgsz, + batch_size // WORLD_SIZE * 2, + gs, + single_cls, + hyp=hyp, + cache=False, # None if noval else opt.cache, + rect=True, + rank=-1, + workers=workers * 2, + pad=0.5, + prefix=colorstr('val: '), + table=val_table, + )[0] if not resume: if not opt.noautoanchor: @@ -249,6 +292,43 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights model.names = names + # 3LC Setup -------------------------------------------------------------------------------------------------------- + tlc_metrics_collection_epochs = range(opt.tlc_mc_start, epochs + + 1, opt.tlc_mc_interval) if not opt.tlc_disable_mc else [] + tlc_mc_string = create_tlc_info_string_before_training(tlc_metrics_collection_epochs, opt.tlc_mc_before_training) + LOGGER.info(colorstr('3LC: ') + tlc_mc_string) + compute_loss = ComputeLoss(model) # init loss class + + if RANK in {-1, 0} and not opt.tlc_disable_mc and opt.tlc_mc_collect_loss: + m = de_parallel(model).model[-1] # Detect() module + compute_loss_tlc = TLCComputeLoss('cpu', model.hyp, m.stride, m.na, m.nc, m.nl, + m.anchors) # init loss class on cpu + else: + compute_loss_tlc = None + + if RANK in {-1, 0} and not opt.tlc_disable_mc: + run = tlc.init(train_table.project_name) + + # Collect metrics prior to training? + if opt.tlc_mc_before_training: + for table in (train_table, val_table): + collect.run( + model=ema.ema, + table=table, + stride=gs, + epoch=-1, + imgsz=imgsz, # Infer on the same size images + conf_thres=opt.tlc_mc_conf_thres, + iou_thres=opt.tlc_mc_nms_iou_thres, + max_det=opt.tlc_mc_max_det, + tlc_iou_thres=opt.tlc_mc_iou_thres, + workers=workers, + half=amp, + tlc_image_embeddings_dim=opt.tlc_image_embeddings_dim, + compute_loss=compute_loss_tlc) + + # 3LC Setup End ---------------------------------------------------------------------------------------------------- + # Start training t0 = time.time() nb = len(train_loader) # number of batches @@ -260,21 +340,25 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio scheduler.last_epoch = start_epoch - 1 # do not move scaler = torch.cuda.amp.GradScaler(enabled=amp) stopper, stop = EarlyStopping(patience=opt.patience), False - compute_loss = ComputeLoss(model) # init loss class callbacks.run('on_train_start') + LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n' f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n' f"Logging results to {colorstr('bold', save_dir)}\n" f'Starting training for {epochs} epochs...') + for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ callbacks.run('on_train_epoch_start') model.train() # Update image weights (optional, single-GPU only) - if opt.image_weights: - cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights - iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights - dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx + # if opt.image_weights: + # cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights + # iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights + # dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx + + # Resample train dataset images (only resamples if enabled) + train_loader.dataset.resample(epoch) # Update mosaic border (optional) # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs) @@ -356,17 +440,19 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) final_epoch = (epoch + 1 == epochs) or stopper.possible_stop if not noval or final_epoch: # Calculate mAP - results, maps, _ = validate.run(data_dict, - batch_size=batch_size // WORLD_SIZE * 2, - imgsz=imgsz, - half=amp, - model=ema.ema, - single_cls=single_cls, - dataloader=val_loader, - save_dir=save_dir, - plots=False, - callbacks=callbacks, - compute_loss=compute_loss) + results, maps, _ = validate.run( + data_dict, + batch_size=batch_size // WORLD_SIZE * 2, + imgsz=imgsz, + half=amp, + model=ema.ema, + single_cls=single_cls, + dataloader=val_loader, + save_dir=save_dir, + plots=False, + callbacks=callbacks, + compute_loss=compute_loss, + ) # Update best mAP fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] @@ -407,6 +493,26 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio if stop: break # must break all DDP ranks + # TLC Metrics Collection START --------------------------------------------------------------------------------- + if not opt.tlc_disable_mc: + if epoch in tlc_metrics_collection_epochs and RANK in {-1, 0}: + for table in (train_table, val_table): + collect.run( + model=ema.ema, + table=table, + epoch=epoch, + stride=gs, + imgsz=imgsz, # Infer on the same size images + conf_thres=opt.tlc_mc_conf_thres, + iou_thres=opt.tlc_mc_nms_iou_thres, + max_det=opt.tlc_mc_max_det, + tlc_iou_thres=opt.tlc_mc_iou_thres, + workers=workers, + half=amp, + tlc_image_embeddings_dim=opt.tlc_image_embeddings_dim, + compute_loss=compute_loss_tlc) + + # TLC Metrics Collection END ----------------------------------------------------------------------------------- # end epoch ---------------------------------------------------------------------------------------------------- # end training ----------------------------------------------------------------------------------------------------- if RANK in {-1, 0}: @@ -429,13 +535,20 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio verbose=True, plots=plots, callbacks=callbacks, - compute_loss=compute_loss) # val best model with plots + compute_loss=compute_loss, # val best model with plots + tlc_discard_non_zero_preds=opt.tlc_discard_non_zero_preds, + ) if is_coco: callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi) callbacks.run('on_train_end', last, best, epoch, results) torch.cuda.empty_cache() + if RANK in {-1, 0} and opt.tlc_image_embeddings_dim in (2, 3): + # Reduce the embeddings to the desired dimensionality + run.reduce_embeddings_by_example_table_url(train_table.url, n_components=opt.tlc_image_embeddings_dim) + if not opt.tlc_disable_mc: + tlc.close() return results @@ -476,6 +589,73 @@ def parse_opt(known=False): parser.add_argument('--seed', type=int, default=0, help='Global training seed') parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify') + # 3LC arguments + parser.add_argument('--tlc-disable-mc', + '--tlc-disable-metrics-collection', + dest='tlc_disable_mc', + action='store_true', + help='Disable 3LC metrics collection.') + parser.add_argument('--tlc-mc-interval', + '--tlc-metrics-collection-interval', + dest='tlc_mc_interval', + type=int, + default=1, + help='Epoch interval between metrics collections.') + parser.add_argument('--tlc-mc-start', + '--tlc-metrics-collection-start', + dest='tlc_mc_start', + type=int, + default=0, + help='Epoch to start collecting metrics. Defaults to first epoch (0).') + parser.add_argument('--tlc-mc-before-training', + '--tlc-metrics-collection-before-training', + dest='tlc_mc_before_training', + action='store_true', + help='Collect metrics before training.') + parser.add_argument('--tlc-mc-iou-thres', + '--tlc-metrics-collection-iou-threshold', + dest='tlc_mc_iou_thres', + type=float, + default=0.3, + help='IoU threshold for 3LC metrics collection.') + parser.add_argument('--tlc-mc-conf-thres', + '--tlc-metrics-collection-confidence-threshold', + dest='tlc_mc_conf_thres', + type=float, + default=0.25, + help='NMS Confidence threshold for metrics collection') + parser.add_argument('--tlc-mc-nms-iou-thres', + '--tlc-metrics-collection-nms-iou-threshold', + dest='tlc_mc_nms_iou_thres', + type=float, + default=0.45, + help='IoU threshold for metrics collection NMS') + parser.add_argument('--tlc-mc-max-det', + '--tlc-metrics-collection-max-det', + dest='tlc_mc_max_det', + type=int, + default=300, + help='Maximum number of detections per image for metrics collection') + parser.add_argument('--tlc-mc-collect-loss', action='store_true', help='Collect loss during metrics collection') + parser.add_argument('--tlc-train-revision-url', + type=str, + default='', + help='Train dataset revision, defaults to latest.') + parser.add_argument('--tlc-val-revision-url', + type=str, + default='', + help='Validation dataset revision, defaults to latest.') + parser.add_argument('--tlc-root-url', type=str, default=None, help='Root URL for datasets. Defaults to None.') + parser.add_argument('--tlc-disable-sample-weights', action='store_true', help='Disable sampling weights.') + parser.add_argument('--tlc-discard-non-zero-preds', + action='store_true', + help='Discard predictions with class != 0 before validating') + parser.add_argument( + '--tlc-image-embeddings-dim', + type=int, + default=0, + help='Dimension of image embeddings (2 or 3). Default is 0, which means no image embeddings are used.') + # Logger arguments parser.add_argument('--entity', default=None, help='Entity') parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='Upload data, "val" option') diff --git a/utils/general.py b/utils/general.py index 135141e21436..61434a108430 100644 --- a/utils/general.py +++ b/utils/general.py @@ -899,6 +899,7 @@ def non_max_suppression( max_wh = 7680 # (pixels) maximum box width and height max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() time_limit = 0.5 + 0.05 * bs # seconds to quit after + time_limit *= 3 redundant = True # require redundant detections multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) merge = False # use merge-NMS diff --git a/utils/tlc_integration/__init__.py b/utils/tlc_integration/__init__.py new file mode 100644 index 000000000000..5457f4dfa1f1 --- /dev/null +++ b/utils/tlc_integration/__init__.py @@ -0,0 +1,9 @@ +# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license +from .collectors import (NoOpMetricsCollectorWithModel, Preprocessor, YOLOv5BoundingBoxMetricsCollector, + tlc_create_metrics_collectors) +from .dataloaders import create_dataloader +from .loss import TLCComputeLoss +from .utils import get_or_create_tlc_table + +__all__ = (NoOpMetricsCollectorWithModel, Preprocessor, YOLOv5BoundingBoxMetricsCollector, + tlc_create_metrics_collectors, create_dataloader, get_or_create_tlc_table, TLCComputeLoss) diff --git a/utils/tlc_integration/collectors.py b/utils/tlc_integration/collectors.py new file mode 100644 index 000000000000..d07a17cac5d1 --- /dev/null +++ b/utils/tlc_integration/collectors.py @@ -0,0 +1,248 @@ +# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license +""" +Metrics collectors - 3LC integration +""" +import torch +from tlc.client.torch.metrics.metrics_collectors.bounding_box_metrics_collector import (YOLOAnnotation, + YOLOBoundingBoxMetricsCollector, + YOLOGroundTruth, YOLOPrediction) +from tlc.client.torch.metrics.metrics_collectors.metrics_collector_base import MetricsCollectorBase +from tlc.core.builtins.constants.number_roles import NUMBER_ROLE_NN_EMBEDDING +from tlc.core.schema import DimensionNumericValue, Float32Value, Schema + +from ..general import non_max_suppression, scale_boxes, xywh2xyxy, xyxy2xywhn +from .utils import xyxy_to_xywh + +ACTIVATIONS = [] + + +class YOLOv5BoundingBoxMetricsCollector(YOLOBoundingBoxMetricsCollector): + """A YOLOv5 specific bounding box metrics collector.""" + + def __init__(self, *args, collect_embeddings=False, compute_loss=None, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Support other layers + # Collecting embeddings for the last backbone layer. + # This is the output of the SPPF layer with mean pooling the spatial dimensions, resulting + # in a 512-dimensional vector. + self._collect_embeddings = collect_embeddings + self._compute_loss = compute_loss + self._activation_size = 512 # 512 for most models. TODO: Infer this - for special YOLOv5 models with different size. + + def compute_metrics(self, batch, _1=None, _2=None): + with torch.no_grad(): + device = next(self.model.parameters()).device # get sample device + # half precision only supported on CUDA and only when model is fp16 + half = (device.type != 'cpu' and next(self.model.parameters()).dtype == torch.float16) + self.model.eval() + im, targets, paths, shapes = batch + im = im.to(device, non_blocking=True) + targets = targets.to(device) + im = im.half() if half else im.float() + im /= 255 + nb, _, height, width = im.shape # batch size, channels, height, width + + # Forward + preds, train_out = self.model(im, augment=False) + + # Read the activations written during the forward pass for the batch + # We remove it once it has been read + if self._collect_embeddings: + assert len(ACTIVATIONS) == 1 + activations = ACTIVATIONS.pop() + + assert len(ACTIVATIONS) == 0 + + metrics = super().compute_metrics(batch, preds) + + # Compute and add loss values to metrics + # TODO: Batched computation + if self._compute_loss is not None: + metrics.update({'loss': [], 'box_loss': [], 'obj_loss': [], 'cls_loss': []}) + train_out = [to.cpu() for to in train_out] + targets = targets.cpu() + # Pretend batch size is 1 and compute loss for each sample + for i in range(nb): + train_out_sample = [to[i:i + 1, ...] + for to in train_out] # Get the train_out for the sample, keep batch dim + targets_sample = targets[targets[:, 0] == i, :] # Get the targets for the sample + targets_sample[:, 0] = 0 # Set the batch index to 0 + losses = self._compute_loss(train_out_sample, targets_sample)[1].numpy() + metrics['loss'].append(losses.sum()) + metrics['box_loss'].append(losses[0]) + metrics['obj_loss'].append(losses[1]) + metrics['cls_loss'].append(losses[2]) + + # Add embeddings to metrics + if self._collect_embeddings and activations is not None: + metrics['embeddings'] = activations.cpu().numpy() + + return metrics + + @property + def column_schemas(self): + _column_schemas = super().column_schemas + + # Loss schemas + if self._compute_loss is not None: + _column_schemas['loss'] = Schema(description='Sample loss', + writable=False, + value=Float32Value(), + display_importance=3003) + _column_schemas['box_loss'] = Schema(description='Box loss', + writable=False, + value=Float32Value(), + display_importance=3004) + _column_schemas['obj_loss'] = Schema(description='Object loss', + writable=False, + value=Float32Value(), + display_importance=3005) + _column_schemas['cls_loss'] = Schema(description='Classification loss', + writable=False, + value=Float32Value(), + display_importance=3006) + + # Embedding schema + if self._collect_embeddings: + embedding_schema = Schema('Embedding', + 'Large NN embedding', + writable=False, + computable=False, + value=Float32Value(number_role=NUMBER_ROLE_NN_EMBEDDING)) + # 512 for all YOLO detection models + embedding_schema.size0 = DimensionNumericValue(value_min=self._activation_size, + value_max=self._activation_size, + enforce_min=True, + enforce_max=True) + _column_schemas['embeddings'] = embedding_schema + + return _column_schemas + + +class NoOpMetricsCollectorWithModel(MetricsCollectorBase): + """ This metrics collector does nothing, except to block 3LC from performing a forward pass. + + """ + + def compute_metrics(self, batch, predictions=None, hook_outputs=None): + return {} + + @property + def model(self): + return torch.nn.Identity() + + +class Preprocessor: + + def __init__(self, nms_kwargs): + self.nms_kwargs = nms_kwargs + + def __call__(self, batch, predictions): + # Apply NMS + predictions = non_max_suppression(predictions, **self.nms_kwargs) + + images, targets, paths, shapes = batch + batch_size = len(paths) + + # Ground truth + processed_batch = [] + + nb, _, height, width = images.shape # batch size, channels, height, width + targets = targets.cpu() + targets[:, 2:] *= torch.tensor((width, height, width, height), device=targets.device) + + for i in range(batch_size): + height, width = shapes[i][0] + + labels = targets[targets[:, 0] == i, 1:] + tbox = xywh2xyxy(labels[:, 1:5]) # target boxes + scale_boxes(images[i].shape[1:], tbox, shapes[i][0], shapes[i][1]) + # This is xyxy scaled boxes. Now go back to xywh-normalized and write these + labelsn = torch.cat((labels[:, 0:1], tbox), 1) # normalized labels + xywh_boxes = xyxy2xywhn(labelsn[:, 1:5], w=width, h=height) + num_boxes = labelsn.shape[0] + + # Create annotations with YOLO format + annotations = [ + YOLOAnnotation( + category_id=labels[j, 0], + bbox=xywh_boxes[j], + score=1.0, + ) for j in range(num_boxes)] + ground_truth = YOLOGroundTruth( + file_name=paths[i], + height=height, + width=width, + annotations=annotations, + ) + processed_batch.append(ground_truth) + + # Predictions + processed_predictions = [] + for i, prediction in enumerate(predictions): + height, width = shapes[i][0] + scaled_boxes = scale_boxes( + images[i].shape[1:], + prediction[:, :4], + shapes[i][0], + shapes[i][1], + ) + prediction = prediction.cpu().numpy() + annotations = [ + YOLOAnnotation( + category_id=prediction[j, 5], + bbox=xyxy_to_xywh(scaled_boxes[j, :].tolist(), height=height, width=width), + score=prediction[j, 4], + ) for j in range(prediction.shape[0])] + processed_predictions.append(YOLOPrediction(annotations=annotations)) + + return processed_batch, processed_predictions + + +def tlc_create_metrics_collectors(model, + names, + conf_thres: float = 0.45, + nms_iou_thres: float = 0.45, + max_det: int = 300, + iou_thres: float = 0.4, + compute_embeddings: bool = False, + compute_loss=None): + """ Sets up the default metrics collectors for YOLO bounding box metrics collection. + + :param model: The model to use for metrics collection. + :param conf_thres: Confidence threshold for predictions. Anything under is discarded. + :param nms_iou_thres: IoU threshold to use for NMS. Boxes with IoU > nms_iou_thres are + collapsed to the one with greatest confidence. + :param max_det: Maximum number of detections for a sample. + :param iou_thres: IoU threshold to use for computing True Positives. + :param compute_embeddings: Whether to compute embeddings for each sample. + :param compute_loss: Function to compute loss for each sample. + + :returns metrics_collectors: A list of metrics collectors to use. + + """ + nms_kwargs = { + 'conf_thres': conf_thres, + 'iou_thres': nms_iou_thres, + 'classes': None, # TODO: Add this? Filters to subset of classes. + 'agnostic': False, # TODO: Add this as a kwarg option? 3LC doesn't really support it? + 'max_det': max_det} + + preprocess_fn = Preprocessor(nms_kwargs) + metrics_collectors = [ + YOLOv5BoundingBoxMetricsCollector( + model=model, + classes=list(names.values()), + label_mapping={i: i + for i in range(len(names))}, + iou_threshold=iou_thres, + compute_derived_metrics=True, + derived_metrics_mode='strict', + preprocess_fn=preprocess_fn, + collect_embeddings=compute_embeddings, + compute_loss=compute_loss, + ), + NoOpMetricsCollectorWithModel(metric_names=[]), # Avoids extra 3LC forward pass + ] + return metrics_collectors diff --git a/utils/tlc_integration/dataloaders.py b/utils/tlc_integration/dataloaders.py new file mode 100644 index 000000000000..f4c385674a9d --- /dev/null +++ b/utils/tlc_integration/dataloaders.py @@ -0,0 +1,330 @@ +# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license +""" +Dataloaders and dataset utils - 3LC integration +""" +import os +from collections import Counter +from multiprocessing.pool import ThreadPool +from pathlib import Path + +import numpy as np +import tlc +import torch +from PIL import Image, ImageOps +from tlc.core.builtins.constants.column_names import BOUNDING_BOXES, HEIGHT, IMAGE, SAMPLE_WEIGHT, WIDTH +from tlc.core.url import Url +from torch.utils.data import DataLoader, distributed +from tqdm import tqdm + +from utils.augmentations import Albumentations +from utils.general import LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, colorstr, cv2 +from utils.torch_utils import torch_distributed_zero_first + +from ..dataloaders import InfiniteDataLoader, LoadImagesAndLabels, img2label_paths, seed_worker +from .utils import tlc_table_row_to_yolo_label + +LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html +RANK = int(os.getenv('RANK', -1)) +PIN_MEMORY = (str(os.getenv('PIN_MEMORY', True)).lower() == 'true') # global pin_memory for dataloaders + + +def create_dataloader( + path, + imgsz, + batch_size, + stride, + single_cls=False, + hyp=None, + augment=False, + cache=False, + pad=0.0, + rect=False, + rank=-1, + workers=8, + image_weights=False, + quad=False, + prefix='', + shuffle=False, + seed=0, + table=None, + tlc_sampling_weights=False, +): + if rect and shuffle: + LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False') + shuffle = False + + assert table is not None, 'table must be provided' + tlc_prefix = colorstr('3LC: ') + + with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP + LOGGER.info(f'{tlc_prefix}Creating dataloader for {table.dataset_name} dataset') + dataset = TLCLoadImagesAndLabels( + path, + imgsz, + batch_size, + augment=augment, # augmentation + hyp=hyp, # hyperparameters + rect=rect, # rectangular batches + cache_images=cache, + single_cls=single_cls, + stride=int(stride), + pad=pad, + image_weights=image_weights, + prefix=tlc_prefix, + table=table, + tlc_sampling_weights=tlc_sampling_weights, + ) + + batch_size = min(batch_size, len(dataset)) + nd = torch.cuda.device_count() # number of CUDA devices + nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers + sampler = (None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)) + loader = (DataLoader if (image_weights or tlc_sampling_weights) else InfiniteDataLoader + ) # only DataLoader allows for attribute updates + generator = torch.Generator() + generator.manual_seed(6148914691236517205 + seed + RANK) + return ( + loader( + dataset, + batch_size=batch_size, + shuffle=shuffle and sampler is None, + num_workers=nw, + sampler=sampler, + pin_memory=PIN_MEMORY, + collate_fn=TLCLoadImagesAndLabels.collate_fn4 if quad else TLCLoadImagesAndLabels.collate_fn, + worker_init_fn=seed_worker, + generator=generator, + ), + dataset, + ) + + +class TLCLoadImagesAndLabels(LoadImagesAndLabels): + # YOLOv5 train_loader/val_loader, loads images and labels for training and validation + cache_version = 0.6 # dataset labels *.cache version + rand_interp_methods = [ + cv2.INTER_NEAREST, + cv2.INTER_LINEAR, + cv2.INTER_CUBIC, + cv2.INTER_AREA, + cv2.INTER_LANCZOS4, ] + + def __init__( + self, + path, + img_size=640, + batch_size=16, + augment=False, + hyp=None, + rect=False, + image_weights=False, + cache_images=False, + single_cls=False, + stride=32, + pad=0.0, + min_items=0, + prefix='', + table=None, + tlc_sampling_weights=False, + ): + self.img_size = img_size + self.augment = augment + self.hyp = hyp + self.image_weights = image_weights + self.rect = False if image_weights else rect + self.mosaic = (self.augment and not self.rect) # load 4 images at a time into a mosaic (only during training) + self.mosaic_border = [-img_size // 2, -img_size // 2] + self.stride = stride + self.path = path + self.prefix = prefix + self.albumentations = Albumentations(size=img_size) if augment else None + + if rect and tlc_sampling_weights: + raise ValueError('Rectangular training is not compatible with 3LC sampling weights') + + self.tlc_use_sampling_weights = tlc_sampling_weights + if tlc_sampling_weights: + LOGGER.info(f'{prefix}Using 3LC sampling weights') + + # Get 3lc table - read the yolo image and label paths and any revisions + self.categories = table.get_value_map_for_column(BOUNDING_BOXES) + self.tlc_name = table.dataset_name + self.tlc_table_url = table.url.to_str() + + self.sampling_weights = [] + self.im_files = [] + self.shapes = [] + self.labels = [] + + num_fixed, num_corrupt = 0, 0 + msgs = [] + + pbar = iter(table.table_rows) + if RANK in {-1, 0}: + pbar = tlc.track(pbar, description=f'Loading data from 3LC Table {table.url.name}', total=len(table)) + + for row in pbar: + im_file = Url(row[IMAGE]).to_absolute().to_str() + fixed, corrupt, msg = fix_image(im_file) + if msg: + msgs.append(msg) + num_fixed += int(fixed) + num_corrupt += int(corrupt) + + # Ignore corrupt images when training or validating + # Don't ignore when collecting metrics since the dataset length will change + if not corrupt or table.collecting_metrics: + self.sampling_weights.append(row[SAMPLE_WEIGHT]) + self.im_files.append(str(Path(im_file))) # Ensure path is os.sep-delimited + self.shapes.append((row[WIDTH], row[HEIGHT])) + self.labels.append(tlc_table_row_to_yolo_label(row)) + + self.shapes = np.array(self.shapes) + self.sampling_weights = np.array(self.sampling_weights) + self.sampling_weights = self.sampling_weights / np.sum(self.sampling_weights) + + if num_fixed > 0 or num_corrupt > 0: + LOGGER.info(f'Fixed {num_fixed} images. Found and ignored {num_corrupt} corrupt images') + + if len(msgs) > 0: + LOGGER.info('\n'.join(msgs)) + + n = len(self.im_files) + self.label_files = img2label_paths( + self.im_files) # .label_files is not really used in the 3LC integration, as labels are stored in the table + self.segments = tuple([] for _ in range(n)) # TODO: Add segmentation support + + # Filter images + if min_items: + include = (np.array([len(x) >= min_items for x in self.labels]).nonzero()[0].astype(int)) + LOGGER.info(f'{prefix}{n - len(include)}/{n} images filtered from dataset') + self.im_files = [self.im_files[i] for i in include] + self.label_files = [self.label_files[i] for i in include] + self.labels = [self.labels[i] for i in include] + self.segments = [self.segments[i] for i in include] + self.shapes = self.shapes[include] # wh + + # Create indices + bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index + nb = bi[-1] + 1 # number of batches + self.batch = bi # batch index of image + self.n = n + self.indices = range(n) + + # Update labels + include_class = [] # filter labels to include only these classes (optional) + self.segments = list(self.segments) + include_class_array = np.array(include_class).reshape(1, -1) + for i, (label, segment) in enumerate(zip(self.labels, self.segments)): + if include_class: + j = (label[:, 0:1] == include_class_array).any(1) + self.labels[i] = label[j] + if segment: + self.segments[i] = [segment[idx] for idx, elem in enumerate(j) if elem] + if single_cls: # single-class training, merge all classes into 0 + self.labels[i][:, 0] = 0 + + # Rectangular Training + if self.rect: + # Sort by aspect ratio + s = self.shapes # wh + ar = s[:, 1] / s[:, 0] # aspect ratio + irect = ar.argsort() + self.im_files = [self.im_files[i] for i in irect] + self.label_files = [self.label_files[i] for i in irect] + self.labels = [self.labels[i] for i in irect] + self.segments = [self.segments[i] for i in irect] + self.shapes = s[irect] # wh + ar = ar[irect] + self.sampling_weights = self.sampling_weights[irect] + + # Set training image shapes + shapes = [[1, 1]] * nb + for i in range(nb): + ari = ar[bi == i] + mini, maxi = ari.min(), ari.max() + if maxi < 1: + shapes[i] = [maxi, 1] + elif mini > 1: + shapes[i] = [1, 1 / mini] + + self.batch_shapes = (np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride) + + # Cache images into RAM/disk for faster training + if cache_images == 'ram' and not self.check_cache_ram(prefix=prefix): + cache_images = False + self.ims = [None] * n + self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files] + if cache_images: + b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes + self.im_hw0, self.im_hw = [None] * n, [None] * n + fcn = (self.cache_images_to_disk if cache_images == 'disk' else self.load_image) + results = ThreadPool(NUM_THREADS).imap(fcn, range(n)) + pbar = tqdm( + enumerate(results), + total=n, + bar_format=TQDM_BAR_FORMAT, + disable=LOCAL_RANK > 0, + ) + for i, x in pbar: + if cache_images == 'disk': + b += self.npy_files[i].stat().st_size + else: # 'ram' + ( + self.ims[i], + self.im_hw0[i], + self.im_hw[i], + ) = x # im, hw_orig, hw_resized = load_image(self, i) + b += self.ims[i].nbytes + pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})' + pbar.close() + + @staticmethod + def _print_histogram_with_buckets(data, num_buckets=10) -> None: + # Bucketing data + bucket_size = (max(data) - min(data)) // num_buckets + bucketed_data = [x // bucket_size * bucket_size for x in data] + counter = Counter(bucketed_data) + max_count = max(counter.values()) + + for value in range(min(bucketed_data), max(bucketed_data) + bucket_size, bucket_size): + count = counter.get(value, 0) + bar = '*' * int(count / max_count * 50) # Scale the bar length + LOGGER.info(f'{value:5} - {value + bucket_size - 1:5} | {bar}') + + def resample(self, epoch=None): + if self.tlc_use_sampling_weights: + # Seed such that each process does the same sampling + if epoch is not None: + np.random.seed(epoch) + LOGGER.info(f'{self.prefix}Resampling dataset for epoch {epoch}') + # Sample indices weighted by 3LC sampling weight + self.indices = np.random.choice( + len(self.indices), + size=len(self.indices), + replace=True, + p=self.sampling_weights, + ) + + +def fix_image(im_file): + fixed = False + corrupt = False + msg = '' + + # From utils/dataloaders.py + try: + im = Image.open(im_file) + if im.format.lower() in ('jpg', 'jpeg'): + with open(im_file, 'rb') as f: + f.seek(-2, 2) + if f.read() != b'\xff\xd9': # corrupt JPEG + ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100) + msg = f'WARNING ⚠️ {im_file}: corrupt JPEG restored and saved' + fixed = True + + except Exception as e: + msg = f'WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}' + corrupt = True + + return fixed, corrupt, msg diff --git a/utils/tlc_integration/loss.py b/utils/tlc_integration/loss.py new file mode 100644 index 000000000000..931cd9575a20 --- /dev/null +++ b/utils/tlc_integration/loss.py @@ -0,0 +1,35 @@ +# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license +""" +Loss functions per sample - WIP +""" +import torch +import torch.nn as nn + +from ..loss import ComputeLoss, FocalLoss, smooth_BCE + + +class TLCComputeLoss(ComputeLoss): + # Compute losses + def __init__(self, device, h, stride, na, nc, nl, anchors, autobalance=False): + + # Define criteria + BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) + BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device)) + + # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 + self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets + + # Focal loss + g = h['fl_gamma'] # focal loss gamma + if g > 0: + BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) + + self.balance = {3: [4.0, 1.0, 0.4]}.get(nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7 + self.ssi = list(stride).index(16) if autobalance else 0 # stride 16 index + self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance + self.na = na # number of anchors + self.nc = nc # number of classes + self.nl = nl # number of layers + self.device = device + + self.anchors = torch.clone(anchors).detach().to(self.device) diff --git a/utils/tlc_integration/utils.py b/utils/tlc_integration/utils.py new file mode 100644 index 000000000000..5b501ee06481 --- /dev/null +++ b/utils/tlc_integration/utils.py @@ -0,0 +1,179 @@ +# YOLOv5 🚀 AGPL-3.0 license +""" +3LC Utils +""" + +from pathlib import Path + +import numpy as np +import tlc +from tlc.core.objects.tables.from_url.utils import get_cache_file_name, resolve_table_url + +from utils.general import LOGGER, colorstr + + +def get_or_create_tlc_table(yolo_yaml_file, split, revision_url='', root_url=None): + """Get or create a 3LC Table for the given inputs""" + + if not yolo_yaml_file and not revision_url: + raise ValueError('Either yolo_yaml_file or revision_url must be specified') + + if not split and not revision_url: + raise ValueError('split must be specified if revision_url is not specified') + + tlc_prefix = colorstr('3LC: ') + + # Ensure complete index before resolving any Tables + tlc.TableIndexingTable.instance().ensure_fully_defined() + + # Infer dataset and project names + dataset_name_base = Path(yolo_yaml_file).stem + dataset_name = dataset_name_base + '-' + split + project_name = 'yolov5-' + dataset_name_base + + if yolo_yaml_file: # review this + yolo_yaml_file = str(Path(yolo_yaml_file).resolve()) # Ensure absolute path for resolving Table Url + + # Resolve a unique Table name using dataset_name, yaml file path, yaml file size (and optionally root_url path and size), and split to create a deterministic url + # The Table Url will be <3LC Table root> / / .json + table_url_from_yaml = resolve_table_url([yolo_yaml_file, root_url if root_url else '', split], + dataset_name, + prefix='yolo_') + + # If revision_url is specified as an argument, use that Table + if revision_url: + try: + table = tlc.Table.from_url(revision_url) + except FileNotFoundError: + raise ValueError(f'Could not find Table {revision_url} for {split} split') + + # If YAML file (--data argument) is also set, write appropriate log messages + if yolo_yaml_file: + try: + root_table = tlc.Table.from_url(table_url_from_yaml) + if not table.is_descendant_of(root_table): + LOGGER.info( + f"{tlc_prefix}Revision URL is not a descendant of the Table corresponding to the YAML file's {split} split. Ignoring YAML file." + ) + except FileNotFoundError: + LOGGER.warning( + f'{tlc_prefix}Ignoring YAML file {yolo_yaml_file} because --tlc-{split}{"-" if split else ""}revision-url is set' + ) + try: + check_table_compatibility(table) + except AssertionError as e: + raise ValueError(f'Table {revision_url} is not compatible with YOLOv5') from e + + LOGGER.info(f'{tlc_prefix}Using {split} revision {revision_url}') + else: + + try: + table = tlc.Table.from_url(table_url_from_yaml) + initial_url = table.url + table = table.latest() + latest_url = table.url + if initial_url != latest_url: + LOGGER.info(f'{tlc_prefix}Using latest version of {split} table: {latest_url.to_str()}') + else: + LOGGER.info(f'{tlc_prefix}Using root {split} table: {initial_url.to_str()}') + except FileNotFoundError: + cache_url = get_cache_file_name(table_url_from_yaml) + table = tlc.TableFromYolo( + url=table_url_from_yaml, + row_cache_url=cache_url, + dataset_name=dataset_name, + project_name=project_name, + input_url=yolo_yaml_file, + root_url=root_url, + split=split, + ) + table.get_rows_as_binary() # Force immediate creation of row cache + LOGGER.info(f'{tlc_prefix}Using {split} table {table.url}') + + try: + check_table_compatibility(table) + except AssertionError as e: + raise ValueError(f'Table {table_url_from_yaml.to_str()} is not compatible with YOLOv5') from e + + table.ensure_fully_defined() + return table + + +def unpack_box(bbox): + return [bbox[tlc.LABEL], bbox[tlc.X0], bbox[tlc.Y0], bbox[tlc.X1], bbox[tlc.Y1]] + + +def tlc_table_row_to_yolo_label(row): + unpacked = [unpack_box(box) for box in row[tlc.BOUNDING_BOXES][tlc.BOUNDING_BOX_LIST]] + arr = np.array(unpacked, ndmin=2, dtype=np.float32) + if len(unpacked) == 0: + arr = arr.reshape(0, 5) + return arr + + +def xyxy_to_xywh(xyxy, height, width): + """Converts a bounding box from XYXY_ABS to XYWH_REL format. + + :param xyxy: A bounding box in XYXY_ABS format. + :param height: The height of the image the bounding box is in. + :param width: The width of the image the bounding box is in. + + :returns: The bounding box in XYWH_REL format with XY being centered. + """ + x0, y0, x1, y1 = xyxy + return [ + (x0 + x1) / (2 * width), + (y0 + y1) / (2 * height), + (x1 - x0) / width, + (y1 - y0) / height, ] + + +def create_tlc_info_string_before_training(metrics_collection_epochs, collect_before_training=False): + """Prints the 3LC info before training. + + :param metrics_collection_epochs: The epochs to collect metrics for. + :param collect_before_training: Whether to collect metrics before training. + :param tlc_disable: Whether 3LC metrics collection is disabled. + + :returns: The 3LC info string. + """ + if not metrics_collection_epochs and not collect_before_training: + tlc_mc_string = 'Metrics collection disabled for this run.' + elif not metrics_collection_epochs and collect_before_training: + tlc_mc_string = 'Collecting metrics only before training.' + else: + plural_epochs = len(metrics_collection_epochs) > 1 + mc_epochs_str = ','.join(map(str, metrics_collection_epochs)) + before_string = 'before training and ' if collect_before_training else '' + tlc_mc_string = f'Collecting metrics {before_string}for epoch{"s" if plural_epochs else ""} {mc_epochs_str}' + + return tlc_mc_string + + +def check_table_compatibility(table: tlc.Table) -> bool: + """Check that the 3LC Table is compatible with YOLOv5""" + + row_schema = table.row_schema.values + assert tlc.IMAGE in row_schema + assert tlc.WIDTH in row_schema + assert tlc.HEIGHT in row_schema + assert tlc.BOUNDING_BOXES in row_schema + assert tlc.BOUNDING_BOX_LIST in row_schema[tlc.BOUNDING_BOXES].values + assert tlc.SAMPLE_WEIGHT in row_schema + assert tlc.LABEL in row_schema[tlc.BOUNDING_BOXES].values[tlc.BOUNDING_BOX_LIST].values + assert tlc.X0 in row_schema[tlc.BOUNDING_BOXES].values[tlc.BOUNDING_BOX_LIST].values + assert tlc.Y0 in row_schema[tlc.BOUNDING_BOXES].values[tlc.BOUNDING_BOX_LIST].values + assert tlc.X1 in row_schema[tlc.BOUNDING_BOXES].values[tlc.BOUNDING_BOX_LIST].values + assert tlc.Y1 in row_schema[tlc.BOUNDING_BOXES].values[tlc.BOUNDING_BOX_LIST].values + + X0 = row_schema[tlc.BOUNDING_BOXES].values[tlc.BOUNDING_BOX_LIST].values[tlc.X0] + Y0 = row_schema[tlc.BOUNDING_BOXES].values[tlc.BOUNDING_BOX_LIST].values[tlc.Y0] + X1 = row_schema[tlc.BOUNDING_BOXES].values[tlc.BOUNDING_BOX_LIST].values[tlc.X1] + Y1 = row_schema[tlc.BOUNDING_BOXES].values[tlc.BOUNDING_BOX_LIST].values[tlc.Y1] + + assert X0.value.number_role == tlc.NUMBER_ROLE_BB_CENTER_X + assert Y0.value.number_role == tlc.NUMBER_ROLE_BB_CENTER_Y + assert X1.value.number_role == tlc.NUMBER_ROLE_BB_SIZE_X + assert Y1.value.number_role == tlc.NUMBER_ROLE_BB_SIZE_Y + + return True diff --git a/val.py b/val.py index 8da3ef7667aa..f5cd8188c598 100644 --- a/val.py +++ b/val.py @@ -27,6 +27,7 @@ from pathlib import Path import numpy as np +import tlc import torch from tqdm import tqdm @@ -38,12 +39,12 @@ from models.common import DetectMultiBackend from utils.callbacks import Callbacks -from utils.dataloaders import create_dataloader from utils.general import (LOGGER, TQDM_BAR_FORMAT, Profile, check_dataset, check_img_size, check_requirements, check_yaml, coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args, scale_boxes, xywh2xyxy, xyxy2xywh) from utils.metrics import ConfusionMatrix, ap_per_class, box_iou from utils.plots import output_to_target, plot_images, plot_val_study +from utils.tlc_integration import create_dataloader, get_or_create_tlc_table from utils.torch_utils import select_device, smart_inference_mode @@ -125,6 +126,8 @@ def run( plots=True, callbacks=Callbacks(), compute_loss=None, + tlc_discard_non_zero_preds=False, + tlc_revision_url=None, ): # Initialize/load model and set device training = model is not None @@ -152,8 +155,20 @@ def run( batch_size = 1 # export.py models default to batch-size 1 LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') + table = get_or_create_tlc_table( + yolo_yaml_file=data, + split=task, + revision_url=tlc_revision_url, + ) + # Data - data = check_dataset(data) # check + if data: + data = check_dataset(data) # check + else: + # --data argument was explicitly set empty, use 3LC Table to + # populate data dict as far as possible. + nc = len(table.get_value_map_for_column(tlc.BOUNDING_BOXES)) + data = {'val': None, 'train': None, 'nc': nc} # Configure model.eval() @@ -172,15 +187,18 @@ def run( model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz)) # warmup pad, rect = (0.0, False) if task == 'speed' else (0.5, pt) # square inference for benchmarks task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images - dataloader = create_dataloader(data[task], - imgsz, - batch_size, - stride, - single_cls, - pad=pad, - rect=rect, - workers=workers, - prefix=colorstr(f'{task}: '))[0] + dataloader = create_dataloader( + data[task], + imgsz, + batch_size, + stride, + single_cls, + pad=pad, + rect=rect, + workers=workers, + prefix=colorstr(f'{task}: '), + table=table, + )[0] seen = 0 confusion_matrix = ConfusionMatrix(nc=nc) @@ -227,6 +245,10 @@ def run( # Metrics for si, pred in enumerate(preds): + if tlc_discard_non_zero_preds: + # Filter out predictions with class != 0. + # This is useful if you want to evaluate the model on a dataset where only label 0 is relevant. + pred = pred[pred[:, 5] == 0] labels = targets[targets[:, 0] == si, 1:] nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions path, shape = Path(paths[si]), shapes[si][0] @@ -363,6 +385,17 @@ def parse_opt(): parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference') + + # 3LC arguments + parser.add_argument('--tlc-revision-url', + type=str, + default='', + help='URL to the revision of the 3LC dataset to collect metrics for') + + parser.add_argument('--tlc-discard-non-zero-preds', + action='store_true', + help='Discard predictions with class != 0 before validating') + opt = parser.parse_args() opt.data = check_yaml(opt.data) # check YAML opt.save_json |= opt.data.endswith('coco.yaml')