Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent collapsing batch dims in dot ops with constants #2823

Merged
merged 13 commits into from
May 31, 2024
Merged
100 changes: 80 additions & 20 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,26 +1012,28 @@
}
};

struct find_reshape_reshape_dot
struct find_reshape_dot
{
auto matcher() const
{
return match::name("dot")(match::used_once(),
match::args(match::name("reshape").bind("inp_rsp1"),
match::name("reshape").bind("inp_rsp2")));
return match::name("dot")(
match::used_once(),
match::either_arg(0, 1)(match::name("reshape").bind("rsp"),
match::skip_broadcasts(match::any().bind("other"))));
}

// Gemm axis should not be altered by the reshape
auto is_valid_reshape(instruction_ref in, instruction_ref rsp) const
auto is_valid_reshape(instruction_ref inp, instruction_ref rsp, size_t dot_axis) const
{
auto in_lens = in->get_shape().lens();
auto inp_lens = inp->get_shape().lens();
auto rsp_lens = rsp->get_shape().lens();

return std::equal(rsp_lens.end() - 2, rsp_lens.end(), in_lens.end() - 2, in_lens.end());
return (inp_lens.size() >= dot_axis and
rsp_lens[rsp_lens.size() - dot_axis] == inp_lens[inp_lens.size() - dot_axis]);
}

// Batch dims should match for both inputs
auto is_valid_inputs(instruction_ref in1, instruction_ref in2) const
// Same batch dims
auto has_same_batch_dims(instruction_ref in1, instruction_ref in2) const
{
auto in1_lens = in1->get_shape().lens();
auto in2_lens = in2->get_shape().lens();
Expand All @@ -1043,21 +1045,79 @@

void apply(module& m, const match::matcher_result& r) const
{
auto dot = r.result;
auto inp_rsp1 = r.instructions["inp_rsp1"];
auto inp_rsp2 = r.instructions["inp_rsp2"];
auto dot = r.result;
auto rsp = r.instructions["rsp"];
auto other = r.instructions["other"];

auto dot_lens = dot->get_shape().lens();
auto rsp_lens = rsp->get_shape().lens();
auto inp = rsp->inputs().front();
auto inp_lens = inp->get_shape().lens();

auto inp1 = inp_rsp1->inputs().front();
auto inp2 = inp_rsp2->inputs().front();
// Gemm axis should not be altered by the reshape
bool flipped = rsp == dot->inputs().back();
size_t dot_axis = (flipped) ? 2 : 1;

if(not(is_valid_reshape(inp1, inp_rsp1) and is_valid_reshape(inp2, inp_rsp2) and
is_valid_inputs(inp1, inp2)))
if(not is_valid_reshape(inp, rsp, dot_axis))
return;

auto new_dot = m.insert_instruction(dot, dot->get_operator(), inp1, inp2);
m.replace_instruction(dot, make_op("reshape", {{"dims", dot_lens}}), new_dot);
instruction_ref new_other;
if(other->get_operator().name() == "reshape")
{
auto other_inp = other->inputs().front();
size_t other_dot_axis = (flipped) ? 1 : 2;
if(not is_valid_reshape(other_inp, other, other_dot_axis) or
not has_same_batch_dims(inp, other_inp))
return;

Check warning on line 1070 in src/simplify_reshapes.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_reshapes.cpp#L1070

Added line #L1070 was not covered by tests

new_other = other_inp;
}
else
{
auto other_lens = other->get_shape().lens();
if(other_lens.size() > 2)
return;

std::vector<size_t> new_other_lens{inp_lens.begin(), inp_lens.end() - 2};
operation new_bc_op;

auto bc_other = (flipped) ? dot->inputs().front() : dot->inputs().back();
auto bc_other_lens = bc_other->get_shape().lens();
new_other_lens.insert(
new_other_lens.end(), bc_other_lens.end() - 2, bc_other_lens.end());

// if the original weight is one dimensional, look at the original broadcast
// to determine the correct broadcast axis
if(other_lens.size() == 1)
{
auto bc_other_strides = bc_other->get_shape().strides();
auto it = std::find_if(bc_other_strides.begin(),
bc_other_strides.end(),
[&](auto i) { return i != 0; });
auto orig_bc_axis = std::distance(bc_other_strides.begin(), it);

auto new_bc_axis = new_other_lens.size() - (bc_other_lens.size() - orig_bc_axis);
new_bc_op =
make_op("broadcast", {{"axis", new_bc_axis}, {"out_lens", new_other_lens}});
}
else
{
new_bc_op = make_op("multibroadcast", {{"out_lens", new_other_lens}});
}

new_other = m.insert_instruction(dot, new_bc_op, other);
}

instruction_ref new_dot;
if(flipped)
{
new_dot = m.insert_instruction(dot, make_op("dot"), new_other, inp);
}
else
{
new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_other);
}
m.replace_instruction(
dot, make_op("reshape", {{"dims", dot->get_shape().lens()}}), new_dot);
}
};

Expand All @@ -1081,7 +1141,7 @@
find_broadcast_transpose{},
find_slice_transpose{},
find_unary_shape_transforms{},
find_reshape_reshape_dot{},
find_reshape_dot{},
find_scalar_multibroadcast_reshape_or_transpose{});
dead_code_elimination{}.apply(m);
}
Expand Down
166 changes: 166 additions & 0 deletions test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2266,4 +2266,170 @@ TEST_CASE(reshape_reshape_dot_gemm_axis)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_dot)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {32, 32}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 32}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc);
m1.add_return({dot});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("inp", s_inp);
auto w = m2.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 32, 32}}}), w);
auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), dot);
m2.add_return({rsp});
};

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_dot_flipped)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {16, 8}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {16, 8, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16, 16, 8}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), w_bc, rsp);
m1.add_return({dot});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("inp", s_inp);
auto w = m2.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 16, 8}}}), w);
auto dot = m2.add_instruction(migraphx::make_op("dot"), w_bc, inp);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {16, 16, 32}}}), dot);
m2.add_return({rsp});
};

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_dot_dot_axis)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 4}};
migraphx::shape s_w{migraphx::shape::float_type, {32, 32}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 8, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 32, 32}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc);
m1.add_return({dot});
};

migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_dot_flipped_dot_axis)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {8, 64}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 64}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), w_bc, rsp);
m1.add_return({dot});
};

migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_dot_broadcast)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {32}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 32, 32}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc);
m1.add_return({dot});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("inp", s_inp);
auto w = m2.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", {2, 8, 32, 32}}}), w);
auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 64, 32}}}), dot);
m2.add_return({rsp});
};

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reshape_dot_broadcast_2)
{
migraphx::shape s_inp{migraphx::shape::float_type, {2, 8, 8, 32}};
migraphx::shape s_w{migraphx::shape::float_type, {32}};

migraphx::module m1;
{
auto inp = m1.add_parameter("inp", s_inp);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 32}}}), inp);
auto w = m1.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {32, 32}}}), w);
auto dot = m1.add_instruction(migraphx::make_op("dot"), rsp, w_bc);
m1.add_return({dot});
};
run_pass(m1);

migraphx::module m2;
{
auto inp = m2.add_parameter("inp", s_inp);
auto w = m2.add_literal(migraphx::generate_literal(s_w));
auto w_bc = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 3}, {"out_lens", {2, 8, 32, 32}}}), w);
auto dot = m2.add_instruction(migraphx::make_op("dot"), inp, w_bc);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 32}}}), dot);
m2.add_return({rsp});
};

EXPECT(m1.sort() == m2.sort());
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
Loading