From bfa2cd683cc02a9902e82ffd0e9e4161d340add2 Mon Sep 17 00:00:00 2001 From: Joshua Friedrich Date: Tue, 23 Mar 2021 15:51:54 +0100 Subject: [PATCH] Merge --- train.py | 163 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 84 insertions(+), 79 deletions(-) diff --git a/train.py b/train.py index d6909a8e7b4d..711a0aa234b1 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,4 @@ + import argparse import logging import math @@ -33,6 +34,7 @@ from utils.loss import ComputeLoss from utils.plots import plot_images, plot_labels, plot_results, plot_evolution from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel +from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id, check_wandb_config_file @@ -49,7 +51,7 @@ logger = logging.getLogger(__name__) -def train(hyp, opt, device, tb_writer=None, wandb=None): +def train(hyp, opt, device, tb_writer=None): logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) save_dir, epochs, batch_size, total_batch_size, weights, rank = \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank @@ -73,10 +75,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): init_seeds(2 + rank) with open(opt.data) as f: data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict - with torch_distributed_zero_first(rank): - check_dataset(data_dict) # check - train_path = data_dict['train'] - test_path = data_dict['val'] + is_coco = opt.data.endswith('coco.yaml') + + # Logging- Doing this before checking the dataset. Might update data_dict + if rank in [-1, 0]: + opt.hyp = hyp # add hyperparameters + run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None + wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict) + data_dict = wandb_logger.data_dict + if wandb_logger.wandb: + weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming + loggers = {'wandb': wandb_logger.wandb} # loggers dict nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check @@ -95,6 +104,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report else: model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create + with torch_distributed_zero_first(rank): + check_dataset(data_dict) # check + train_path = data_dict['train'] + test_path = data_dict['val'] # Freeze freeze = [] # parameter names to freeze (full or partial) @@ -138,17 +151,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs) - # Logging - if rank in [-1, 0] and wandb and wandb.run is None: - opt.hyp = hyp # add hyperparameters - wandb_run = wandb.init(config=opt, resume="allow", - project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem, - name=save_dir.stem, - entity=opt.entity, - group=opt.wandb_group, - id=ckpt.get('wandb_id') if 'ckpt' in locals() else None) - loggers = {'wandb': wandb} # loggers dict - # EMA ema = ModelEMA(model) if rank in [-1, 0] else None @@ -363,9 +365,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # if tb_writer: # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) # tb_writer.add_graph(model, imgs) # add model to tensorboard - elif plots and ni == 10 and wandb: - wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg') - if x.exists()]}, commit=False) + elif plots and ni == 10 and wandb_logger.wandb: + wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in + save_dir.glob('train*.jpg') if x.exists()]}) # end batch ------------------------------------------------------------------------------------------------ # end epoch ---------------------------------------------------------------------------------------------------- @@ -380,8 +382,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) final_epoch = epoch + 1 == epochs if not opt.notest or final_epoch: # Calculate mAP - results, maps, times = test.test(opt.data, - batch_size=batch_size * 2, + wandb_logger.current_epoch = epoch + 1 + results, maps, times = test.test(data_dict, + batch_size=total_batch_size, imgsz=imgsz_test, model=ema.ema, single_cls=opt.single_cls, @@ -389,8 +392,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): save_dir=save_dir, verbose=nc < 50 and final_epoch, plots=plots and final_epoch, - log_imgs=opt.log_imgs if wandb else 0, - compute_loss=compute_loss) + wandb_logger=wandb_logger, + compute_loss=compute_loss, + is_coco=is_coco) # Write with open(results_file, 'a') as f: @@ -406,8 +410,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): if tb_writer: tb_writer.add_scalar(tag, x, epoch) # tensorboard - if wandb: - wandb.log({tag: x}, step=epoch, commit=tag == tags[-1]) # W&B + if wandb_logger.wandb: + wandb_logger.log({tag: x}) # W&B # Update best mAP fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] @@ -432,36 +436,29 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): 'ema': deepcopy(ema.ema).half(), 'updates': ema.updates, 'optimizer': optimizer.state_dict(), - 'wandb_id': wandb_run.id if wandb else None} + 'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None} # Save last, best and delete torch.save(ckpt, last) if best_fitness == fi: torch.save(ckpt, best) + if wandb_logger.wandb: + if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1: + wandb_logger.log_model( + last.parent, opt, epoch, fi, best_model=best_fitness == fi) del ckpt + wandb_logger.end_epoch(best_result=best_fitness == fi) # end epoch ---------------------------------------------------------------------------------------------------- # end training - if rank in [-1, 0]: - # Strip optimizers - final = best if best.exists() else last # final model - for f in last, best: - if f.exists(): - strip_optimizer(f) - if opt.bucket: - os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload - # Plots if plots: plot_results(save_dir=save_dir) # save as results.png - if wandb: + if wandb_logger.wandb: files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]] - wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files - if (save_dir / f).exists()]}) - if opt.log_artifacts: - wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem) - + wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files + if (save_dir / f).exists()]}) # Test best.pt logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) if opt.data.endswith('coco.yaml') and nc == 80: # if COCO @@ -476,13 +473,24 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): dataloader=testloader, save_dir=save_dir, save_json=True, - plots=False) + plots=False, + is_coco=is_coco) + # Strip optimizers + final = best if best.exists() else last # final model + for f in last, best: + if f.exists(): + strip_optimizer(f) # strip optimizers + if opt.bucket: + os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload + if wandb_logger.wandb: # Log the stripped model + wandb_logger.wandb.log_artifact(str(final), type='model', + name='run_' + wandb_logger.wandb_run.id + '_model', + aliases=['last', 'best', 'stripped']) else: dist.destroy_process_group() - - wandb.run.finish() if wandb and wandb.run else None torch.cuda.empty_cache() + wandb_logger.finish_run() return results @@ -529,8 +537,6 @@ def train_ray_tune(config): parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') - parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100') - parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model') parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers') parser.add_argument('--project', default='runs/train', help='save to project/name') parser.add_argument('--entity', default=None, help='W&B entity') @@ -538,6 +544,10 @@ def train_ray_tune(config): parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') parser.add_argument('--quad', action='store_true', help='quad dataloader, experimental for >640 image sizes') parser.add_argument('--linear-lr', action='store_true', help='linear LR') + parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table') + parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B') + parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch') + parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used') parser.add_argument('--wandb-group', type=str, default='no_group', help='JF: grouping for wandb logging') parser.add_argument('--raytune', action='store_true', help='JF: optimize with ray tune') @@ -553,7 +563,8 @@ def train_ray_tune(config): check_requirements() # Resume - if opt.resume: # resume an interrupted run + wandb_run = resume_and_get_id(opt) + if opt.resume and not wandb_run: # resume an interrupted run ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' apriori = opt.global_rank, opt.local_rank @@ -586,12 +597,6 @@ def train_ray_tune(config): # Train logger.info(opt) - try: - import wandb - except ImportError: - wandb = None - prefix = colorstr('wandb: ') - logger.info(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)") if not opt.evolve: tb_writer = None # init loggers if opt.global_rank in [-1, 0]: @@ -600,34 +605,34 @@ def train_ray_tune(config): - ### Joshua Friedrich - ### Ray Tune - if opt.raytune: - print("raytune") - global_hyp = hyp - global_opt = opt - global_device = device - global_tb_writer = tb_writer - global_wandb = wandb + ### Joshua Friedrich + ### Ray Tune + if opt.raytune: + print("raytune") + global_hyp = hyp + global_opt = opt + global_device = device + global_tb_writer = tb_writer + # global_wandb = wandb - train_ray_tune({ - "alpha": tune.grid_search([0.001, 0.01, 0.1]), - "beta": tune.choice([1, 2, 3]) - }) - - exit(5) - - analysis = tune.run( - train_ray_tune, - metric=fitness, - mode=max, - config={ + train_ray_tune({ "alpha": tune.grid_search([0.001, 0.01, 0.1]), "beta": tune.choice([1, 2, 3]) - }, - fail_fast=True) - else: - train(hyp, opt, device, tb_writer, wandb) + }) + + exit(5) + + analysis = tune.run( + train_ray_tune, + metric=fitness, + mode=max, + config={ + "alpha": tune.grid_search([0.001, 0.01, 0.1]), + "beta": tune.choice([1, 2, 3]) + }, + fail_fast=True) + else: + train(hyp, opt, device, tb_writer) @@ -703,7 +708,7 @@ def train_ray_tune(config): hyp[k] = round(hyp[k], 5) # significant digits # Train mutation - results = train(hyp.copy(), opt, device, wandb=wandb) + results = train(hyp.copy(), opt, device) # Write mutation results print_mutation(hyp.copy(), results, yaml_file, opt.bucket)