Skip to content

Commit

Permalink
feat(api): add a model_stats field containing the number of parameter…
Browse files Browse the repository at this point in the history
…s of the model

Added an utility DTO type to allow DTO & APIData interoperability. This may ease APIData replacement in the future
  • Loading branch information
Bycob authored and mergify[bot] committed Feb 27, 2023
1 parent 9b7581a commit b562fee
Show file tree
Hide file tree
Showing 21 changed files with 326 additions and 164 deletions.
2 changes: 1 addition & 1 deletion src/apidata.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ namespace dd
buffer.GetString());
}

template <typename T> static APIData fromDTO(const oatpp::Void &dto)
static APIData fromDTO(const oatpp::Void &dto)
{
std::shared_ptr<oatpp::data::mapping::ObjectMapper> object_mapper
= dd::oatpp_utils::createDDMapper();
Expand Down
2 changes: 1 addition & 1 deletion src/backends/caffe/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2830,7 +2830,7 @@ namespace dd
= oatpp::Object<DTO::ServicePredict>(
std::static_pointer_cast<typename DTO::ServicePredict>(
any->ptr));
ad = APIData::fromDTO<oatpp::Void>(predict_dto);
ad = APIData::fromDTO(predict_dto);

if (predict_dto->_chain)
{
Expand Down
2 changes: 1 addition & 1 deletion src/backends/torch/torchgraphbackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#include <torch/torch.h>
#pragma GCC diagnostic pop
#include <torch/ordered_dict.h>
#pragma GCC diagnostic pop

namespace dd
{
Expand Down
18 changes: 17 additions & 1 deletion src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,16 @@ namespace dd
}
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
class TMLModel>
void TorchLib<TInputConnectorStrategy, TOutputConnectorStrategy,
TMLModel>::compute_and_print_model_info()
{
_module.compute_and_print_model_info();
this->_model_params = _module._params_count;
this->_model_frozen_params = _module._frozen_params_count;
}

/*- from mllib -*/
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
class TMLModel>
Expand Down Expand Up @@ -492,7 +502,7 @@ namespace dd
// print
if (_module.is_ready(_template))
{
_module.print_model_info();
compute_and_print_model_info();
}

_best_metrics = { "map", "meaniou", "mlacc", "delta_score_0.1", "bacc",
Expand Down Expand Up @@ -658,8 +668,11 @@ namespace dd
try
{
inputc.transform(ad);
bool module_was_ready = _module.is_ready(_template);
_module.post_transform_train<TInputConnectorStrategy>(
_template, _template_params, inputc, this->_mlmodel, _main_device);
if (!module_was_ready)
compute_and_print_model_info();
}
catch (...)
{
Expand Down Expand Up @@ -1457,9 +1470,12 @@ namespace dd
// XXX: torchinputconn does not fully support DTOs yet
inputc.transform(ad_in);
}
bool module_was_ready = _module.is_ready(_template);
_module.post_transform_predict(_template, _template_params, inputc,
this->_mlmodel, _main_device,
predict_dto);
if (!module_was_ready)
compute_and_print_model_info();
}
catch (...)
{
Expand Down
3 changes: 3 additions & 0 deletions src/backends/torch/torchlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ namespace dd
*/
double unscale(double val, unsigned int k,
const TInputConnectorStrategy &inputc);

/** print and update model stats */
void compute_and_print_model_info();
};
}

Expand Down
55 changes: 27 additions & 28 deletions src/backends/torch/torchmodule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,10 @@ namespace dd
const TorchModel &tmodel,
const torch::Device &device)
{
bool model_changed = false;
if (!_native)
{
create_native_template<TInputConnectorStrategy>(
tmpl, template_params, inputc, tmodel, device);
if (_native)
model_changed = true;
}
if (_graph)
{
Expand All @@ -228,7 +225,6 @@ namespace dd
if (_graph->needs_reload())
{
_logger->info("net was reallocated due to input dim changes");
model_changed = true;
}
// reload params after finalize
graph_model_load(tmodel);
Expand All @@ -244,7 +240,6 @@ namespace dd
setup_linear_head(_nclasses,
const_cast<TInputConnectorStrategy &>(inputc)
.get_input_example(device));
model_changed = true;
}
catch (std::exception &e)
{
Expand All @@ -260,7 +255,6 @@ namespace dd
const_cast<TInputConnectorStrategy &>(inputc)
.get_input_example(device),
inputc._alphabet_size);
model_changed = true;
}
catch (std::exception &e)
{
Expand All @@ -270,11 +264,6 @@ namespace dd
}

to(device);

if (model_changed)
{
print_model_info();
}
}

template <class TInputConnectorStrategy>
Expand Down Expand Up @@ -641,11 +630,11 @@ namespace dd
void print_native_params(std::shared_ptr<spdlog::logger> logger,
const std::string &name,
const torch::nn::Module &module,
int64_t &param_count)
int64_t &param_count, int64_t &frozen_count)
{
logger->info("## {} parameters", name);
param_count = 0;
int64_t frozen_count = 0;
frozen_count = 0;
for (const auto &p : module.named_parameters())
{
std::stringstream sstream;
Expand All @@ -661,33 +650,36 @@ namespace dd
count *= s;
}
param_count += count;
if (p.value().requires_grad())
if (!p.value().requires_grad())
frozen_count += count;
}
logger->info("{} parameters count: {}", name,
long_number_to_str(param_count));
if (frozen_count != 0)
{
logger->info("\tfrozen = {}", frozen_count);
logger->info("\tfrozen = {}", long_number_to_str(frozen_count));
}
}

void TorchModule::print_model_info()
void TorchModule::compute_and_print_model_info()
{
int64_t total_param_count = 0;
int64_t total_frozen_count = 0;
if (_graph)
{
int64_t graph_param_count;
print_native_params(_logger, "Graph", *_graph, graph_param_count);
int64_t graph_param_count, graph_frozen_count;
print_native_params(_logger, "Graph", *_graph, graph_param_count,
graph_frozen_count);
total_param_count += graph_param_count;
return;
total_frozen_count += graph_frozen_count;
}
if (_native)
{
int64_t native_param_count;
print_native_params(_logger, "Native", *_native, native_param_count);
int64_t native_param_count, native_frozen_count;
print_native_params(_logger, "Native", *_native, native_param_count,
native_frozen_count);
total_param_count += native_param_count;
return;
total_frozen_count += native_frozen_count;
}
if (_traced)
{
Expand All @@ -709,32 +701,39 @@ namespace dd
count *= s;
}
traced_param_count += count;
if (p.value.requires_grad())
if (!p.value.requires_grad())
traced_frozen_count += count;
}
_logger->info("Traced parameters count: {}",
long_number_to_str(traced_param_count));
if (traced_frozen_count != 0)
{
_logger->info("\tfrozen = {}", traced_frozen_count);
_logger->info("\tfrozen = {}",
long_number_to_str(traced_frozen_count));
}
total_param_count += traced_param_count;
total_frozen_count += traced_frozen_count;
}
if (_linear_head)
{
int64_t linear_param_count;
int64_t linear_param_count, linear_frozen_count;
print_native_params(_logger, "Linear", *_linear_head,
linear_param_count);
linear_param_count, linear_frozen_count);
total_param_count += linear_param_count;
total_frozen_count += linear_frozen_count;
}
if (_crnn_head)
{
int64_t crnn_param_count;
print_native_params(_logger, "CRNN", *_crnn_head, crnn_param_count);
int64_t crnn_param_count, crnn_frozen_count;
print_native_params(_logger, "CRNN", *_crnn_head, crnn_param_count,
crnn_frozen_count);
total_param_count += crnn_param_count;
total_frozen_count += crnn_frozen_count;
}
_logger->info("## Total number of parameters: {}",
long_number_to_str(total_param_count));
_params_count = total_param_count;
_frozen_params_count = total_frozen_count;
}

template void TorchModule::post_transform(
Expand Down
17 changes: 11 additions & 6 deletions src/backends/torch/torchmodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,17 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#include <torch/torch.h>
#pragma GCC diagnostic pop
#include "torchmodel.h"
#include "torchgraphbackend.h"
#include "native/native_net.h"
#include "native/templates/crnn_head.hpp"
#include <torch/script.h>
#include <torch/nn/pimpl.h>
#if !defined(CPU_ONLY)
#include <torch/nn/parallel/data_parallel.h>
#endif
#pragma GCC diagnostic pop

#include "torchmodel.h"
#include "torchgraphbackend.h"
#include "native/native_net.h"
#include "native/templates/crnn_head.hpp"

namespace dd
{
Expand Down Expand Up @@ -203,7 +204,7 @@ namespace dd
* \brief print model information such as parameter count, number of
* parameters for each layer, whether the layers are frozen or not
**/
void print_model_info();
void compute_and_print_model_info();

public:
std::shared_ptr<torch::jit::script::Module>
Expand All @@ -218,6 +219,10 @@ namespace dd
torch::nn::Linear _linear_head = nullptr;
CRNNHead _crnn_head = nullptr;

// stats
int _params_count = 0; /**< number of parameters */
int _frozen_params_count = 0; /**< number of frozen parameters */

bool _require_linear_head = false;
bool _require_crnn_head = false;
std::string
Expand Down
2 changes: 2 additions & 0 deletions src/dto/ddtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ namespace dd
{
namespace __class
{
const oatpp::ClassId APIDataClass::CLASS_ID("APIData");

const oatpp::ClassId GpuIdsClass::CLASS_ID("GpuIds");

template <>
Expand Down
Loading

0 comments on commit b562fee

Please sign in to comment.