Skip to content

Commit

Permalink
Add TensorFlow and TFLite export (ultralytics#1127)
Browse files Browse the repository at this point in the history
* Add models/tf.py for TensorFlow and TFLite export

* Set auto=False for int8 calibration

* Update requirements.txt for TensorFlow and TFLite export

* Read anchors directly from PyTorch weights

* Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export

* Remove check_anchor_order, check_file, set_logging from import

* Reformat code and optimize imports

* Autodownload model and check cfg

* update --source path, img-size to 320, single output

* Adjust representative_dataset

* Put representative dataset in tfl_int8 block

* detect.py TF inference

* weights to string

* weights to string

* cleanup tf.py

* Add --dynamic-batch-size

* Add xywh normalization to reduce calibration error

* Update requirements.txt

TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error

* Fix imports

Move C3 from models.experimental to models.common

* Add models/tf.py for TensorFlow and TFLite export

* Set auto=False for int8 calibration

* Update requirements.txt for TensorFlow and TFLite export

* Read anchors directly from PyTorch weights

* Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export

* Remove check_anchor_order, check_file, set_logging from import

* Reformat code and optimize imports

* Autodownload model and check cfg

* update --source path, img-size to 320, single output

* Adjust representative_dataset

* detect.py TF inference

* Put representative dataset in tfl_int8 block

* weights to string

* weights to string

* cleanup tf.py

* Add --dynamic-batch-size

* Add xywh normalization to reduce calibration error

* Update requirements.txt

TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error

* Fix imports

Move C3 from models.experimental to models.common

* implement C3() and SiLU()

* Fix reshape dim to support dynamic batching

* Add epsilon argument in tf_BN, which is different between TF and PT

* Set stride to None if not using PyTorch, and do not warmup without PyTorch

* Add list support in check_img_size()

* Add list input support in detect.py

* sys.path.append('./') to run from yolov5/

* Add int8 quantization support for TensorFlow 2.5

* Add get_coco128.sh

* Remove --no-tfl-detect in models/tf.py (Use tf-android-tfl-detect branch for EdgeTPU)

* Update requirements.txt

* Replace torch.load() with attempt_load()

* Update requirements.txt

* Add --tf-raw-resize to set half_pixel_centers=False

* Add --agnostic-nms for TF class-agnostic NMS

* Cleanup after merge

* Cleanup2 after merge

* Cleanup3 after merge

* Add tf.py docstring with credit and usage

* pb saved_model and tflite use only one model in detect.py

* Add use cases in docstring of tf.py

* Remove redundant `stride` definition

* Remove keras direct import

* Fix `check_requirements(('tensorflow>=2.4.1',))`

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
zldrobit and glenn-jocher authored Aug 17, 2021
1 parent 5cd1e1d commit 8ebe069
Show file tree
Hide file tree
Showing 5 changed files with 626 additions and 17 deletions.
64 changes: 54 additions & 10 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn

Expand Down Expand Up @@ -51,6 +52,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
hide_labels=False, # hide labels
hide_conf=False, # hide confidences
half=False, # use FP16 half-precision inference
tfl_int8=False, # INT8 quantized TFLite model
):
save_img = not nosave and not source.endswith('.txt') # save inference images
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
Expand All @@ -68,7 +70,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
# Load model
w = weights[0] if isinstance(weights, list) else weights
classify, suffix = False, Path(w).suffix.lower()
pt, onnx, tflite, pb, graph_def = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', '']) # backend
pt, onnx, tflite, pb, saved_model = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', '']) # backend
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
if pt:
model = attempt_load(weights, map_location=device) # load FP32 model
Expand All @@ -83,30 +85,49 @@ def run(weights='yolov5s.pt', # model.pt path(s)
check_requirements(('onnx', 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
else: # TensorFlow models
check_requirements(('tensorflow>=2.4.1',))
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped import
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
tf.nest.map_structure(x.graph.as_graph_element, outputs))

graph_def = tf.Graph().as_graph_def()
graph_def.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
elif saved_model:
model = tf.keras.models.load_model(w)
elif tflite:
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
imgsz = check_img_size(imgsz, s=stride) # check image size

# Dataloader
if webcam:
view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz, stride=stride)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
bs = len(dataset) # batch_size
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride)
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
bs = 1 # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs

# Run inference
if pt and device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once
t0 = time.time()
for path, img, im0s, vid_cap in dataset:
if pt:
if onnx:
img = img.astype('float32')
else:
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
elif onnx:
img = img.astype('float32')
img /= 255.0 # 0 - 255 to 0.0 - 1.0
img = img / 255.0 # 0 - 255 to 0.0 - 1.0
if len(img.shape) == 3:
img = img[None] # expand for batch dim

Expand All @@ -117,6 +138,27 @@ def run(weights='yolov5s.pt', # model.pt path(s)
pred = model(img, augment=augment, visualize=visualize)[0]
elif onnx:
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
else: # tensorflow model (tflite, pb, saved_model)
imn = img.permute(0, 2, 3, 1).cpu().numpy() # image in numpy
if pb:
pred = frozen_func(x=tf.constant(imn)).numpy()
elif saved_model:
pred = model(imn, training=False).numpy()
elif tflite:
if tfl_int8:
scale, zero_point = input_details[0]['quantization']
imn = (imn / scale + zero_point).astype(np.uint8)
interpreter.set_tensor(input_details[0]['index'], imn)
interpreter.invoke()
pred = interpreter.get_tensor(output_details[0]['index'])
if tfl_int8:
scale, zero_point = output_details[0]['quantization']
pred = (pred.astype(np.float32) - zero_point) * scale
pred[..., 0] *= imgsz[1] # x
pred[..., 1] *= imgsz[0] # y
pred[..., 2] *= imgsz[1] # w
pred[..., 3] *= imgsz[0] # h
pred = torch.tensor(pred)

# NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
Expand Down Expand Up @@ -202,9 +244,9 @@ def run(weights='yolov5s.pt', # model.pt path(s)

def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pb', help='model.pt path(s)')
parser.add_argument('--source', type=str, default='data/images', help='file/dir/URL/glob, 0 for webcam')
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
Expand All @@ -226,7 +268,9 @@ def parse_opt():
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--tfl-int8', action='store_true', help='INT8 quantized TFLite model')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
return opt


Expand Down
8 changes: 6 additions & 2 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,18 @@ def forward(self, x, augment=False, profile=False, visualize=False):
return y, None # inference, train output


def attempt_load(weights, map_location=None, inplace=True):
def attempt_load(weights, map_location=None, inplace=True, fuse=True):
from models.yolo import Detect, Model

# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location=map_location) # load
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
if fuse:
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
else:
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse


# Compatibility updates
for m in model.modules():
Expand Down
Loading

0 comments on commit 8ebe069

Please sign in to comment.