diff --git a/train.py b/train.py index 8f206a9401c5..5ad47fe9ea6a 100644 --- a/train.py +++ b/train.py @@ -46,10 +46,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary opt, device, ): - save_dir, epochs, batch_size, total_batch_size, weights, single_cls = \ - Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.single_cls + save_dir, epochs, batch_size, weights, single_cls = \ + opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls # Directories + save_dir = Path(save_dir) wdir = save_dir / 'weights' wdir.mkdir(parents=True, exist_ok=True) # make dir last = wdir / 'last.pt' @@ -127,8 +128,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Optimizer nbs = 64 # nominal batch size - accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing - hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay + accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing + hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay logger.info(f"Scaled weight_decay = {hyp['weight_decay']}") pg0, pg1, pg2 = [], [], [] # optimizer parameter groups @@ -205,7 +206,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary logger.info('Using SyncBatchNorm()') # Trainloader - dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls, + dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK, workers=opt.workers, image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) @@ -215,7 +216,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Process 0 if RANK in [-1, 0]: - testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls, + testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls, hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, workers=opt.workers, pad=0.5, prefix=colorstr('val: '))[0] @@ -302,7 +303,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if ni <= nw: xi = [0, nw] # x interp # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou) - accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round()) + accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) @@ -371,7 +372,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if not opt.notest or final_epoch: # Calculate mAP wandb_logger.current_epoch = epoch + 1 results, maps, _ = test.test(data_dict, - batch_size=batch_size * 2, + batch_size=batch_size // WORLD_SIZE * 2, imgsz=imgsz_test, model=ema.ema, single_cls=single_cls, @@ -439,7 +440,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if is_coco: # COCO dataset for m in [last, best] if best.exists() else [last]: # speed, mAP tests results, _, _ = test.test(opt.data, - batch_size=batch_size * 2, + batch_size=batch_size // WORLD_SIZE * 2, imgsz=imgsz_test, conf_thres=0.001, iou_thres=0.7, @@ -518,7 +519,7 @@ def main(opt): assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' with open(Path(ckpt).parent.parent / 'opt.yaml') as f: opt = argparse.Namespace(**yaml.safe_load(f)) # replace - opt.cfg, opt.weights, opt.resume, opt.batch_size = '', ckpt, True, opt.total_batch_size # reinstate + opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate logger.info('Resuming training from %s' % ckpt) else: # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml') @@ -529,17 +530,15 @@ def main(opt): opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve)) # DDP mode - opt.total_batch_size = opt.batch_size device = select_device(opt.device, batch_size=opt.batch_size) if LOCAL_RANK != -1: from datetime import timedelta - assert torch.cuda.device_count() > LOCAL_RANK, 'too few GPUS for DDP command' + assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command' torch.cuda.set_device(LOCAL_RANK) device = torch.device('cuda', LOCAL_RANK) dist.init_process_group(backend="gloo", timeout=timedelta(seconds=60)) assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count' assert not opt.image_weights, '--image-weights argument is not compatible with DDP training' - opt.batch_size = opt.total_batch_size // WORLD_SIZE # Train if not opt.evolve: