diff --git a/train.py b/train.py index dfe997b1bde2..d83f3cd1863c 100644 --- a/train.py +++ b/train.py @@ -36,7 +36,7 @@ from models.experimental import attempt_load from models.yolo import Model from utils.autoanchor import check_anchors -from utils.autobatch import check_batch_size +from utils.autobatch import check_train_batch_size from utils.datasets import create_dataloader from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \ @@ -137,9 +137,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple # Batch size - if cuda and RANK == -1: # single-GPU only - with amp.autocast(): - batch_size = check_batch_size(deepcopy(model).train(), imgsz, batch_size) + if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size + batch_size = check_train_batch_size(model, imgsz) # Optimizer nbs = 64 # nominal batch size @@ -446,7 +445,7 @@ def parse_opt(known=False): parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml', help='hyperparameters path') parser.add_argument('--epochs', type=int, default=300) - parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') + parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch') parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)') parser.add_argument('--rect', action='store_true', help='rectangular training') parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training') diff --git a/utils/autobatch.py b/utils/autobatch.py index 3f97140c1a29..22a8c59040c8 100644 --- a/utils/autobatch.py +++ b/utils/autobatch.py @@ -7,20 +7,20 @@ import numpy as np import torch +from torch.cuda import amp from utils.general import colorstr -from utils.torch_utils import de_parallel, profile +from utils.torch_utils import profile -def check_batch_size(model, imgsz=640, b=16): - # Check YOLOv5 batch size - if b < 1 or b == 'auto': - b = autobatch(model, imgsz) # compute optimal batch size - return b +def check_train_batch_size(model, imgsz=640): + # Check YOLOv5 training batch size + with amp.autocast(): + return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size -def autobatch(model, imgsz=640, fraction=0.9): - # Automatically compute optimal batch size to use `fraction` of available CUDA memory +def autobatch(model, imgsz=640, fraction=0.9, batch_size=16): + # Automatically estimate best batch size to use `fraction` of available CUDA memory # Usage: # import torch # from utils.autobatch import autobatch @@ -30,6 +30,10 @@ def autobatch(model, imgsz=640, fraction=0.9): prefix = colorstr('autobatch: ') print(f'{prefix}Computing optimal batch size for --imgsz {imgsz}') device = next(model.parameters()).device # get model device + if device.type == 'cpu': + print(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}') + return batch_size + t = torch.cuda.get_device_properties(device).total_memory / 1024 ** 3 # (GB) r = torch.cuda.memory_reserved(device) / 1024 ** 3 # (GB) a = torch.cuda.memory_allocated(device) / 1024 ** 3 # (GB)