Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EdgeTPU support #3630

Merged
merged 107 commits into from
Dec 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
5b6528c
Add models/tf.py for TensorFlow and TFLite export
zldrobit Oct 9, 2020
a30dad4
Set auto=False for int8 calibration
zldrobit Oct 9, 2020
f0cb6e2
Update requirements.txt for TensorFlow and TFLite export
zldrobit Oct 20, 2020
ce73d3d
Read anchors directly from PyTorch weights
zldrobit Oct 23, 2020
d101f7e
Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export
zldrobit Nov 2, 2020
55fb2eb
Remove check_anchor_order, check_file, set_logging from import
zldrobit Nov 16, 2020
3dd69d9
Reformat code and optimize imports
glenn-jocher Nov 28, 2020
e8a9ad2
Autodownload model and check cfg
glenn-jocher Nov 28, 2020
efbb853
update --source path, img-size to 320, single output
glenn-jocher Nov 28, 2020
aed53ce
Adjust representative_dataset
glenn-jocher Nov 28, 2020
ccb2336
Put representative dataset in tfl_int8 block
zldrobit Nov 30, 2020
9f893c8
detect.py TF inference
glenn-jocher Dec 2, 2020
d9fad06
Merge remote-tracking branch 'origin/tf-only-export' into tf-only-export
glenn-jocher Dec 2, 2020
49a9e05
weights to string
glenn-jocher Dec 2, 2020
1867bb4
weights to string
glenn-jocher Dec 2, 2020
4eed608
cleanup tf.py
glenn-jocher Dec 4, 2020
8ba2ca9
Add --dynamic-batch-size
zldrobit Dec 22, 2020
4d9104a
Add xywh normalization to reduce calibration error
zldrobit Dec 22, 2020
ae9bce8
Merge branch 'master' into tf-only-export
glenn-jocher Dec 22, 2020
cabb802
Update requirements.txt
zldrobit Dec 23, 2020
a5967f8
Fix imports
zldrobit Dec 24, 2020
dbc7f71
Add models/tf.py for TensorFlow and TFLite export
zldrobit Oct 9, 2020
565e620
Set auto=False for int8 calibration
zldrobit Oct 9, 2020
fc26561
Update requirements.txt for TensorFlow and TFLite export
zldrobit Oct 20, 2020
05cc389
Read anchors directly from PyTorch weights
zldrobit Oct 23, 2020
817fcf8
Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export
zldrobit Nov 2, 2020
9a4ce5e
Remove check_anchor_order, check_file, set_logging from import
zldrobit Nov 16, 2020
719479e
Reformat code and optimize imports
glenn-jocher Nov 28, 2020
243fa7f
Autodownload model and check cfg
glenn-jocher Nov 28, 2020
3b6cf12
update --source path, img-size to 320, single output
glenn-jocher Nov 28, 2020
5a5f949
Adjust representative_dataset
glenn-jocher Nov 28, 2020
00d50fd
detect.py TF inference
glenn-jocher Dec 2, 2020
061907b
Put representative dataset in tfl_int8 block
zldrobit Nov 30, 2020
ca4550b
weights to string
glenn-jocher Dec 2, 2020
5e04c5c
weights to string
glenn-jocher Dec 2, 2020
9121a87
cleanup tf.py
glenn-jocher Dec 4, 2020
e9bc606
Add --dynamic-batch-size
zldrobit Dec 22, 2020
dacd8af
Add xywh normalization to reduce calibration error
zldrobit Dec 22, 2020
b492af9
Update requirements.txt
zldrobit Dec 23, 2020
fbf5a45
Fix imports
zldrobit Dec 24, 2020
c4cfbd9
Merge branch 'tf-only-export' of https://github.com/zldrobit/yolov5 i…
glenn-jocher Jan 8, 2021
c761637
implement C3() and SiLU()
glenn-jocher Jan 8, 2021
e610134
Add TensorFlow and TFLite Detection
zldrobit Oct 10, 2020
98bd249
Add --tfl-detect for TFLite Detection
zldrobit Oct 10, 2020
aeffcb0
Add int8 quantized TFLite inference in detect.py
zldrobit Oct 12, 2020
c10bd95
Add --edgetpu for Edge TPU detection
zldrobit Dec 30, 2020
0b943fe
Fix --img-size to add rectangle TensorFlow and TFLite input
zldrobit Oct 20, 2020
15e7be4
Add --no-tf-nms to detect objects using models combined with TensorFl…
zldrobit Nov 2, 2020
83852c2
Fix --img-size list type input
zldrobit Nov 2, 2020
992ec81
Update README.md
zldrobit Nov 9, 2020
058afbf
Add Android project for TFLite inference
zldrobit Oct 14, 2020
5e7dc98
Upgrade TensorFlow v2.3.1 -> v2.4.0
zldrobit Dec 31, 2020
14af2ea
Disable normalization of xywh
zldrobit Jan 8, 2021
986cc0f
Rewrite names init in detect.py
zldrobit Jan 8, 2021
fb5398c
Change input resolution 640 -> 320 on Android
Jan 8, 2021
5091564
Disable NNAPI
Jan 8, 2021
856462f
Update README.me --img 640 -> 320
zldrobit Jan 10, 2021
a16f037
Update README.me for Edge TPU
zldrobit Jan 10, 2021
c37d6b5
Update README.md
zldrobit Jan 27, 2021
36ed3cd
Fix reshape dim to support dynamic batching
zldrobit Feb 2, 2021
f9b6202
Fix reshape dim to support dynamic batching
zldrobit Feb 2, 2021
aad9e24
Merge branch 'master' into tf-only-export
glenn-jocher Feb 4, 2021
8ec975e
Merge branch 'master' into tf-only-export
glenn-jocher Feb 6, 2021
b0fa5a3
Add epsilon argument in tf_BN, which is different between TF and PT
zldrobit Mar 11, 2021
710bf56
Set stride to None if not using PyTorch, and do not warmup without Py…
zldrobit Mar 12, 2021
c45ceef
Add list support in check_img_size()
zldrobit Mar 23, 2021
0d39b24
Add list input support in detect.py
zldrobit Mar 23, 2021
47da942
merge ultralytics:master
glenn-jocher Mar 25, 2021
8cb7032
merge ultralytics:master
glenn-jocher Mar 28, 2021
4e1485b
sys.path.append('./') to run from yolov5/
glenn-jocher Mar 28, 2021
e4e6d6f
Add int8 quantization support for TensorFlow 2.5
zldrobit Apr 3, 2021
aafe224
Add get_coco128.sh
zldrobit May 6, 2021
d3be281
Remove --no-tfl-detect in models/tf.py (Use tf-android-tfl-detect bra…
zldrobit May 6, 2021
9ca5d7a
Merge branch 'develop' into tf-only-export
glenn-jocher Jun 2, 2021
215c865
Update requirements.txt
glenn-jocher Jun 2, 2021
a2867da
Replace torch.load() with attempt_load()
zldrobit Jun 7, 2021
86768d1
Update requirements.txt
zldrobit Jun 7, 2021
972c6a2
Add --tf-raw-resize to set half_pixel_centers=False
zldrobit Jun 11, 2021
14cb406
Merge branch 'tf-only-export' into tf-android-tfl-detect
zldrobit Jun 15, 2021
a5885b7
Remove android directory
zldrobit Jun 16, 2021
aa10f48
Update README.md
zldrobit Jun 16, 2021
a4b1d8b
Update README.md
zldrobit Jun 16, 2021
3685829
Add multiple OS support for EdgeTPU detection
zldrobit Jul 21, 2021
244a9ae
Merge master
zldrobit Nov 12, 2021
e2e6987
Fix export and detect
zldrobit Nov 12, 2021
cc0d939
Fix saved_model and pb detect error
zldrobit Nov 12, 2021
1d1ec40
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2021
ae9bf97
Fix pre-commit.ci failure
zldrobit Nov 12, 2021
672c800
Add edgetpu in export.py docstring
zldrobit Nov 19, 2021
ad1c9f7
Merge branch 'master' into tf-edgetpu
glenn-jocher Nov 22, 2021
9b00f86
Fix Edge TPU model detection exported by TF 2.7
zldrobit Nov 24, 2021
a8ef5c4
Add class names for TF/TFLite in DetectMultibackend
zldrobit Nov 25, 2021
e02b160
Fix assignment with nl in TFLite Detection
zldrobit Nov 25, 2021
df45b51
Add check when getting Edge TPU compiler version
zldrobit Nov 25, 2021
6b35aaf
Merge master
zldrobit Nov 26, 2021
9912bc6
Add UTF-8 encoding in opening --data file for Windows
zldrobit Nov 26, 2021
f335296
Remove redundant TensorFlow import
zldrobit Nov 26, 2021
aee5a65
Add Edge TPU in export.py's docstring
zldrobit Dec 10, 2021
177f01e
Add the detect layer in Edge TPU model conversion
zldrobit Dec 10, 2021
411d7cf
Merge master
zldrobit Dec 10, 2021
516db4a
Merge branch 'master' into tf-edgetpu
glenn-jocher Dec 23, 2021
939650d
Default `dnn=False`
glenn-jocher Dec 23, 2021
4303323
Cleanup data.yaml loading
glenn-jocher Dec 23, 2021
e2172cf
Update detect.py
glenn-jocher Dec 23, 2021
48bc0d2
Update val.py
glenn-jocher Dec 23, 2021
243f537
Merge master
glenn-jocher Dec 31, 2021
16ef8e7
Comments and generalize data.yaml names
glenn-jocher Dec 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
@torch.no_grad()
def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
imgsz=(640, 640), # inference size (height, width)
conf_thres=0.25, # confidence threshold
iou_thres=0.45, # NMS IOU threshold
Expand Down Expand Up @@ -76,7 +77,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)

# Load model
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data)
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
imgsz = check_img_size(imgsz, s=stride) # check image size

Expand Down Expand Up @@ -204,6 +205,7 @@ def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)')
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam')
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
Expand Down
29 changes: 25 additions & 4 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,24 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te
LOGGER.info(f'\n{prefix} export failure: {e}')


def export_edgetpu(keras_model, im, file, prefix=colorstr('Edge TPU:')):
# YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
try:
cmd = 'edgetpu_compiler --version'
out = subprocess.run(cmd, shell=True, capture_output=True, check=True)
ver = out.stdout.decode().split()[-1]
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
f = str(file).replace('.pt', '-int8_edgetpu.tflite')
f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model

cmd = f"edgetpu_compiler -s {f_tfl}"
subprocess.run(cmd, shell=True, check=True)

LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')


def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
# YOLOv5 TensorFlow.js export
try:
Expand Down Expand Up @@ -285,6 +303,7 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):


def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
try:
check_requirements(('tensorrt',))
import tensorrt as trt
Expand Down Expand Up @@ -356,7 +375,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
):
t = time.time()
include = [x.lower() for x in include]
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)

# Checks
Expand Down Expand Up @@ -405,15 +424,17 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'

# TensorFlow Exports
if any(tf_exports):
pb, tflite, tfjs = tf_exports[1:]
pb, tflite, edgetpu, tfjs = tf_exports[1:]
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
model = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs,
agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class, topk_all=topk_all,
conf_thres=conf_thres, iou_thres=iou_thres) # keras model
if pb or tfjs: # pb prerequisite to tfjs
export_pb(model, im, file)
if tflite:
export_tflite(model, im, file, int8=int8, data=data, ncalib=100)
if tflite or edgetpu:
export_tflite(model, im, file, int8=int8 or edgetpu, data=data, ncalib=100)
if edgetpu:
export_edgetpu(model, im, file)
if tfjs:
export_tfjs(model, im, file)

Expand Down
10 changes: 8 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import requests
import torch
import torch.nn as nn
import yaml
from PIL import Image
from torch.cuda import amp

Expand Down Expand Up @@ -276,14 +277,15 @@ def forward(self, x):

class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=None, dnn=False):
def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript
# CoreML: *.mlmodel
# TensorFlow: *_saved_model
# TensorFlow: *.pb
# TensorFlow Lite: *.tflite
# TensorFlow Edge TPU: *_edgetpu.tflite
# ONNX Runtime: *.onnx
# OpenCV DNN: *.onnx with dnn=True
# TensorRT: *.engine
Expand All @@ -297,6 +299,9 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False):
pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
w = attempt_download(w) # download if not local
if data: # data.yaml path (optional)
with open(data, errors='ignore') as f:
names = yaml.safe_load(f)['names'] # class names

if jit: # TorchScript
LOGGER.info(f'Loading {w} for TorchScript inference...')
Expand Down Expand Up @@ -343,7 +348,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False):
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
context = model.create_execution_context()
batch_size = bindings['images'].shape[0]
else: # TensorFlow model (TFLite, pb, saved_model)
else: # TensorFlow (TFLite, pb, saved_model)
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...')
import tensorflow as tf
Expand Down Expand Up @@ -425,6 +430,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
y[..., 1] *= h # y
y[..., 2] *= w # w
y[..., 3] *= h # h

y = torch.tensor(y) if isinstance(y, np.ndarray) else y
return (y, []) if val else y

Expand Down
2 changes: 1 addition & 1 deletion val.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def run(data,
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir

# Load model
model = DetectMultiBackend(weights, device=device, dnn=dnn)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data)
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
imgsz = check_img_size(imgsz, s=stride) # check image size
half &= (pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
Expand Down