Skip to content

Commit

Permalink
Fix rank and processgroup error
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Jul 21, 2020
1 parent bb15ba3 commit 927370f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
17 changes: 8 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,11 @@ def train(local_rank, hyp, opt, device):
mixed_precision = opt.mixed_precision
# local_rank = opt.local_rank

if (opt.parallel):
if (opt.distributed):
device = torch.device(local_rank)
if (opt.distributed):
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:9999', rank=local_rank,
world_size=opt.world_size) # distributed backend
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:9999', rank=local_rank,
world_size=opt.world_size) # distributed backend
# TODO: Init DDP logging. Only the first process is allowed to log.
# Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.

Expand Down Expand Up @@ -116,7 +115,7 @@ def train(local_rank, hyp, opt, device):

# Load Model
# Avoid multiple downloads.
with torch_distributed_zero_first(local_rank, (opt.parallel and not opt.distributed)):
with torch_distributed_zero_first(local_rank):
google_utils.attempt_download(weights)
start_epoch, best_fitness = 0, 0.0
if weights.endswith('.pt'): # pytorch format
Expand Down Expand Up @@ -249,7 +248,7 @@ def train(local_rank, hyp, opt, device):
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders

mloss = torch.zeros(4, device=device) # mean losses
if local_rank != -1:
if opt.distributed:
dataloader.sampler.set_epoch(epoch)
pbar = enumerate(dataloader)
if local_rank in [-1, 0]:
Expand Down Expand Up @@ -503,9 +502,9 @@ def run(fn, hyp, opt, device):
# Train
if not opt.evolve:
if (opt.distributed):
run(train, hyp, opt, None)
run(train, hyp, opt, None) #DDP
else:
train(0, hyp, opt, device) #CPU/Single GPU
train(-1, hyp, opt, device) #CPU,1 GPU, DP

# Evolve hyperparameters (optional)
else:
Expand Down
4 changes: 2 additions & 2 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@


@contextmanager
def torch_distributed_zero_first(local_rank: int, dp_mode=False):
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
yield
if local_rank == 0 and not dp_mode:
if local_rank == 0:
torch.distributed.barrier()


Expand Down

0 comments on commit 927370f

Please sign in to comment.