diff --git a/export.py b/export.py index 6a8c4f6f94a0..15e92a784a50 100644 --- a/export.py +++ b/export.py @@ -433,9 +433,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' conf_thres=0.25 # TF.js NMS: confidence threshold ): t = time.time() - include = [x.lower() for x in include] - tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports - file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) + include = [x.lower() for x in include] # to lowercase + formats = tuple(export_formats()['Argument'][1:]) # --include arguments + flags = [x in include for x in formats] + assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}' + jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans + file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights # Load PyTorch model device = select_device(device) @@ -475,20 +478,19 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' # Exports f = [''] * 10 # exported filenames warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning - if 'torchscript' in include: + if jit: f[0] = export_torchscript(model, im, file, optimize) - if 'engine' in include: # TensorRT required before ONNX + if engine: # TensorRT required before ONNX f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose) - if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX + if onnx or xml: # OpenVINO requires ONNX f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify) - if 'openvino' in include: + if xml: # OpenVINO f[3] = export_openvino(model, im, file) - if 'coreml' in include: + if coreml: _, f[4] = export_coreml(model, im, file) # TensorFlow Exports - if any(tf_exports): - pb, tflite, edgetpu, tfjs = tf_exports[1:] + if any((saved_model, pb, tflite, edgetpu, tfjs)): if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707 check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow` assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'