Skip to content

Commit

Permalink
Fix TF exports >= 2GB (ultralytics#6292)
Browse files Browse the repository at this point in the history
* Fix exporting saved_model: pb exceeds 2GB

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Replace TF v1.x API with TF v2.x API for saved_model export

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Clean up

* Remove lambda in tf.function()

* Revert "Remove lambda in tf.function()" to be compatible with TF v2.4

This reverts commit 46c7931f11dfdea6ae340c77287c35c30b9e0779.

* Fix for pre-commit.ci

* Cleanup1

* Cleanup2

* Backwards compatibility update

* Update common.py

* Update common.py

* Cleanup3

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
3 people authored and eladco committed Mar 10, 2022
1 parent 84dfc43 commit e5c11ff
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 55 deletions.
98 changes: 45 additions & 53 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
TensorFlow.js | `tfjs` | yolov5s_web_model/
Requirements:
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
Usage:
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx openvino engine coreml tflite ...
Expand Down Expand Up @@ -45,6 +49,7 @@
import subprocess
import sys
import time
import warnings
from pathlib import Path

import pandas as pd
Expand Down Expand Up @@ -239,41 +244,14 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')

def export_keras(model, im, file, dynamic, prefix=colorstr('Keras:')):
# YOLOv5 TensorFlow SavedModel export
try:
import tensorflow as tf
from tensorflow import keras

from models.keras import TFDetect, KerasModel

LOGGER.info(f'\n{prefix} starting export with keras {tf.__version__}...')
f = str(file).replace('.pt', '.h5')
batch_size, ch, *imgsz = list(im.shape) # BCHW

model = KerasModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for Keras
_ = model.predict(im) # first call to create weights
inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
outputs = model.predict(inputs)
keras_model = keras.Model(inputs=inputs, outputs=outputs, name="yolov5n")
keras_model.trainable = False
keras_model.summary()
keras_model.save(f, save_format='h5')

LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return keras_model, f
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')
return None, None

def export_saved_model(model, im, file, dynamic,
tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')):
conf_thres=0.25, keras=False, prefix=colorstr('TensorFlow SavedModel:')):
# YOLOv5 TensorFlow SavedModel export
try:
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

from models.tf import TFDetect, TFModel

Expand All @@ -282,16 +260,28 @@ def export_saved_model(model, im, file, dynamic,
batch_size, ch, *imgsz = list(im.shape) # BCHW

tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
im = tf.ones((batch_size, *imgsz, 3)) # BHWC order for TensorFlow
y = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
y = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
inputs = tf.keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
keras_model = keras.Model(inputs=inputs, outputs=outputs)
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
keras_model.trainable = False
keras_model.summary()
keras_model.save(f, save_format='tf')

if keras:
keras_model.save(f, save_format='tf')
else:
m = tf.function(lambda x: keras_model(x)) # full model
spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
m = m.get_concrete_function(spec)
frozen_func = convert_variables_to_constants_v2(m)
tfm = tf.Module()
tfm.__call__ = tf.function(lambda x: frozen_func(x), [spec])
tfm.__call__(im)
tf.saved_model.save(
tfm,
f,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if
check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return keras_model, f
except Exception as e:
Expand Down Expand Up @@ -358,13 +348,14 @@ def export_edgetpu(keras_model, im, file, prefix=colorstr('Edge TPU:')):
cmd = 'edgetpu_compiler --version'
help_url = 'https://coral.ai/docs/edgetpu/compiler/'
assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
if subprocess.run(cmd, shell=True).returncode != 0:
if subprocess.run(cmd + ' >/dev/null', shell=True).returncode != 0:
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
for c in ['curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
'sudo apt-get update',
'sudo apt-get install edgetpu-compiler']:
subprocess.run(c, shell=True, check=True)
subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]

LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
Expand Down Expand Up @@ -446,16 +437,17 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)

# Checks
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
opset = 12 if ('openvino' in include) else opset # OpenVINO requires opset <= 12

# Load PyTorch model
device = select_device(device)
assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model
nc, names = model.nc, model.names # number of classes, class names

# Checks
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
opset = 12 if ('openvino' in include) else opset # OpenVINO requires opset <= 12
assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}'

# Input
gs = int(max(model.stride)) # grid size (max stride)
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
Expand All @@ -477,10 +469,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'

for _ in range(2):
y = model(im) # dry runs
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)")
shape = tuple(y[0].shape) # model output shape
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")

# Exports
f = [''] * 11 # exported filenames
f = [''] * 10 # exported filenames
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
if 'torchscript' in include:
f[0] = export_torchscript(model, im, file, optimize)
if 'engine' in include: # TensorRT required before ONNX
Expand Down Expand Up @@ -510,17 +504,15 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
if tfjs:
f[9] = export_tfjs(model, im, file)

if 'keras' in include:
_, f[10] = export_keras(model, im, file, dynamic)

# Finish
f = [str(x) for x in f if x] # filter out '' and None
LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f"\nVisualize with https://netron.app"
f"\nDetect with `python detect.py --weights {f[-1]}`"
f" or `model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')"
f"\nValidate with `python val.py --weights {f[-1]}`")
if any(f):
LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f"\nDetect: python detect.py --weights {f[-1]}"
f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')"
f"\nValidate: python val.py --weights {f[-1]}"
f"\nVisualize: https://netron.app")
return f # return list of exported files/dirs


Expand Down
5 changes: 3 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
if saved_model: # SavedModel
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
import tensorflow as tf
model = tf.keras.models.load_model(w)
keras = False # assume TF1 saved_model
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
import tensorflow as tf
Expand Down Expand Up @@ -431,7 +432,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.saved_model: # SavedModel
y = self.model(im, training=False).numpy()
y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy()
elif self.pb: # GraphDef
y = self.frozen_func(x=self.tf.constant(im)).numpy()
elif self.tflite: # Lite
Expand Down

0 comments on commit e5c11ff

Please sign in to comment.