-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for TF export #23
Conversation
@AyushExel nice! I think TFProto needs a call() method though no? |
@AyushExel also you should build and run a segmentation model following the detection example below to verify it works, i.e. from tf.py run() function: # TensorFlow model
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
tf_model = TFModel(cfg='yolov5-seg.yaml', model=model, nc=model.nc, imgsz=imgsz)
_ = tf_model.predict(im) # inference EDIT: TFSegment() forward method needs to be renamed to 'call' in TF terminology |
@glenn-jocher ok. |
@zldrobit can you take a look at this PR please? We are adding Segmentation model support to YOLOv5 in ultralytics#9052, and we need to add two additional modules to tf.py which are a Segment() head and a Proto() module for masks. Segment() is a little tricky as it inherits from Detect() and uses the Detect.forward method in addition to Proto.forward. Lines 92 to 107 in 3b25f3c
Lines 764 to 775 in 3b25f3c
|
@glenn-jocher @AyushExel I am trying to reproduce the error. However, I couldn't find the segmentation model ( EDIT: By comparing Lines 92 to 107 in 3b25f3c
Lines 764 to 775 in 3b25f3c
with Lines 322 to 347 in d2af8e1
, I cannot find any mismatch numbers. |
From the second image, I recognized that 36864 = 147456 / 4 = 3 x 3 x 64 x 256 / 4. Maybe this is helpful to locate the problem. |
@zldrobit thanks for taking a look! The new v6.3 segmentation models are temporarily in the v6.2 assets. They are just finishing training now. I've uploaded yolvo5s-seg.pt here: |
@glenn-jocher glad to help! I sent a PR to address the TF export problem of segmentation model. After that, I tried to run |
Fix TF/TFLite export for segmentation model
@zldrobit thanks! I think your last PR that I just merged solves this problem? |
@AyushExel my pleasure! Yes, I could confirm after the PR, the |
@zldrobit awesome thanks for the help! @AyushExel allright let's merge this into instance_seg and I can debug a bit there. |
@zldrobit yes it looks like DetectMultiBackend is expecting all models to output a single np.array/torch.tensor, but we have lists/tuples used for Segmentation, so we need some conditional logic in DetectMultiBackend to handle SegmentationModels specially, and/or we need to convert ClassificationModel and DetectionModel to output list/tuple also of length 1 to unify all the YOLOv5 models. |
Usage:
Output
![Screenshot from 2022-08-30 13-20-58](https://user-images.githubusercontent.com/15766192/187381789-11dfa02b-7eb2-4b15-8ab1-2cd7d188bf57.png)
We'll need to wait for the official models to complete training before we can run benchmark on exported model