Skip to content

Commit

Permalink
fix: cropped model input size when publishing torch models + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz authored and mergify[bot] committed Mar 21, 2022
1 parent 4a4fd3f commit 2dabd89
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
21 changes: 10 additions & 11 deletions src/backends/torch/torchmodel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,18 +260,17 @@ namespace dd
d_config.GetAllocator());

//- crop_size
auto d_input = d_config["parameters"]["input"].GetObject();
auto d_mllib = d_config["parameters"]["mllib"].GetObject();
if (d_mllib.HasMember("crop_size"))
auto d_config_input = d_config["parameters"]["input"].GetObject();
auto d_model_mllib = d_model["parameters"]["mllib"].GetObject();
if (d_model_mllib.HasMember("crop_size"))
{
try
{
int crop_size
= d_model["parameters"]["mllib"]["crop_size"].GetInt();
int crop_size = d_model_mllib["crop_size"].GetInt();
if (crop_size > 0)
{
d_config["parameters"]["input"]["width"].SetInt(crop_size);
d_config["parameters"]["input"]["height"].SetInt(crop_size);
d_config_input["width"].SetInt(crop_size);
d_config_input["height"].SetInt(crop_size);
}
}
catch (RapidjsonException &e)
Expand All @@ -281,10 +280,10 @@ namespace dd
//- db
try
{
if (d_input.HasMember("db"))
d_input["db"].SetBool(false);
if (d_mllib.HasMember("db"))
d_input["db"].SetBool(false);
if (d_config_input.HasMember("db"))
d_config_input["db"].SetBool(false);
if (d_model_mllib.HasMember("db"))
d_model_mllib["db"].SetBool(false);
}
catch (RapidjsonException &e)
{
Expand Down
16 changes: 16 additions & 0 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1063,8 +1063,24 @@ TEST(torchapi, service_publish_trained_model)
ASSERT_TRUE(fileops::file_exists(published_repo + "/checkpoint-1.pt"));
ASSERT_TRUE(fileops::file_exists(published_repo + "/best_model.txt"));
ASSERT_TRUE(fileops::file_exists(published_repo + "/model.json"));
ASSERT_TRUE(fileops::file_exists(published_repo + "/config.json"));
ASSERT_FALSE(fileops::file_exists(published_repo + "/resnet50.pt"));

// Check on published model configuration
std::string config_path = published_repo + "/config.json";
std::ifstream ifs_config(config_path.c_str(), std::ios::binary);
ASSERT_TRUE(ifs_config.is_open());
std::stringstream config_sstr;
config_sstr << ifs_config.rdbuf();
ifs_config.close();
rapidjson::Document d_config;
d_config.Parse<rapidjson::kParseNanAndInfFlag>(config_sstr.str().c_str());
auto d_config_input = d_config["parameters"]["input"].GetObject();
ASSERT_TRUE(d_config_input.HasMember("width"));
ASSERT_TRUE(d_config_input["width"].GetInt() == 224);
ASSERT_TRUE(d_config_input.HasMember("height"));
ASSERT_TRUE(d_config_input["height"].GetInt() == 224);

// Clean up train repo
std::unordered_set<std::string> lfiles;
fileops::list_directory(resnet50_train_repo, true, false, false, lfiles);
Expand Down

0 comments on commit 2dabd89

Please sign in to comment.