From 11a2b2a7e19b63a516518c98f1a9afad1c694ea5 Mon Sep 17 00:00:00 2001 From: SecretStar112 Date: Mon, 23 Aug 2021 17:05:53 +0200 Subject: [PATCH] Automatic TFLite uint8 determination (#4515) * Auto TFLite uint8 detection This PR automatically determines if TFLite models are uint8 quantized rather than accepting a manual argument. The quantization determination is based on @zldrobit comment https://github.com/ultralytics/yolov5/pull/1127#issuecomment-901713847 * Cleanup --- detect.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/detect.py b/detect.py index 601d5da..15ddc1f 100644 --- a/detect.py +++ b/detect.py @@ -52,7 +52,6 @@ def run(weights='yolov5s.pt', # model.pt path(s) hide_labels=False, # hide labels hide_conf=False, # hide confidences half=False, # use FP16 half-precision inference - tfl_int8=False, # INT8 quantized TFLite model ): save_img = not nosave and not source.endswith('.txt') # save inference images webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( @@ -104,6 +103,7 @@ def wrap_frozen_graph(gd, inputs, outputs): interpreter.allocate_tensors() # allocate input_details = interpreter.get_input_details() # inputs output_details = interpreter.get_output_details() # outputs + int8 = input_details[0]['dtype'] == np.uint8 # is TFLite quantized uint8 model imgsz = check_img_size(imgsz, s=stride) # check image size # Dataloader @@ -145,15 +145,15 @@ def wrap_frozen_graph(gd, inputs, outputs): elif saved_model: pred = model(imn, training=False).numpy() elif tflite: - if tfl_int8: + if int8: scale, zero_point = input_details[0]['quantization'] - imn = (imn / scale + zero_point).astype(np.uint8) + imn = (imn / scale + zero_point).astype(np.uint8) # de-scale interpreter.set_tensor(input_details[0]['index'], imn) interpreter.invoke() pred = interpreter.get_tensor(output_details[0]['index']) - if tfl_int8: + if int8: scale, zero_point = output_details[0]['quantization'] - pred = (pred.astype(np.float32) - zero_point) * scale + pred = (pred.astype(np.float32) - zero_point) * scale # re-scale pred[..., 0] *= imgsz[1] # x pred[..., 1] *= imgsz[0] # y pred[..., 2] *= imgsz[1] # w @@ -268,7 +268,6 @@ def parse_opt(): parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels') parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences') parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') - parser.add_argument('--tfl-int8', action='store_true', help='INT8 quantized TFLite model') opt = parser.parse_args() opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand return opt