diff --git a/export.py b/export.py index 623844ff3531..9868fcae95c3 100644 --- a/export.py +++ b/export.py @@ -492,6 +492,8 @@ def run( # Checks imgsz *= 2 if len(imgsz) == 1 else 1 # expand assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}' + if optimize: + assert device.type != 'cuda', '--optimize not compatible with cuda devices, i.e. use --device cpu' # Input gs = int(max(model.stride)) # grid size (max stride)