From f06113976790912b1d1d1224636a67a0c94e1840 Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 22 Feb 2024 21:02:17 +0000 Subject: [PATCH 1/8] initial const dot matcher work --- src/simplify_reshapes.cpp | 102 ++++++++++++++++++++++++++++++++ test/simplify_reshapes_test.cpp | 28 ++++++++- 2 files changed, 128 insertions(+), 2 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index b8b0439273c..eb5031b9b1a 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -900,6 +900,38 @@ struct find_scalar_multibroadcast_reshape_or_transpose } }; +// Remove unnecessary preceeding size 1 dims for constants +struct find_const_broadcast +{ + auto matcher() const + { + return match::name("multibroadcast", "broadcast")( + match::arg(0)(match::is_constant()(match::used_once()).bind("constant"))); + } + + void apply(module& m, const match::matcher_result& mr) const + { + auto mbr = mr.result; + auto constant = mr.instructions["constant"]; + + if(constant->get_shape().scalar()) + return; + + auto const_lens = constant->get_shape().lens(); + auto it = std::find_if(const_lens.begin(), const_lens.end(), [](auto i) { return i != 1; }); + auto naxes = std::distance(const_lens.begin(), it); + if(naxes == 0) + return; + + std::vector sq_axes(naxes); + std::iota(sq_axes.begin(), sq_axes.end(), 0); + + auto sq_const = + m.insert_instruction(mbr, make_op("squeeze", {{"axes", sq_axes}}), constant); + m.replace_instruction(mbr, mbr->get_operator(), sq_const); + } +}; + struct find_reshape_reshape_dot { auto matcher() const @@ -949,6 +981,74 @@ struct find_reshape_reshape_dot } }; +struct find_reshape_const_dot +{ + auto matcher() const + { + return match::name("dot")( + match::used_once(), + match::args(match::skip(match::name("convert").bind("convert"))( + match::name("reshape").bind("rsp")), + match::skip_broadcasts(match::is_constant().bind("constant")))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto dot = r.result; + auto rsp = r.instructions["rsp"]; + auto constant = r.instructions["constant"]; + + auto const_lens = constant->get_shape().lens(); + if(const_lens.size() > 2) + return; + + auto rsp_lens = rsp->get_shape().lens(); + auto inp = rsp->inputs().front(); + auto inp_lens = inp->get_shape().lens(); + + // Gemm axis should not be altered by the reshape + if(rsp_lens.back() != inp_lens.back()) + return; + + if(contains(r.instructions, "convert")) + { + auto convert = r.instructions["convert"]; + inp = m.insert_instruction(dot, convert->get_operator(), inp); + rsp = m.insert_instruction(dot, rsp->get_operator(), inp); + } + + std::vector new_const_lens{inp_lens.begin(), inp_lens.end() - 2}; + migraphx::operation new_bc_op; + + auto bc_const = dot->inputs().back(); + auto bc_const_lens = bc_const->get_shape().lens(); + new_const_lens.insert(new_const_lens.end(), bc_const_lens.end() - 2, bc_const_lens.end()); + + // if the orignal weight is one dimensional, look at the original broadcast + // to determine the correct broadcast axis + if(const_lens.size() == 1) + { + auto bc_const_strides = bc_const->get_shape().strides(); + auto it = std::find_if( + bc_const_strides.begin(), bc_const_strides.end(), [&](auto i) { return i != 0; }); + auto orig_bc_axis = std::distance(bc_const_strides.begin(), it); + + auto new_bc_axis = new_const_lens.size() - (bc_const_lens.size() - orig_bc_axis); + new_bc_op = migraphx::make_op("broadcast", + {{"axis", new_bc_axis}, {"out_lens", new_const_lens}}); + } + else + { + new_bc_op = migraphx::make_op("multibroadcast", {{"out_lens", new_const_lens}}); + } + + auto new_bc_const = m.insert_instruction(dot, new_bc_op, constant); + auto new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_bc_const); + m.replace_instruction( + dot, make_op("reshape", {{"dims", dot->get_shape().lens()}}), new_dot); + } +}; + void simplify_reshapes::apply(module& m) const { for(int i = 0; i < depth; i++) @@ -960,6 +1060,7 @@ void simplify_reshapes::apply(module& m) const find_reshaper{}, find_reshape_cont{}, find_transpose{}, + find_const_broadcast{}, find_concat_slice{}, find_concat_transpose{}, find_concat_multibroadcasts{}, @@ -969,6 +1070,7 @@ void simplify_reshapes::apply(module& m) const find_broadcast_transpose{}, find_slice_transpose{}, find_transpose_contiguous_reshaper_unary{}, + find_reshape_const_dot{}, find_reshape_reshape_dot{}, find_scalar_multibroadcast_reshape_or_transpose{}); dead_code_elimination{}.apply(m); diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index abf3928a14e..0a48b3e7062 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -1665,8 +1665,8 @@ TEST_CASE(transpose_contiguous_squeeze_unary) auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}}); auto transpose_ins = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); - auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins); - auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), rsqrt); + auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins); + auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), rsqrt); m2.add_instruction(pass_op{}, sq_ins); } EXPECT(m1 == m2); @@ -2087,4 +2087,28 @@ TEST_CASE(reshape_reshape_dot_gemm_axis) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(const_multibroadcast) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 64, 1}}; + migraphx::module m1; + { + auto a = m1.add_literal(migraphx::generate_literal(s)); + auto mbc = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {32, 64, 64}}}), a); + m1.add_return({mbc}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto a = m2.add_literal(migraphx::generate_literal(s)); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), a); + auto mbc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {32, 64, 64}}}), sq); + m2.add_return({mbc}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From cdc7977d536ccf5d578a1c20fff177988bceee7e Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 22 Feb 2024 16:59:00 -0800 Subject: [PATCH 2/8] rsp_const_dot matcher --- src/simplify_reshapes.cpp | 38 ++++-- test/simplify_reshapes_test.cpp | 227 +++++++++++++++++++++++++++++++- 2 files changed, 246 insertions(+), 19 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index eb5031b9b1a..362d67313ab 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -901,11 +901,11 @@ struct find_scalar_multibroadcast_reshape_or_transpose }; // Remove unnecessary preceeding size 1 dims for constants -struct find_const_broadcast +struct find_const_multibroadcast { auto matcher() const { - return match::name("multibroadcast", "broadcast")( + return match::name("multibroadcast")( match::arg(0)(match::is_constant()(match::used_once()).bind("constant"))); } @@ -987,9 +987,9 @@ struct find_reshape_const_dot { return match::name("dot")( match::used_once(), - match::args(match::skip(match::name("convert").bind("convert"))( - match::name("reshape").bind("rsp")), - match::skip_broadcasts(match::is_constant().bind("constant")))); + match::either_arg(0, 1)(match::skip(match::name("convert").bind("convert"))( + match::name("reshape").bind("rsp")), + match::skip_broadcasts(match::is_constant().bind("constant")))); } void apply(module& m, const match::matcher_result& r) const @@ -1007,7 +1007,9 @@ struct find_reshape_const_dot auto inp_lens = inp->get_shape().lens(); // Gemm axis should not be altered by the reshape - if(rsp_lens.back() != inp_lens.back()) + bool flipped = rsp == dot->inputs().back(); + size_t dot_axis = (flipped) ? -2 : -1; + if(rsp_lens.end()[dot_axis] != inp_lens.end()[dot_axis]) return; if(contains(r.instructions, "convert")) @@ -1018,9 +1020,9 @@ struct find_reshape_const_dot } std::vector new_const_lens{inp_lens.begin(), inp_lens.end() - 2}; - migraphx::operation new_bc_op; + operation new_bc_op; - auto bc_const = dot->inputs().back(); + auto bc_const = (flipped) ? dot->inputs().front() : dot->inputs().back(); auto bc_const_lens = bc_const->get_shape().lens(); new_const_lens.insert(new_const_lens.end(), bc_const_lens.end() - 2, bc_const_lens.end()); @@ -1034,16 +1036,24 @@ struct find_reshape_const_dot auto orig_bc_axis = std::distance(bc_const_strides.begin(), it); auto new_bc_axis = new_const_lens.size() - (bc_const_lens.size() - orig_bc_axis); - new_bc_op = migraphx::make_op("broadcast", - {{"axis", new_bc_axis}, {"out_lens", new_const_lens}}); + new_bc_op = make_op("broadcast", {{"axis", new_bc_axis}, {"out_lens", new_const_lens}}); } else { - new_bc_op = migraphx::make_op("multibroadcast", {{"out_lens", new_const_lens}}); + new_bc_op = make_op("multibroadcast", {{"out_lens", new_const_lens}}); } auto new_bc_const = m.insert_instruction(dot, new_bc_op, constant); - auto new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_bc_const); + + instruction_ref new_dot; + if(flipped) + { + new_dot = m.insert_instruction(dot, make_op("dot"), new_bc_const, inp); + } + else + { + new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_bc_const); + } m.replace_instruction( dot, make_op("reshape", {{"dims", dot->get_shape().lens()}}), new_dot); } @@ -1060,7 +1070,7 @@ void simplify_reshapes::apply(module& m) const find_reshaper{}, find_reshape_cont{}, find_transpose{}, - find_const_broadcast{}, + find_const_multibroadcast{}, find_concat_slice{}, find_concat_transpose{}, find_concat_multibroadcasts{}, @@ -1070,8 +1080,8 @@ void simplify_reshapes::apply(module& m) const find_broadcast_transpose{}, find_slice_transpose{}, find_transpose_contiguous_reshaper_unary{}, - find_reshape_const_dot{}, find_reshape_reshape_dot{}, + find_reshape_const_dot{}, find_scalar_multibroadcast_reshape_or_transpose{}); dead_code_elimination{}.apply(m); } diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 0a48b3e7062..6065f65c035 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -2089,12 +2089,12 @@ TEST_CASE(reshape_reshape_dot_gemm_axis) TEST_CASE(const_multibroadcast) { - migraphx::shape s{migraphx::shape::float_type, {1, 64, 1}}; + migraphx::shape s{migraphx::shape::float_type, {1, 1, 64, 1}}; migraphx::module m1; { auto a = m1.add_literal(migraphx::generate_literal(s)); auto mbc = m1.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {32, 64, 64}}}), a); + migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 64, 64}}}), a); m1.add_return({mbc}); }; run_pass(m1); @@ -2102,12 +2102,229 @@ TEST_CASE(const_multibroadcast) migraphx::module m2; { auto a = m2.add_literal(migraphx::generate_literal(s)); - auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), a); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 1}}}), a); auto mbc = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {32, 64, 64}}}), sq); + migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 64, 64}}}), sq); m2.add_return({mbc}); }; - + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(const_multibroadcast_no_apply) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 1, 64, 1}}; + migraphx::module m1; + { + auto a = m1.add_literal(migraphx::generate_literal(s)); + auto mbc = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 64, 64}}}), a); + m1.add_return({mbc}); + }; + + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reshape_const_dot) +{ + migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; + migraphx::shape s_w{migraphx::shape::float_type, {32, 32}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("inp", s_inp); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp); + auto w = m1.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 32}}}), w); + auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc); + m1.add_return({dot}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto inp = m2.add_parameter("inp", s_inp); + auto w = m2.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 32, 32}}}), w); + auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), dot); + m2.add_return({rsp}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reshape_const_dot_flipped) +{ + migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; + migraphx::shape s_w{migraphx::shape::float_type, {16, 8}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("inp", s_inp); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {16, 8, 32}}}), inp); + auto w = m1.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16, 16, 8}}}), w); + auto dot = m1.add_instruction(migraphx::make_op("dot"), w_bc, rsp); + m1.add_return({dot}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto inp = m2.add_parameter("inp", s_inp); + auto w = m2.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 16, 8}}}), w); + auto dot = m2.add_instruction(migraphx::make_op("dot"), w_bc, inp); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {16, 16, 32}}}), dot); + m2.add_return({rsp}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reshape_const_dot_dot_axis) +{ + migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 4}}; + migraphx::shape s_w{migraphx::shape::float_type, {32, 32}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("inp", s_inp); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 8, 32}}}), inp); + auto w = m1.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 32}}}), w); + auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc); + m1.add_return({dot}); + }; + + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reshape_const_dot_flipped_dot_axis) +{ + migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; + migraphx::shape s_w{migraphx::shape::float_type, {8, 64}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("inp", s_inp); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp); + auto w = m1.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 64}}}), w); + auto dot = m1.add_instruction(migraphx::make_op("dot"), w_bc, rsp); + m1.add_return({dot}); + }; + + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reshape_const_dot_broadcast) +{ + migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; + migraphx::shape s_w{migraphx::shape::float_type, {32}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("inp", s_inp); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp); + auto w = m1.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 32, 32}}}), w); + auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc); + m1.add_return({dot}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto inp = m2.add_parameter("inp", s_inp); + auto w = m2.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", {2, 8, 32, 32}}}), w); + auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), dot); + m2.add_return({rsp}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reshape_const_dot_broadcast_2) +{ + migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; + migraphx::shape s_w{migraphx::shape::float_type, {32}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("inp", s_inp); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 32}}}), inp); + auto w = m1.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {32, 32}}}), w); + auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc); + m1.add_return({dot}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto inp = m2.add_parameter("inp", s_inp); + auto w = m2.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 3}, {"out_lens", {2, 8, 32, 32}}}), w); + auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 32}}}), dot); + m2.add_return({rsp}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reshape_const_dot_with_convert) +{ + migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; + migraphx::shape s_w{migraphx::shape::half_type, {32, 32}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("inp", s_inp); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp); + auto convert = m1.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), rsp); + auto w = m1.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 32}}}), w); + auto dot = m1.add_instruction(migraphx::make_op("dot"), convert, w_bc); + m1.add_return({dot}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto inp = m2.add_parameter("inp", s_inp); + auto convert = m2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), inp); + auto w = m2.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 32, 32}}}), w); + auto dot = m2.add_instruction(migraphx::make_op("dot"), convert, w_bc); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), dot); + m2.add_return({rsp}); + }; + EXPECT(m1.sort() == m2.sort()); } From f147dfcbab62054e9c6f967ee9e2355d8afc1423 Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 23 Feb 2024 21:50:38 +0000 Subject: [PATCH 3/8] remove handling of convert and fix breaking test case --- src/simplify_reshapes.cpp | 12 ++--------- test/simplify_reshapes_test.cpp | 35 --------------------------------- 2 files changed, 2 insertions(+), 45 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 362d67313ab..4ee6c159639 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -914,7 +914,7 @@ struct find_const_multibroadcast auto mbr = mr.result; auto constant = mr.instructions["constant"]; - if(constant->get_shape().scalar()) + if(constant->get_shape().lens().size() <= 1) return; auto const_lens = constant->get_shape().lens(); @@ -987,8 +987,7 @@ struct find_reshape_const_dot { return match::name("dot")( match::used_once(), - match::either_arg(0, 1)(match::skip(match::name("convert").bind("convert"))( - match::name("reshape").bind("rsp")), + match::either_arg(0, 1)(match::name("reshape").bind("rsp"), match::skip_broadcasts(match::is_constant().bind("constant")))); } @@ -1012,13 +1011,6 @@ struct find_reshape_const_dot if(rsp_lens.end()[dot_axis] != inp_lens.end()[dot_axis]) return; - if(contains(r.instructions, "convert")) - { - auto convert = r.instructions["convert"]; - inp = m.insert_instruction(dot, convert->get_operator(), inp); - rsp = m.insert_instruction(dot, rsp->get_operator(), inp); - } - std::vector new_const_lens{inp_lens.begin(), inp_lens.end() - 2}; operation new_bc_op; diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 6065f65c035..7e0060fbeb6 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -2293,39 +2293,4 @@ TEST_CASE(reshape_const_dot_broadcast_2) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(reshape_const_dot_with_convert) -{ - migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; - migraphx::shape s_w{migraphx::shape::half_type, {32, 32}}; - - migraphx::module m1; - { - auto inp = m1.add_parameter("inp", s_inp); - auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp); - auto convert = m1.add_instruction( - migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), rsp); - auto w = m1.add_literal(migraphx::generate_literal(s_w)); - auto w_bc = - m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 32}}}), w); - auto dot = m1.add_instruction(migraphx::make_op("dot"), convert, w_bc); - m1.add_return({dot}); - }; - run_pass(m1); - - migraphx::module m2; - { - auto inp = m2.add_parameter("inp", s_inp); - auto convert = m2.add_instruction( - migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), inp); - auto w = m2.add_literal(migraphx::generate_literal(s_w)); - auto w_bc = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 32, 32}}}), w); - auto dot = m2.add_instruction(migraphx::make_op("dot"), convert, w_bc); - auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), dot); - m2.add_return({rsp}); - }; - - EXPECT(m1.sort() == m2.sort()); -} - int main(int argc, const char* argv[]) { test::run(argc, argv); } From cf23737e9a5a53caea31ce09bdf36691253c9b2c Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 26 Feb 2024 18:34:22 +0000 Subject: [PATCH 4/8] remove const constraint in rehsape_dot matcher --- src/simplify_reshapes.cpp | 52 +++++++++++++++++---------------- test/simplify_reshapes_test.cpp | 12 ++++---- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 4ee6c159639..706ebe3d92e 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -981,24 +981,24 @@ struct find_reshape_reshape_dot } }; -struct find_reshape_const_dot +struct find_reshape_dot { auto matcher() const { return match::name("dot")( match::used_once(), match::either_arg(0, 1)(match::name("reshape").bind("rsp"), - match::skip_broadcasts(match::is_constant().bind("constant")))); + match::skip_broadcasts(match::any().bind("other")))); } void apply(module& m, const match::matcher_result& r) const { - auto dot = r.result; - auto rsp = r.instructions["rsp"]; - auto constant = r.instructions["constant"]; + auto dot = r.result; + auto rsp = r.instructions["rsp"]; + auto other = r.instructions["other"]; - auto const_lens = constant->get_shape().lens(); - if(const_lens.size() > 2) + auto other_lens = other->get_shape().lens(); + if(other_lens.size() > 2) return; auto rsp_lens = rsp->get_shape().lens(); @@ -1007,44 +1007,46 @@ struct find_reshape_const_dot // Gemm axis should not be altered by the reshape bool flipped = rsp == dot->inputs().back(); - size_t dot_axis = (flipped) ? -2 : -1; - if(rsp_lens.end()[dot_axis] != inp_lens.end()[dot_axis]) + size_t dot_axis = (flipped) ? 2 : 1; + + if(inp_lens.size() < dot_axis or + rsp_lens[rsp_lens.size() - dot_axis] != inp_lens[inp_lens.size() - dot_axis]) return; - std::vector new_const_lens{inp_lens.begin(), inp_lens.end() - 2}; + std::vector new_other_lens{inp_lens.begin(), inp_lens.end() - 2}; operation new_bc_op; - auto bc_const = (flipped) ? dot->inputs().front() : dot->inputs().back(); - auto bc_const_lens = bc_const->get_shape().lens(); - new_const_lens.insert(new_const_lens.end(), bc_const_lens.end() - 2, bc_const_lens.end()); + auto bc_other = (flipped) ? dot->inputs().front() : dot->inputs().back(); + auto bc_other_lens = bc_other->get_shape().lens(); + new_other_lens.insert(new_other_lens.end(), bc_other_lens.end() - 2, bc_other_lens.end()); - // if the orignal weight is one dimensional, look at the original broadcast + // if the original weight is one dimensional, look at the original broadcast // to determine the correct broadcast axis - if(const_lens.size() == 1) + if(other_lens.size() == 1) { - auto bc_const_strides = bc_const->get_shape().strides(); + auto bc_other_strides = bc_other->get_shape().strides(); auto it = std::find_if( - bc_const_strides.begin(), bc_const_strides.end(), [&](auto i) { return i != 0; }); - auto orig_bc_axis = std::distance(bc_const_strides.begin(), it); + bc_other_strides.begin(), bc_other_strides.end(), [&](auto i) { return i != 0; }); + auto orig_bc_axis = std::distance(bc_other_strides.begin(), it); - auto new_bc_axis = new_const_lens.size() - (bc_const_lens.size() - orig_bc_axis); - new_bc_op = make_op("broadcast", {{"axis", new_bc_axis}, {"out_lens", new_const_lens}}); + auto new_bc_axis = new_other_lens.size() - (bc_other_lens.size() - orig_bc_axis); + new_bc_op = make_op("broadcast", {{"axis", new_bc_axis}, {"out_lens", new_other_lens}}); } else { - new_bc_op = make_op("multibroadcast", {{"out_lens", new_const_lens}}); + new_bc_op = make_op("multibroadcast", {{"out_lens", new_other_lens}}); } - auto new_bc_const = m.insert_instruction(dot, new_bc_op, constant); + auto new_bc_other = m.insert_instruction(dot, new_bc_op, other); instruction_ref new_dot; if(flipped) { - new_dot = m.insert_instruction(dot, make_op("dot"), new_bc_const, inp); + new_dot = m.insert_instruction(dot, make_op("dot"), new_bc_other, inp); } else { - new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_bc_const); + new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_bc_other); } m.replace_instruction( dot, make_op("reshape", {{"dims", dot->get_shape().lens()}}), new_dot); @@ -1073,7 +1075,7 @@ void simplify_reshapes::apply(module& m) const find_slice_transpose{}, find_transpose_contiguous_reshaper_unary{}, find_reshape_reshape_dot{}, - find_reshape_const_dot{}, + find_reshape_dot{}, find_scalar_multibroadcast_reshape_or_transpose{}); dead_code_elimination{}.apply(m); } diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 7e0060fbeb6..8a8b1c0c6a4 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -2127,7 +2127,7 @@ TEST_CASE(const_multibroadcast_no_apply) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(reshape_const_dot) +TEST_CASE(reshape_dot) { migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; migraphx::shape s_w{migraphx::shape::float_type, {32, 32}}; @@ -2158,7 +2158,7 @@ TEST_CASE(reshape_const_dot) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(reshape_const_dot_flipped) +TEST_CASE(reshape_dot_flipped) { migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; migraphx::shape s_w{migraphx::shape::float_type, {16, 8}}; @@ -2189,7 +2189,7 @@ TEST_CASE(reshape_const_dot_flipped) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(reshape_const_dot_dot_axis) +TEST_CASE(reshape_dot_dot_axis) { migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 4}}; migraphx::shape s_w{migraphx::shape::float_type, {32, 32}}; @@ -2210,7 +2210,7 @@ TEST_CASE(reshape_const_dot_dot_axis) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(reshape_const_dot_flipped_dot_axis) +TEST_CASE(reshape_dot_flipped_dot_axis) { migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; migraphx::shape s_w{migraphx::shape::float_type, {8, 64}}; @@ -2231,7 +2231,7 @@ TEST_CASE(reshape_const_dot_flipped_dot_axis) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(reshape_const_dot_broadcast) +TEST_CASE(reshape_dot_broadcast) { migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; migraphx::shape s_w{migraphx::shape::float_type, {32}}; @@ -2262,7 +2262,7 @@ TEST_CASE(reshape_const_dot_broadcast) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(reshape_const_dot_broadcast_2) +TEST_CASE(reshape_dot_broadcast_2) { migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; migraphx::shape s_w{migraphx::shape::float_type, {32}}; From 6819b956c26e3fc16758298336887f0b1923889a Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 26 Feb 2024 18:45:09 +0000 Subject: [PATCH 5/8] remove used once constraint for const_multibroadcast --- src/simplify_reshapes.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 706ebe3d92e..59600874325 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -905,8 +905,7 @@ struct find_const_multibroadcast { auto matcher() const { - return match::name("multibroadcast")( - match::arg(0)(match::is_constant()(match::used_once()).bind("constant"))); + return match::name("multibroadcast")(match::arg(0)(match::is_constant().bind("constant"))); } void apply(module& m, const match::matcher_result& mr) const From d06e4f8f941fcfc3164685f63805faa506ded73b Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 26 Feb 2024 23:05:44 +0000 Subject: [PATCH 6/8] add matcher to move convert before reshapes --- src/simplify_reshapes.cpp | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 59600874325..5ad2aecee09 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -980,6 +980,31 @@ struct find_reshape_reshape_dot } }; +// Move convert before shape manipulation ops so shape manipulations can be matched with +// following ops +struct find_shape_op_convert +{ + auto matcher() const + { + auto shape_ops = reshaper_names(); + shape_ops.insert("transpose"); + shape_ops.insert("broadcast"); + shape_ops.insert("multibroadcast"); + + return match::name("convert")(match::arg(0)(match::name(shape_ops).bind("shape_op"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto convert = r.result; + auto shape_op = r.instructions["shape_op"]; + + auto inp = shape_op->inputs().front(); + auto new_convert = m.insert_instruction(shape_op, convert->get_operator(), inp); + m.replace_instruction(convert, shape_op->get_operator(), new_convert); + } +}; + struct find_reshape_dot { auto matcher() const @@ -1072,6 +1097,7 @@ void simplify_reshapes::apply(module& m) const find_transpose_slice{}, find_broadcast_transpose{}, find_slice_transpose{}, + find_shape_op_convert{}, find_transpose_contiguous_reshaper_unary{}, find_reshape_reshape_dot{}, find_reshape_dot{}, From efa81d9429a163c723191c1f4643afadee9be507 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 27 Feb 2024 03:39:59 +0000 Subject: [PATCH 7/8] change reshape_convert matcher to only apply when preceeding dot --- src/simplify_reshapes.cpp | 26 ++++++++++------------ test/simplify_reshapes_test.cpp | 38 +++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 5ad2aecee09..345a616ddf0 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -980,28 +980,24 @@ struct find_reshape_reshape_dot } }; -// Move convert before shape manipulation ops so shape manipulations can be matched with -// following ops -struct find_shape_op_convert +// Move convert before reshape when preceeding a dot op +struct find_reshape_convert_dot { auto matcher() const { - auto shape_ops = reshaper_names(); - shape_ops.insert("transpose"); - shape_ops.insert("broadcast"); - shape_ops.insert("multibroadcast"); - - return match::name("convert")(match::arg(0)(match::name(shape_ops).bind("shape_op"))); + auto dot_output = match::any_of[match::outputs()](match::name("dot")); + return match::name("convert")(match::arg(0)(match::name("reshape").bind("rsp")), + dot_output); } void apply(module& m, const match::matcher_result& r) const { - auto convert = r.result; - auto shape_op = r.instructions["shape_op"]; + auto convert = r.result; + auto rsp = r.instructions["rsp"]; - auto inp = shape_op->inputs().front(); - auto new_convert = m.insert_instruction(shape_op, convert->get_operator(), inp); - m.replace_instruction(convert, shape_op->get_operator(), new_convert); + auto inp = rsp->inputs().front(); + auto new_convert = m.insert_instruction(rsp, convert->get_operator(), inp); + m.replace_instruction(convert, rsp->get_operator(), new_convert); } }; @@ -1097,7 +1093,7 @@ void simplify_reshapes::apply(module& m) const find_transpose_slice{}, find_broadcast_transpose{}, find_slice_transpose{}, - find_shape_op_convert{}, + find_reshape_convert_dot{}, find_transpose_contiguous_reshaper_unary{}, find_reshape_reshape_dot{}, find_reshape_dot{}, diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 8a8b1c0c6a4..2f552a21756 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -2293,4 +2293,42 @@ TEST_CASE(reshape_dot_broadcast_2) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(reshape_convert_dot) +{ + migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}}; + migraphx::shape s_w{migraphx::shape::half_type, {32, 32}}; + + migraphx::module m1; + { + auto inp = m1.add_parameter("inp", s_inp); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp); + auto convert = m1.add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::half_type)}}), + rsp); + auto w = m1.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = + m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 32}}}), w); + auto dot = m1.add_instruction(migraphx::make_op("dot"), convert, w_bc); + m1.add_return({dot}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto inp = m2.add_parameter("inp", s_inp); + auto convert = m2.add_instruction( + migraphx::make_op("convert", + {{"target_type", migraphx::to_value(migraphx::shape::half_type)}}), + inp); + auto w = m2.add_literal(migraphx::generate_literal(s_w)); + auto w_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 32, 32}}}), w); + auto dot = m2.add_instruction(migraphx::make_op("dot"), convert, w_bc); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), dot); + m2.add_return({rsp}); + }; + + EXPECT(m1.sort() == m2.sort()); +} int main(int argc, const char* argv[]) { test::run(argc, argv); } From 6dc309cf90bac77565bdbfbc3b2848112ab4a030 Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 4 Mar 2024 19:06:06 +0000 Subject: [PATCH 8/8] combine reshape-dot matchers --- src/simplify_reshapes.cpp | 141 +++++++++++++++++--------------------- 1 file changed, 64 insertions(+), 77 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 345a616ddf0..c10182280b2 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -931,55 +931,6 @@ struct find_const_multibroadcast } }; -struct find_reshape_reshape_dot -{ - auto matcher() const - { - return match::name("dot")(match::used_once(), - match::args(match::name("reshape").bind("inp_rsp1"), - match::name("reshape").bind("inp_rsp2"))); - } - - // Gemm axis should not be altered by the reshape - auto is_valid_reshape(instruction_ref in, instruction_ref rsp) const - { - auto in_lens = in->get_shape().lens(); - auto rsp_lens = rsp->get_shape().lens(); - - return std::equal(rsp_lens.end() - 2, rsp_lens.end(), in_lens.end() - 2, in_lens.end()); - } - - // Batch dims should match for both inputs - auto is_valid_inputs(instruction_ref in1, instruction_ref in2) const - { - auto in1_lens = in1->get_shape().lens(); - auto in2_lens = in2->get_shape().lens(); - - return ( - in1_lens.size() == in2_lens.size() and - std::equal(in1_lens.begin(), in1_lens.end() - 2, in2_lens.begin(), in2_lens.end() - 2)); - } - - void apply(module& m, const match::matcher_result& r) const - { - auto dot = r.result; - auto inp_rsp1 = r.instructions["inp_rsp1"]; - auto inp_rsp2 = r.instructions["inp_rsp2"]; - - auto dot_lens = dot->get_shape().lens(); - - auto inp1 = inp_rsp1->inputs().front(); - auto inp2 = inp_rsp2->inputs().front(); - - if(not(is_valid_reshape(inp1, inp_rsp1) and is_valid_reshape(inp2, inp_rsp2) and - is_valid_inputs(inp1, inp2))) - return; - - auto new_dot = m.insert_instruction(dot, dot->get_operator(), inp1, inp2); - m.replace_instruction(dot, make_op("reshape", {{"dims", dot_lens}}), new_dot); - } -}; - // Move convert before reshape when preceeding a dot op struct find_reshape_convert_dot { @@ -1011,16 +962,33 @@ struct find_reshape_dot match::skip_broadcasts(match::any().bind("other")))); } + // Gemm axis should not be altered by the reshape + auto is_valid_reshape(instruction_ref inp, instruction_ref rsp, size_t dot_axis) const + { + auto inp_lens = inp->get_shape().lens(); + auto rsp_lens = rsp->get_shape().lens(); + + return (inp_lens.size() >= dot_axis and + rsp_lens[rsp_lens.size() - dot_axis] == inp_lens[inp_lens.size() - dot_axis]); + } + + // Same batch dims + auto has_same_batch_dims(instruction_ref in1, instruction_ref in2) const + { + auto in1_lens = in1->get_shape().lens(); + auto in2_lens = in2->get_shape().lens(); + + return ( + in1_lens.size() == in2_lens.size() and + std::equal(in1_lens.begin(), in1_lens.end() - 2, in2_lens.begin(), in2_lens.end() - 2)); + } + void apply(module& m, const match::matcher_result& r) const { auto dot = r.result; auto rsp = r.instructions["rsp"]; auto other = r.instructions["other"]; - auto other_lens = other->get_shape().lens(); - if(other_lens.size() > 2) - return; - auto rsp_lens = rsp->get_shape().lens(); auto inp = rsp->inputs().front(); auto inp_lens = inp->get_shape().lens(); @@ -1029,44 +997,64 @@ struct find_reshape_dot bool flipped = rsp == dot->inputs().back(); size_t dot_axis = (flipped) ? 2 : 1; - if(inp_lens.size() < dot_axis or - rsp_lens[rsp_lens.size() - dot_axis] != inp_lens[inp_lens.size() - dot_axis]) + if(not is_valid_reshape(inp, rsp, dot_axis)) return; - std::vector new_other_lens{inp_lens.begin(), inp_lens.end() - 2}; - operation new_bc_op; - - auto bc_other = (flipped) ? dot->inputs().front() : dot->inputs().back(); - auto bc_other_lens = bc_other->get_shape().lens(); - new_other_lens.insert(new_other_lens.end(), bc_other_lens.end() - 2, bc_other_lens.end()); - - // if the original weight is one dimensional, look at the original broadcast - // to determine the correct broadcast axis - if(other_lens.size() == 1) + instruction_ref new_other; + if(other->get_operator().name() == "reshape") { - auto bc_other_strides = bc_other->get_shape().strides(); - auto it = std::find_if( - bc_other_strides.begin(), bc_other_strides.end(), [&](auto i) { return i != 0; }); - auto orig_bc_axis = std::distance(bc_other_strides.begin(), it); + auto other_inp = other->inputs().front(); + size_t other_dot_axis = (flipped) ? 1 : 2; + if(not is_valid_reshape(other_inp, other, other_dot_axis) or + not has_same_batch_dims(inp, other_inp)) + return; - auto new_bc_axis = new_other_lens.size() - (bc_other_lens.size() - orig_bc_axis); - new_bc_op = make_op("broadcast", {{"axis", new_bc_axis}, {"out_lens", new_other_lens}}); + new_other = other_inp; } else { - new_bc_op = make_op("multibroadcast", {{"out_lens", new_other_lens}}); - } + auto other_lens = other->get_shape().lens(); + if(other_lens.size() > 2) + return; + + std::vector new_other_lens{inp_lens.begin(), inp_lens.end() - 2}; + operation new_bc_op; + + auto bc_other = (flipped) ? dot->inputs().front() : dot->inputs().back(); + auto bc_other_lens = bc_other->get_shape().lens(); + new_other_lens.insert( + new_other_lens.end(), bc_other_lens.end() - 2, bc_other_lens.end()); - auto new_bc_other = m.insert_instruction(dot, new_bc_op, other); + // if the original weight is one dimensional, look at the original broadcast + // to determine the correct broadcast axis + if(other_lens.size() == 1) + { + auto bc_other_strides = bc_other->get_shape().strides(); + auto it = std::find_if(bc_other_strides.begin(), + bc_other_strides.end(), + [&](auto i) { return i != 0; }); + auto orig_bc_axis = std::distance(bc_other_strides.begin(), it); + + auto new_bc_axis = new_other_lens.size() - (bc_other_lens.size() - orig_bc_axis); + new_bc_op = + make_op("broadcast", {{"axis", new_bc_axis}, {"out_lens", new_other_lens}}); + } + else + { + new_bc_op = make_op("multibroadcast", {{"out_lens", new_other_lens}}); + } + + new_other = m.insert_instruction(dot, new_bc_op, other); + } instruction_ref new_dot; if(flipped) { - new_dot = m.insert_instruction(dot, make_op("dot"), new_bc_other, inp); + new_dot = m.insert_instruction(dot, make_op("dot"), new_other, inp); } else { - new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_bc_other); + new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_other); } m.replace_instruction( dot, make_op("reshape", {{"dims", dot->get_shape().lens()}}), new_dot); @@ -1095,7 +1083,6 @@ void simplify_reshapes::apply(module& m) const find_slice_transpose{}, find_reshape_convert_dot{}, find_transpose_contiguous_reshaper_unary{}, - find_reshape_reshape_dot{}, find_reshape_dot{}, find_scalar_multibroadcast_reshape_or_transpose{}); dead_code_elimination{}.apply(m);