Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TensorFlow and TFLite export #1127

Merged
merged 71 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
5b6528c
Add models/tf.py for TensorFlow and TFLite export
zldrobit Oct 9, 2020
a30dad4
Set auto=False for int8 calibration
zldrobit Oct 9, 2020
f0cb6e2
Update requirements.txt for TensorFlow and TFLite export
zldrobit Oct 20, 2020
ce73d3d
Read anchors directly from PyTorch weights
zldrobit Oct 23, 2020
d101f7e
Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export
zldrobit Nov 2, 2020
55fb2eb
Remove check_anchor_order, check_file, set_logging from import
zldrobit Nov 16, 2020
3dd69d9
Reformat code and optimize imports
glenn-jocher Nov 28, 2020
e8a9ad2
Autodownload model and check cfg
glenn-jocher Nov 28, 2020
efbb853
update --source path, img-size to 320, single output
glenn-jocher Nov 28, 2020
aed53ce
Adjust representative_dataset
glenn-jocher Nov 28, 2020
ccb2336
Put representative dataset in tfl_int8 block
zldrobit Nov 30, 2020
9f893c8
detect.py TF inference
glenn-jocher Dec 2, 2020
d9fad06
Merge remote-tracking branch 'origin/tf-only-export' into tf-only-export
glenn-jocher Dec 2, 2020
49a9e05
weights to string
glenn-jocher Dec 2, 2020
1867bb4
weights to string
glenn-jocher Dec 2, 2020
4eed608
cleanup tf.py
glenn-jocher Dec 4, 2020
8ba2ca9
Add --dynamic-batch-size
zldrobit Dec 22, 2020
4d9104a
Add xywh normalization to reduce calibration error
zldrobit Dec 22, 2020
ae9bce8
Merge branch 'master' into tf-only-export
glenn-jocher Dec 22, 2020
cabb802
Update requirements.txt
zldrobit Dec 23, 2020
a5967f8
Fix imports
zldrobit Dec 24, 2020
dbc7f71
Add models/tf.py for TensorFlow and TFLite export
zldrobit Oct 9, 2020
565e620
Set auto=False for int8 calibration
zldrobit Oct 9, 2020
fc26561
Update requirements.txt for TensorFlow and TFLite export
zldrobit Oct 20, 2020
05cc389
Read anchors directly from PyTorch weights
zldrobit Oct 23, 2020
817fcf8
Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export
zldrobit Nov 2, 2020
9a4ce5e
Remove check_anchor_order, check_file, set_logging from import
zldrobit Nov 16, 2020
719479e
Reformat code and optimize imports
glenn-jocher Nov 28, 2020
243fa7f
Autodownload model and check cfg
glenn-jocher Nov 28, 2020
3b6cf12
update --source path, img-size to 320, single output
glenn-jocher Nov 28, 2020
5a5f949
Adjust representative_dataset
glenn-jocher Nov 28, 2020
00d50fd
detect.py TF inference
glenn-jocher Dec 2, 2020
061907b
Put representative dataset in tfl_int8 block
zldrobit Nov 30, 2020
ca4550b
weights to string
glenn-jocher Dec 2, 2020
5e04c5c
weights to string
glenn-jocher Dec 2, 2020
9121a87
cleanup tf.py
glenn-jocher Dec 4, 2020
e9bc606
Add --dynamic-batch-size
zldrobit Dec 22, 2020
dacd8af
Add xywh normalization to reduce calibration error
zldrobit Dec 22, 2020
b492af9
Update requirements.txt
zldrobit Dec 23, 2020
fbf5a45
Fix imports
zldrobit Dec 24, 2020
c4cfbd9
Merge branch 'tf-only-export' of https://github.com/zldrobit/yolov5 i…
glenn-jocher Jan 8, 2021
c761637
implement C3() and SiLU()
glenn-jocher Jan 8, 2021
36ed3cd
Fix reshape dim to support dynamic batching
zldrobit Feb 2, 2021
aad9e24
Merge branch 'master' into tf-only-export
glenn-jocher Feb 4, 2021
8ec975e
Merge branch 'master' into tf-only-export
glenn-jocher Feb 6, 2021
b0fa5a3
Add epsilon argument in tf_BN, which is different between TF and PT
zldrobit Mar 11, 2021
710bf56
Set stride to None if not using PyTorch, and do not warmup without Py…
zldrobit Mar 12, 2021
c45ceef
Add list support in check_img_size()
zldrobit Mar 23, 2021
0d39b24
Add list input support in detect.py
zldrobit Mar 23, 2021
47da942
merge ultralytics:master
glenn-jocher Mar 25, 2021
8cb7032
merge ultralytics:master
glenn-jocher Mar 28, 2021
4e1485b
sys.path.append('./') to run from yolov5/
glenn-jocher Mar 28, 2021
e4e6d6f
Add int8 quantization support for TensorFlow 2.5
zldrobit Apr 3, 2021
aafe224
Add get_coco128.sh
zldrobit May 6, 2021
d3be281
Remove --no-tfl-detect in models/tf.py (Use tf-android-tfl-detect bra…
zldrobit May 6, 2021
9ca5d7a
Merge branch 'develop' into tf-only-export
glenn-jocher Jun 2, 2021
215c865
Update requirements.txt
glenn-jocher Jun 2, 2021
a2867da
Replace torch.load() with attempt_load()
zldrobit Jun 7, 2021
86768d1
Update requirements.txt
zldrobit Jun 7, 2021
972c6a2
Add --tf-raw-resize to set half_pixel_centers=False
zldrobit Jun 11, 2021
eed4980
Add --agnostic-nms for TF class-agnostic NMS
zldrobit Jun 18, 2021
34b7d67
Merge master
glenn-jocher Aug 16, 2021
10eebf2
Cleanup after merge
glenn-jocher Aug 16, 2021
e3aa755
Cleanup2 after merge
glenn-jocher Aug 16, 2021
27fc39b
Cleanup3 after merge
glenn-jocher Aug 16, 2021
48a6bf8
Add tf.py docstring with credit and usage
glenn-jocher Aug 16, 2021
4e12c70
pb saved_model and tflite use only one model in detect.py
zldrobit Aug 17, 2021
ea0274f
Add use cases in docstring of tf.py
zldrobit Aug 17, 2021
2f63c6d
Remove redundant `stride` definition
glenn-jocher Aug 17, 2021
c3f46c3
Remove keras direct import
glenn-jocher Aug 17, 2021
3133381
Fix `check_requirements(('tensorflow>=2.4.1',))`
glenn-jocher Aug 17, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn

Expand Down Expand Up @@ -51,6 +52,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
hide_labels=False, # hide labels
hide_conf=False, # hide confidences
half=False, # use FP16 half-precision inference
tfl_int8=False, # INT8 quantized TFLite model
):
save_img = not nosave and not source.endswith('.txt') # save inference images
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
Expand All @@ -68,7 +70,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
# Load model
w = weights[0] if isinstance(weights, list) else weights
classify, suffix = False, Path(w).suffix.lower()
pt, onnx, tflite, pb, graph_def = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', '']) # backend
pt, onnx, tflite, pb, saved_model = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', '']) # backend
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
if pt:
model = attempt_load(weights, map_location=device) # load FP32 model
Expand All @@ -83,30 +85,49 @@ def run(weights='yolov5s.pt', # model.pt path(s)
check_requirements(('onnx', 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
else: # TensorFlow models
check_requirements(('tensorflow>=2.4.1',))
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped import
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
tf.nest.map_structure(x.graph.as_graph_element, outputs))

graph_def = tf.Graph().as_graph_def()
graph_def.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
elif saved_model:
model = tf.keras.models.load_model(w)
elif tflite:
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
imgsz = check_img_size(imgsz, s=stride) # check image size

# Dataloader
if webcam:
view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz, stride=stride)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
bs = len(dataset) # batch_size
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride)
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
bs = 1 # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs

# Run inference
if pt and device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once
t0 = time.time()
for path, img, im0s, vid_cap in dataset:
if pt:
if onnx:
img = img.astype('float32')
else:
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
elif onnx:
img = img.astype('float32')
img /= 255.0 # 0 - 255 to 0.0 - 1.0
img = img / 255.0 # 0 - 255 to 0.0 - 1.0
if len(img.shape) == 3:
img = img[None] # expand for batch dim

Expand All @@ -117,6 +138,27 @@ def run(weights='yolov5s.pt', # model.pt path(s)
pred = model(img, augment=augment, visualize=visualize)[0]
elif onnx:
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
else: # tensorflow model (tflite, pb, saved_model)
imn = img.permute(0, 2, 3, 1).cpu().numpy() # image in numpy
if pb:
pred = frozen_func(x=tf.constant(imn)).numpy()
elif saved_model:
pred = model(imn, training=False).numpy()
elif tflite:
if tfl_int8:
scale, zero_point = input_details[0]['quantization']
imn = (imn / scale + zero_point).astype(np.uint8)
interpreter.set_tensor(input_details[0]['index'], imn)
interpreter.invoke()
pred = interpreter.get_tensor(output_details[0]['index'])
if tfl_int8:
scale, zero_point = output_details[0]['quantization']
pred = (pred.astype(np.float32) - zero_point) * scale
pred[..., 0] *= imgsz[1] # x
pred[..., 1] *= imgsz[0] # y
pred[..., 2] *= imgsz[1] # w
pred[..., 3] *= imgsz[0] # h
pred = torch.tensor(pred)

# NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
Expand Down Expand Up @@ -202,9 +244,9 @@ def run(weights='yolov5s.pt', # model.pt path(s)

def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pb', help='model.pt path(s)')
parser.add_argument('--source', type=str, default='data/images', help='file/dir/URL/glob, 0 for webcam')
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
Expand All @@ -226,7 +268,9 @@ def parse_opt():
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--tfl-int8', action='store_true', help='INT8 quantized TFLite model')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
return opt


Expand Down
8 changes: 6 additions & 2 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,18 @@ def forward(self, x, augment=False, profile=False, visualize=False):
return y, None # inference, train output


def attempt_load(weights, map_location=None, inplace=True):
def attempt_load(weights, map_location=None, inplace=True, fuse=True):
from models.yolo import Detect, Model

# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location=map_location) # load
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
if fuse:
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
else:
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse


# Compatibility updates
for m in model.modules():
Expand Down
Loading