Skip to content

Commit

Permalink
fix(torch): correctly normalize l1 and l2 metrics in case of multi di…
Browse files Browse the repository at this point in the history
…m regression
  • Loading branch information
fantes committed Jan 10, 2023
1 parent 11905eb commit cc9a636
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/supervisedoutputconnector.h
Original file line number Diff line number Diff line change
Expand Up @@ -2615,6 +2615,7 @@ namespace dd
target = bad.get("target").get<std::vector<double>>();
else
target.push_back(bad.get("target").get<double>());
int reg_dim = predictions.size();
double leucl = 0;
for (size_t j = 0; j < target.size(); j++)
{
Expand All @@ -2627,7 +2628,7 @@ namespace dd
if (diff >= thres)
{
if (l1)
eucl += diff;
eucl += diff / reg_dim;
else
leucl += diff * diff;
if (compute_all_distl)
Expand All @@ -2637,7 +2638,7 @@ namespace dd
else
{
if (l1)
eucl += diff;
eucl += diff / reg_dim;
else
leucl += diff * diff;
if (compute_all_distl)
Expand All @@ -2651,7 +2652,7 @@ namespace dd
}
if (!l1)
{
eucl += sqrt(leucl);
eucl += sqrt(leucl) / reg_dim;
if (compute_all_distl)
for (size_t j = 0; j < target.size(); ++j)
all_eucl[j] = sqrt(all_eucl[j]);
Expand Down

0 comments on commit cc9a636

Please sign in to comment.