diff --git a/utils/autobatch.py b/utils/autobatch.py index 641b055b9fe3..bafff1c20658 100644 --- a/utils/autobatch.py +++ b/utils/autobatch.py @@ -7,6 +7,7 @@ import numpy as np import torch +from torch.backends import cudnn from utils.general import LOGGER, colorstr from utils.torch_utils import profile @@ -46,11 +47,14 @@ def autobatch(model, imgsz=640, fraction=0.8, batch_size=16): # Profile batch sizes batch_sizes = [1, 2, 4, 8, 16] + bm = cudnn.benchmark + cudnn.benchmark = False # avoid benchmark interference https://github.com/ultralytics/yolov5/issues/9287 try: img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes] results = profile(img, model, n=3, device=device) except Exception as e: LOGGER.warning(f'{prefix}{e}') + cudnn.benchmark = bm # reset to original value # Fit a solution y = [x[2] for x in results if x] # memory [2]