Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Applying NMS to YOLOv5-exported TFlite Model #160

Open
jansdhillon opened this issue Oct 18, 2023 · 3 comments
Open

Applying NMS to YOLOv5-exported TFlite Model #160

jansdhillon opened this issue Oct 18, 2023 · 3 comments

Comments

@jansdhillon
Copy link

Hello All,

I am trying to use this (great) package for a Flutter app I am making, and I have run into a lot of issues using the tflite_model_maker. The best I've been able to get is training YOLOv5 on my dataset and then exporting it to TFLite.

When I apply NMS to the TFlite model when exporting it, I get an output shape like this:

  T#561(model/tf.math.multiply/Mul) shape:[1, 25200, 80], type:FLOAT32
  T#562(StatefulPartitionedCall:0) shape:[1, 100, 4], type:FLOAT32
  T#563(StatefulPartitionedCall:1) shape:[1, 100], type:FLOAT32
  T#564(StatefulPartitionedCall:2) shape:[1, 100], type:FLOAT32
  T#565(StatefulPartitionedCall:3) shape:[1], type:INT32

---------------------------------------------------------------
Your TFLite model has ‘1’ signature_def(s).

Signature#0 key: 'serving_default'
- Subgraph: Subgraph#0
- Inputs: 
    'input_1' : T#0
- Outputs: 
    'tf.image.combined_non_max_suppression' : T#562
    'tf.image.combined_non_max_suppression_1' : T#563
    'tf.image.combined_non_max_suppression_2' : T#564
    'tf.image.combined_non_max_suppression_3' : T#565

However, when I load this into the flutter example app, I'm warned that this operation is not supported by TFLite.

When I export it without the --nms applied, this is the output shape:
(from Netron)

name: StatefulPartitionedCall:0
tensor: float32[1,25200,85]
location: 532

However, this obviously does not work directly with the output shape defined in this package.

I am no ML expert by any means so I am a bit lost about how to apply NMS to my TFLite model. From what I can understand, the tflite_model_maker package has some sort of postprocessing step to get the desired output, but I have yet to get that working (see tensorflow/tensorflow#62135).

Could someone give me some pointers on how to process the output tensor to use in my Flutter app?

@CaptainDario
Copy link
Contributor

I do not have any experience with the model maker, so keep that in mind while reading this answer.

I think adding an NMS layer to the model should be possible, but I am not sure how to actually do it. Here are the TF docs for NMS (I guess you need TF Select ops for this to work) and here a discussion about NMS with yolo.

A second approach is for you to understand the output format of yolo v5 and implement NMS yourself. I implemented NMS for yolo v5 in dart, you can take a look if you want but it is very slow and should be implemented a different way (maybe in C).

@jansdhillon
Copy link
Author

Hi there, thanks so much for your response. I'll have to assess my options and go from there.

@chathurach
Copy link

As per my understanding, NMS cannot be embedded in tflite models. So the NMS part will be stripped out when you convert the yolo to tflite. Hence the warning. You will get all the possible detections from tflite models. So you will have to do the NMS part after detection.

  1. Loop through the 25200 cases of detections and in each detection, loop through all the 80 classes (in 85, the first 4 are for x,y, width, and height, then the confidence) and find bounding boxes that has more than 0.1 confidence. This will reduce the number of detections drastically.
  2. Then run NMS for the filtred bounding boxes.

here is an NMS example in dart;

import 'dart:ui';

import 'package:collection/collection.dart' as p;

import 'recognitions.dart';

//var labels;

List<Recognition> nms(List<Recognition> list, var labels,
    double ioU) // Turned from Java's ArrayList to Dart's List.
{
  List<Recognition> nmsList = List<Recognition>.empty(growable: true);

  for (int k = 0; k < labels.length; k++) {
    // 1.find max confidence per class
    p.PriorityQueue<Recognition> pq = p.HeapPriorityQueue<Recognition>();
    for (int i = 0; i < list.length; ++i) {
      if (list[i].label == labels[k]) {
        // Changed from comparing #th class to class to string to string
        pq.add(list[i]);
      }
    }

    // 2.do non maximum suppression
    while (pq.length > 0) {
      // insert detection with max confidence
      List<Recognition> detections = pq.toList(); //In Java: pq.toArray(a)
      Recognition max = detections[0];
      nmsList.add(max);
      pq.clear();
      for (int j = 1; j < detections.length; j++) {
        Recognition detection = detections[j];
        Rect b = detection.location!;
        Rect x = max.location!;
        if (boxIou(x, b) < ioU) {
          pq.add(detection);
        }
      }
    }
  }

  return nmsList;
}

double boxIou(Rect a, Rect b) {
  return boxIntersection(a, b) / boxUnion(a, b);
}

double boxIntersection(Rect a, Rect b) {
  double w = overlap((a.left + a.right) / 2, a.right - a.left,
      (b.left + b.right) / 2, b.right - b.left);
  double h = overlap((a.top + a.bottom) / 2, a.bottom - a.top,
      (b.top + b.bottom) / 2, b.bottom - b.top);
  if ((w < 0) || (h < 0)) {
    return 0;
  }
  double area = (w * h);
  return area;
}

double boxUnion(Rect a, Rect b) {
  double i = boxIntersection(a, b);
  double u = ((((a.right - a.left) * (a.bottom - a.top)) +
          ((b.right - b.left) * (b.bottom - b.top))) -
      i);
  return u;
}

double overlap(double x1, double w1, double x2, double w2) {
  double l1 = (x1 - (w1 / 2));
  double l2 = (x2 - (w2 / 2));
  double left = ((l1 > l2) ? l1 : l2);
  double r1 = (x1 + (w1 / 2));
  double r2 = (x2 + (w2 / 2));
  double right = ((r1 < r2) ? r1 : r2);
  return right - left;
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants