Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent collapsing batch dims in dot ops with constants #2823

Merged
merged 13 commits into from
May 31, 2024
Merged
154 changes: 134 additions & 20 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -900,26 +900,80 @@
}
};

struct find_reshape_reshape_dot
// Remove unnecessary preceeding size 1 dims for constants
struct find_const_multibroadcast
{
auto matcher() const
{
return match::name("dot")(match::used_once(),
match::args(match::name("reshape").bind("inp_rsp1"),
match::name("reshape").bind("inp_rsp2")));
return match::name("multibroadcast")(match::arg(0)(match::is_constant().bind("constant")));
}

void apply(module& m, const match::matcher_result& mr) const
{
auto mbr = mr.result;
auto constant = mr.instructions["constant"];

if(constant->get_shape().lens().size() <= 1)
return;

auto const_lens = constant->get_shape().lens();
auto it = std::find_if(const_lens.begin(), const_lens.end(), [](auto i) { return i != 1; });
auto naxes = std::distance(const_lens.begin(), it);
if(naxes == 0)
return;

std::vector<std::size_t> sq_axes(naxes);
std::iota(sq_axes.begin(), sq_axes.end(), 0);

auto sq_const =
m.insert_instruction(mbr, make_op("squeeze", {{"axes", sq_axes}}), constant);
m.replace_instruction(mbr, mbr->get_operator(), sq_const);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we replace it with broadcast instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just removing any unnecessary preceding dims in literals eg. {1, 1, 640, 640) which are later broadcasted to something like {2, 32, 640, 640}. Would broadcast work for this? I thought it only does 1 axis

}
};
shivadbhavsar marked this conversation as resolved.
Show resolved Hide resolved

// Move convert before reshape when preceeding a dot op
struct find_reshape_convert_dot
{
auto matcher() const
{
auto dot_output = match::any_of[match::outputs()](match::name("dot"));
return match::name("convert")(match::arg(0)(match::name("reshape").bind("rsp")),
dot_output);
}

void apply(module& m, const match::matcher_result& r) const
{
auto convert = r.result;
auto rsp = r.instructions["rsp"];

auto inp = rsp->inputs().front();
auto new_convert = m.insert_instruction(rsp, convert->get_operator(), inp);
m.replace_instruction(convert, rsp->get_operator(), new_convert);
}
};
shivadbhavsar marked this conversation as resolved.
Show resolved Hide resolved

struct find_reshape_dot
{
auto matcher() const
{
return match::name("dot")(
match::used_once(),
match::either_arg(0, 1)(match::name("reshape").bind("rsp"),
match::skip_broadcasts(match::any().bind("other"))));
}

// Gemm axis should not be altered by the reshape
auto is_valid_reshape(instruction_ref in, instruction_ref rsp) const
auto is_valid_reshape(instruction_ref inp, instruction_ref rsp, size_t dot_axis) const
{
auto in_lens = in->get_shape().lens();
auto inp_lens = inp->get_shape().lens();
auto rsp_lens = rsp->get_shape().lens();

return std::equal(rsp_lens.end() - 2, rsp_lens.end(), in_lens.end() - 2, in_lens.end());
return (inp_lens.size() >= dot_axis and
rsp_lens[rsp_lens.size() - dot_axis] == inp_lens[inp_lens.size() - dot_axis]);
}

// Batch dims should match for both inputs
auto is_valid_inputs(instruction_ref in1, instruction_ref in2) const
// Same batch dims
auto has_same_batch_dims(instruction_ref in1, instruction_ref in2) const
{
auto in1_lens = in1->get_shape().lens();
auto in2_lens = in2->get_shape().lens();
Expand All @@ -931,21 +985,79 @@

void apply(module& m, const match::matcher_result& r) const
{
auto dot = r.result;
auto inp_rsp1 = r.instructions["inp_rsp1"];
auto inp_rsp2 = r.instructions["inp_rsp2"];
auto dot = r.result;
auto rsp = r.instructions["rsp"];
auto other = r.instructions["other"];

auto dot_lens = dot->get_shape().lens();
auto rsp_lens = rsp->get_shape().lens();
auto inp = rsp->inputs().front();
auto inp_lens = inp->get_shape().lens();

auto inp1 = inp_rsp1->inputs().front();
auto inp2 = inp_rsp2->inputs().front();
// Gemm axis should not be altered by the reshape
bool flipped = rsp == dot->inputs().back();
size_t dot_axis = (flipped) ? 2 : 1;

if(not(is_valid_reshape(inp1, inp_rsp1) and is_valid_reshape(inp2, inp_rsp2) and
is_valid_inputs(inp1, inp2)))
if(not is_valid_reshape(inp, rsp, dot_axis))
return;

auto new_dot = m.insert_instruction(dot, dot->get_operator(), inp1, inp2);
m.replace_instruction(dot, make_op("reshape", {{"dims", dot_lens}}), new_dot);
instruction_ref new_other;
if(other->get_operator().name() == "reshape")
{
auto other_inp = other->inputs().front();
size_t other_dot_axis = (flipped) ? 1 : 2;
if(not is_valid_reshape(other_inp, other, other_dot_axis) or
not has_same_batch_dims(inp, other_inp))

Check warning on line 1009 in src/simplify_reshapes.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_reshapes.cpp#L1009

Added line #L1009 was not covered by tests
return;

new_other = other_inp;
}
else
{
auto other_lens = other->get_shape().lens();
if(other_lens.size() > 2)
return;

std::vector<size_t> new_other_lens{inp_lens.begin(), inp_lens.end() - 2};
operation new_bc_op;

auto bc_other = (flipped) ? dot->inputs().front() : dot->inputs().back();
auto bc_other_lens = bc_other->get_shape().lens();
new_other_lens.insert(
new_other_lens.end(), bc_other_lens.end() - 2, bc_other_lens.end());

// if the original weight is one dimensional, look at the original broadcast
// to determine the correct broadcast axis
if(other_lens.size() == 1)
{
auto bc_other_strides = bc_other->get_shape().strides();
auto it = std::find_if(bc_other_strides.begin(),
bc_other_strides.end(),
[&](auto i) { return i != 0; });
auto orig_bc_axis = std::distance(bc_other_strides.begin(), it);

auto new_bc_axis = new_other_lens.size() - (bc_other_lens.size() - orig_bc_axis);
new_bc_op =
make_op("broadcast", {{"axis", new_bc_axis}, {"out_lens", new_other_lens}});
}
else
{
new_bc_op = make_op("multibroadcast", {{"out_lens", new_other_lens}});
}

new_other = m.insert_instruction(dot, new_bc_op, other);
}

instruction_ref new_dot;
if(flipped)
{
new_dot = m.insert_instruction(dot, make_op("dot"), new_other, inp);
}
else
{
new_dot = m.insert_instruction(dot, make_op("dot"), inp, new_other);
}
m.replace_instruction(
dot, make_op("reshape", {{"dims", dot->get_shape().lens()}}), new_dot);
}
};

Expand All @@ -960,6 +1072,7 @@
find_reshaper{},
find_reshape_cont{},
find_transpose{},
find_const_multibroadcast{},
find_concat_slice{},
find_concat_transpose{},
find_concat_multibroadcasts{},
Expand All @@ -968,8 +1081,9 @@
find_transpose_slice{},
find_broadcast_transpose{},
find_slice_transpose{},
find_reshape_convert_dot{},
find_transpose_contiguous_reshaper_unary{},
find_reshape_reshape_dot{},
find_reshape_dot{},
find_scalar_multibroadcast_reshape_or_transpose{});
dead_code_elimination{}.apply(m);
}
Expand Down
Loading
Loading