Skip to content

Commit

Permalink
feat(torch): allow multigpu for traced models
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Sep 23, 2022
1 parent 101407b commit 6b3b9c0
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 13 deletions.
35 changes: 28 additions & 7 deletions src/backends/torch/torchmodule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ namespace dd
// reload params after finalize
graph_model_load(tmodel);
}
to(_device);

if (_require_linear_head && !_linear_head)
{
Expand All @@ -239,7 +238,6 @@ namespace dd
setup_linear_head(_nclasses,
const_cast<TInputConnectorStrategy &>(inputc)
.get_input_example(device));
_linear_head->to(_device);
}
catch (std::exception &e)
{
Expand All @@ -255,14 +253,15 @@ namespace dd
const_cast<TInputConnectorStrategy &>(inputc)
.get_input_example(device),
inputc._alphabet_size);
_crnn_head->to(_device);
}
catch (std::exception &e)
{
throw MLLibInternalException(std::string("Libtorch error: ")
+ e.what());
}
}

to(device);
}

template <class TInputConnectorStrategy>
Expand Down Expand Up @@ -560,17 +559,39 @@ namespace dd
std::shared_ptr<TorchModule> TorchModule::clone(torch::Device device)
{
auto cloned = std::make_shared<TorchModule>(*this);

if (_native)
{
cloned->_native
= std::dynamic_pointer_cast<NativeModule>(_native->clone(device));
}
if (_graph || _traced)
if (_traced)
{
throw MLLibBadParamException(
"MultiGPU is not supported on non cloneable models (including "
"graphs & traced models)");
cloned->_traced
= std::make_shared<torch::jit::script::Module>(_traced->clone());
cloned->_traced->to(device);
for (auto param : cloned->_traced->parameters())
{
param.detach_().requires_grad_();
}
}
if (_linear_head)
{
cloned->_linear_head
= std::dynamic_pointer_cast<torch::nn::LinearImpl>(
_linear_head->clone(device));
}
if (_crnn_head)
{
cloned->_crnn_head = std::dynamic_pointer_cast<CRNNHeadImpl>(
_crnn_head->clone(device));
}
if (_graph)
{
throw MLLibBadParamException("MultiGPU is not supported on non "
"cloneable models (graph models)");
}
cloned->to(device);
return cloned;
}

Expand Down
104 changes: 98 additions & 6 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2008,8 +2008,8 @@ TEST(torchapi, service_train_object_detection_yolox)
+ detect_train_repo_yolox
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":"
"640,\"width\":640,\"rgb\":true,\"bbox\":true},"
"\"mllib\":{\"template\":\"yolox\",\"gpu\":true,\"nclasses\":"
"2}}}";
"\"mllib\":{\"template\":\"yolox\",\"gpu\":true,"
"\"nclasses\":2}}}";

std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);
Expand All @@ -2028,8 +2028,7 @@ TEST(torchapi, service_train_object_detection_yolox)
"\"cutout\":0.1,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"prob\":0."
"01}},\"input\":{\"seed\":12347,\"db\":true,"
"\"prob\":0.01}},\"input\":{\"seed\":12347,\"db\":true,"
"\"shuffle\":true},\"output\":{\"measure\":[\"map\"]}},\"data\":[\""
+ fasterrcnn_train_data + "\",\"" + fasterrcnn_test_data + "\"]}";

Expand Down Expand Up @@ -2091,6 +2090,101 @@ TEST(torchapi, service_train_object_detection_yolox)
fileops::remove_dir(detect_train_repo_yolox + "test_0.lmdb");
}

TEST(torchapi, service_train_object_detection_yolox_multigpu)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
torch::manual_seed(torch_seed);
at::globalContext().setDeterministicCuDNN(true);

JsonAPI japi;
std::string sname = "detectserv";
std::string jstr
= "{\"mllib\":\"torch\",\"description\":\"yolox\",\"type\":"
"\"supervised\",\"model\":{\"repository\":\""
+ detect_train_repo_yolox
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":"
"640,\"width\":640,\"rgb\":true,\"bbox\":true},"
"\"mllib\":{\"template\":\"yolox\",\"gpu\":true,\"gpuid\":[0,1],"
"\"nclasses\":2}}}";

std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// Train
std::string jtrainstr
= "{\"service\":\"detectserv\",\"async\":false,\"parameters\":{"
"\"mllib\":{\"solver\":{\"iterations\":"
+ iterations_detection + ",\"base_lr\":" + torch_lr
+ ",\"iter_size\":2,\"solver_"
"type\":\"ADAM\",\"test_interval\":200},\"net\":{\"batch_size\":2,"
"\"test_batch_size\":2,\"reg_weight\":0.5},\"resume\":false,"
"\"mirror\":true,\"rotate\":true,\"crop_size\":512,"
"\"test_crop_samples\":10,"
"\"cutout\":0.1,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"prob\":0.01}},\"input\":{\"seed\":12347,\"db\":true,"
"\"shuffle\":true},\"output\":{\"measure\":[\"map\"]}},\"data\":[\""
+ fasterrcnn_train_data + "\",\"" + fasterrcnn_test_data + "\"]}";

joutstr = japi.jrender(japi.service_train(jtrainstr));
JDoc jd;
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(201, jd["status"]["code"]);

ASSERT_EQ(jd["body"]["measure"]["iteration"], 200) << "iterations";
ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() <= 1.0) << "map";
// ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() > 0.0) << "map";

// check metrics
auto &meas = jd["body"]["measure"];
ASSERT_TRUE(meas.HasMember("iou_loss"));
ASSERT_TRUE(meas.HasMember("conf_loss"));
ASSERT_TRUE(meas.HasMember("cls_loss"));
ASSERT_TRUE(meas.HasMember("l1_loss"));
ASSERT_TRUE(meas.HasMember("train_loss"));
ASSERT_TRUE(
std::abs(meas["train_loss"].GetDouble()
- (meas["iou_loss"].GetDouble() * 0.5
+ meas["cls_loss"].GetDouble() + meas["l1_loss"].GetDouble()
+ meas["conf_loss"].GetDouble()))
< 0.0001);

// check that predict works fine
std::string jpredictstr = "{\"service\":\"detectserv\",\"parameters\":{"
"\"input\":{\"height\":640,"
"\"width\":640},\"output\":{\"bbox\":true, "
"\"confidence_threshold\":0.8}},\"data\":[\""
+ detect_train_repo_fasterrcnn
+ "/imgs/la_melrose_ave-000020.jpg\"]}";
joutstr = japi.jrender(japi.service_predict(jpredictstr));
jd = JDoc();
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(200, jd["status"]["code"]);

std::unordered_set<std::string> lfiles;
fileops::list_directory(detect_train_repo_yolox, true, false, false, lfiles);
for (std::string ff : lfiles)
{
if (ff.find("checkpoint") != std::string::npos
|| ff.find("solver") != std::string::npos)
remove(ff.c_str());
}
ASSERT_TRUE(!fileops::file_exists(detect_train_repo_yolox + "checkpoint-"
+ iterations_detection + ".ptw"));
ASSERT_TRUE(!fileops::file_exists(detect_train_repo_yolox + "checkpoint-"
+ iterations_detection + ".pt"));

fileops::clear_directory(detect_train_repo_yolox + "train.lmdb");
fileops::clear_directory(detect_train_repo_yolox + "test_0.lmdb");
fileops::remove_dir(detect_train_repo_yolox + "train.lmdb");
fileops::remove_dir(detect_train_repo_yolox + "test_0.lmdb");
}

TEST(torchapi, service_train_images_native)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
Expand Down Expand Up @@ -3655,7 +3749,6 @@ TEST(torchapi, service_train_vit_images_gpu)
std::string jstr
= "{\"mllib\":\"torch\",\"description\":\"image\",\"type\":"
"\"supervised\",\"model\":{\"repository\":\""

+ vit_train_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\","
"\"width\":224,\"height\":224,\"db\":true},\"mllib\":{\"nclasses\":"
Expand Down Expand Up @@ -3742,7 +3835,6 @@ TEST(torchapi, service_train_vit_images_multigpu)
std::string jstr
= "{\"mllib\":\"torch\",\"description\":\"image\",\"type\":"
"\"supervised\",\"model\":{\"repository\":\""

+ vit_train_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\","
"\"width\":224,\"height\":224,\"db\":true},\"mllib\":{\"nclasses\":"
Expand Down

0 comments on commit 6b3b9c0

Please sign in to comment.