Skip to content

Commit

Permalink
fix: torch model published config file
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz authored and mergify[bot] committed Jan 18, 2022
1 parent 809f00a commit b0d4e04
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 110 deletions.
1 change: 0 additions & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,6 @@ datatype | string | true | "fp32" | datatype inside compiled

Parameter | Type | Optional | Default | Description
--------- | ---- | -------- | ------- | -----------
store_config | bool | yes | false | stores the creation call in a `config.json` file in the model directory
measure | array of string | yes | depending on problem type | measure to use at test time


Expand Down
6 changes: 5 additions & 1 deletion src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1199,7 +1199,10 @@ namespace dd
}

if (elapsed_it == iterations)
out = meas_out;
{
out.add("measure", meas_out.getobj("measure"));
out.add("measures", meas_out.getv("measures"));
}
}

train_loss = 0;
Expand Down Expand Up @@ -1269,6 +1272,7 @@ namespace dd

inputc.response_params(out);
this->_logger->info("Training done.");

return 0;
}

Expand Down
103 changes: 98 additions & 5 deletions src/backends/torch/torchmodel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
*/

#include "torchmodel.h"

#include "utils/utils.hpp"

#include <rapidjson/document.h>
#include <rapidjson/filereadstream.h>

namespace dd
{
int TorchModel::read_from_repository(
Expand Down Expand Up @@ -200,16 +202,17 @@ namespace dd
{
std::vector<std::string> selts = dd_utils::split((*hit), '/');
fileops::copy_file((*hit), target_repo + '/' + selts.back());
if (selts.back().find(".json") != std::string::npos)
fileops::replace_string_in_file(target_repo + '/'
+ selts.back(),
"db\":true", "db\":false");
logger->info("successfully copied model file {} to {}", (*hit),
target_repo + '/' + selts.back());
}
++hit;
}

logger->info("successfully copied best model files from {} to {}",
source_repo, target_repo);

update_config_json_parameters(target_repo, logger);

return 0;
}
// else if best model file does not exist
Expand All @@ -218,4 +221,94 @@ namespace dd
source_repo, target_repo);
return 1;
}

void TorchModel::update_config_json_parameters(
const std::string &target_repo,
const std::shared_ptr<spdlog::logger> &logger)
{
// parse config.json and model.json
std::string config_path = target_repo + "/config.json";
std::string model_path = target_repo + "/model.json";
std::ifstream ifs_config(config_path.c_str(), std::ios::binary);
if (!ifs_config.is_open())
{
logger->error("could not find config file {} for export update",
config_path);
return;
}
std::stringstream config_sstr;
config_sstr << ifs_config.rdbuf();
ifs_config.close();
std::ifstream ifs_model(model_path.c_str(), std::ios::binary);
if (!ifs_model.is_open())
{
logger->error("could not find model file {} for export update",
config_path);
return;
}
std::stringstream model_sstr;
model_sstr << ifs_model.rdbuf();
ifs_model.close();

rapidjson::Document d_config;
d_config.Parse<rapidjson::kParseNanAndInfFlag>(config_sstr.str().c_str());
rapidjson::Document d_model;
d_model.Parse<rapidjson::kParseNanAndInfFlag>(model_sstr.str().c_str());

// apply changes
bool config_update = false;
//- crop_size
try
{
int crop_size = d_model["parameters"]["mllib"]["crop_size"].GetInt();
if (crop_size > 0)
{
d_config["parameters"]["input"]["width"].SetInt(crop_size);
d_config["parameters"]["input"]["height"].SetInt(crop_size);
}
config_update = true;
}
catch (RapidjsonException &e)
{
config_update = false;
}
//- db
try
{
auto d_input = d_config["parameters"]["input"].GetObject();
if (d_input.HasMember("db"))
d_input["db"].SetBool(false);
auto d_mllib = d_config["parameters"]["mllib"].GetObject();
if (d_mllib.HasMember("db"))
d_input["db"].SetBool(false);
config_update = true;
}
catch (RapidjsonException &e)
{
config_update |= false;
}

if (!config_update)
{
logger->warn("no update required to config.json");
return;
}

// save updated config.json
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer, rapidjson::UTF8<>,
rapidjson::UTF8<>, rapidjson::CrtAllocator,
rapidjson::kWriteNanAndInfFlag>
writer(buffer);
bool done = d_config.Accept(writer);
if (!done)
throw DataConversionException("JSON rendering failed");
std::string config_str = buffer.GetString();
std::ofstream config_out(config_path.c_str(), std::ios::out
| std::ios::binary
| std::ios::trunc);
config_out << config_str;
config_out.close();
logger->info("successfully updated {}", config_path);
}
}
4 changes: 4 additions & 0 deletions src/backends/torch/torchmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ namespace dd
const std::string &target_repo,
const std::shared_ptr<spdlog::logger> &logger);

void update_config_json_parameters(
const std::string &target_repo,
const std::shared_ptr<spdlog::logger> &logger);

public:
std::string _traced; /**< path of the traced part of the net. */
std::string _weights; /**< path of the weights of the net. */
Expand Down
Loading

0 comments on commit b0d4e04

Please sign in to comment.