From 6dcc47b14ac5bd9b52872ad94aca92b07e54da72 Mon Sep 17 00:00:00 2001 From: Giacomo Guiduzzi <10937563+giacomoguiduzzi@users.noreply.github.com> Date: Sun, 26 Jun 2022 19:54:00 +0200 Subject: [PATCH 1/8] Implementation of Early Stopping for DDP training This edit correctly uses the broadcast_object_list() function to send slave processes a boolean so to end the training phase if the variable is True, thus allowing the master process to destroy the process group and terminate. --- train.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index a06ad5a418f8..acecf99bb38d 100644 --- a/train.py +++ b/train.py @@ -301,6 +301,11 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n' f"Logging results to {colorstr('bold', save_dir)}\n" f'Starting training for {epochs} epochs...') + + stop = False + if RANK != -1: + broadcast_list = list() + for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ callbacks.run('on_train_epoch_start') model.train() @@ -428,11 +433,32 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio del ckpt callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi) + stop = stopper(epoch=epoch, fitness=fi) + # Stop Single-GPU - if RANK == -1 and stopper(epoch=epoch, fitness=fi): - break + # if RANK == -1 and stopper(epoch=epoch, fitness=fi): + # break + + if RANK != -1: # if DDP training + if RANK == 0: + broadcast_list.append(stop) + + else: + broadcast_list.append(None) + + dist.broadcast_object_list(broadcast_list, 0) + + if RANK != 0: + stop = broadcast_list[0] + + # Stop Single GPU and Multi GPU training + if stop: + break + + if RANK != -1: + broadcast_list.clear() - # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576 + # Stop DDP TODO: known issues https://github.com/ultralytics/yolov5/pull/4576 # stop = stopper(epoch=epoch, fitness=fi) # if RANK == 0: # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks From 4aa4305dc6db78e69c3b7a378e23fa8610c7f895 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 28 Jun 2022 18:16:45 +0200 Subject: [PATCH 2/8] Update train.py --- train.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/train.py b/train.py index 84edbc92f89a..b0145b27ab6a 100644 --- a/train.py +++ b/train.py @@ -440,14 +440,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # break if RANK != -1: # if DDP training - if RANK == 0: - broadcast_list.append(stop) - - else: - broadcast_list.append(None) - + broadcast_list.append(stop if RANK == 0 else None) dist.broadcast_object_list(broadcast_list, 0) - if RANK != 0: stop = broadcast_list[0] From 953aaa3c7e392a3c1b1679899f627f148374bbf3 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 28 Jun 2022 18:25:02 +0200 Subject: [PATCH 3/8] Update train.py --- train.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/train.py b/train.py index b0145b27ab6a..2fa67b4ab5e5 100644 --- a/train.py +++ b/train.py @@ -435,33 +435,22 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio stop = stopper(epoch=epoch, fitness=fi) - # Stop Single-GPU - # if RANK == -1 and stopper(epoch=epoch, fitness=fi): - # break - + # Stop DDP TODO: known issues https://github.com/ultralytics/yolov5/pull/4576 + # stop = stopper(epoch=epoch, fitness=fi) + # if RANK == 0: + # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks + + # EarlyStop Single and Multi-GPU training if RANK != -1: # if DDP training broadcast_list.append(stop if RANK == 0 else None) dist.broadcast_object_list(broadcast_list, 0) if RANK != 0: stop = broadcast_list[0] - - # Stop Single GPU and Multi GPU training if stop: - break - + break # must break all DDP ranks if RANK != -1: broadcast_list.clear() - # Stop DDP TODO: known issues https://github.com/ultralytics/yolov5/pull/4576 - # stop = stopper(epoch=epoch, fitness=fi) - # if RANK == 0: - # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks - - # Stop DPP - # with torch_distributed_zero_first(RANK): - # if stop: - # break # must break all DDP ranks - # end epoch ---------------------------------------------------------------------------------------------------- # end training ----------------------------------------------------------------------------------------------------- if RANK in {-1, 0}: From 0e87ad64e8f678d244d1b4f4e1639fc81c1ef4e7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Jun 2022 16:26:17 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 2fa67b4ab5e5..b3a5f3a6957b 100644 --- a/train.py +++ b/train.py @@ -439,7 +439,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # stop = stopper(epoch=epoch, fitness=fi) # if RANK == 0: # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks - + # EarlyStop Single and Multi-GPU training if RANK != -1: # if DDP training broadcast_list.append(stop if RANK == 0 else None) From d6ad6805ddccd2d769420e78886874cedec77a4d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 28 Jun 2022 18:27:53 +0200 Subject: [PATCH 5/8] Update train.py --- train.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index b3a5f3a6957b..22eb64ac9208 100644 --- a/train.py +++ b/train.py @@ -269,6 +269,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # DDP mode if cuda and RANK != -1: + broadcast_list = list() if check_version(torch.__version__, '1.11.0'): model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True) else: @@ -294,18 +295,13 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move scaler = torch.cuda.amp.GradScaler(enabled=amp) - stopper = EarlyStopping(patience=opt.patience) + stopper, stop = EarlyStopping(patience=opt.patience), False compute_loss = ComputeLoss(model) # init loss class callbacks.run('on_train_start') LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n' f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n' f"Logging results to {colorstr('bold', save_dir)}\n" f'Starting training for {epochs} epochs...') - - stop = False - if RANK != -1: - broadcast_list = list() - for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ callbacks.run('on_train_epoch_start') model.train() From 39c1f11e0678436a2085d2363e3b9d36c2c4717d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 28 Jun 2022 18:31:18 +0200 Subject: [PATCH 6/8] Update train.py --- train.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 22eb64ac9208..f291106b759a 100644 --- a/train.py +++ b/train.py @@ -431,15 +431,10 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio stop = stopper(epoch=epoch, fitness=fi) - # Stop DDP TODO: known issues https://github.com/ultralytics/yolov5/pull/4576 - # stop = stopper(epoch=epoch, fitness=fi) - # if RANK == 0: - # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks - - # EarlyStop Single and Multi-GPU training + # EarlyStopping if RANK != -1: # if DDP training broadcast_list.append(stop if RANK == 0 else None) - dist.broadcast_object_list(broadcast_list, 0) + dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks if RANK != 0: stop = broadcast_list[0] if stop: From 227a77ae2185eccf10a5a5d227bf401ee6670ebf Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 28 Jun 2022 18:32:52 +0200 Subject: [PATCH 7/8] Update train.py --- train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/train.py b/train.py index f291106b759a..ccc74f3e98d8 100644 --- a/train.py +++ b/train.py @@ -403,6 +403,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # Update best mAP fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95] + stop = stopper(epoch=epoch, fitness=fi) # early stop check if fi > best_fitness: best_fitness = fi log_vals = list(mloss) + list(results) + lr @@ -429,8 +430,6 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio del ckpt callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi) - stop = stopper(epoch=epoch, fitness=fi) - # EarlyStopping if RANK != -1: # if DDP training broadcast_list.append(stop if RANK == 0 else None) From 58bc763e1c7adb4995c5961de584fadee22d8978 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 28 Jun 2022 18:40:13 +0200 Subject: [PATCH 8/8] Further cleanup This cleans up the definition of broadcast_list and removes the requirement for clear() afterward. --- train.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/train.py b/train.py index ccc74f3e98d8..dd5eeb600a76 100644 --- a/train.py +++ b/train.py @@ -269,7 +269,6 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # DDP mode if cuda and RANK != -1: - broadcast_list = list() if check_version(torch.__version__, '1.11.0'): model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True) else: @@ -432,14 +431,12 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # EarlyStopping if RANK != -1: # if DDP training - broadcast_list.append(stop if RANK == 0 else None) + broadcast_list = [stop if RANK == 0 else None] dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks if RANK != 0: stop = broadcast_list[0] if stop: break # must break all DDP ranks - if RANK != -1: - broadcast_list.clear() # end epoch ---------------------------------------------------------------------------------------------------- # end training -----------------------------------------------------------------------------------------------------