From 166824ea40c9e049d8b8c996ba61b3549d6ed3cd Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 13 May 2022 12:32:47 +0200 Subject: [PATCH] Implement DDP `static_graph=True` (#6940) * Implement DDP `static_graph=True` Experimental implementation of new PyTorch 1.11.0 DDP feature. * Add 1.11.0 check * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- train.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 707651637a79..d884c82e0bd7 100644 --- a/train.py +++ b/train.py @@ -47,9 +47,9 @@ from utils.datasets import create_dataloader from utils.downloads import attempt_download from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, - check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, - intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, methods, - one_cycle, print_args, print_mutation, strip_optimizer) + check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path, + init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, + methods, one_cycle, print_args, print_mutation, strip_optimizer) from utils.loggers import Loggers from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loss import ComputeLoss @@ -269,7 +269,10 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # DDP mode if cuda and RANK != -1: - model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) + if check_version(torch.__version__, '1.11.0'): + model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True) + else: + model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) # Model attributes nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)