Skip to content

Commit

Permalink
assert torch!=1.12.0 for DDP training (#8621)
Browse files Browse the repository at this point in the history
* assert torch!=1.12.0 for DDP training

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
glenn-jocher and pre-commit-ci[bot] committed Jul 18, 2022
1 parent 51fb467 commit 9cf5fd5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ Pillow>=7.1.2
PyYAML>=5.3.1
requests>=2.23.0
scipy>=1.4.1
torch>=1.7.0,!=1.12.0 # https://github.com/ultralytics/yolov5/issues/8395
torchvision>=0.8.1,!=0.13.0 # https://github.com/ultralytics/yolov5/issues/8395
torch>=1.7.0
torchvision>=0.8.1
tqdm>=4.64.0
protobuf<4.21.3 # https://github.com/ultralytics/yolov5/issues/8012

Expand Down
14 changes: 5 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import torch.distributed as dist
import torch.nn as nn
import yaml
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import lr_scheduler
from tqdm import tqdm

Expand All @@ -46,15 +45,15 @@
from utils.dataloaders import create_dataloader
from utils.downloads import attempt_download
from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run,
increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer)
check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss
from utils.metrics import fitness
from utils.plots import plot_evolve, plot_labels
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_optimizer,
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
torch_distributed_zero_first)

LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
Expand Down Expand Up @@ -248,10 +247,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio

# DDP mode
if cuda and RANK != -1:
if check_version(torch.__version__, '1.11.0'):
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
else:
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
model = smart_DDP(model)

# Model attributes
nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
Expand Down
18 changes: 17 additions & 1 deletion utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP

from utils.general import LOGGER, colorstr, file_date, git_describe
from utils.general import LOGGER, check_version, colorstr, file_date, git_describe

LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))

try:
import thop # for FLOPs computation
Expand All @@ -29,6 +34,17 @@
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')


def smart_DDP(model):
# Model DDP creation with checks
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
if check_version(torch.__version__, '1.11.0'):
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
else:
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)


@contextmanager
def torch_distributed_zero_first(local_rank: int):
# Decorator to make all processes in distributed training wait for each local_master to do something
Expand Down

0 comments on commit 9cf5fd5

Please sign in to comment.