Skip to content

Commit

Permalink
feat(predict): add best_bbox for torch, trt, caffe, ncnn backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Jun 9, 2021
1 parent a0da6f7 commit 7890401
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@ sentences | bool | yes | false
characters | bool | yes | false | character-level text processing, as opposed to word-based text processing
sequence | int | yes | N/A | for character-level text processing, the fixed length of each sample of text
read_forward | bool | yes | false | for character-level text processing, whether to read content from left to right
alphabet | string | yes | abcdefghijklmnopqrstuvwxyz 0123456789 ,;.!?:'"/\\\ \|_@#$%^&*~\`+-=<>()[]{} | for character-level text processing, the alphabet of recognized symbols
alphabet | string | yes | abcdefghijklmnopqrstuvwxyz 0123456789 ,;.!?:'"/\\\ \|\_@#$%^&\*~\`+-=<>()[]{} | for character-level text processing, the alphabet of recognized symbols
sparse | bool | yes | false | whether to use sparse features (and sparce computations with Caffe for huge memory savings, for xgboost use `svm` connector instead)

- SVM (`svm`)
Expand All @@ -1069,6 +1069,7 @@ network | object | yes | empty | Output netw
measure | array | yes | empty | Output measures requested, from `acc`: accuracy, `acc-k`: top-k accuracy, replace k with number (e.g. `acc-5`), `f1`: f1, precision and recall, `mcll`: multi-class log loss, `auc`: area under the curve, `cmdiag`: diagonal of confusion matrix (requires `f1`), `cmfull`: full confusion matrix (requires `f1`), `mcc`: Matthews correlation coefficient
confidence_threshold | double | yes | 0.0 | only returns classifications or detections with probability strictly above threshold
bbox | bool | yes | false | returns bounding boxes around object when using an object detection model, such that (xmin,ymax) yields the top left corner and (xmax,ymin) the lower right corner of a box.
best_bbox | int | yes | -1 | if > 0, returns only the `best_bbox` with highest confidence
regression | bool | yes | false | whether the output of a model is a regression target (i.e. vector of one or more floats)
rois | string | yes | empty | set the ROI layer from which to extract the features from bounding boxes. Both the boxes and features ar returned when using an object detection model with ROI pooling layer
index | bool | yes | false | whether to index the output from prediction, for similarity search
Expand Down
6 changes: 6 additions & 0 deletions src/backends/caffe/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2802,6 +2802,7 @@ namespace dd
int blank_label = -1;
std::string roi_layer;
double confidence_threshold = 0.0;
int best_bbox = -1;
if (ad_output.has("confidence_threshold"))
{
try
Expand All @@ -2816,6 +2817,8 @@ namespace dd
ad_output.get("confidence_threshold").get<int>());
}
}
if (ad_output.has("best_bbox"))
best_bbox = ad_output.get("best_bbox").get<int>();

if (inputc._timeserie
&& ad.getobj("parameters").getobj("input").has("timesteps"))
Expand Down Expand Up @@ -3247,6 +3250,9 @@ namespace dd
int curi = -1;
while (true && k < results_height)
{
if (best_bbox > 0
&& bboxes.size() >= static_cast<size_t>(best_bbox))
break;
if (outr[0] == -1)
{
// skipping invalid detection
Expand Down
4 changes: 4 additions & 0 deletions src/backends/ncnn/ncnnlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ namespace dd
for (int i = 0; i < inputc._out.at(b).h; i++)
{
const float *values = inputc._out.at(b).row(i);
if (output_params->best_bbox > 0
&& bboxes.size()
>= static_cast<size_t>(output_params->best_bbox))
break;
if (values[1] < output_params->confidence_threshold)
break; // output is sorted by confidence

Expand Down
5 changes: 5 additions & 0 deletions src/backends/tensorrt/tensorrtlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,11 @@ namespace dd
int curi = -1;
while (true && k < results_height)
{
if (output_params->best_bbox > 0
&& bboxes.size() >= static_cast<size_t>(
output_params->best_bbox))
break;

if (outr[0] == -1)
{
// skipping invalid detection
Expand Down
13 changes: 12 additions & 1 deletion src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,7 @@ namespace dd
bool bbox = _bbox;
double confidence_threshold = 0.0;
int best_count = _nclasses;
int best_bbox = -1;

if (params.has("mllib"))
{
Expand Down Expand Up @@ -1188,7 +1189,13 @@ namespace dd
}
}
if (output_params.has("best"))
best_count = output_params.get("best").get<int>();
{
best_count = output_params.get("best").get<int>();
}
if (output_params.has("best_bbox"))
{
best_bbox = output_params.get("best_bbox").get<int>();
}

bool lstm_continuation = false;
TInputConnectorStrategy inputc(this->_inputc);
Expand Down Expand Up @@ -1388,6 +1395,10 @@ namespace dd

for (int j = 0; j < labels_tensor.size(0); ++j)
{
if (best_bbox > 0
&& bboxes.size() >= static_cast<size_t>(best_bbox))
break;

double score = score_acc[j];
if (score < confidence_threshold)
continue;
Expand Down
1 change: 1 addition & 0 deletions src/dto/output_connector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace dd
DTO_FIELD(Boolean, ctc) = false;
DTO_FIELD(Float32, confidence_threshold) = 0.0;
DTO_FIELD(Int32, best);
DTO_FIELD(Int32, best_bbox) = -1;

/* ncnn */
DTO_FIELD(Int32, blank_label) = -1;
Expand Down
18 changes: 17 additions & 1 deletion tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ TEST(torchapi, service_predict_object_detection)
"\"width\":224},\"output\":{\"bbox\":true, "
"\"confidence_threshold\":0.8}},\"data\":[\""
+ detect_repo + "cat.jpg\"]}";
// TODO changer image test ?

joutstr = japi.jrender(japi.service_predict(jpredictstr));
JDoc jd;
Expand All @@ -274,6 +273,23 @@ TEST(torchapi, service_predict_object_detection)
&& bbox["ymax"].GetDouble() > 300);
// Check confidence threshold
ASSERT_TRUE(preds[preds.Size() - 1]["prob"].GetDouble() >= 0.8);

// best
jpredictstr = "{\"service\":\"detectserv\",\"parameters\":{"
"\"input\":{\"height\":224,"
"\"width\":224},\"output\":{\"bbox\":true, "
"\"best_bbox\":3}},\"data\":[\""
+ detect_repo + "cat.jpg\"]}";
joutstr = japi.jrender(japi.service_predict(jpredictstr));
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());

ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(200, jd["status"]["code"]);
ASSERT_TRUE(jd["body"]["predictions"].IsArray());

auto &preds_best = jd["body"]["predictions"][0]["classes"];
ASSERT_EQ(preds_best.Size(), 3);
}

TEST(torchapi, service_predict_txt_classification)
Expand Down

0 comments on commit 7890401

Please sign in to comment.