diff --git a/requirements.txt b/requirements.txt index 4550fc771b04..a3284d6529eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,8 +9,8 @@ Pillow>=7.1.2 PyYAML>=5.3.1 requests>=2.23.0 scipy>=1.4.1 -torch>=1.7.0,!=1.12.0 # https://github.com/ultralytics/yolov5/issues/8395 -torchvision>=0.8.1,!=0.13.0 # https://github.com/ultralytics/yolov5/issues/8395 +torch>=1.7.0 +torchvision>=0.8.1 tqdm>=4.64.0 protobuf<4.21.3 # https://github.com/ultralytics/yolov5/issues/8012 diff --git a/train.py b/train.py index 6b463bf56423..c298692b7335 100644 --- a/train.py +++ b/train.py @@ -27,7 +27,6 @@ import torch.distributed as dist import torch.nn as nn import yaml -from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import lr_scheduler from tqdm import tqdm @@ -46,15 +45,15 @@ from utils.dataloaders import create_dataloader from utils.downloads import attempt_download from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size, - check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run, - increment_path, init_seeds, intersect_dicts, labels_to_class_weights, - labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer) + check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path, + init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, + one_cycle, print_args, print_mutation, strip_optimizer) from utils.loggers import Loggers from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loss import ComputeLoss from utils.metrics import fitness from utils.plots import plot_evolve, plot_labels -from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_optimizer, +from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer, torch_distributed_zero_first) LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html @@ -248,10 +247,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # DDP mode if cuda and RANK != -1: - if check_version(torch.__version__, '1.11.0'): - model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True) - else: - model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) + model = smart_DDP(model) # Model attributes nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index d82368dc6271..5f2a22c36f1a 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -17,8 +17,13 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP -from utils.general import LOGGER, colorstr, file_date, git_describe +from utils.general import LOGGER, check_version, colorstr, file_date, git_describe + +LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html +RANK = int(os.getenv('RANK', -1)) +WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) try: import thop # for FLOPs computation @@ -29,6 +34,17 @@ warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling') +def smart_DDP(model): + # Model DDP creation with checks + assert not check_version(torch.__version__, '1.12.0', pinned=True), \ + 'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \ + 'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395' + if check_version(torch.__version__, '1.11.0'): + return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True) + else: + return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) + + @contextmanager def torch_distributed_zero_first(local_rank: int): # Decorator to make all processes in distributed training wait for each local_master to do something