From 8a32fa8477f7b105d23417810f4b86168398d545 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 22 May 2022 13:41:18 +0200 Subject: [PATCH] Add PyTorch AMP check (#7917) * Add PyTorch AMP check * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup * Cleanup * Cleanup * Robust for DDP * Fixes * Add amp enabled boolean to check_train_batch_size * Simplify * space to prefix Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- models/common.py | 5 +++-- train.py | 16 ++++++++-------- utils/autobatch.py | 5 ++--- utils/general.py | 25 ++++++++++++++++++++++++- 4 files changed, 37 insertions(+), 14 deletions(-) diff --git a/models/common.py b/models/common.py index abee3591a9f8..996b7128bfba 100644 --- a/models/common.py +++ b/models/common.py @@ -524,9 +524,10 @@ class AutoShape(nn.Module): max_det = 1000 # maximum number of detections per image amp = False # Automatic Mixed Precision (AMP) inference - def __init__(self, model): + def __init__(self, model, verbose=True): super().__init__() - LOGGER.info('Adding AutoShape... ') + if verbose: + LOGGER.info('Adding AutoShape... ') copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance self.pt = not self.dmb or model.pt # PyTorch model diff --git a/train.py b/train.py index feec7e9ae2f9..5552b77f1a40 100644 --- a/train.py +++ b/train.py @@ -27,7 +27,6 @@ import torch.distributed as dist import torch.nn as nn import yaml -from torch.cuda import amp from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import SGD, Adam, AdamW, lr_scheduler from tqdm import tqdm @@ -46,10 +45,10 @@ from utils.callbacks import Callbacks from utils.dataloaders import create_dataloader from utils.downloads import attempt_download -from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, - check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path, - init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, - one_cycle, print_args, print_mutation, strip_optimizer) +from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size, + check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run, + increment_path, init_seeds, intersect_dicts, labels_to_class_weights, + labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer) from utils.loggers import Loggers from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loss import ComputeLoss @@ -126,6 +125,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report else: model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create + amp = check_amp(model) # check AMP # Freeze freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze @@ -141,7 +141,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # Batch size if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size - batch_size = check_train_batch_size(model, imgsz) + batch_size = check_train_batch_size(model, imgsz, amp) loggers.on_params_update({"batch_size": batch_size}) # Optimizer @@ -293,7 +293,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move - scaler = amp.GradScaler(enabled=cuda) + scaler = torch.cuda.amp.GradScaler(enabled=amp) stopper = EarlyStopping(patience=opt.patience) compute_loss = ComputeLoss(model) # init loss class callbacks.run('on_train_start') @@ -348,7 +348,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) # Forward - with amp.autocast(enabled=cuda): + with torch.cuda.amp.autocast(amp): pred = model(imgs) # forward loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size if RANK != -1: diff --git a/utils/autobatch.py b/utils/autobatch.py index e53b4787b87d..11009453b36a 100644 --- a/utils/autobatch.py +++ b/utils/autobatch.py @@ -7,15 +7,14 @@ import numpy as np import torch -from torch.cuda import amp from utils.general import LOGGER, colorstr from utils.torch_utils import profile -def check_train_batch_size(model, imgsz=640): +def check_train_batch_size(model, imgsz=640, amp=True): # Check YOLOv5 training batch size - with amp.autocast(): + with torch.cuda.amp.autocast(amp): return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size diff --git a/utils/general.py b/utils/general.py index e1c5e7c1c321..4d92ffe1ea49 100755 --- a/utils/general.py +++ b/utils/general.py @@ -36,9 +36,11 @@ from utils.downloads import gsutil_getsize from utils.metrics import box_iou, fitness -# Settings FILE = Path(__file__).resolve() ROOT = FILE.parents[1] # YOLOv5 root directory +RANK = int(os.getenv('RANK', -1)) + +# Settings DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode @@ -505,6 +507,27 @@ def check_dataset(data, autodownload=True): return data # dictionary +def check_amp(model): + # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation + from models.common import AutoShape + + if next(model.parameters()).device.type == 'cpu': # get model device + return False + prefix = colorstr('AMP: ') + im = cv2.imread(ROOT / 'data' / 'images' / 'bus.jpg')[..., ::-1] # OpenCV image (BGR to RGB) + m = AutoShape(model, verbose=False) # model + a = m(im).xyxy[0] # FP32 inference + m.amp = True + b = m(im).xyxy[0] # AMP inference + if (a.shape == b.shape) and torch.allclose(a, b, atol=1.0): # close to 1.0 pixel bounding box + LOGGER.info(emojis(f'{prefix}checks passed ✅')) + return True + else: + help_url = 'https://github.com/ultralytics/yolov5/issues/7908' + LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}')) + return False + + def url2file(url): # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/