Skip to content

Commit

Permalink
NonMaxSuppression op ref implementation (#968)
Browse files Browse the repository at this point in the history
This PR is the ref implementation of the nonmaxsuppression operator. It always returns the max possible output shape, which is the problem tracked in issue #948.
  • Loading branch information
scxiao committed Oct 28, 2021
1 parent cf0b6d6 commit c98b22d
Show file tree
Hide file tree
Showing 10 changed files with 450 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ register_migraphx_ops(
multibroadcast
multinomial
neg
nonmaxsuppression
nonzero
outline
pad
Expand Down
234 changes: 234 additions & 0 deletions src/include/migraphx/op/nonmaxsuppression.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_NONMAXSUPPRESSION_HPP
#define MIGRAPHX_GUARD_OPERATORS_NONMAXSUPPRESSION_HPP

#include <cmath>
#include <queue>
#include <cstdint>
#include <iterator>
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/tensor_view.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/output_iterator.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

struct nonmaxsuppression
{
bool center_point_box = false;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.center_point_box, "center_point_box"));
}

std::string name() const { return "nonmaxsuppression"; }

shape compute_shape(std::vector<shape> inputs) const
{
// requires at least 2 inputs
check_shapes{inputs, *this}.standard();
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3);
auto lens = inputs.front().lens();

// check input shape
if(lens[1] != inputs.at(1).lens()[2])
{
MIGRAPHX_THROW("NonMaxSuppression: dimension mismatch between first and second input!");
}

std::vector<int64_t> out_lens(2);
out_lens.at(0) = lens.at(1);
out_lens.at(1) = 3;
return {shape::int64_type, out_lens};
}

struct box
{
std::array<float, 2> x;
std::array<float, 2> y;

void sort()
{
std::sort(x.begin(), x.end());
std::sort(y.begin(), y.end());
}

std::array<float, 2>& operator[](std::size_t i) { return i == 0 ? x : y; }

float area() const
{
assert(std::is_sorted(x.begin(), x.end()));
assert(std::is_sorted(y.begin(), y.end()));
return (x[1] - x[0]) * (y[1] - y[0]);
}
};

template <class T>
box batch_box(const T* boxes, std::size_t bidx) const
{
box result{};
const T* start = boxes + 4 * bidx;
if(center_point_box)
{
float half_width = start[2] / 2.0f;
float half_height = start[3] / 2.0f;
float x_center = start[0];
float y_center = start[1];
result.x = {x_center - half_width, x_center + half_width};
result.y = {y_center - half_height, y_center + half_height};
}
else
{
result.x = {start[1], start[3]};
result.y = {start[0], start[2]};
}

return result;
}

inline bool suppress_by_iou(box b1, box b2, float iou_threshold) const
{
b1.sort();
b2.sort();

box intersection{};
for(auto i : range(2))
{
intersection[i][0] = std::max(b1[i][0], b2[i][0]);
intersection[i][1] = std::min(b1[i][1], b2[i][1]);
}

std::vector<std::array<float, 2>> bbox = {intersection.x, intersection.y};
if(std::any_of(bbox.begin(), bbox.end(), [](auto bx) {
return not std::is_sorted(bx.begin(), bx.end());
}))
{
return false;
}

const float area1 = b1.area();
const float area2 = b2.area();
const float intersection_area = intersection.area();
const float union_area = area1 + area2 - intersection_area;

if(area1 <= .0f or area2 <= .0f or union_area <= .0f)
{
return false;
}

const float intersection_over_union = intersection_area / union_area;

return intersection_over_union > iou_threshold;
}

argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};

result.visit([&](auto out) { std::fill(out.begin(), out.end(), 0); });

std::size_t max_output_boxes_per_class = 0;
float iou_threshold = 0.0f;
float score_threshold = 0.0f;

if(args.size() > 2)
{
max_output_boxes_per_class = args.at(2).at<std::size_t>();
}
// max_output_boxes_per_class is 0, no output
if(max_output_boxes_per_class == 0)
{
return result;
}

if(args.size() > 3)
{
iou_threshold = args.at(3).at<float>();
}

if(args.size() > 4)
{
score_threshold = args.at(4).at<float>();
}

const auto& lens = args.at(1).get_shape().lens();
auto batch_num = lens[0];
auto class_num = lens[1];
auto box_num = args.at(0).get_shape().lens()[1];

std::vector<std::pair<float, int64_t>> selected_boxes_inside_class;
std::vector<int64_t> selected_indices;
selected_boxes_inside_class.reserve(output_shape.elements());

auto scores = make_view<float>(args.at(1).get_shape(), args.at(1).cast<float>());
const float* boxes = args.at(0).cast<float>();
shape comp_s{shape::float_type, {batch_num, class_num}};
shape_for_each(comp_s, [&](auto idx) {
auto bidx = idx[0];
auto cidx = idx[1];

std::size_t score_offset = (bidx * class_num + cidx) * box_num;
const float* batch_boxes = boxes + bidx * box_num * 4;
std::priority_queue<std::pair<float, int64_t>> sorted_boxes;
auto insert_to_sorted_boxes =
make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); });

int64_t box_idx = 0;
transform_if(scores.begin() + score_offset,
scores.begin() + score_offset + box_num,
insert_to_sorted_boxes,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });

selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold
while(!sorted_boxes.empty() &&
selected_boxes_inside_class.size() < max_output_boxes_per_class)
{
const std::pair<float, int64_t>& next_top_score = sorted_boxes.top();

// Check with existing selected boxes for this class, suppress if exceed the IOU
// (Intersection Over Union) threshold
bool not_selected = std::any_of(
selected_boxes_inside_class.begin(),
selected_boxes_inside_class.end(),
[&](auto selected_index) {
return this->suppress_by_iou(batch_box(batch_boxes, next_top_score.second),
batch_box(batch_boxes, selected_index.second),
iou_threshold);
});

if(not not_selected)
{
selected_boxes_inside_class.push_back(next_top_score);
selected_indices.push_back(bidx);
selected_indices.push_back(cidx);
selected_indices.push_back(next_top_score.second);
}
sorted_boxes.pop();
}
});

result.visit([&](auto out) {
std::copy(selected_indices.begin(), selected_indices.end(), out.begin());
});

return result;
}
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
1 change: 1 addition & 0 deletions src/include/migraphx/operators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/neg.hpp>
#include <migraphx/op/nonmaxsuppression.hpp>
#include <migraphx/op/nonzero.hpp>
#include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp>
Expand Down
3 changes: 2 additions & 1 deletion src/onnx/parse_generic_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Log", "log"},
{"LRN", "lrn"},
{"Neg", "neg"},
{"NonMaxSuppression", "nonmaxsuppression"},
{"Reciprocal", "recip"},
{"Relu", "relu"},
{"Round", "round"},
Expand All @@ -49,7 +50,7 @@ struct parse_generic_op : op_parser<parse_generic_op>

bool needs_contiguous(const std::string& op_name) const
{
return contains({"flatten", "gather", "scatter"}, op_name);
return contains({"flatten", "gather", "nonmaxsuppression", "scatter"}, op_name);
}

instruction_ref parse(const op_desc& opd,
Expand Down
21 changes: 21 additions & 0 deletions src/targets/gpu/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ struct miopen_apply
add_if_op();
add_loop_op();
add_neg_op();
add_nms_op();
add_quant_convolution_op();
add_roialign();
}
Expand Down Expand Up @@ -524,6 +525,26 @@ struct miopen_apply
ins, make_op("gpu::loop", ins->get_operator().to_value()), inputs, mod_args);
});
}

void add_nms_op()
{
apply_map.emplace("nonmaxsuppression", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto output = insert_allocation(ins, s);
std::vector<instruction_ref> cpu_inputs;
auto inputs = ins->inputs();
std::transform(
inputs.begin(), inputs.end(), std::back_inserter(cpu_inputs), [&](auto in) {
return mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), in);
});
cpu_inputs.front() =
mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_inputs);
auto cpu_out = mod->insert_instruction(ins, ins->get_operator(), cpu_inputs);
auto gpu_out =
mod->insert_instruction(ins, make_op("hip::copy_to_gpu"), cpu_out, output);
return mod->replace_instruction(ins, gpu_out);
});
}
};

void lowering::apply(module& m) const { miopen_apply{&m, this}.apply(); }
Expand Down
25 changes: 25 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2771,6 +2771,31 @@ def neg_test():
return ([node], [x], [y])


@onnx_test
def nms_test():
b = helper.make_tensor_value_info('boxes', TensorProto.FLOAT, [1, 6, 4])
s = helper.make_tensor_value_info('scores', TensorProto.FLOAT, [1, 1, 6])
mo = helper.make_tensor_value_info('max_output_boxes_per_class',
TensorProto.INT64, [1])
iou = helper.make_tensor_value_info('iou_threshold', TensorProto.FLOAT,
[1])
st = helper.make_tensor_value_info('score_threshold', TensorProto.FLOAT,
[1])
out = helper.make_tensor_value_info('selected_indices', TensorProto.INT64,
[6, 3])

node = onnx.helper.make_node('NonMaxSuppression',
inputs=[
'boxes', 'scores',
'max_output_boxes_per_class',
'iou_threshold', 'score_threshold'
],
outputs=['selected_indices'],
center_point_box=1)

return ([node], [b, s, mo, iou, st], [out])


@onnx_test
def not_test():
x = helper.make_tensor_value_info('0', TensorProto.INT32, [4])
Expand Down
34 changes: 34 additions & 0 deletions test/onnx/nms_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
nms_test:�
�
boxes
scores
max_output_boxes_per_class
iou_threshold
score_thresholdselected_indices"NonMaxSuppression*
center_point_box�nms_testZ
boxes



Z
scores



Z(
max_output_boxes_per_class


Z
iou_threshold


Z
score_threshold


b"
selected_indices


B
Expand Down
Loading

0 comments on commit c98b22d

Please sign in to comment.