Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for TF export #23

Merged
merged 7 commits into from
Sep 4, 2022
Merged

Add support for TF export #23

merged 7 commits into from
Sep 4, 2022

Conversation

AyushExel
Copy link
Owner

@AyushExel AyushExel commented Aug 30, 2022

Usage:

python export.py --weights yolov5s-seg.pt --include tflite

Output
Screenshot from 2022-08-30 13-20-58

We'll need to wait for the official models to complete training before we can run benchmark on exported model

@glenn-jocher
Copy link

@AyushExel nice! I think TFProto needs a call() method though no?

@glenn-jocher
Copy link

glenn-jocher commented Aug 30, 2022

@AyushExel also you should build and run a segmentation model following the detection example below to verify it works, i.e. from tf.py run() function:

    # TensorFlow model
    im = tf.zeros((batch_size, *imgsz, 3))  # BHWC image
    tf_model = TFModel(cfg='yolov5-seg.yaml', model=model, nc=model.nc, imgsz=imgsz)
    _ = tf_model.predict(im)  # inference

EDIT: TFSegment() forward method needs to be renamed to 'call' in TF terminology

@AyushExel
Copy link
Owner Author

@glenn-jocher ok.
Channel is at last in tf format right. So the input of proto layer should be 1x80x80x64 instead of 1x64x80x80 but it still complains about dimensions.
Screenshot 2022-08-30 at 4 49 07 PM
Screenshot 2022-08-30 at 4 49 41 PM
I think there's something messing up the weight transfer, I'm jumping into this rabbit hole

@glenn-jocher
Copy link

@zldrobit can you take a look at this PR please? We are adding Segmentation model support to YOLOv5 in ultralytics#9052, and we need to add two additional modules to tf.py which are a Segment() head and a Proto() module for masks. Segment() is a little tricky as it inherits from Detect() and uses the Detect.forward method in addition to Proto.forward.

yolov5/models/yolo.py

Lines 92 to 107 in 3b25f3c

class Segment(Detect):
# YOLOv5 Segment head for segmentation models
def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), inplace=True):
super().__init__(nc, anchors, ch, inplace)
self.nm = nm # number of masks
self.npr = npr # number of protos
self.no = 5 + nc + self.nm # number of outputs per anchor
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.proto = Proto(ch[0], self.npr, self.nm) # protos
self.detect = Detect.forward
def forward(self, x):
p = self.proto(x[0])
x = self.detect(self, x)
return (x, p) if self.training else (x[0], p) if self.export else (x[0], (x[1], p))

yolov5/models/common.py

Lines 764 to 775 in 3b25f3c

class Proto(nn.Module):
# YOLOv5 mask Proto module for segmentation models
def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
super().__init__()
self.cv1 = Conv(c1, c_, k=3)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.cv2 = Conv(c_, c_, k=3)
self.cv3 = Conv(c_, c2)
def forward(self, x):
return self.cv3(self.cv2(self.upsample(self.cv1(x))))

@zldrobit
Copy link

zldrobit commented Sep 1, 2022

@glenn-jocher @AyushExel I am trying to reproduce the error. However, I couldn't find the segmentation model (yolov5s-seg.pt). If it's possible, plz share the segmentation model for debugging.

EDIT: By comparing

yolov5/models/yolo.py

Lines 92 to 107 in 3b25f3c

class Segment(Detect):
# YOLOv5 Segment head for segmentation models
def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), inplace=True):
super().__init__(nc, anchors, ch, inplace)
self.nm = nm # number of masks
self.npr = npr # number of protos
self.no = 5 + nc + self.nm # number of outputs per anchor
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.proto = Proto(ch[0], self.npr, self.nm) # protos
self.detect = Detect.forward
def forward(self, x):
p = self.proto(x[0])
x = self.detect(self, x)
return (x, p) if self.training else (x[0], p) if self.export else (x[0], (x[1], p))

yolov5/models/common.py

Lines 764 to 775 in 3b25f3c

class Proto(nn.Module):
# YOLOv5 mask Proto module for segmentation models
def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
super().__init__()
self.cv1 = Conv(c1, c_, k=3)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.cv2 = Conv(c_, c_, k=3)
self.cv3 = Conv(c_, c2)
def forward(self, x):
return self.cv3(self.cv2(self.upsample(self.cv1(x))))

with

yolov5/models/tf.py

Lines 322 to 347 in d2af8e1

class TFSegment(TFDetect):
# YOLOv5 Segment head for segmentation models
def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), imgsz=(640, 640), w=None):
super().__init__(nc, anchors, ch, imgsz, w)
self.nm = nm # number of masks
self.npr = npr # number of protos
self.no = 5 + nc + self.nm # number of outputs per anchor
self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)] # output conv
self.proto = TFProto(ch[0], self.npr, self.nm, w=w.proto) # protos
self.detect = TFDetect.call
def call(self, x):
p = self.proto(x[0])
x = self.detect(self, x)
return (x, p) if self.training else (x[0], p) if self.export else (x[0], (x[1], p))
class TFProto(keras.layers.Layer):
def __init__(self, c1, c_=256, c2=32, w=None):
super().__init__()
self.cv1 = TFConv(c1, c_, k=3, w=w.cv1)
self.upsample = TFUpsample(None, scale_factor=2, mode='nearest')
self.cv2 = TFConv(c_, c_, k=3, w=w.cv2)
self.cv3 = TFConv(c_, c2, w=w.cv3)
def call(self, inputs):
return self.cv3(self.cv2(self.upsample(self.cv1(inputs))))

, I cannot find any mismatch numbers.

@zldrobit
Copy link

zldrobit commented Sep 1, 2022

@AyushExel

Channel is at last in tf format right. So the input of proto layer should be 1x80x80x64 instead of 1x64x80x80 but it still complains about dimensions.

I think there's something messing up the weight transfer, I'm jumping into this rabbit hole

From the second image, I recognized that 36864 = 147456 / 4 = 3 x 3 x 64 x 256 / 4. Maybe this is helpful to locate the problem.

@glenn-jocher
Copy link

glenn-jocher commented Sep 2, 2022

@zldrobit thanks for taking a look! The new v6.3 segmentation models are temporarily in the v6.2 assets. They are just finishing training now. I've uploaded yolvo5s-seg.pt here:
https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5s-seg.pt

@zldrobit
Copy link

zldrobit commented Sep 3, 2022

@glenn-jocher glad to help! I sent a PR to address the TF export problem of segmentation model. After that, I tried to run segment/predict.py with the TFLite model but failed. TFLite inference in DetectMultiBackend() assumes TF/TFLite models have only one output, and it's conflicted with the segmentation model of more than one output. IMHO, running YOLOv5 TF segment model may take some more development and refactor efforts.

Fix TF/TFLite export for segmentation model
@AyushExel
Copy link
Owner Author

@zldrobit thanks! I think your last PR that I just merged solves this problem?

@zldrobit
Copy link

zldrobit commented Sep 3, 2022

@AyushExel my pleasure! Yes, I could confirm after the PR, the instance_seg_tf branch could export a YOLOv5 TF/TFLite segmentation model.

@glenn-jocher
Copy link

@zldrobit awesome thanks for the help!

@AyushExel allright let's merge this into instance_seg and I can debug a bit there.

@glenn-jocher
Copy link

glenn-jocher commented Sep 3, 2022

@zldrobit yes it looks like DetectMultiBackend is expecting all models to output a single np.array/torch.tensor, but we have lists/tuples used for Segmentation, so we need some conditional logic in DetectMultiBackend to handle SegmentationModels specially, and/or we need to convert ClassificationModel and DetectionModel to output list/tuple also of length 1 to unify all the YOLOv5 models.

@AyushExel AyushExel merged commit 0a039c3 into instance_seg Sep 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants