Skip to content

Commit

Permalink
feat(ml): added support for segformer with torch backend
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz authored and mergify[bot] committed Dec 28, 2021
1 parent f86b8b8 commit ab03d1d
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 6 deletions.
18 changes: 14 additions & 4 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1642,8 +1642,13 @@ namespace dd
}
else if (_segmentation)
{
auto out_dict = out_ivalue.toGenericDict();
output = torch_utils::to_tensor_safe(out_dict.at("out"));
if (out_ivalue.isGenericDict())
{
auto out_dict = out_ivalue.toGenericDict();
output = torch_utils::to_tensor_safe(out_dict.at("out"));
}
else
output = torch_utils::to_tensor_safe(out_ivalue);
output = torch::softmax(output, 1);

int imgsize = inputc.width() * inputc.height();
Expand Down Expand Up @@ -1968,8 +1973,13 @@ namespace dd
}
else if (_segmentation)
{
auto out_dict = out_ivalue.toGenericDict();
output = torch_utils::to_tensor_safe(out_dict.at("out"));
if (out_ivalue.isGenericDict())
{
auto out_dict = out_ivalue.toGenericDict();
output = torch_utils::to_tensor_safe(out_dict.at("out"));
}
else
output = torch_utils::to_tensor_safe(out_ivalue);
output = torch::softmax(output, 1);
torch::Tensor target = batch.target.at(0).to(torch::kFloat64);
torch::Tensor segmap
Expand Down
7 changes: 7 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,13 @@ if (USE_TORCH)
"deeplabv3_training_torch.tar.gz"
"deeplabv3_training_torch"
)
DOWNLOAD_DATASET(
"Torch training Segformer model"
"https://www.deepdetect.com/dd/examples/torch/segformer_training_torch.tar.gz"
"examples/torch"
"segformer_training_torch.tar.gz"
"segformer_training_torch"
)
DOWNLOAD_DATASET(
"Torch BERT classification test model"
"https://www.deepdetect.com/dd/examples/torch/bert_inference_torch.tar.gz"
Expand Down
89 changes: 87 additions & 2 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ static std::string deeplabv3_test_data
static std::string deeplabv3_test_image
= "../examples/torch/deeplabv3_training_torch/CamVid_square/test/"
"Seq05VD_f00330.png";
static std::string segformer_train_repo
= "../examples/torch/segformer_training_torch/";

static std::string resnet50_native_weights
= "../examples/torch/resnet50_native_torch/resnet50.npt";
Expand Down Expand Up @@ -746,7 +748,7 @@ TEST(torchapi, service_train_images)
fileops::remove_dir(resnet50_train_repo + "test_0.lmdb");
}

TEST(torchapi, service_train_image_segmentation)
TEST(torchapi, service_train_image_segmentation_deeplabv3)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
torch::manual_seed(torch_seed);
Expand Down Expand Up @@ -829,7 +831,7 @@ TEST(torchapi, service_train_image_segmentation)
fileops::remove_dir(deeplabv3_train_repo + "test_0.lmdb");
}

TEST(torchapi, service_train_image_segmentation_dice)
TEST(torchapi, service_train_image_segmentation_deeplabv3_dice)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
torch::manual_seed(torch_seed);
Expand Down Expand Up @@ -914,6 +916,89 @@ TEST(torchapi, service_train_image_segmentation_dice)
fileops::remove_dir(deeplabv3_train_repo + "test_0.lmdb");
}

TEST(torchapi, service_train_image_segmentation_segformer)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
torch::manual_seed(torch_seed);
at::globalContext().setDeterministicCuDNN(true);

// Create service
JsonAPI japi;
std::string sname = "imgserv";
std::string jstr
= "{\"mllib\":\"torch\",\"description\":\"image\",\"type\":"
"\"supervised\",\"model\":{\"repository\":\""
+ segformer_train_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\","
"\"width\":480,\"height\":480,\"db\":true,\"segmentation\":true},"
"\"mllib\":{\"nclasses\":"
"13,\"gpu\":true,\"segmentation\":true}}}";
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// Train
std::string jtrainstr
= "{\"service\":\"imgserv\",\"async\":false,\"parameters\":{"
"\"mllib\":{\"solver\":{\"iterations\":"
+ iterations_deeplabv3 + ",\"base_lr\":" + torch_lr
+ ",\"iter_size\":1,\"solver_type\":\"ADAM\",\"test_"
"interval\":100},\"net\":{\"batch_size\":4},"
"\"resume\":false,\"mirror\":true,\"rotate\":true,\"crop_size\":224,"
"\"cutout\":0.5,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":1},\"noise\":{\"prob\":0.01},\"distort\":{\"prob\":0."
"01}},"
"\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":true,"
"\"segmentation\":true,\"scale\":0.0039,\"mean\":[0.485,0.456,0.406]"
",\"std\":[0.229,0.224,0.225]},"
"\"output\":{\"measure\":[\"meaniou\",\"acc\"]}},\"data\":[\""
+ deeplabv3_train_data + "\",\"" + deeplabv3_test_data + "\"]}";
joutstr = japi.jrender(japi.service_train(jtrainstr));
JDoc jd;
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(201, jd["status"]["code"]);

ASSERT_TRUE(jd["body"]["measure"]["meanacc"].GetDouble() <= 1) << "accuracy";
ASSERT_TRUE(jd["body"]["measure"]["meanacc"].GetDouble() >= 0.006)
<< "accuracy good";
ASSERT_TRUE(jd["body"]["measure"]["meaniou"].GetDouble() <= 1) << "meaniou";

std::string jpredictstr
= "{\"service\":\"imgserv\",\"parameters\":{"
"\"input\":{\"height\":480,"
"\"width\":480,\"scale\":0.0039,\"mean\":[0.485,0.456,0.406],\"std\":["
"0.229,0.224,0.225]},\"output\":{\"segmentation\":true, "
"\"confidences\":[\"best\"]}},\"data\":[\""
+ deeplabv3_test_image + "\"]}";

joutstr = japi.jrender(japi.service_predict(jpredictstr));
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(200, jd["status"]["code"]);
ASSERT_TRUE(jd["body"]["predictions"].IsArray());

std::unordered_set<std::string> lfiles;
fileops::list_directory(segformer_train_repo, true, false, false, lfiles);
for (std::string ff : lfiles)
{
if (ff.find("checkpoint") != std::string::npos
|| ff.find("solver") != std::string::npos)
remove(ff.c_str());
}
ASSERT_TRUE(!fileops::file_exists(segformer_train_repo + "checkpoint-"
+ iterations_deeplabv3 + ".ptw"));
ASSERT_TRUE(!fileops::file_exists(segformer_train_repo + "checkpoint-"
+ iterations_deeplabv3 + ".pt"));

fileops::clear_directory(segformer_train_repo + "train.lmdb");
fileops::clear_directory(segformer_train_repo + "test_0.lmdb");
fileops::remove_dir(segformer_train_repo + "train.lmdb");
fileops::remove_dir(segformer_train_repo + "test_0.lmdb");
}

TEST(torchapi, service_publish_trained_model)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
Expand Down

0 comments on commit ab03d1d

Please sign in to comment.