Skip to content

Commit

Permalink
AMP check improvements backup YOLOv5n pretrained (ultralytics#7959)
Browse files Browse the repository at this point in the history
* Reduce AMP check to detections verification

More robust and faster

* Update general.py

* Update general.py
  • Loading branch information
glenn-jocher authored and tdhooghe committed Jun 10, 2022
1 parent da6752a commit 9fcf5c0
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,27 +507,27 @@ def check_dataset(data, autodownload=True):

def check_amp(model):
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
from models.common import AutoShape
from models.common import AutoShape, DetectMultiBackend

def amp_allclose(model, im):
# All close FP32 vs AMP results
m = AutoShape(model, verbose=False) # model
a = m(im).xywhn[0] # FP32 inference
m.amp = True
b = m(im).xywhn[0] # AMP inference
return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance

if next(model.parameters()).device.type == 'cpu': # get model device
return False
prefix = colorstr('AMP: ')
file = ROOT / 'data' / 'images' / 'bus.jpg' # image to test
if file.exists():
im = cv2.imread(file)[..., ::-1] # OpenCV image (BGR to RGB)
elif check_online():
im = 'https://ultralytics.com/images/bus.jpg'
else:
LOGGER.warning(emojis(f'{prefix}checks skipped ⚠️, not online.'))
return True
m = AutoShape(model, verbose=False) # model
a = m(im).xywhn[0] # FP32 inference
m.amp = True
b = m(im).xywhn[0] # AMP inference
if (a.shape == b.shape) and torch.allclose(a, b, atol=0.05): # close to 5% absolute tolerance
device = next(model.parameters()).device # get model device
if device.type == 'cpu':
return False # AMP disabled on CPU
f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
try:
assert amp_allclose(model, im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im)
LOGGER.info(emojis(f'{prefix}checks passed ✅'))
return True
else:
except Exception:
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
return False
Expand Down

0 comments on commit 9fcf5c0

Please sign in to comment.