Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize attention fusion #3403

Merged
merged 28 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
95a38c3
submodule matcher
shivadbhavsar Aug 22, 2024
ba37389
wip - remove buggy submodule matcher, add fuse reduce unrolling
shivadbhavsar Aug 22, 2024
1340c50
wip - almost working without trailing pw
shivadbhavsar Aug 23, 2024
1e976fb
working attn fusion
shivadbhavsar Aug 26, 2024
0c5bcd4
remove include
shivadbhavsar Aug 26, 2024
d0d65c3
working bert
shivadbhavsar Aug 27, 2024
a247ca3
typo
shivadbhavsar Aug 27, 2024
26141c3
format
shivadbhavsar Aug 27, 2024
61023b0
Merge remote-tracking branch 'origin/develop' into generalized_attn
shivadbhavsar Aug 27, 2024
df3af29
add tests
shivadbhavsar Aug 28, 2024
07c3985
add tests for new utility functions
shivadbhavsar Aug 28, 2024
19a9eed
test update
shivadbhavsar Aug 28, 2024
2638b6c
fix module test case
shivadbhavsar Aug 28, 2024
6367bbe
Merge remote-tracking branch 'origin/develop' into generalized_attn
shivadbhavsar Sep 10, 2024
5dd25c5
document test cases
shivadbhavsar Sep 10, 2024
491576f
format
shivadbhavsar Sep 10, 2024
10d254f
Merge branch 'develop' into generalized_attn
causten Sep 11, 2024
d2fb486
Merge remote-tracking branch 'origin/develop' into generalized_attn
shivadbhavsar Sep 16, 2024
dcc5e8b
fix temporarly modules
shivadbhavsar Sep 16, 2024
8410707
move lambda for unrolling pointwise ops to a static func and reuse
shivadbhavsar Sep 16, 2024
cac74e4
bad assert
shivadbhavsar Sep 16, 2024
2a86823
Merge branch 'develop' into generalized_attn
pfultz2 Sep 18, 2024
83b7194
revert pass_manager change
shivadbhavsar Sep 18, 2024
b12fc83
Merge branch 'develop' into generalized_attn
shivadbhavsar Sep 18, 2024
f82b327
fix compilation on Windows
apwojcik Sep 18, 2024
ea5b5d3
add missing set_bypass
shivadbhavsar Sep 18, 2024
98522a8
accept reshapes between gemm1 and fused_reduce
shivadbhavsar Sep 18, 2024
df7d9f0
format
shivadbhavsar Sep 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/include/migraphx/match/softmax.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MATCH_SOFTMAX_HPP
#define MIGRAPHX_GUARD_MATCH_SOFTMAX_HPP

#include <migraphx/config.hpp>
#include <migraphx/matcher.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace match {

namespace detail {
template <class F>
struct softmax_matcher
{
F f;

auto exp_x_minus_max() const
{
auto x_minus_max =
f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(f("reduce_max"))));
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
return f("exp")(arg(0)(x_minus_max));
}

auto softmax_base_ops() const
{
auto sum_exp_x_minus_max = f("reduce_sum")(arg(0)(exp_x_minus_max()));
return f("div")(arg(0)(exp_x_minus_max()), arg(1)(skip_broadcasts(sum_exp_x_minus_max)));
}

auto matcher() const { return softmax_base_ops(); }
};
} // namespace detail

template <class F>
auto softmax(F f)
{
return detail::softmax_matcher<F>{f}.matcher();
}

inline auto softmax()
{
return softmax([](auto x) { return name(x); });
}

} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
11 changes: 11 additions & 0 deletions src/include/migraphx/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ struct MIGRAPHX_EXPORT module
std::unordered_map<instruction_ref, instruction_ref>
get_ins_param_map(const std::vector<instruction_ref>& inputs, bool reverse = false) const;

/// Given a mapping from submodule instructions to parent module instructions
/// construct a vector of inputs with parent module instructions in the
/// correct order
std::vector<instruction_ref>
get_inputs(const std::unordered_map<instruction_ref, instruction_ref>& map_ins) const;

using with_inputs = module_with_inputs;

/// This will split the module into two parts at the instruction splits.
Expand All @@ -245,6 +251,11 @@ struct MIGRAPHX_EXPORT module
const std::vector<instruction_ref>& splits1,
const std::vector<instruction_ref>& splits2) const;

// Insert params to module based on given input instructions and add
// mappings from inputs to corresponding params in instructions map
void add_params(const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr);

// Fuse the instruction into the module by inserting the instructions and
// parameters for any missing inputs.
std::vector<instruction_ref>
Expand Down
21 changes: 21 additions & 0 deletions src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,21 @@ module::get_ins_param_map(const std::vector<instruction_ref>& inputs, bool rever
return result;
}

std::vector<instruction_ref>
module::get_inputs(const std::unordered_map<instruction_ref, instruction_ref>& map_ins) const
{
std::vector<instruction_ref> inputs;
auto params = this->get_parameters();
sort_params(params);
shivadbhavsar marked this conversation as resolved.
Show resolved Hide resolved

std::transform(params.begin(),
params.end(),
std::back_inserter(inputs),
[&](instruction_ref param) { return map_ins.at(param); });

return inputs;
}

static std::vector<instruction_ref>
select_params(const std::vector<instruction_ref>& instructions,
const std::unordered_map<instruction_ref, instruction_ref>& param_map)
Expand Down Expand Up @@ -1008,6 +1023,12 @@ static void insert_params(module& m,
}
}

void module::add_params(const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>* map_ins)
{
insert_params(*this, inputs, *map_ins);
}

std::vector<instruction_ref>
module::fuse(const std::vector<instruction_ref>& inss,
std::unordered_map<instruction_ref, instruction_ref>* map_ins,
Expand Down
Loading
Loading