Skip to content

Commit

Permalink
feat(torch): Added param disable_concurrent_predict
Browse files Browse the repository at this point in the history
This parameter helps manage GPU memory if multiple external API callers are in use and memory management cannot be done externally. Not in use by default.
  • Loading branch information
cchadowitz authored and Bycob committed Jul 3, 2023
1 parent 0e516bb commit 71cb66a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 4 deletions.
9 changes: 5 additions & 4 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -1119,12 +1119,13 @@ test_batch_size | int | yes | N/A | Prediction batch size (the server
- Torch

Parameter | Type | Optional | Default | Description
--------- | ---- | -------- | ------- | -----------
--------- | ---- | -------- |---------| -----------
gpu | bool | yes | false | Whether to use GPU
gpuid | int or array | yes | 0 | GPU id, use single int for single GPU, `-1` for using all GPUs, and array e.g. `[1,3]` for selecting among multiple GPUs
gpuid | int or array | yes | 0 | GPU id, use single int for single GPU, `-1` for using all GPUs, and array e.g. `[1,3]` for selecting among multiple GPUs
extract_layer | string | yes | "" | Returns tensor values from intermediate layers. In bert models "hidden_state" allows to extract raw hidden_states values to return as output. If set to 'last', simply returns the tensor values from last layer.
forward_method | string | yes | "" | Executes a custom function from within a traced/JIT model, instead of the standard forward()
multi_label | bool | yes | false | Model outputs an independent score for each class
forward_method | string | yes | "" | Executes a custom function from within a traced/JIT model, instead of the standard forward()
multi_label | bool | yes | false | Model outputs an independent score for each class
concurrent_predict | bool | yes | true | Enable/disable concurrent predict for the model


- XGBoost
Expand Down
11 changes: 11 additions & 0 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ namespace dd
_segmentation = tl._segmentation;
_ctc = tl._ctc;
_multi_label = tl._multi_label;
_concurrent_predict = tl._concurrent_predict;
_loss = tl._loss;
_template_params = tl._template_params;
_dtype = tl._dtype;
Expand Down Expand Up @@ -312,6 +313,11 @@ namespace dd
_multi_label = true;
}

if (mllib_dto->concurrent_predict)
{
_concurrent_predict = true;
}

if (_template == "bert")
{
if (!self_supervised.empty())
Expand Down Expand Up @@ -1369,6 +1375,11 @@ namespace dd
int TorchLib<TInputConnectorStrategy, TOutputConnectorStrategy,
TMLModel>::predict(const APIData &ad_in, APIData &out)
{
std::unique_ptr<std::lock_guard<std::mutex>> lock;
if (!_concurrent_predict) {
// concurrent calls can use more memory on gpu than initially expected
lock = std::make_unique<std::lock_guard<std::mutex>>(_net_mutex);
}
oatpp::Object<DTO::ServicePredict> predict_dto;

// XXX: until everything is DTO, we consider the two cases:
Expand Down
8 changes: 8 additions & 0 deletions src/backends/torch/torchlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,18 @@ namespace dd
bool _segmentation = false; /**< select segmentation type problem */
bool _ctc = false; /**< select OCR type problem */
bool _multi_label = false; /**< whether model outputs multiple labels */
bool _concurrent_predict = true; /**< allow concurrent predicts */
std::string _loss = ""; /**< selected loss*/
double _reg_weight
= 1; /**< for detection models, weight for bbox regression loss. */

std::mutex
_net_mutex; /**< mutex around net, e.g. no concurrent predict calls as
it can use more gpu memory than initially expected.
Use batches instead.
This is only used if concurrent_predict is
disabled. */

APIData _template_params; /**< template parameters, for recurrent and
native models*/

Expand Down
7 changes: 7 additions & 0 deletions src/dto/mllib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ namespace dd
}
DTO_FIELD(Boolean, multi_label) = false;

DTO_FIELD_INFO(concurrent_predict)
{
info->description
= "Enable/disable concurrent predict for the model";
}
DTO_FIELD(Boolean, concurrent_predict) = true;

// Libtorch predict options
DTO_FIELD_INFO(forward_method)
{
Expand Down

0 comments on commit 71cb66a

Please sign in to comment.