Skip to content

Commit

Permalink
Add PyTorch AMP check (ultralytics#7917)
Browse files Browse the repository at this point in the history
* Add PyTorch AMP check

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup

* Cleanup

* Cleanup

* Robust for DDP

* Fixes

* Add amp enabled boolean to check_train_batch_size

* Simplify

* space to prefix

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and Clay Januhowski committed Sep 8, 2022
1 parent 94787f4 commit 8a32fa8
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 14 deletions.
5 changes: 3 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,10 @@ class AutoShape(nn.Module):
max_det = 1000 # maximum number of detections per image
amp = False # Automatic Mixed Precision (AMP) inference

def __init__(self, model):
def __init__(self, model, verbose=True):
super().__init__()
LOGGER.info('Adding AutoShape... ')
if verbose:
LOGGER.info('Adding AutoShape... ')
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
self.pt = not self.dmb or model.pt # PyTorch model
Expand Down
16 changes: 8 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import torch.distributed as dist
import torch.nn as nn
import yaml
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD, Adam, AdamW, lr_scheduler
from tqdm import tqdm
Expand All @@ -46,10 +45,10 @@
from utils.callbacks import Callbacks
from utils.dataloaders import create_dataloader
from utils.downloads import attempt_download
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
one_cycle, print_args, print_mutation, strip_optimizer)
from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run,
increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss
Expand Down Expand Up @@ -126,6 +125,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
else:
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
amp = check_amp(model) # check AMP

# Freeze
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
Expand All @@ -141,7 +141,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio

# Batch size
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
batch_size = check_train_batch_size(model, imgsz)
batch_size = check_train_batch_size(model, imgsz, amp)
loggers.on_params_update({"batch_size": batch_size})

# Optimizer
Expand Down Expand Up @@ -293,7 +293,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda)
scaler = torch.cuda.amp.GradScaler(enabled=amp)
stopper = EarlyStopping(patience=opt.patience)
compute_loss = ComputeLoss(model) # init loss class
callbacks.run('on_train_start')
Expand Down Expand Up @@ -348,7 +348,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

# Forward
with amp.autocast(enabled=cuda):
with torch.cuda.amp.autocast(amp):
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1:
Expand Down
5 changes: 2 additions & 3 deletions utils/autobatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@

import numpy as np
import torch
from torch.cuda import amp

from utils.general import LOGGER, colorstr
from utils.torch_utils import profile


def check_train_batch_size(model, imgsz=640):
def check_train_batch_size(model, imgsz=640, amp=True):
# Check YOLOv5 training batch size
with amp.autocast():
with torch.cuda.amp.autocast(amp):
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size


Expand Down
25 changes: 24 additions & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness

# Settings
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLOv5 root directory
RANK = int(os.getenv('RANK', -1))

# Settings
DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
Expand Down Expand Up @@ -505,6 +507,27 @@ def check_dataset(data, autodownload=True):
return data # dictionary


def check_amp(model):
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
from models.common import AutoShape

if next(model.parameters()).device.type == 'cpu': # get model device
return False
prefix = colorstr('AMP: ')
im = cv2.imread(ROOT / 'data' / 'images' / 'bus.jpg')[..., ::-1] # OpenCV image (BGR to RGB)
m = AutoShape(model, verbose=False) # model
a = m(im).xyxy[0] # FP32 inference
m.amp = True
b = m(im).xyxy[0] # AMP inference
if (a.shape == b.shape) and torch.allclose(a, b, atol=1.0): # close to 1.0 pixel bounding box
LOGGER.info(emojis(f'{prefix}checks passed ✅'))
return True
else:
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
return False


def url2file(url):
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
Expand Down

0 comments on commit 8a32fa8

Please sign in to comment.