Skip to content

Commit

Permalink
feat(regression): add l1 metric for regression
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and mergify[bot] committed Jan 6, 2023
1 parent 8259a29 commit c82f08d
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ No parameters
Parameter | Type | Optional | Default | Description
--------- | ---- | -------- | ------- | -----------
best | int | yes | 1 | Number of top predictions returned by data URI (supervised)
measure | array | yes | empty | Output measures requested, from `acc`: accuracy, `acc-k`: top-k accuracy, replace k with number (e.g. `acc-5`), `f1`: f1, precision and recall, `mcll`: multi-class log loss, `auc`: area under the curve, `cmdiag`: diagonal of confusion matrix (requires `f1`), `cmfull`: full confusion matrix (requires `f1`), `mcc`: Matthews correlation coefficient, `eucll`: euclidean distance (e.g. for regression tasks), `kl`: KL_divergence, `js`: JS divergence, `was`: Wasserstein, `ks`: Kolmogorov Smirnov, `dc`: distance correlation, `r2`: R2, `deltas`: delta scores, 'raw': ouput raw results, in case of predict call, this requires a special deploy.prototxt that is a test network (to have ground truth)
measure | array | yes | empty | Output measures requested, from `acc`: accuracy, `acc-k`: top-k accuracy, replace k with number (e.g. `acc-5`), `f1`: f1, precision and recall, `mcll`: multi-class log loss, `auc`: area under the curve, `cmdiag`: diagonal of confusion matrix (requires `f1`), `cmfull`: full confusion matrix (requires `f1`), `mcc`: Matthews correlation coefficient, `eucll`: euclidean distance (e.g. for regression tasks),`l1`: l1 distance (e.g. for regression tasks), `kl`: KL_divergence, `js`: JS divergence, `was`: Wasserstein, `ks`: Kolmogorov Smirnov, `dc`: distance correlation, `r2`: R2, `deltas`: delta scores, 'raw': ouput raw results, in case of predict call, this requires a special deploy.prototxt that is a test network (to have ground truth)
target_repository | string | yes | empty | target directory to which to copy the best model files once training has completed

#### Machine learning libraries
Expand Down
57 changes: 44 additions & 13 deletions src/supervisedoutputconnector.h
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,10 @@ namespace dd
bool beucll = false;
float beucll_thres = -1;
find_presence_and_thres("eucll", measures, beucll, beucll_thres);
bool compute_all_meucll = beucll && !autoencoder;
bool bl1 = (std::find(measures.begin(), measures.end(), "l1")
!= measures.end());
bool compute_all_distl = (beucll || bl1) && !autoencoder;

bool bmcc = (std::find(measures.begin(), measures.end(), "mcc")
!= measures.end());
bool baccv = false;
Expand Down Expand Up @@ -1074,16 +1077,16 @@ namespace dd
double meucll;
std::vector<double> all_meucll;
std::tie(meucll, all_meucll)
= eucll(ad_res, -1, compute_all_meucll);
= distl(ad_res, -1, compute_all_distl, false);
meas_out.add("eucll", meucll);
if (all_meucll.size() > 1 && compute_all_meucll)
if (all_meucll.size() > 1 && compute_all_distl)
for (unsigned int i = 0; i < all_meucll.size(); ++i)
meas_out.add("eucll_" + std::to_string(i), all_meucll[i]);

if (beucll_thres > 0)
{
std::tuple<double, std::vector<double>> tmeucll_thres
= eucll(ad_res, beucll_thres, compute_all_meucll);
= distl(ad_res, beucll_thres, compute_all_distl, false);
double meucll_thres = std::get<0>(tmeucll_thres);
std::string b = "eucll_no_" + std::to_string(beucll_thres);
meas_out.add(b, meucll_thres);
Expand All @@ -1098,6 +1101,16 @@ namespace dd
}
}
}
if (bl1)
{
double ml1;
std::vector<double> all_ml1;
std::tie(ml1, all_ml1)
= distl(ad_res, -1, compute_all_distl, true);
meas_out.add("l1", ml1);
for (unsigned int i = 0; i < all_ml1.size(); ++i)
meas_out.add("l1_" + std::to_string(i), all_ml1[i]);
}
if (bmcc)
{
double mmcc = mcc(ad_res);
Expand Down Expand Up @@ -2574,15 +2587,16 @@ namespace dd
}

static std::tuple<double, std::vector<double>>
eucll(const APIData &ad, float thres, bool compute_all_meucll)
distl(const APIData &ad, float thres, bool compute_all_distl,
bool l1 = false)
{
double eucl = 0.0;
unsigned int psize = ad.getobj(std::to_string(0))
.get("pred")
.get<std::vector<double>>()
.size();
std::vector<double> all_eucl;
if (compute_all_meucll)
if (compute_all_distl)
all_eucl.resize(psize, 0.0);
int batch_size = ad.get("batch_size").get<int>();
bool has_ignore = ad.has("ignore_label");
Expand Down Expand Up @@ -2612,22 +2626,39 @@ namespace dd
{
if (diff >= thres)
{
leucl += diff * diff;
if (compute_all_meucll)
if (l1)
eucl += diff;
else
leucl += diff * diff;
if (compute_all_distl)
all_eucl[j] += diff;
}
}
else
{
leucl += diff * diff;
if (compute_all_meucll)
all_eucl[j] += diff;
if (l1)
eucl += diff;
else
leucl += diff * diff;
if (compute_all_distl)
{
if (l1)
all_eucl[j] += diff;
else
all_eucl[j] += diff * diff;
}
}
}
eucl += sqrt(leucl);
if (!l1)
{
eucl += sqrt(leucl);
if (compute_all_distl)
for (size_t j = 0; j < target.size(); ++j)
all_eucl[j] = sqrt(all_eucl[j]);
}
}

if (compute_all_meucll)
if (compute_all_distl)
for (unsigned int i = 0; i < all_eucl.size(); ++i)
all_eucl[i] /= static_cast<double>(batch_size);

Expand Down
62 changes: 62 additions & 0 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1996,6 +1996,68 @@ TEST(torchapi, service_train_images_split_regression_2dims_db_false)
fileops::remove_dir(resnet50_train_repo + "test_0.lmdb");
}

TEST(torchapi, service_train_images_split_regression_2dims_db_false_l1)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
torch::manual_seed(torch_seed);
at::globalContext().setDeterministicCuDNN(true);

// Create service
JsonAPI japi;
std::string sname = "imgserv";
std::string jstr
= "{\"mllib\":\"torch\",\"description\":\"image\",\"type\":"
"\"supervised\",\"model\":{\"repository\":\""
+ resnet50_train_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\","
"\"width\":224,\"height\":224,\"db\":false},\"mllib\":{\"ntargets\":"
"2,\"finetuning\":true,\"regression\":true,\"gpu\":true,\"loss\":"
"\"L2\"}}}";
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// Train
std::string jtrainstr
= "{\"service\":\"imgserv\",\"async\":false,\"parameters\":{"
"\"mllib\":{\"solver\":{\"iterations\":"
+ iterations_resnet50 + ",\"base_lr\":" + torch_lr
+ ",\"iter_size\":4,\"solver_"
"type\":\"ADAM\",\"test_interval\":200},\"net\":{\"batch_size\":4, "
"\"test_batch_size\":3},"
"\"resume\":false},"
"\"input\":{\"seed\":12345,\"db\":false,\"shuffle\":true,\"test_"
"split\":0.1},"
"\"output\":{\"measure\":[\"l1\"]}},\"data\":[\""
+ resnet50_train_data_reg2 + "\"]}";
joutstr = japi.jrender(japi.service_train(jtrainstr));
JDoc jd;
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(201, jd["status"]["code"]);

ASSERT_TRUE(jd["body"]["measure"]["iteration"] == 200) << "iterations";
ASSERT_TRUE(jd["body"]["measure"]["l1"].GetDouble() <= 15.0) << "l1";

std::unordered_set<std::string> lfiles;
fileops::list_directory(resnet50_train_repo, true, false, false, lfiles);
for (std::string ff : lfiles)
{
if (ff.find("checkpoint") != std::string::npos
|| ff.find("solver") != std::string::npos)
remove(ff.c_str());
}
ASSERT_TRUE(!fileops::file_exists(resnet50_train_repo + "checkpoint-"
+ iterations_resnet50 + ".ptw"));
ASSERT_TRUE(!fileops::file_exists(resnet50_train_repo + "checkpoint-"
+ iterations_resnet50 + ".pt"));

fileops::clear_directory(resnet50_train_repo + "train.lmdb");
fileops::clear_directory(resnet50_train_repo + "test_0.lmdb");
fileops::remove_dir(resnet50_train_repo + "train.lmdb");
fileops::remove_dir(resnet50_train_repo + "test_0.lmdb");
}

TEST(torchapi, service_train_object_detection_fasterrcnn)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
Expand Down

0 comments on commit c82f08d

Please sign in to comment.