Skip to content

Commit

Permalink
fix(detection/torch): correctly normalize MAP wrt torchlib outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and beniz committed Feb 9, 2022
1 parent 72d7f51 commit b12d188
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions src/supervisedoutputconnector.h
Original file line number Diff line number Diff line change
Expand Up @@ -2358,35 +2358,28 @@ namespace dd
{
double mmAP = 0.0;
std::map<int, int> APs_count;

int APs_count_all = 0;
// extract tp, fp, labels
APIData bad = ad.getobj("0");
int pos_count = ad.get("pos_count").get<int>();
for (int i = 0; i < pos_count; i++)
{
// do a mean over label AP per image in test set
double mAP = 0.0;
std::vector<APIData> vbad = bad.getv(std::to_string(i));
// std::cerr << "vbad size=" << vbad.size() << std::endl;
for (size_t j = 0; j < vbad.size(); j++)
{
std::vector<double> tp_d
= vbad.at(j).get("tp_d").get<std::vector<double>>();
// std::cerr << "tp_d size=" << tp_d.size() << std::endl;
std::vector<int> tp_i
= vbad.at(j).get("tp_i").get<std::vector<int>>();
std::vector<double> fp_d
= vbad.at(j).get("fp_d").get<std::vector<double>>();
std::vector<int> fp_i
= vbad.at(j).get("fp_i").get<std::vector<int>>();
int num_pos = vbad.at(j).get("num_pos").get<int>();
// std::cerr << "num_pos=" << num_pos << std::endl;
int label = vbad.at(j).get("label").get<int>();
// std::cerr << "label=" << label << std::endl;
std::vector<std::pair<double, int>> tp;
std::vector<std::pair<double, int>> fp;

// std::cerr << "fp_d size=" << fp_d.size() << std::endl;
for (size_t j = 0; j < tp_d.size(); j++)
{
tp.push_back(std::pair<double, int>(tp_d.at(j), tp_i.at(j)));
Expand All @@ -2396,31 +2389,39 @@ namespace dd
fp.push_back(std::pair<double, int>(fp_d.at(j), fp_i.at(j)));
}

double local_ap = compute_ap(tp, fp, num_pos);
if (APs.find(label) == APs.end())
if (tp.size() > 0 or fp.size() > 0)
{
APs[label] = local_ap;
APs_count[label] = 1;
APs_count_all += 1;
double local_ap = compute_ap(tp, fp, num_pos);
if (APs.find(label) == APs.end())
{
APs[label] = local_ap;
APs_count[label] = 1;
}
else
{
APs[label] += local_ap;
APs_count[label] += 1;
}
mmAP += local_ap;
}
else
else if (APs.find(label) == APs.end())
{
APs[label] += local_ap;
APs_count[label] += 1;
APs[label] = 0.0;
APs_count[label] = 0;
}
// std::cerr << "ap for label " << label << "=" << APs[label]
// << std::endl;
mAP += local_ap;
}
mAP /= static_cast<double>(vbad.size());
// do a mean mAP over images in test set
mmAP += mAP;
}
for (auto ap : APs)
{
APs[ap.first] /= static_cast<float>(APs_count[ap.first]);
if (APs_count[ap.first] == 0)
APs[ap.first] = 0.0;
else
APs[ap.first] /= static_cast<float>(APs_count[ap.first]);
}
return mmAP / static_cast<double>(pos_count);
;
if (APs_count_all == 0)
return 0.0;
return mmAP / static_cast<double>(APs_count_all);
}

// measure: AUC
Expand Down

0 comments on commit b12d188

Please sign in to comment.