diff --git a/train.py b/train.py index 85bdf1bf9a1f..27f42c9a9c1d 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,7 @@ from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -import test # import test.py to get mAP after each epoch +import test # for end-of-epoch mAP from models.experimental import attempt_load from models.yolo import Model from utils.autoanchor import check_anchors @@ -39,7 +39,11 @@ logger = logging.getLogger(__name__) -def train(hyp, opt, device, tb_writer=None): +def train(hyp, + opt, + device, + tb_writer=None + ): logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \ @@ -341,7 +345,7 @@ def train(hyp, opt, device, tb_writer=None): save_dir.glob('train*.jpg') if x.exists()]}) # end batch ------------------------------------------------------------------------------------------------ - + # Scheduler lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard scheduler.step() @@ -404,12 +408,11 @@ def train(hyp, opt, device, tb_writer=None): torch.save(ckpt, best) if wandb_logger.wandb: if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1: - wandb_logger.log_model( - last.parent, opt, epoch, fi, best_model=best_fitness == fi) + wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi) del ckpt # end epoch ---------------------------------------------------------------------------------------------------- - # end training + # end training ----------------------------------------------------------------------------------------------------- if rank in [-1, 0]: logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') if plots: