Skip to content

Commit

Permalink
find_splits bugfix (#3244)
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar committed Jul 9, 2024
1 parent 937048b commit fbb3de7
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 4 deletions.
61 changes: 57 additions & 4 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,47 @@ struct find_splits
return true;
}

int get_binary_op_split_idx(std::vector<instruction_ref> group,
std::vector<instruction_ref> splits) const
{
auto first_group_inputs = group.front()->inputs();
auto arg_it =
std::find_if(first_group_inputs.begin(), first_group_inputs.end(), [&](auto i) {
return std::find(splits.begin(), splits.end(), i) != splits.end();
});
auto split_idx = arg_it - first_group_inputs.begin();

// All splits are at the same input index
if(std::all_of(group.begin() + 1, group.end(), [&](auto i) {
auto split_idx_input = i->inputs().at(split_idx);
return std::find(splits.begin(), splits.end(), split_idx_input) != splits.end();
}))
return split_idx;

return -1;
}

void align_commutative_op_args(module& m,
std::vector<instruction_ref> group,
std::vector<instruction_ref> splits,
size_t split_idx) const
{
auto group_op = group.front()->get_operator();
assert(std::all_of(
group.begin(), group.end(), [&](auto i) { return i->get_operator() == group_op; }));

for(auto i : group)
{
if(std::find(splits.begin(), splits.end(), i->inputs().at(split_idx)) == splits.end())
{
auto args = i->inputs();
assert(args.size() == 2);
std::reverse(args.begin(), args.end());
m.replace_instruction(i, i->get_operator(), args);
}
}
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
Expand Down Expand Up @@ -1127,11 +1168,23 @@ struct find_splits
assert(not std::none_of(start->inputs().begin(), start->inputs().end(), [](auto i) {
return i->name() == "slice";
}) && "one argument must be a split");
auto data_idx = 1;
if(start->inputs().back()->name() == "slice")

split_idx = get_binary_op_split_idx(group, splits);
assert(split_idx < 2);
size_t data_idx;
if(split_idx < 0 and op.attributes().contains("commutative"))
{
split_idx = 0;
data_idx = 1;
align_commutative_op_args(m, group, splits, split_idx);
}
else if(split_idx < 0)
{
return;
}
else
{
split_idx = 1;
data_idx = 0;
data_idx = split_idx == 0 ? 1 : 0;
}

std::vector<instruction_ref> data_args;
Expand Down
91 changes: 91 additions & 0 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2023,6 +2023,97 @@ TEST_CASE(simplify_split_add_relu)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(simplify_split_add_flipped_input)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto one = m1.add_literal(1);
auto oneb = m1.add_instruction(b, one);
auto two = m1.add_literal(2);
auto twob = m1.add_instruction(b, two);
auto sum1 = m1.add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = m1.add_instruction(migraphx::make_op("add"), twob, y);
auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2);
auto add = m1.add_instruction(migraphx::make_op("add"), relu1, relu2);
m1.add_instruction(pass_op{}, add);
}
run_pass(m1);

migraphx::module m2;
{
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1);
auto two = m2.add_literal(2);
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), one, two);
auto concatb = m2.add_instruction(b, concat);
auto sum = m2.add_instruction(migraphx::make_op("add"), input, concatb);
auto relu = m2.add_instruction(migraphx::make_op("relu"), sum);
auto x = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu);
auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu);
auto add = m2.add_instruction(migraphx::make_op("add"), x, y);
m2.add_instruction(pass_op{}, add);
}
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(simplify_split_non_comm_flipped_input)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto one = m1.add_literal(1);
auto oneb = m1.add_instruction(b, one);
auto two = m1.add_literal(2);
auto twob = m1.add_instruction(b, two);
auto sum1 = m1.add_instruction(migraphx::make_op("sub"), x, oneb);
auto relu1 = m1.add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = m1.add_instruction(migraphx::make_op("sub"), twob, y);
auto relu2 = m1.add_instruction(migraphx::make_op("relu"), sum2);
auto add = m1.add_instruction(migraphx::make_op("sub"), relu1, relu2);
m1.add_instruction(pass_op{}, add);
}
run_pass(m1);

migraphx::module m2;
{
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m2.add_parameter("input", s);
auto x = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto one = m2.add_literal(1);
auto neg_one = m2.add_instruction(migraphx::make_op("neg"), one);
auto oneb = m2.add_instruction(b, neg_one);
auto two = m2.add_literal(2);
auto twob = m2.add_instruction(b, two);
auto sum1 = m2.add_instruction(migraphx::make_op("add"), x, oneb);
auto relu1 = m2.add_instruction(migraphx::make_op("relu"), sum1);
auto sum2 = m2.add_instruction(migraphx::make_op("sub"), twob, y);
auto relu2 = m2.add_instruction(migraphx::make_op("relu"), sum2);
auto add = m2.add_instruction(migraphx::make_op("sub"), relu1, relu2);
m2.add_instruction(pass_op{}, add);
}
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(simplify_split_reduce0)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
Expand Down

0 comments on commit fbb3de7

Please sign in to comment.