Skip to content

Commit

Permalink
Add int8 quantized TFLite inference in detect.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zldrobit committed Oct 12, 2020
1 parent 1ab4883 commit e98991d
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def _imports_graph_def():
_ = frozen_func(x=tf.constant(img.permute(0, 2, 3, 1).cpu().numpy()))
elif backend == 'tflite':
input_data = img.permute(0, 2, 3, 1).cpu().numpy()
if opt.tfl_int8:
input_data = input_data.astype(np.uint8)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
Expand Down Expand Up @@ -194,6 +196,10 @@ def _imports_graph_def():

elif backend == 'tflite':
input_data = img.permute(0, 2, 3, 1).cpu().numpy()
if opt.tfl_int8:
scale, zero_point = input_details[0]['quantization']
input_data = input_data / scale + zero_point
input_data = input_data.astype(np.uint8)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
if not opt.tfl_detect:
Expand All @@ -203,12 +209,18 @@ def _imports_graph_def():
import yaml
yaml_file = Path(opt.cfg).name
with open(opt.cfg) as f:
yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
yaml = yaml.load(f, Loader=yaml.FullLoader)

anchors = yaml['anchors']
nc = yaml['nc']
nl = len(anchors)
x = [torch.tensor(interpreter.get_tensor(output_details[i]['index']), device=device) for i in range(nl)]
if opt.tfl_int8:
for i in range(nl):
scale, zero_point = output_details[i]['quantization']
x[i] = x[i].float()
x[i] = (x[i] - zero_point) * scale

def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny * nx, 2)).float()
Expand Down Expand Up @@ -318,6 +330,7 @@ def _make_grid(nx=20, ny=20):
parser.add_argument('--update', action='store_true', help='update all models')
parser.add_argument('--tfl-detect', action='store_true', help='add Detect module in TFLite')
parser.add_argument('--cfg', type=str, default='./models/yolov5s.yaml', help='cfg path')
parser.add_argument('--tfl-int8', action='store_true', help='use int8 quantized TFLite model')
opt = parser.parse_args()
print(opt)

Expand Down

0 comments on commit e98991d

Please sign in to comment.