Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match "mul + where" as the pointwise operation before softmax in attention fusion #3381

Merged
merged 13 commits into from
Aug 20, 2024
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/rocMLIR@8f51edbff77499fbd4d14e38b3efda1d210f6f6e -DBUILD_FAT_LIBROCKCOMPILER=On
ROCm/rocMLIR@26c8d17e70db4690da8db5ea60dab3d271c82c54 -DBUILD_FAT_LIBROCKCOMPILER=On
2 changes: 1 addition & 1 deletion src/include/migraphx/op/pad.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down
13 changes: 13 additions & 0 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,19 @@ struct find_mlir_standalone_attention_op
bias = mm->add_instruction(make_op("add"), scaled_gemm0, bias_param);
inputs.push_back(bias_input);
}
else if(orig_inputs.size() == 5) // gemm1 + mul_where + softmax + gemm2 + trailing_pm case
{
auto select_cond = orig_inputs[2];
auto select_const = orig_inputs[3];
instruction_ref select_cond_param =
mm->add_parameter("y_cond", select_cond->get_shape().as_standard());
instruction_ref select_cond_const =
mm->add_parameter("y_const", select_const->get_shape().as_standard());
bias = mm->add_instruction(
make_op("where"), select_cond_param, scaled_gemm0, select_cond_const);
inputs.push_back(select_cond);
inputs.push_back(select_const);
}

auto softmax = mm->add_instruction(
make_op("softmax", {{"axis", gemm0->get_shape().lens().size() - 1}}),
Expand Down
25 changes: 23 additions & 2 deletions src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ struct gemm_softmax_gemm

void check_gemm_shape(const shape& s) const
{
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1) and
not s.scalar())
MIGRAPHX_THROW("Invalid shape for " + name());
}

Expand All @@ -58,15 +59,35 @@ struct gemm_softmax_gemm
MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size()));

const bool is_bias_enabled = inputs.size() == 4;
const bool is_mul_where = inputs.size() == 5;
auto a = inputs[0];
auto b = inputs[1];
auto b1 = inputs[is_bias_enabled ? 3 : 2];
auto b1 = inputs.back();

for(const auto& input : inputs)
{
check_gemm_shape(input);
}
auto gemm0_shape = op.compute_shape({a, b});
if(is_mul_where)
{
auto select_cond = inputs[2];
auto select_const = inputs[3];
if(select_cond.lens() != select_const.lens())
{
std::stringstream err_msg;
err_msg << name() << ": has inconsistent where op condition and constant size: "
<< select_cond << "!=" << select_const;
MIGRAPHX_THROW(err_msg.str());
}
if(select_cond.lens() != gemm0_shape.lens())
{
std::stringstream err_msg;
err_msg << name() << ": has inconsistent where op condition size"
<< ". Expected: " << gemm0_shape << ". Given: " << select_cond;
MIGRAPHX_THROW(err_msg.str());
}
}
if(is_bias_enabled)
{
auto bias_shape = inputs[2];
Expand Down
2 changes: 1 addition & 1 deletion src/targets/gpu/jit/pad.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down
13 changes: 11 additions & 2 deletions src/targets/gpu/prefuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,15 @@ struct find_gemm_softmax_gemm
.bind("gemm1")));
auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto where = match::name("where")(match::arg(2)(match::is_constant().bind("select_const")),
match::arg(1)(mul),
match::arg(0)(match::any().bind("select_cond")));
auto add =
match::name("add")(is_bias_supported(),
match::nargs(2),
match::either_arg(0, 1)(match::none_of(mul).bind("bias"), mul));
auto softmax =
match::name("softmax")(match::arg(0)(match::any_of(mul, add, gemm1))).bind("softmax");
auto softmax = match::name("softmax")(match::arg(0)(match::any_of(mul, add, gemm1, where)))
.bind("softmax");

return match::name("dot")(
match::any_of(is_ck_gemm(), is_mlir_gemm(), is_test_gemm(enable_attention))
Expand Down Expand Up @@ -229,10 +232,16 @@ struct find_gemm_softmax_gemm
}

auto inputs = gemm1_ins->inputs(); // A, B
if(contains(r.instructions, "select_cond"))
{
inputs.push_back(r.instructions["select_cond"]);
inputs.push_back(r.instructions["select_const"]);
}
if(contains(r.instructions, "bias"))
{
inputs.push_back(r.instructions["bias"]);
}

inputs.push_back(gemm2_ins->inputs().back()); // B1

mpm.get_module().replace_instruction(
Expand Down
56 changes: 56 additions & 0 deletions test/verify/test_gemm_mul_where_softmax_gemm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct test_gemm_mul_where_softmax_gemm : verify_program<test_gemm_mul_where_softmax_gemm>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}};
migraphx::shape m2_shape{migraphx::shape::bool_type, {1, 12, 256, 256}};
auto m1_elements = m1_shape.elements();
auto a = mm->add_parameter("1", m1_shape);
auto b = mm->add_parameter("2", m1_shape);
auto b1 = mm->add_parameter("3", m1_shape);
auto select = mm->add_parameter("4", m2_shape);
std::vector<float> eights(m1_elements, 0.125);
std::vector<float> tens(m1_elements, 10);
auto eight = mm->add_literal(migraphx::literal{m1_shape, eights});
auto ten = mm->add_literal(migraphx::literal{m1_shape, tens});
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto where = mm->add_instruction(migraphx::make_op("where"), select, scale, ten);
auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), where);
mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
return p;
}
std::string section() const { return "gemm"; }
};
Loading