diff --git a/train.py b/train.py index e37cf816bcb1..bbf879f3af5f 100644 --- a/train.py +++ b/train.py @@ -134,6 +134,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): wandb_run = wandb.init(config=opt, resume="allow", project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem, name=save_dir.stem, + entity=opt.entity, id=ckpt.get('wandb_id') if 'ckpt' in locals() else None) loggers = {'wandb': wandb} # loggers dict @@ -467,6 +468,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model') parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers') parser.add_argument('--project', default='runs/train', help='save to project/name') + parser.add_argument('--entity', default=None, help='W&B entity') parser.add_argument('--name', default='exp', help='save to project/name') parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') parser.add_argument('--quad', action='store_true', help='quad dataloader')