Skip to content

Commit

Permalink
Fix TF exports >= 2GB (ultralytics#6292)
Browse files Browse the repository at this point in the history
* Fix exporting saved_model: pb exceeds 2GB

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Replace TF v1.x API with TF v2.x API for saved_model export

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Clean up

* Remove lambda in tf.function()

* Revert "Remove lambda in tf.function()" to be compatible with TF v2.4

This reverts commit 46c7931f11dfdea6ae340c77287c35c30b9e0779.

* Fix for pre-commit.ci

* Cleanup1

* Cleanup2

* Backwards compatibility update

* Update common.py

* Update common.py

* Cleanup3

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
3 people committed Feb 18, 2022
1 parent 3a80ec4 commit 0a1a89d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
25 changes: 19 additions & 6 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F

def export_saved_model(model, im, file, dynamic,
tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')):
conf_thres=0.25, keras=False, prefix=colorstr('TensorFlow SavedModel:')):
# YOLOv5 TensorFlow SavedModel export
try:
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

from models.tf import TFDetect, TFModel

Expand All @@ -262,13 +262,26 @@ def export_saved_model(model, im, file, dynamic,
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
inputs = tf.keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
keras_model = keras.Model(inputs=inputs, outputs=outputs)
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
keras_model.trainable = False
keras_model.summary()
keras_model.save(f, save_format='tf')

if keras:
keras_model.save(f, save_format='tf')
else:
m = tf.function(lambda x: keras_model(x)) # full model
spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
m = m.get_concrete_function(spec)
frozen_func = convert_variables_to_constants_v2(m)
tfm = tf.Module()
tfm.__call__ = tf.function(lambda x: frozen_func(x), [spec])
tfm.__call__(im)
tf.saved_model.save(
tfm,
f,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if
check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return keras_model, f
except Exception as e:
Expand Down
5 changes: 3 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
if saved_model: # SavedModel
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
import tensorflow as tf
model = tf.keras.models.load_model(w)
keras = False # assume TF1 saved_model
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
import tensorflow as tf
Expand Down Expand Up @@ -431,7 +432,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.saved_model: # SavedModel
y = self.model(im, training=False).numpy()
y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy()
elif self.pb: # GraphDef
y = self.frozen_func(x=self.tf.constant(im)).numpy()
elif self.tflite: # Lite
Expand Down

0 comments on commit 0a1a89d

Please sign in to comment.