Skip to content

Commit

Permalink
fix(torch): concurrent_predict was always true
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob committed Jul 17, 2023
1 parent 5eb7890 commit edb28c1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
7 changes: 2 additions & 5 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ namespace dd
"GPU is not available, service could not be created");
}

_concurrent_predict = mllib_dto->concurrent_predict;
std::vector<int> gpuids = mllib_dto->gpuid->_ids;

if (mllib_dto->nclasses != 0)
Expand Down Expand Up @@ -313,11 +314,6 @@ namespace dd
_multi_label = true;
}

if (mllib_dto->concurrent_predict)
{
_concurrent_predict = true;
}

if (_template == "bert")
{
if (!self_supervised.empty())
Expand Down Expand Up @@ -1380,6 +1376,7 @@ namespace dd
{
// concurrent calls can use more memory on gpu than initially expected
lock = std::make_unique<std::lock_guard<std::mutex>>(_net_mutex);
this->_logger->info("Locking torch service for predict");
}
oatpp::Object<DTO::ServicePredict> predict_dto;

Expand Down
20 changes: 20 additions & 0 deletions src/backends/torch/torchmodule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,11 @@ namespace dd
const torch::Device &device,
const oatpp::Object<DTO::ServicePredict> &pred_dto);

template void TorchModule::create_native_template(
const std::string &tmpl, const APIData &lib_ad,
const ImgTorchInputFileConn &inputc, const TorchModel &tmodel,
const torch::Device &device);

template void TorchModule::post_transform(
const std::string tmpl, const APIData &template_params,
const VideoTorchInputFileConn &inputc, const TorchModel &tmodel,
Expand All @@ -768,6 +773,11 @@ namespace dd
const torch::Device &device,
const oatpp::Object<DTO::ServicePredict> &pred_dto);

template void TorchModule::create_native_template(
const std::string &tmpl, const APIData &lib_ad,
const VideoTorchInputFileConn &inputc, const TorchModel &tmodel,
const torch::Device &device);

template void TorchModule::post_transform(
const std::string tmpl, const APIData &template_params,
const TxtTorchInputFileConn &inputc, const TorchModel &tmodel,
Expand All @@ -784,6 +794,11 @@ namespace dd
const torch::Device &device,
const oatpp::Object<DTO::ServicePredict> &pred_dto);

template void TorchModule::create_native_template(
const std::string &tmpl, const APIData &lib_ad,
const TxtTorchInputFileConn &inputc, const TorchModel &tmodel,
const torch::Device &device);

template void TorchModule::post_transform(
const std::string tmpl, const APIData &template_params,
const CSVTSTorchInputFileConn &inputc, const TorchModel &tmodel,
Expand All @@ -799,4 +814,9 @@ namespace dd
const CSVTSTorchInputFileConn &inputc, const TorchModel &tmodel,
const torch::Device &device,
const oatpp::Object<DTO::ServicePredict> &pred_dto);

template void TorchModule::create_native_template(
const std::string &tmpl, const APIData &lib_ad,
const CSVTSTorchInputFileConn &inputc, const TorchModel &tmodel,
const torch::Device &device);
}

0 comments on commit edb28c1

Please sign in to comment.