Skip to content

Commit

Permalink
Merge 20a953a into 3c9df3b
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 committed Jul 2, 2023
2 parents 3c9df3b + 20a953a commit 9dfe1fb
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs);
}
std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) {
// uses previous multibroadcast to avoid recalculating the common shape from the
// uses previous input to avoid recalculating the common shape from the
// full set of input shapes at runtime
if(input->get_shape().dyn_dims() != c_dyn_dims)
{
Expand Down
11 changes: 6 additions & 5 deletions src/include/migraphx/op/clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
#define MIGRAPHX_GUARD_OPERATORS_CLIP_HPP

#include <array>
#include <cmath>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <cmath>
#include <migraphx/dyn_output.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -48,15 +49,15 @@ struct clip

shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).same_type().same_dims();
check_shapes{inputs, *this, true}.has(3).same_type().same_dims();
return inputs.front();
}

argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
visit_all(result, args[0], args[1], args[2])([&](auto output, auto x, auto min, auto max) {
par_for(output_shape.elements(),
par_for(dyn_out.computed_shape.elements(),
[&](auto i) { output[i] = std::min(std::max(min[i], x[i]), max[i]); });
});

Expand Down
Binary file added test/onnx/clip_dyn_min_max_test.onnx
Binary file not shown.
Binary file added test/onnx/clip_dyn_min_only_test.onnx
Binary file not shown.
27 changes: 27 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,33 @@ def clip_test_args_type_mismatch():
return ([node], [x], [y], [min_val, max_val])


@onnx_test()
def clip_dyn_min_max_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None])

min_val = helper.make_tensor('min', TensorProto.FLOAT, [], [0.0])
max_val = helper.make_tensor('max', TensorProto.FLOAT, [], [6.0])

node = onnx.helper.make_node('Clip',
inputs=['0', 'min', 'max'],
outputs=['1'])

return ([node], [x], [y], [min_val, max_val])


@onnx_test()
def clip_dyn_min_only_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None])

min_val = helper.make_tensor('min', TensorProto.FLOAT, [], [0.0])

node = onnx.helper.make_node('Clip', inputs=['0', 'min'], outputs=['1'])

return ([node], [x], [y], [min_val])


@onnx_test()
def concat_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 4, 3])
Expand Down
41 changes: 41 additions & 0 deletions test/onnx/onnx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,47 @@ TEST_CASE(clip_test_args_type_mismatch)
EXPECT(p == prog);
}

TEST_CASE(clip_dyn_min_max_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
std::vector<migraphx::shape::dynamic_dimension> dds = {{2, 8, {3}}};
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, dds});
min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_dyn_dims", to_value(dds)}}), min_val, l0);
max_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_dyn_dims", to_value(dds)}}), max_val, l0);
auto ret = mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
mm->add_return({ret});

migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 8, {3}};
auto prog = parse_onnx("clip_dyn_min_max_test.onnx", options);

EXPECT(p == prog);
}

TEST_CASE(clip_dyn_min_only_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto min_val = mm->add_literal(0.0f);
std::vector<migraphx::shape::dynamic_dimension> dds = {{2, 8, {3}}};
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, dds});
min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_dyn_dims", to_value(dds)}}), min_val, l0);
auto ret = mm->add_instruction(migraphx::make_op("max"), l0, min_val);
mm->add_return({ret});

migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 8, {3}};
auto prog = parse_onnx("clip_dyn_min_only_test.onnx", options);

EXPECT(p == prog);
}

TEST_CASE(concat_test)
{
migraphx::program p;
Expand Down
25 changes: 25 additions & 0 deletions test/ref_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,31 @@ TEST_CASE(clip_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}

TEST_CASE(clip_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dds = {{2, 8, {3}}};
migraphx::shape s{migraphx::shape::float_type, dds};
auto l = mm->add_parameter("X", s);
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction(migraphx::make_op("multibroadcast"), min_val, l);
max_val = mm->add_instruction(migraphx::make_op("multibroadcast"), max_val, l);
mm->add_instruction(migraphx::make_op("clip"), l, min_val, max_val);
p.compile(migraphx::make_target("ref"));

migraphx::shape static_shape{migraphx::shape::float_type, {3}};
migraphx::parameter_map params;
std::vector<float> data = {-1.0, 0.0, 10.0};
params["X"] = migraphx::argument(static_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0, 0.0, 6.0};
EXPECT(migraphx::verify_range(results_vector, gold));
}

TEST_CASE(concat_test)
{
{
Expand Down

0 comments on commit 9dfe1fb

Please sign in to comment.