Skip to content

Commit

Permalink
Handle broadcast operator for inputs to concat as well (#3450)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 committed Sep 20, 2024
1 parent 32a84a8 commit 8b992d7
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 15 deletions.
44 changes: 29 additions & 15 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ struct find_concat_multibroadcasts
{
auto matcher() const
{
return match::name("concat")(match::all_of[match::inputs()](match::name("multibroadcast")));
return match::name("concat")(
match::all_of[match::inputs()](match::name("multibroadcast", "broadcast")));
}

void apply(module& m, const match::matcher_result& mr) const
Expand All @@ -287,32 +288,46 @@ struct find_concat_multibroadcasts
return;
}

// Skip if the broadcasts are different
auto broadcast = concat_inputs.front()->get_operator();
auto broadcast_value = broadcast.to_value();
if(not std::all_of(concat_inputs.begin() + 1, concat_inputs.end(), [&](instruction_ref b) {
if(b->name() != broadcast.name())
return false;
if(broadcast.name() == "broadcast")
return b->get_operator().to_value()["axis"] == broadcast_value["axis"];
return true;
}))
{
return;
}

// Get the inputs of multibroadcast ops. Will be used as inputs to new concat op
std::vector<instruction_ref> mb_inputs(concat_inputs.size());
std::transform(concat_inputs.begin(), concat_inputs.end(), mb_inputs.begin(), [](auto i) {
std::vector<instruction_ref> inputs(concat_inputs.size());
std::transform(concat_inputs.begin(), concat_inputs.end(), inputs.begin(), [](auto i) {
return i->inputs().front();
});

// Check that the inputs into the multibroadcasts have the same rank
const auto& first_shape = mb_inputs.front()->get_shape();
if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto mb_in) {
return mb_in->get_shape().ndim() == first_shape.ndim();
// Check that the inputs into the broadcasts have the same rank
const auto& first_shape = inputs.front()->get_shape();
if(not std::all_of(inputs.begin() + 1, inputs.end(), [&](auto input) {
return input->get_shape().ndim() == first_shape.ndim();
}))
{
return;
}

// Reduce axis by number of leading broadcasted dimensions
if(mb_inputs.front()->get_shape().lens().size() < concat_out_lens.size())
if(inputs.front()->get_shape().lens().size() < concat_out_lens.size())
{
concat_op.axis -=
std::count(front_mb_strides.begin(), front_mb_strides.begin() + concat_op.axis, 0);
}

// Inputs to multibroadcasts should have the same dimensions except for the axis to
// Inputs to broadcasts should have the same dimensions except for the axis to
// concatenate over
const auto& front_in_lens = mb_inputs.front()->get_shape().lens();
if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto input_to_mb) {
const auto& front_in_lens = inputs.front()->get_shape().lens();
if(not std::all_of(inputs.begin() + 1, inputs.end(), [&](auto input_to_mb) {
const auto& lens = input_to_mb->get_shape().lens();
return std::equal(
lens.begin(), lens.begin() + concat_op.axis, front_in_lens.begin()) and
Expand All @@ -324,10 +339,9 @@ struct find_concat_multibroadcasts
return;
}

auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, mb_inputs);
m.replace_instruction(concat_ins,
migraphx::make_op("multibroadcast", {{"out_lens", concat_out_lens}}),
new_concat_ins);
auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, inputs);
broadcast.from_value({{"out_lens", concat_ins->get_shape().lens()}});
m.replace_instruction(concat_ins, broadcast, new_concat_ins);
}
};

Expand Down
27 changes: 27 additions & 0 deletions test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,33 @@ TEST_CASE(concat_multibroadcasts9)
EXPECT(m == m_original);
}

TEST_CASE(concat_broadcast1)
{
auto s = migraphx::shape{migraphx::shape::float_type, {1024, 1024}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto xb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {8, 1024, 1024}}}), x);
auto yb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {8, 1024, 1024}}}), y);
auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), xb, yb);
m1.add_return({concat});
}
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
auto b = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {8, 1024, 2048}}}), concat);
m2.add_return({b});
}
run_pass(m1);
EXPECT(m1 == m2);
}

TEST_CASE(concat_transpose1)
{
migraphx::module m;
Expand Down

0 comments on commit 8b992d7

Please sign in to comment.