diff --git a/requirements.txt b/requirements.txt index d0555dba6c0..174a745cc6c 100755 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@8f51edbff77499fbd4d14e38b3efda1d210f6f6e -DBUILD_FAT_LIBROCKCOMPILER=On +ROCm/rocMLIR@26c8d17e70db4690da8db5ea60dab3d271c82c54 -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/include/migraphx/op/pad.hpp b/src/include/migraphx/op/pad.hpp index 7fad02b8375..478299624f9 100644 --- a/src/include/migraphx/op/pad.hpp +++ b/src/include/migraphx/op/pad.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 06ec1f1dec5..c69d9f10de8 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -592,6 +592,19 @@ struct find_mlir_standalone_attention_op bias = mm->add_instruction(make_op("add"), scaled_gemm0, bias_param); inputs.push_back(bias_input); } + else if(orig_inputs.size() == 5) // gemm1 + mul_where + softmax + gemm2 + trailing_pm case + { + auto select_cond = orig_inputs[2]; + auto select_const = orig_inputs[3]; + instruction_ref select_cond_param = + mm->add_parameter("y_cond", select_cond->get_shape().as_standard()); + instruction_ref select_cond_const = + mm->add_parameter("y_const", select_const->get_shape().as_standard()); + bias = mm->add_instruction( + make_op("where"), select_cond_param, scaled_gemm0, select_cond_const); + inputs.push_back(select_cond); + inputs.push_back(select_const); + } auto softmax = mm->add_instruction( make_op("softmax", {{"axis", gemm0->get_shape().lens().size() - 1}}), diff --git a/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp b/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp index ec8561e779c..6a63bde373d 100644 --- a/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp +++ b/src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp @@ -47,7 +47,8 @@ struct gemm_softmax_gemm void check_gemm_shape(const shape& s) const { - if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1)) + if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1) and + not s.scalar()) MIGRAPHX_THROW("Invalid shape for " + name()); } @@ -58,15 +59,35 @@ struct gemm_softmax_gemm MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size())); const bool is_bias_enabled = inputs.size() == 4; + const bool is_mul_where = inputs.size() == 5; auto a = inputs[0]; auto b = inputs[1]; - auto b1 = inputs[is_bias_enabled ? 3 : 2]; + auto b1 = inputs.back(); for(const auto& input : inputs) { check_gemm_shape(input); } auto gemm0_shape = op.compute_shape({a, b}); + if(is_mul_where) + { + auto select_cond = inputs[2]; + auto select_const = inputs[3]; + if(select_cond.lens() != select_const.lens()) + { + std::stringstream err_msg; + err_msg << name() << ": has inconsistent where op condition and constant size: " + << select_cond << "!=" << select_const; + MIGRAPHX_THROW(err_msg.str()); + } + if(select_cond.lens() != gemm0_shape.lens()) + { + std::stringstream err_msg; + err_msg << name() << ": has inconsistent where op condition size" + << ". Expected: " << gemm0_shape << ". Given: " << select_cond; + MIGRAPHX_THROW(err_msg.str()); + } + } if(is_bias_enabled) { auto bias_shape = inputs[2]; diff --git a/src/targets/gpu/jit/pad.cpp b/src/targets/gpu/jit/pad.cpp index 48f43359e76..d216cb2dc74 100644 --- a/src/targets/gpu/jit/pad.cpp +++ b/src/targets/gpu/jit/pad.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index 321a3e8f888..7856a6e8f58 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -196,12 +196,15 @@ struct find_gemm_softmax_gemm .bind("gemm1"))); auto mul = match::name("mul")( match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1)); + auto where = match::name("where")(match::arg(2)(match::is_constant().bind("select_const")), + match::arg(1)(mul), + match::arg(0)(match::any().bind("select_cond"))); auto add = match::name("add")(is_bias_supported(), match::nargs(2), match::either_arg(0, 1)(match::none_of(mul).bind("bias"), mul)); - auto softmax = - match::name("softmax")(match::arg(0)(match::any_of(mul, add, gemm1))).bind("softmax"); + auto softmax = match::name("softmax")(match::arg(0)(match::any_of(mul, add, gemm1, where))) + .bind("softmax"); return match::name("dot")( match::any_of(is_ck_gemm(), is_mlir_gemm(), is_test_gemm(enable_attention)) @@ -229,10 +232,16 @@ struct find_gemm_softmax_gemm } auto inputs = gemm1_ins->inputs(); // A, B + if(contains(r.instructions, "select_cond")) + { + inputs.push_back(r.instructions["select_cond"]); + inputs.push_back(r.instructions["select_const"]); + } if(contains(r.instructions, "bias")) { inputs.push_back(r.instructions["bias"]); } + inputs.push_back(gemm2_ins->inputs().back()); // B1 mpm.get_module().replace_instruction( diff --git a/test/verify/test_gemm_mul_where_softmax_gemm.cpp b/test/verify/test_gemm_mul_where_softmax_gemm.cpp new file mode 100644 index 00000000000..20305f7dc03 --- /dev/null +++ b/test/verify/test_gemm_mul_where_softmax_gemm.cpp @@ -0,0 +1,56 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gemm_mul_where_softmax_gemm : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 256}}; + migraphx::shape m2_shape{migraphx::shape::bool_type, {1, 12, 256, 256}}; + auto m1_elements = m1_shape.elements(); + auto a = mm->add_parameter("1", m1_shape); + auto b = mm->add_parameter("2", m1_shape); + auto b1 = mm->add_parameter("3", m1_shape); + auto select = mm->add_parameter("4", m2_shape); + std::vector eights(m1_elements, 0.125); + std::vector tens(m1_elements, 10); + auto eight = mm->add_literal(migraphx::literal{m1_shape, eights}); + auto ten = mm->add_literal(migraphx::literal{m1_shape, tens}); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight); + auto where = mm->add_instruction(migraphx::make_op("where"), select, scale, ten); + auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), where); + mm->add_instruction(migraphx::make_op("dot"), softmax, b1); + return p; + } + std::string section() const { return "gemm"; } +};