Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weโ€™ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorFlow SegmentationModel support #9472

Merged
merged 9 commits into from
Sep 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
python benchmarks.py --data coco128.yaml --weights ${{ matrix.model }}.pt --img 320 --hard-fail 0.29
- name: Benchmark SegmentationModel
run: |
python benchmarks.py --data coco128-seg.yaml --weights ${{ matrix.model }}-seg.pt --img 320
python benchmarks.py --data coco128-seg.yaml --weights ${{ matrix.model }}-seg.pt --img 320 --hard-fail 0.22

Tests:
timeout-minutes: 60
Expand Down
2 changes: 1 addition & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def export_saved_model(model,
m = m.get_concrete_function(spec)
frozen_func = convert_variables_to_constants_v2(m)
tfm = tf.Module()
tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x)[0], [spec])
tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x), [spec])
tfm.__call__(im)
tf.saved_model.save(tfm,
f,
Expand Down
29 changes: 20 additions & 9 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,17 @@ def wrap_frozen_graph(gd, inputs, outputs):
ge = x.graph.as_graph_element
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))

def gd_outputs(gd):
name_list, input_list = [], []
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
name_list.append(node.name)
input_list.extend(node.input)
return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))

gd = tf.Graph().as_graph_def() # TF GraphDef
with open(w, 'rb') as f:
gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate
Expand Down Expand Up @@ -528,22 +535,26 @@ def forward(self, im, augment=False, visualize=False):
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.saved_model: # SavedModel
y = (self.model(im, training=False) if self.keras else self.model(im)).numpy()
y = self.model(im, training=False) if self.keras else self.model(im)
elif self.pb: # GraphDef
y = self.frozen_func(x=self.tf.constant(im)).numpy()
y = self.frozen_func(x=self.tf.constant(im))
else: # Lite or Edge TPU
input, output = self.input_details[0], self.output_details[0]
input = self.input_details[0]
int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
if int8:
scale, zero_point = input['quantization']
im = (im / scale + zero_point).astype(np.uint8) # de-scale
self.interpreter.set_tensor(input['index'], im)
self.interpreter.invoke()
y = self.interpreter.get_tensor(output['index'])
if int8:
scale, zero_point = output['quantization']
y = (y.astype(np.float32) - zero_point) * scale # re-scale
y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
y = []
for output in self.output_details:
x = self.interpreter.get_tensor(output['index'])
if int8:
scale, zero_point = output['quantization']
x = (x.astype(np.float32) - zero_point) * scale # re-scale
y.append(x)
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels

if isinstance(y, (list, tuple)):
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
Expand Down
15 changes: 8 additions & 7 deletions models/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,15 @@ def call(self, inputs):
x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])

if not self.training: # inference
y = tf.sigmoid(x[i])
y = x[i]
grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3]) * 4
xy = (y[..., 0:2] * 2 + grid) * self.stride[i] # xy
wh = y[..., 2:4] ** 2 * anchor_grid
xy = (tf.sigmoid(y[..., 0:2]) * 2 + grid) * self.stride[i] # xy
wh = tf.sigmoid(y[..., 2:4]) ** 2 * anchor_grid
# Normalize xywh to 0-1 to reduce calibration error
xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
y = tf.concat([xy, wh, y[..., 4:]], -1)
y = tf.concat([xy, wh, tf.sigmoid(y[..., 4:5 + self.nc]), y[..., 5 + self.nc:]], -1)
z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))

return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1), x)
Expand All @@ -333,8 +333,9 @@ def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), imgsz=(640, 640), w

def call(self, x):
p = self.proto(x[0])
p = tf.transpose(p, [0, 3, 1, 2]) # from shape(1,160,160,32) to shape(1,32,160,160)
x = self.detect(self, x)
return (x, p) if self.training else ((x[0], p),)
return (x, p) if self.training else (x[0], p)


class TFProto(keras.layers.Layer):
Expand Down Expand Up @@ -485,8 +486,8 @@ def predict(self,
conf_thres,
clip_boxes=False)
return nms, x[1]
return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
# x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
return x # output [1,6300,85] = [xywh, conf, class0, class1, ...]
# x = x[0] # [x(1,6300,85), ...] to x(6300,85)
# xywh = x[..., :4] # x(6300,4) boxes
# conf = x[..., 4:5] # x(6300,1) confidences
# cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
Expand Down