diff --git a/export.py b/export.py index 5611ab95b1dc..d550a85fd99f 100644 --- a/export.py +++ b/export.py @@ -448,7 +448,8 @@ def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')): @try_export -def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')): +def export_tflite(keras_model, im, file, int8, per_tensor, data, nms, agnostic_nms, + prefix=colorstr('TensorFlow Lite:')): # YOLOv5 TensorFlow Lite export import tensorflow as tf @@ -469,6 +470,8 @@ def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=c converter.inference_input_type = tf.uint8 # or tf.int8 converter.inference_output_type = tf.uint8 # or tf.int8 converter.experimental_new_quantizer = True + if per_tensor: + converter._experimental_disable_per_channel = True f = str(file).replace('.pt', '-int8.tflite') if nms or agnostic_nms: converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS) @@ -713,6 +716,7 @@ def run( keras=False, # use Keras optimize=False, # TorchScript: optimize for mobile int8=False, # CoreML/TF INT8 quantization + per_tensor=False, # TF per tensor quantization dynamic=False, # ONNX/TF/TensorRT: dynamic axes simplify=False, # ONNX: simplify model opset=12, # ONNX: opset version @@ -798,7 +802,14 @@ def run( if pb or tfjs: # pb prerequisite to tfjs f[6], _ = export_pb(s_model, file) if tflite or edgetpu: - f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) + f[7], _ = export_tflite(s_model, + im, + file, + int8 or edgetpu, + per_tensor, + data=data, + nms=nms, + agnostic_nms=agnostic_nms) if edgetpu: f[8], _ = export_edgetpu(file) add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs)) @@ -837,6 +848,7 @@ def parse_opt(known=False): parser.add_argument('--keras', action='store_true', help='TF: use Keras') parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile') parser.add_argument('--int8', action='store_true', help='CoreML/TF/OpenVINO INT8 quantization') + parser.add_argument('--per-tensor', action='store_true', help='TF per-tensor quantization') parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes') parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') parser.add_argument('--opset', type=int, default=17, help='ONNX: opset version')