diff --git a/train.py b/train.py index 257256e85185..e8034a648cac 100644 --- a/train.py +++ b/train.py @@ -99,7 +99,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary plots = not evolve # create plots cuda = device.type != 'cpu' init_seeds(1 + RANK) - with torch_distributed_zero_first(RANK): + with torch_distributed_zero_first(LOCAL_RANK): data_dict = data_dict or check_dataset(data) # check if None train_path, val_path = data_dict['train'], data_dict['val'] nc = 1 if single_cls else int(data_dict['nc']) # number of classes @@ -111,7 +111,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary check_suffix(weights, '.pt') # check weights pretrained = weights.endswith('.pt') if pretrained: - with torch_distributed_zero_first(RANK): + with torch_distributed_zero_first(LOCAL_RANK): weights = attempt_download(weights) # download if not found locally ckpt = torch.load(weights, map_location=device) # load checkpoint model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create @@ -208,7 +208,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Trainloader train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls, - hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=RANK, + hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK, workers=workers, image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class