Skip to content

Commit

Permalink
fix(chain): empty predictions were too empty
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Sep 7, 2023
1 parent 27b72b3 commit 57bed0b
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 11 deletions.
11 changes: 7 additions & 4 deletions src/backends/tensorrt/tensorrtlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,13 @@ namespace dd
TMLModel>::clear_mllib(const APIData &ad)
{
(void)ad;
cudaFree(_buffers.data()[_inputIndex]);
cudaFree(_buffers.data()[_outputIndex0]);
if (_bbox)
cudaFree(_buffers.data()[_outputIndex1]);
if (!_buffers.empty())
{
cudaFree(_buffers.at(_inputIndex));
cudaFree(_buffers.at(_outputIndex0));
if (_bbox)
cudaFree(_buffers.at(_outputIndex1));
}

// remove compiled model files.
std::vector<std::string> extensions
Expand Down
10 changes: 4 additions & 6 deletions src/services.h
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,8 @@ namespace dd
{
chain_logger->info("[" + std::to_string(chain_pos)
+ "] / no result from prediction");
cdata.add_model_data(
pred_id,
DTO::PredictBody::createShared()); // store empty model output
cdata.add_model_data(pred_id,
pred_dto); // store empty model output
return 1;
}

Expand Down Expand Up @@ -1170,9 +1169,8 @@ namespace dd
{
chain_logger->info("[" + std::to_string(chain_pos)
+ "] / no result from prediction");
cdata.add_model_data(
pred_id,
DTO::PredictBody::createShared()); // store empty model output
cdata.add_model_data(pred_id,
pred_dto); // store empty model output
return 1;
}

Expand Down
103 changes: 102 additions & 1 deletion tests/ut-chain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,30 @@ TEST(chain, chain_torch_detection_classification)
.GetString(),
std::string("n02086079 Pekinese, Pekingese, Peke"));

// chain predict without detection
jchainstr
= "{\"chain\":{\"name\":\"chain\",\"calls\":["
"{\"service\":\""
+ detect_sname
+ "\",\"parameters\":{\"input\":{\"keep_orig\":true},\"output\":{"
"\"bbox\":true,\"confidence_threshold\":0.9999}},\"data\":[\""
+ uri1
+ "\"]},"
"{\"id\":\"crop\",\"action\":{\"type\":\"crop\",\"parameters\":{"
"\"padding_ratio\":0.05}}},{\"service\":\""
+ classif_sname
+ "\",\"parent_id\":\"crop\",\"parameters\":{\"output\":{\"best\":1}}}"
"]}}";
joutstr = japi.jrender(japi.service_chain("chain", jchainstr));
jd = JDoc();
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(200, jd["status"]["code"]);
ASSERT_TRUE(jd["body"]["predictions"].IsArray());
ASSERT_EQ(jd["body"]["predictions"].Size(), 1);
ASSERT_EQ(jd["body"]["predictions"][0]["classes"].Size(), 0);

// multiple models (tree)
std::string classif2_sname = "classif2";
jstr = "{\"mllib\":\"torch\",\"description\":\"squeezenet\",\"type\":"
Expand Down Expand Up @@ -181,7 +205,6 @@ TEST(chain, chain_torch_detection_classification)
// cleanup
fileops::remove_file(torch_detect_repo, "model.json");
}

#endif

#ifdef USE_CAFFE
Expand Down Expand Up @@ -486,4 +509,82 @@ TEST(chain, chain_trt_detection_gan)
+ get_trt_archi() + "_bs1"));
}

// Test internal call without json
TEST(chain, chain_trt_dto)
{
JsonAPI japi;
std::string detect_sname = "detect";
std::string jstr
= "{\"mllib\":\"tensorrt\",\"description\":\"yolox\","
"\"type\":\"supervised\",\"model\":{\"repository\":\""
+ trt_detect_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":"
"\"image\",\"height\":640,\"width\":640},\"mllib\":{"
"\"maxWorkspaceSize\":256,\"gpuid\":0,"
"\"template\":\"yolox\",\"nclasses\":81,\"datatype\":\"fp16\"}}}";
std::string joutstr = japi.jrender(japi.service_create(detect_sname, jstr));
ASSERT_EQ(created_str, joutstr);

std::string gan_sname = "gan";
jstr = "{\"mllib\":\"tensorrt\",\"description\":\"gan\",\"type\":"
"\"supervised\",\"model\":{\"repository\":\""
+ trt_gan_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":"
"360,\"width\":360,\"rgb\":true,\"scale\":0.0039,\"mean\":[0.5, "
"0.5,0.5],\"std\":[0.5,0.5,0.5]},\"mllib\":{\"maxBatchSize\":1,"
"\"maxWorkspaceSize\":256,\"gpuid\":0,\"datatype\":\"fp16\"}}}";
joutstr = japi.jrender(japi.service_create(gan_sname, jstr));
ASSERT_EQ(created_str, joutstr);

// chain call with no predictions
std::string uri1 = trt_gan_repo + "/horse_1024.jpg";
auto input_dto = oatpp::Object<DTO::ServiceChain>::createShared();
input_dto->chain = oatpp::Object<DTO::Chain>::createShared();

auto call1 = oatpp::Object<DTO::ChainCall>::createShared();
call1->service = detect_sname;
call1->parameters->input->keep_orig = true;
call1->parameters->output->bbox = true;
call1->parameters->output->confidence_threshold = 0.9999;
call1->data->push_back(uri1);
input_dto->chain->calls->push_back(call1);

auto call2 = oatpp::Object<DTO::ChainCall>::createShared();
call2->id = "crop";
call2->action = oatpp::Object<DTO::ChainAction>::createShared();
call2->action->type = "crop";
call2->action->parameters->padding_ratio = 0.05;
input_dto->chain->calls->push_back(call2);

auto call3 = oatpp::Object<DTO::ChainCall>::createShared();
call3->service = gan_sname;
call3->parent_id = "crop";
call3->parameters->mllib->extract_layer = "last";
call3->parameters->output->image = true;
input_dto->chain->calls->push_back(call3);

auto chain_out = japi.chain(input_dto, "chain");
JDoc jdoc;
oatpp_utils::dtoToJDoc(chain_out, jdoc);
std::cout << dd_utils::jrender(jdoc) << std::endl;

ASSERT_EQ(chain_out->predictions->size(), 1);
// PredictClass -> only one model, so the Predict DTO is returned without
// modifications
ASSERT_EQ((*chain_out->predictions->at(0))["classes"]
.retrieve<oatpp::Vector<oatpp::Object<DTO::PredictClass>>>()
->size(),
0);

jstr = "{\"clear\":\"lib\"}";
joutstr = japi.jrender(japi.service_delete(detect_sname, jstr));
ASSERT_EQ(ok_str, joutstr);
joutstr = japi.jrender(japi.service_delete(gan_sname, jstr));
ASSERT_EQ(ok_str, joutstr);
ASSERT_TRUE(!fileops::file_exists(trt_detect_repo + "/TRTengine_arch"
+ get_trt_archi() + "_bs1"));
ASSERT_TRUE(!fileops::file_exists(trt_gan_repo + "/TRTengine_arch"
+ get_trt_archi() + "_bs1"));
}

#endif

0 comments on commit 57bed0b

Please sign in to comment.