Skip to content

Commit

Permalink
feat(torch): add map metrics with arbitrary iou threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Apr 6, 2023
1 parent c29ce88 commit 20d8ebe
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 30 deletions.
39 changes: 32 additions & 7 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2017,6 +2017,17 @@ namespace dd
ad_out.add("measure", meas);
}

std::vector<int> iou_thresholds;
std::map<int, APIData> ad_bbox_per_iou;
if (_bbox)
{
auto meas = ad_out.get("measure").get<std::vector<std::string>>();
SupervisedOutput::find_ap_iou_thresholds(meas, iou_thresholds);

for (int i : iou_thresholds)
ad_bbox_per_iou[i] = APIData();
}

auto dataloader = torch::data::make_data_loader(
dataset, data::DataLoaderOptions(batch_size));
torch::Device cpu("cpu");
Expand Down Expand Up @@ -2120,11 +2131,19 @@ namespace dd
++stop;
}

auto vbad = get_bbox_stats(
targ_bboxes.index({ torch::indexing::Slice(start, stop) }),
targ_labels.index({ torch::indexing::Slice(start, stop) }),
bboxes_tensor, labels_tensor, score_tensor);
ad_bbox.add(std::to_string(entry_id), vbad);
for (int iou_thres : iou_thresholds)
{
double iou_thres_d = static_cast<double>(iou_thres) / 100;
std::vector<APIData> vbad = get_bbox_stats(
targ_bboxes.index(
{ torch::indexing::Slice(start, stop) }),
targ_labels.index(
{ torch::indexing::Slice(start, stop) }),
bboxes_tensor, labels_tensor, score_tensor,
iou_thres_d);
ad_bbox_per_iou[iou_thres].add(std::to_string(entry_id),
vbad);
}
++entry_id;
}
}
Expand Down Expand Up @@ -2299,6 +2318,12 @@ namespace dd
{
ad_res.add("bbox", true);
ad_res.add("pos_count", entry_id);

for (int iou_thres : iou_thresholds)
{
ad_bbox.add("map-" + std::to_string(iou_thres),
ad_bbox_per_iou[iou_thres]);
}
ad_res.add("0", ad_bbox);
}
else if (_segmentation)
Expand All @@ -2318,7 +2343,8 @@ namespace dd
const at::Tensor &targ_labels,
const at::Tensor &bboxes_tensor,
const at::Tensor &labels_tensor,
const at::Tensor &score_tensor)
const at::Tensor &score_tensor,
float overlap_threshold)
{
auto targ_bboxes_acc = targ_bboxes.accessor<float, 2>();
auto targ_labels_acc = targ_labels.accessor<int64_t, 1>();
Expand Down Expand Up @@ -2348,7 +2374,6 @@ namespace dd
};

std::vector<eval_info> eval_infos(_nclasses);
float overlap_threshold = 0.5; // TODO: parameter

for (int j = 0; j < pred_bbox_count; ++j)
{
Expand Down
3 changes: 2 additions & 1 deletion src/backends/torch/torchlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ namespace dd
const at::Tensor &targ_labels,
const at::Tensor &bboxes_tensor,
const at::Tensor &labels_tensor,
const at::Tensor &score_tensor);
const at::Tensor &score_tensor,
float overlap_threshold);

public:
unsigned int _nclasses = 0; /**< number of classes*/
Expand Down
106 changes: 96 additions & 10 deletions src/supervisedoutputconnector.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
#define SUPERVISEDOUTPUTCONNECTOR_H
#define TS_METRICS_EPSILON 1E-2

#include <sstream>
#include <iomanip>

#include "dto/output_connector.hpp"

template <typename T>
Expand Down Expand Up @@ -845,19 +848,64 @@ namespace dd
}
if (bbox)
{
bool bbmap = (std::find(measures.begin(), measures.end(), "map")
!= measures.end());
if (bbmap)
// required iou thresholds for map. If there are more than one
// threshold, the global map is the mean over the different iou
// thresholds.
std::vector<int> thresholds;
bool has_map = find_ap_iou_thresholds(measures, thresholds);

if (has_map)
{
std::map<int, float> aps;
double bmap = ap(ad_res, aps);
meas_out.add("map", bmap);
for (auto ap : aps)
double sum_map = 0;
std::map<int, float> sum_aps;
int ap_count = 0;

// map for each threshold
for (int iou_thres : thresholds)
{
std::string s = "map_" + std::to_string(ap.first);
meas_out.add(s, static_cast<double>(ap.second));
std::map<int, float> aps;
double bmap = ap(ad_res, aps, iou_thres);
std::string map_key = "map";
if (iou_thres != 0)
{
std::stringstream ss;
ss << map_key << "-" << std::setfill('0')
<< std::setw(2) << iou_thres;
map_key = ss.str();
}
meas_out.add(map_key, bmap);
for (auto ap : aps)
{
std::string s
= map_key + "_" + std::to_string(ap.first);
meas_out.add(s, static_cast<double>(ap.second));
}

sum_map += bmap;
if (sum_aps.size() == 0)
sum_aps = aps;
else
{
for (auto ap : aps)
sum_aps[ap.first] += ap.second;
}
ap_count++;
}

// mean of all thresholds
if (thresholds.size() > 0)
{
meas_out.add("map", sum_map / ap_count);
for (auto sum_ap : sum_aps)
{
std::string s
= "map_" + std::to_string(sum_ap.first);
meas_out.add(s, static_cast<double>(sum_ap.second
/ ap_count));
}
}
}

bool raw = (std::find(measures.begin(), measures.end(), "raw")
!= measures.end());
if (raw)
Expand Down Expand Up @@ -1608,6 +1656,35 @@ namespace dd
}
}

/** \param thresholds the requested iou thresholds in percent (int)
* \return true if map is requested, false otherwise */
static bool
find_ap_iou_thresholds(const std::vector<std::string> &measures,
std::vector<int> &thresholds)
{
bool has_map = false;
for (std::string s : measures)
{
if (s.find("map") != std::string::npos)
{
has_map = true;
std::vector<std::string> sv = dd_utils::split(s, '-');
int iou_thres = 0;

if (sv.size() == 2)
{
iou_thres = std::atoi(sv.at(1).c_str());
thresholds.push_back(iou_thres);
}
}
}

// Default threshold is 0.5 (map 50)
if (thresholds.empty())
thresholds.push_back(50);
return has_map;
}

static double straight_meas(const APIData &ad)
{
APIData bad = ad.getobj("0");
Expand Down Expand Up @@ -2407,13 +2484,22 @@ namespace dd
return ap;
}

static double ap(const APIData &ad, std::map<int, float> &APs)
/**
* Compute AP for all classes and mean AP
* \param APs std::map containing AP for each class
* \param thres iou threshold for map in percent
*/
static double ap(const APIData &ad, std::map<int, float> &APs, int thres)
{
double mmAP = 0.0;
std::map<int, int> APs_count;
int APs_count_all = 0;
// extract tp, fp, labels
APIData bad = ad.getobj("0");
std::string map_key = "map-" + std::to_string(thres);
if (bad.has(map_key))
bad = bad.getobj(map_key);
// else: default threshold (legacy)
int pos_count = ad.get("pos_count").get<int>();
for (int i = 0; i < pos_count; i++)
{
Expand Down
63 changes: 51 additions & 12 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ TEST(torchapi, compute_bbox_stats)
11, 11, 101, 101, // matching
900, 10, 950, 100, // false positive
510, 510, 610, 610, // 2 preds for 1 targets
490, 490, 590, 590, // --
490, 490, 590, 590, // (second pred)
940, 940, 990, 990, // overlapping but iou < 0.5 -> false positive
};
at::Tensor bboxes_tensor = torch::from_blob(bboxes_data, { 5, 4 });
Expand All @@ -619,10 +619,10 @@ TEST(torchapi, compute_bbox_stats)
float score_data[] = { 0.9, 0.8, 0.7, 0.6, 0.5 };
at::Tensor score_tensor = torch::from_blob(score_data, 5);

auto vbad = torchlib.get_bbox_stats(targ_bboxes, targ_labels, bboxes_tensor,
labels_tensor, score_tensor);

auto lbad = vbad.at(0);
auto vbad50
= torchlib.get_bbox_stats(targ_bboxes, targ_labels, bboxes_tensor,
labels_tensor, score_tensor, 0.5);
auto lbad = vbad50.at(0);
auto tp_i = lbad.get("tp_i").get<std::vector<int>>();
auto tp_d = lbad.get("tp_d").get<std::vector<double>>();
auto fp_i = lbad.get("fp_i").get<std::vector<int>>();
Expand All @@ -643,6 +643,19 @@ TEST(torchapi, compute_bbox_stats)
}
ASSERT_EQ(lbad.get("num_pos").get<int>(), 4);
ASSERT_EQ(lbad.get("label").get<int>(), 1);
APIData ad_bbox_50;
ad_bbox_50.add("0", vbad50);

// with map 90 the third bbox is no longer matching
auto vbad90
= torchlib.get_bbox_stats(targ_bboxes, targ_labels, bboxes_tensor,
labels_tensor, score_tensor, 0.9);
lbad = vbad90.at(0);
tp_i = lbad.get("tp_i").get<std::vector<int>>();
ASSERT_EQ(std::accumulate(tp_i.begin(), tp_i.end(), 0), 1);
ASSERT_FALSE(tp_i[2]);
APIData ad_bbox_90;
ad_bbox_90.add("0", vbad90);

// Get MAP
APIData ad_res;
Expand All @@ -651,15 +664,25 @@ TEST(torchapi, compute_bbox_stats)
ad_res.add("bbox", true);
ad_res.add("pos_count", 1);
APIData ad_bbox;
ad_bbox.add("0", vbad);
ad_bbox.add("map-50", ad_bbox_50);
ad_bbox.add("map-90", ad_bbox_90);
ad_res.add("0", ad_bbox);
ad_res.add("batch_size", 1);
APIData ad_out;
ad_out.add("measure", std::vector<std::string>{ "map" });
ad_out.add("measure", std::vector<std::string>{ "map", "map-50", "map-90" });
APIData out;
SupervisedOutput::measure(ad_res, ad_out, out, 0, "test");
ASSERT_NEAR(out.getobj("measure").get("map").get<double>(), 5. / 12.,
JsonAPI japi;
JDoc jdoc;
jdoc.SetObject();
out.toJDoc(jdoc);
std::cout << japi.jrender(jdoc) << std::endl;
ASSERT_NEAR(out.getobj("measure").get("map-50").get<double>(), 5. / 12.,
std::numeric_limits<float>::epsilon());
ASSERT_NEAR(out.getobj("measure").get("map-90").get<double>(), 0.25,
std::numeric_limits<float>::epsilon());
ASSERT_NEAR(out.getobj("measure").get("map").get<double>(),
(5. / 12 + 0.25) / 2., std::numeric_limits<float>::epsilon());
}

TEST(torchapi, map_false_negative)
Expand Down Expand Up @@ -691,7 +714,7 @@ TEST(torchapi, map_false_negative)
at::Tensor score_tensor = torch::from_blob(score_data, 1);

auto vbad = torchlib.get_bbox_stats(targ_bboxes, targ_labels, bboxes_tensor,
labels_tensor, score_tensor);
labels_tensor, score_tensor, 0.5);

// Get MAP
APIData ad_res;
Expand Down Expand Up @@ -2360,7 +2383,8 @@ TEST(torchapi, service_train_object_detection_yolox)
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"prob\":0.01}},\"input\":{\"seed\":12347,\"db\":true,"
"\"shuffle\":true},\"output\":{\"measure\":[\"map\"]}},\"data\":[\""
"\"shuffle\":true},\"output\":{\"measure\":[\"map-05\",\"map-50\","
"\"map-90\"]}},\"data\":[\""
+ fasterrcnn_train_data + "\",\"" + fasterrcnn_test_data + "\"]}";

joutstr = japi.jrender(japi.service_train(jtrainstr));
Expand All @@ -2372,6 +2396,9 @@ TEST(torchapi, service_train_object_detection_yolox)

// ASSERT_EQ(jd["body"]["measure"]["iteration"], 200) << "iterations";
ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() <= 1.0) << "map";
ASSERT_TRUE(jd["body"]["measure"]["map-05"].GetDouble() <= 1.0) << "map-05";
ASSERT_TRUE(jd["body"]["measure"]["map-50"].GetDouble() <= 1.0) << "map-50";
ASSERT_TRUE(jd["body"]["measure"]["map-90"].GetDouble() <= 1.0) << "map-90";
// ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() > 0.0) << "map";

// check metrics
Expand Down Expand Up @@ -2456,7 +2483,8 @@ TEST(torchapi, service_train_object_detection_yolox_no_db)
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"prob\":0.01}},\"input\":{\"seed\":12347,\"db\":false,"
"\"shuffle\":true},\"output\":{\"measure\":[\"map\"]}},\"data\":[\""
"\"shuffle\":true},\"output\":{\"measure\":[\"map-90\",\"map\"]}},"
"\"data\":[\""
+ fasterrcnn_train_data + "\",\"" + fasterrcnn_test_data + "\"]}";

joutstr = japi.jrender(japi.service_train(jtrainstr));
Expand All @@ -2468,6 +2496,8 @@ TEST(torchapi, service_train_object_detection_yolox_no_db)

// ASSERT_EQ(jd["body"]["measure"]["iteration"], 200) << "iterations";
ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() <= 1.0) << "map";
ASSERT_TRUE(jd["body"]["measure"]["map-90"].GetDouble() <= 1.0) << "map-90";
ASSERT_FALSE(jd["body"]["measure"].HasMember("map-50"));
// ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() > 0.0) << "map";

// check metrics
Expand Down Expand Up @@ -2551,7 +2581,8 @@ TEST(torchapi, service_train_object_detection_yolox_multigpu)
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"prob\":0.01}},\"input\":{\"seed\":12347,\"db\":true,"
"\"shuffle\":true},\"output\":{\"measure\":[\"map\"]}},\"data\":[\""
"\"shuffle\":true},\"output\":{\"measure\":[\"map-50\",\"map-90\"]}}"
",\"data\":[\""
+ fasterrcnn_train_data + "\",\"" + fasterrcnn_test_data + "\"]}";

joutstr = japi.jrender(japi.service_train(jtrainstr));
Expand All @@ -2563,6 +2594,14 @@ TEST(torchapi, service_train_object_detection_yolox_multigpu)

ASSERT_EQ(jd["body"]["measure"]["iteration"], 200) << "iterations";
ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() <= 1.0) << "map";
ASSERT_TRUE(jd["body"]["measure"]["map-50"].GetDouble() <= 1.0) << "map-50";
ASSERT_TRUE(jd["body"]["measure"]["map-90"].GetDouble() <= 1.0) << "map-90";
ASSERT_LE(jd["body"]["measure"]["map-90"].GetDouble(),
jd["body"]["measure"]["map-50"].GetDouble());
ASSERT_NEAR((jd["body"]["measure"]["map-90"].GetDouble()
+ jd["body"]["measure"]["map-50"].GetDouble())
/ 2,
jd["body"]["measure"]["map"].GetDouble(), 0.001);
// ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() > 0.0) << "map";

// check metrics
Expand Down

0 comments on commit 20d8ebe

Please sign in to comment.