diff --git a/detect.py b/detect.py index 6af3c252f7e0..5db6f5512cf7 100644 --- a/detect.py +++ b/detect.py @@ -159,6 +159,8 @@ def _imports_graph_def(): _ = frozen_func(x=tf.constant(img.permute(0, 2, 3, 1).cpu().numpy())) elif backend == 'tflite': input_data = img.permute(0, 2, 3, 1).cpu().numpy() + if opt.tfl_int8: + input_data = input_data.astype(np.uint8) interpreter.set_tensor(input_details[0]['index'], input_data) interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index']) @@ -194,6 +196,10 @@ def _imports_graph_def(): elif backend == 'tflite': input_data = img.permute(0, 2, 3, 1).cpu().numpy() + if opt.tfl_int8: + scale, zero_point = input_details[0]['quantization'] + input_data = input_data / scale + zero_point + input_data = input_data.astype(np.uint8) interpreter.set_tensor(input_details[0]['index'], input_data) interpreter.invoke() if not opt.tfl_detect: @@ -203,12 +209,18 @@ def _imports_graph_def(): import yaml yaml_file = Path(opt.cfg).name with open(opt.cfg) as f: - yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict + yaml = yaml.load(f, Loader=yaml.FullLoader) anchors = yaml['anchors'] nc = yaml['nc'] nl = len(anchors) x = [torch.tensor(interpreter.get_tensor(output_details[i]['index']), device=device) for i in range(nl)] + if opt.tfl_int8: + for i in range(nl): + scale, zero_point = output_details[i]['quantization'] + x[i] = x[i].float() + x[i] = (x[i] - zero_point) * scale + def _make_grid(nx=20, ny=20): yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) return torch.stack((xv, yv), 2).view((1, 1, ny * nx, 2)).float() @@ -318,6 +330,7 @@ def _make_grid(nx=20, ny=20): parser.add_argument('--update', action='store_true', help='update all models') parser.add_argument('--tfl-detect', action='store_true', help='add Detect module in TFLite') parser.add_argument('--cfg', type=str, default='./models/yolov5s.yaml', help='cfg path') + parser.add_argument('--tfl-int8', action='store_true', help='use int8 quantized TFLite model') opt = parser.parse_args() print(opt)