diff --git a/src/common.cpp b/src/common.cpp index e73fee9c1fa..2a16a11f959 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -157,7 +157,7 @@ insert_common_args(module& m, instruction_ref ins, std::vector ins, make_op("multibroadcast", {{"out_dyn_dims", to_value(c_dyn_dims)}}), inputs); } std::transform(inputs.begin() + 1, inputs.end(), inputs.begin() + 1, [&](auto input) { - // uses previous multibroadcast to avoid recalculating the common shape from the + // uses previous input to avoid recalculating the common shape from the // full set of input shapes at runtime if(input->get_shape().dyn_dims() != c_dyn_dims) { diff --git a/src/include/migraphx/op/clip.hpp b/src/include/migraphx/op/clip.hpp index c6221cb1f30..379797a869a 100644 --- a/src/include/migraphx/op/clip.hpp +++ b/src/include/migraphx/op/clip.hpp @@ -25,12 +25,13 @@ #define MIGRAPHX_GUARD_OPERATORS_CLIP_HPP #include +#include #include #include #include #include #include -#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -48,15 +49,15 @@ struct clip shape compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(3).same_type().same_dims(); + check_shapes{inputs, *this, true}.has(3).same_type().same_dims(); return inputs.front(); } - argument compute(const shape& output_shape, std::vector args) const + argument compute(const dyn_output& dyn_out, std::vector args) const { - argument result{output_shape}; + argument result{dyn_out.computed_shape}; visit_all(result, args[0], args[1], args[2])([&](auto output, auto x, auto min, auto max) { - par_for(output_shape.elements(), + par_for(dyn_out.computed_shape.elements(), [&](auto i) { output[i] = std::min(std::max(min[i], x[i]), max[i]); }); }); diff --git a/test/onnx/clip_dyn_min_max_test.onnx b/test/onnx/clip_dyn_min_max_test.onnx new file mode 100644 index 00000000000..758a7086d64 Binary files /dev/null and b/test/onnx/clip_dyn_min_max_test.onnx differ diff --git a/test/onnx/clip_dyn_min_only_test.onnx b/test/onnx/clip_dyn_min_only_test.onnx new file mode 100644 index 00000000000..dae03347048 Binary files /dev/null and b/test/onnx/clip_dyn_min_only_test.onnx differ diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index a224aa1b9df..3f029f45e18 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -743,6 +743,33 @@ def clip_test_args_type_mismatch(): return ([node], [x], [y], [min_val, max_val]) +@onnx_test() +def clip_dyn_min_max_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None]) + + min_val = helper.make_tensor('min', TensorProto.FLOAT, [], [0.0]) + max_val = helper.make_tensor('max', TensorProto.FLOAT, [], [6.0]) + + node = onnx.helper.make_node('Clip', + inputs=['0', 'min', 'max'], + outputs=['1']) + + return ([node], [x], [y], [min_val, max_val]) + + +@onnx_test() +def clip_dyn_min_only_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None]) + + min_val = helper.make_tensor('min', TensorProto.FLOAT, [], [0.0]) + + node = onnx.helper.make_node('Clip', inputs=['0', 'min'], outputs=['1']) + + return ([node], [x], [y], [min_val]) + + @onnx_test() def concat_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 4, 3]) diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index b8fad810305..15199959a20 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -823,6 +823,47 @@ TEST_CASE(clip_test_args_type_mismatch) EXPECT(p == prog); } +TEST_CASE(clip_dyn_min_max_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto min_val = mm->add_literal(0.0f); + auto max_val = mm->add_literal(6.0f); + std::vector dds = {{2, 8, {3}}}; + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, dds}); + min_val = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_dyn_dims", to_value(dds)}}), min_val, l0); + max_val = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_dyn_dims", to_value(dds)}}), max_val, l0); + auto ret = mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val); + mm->add_return({ret}); + + migraphx::onnx_options options; + options.default_dyn_dim_value = {2, 8, {3}}; + auto prog = parse_onnx("clip_dyn_min_max_test.onnx", options); + + EXPECT(p == prog); +} + +TEST_CASE(clip_dyn_min_only_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto min_val = mm->add_literal(0.0f); + std::vector dds = {{2, 8, {3}}}; + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, dds}); + min_val = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_dyn_dims", to_value(dds)}}), min_val, l0); + auto ret = mm->add_instruction(migraphx::make_op("max"), l0, min_val); + mm->add_return({ret}); + + migraphx::onnx_options options; + options.default_dyn_dim_value = {2, 8, {3}}; + auto prog = parse_onnx("clip_dyn_min_only_test.onnx", options); + + EXPECT(p == prog); +} + TEST_CASE(concat_test) { migraphx::program p; diff --git a/test/ref_ops_test.cpp b/test/ref_ops_test.cpp index a1bb7b050d9..8b67b301f4d 100644 --- a/test/ref_ops_test.cpp +++ b/test/ref_ops_test.cpp @@ -850,6 +850,31 @@ TEST_CASE(clip_test) EXPECT(migraphx::verify_range(results_vector, gold)); } +TEST_CASE(clip_dyn_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector dds = {{2, 8, {3}}}; + migraphx::shape s{migraphx::shape::float_type, dds}; + auto l = mm->add_parameter("X", s); + auto min_val = mm->add_literal(0.0f); + auto max_val = mm->add_literal(6.0f); + min_val = mm->add_instruction(migraphx::make_op("multibroadcast"), min_val, l); + max_val = mm->add_instruction(migraphx::make_op("multibroadcast"), max_val, l); + mm->add_instruction(migraphx::make_op("clip"), l, min_val, max_val); + p.compile(migraphx::make_target("ref")); + + migraphx::shape static_shape{migraphx::shape::float_type, {3}}; + migraphx::parameter_map params; + std::vector data = {-1.0, 0.0, 10.0}; + params["X"] = migraphx::argument(static_shape, data.data()); + auto result = p.eval(params).back(); + std::vector results_vector(3); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {0.0, 0.0, 6.0}; + EXPECT(migraphx::verify_range(results_vector, gold)); +} + TEST_CASE(concat_test) { {