Skip to content

Commit

Permalink
Merge pull request #23 from AyushExel/instance_seg_tf
Browse files Browse the repository at this point in the history
Add support for TF export
  • Loading branch information
AyushExel committed Sep 4, 2022
2 parents 161d253 + dde9a55 commit 0a039c3
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions models/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
DWConvTranspose2d, Focus, autopad)
from models.experimental import MixConv2d, attempt_load
from models.yolo import Detect
from models.yolo import Detect, Segment
from utils.activations import SiLU
from utils.general import LOGGER, make_divisible, print_args

Expand Down Expand Up @@ -319,6 +319,32 @@ def _make_grid(nx=20, ny=20):
xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)

class TFSegment(TFDetect):
# YOLOv5 Segment head for segmentation models
def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), imgsz=(640, 640), w=None):
super().__init__(nc, anchors, ch, imgsz, w)
self.nm = nm # number of masks
self.npr = npr # number of protos
self.no = 5 + nc + self.nm # number of outputs per anchor
self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)] # output conv
self.proto = TFProto(ch[0], self.npr, self.nm, w=w.proto) # protos
self.detect = TFDetect.call

def call(self, x):
p = self.proto(x[0])
x = self.detect(self, x)
return (x, p) if self.training else ((x[0], p),)

class TFProto(keras.layers.Layer):
def __init__(self, c1, c_=256, c2=32, w=None):
super().__init__()
self.cv1 = TFConv(c1, c_, k=3, w=w.cv1)
self.upsample = TFUpsample(None, scale_factor=2, mode='nearest')
self.cv2 = TFConv(c_, c_, k=3, w=w.cv2)
self.cv3 = TFConv(c_, c2, w=w.cv3)

def call(self, inputs):
return self.cv3(self.cv2(self.upsample(self.cv1(inputs))))

class TFUpsample(keras.layers.Layer):
# TF version of torch.nn.Upsample()
Expand Down Expand Up @@ -377,10 +403,12 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
elif m is Detect:
elif m in [Detect, Segment]:
args.append([ch[x + 1] for x in f])
if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f)
if m is Segment:
args[3] = make_divisible(args[3] * gw, 8)
args.append(imgsz)
else:
c2 = ch[f]
Expand Down

0 comments on commit 0a039c3

Please sign in to comment.