From 355d526717aa7e4e613856dc4bb4fe0ca5c464b1 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 11 Mar 2022 01:10:30 +0100 Subject: [PATCH 1/4] Implement DDP `static_graph=True` Experimental implementation of new PyTorch 1.11.0 DDP feature. --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 60be962d447f..641f3d3b3f2c 100644 --- a/train.py +++ b/train.py @@ -253,7 +253,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # DDP mode if cuda and RANK != -1: - model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) + model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True) # Model attributes nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps) From 36a46140e43ee1a9702851f80127bb1be333d3ce Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 11 Mar 2022 12:52:06 +0100 Subject: [PATCH 2/4] Add 1.11.0 check --- train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 641f3d3b3f2c..f783e9337646 100644 --- a/train.py +++ b/train.py @@ -47,7 +47,8 @@ from utils.datasets import create_dataloader from utils.downloads import attempt_download from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, - check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, + check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path, + init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer) from utils.loggers import Loggers @@ -253,7 +254,10 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # DDP mode if cuda and RANK != -1: - model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True) + if check_version(torch.__version__, '1.11.0'): + model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True) + else: + model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) # Model attributes nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps) From 2d54991bf94439145132488cf166803966f86eb9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Mar 2022 11:52:16 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index f783e9337646..0c7f301638f0 100644 --- a/train.py +++ b/train.py @@ -48,9 +48,8 @@ from utils.downloads import attempt_download from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path, - init_seeds, - intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle, - print_args, print_mutation, strip_optimizer) + init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, + one_cycle, print_args, print_mutation, strip_optimizer) from utils.loggers import Loggers from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loss import ComputeLoss From 3c7d1f8374220e3a0e11721153eb369f4bf1749e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 May 2022 10:30:19 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 46452d50ed37..d884c82e0bd7 100644 --- a/train.py +++ b/train.py @@ -47,9 +47,9 @@ from utils.datasets import create_dataloader from utils.downloads import attempt_download from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, - check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, check_version, - intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, methods, - one_cycle, print_args, print_mutation, strip_optimizer) + check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path, + init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, + methods, one_cycle, print_args, print_mutation, strip_optimizer) from utils.loggers import Loggers from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loss import ComputeLoss