Skip to content

Commit

Permalink
feat(torch): add multilabel classification
Browse files Browse the repository at this point in the history
  • Loading branch information
cchadowitz authored and mergify[bot] committed Aug 12, 2022
1 parent 24aa37c commit 90d536e
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,7 @@ gpu | bool | yes | false | Whether to use GPU
gpuid | int or array | yes | 0 | GPU id, use single int for single GPU, `-1` for using all GPUs, and array e.g. `[1,3]` for selecting among multiple GPUs
extract_layer | string | yes | "" | Returns tensor values from intermediate layers. In bert models "hidden_state" allows to extract raw hidden_states values to return as output. If set to 'last', simply returns the tensor values from last layer.
forward_method | string | yes | "" | Executes a custom function from within a traced/JIT model, instead of the standard forward()
multi_label | bool | yes | false | Model outputs an independent score for each class


- XGBoost
Expand Down
13 changes: 12 additions & 1 deletion src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ namespace dd
_bbox = tl._bbox;
_segmentation = tl._segmentation;
_ctc = tl._ctc;
_multi_label = tl._multi_label;
_loss = tl._loss;
_template_params = tl._template_params;
_dtype = tl._dtype;
Expand Down Expand Up @@ -295,6 +296,11 @@ namespace dd
_regression = true;
}

if (mllib_dto->multi_label)
{
_multi_label = true;
}

if (_template == "bert")
{
if (!self_supervised.empty())
Expand Down Expand Up @@ -1500,7 +1506,12 @@ namespace dd
// XXX: why is (!_timeserie) needed here? Aren't _timeserie and
// _classification mutually exclusive?
if (extract_layer.empty() && !_timeserie && _classification)
output = torch::softmax(output, 1).to(cpu);
{
if (_multi_label)
output = torch::sigmoid(output).to(cpu);
else
output = torch::softmax(output, 1).to(cpu);
}
else if (extract_layer.empty() && !_timeserie && ctc)
output = torch::softmax(output, 2).to(cpu);
else
Expand Down
1 change: 1 addition & 0 deletions src/backends/torch/torchlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ namespace dd
bool _bbox = false; /**< select detection type problem */
bool _segmentation = false; /**< select segmentation type problem */
bool _ctc = false; /**< select OCR type problem */
bool _multi_label = false; /**< whether model outputs multiple labels */
std::string _loss = ""; /**< selected loss*/
double _reg_weight
= 1; /**< for detection models, weight for bbox regression loss. */
Expand Down
7 changes: 7 additions & 0 deletions src/dto/mllib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ namespace dd
};
DTO_FIELD(UnorderedFields<Any>, template_params);

DTO_FIELD_INFO(multi_label)
{
info->description
= "Model outputs an independent score for each class";
}
DTO_FIELD(Boolean, multi_label) = false;

// Libtorch predict options
DTO_FIELD_INFO(forward_method)
{
Expand Down
40 changes: 40 additions & 0 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,46 @@ TEST(torchapi, service_predict_native_bw)
fileops::remove_dir(native_resnet_repo);
}

TEST(torchapi, service_predict_multi_label)
{
// create service
JsonAPI japi;
std::string sname = "imgserv";
std::string jstr
= "{\"mllib\":\"torch\",\"description\":\"resnet-50\",\"type\":"
"\"supervised\",\"model\":{\"repository\":\""
+ incept_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":"
"224,\"width\":224,\"rgb\":true,\"scale\":0.0039},\"mllib\":{"
"\"nclasses\":1000,\"multi_label\":true}}}";
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// predict
std::string jpredictstr
= "{\"service\":\"imgserv\",\"parameters\":{\"input\":{\"height\":224,"
"\"width\":224},\"output\":{\"best\":10}},\"data\":"
"[\""
+ incept_repo + "cat.jpg\"]}";

joutstr = japi.jrender(japi.service_predict(jpredictstr));
JDoc jd = JDoc();
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(200, jd["status"]["code"]);
ASSERT_TRUE(jd["body"]["predictions"].IsArray());
ASSERT_EQ(jd["body"]["predictions"][0]["classes"].Size(), 10);

// multi_label = no softmax = sums of probs is not 1
double sum = 0.0;
for (size_t i = 0; i < jd["body"]["predictions"][0]["classes"].Size(); i++)
{
sum += jd["body"]["predictions"][0]["classes"][i]["prob"].GetDouble();
}
ASSERT_GT(sum, 1.0);
}

#if !defined(CPU_ONLY)
TEST(torchapi, service_predict_fp16)
{
Expand Down

0 comments on commit 90d536e

Please sign in to comment.