From 7e24411dae95a008490a7fffaf5035669345f61f Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 15 Apr 2024 14:18:56 -0700 Subject: [PATCH 001/145] Add atomic ops --- .../include/migraphx/kernels/atomic.hpp | 118 ++++++++++++++++++ .../kernels/scatter_reduction_modes.hpp | 37 +----- .../include/migraphx/kernels/types.hpp | 1 + 3 files changed, 124 insertions(+), 32 deletions(-) create mode 100644 src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp new file mode 100644 index 00000000000..ff931f50929 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp @@ -0,0 +1,118 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_ATOMIC_HPP +#define MIGRAPHX_GUARD_KERNELS_ATOMIC_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef MIGRAPHX_ALLOW_ATOMIC_CAS +// NOLINTNEXTLINE +#define MIGRAPHX_ALLOW_ATOMIC_CAS 0 +#endif + +// NOLINTNEXTLINE +#define MIGRAPHX_ATOMIC_CAS_WARNING() \ + MIGRAPHX_ASSERT(MIGRAPHX_ALLOW_ATOMIC_CAS and "Using atomicCAS is slow") + +namespace migraphx { +namespace atomic { + +using cas_rank = rank<1>; + +template +MIGRAPHX_DEVICE_CONSTEXPR void cas(rank<1>, T& x, T y, Op op) +{ + MIGRAPHX_ATOMIC_CAS_WARNING(); + using U = conditional_t; + U* address = reinterpret_cast(&x); + U expected = __hip_atomic_load(address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + while(not __hip_atomic_compare_exchange_strong(address, + &expected, + bit_cast(op(bit_cast(expected), y)), + __ATOMIC_RELAXED, + __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT)) + { + } +} + +template +MIGRAPHX_DEVICE_CONSTEXPR auto cas(rank<0>, vec& x, vec y, Op op) + -> decltype(cas(cas_rank{}, x[0], y[0], op), void()) +{ + for(index_int i = 0; i < N; i++) + { + cas(cas_rank{}, x[i], y[i], op); + } +} + +template +MIGRAPHX_DEVICE_CONSTEXPR auto builtin_assign(T& x, T y, op::sum) + MIGRAPHX_RETURNS(unsafeAtomicAdd(&x, y)); + +__device__ inline void builtin_assign(half2& x, half2 y, op::sum) +{ + __builtin_amdgcn_global_atomic_fadd_v2f16(&x, y); +} + +template +constexpr bool is_aligned(const void* ptr) +{ + auto iptr = bit_cast(ptr); + return (iptr % alignof(T)) != 0; +} + +__device__ inline void builtin_assign(half& x, half y, op::sum) +{ + half* address = &x; + if(is_aligned(address)) + { + __builtin_amdgcn_global_atomic_fadd_v2f16(address, half2{half(y), half(0)}); + } + else + { + __builtin_amdgcn_global_atomic_fadd_v2f16(address - 1, half2{half(0), half(y)}); + } +} + +template +MIGRAPHX_DEVICE_CONSTEXPR auto builtin_assign(T& x, T y, op::min) + MIGRAPHX_RETURNS(unsafeAtomicMin(&x, y)); + +template +MIGRAPHX_DEVICE_CONSTEXPR auto builtin_assign(T& x, T y, op::max) + MIGRAPHX_RETURNS(unsafeAtomicMax(&x, y)); + +template +MIGRAPHX_DEVICE_CONSTEXPR auto builtin_assign(vec& x, vec y, Op op) + -> decltype(builtin_assign(x[0], y[0], op), void()) +{ + for(index_int i = 0; i < N; i++) + { + builtin_assign(x[i], y[i], op); + } +} + +template +MIGRAPHX_DEVICE_CONSTEXPR auto assign(rank<0>, T& x, T y, Op op) + MIGRAPHX_RETURNS(cas(cas_rank{}, x, y, op)); + +template +MIGRAPHX_DEVICE_CONSTEXPR auto assign(rank<1>, T& x, T y, Op op) + MIGRAPHX_RETURNS(builtin_assign(x, y, op)); + +} // namespace atomic + +template +MIGRAPHX_DEVICE_CONSTEXPR void atomic_assign(T& x, U y, Op op) +{ + atomic::assign(rank<1>{}, x, T(y), op); +} + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_ATOMIC_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp index b0236f92f2c..166552a849d 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/scatter_reduction_modes.hpp @@ -26,16 +26,7 @@ #include #include -#include - -#ifndef MIGRAPHX_ALLOW_ATOMIC_CAS -// NOLINTNEXTLINE -#define MIGRAPHX_ALLOW_ATOMIC_CAS 0 -#endif - -// NOLINTNEXTLINE -#define MIGRAPHX_ATOMIC_CAS_WARNING() \ - MIGRAPHX_ASSERT(MIGRAPHX_ALLOW_ATOMIC_CAS and "Using atomicCAS is slow") +#include namespace migraphx { @@ -53,15 +44,7 @@ struct assign_add template MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const { - if constexpr(is_same{} or is_same{}) - { - unsafeAtomicAdd(&x, T(y)); - } - else - { - MIGRAPHX_ATOMIC_CAS_WARNING(); - atomicAdd(&x, T(y)); - } + atomic_assign(x, y, op::sum{}); } }; @@ -70,17 +53,7 @@ struct assign_mul template MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const { - MIGRAPHX_ATOMIC_CAS_WARNING(); - T old = x; - T assumed; - do - { - assumed = old; - old = atomicCAS(&x, assumed, assumed * y); -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wfloat-equal" - } while(assumed != old); -#pragma clang diagnostic pop + atomic_assign(x, y, op::product{}); } }; @@ -89,7 +62,7 @@ struct assign_max template MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const { - atomicMax(&x, T(y)); + atomic_assign(x, y, op::max{}); } }; @@ -98,7 +71,7 @@ struct assign_min template MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const { - atomicMin(&x, T(y)); + atomic_assign(x, y, op::min{}); } }; diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp index c2d40ff63fd..a3e03507789 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp @@ -58,6 +58,7 @@ using uint64_t = std::uint64_t; #endif // MIGRAPHX_USE_HIPRTC using index_int = uint32_t; using diff_int = int32_t; +using uintptr_t = uint64_t; static_assert(sizeof(int8_t) == 1, "int8_t must be 1 bytes"); static_assert(sizeof(uint8_t) == 1, "uint8_t must be 1 bytes"); From 244d8b8eb3054f53c09a0a7c8f6174cca1943764 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 15 Apr 2024 14:19:15 -0700 Subject: [PATCH 002/145] Add missing header --- .../kernels/include/migraphx/kernels/rank.hpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 src/targets/gpu/kernels/include/migraphx/kernels/rank.hpp diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/rank.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/rank.hpp new file mode 100644 index 00000000000..4058de120fd --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/rank.hpp @@ -0,0 +1,17 @@ +#ifndef MIGRAPHX_GUARD_KERNELS_RANK_HPP +#define MIGRAPHX_GUARD_KERNELS_RANK_HPP + +namespace migraphx { + +template +struct rank : rank +{ +}; + +template <> +struct rank<0> +{ +}; + +} // namespace migraphx +#endif // MIGRAPHX_GUARD_KERNELS_RANK_HPP From c53c40a6812a38e7250aebd21943adc0710859d8 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 15 Apr 2024 14:39:41 -0700 Subject: [PATCH 003/145] Add support for half type --- src/split_reduce.cpp | 2 +- src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index f2bcf343131..fefa5b80138 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -167,7 +167,7 @@ void split_reduce::apply(module_pass_manager& mpm) const // Only use split reduce with float for now // TODO: Support half and other data types if(not std::all_of(splits.begin(), splits.end(), [](instruction_ref split) { - return split->get_shape().type() == shape::float_type; + return contains({shape::float_type, shape::half_type}, split->get_shape().type()); })) continue; auto v = ins->get_operator().to_value(); diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp index ff931f50929..56fb8d6341f 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp @@ -64,7 +64,7 @@ template constexpr bool is_aligned(const void* ptr) { auto iptr = bit_cast(ptr); - return (iptr % alignof(T)) != 0; + return (iptr % alignof(T)) == 0; } __device__ inline void builtin_assign(half& x, half y, op::sum) From d39f83252cc6fc8bde53170e7cd4845ddab13313 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 17 Apr 2024 19:17:27 -0700 Subject: [PATCH 004/145] Add fuse mthods to module --- src/fuse_reduce.cpp | 95 ++++++++------------------------- src/include/migraphx/module.hpp | 10 ++++ src/module.cpp | 56 +++++++++++++++++++ src/targets/gpu/fuse_mlir.cpp | 16 ++++++ 4 files changed, 103 insertions(+), 74 deletions(-) diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 7f60d5ebe70..cbcca1cad8b 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -83,67 +83,14 @@ struct fused_reduce }; MIGRAPHX_REGISTER_OP(fused_reduce); -static void insert_params(module_ref sm, - const std::vector& inputs, - std::unordered_map& map_ins) -{ - auto n = sm->get_parameter_shapes().size(); - for(auto input : inputs) - { - if(contains(map_ins, input)) - continue; - map_ins[input] = - sm->add_parameter("x" + std::to_string(n++), input->get_shape().as_standard()); - } -} - -static auto insert_ins_in_submodule(module_ref sm, - instruction_ref ins, - std::unordered_map& map_ins) -{ - insert_params(sm, ins->inputs(), map_ins); - return sm->add_instructions({ins}, &map_ins); -} - -static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins) -{ - std::unordered_map map_ins; - return insert_ins_in_submodule(sm, ins, map_ins); -} - -static auto -insert_module_in_submodule(module_ref sm, - const std::vector& inputs, - module_ref m, - std::unordered_map& map_ins, - module::inserter insert = nullptr) -{ - insert_params(sm, inputs, map_ins); - auto param_map = m->get_ins_param_map(inputs); - for(auto&& [input, param] : param_map) - { - map_ins[param] = map_ins.at(input); - } - return sm->add_instructions(m, &map_ins, std::move(insert)); -} - static auto insert_module_in_submodule(module_ref sm, instruction_ref ins, - std::unordered_map& map_ins, + std::unordered_map* map_ins = nullptr, module::inserter insert = nullptr) { - return insert_module_in_submodule( - sm, ins->inputs(), ins->module_inputs().front(), map_ins, std::move(insert)); -} - -static auto insert_module_in_submodule(module_ref sm, - const std::vector& inputs, - module_ref m, - module::inserter insert = nullptr) -{ - std::unordered_map map_ins; - return insert_module_in_submodule(sm, inputs, m, map_ins, std::move(insert)); + assert(ins->module_inputs().size() == 1); + return sm->fuse(*ins->module_inputs().front(), ins->inputs(), map_ins, std::move(insert)); } static std::vector @@ -186,7 +133,7 @@ static void create_reduce_modules(module_pass_manager& mpm) mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++)); rm->set_bypass(); - rm->add_return(insert_ins_in_submodule(rm, ins)); + rm->add_return(rm->fuse({ins})); auto v = ins->get_operator().to_value(); mpm.get_module().replace_instruction( ins, make_op("fused_reduce", {{"axes", v["axes"]}}), ins->inputs(), {rm}); @@ -234,17 +181,17 @@ struct find_pointwise_reduce std::unordered_map map_ins; // Insert pointwise - auto rins = insert_ins_in_submodule(rm, input, map_ins).front(); + auto rins = rm->fuse({input}, &map_ins).front(); map_ins[input] = rins; if(contains(r.instructions, "broadcast")) { auto broadcast = r.instructions["broadcast"]; - map_ins[broadcast] = insert_ins_in_submodule(rm, broadcast, map_ins).front(); + map_ins[broadcast] = rm->fuse({broadcast}, &map_ins).front(); } // Insert fused_reduce - rm->add_return(insert_module_in_submodule(rm, reduce, map_ins)); + rm->add_return(insert_module_in_submodule(rm, reduce, &map_ins)); auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm}); @@ -271,12 +218,12 @@ struct find_reduce_pointwise rm->set_bypass(); std::unordered_map map_ins; // Copy module instructions - insert_module_in_submodule(rm, reduce, map_ins); + insert_module_in_submodule(rm, reduce, &map_ins); if(contains(r.instructions, "broadcast")) { auto broadcast = r.instructions["broadcast"]; map_ins[broadcast->inputs().front()] = rm->get_returns().front(); - auto bout = insert_ins_in_submodule(rm, broadcast, map_ins); + auto bout = rm->fuse({broadcast}, &map_ins); map_ins[input] = bout.front(); } else @@ -284,7 +231,7 @@ struct find_reduce_pointwise map_ins[input] = rm->get_returns().front(); } - auto out = insert_ins_in_submodule(rm, pw, map_ins); + auto out = rm->fuse({pw}, &map_ins); rm->replace_return(out); auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); @@ -315,12 +262,12 @@ struct find_reduce_reduce std::unordered_map map_ins; // Copy reduce1 instructions - insert_module_in_submodule(rm, reduce2, map_ins); + insert_module_in_submodule(rm, reduce2, &map_ins); if(contains(r.instructions, "broadcast")) { auto broadcast = r.instructions["broadcast"]; map_ins[broadcast->inputs().front()] = rm->get_returns().front(); - auto bout = insert_ins_in_submodule(rm, broadcast, map_ins); + auto bout = rm->fuse({broadcast}, &map_ins); map_ins[input] = bout.front(); } else @@ -328,7 +275,7 @@ struct find_reduce_reduce map_ins[input] = rm->get_returns().front(); } - auto out = insert_module_in_submodule(rm, reduce1, map_ins); + auto out = insert_module_in_submodule(rm, reduce1, &map_ins); rm->replace_return(out); auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); @@ -370,14 +317,14 @@ struct reduce_reshape : rewrite_reshapes_base auto dims = base_dims(inputs); auto* oldm = ins->module_inputs().front(); auto* sm = mpm.create_module(oldm->name() + "_reshape"); - insert_module_in_submodule(sm, inputs, oldm, transform_op([&](const operation& sop) { - if(contains(sop.name(), "reduce")) - return make_op(sop.name(), {{"axes", axes}}); - if(sop.name() == "multibroadcast") - return make_op("multibroadcast", {{"out_lens", dims}}); - assert(sop.name() == "pointwise"); - return sop; - })); + sm->fuse(*oldm, inputs, nullptr, transform_op([&](const operation& sop) { + if(contains(sop.name(), "reduce")) + return make_op(sop.name(), {{"axes", axes}}); + if(sop.name() == "multibroadcast") + return make_op("multibroadcast", {{"out_lens", dims}}); + assert(sop.name() == "pointwise"); + return sop; + })); return mpm.get_module().insert_instruction(ins, fused_reduce{axes}, inputs, {sm}); } diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index f9d41121159..159313dfbc9 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -231,6 +231,16 @@ struct MIGRAPHX_EXPORT module const std::vector& splits1, const std::vector& splits2) const; + std::vector fuse( + const std::vector& inss, + std::unordered_map* map_ins = nullptr, inserter insert = nullptr); + + std::vector + fuse(const module& m, + const std::vector& inputs, + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); + void debug_print() const; void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins, diff --git a/src/module.cpp b/src/module.cpp index 8c28c15d22d..9a9f8c0b07d 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -914,6 +914,62 @@ std::array module::split(const std::vector& inputs, + std::unordered_map& map_ins) +{ + auto n = m.get_parameter_shapes().size(); + for(auto input : inputs) + { + if(contains(map_ins, input)) + continue; + map_ins[input] = + m.add_parameter(param_name(n++), input->get_shape().as_standard()); + } +} + +std::vector module::fuse( + const std::vector& inss, + std::unordered_map* map_ins, module::inserter insert) +{ + std::unordered_map default_map_ins; + if(not map_ins) + map_ins = &default_map_ins; + std::vector inputs; + for(auto ins:inss) + { + for(auto input:ins->inputs()) + { + if(contains(inss, input)) + continue; + if(contains(inputs, input)) + continue; + inputs.push_back(input); + } + } + insert_params(*this, inputs, *map_ins); + return this->add_instructions(inss, map_ins, std::move(insert)); +} + +std::vector + module::fuse( + const module& m, + const std::vector& inputs, + std::unordered_map* map_ins, + module::inserter insert) +{ + std::unordered_map default_map_ins; + if(not map_ins) + map_ins = &default_map_ins; + insert_params(*this, inputs, *map_ins); + auto param_map = m.get_ins_param_map(inputs); + for(auto&& [input, param] : param_map) + { + (*map_ins)[param] = map_ins->at(input); + } + return this->add_instructions(&m, map_ins, std::move(insert)); +} + void module_with_inputs::replace(instruction_ref ins, instruction_ref rep) { auto it = std::find(inputs.begin(), inputs.end(), ins); diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index a0a16512358..de1ffa0330d 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -604,6 +604,22 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op } }; +struct find_pointwise_mlir +{ + auto matcher() const + { + return match::name("gpu::mlir_op")(match::any_of[match::inputs()](match::name("pointwise")(match::used_once()).bind("pointwise"))); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto ins = r.result; + auto pw = r.instructions["pointwise"]; + + + } +}; + } // namespace #endif // MIGRAPHX_MLIR From d2d3baeaf99a89b04685d5f7b22c32fb36c68fbf Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 17 Apr 2024 19:17:30 -0700 Subject: [PATCH 005/145] Format --- src/fuse_reduce.cpp | 16 ++++++++-------- src/include/migraphx/module.hpp | 13 +++++++------ src/module.cpp | 23 +++++++++++------------ src/targets/gpu/fuse_mlir.cpp | 7 +++---- 4 files changed, 29 insertions(+), 30 deletions(-) diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index cbcca1cad8b..efb897d2d2b 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -87,7 +87,7 @@ static auto insert_module_in_submodule(module_ref sm, instruction_ref ins, std::unordered_map* map_ins = nullptr, - module::inserter insert = nullptr) + module::inserter insert = nullptr) { assert(ins->module_inputs().size() == 1); return sm->fuse(*ins->module_inputs().front(), ins->inputs(), map_ins, std::move(insert)); @@ -318,13 +318,13 @@ struct reduce_reshape : rewrite_reshapes_base auto* oldm = ins->module_inputs().front(); auto* sm = mpm.create_module(oldm->name() + "_reshape"); sm->fuse(*oldm, inputs, nullptr, transform_op([&](const operation& sop) { - if(contains(sop.name(), "reduce")) - return make_op(sop.name(), {{"axes", axes}}); - if(sop.name() == "multibroadcast") - return make_op("multibroadcast", {{"out_lens", dims}}); - assert(sop.name() == "pointwise"); - return sop; - })); + if(contains(sop.name(), "reduce")) + return make_op(sop.name(), {{"axes", axes}}); + if(sop.name() == "multibroadcast") + return make_op("multibroadcast", {{"out_lens", dims}}); + assert(sop.name() == "pointwise"); + return sop; + })); return mpm.get_module().insert_instruction(ins, fused_reduce{axes}, inputs, {sm}); } diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 159313dfbc9..2fa9d1aa6e5 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -231,15 +231,16 @@ struct MIGRAPHX_EXPORT module const std::vector& splits1, const std::vector& splits2) const; - std::vector fuse( - const std::vector& inss, - std::unordered_map* map_ins = nullptr, inserter insert = nullptr); + std::vector + fuse(const std::vector& inss, + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); std::vector fuse(const module& m, - const std::vector& inputs, - std::unordered_map* map_ins = nullptr, - inserter insert = nullptr); + const std::vector& inputs, + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); void debug_print() const; void debug_print(instruction_ref ins) const; diff --git a/src/module.cpp b/src/module.cpp index 9a9f8c0b07d..73b51bd05e9 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -923,22 +923,22 @@ static void insert_params(module& m, { if(contains(map_ins, input)) continue; - map_ins[input] = - m.add_parameter(param_name(n++), input->get_shape().as_standard()); + map_ins[input] = m.add_parameter(param_name(n++), input->get_shape().as_standard()); } } -std::vector module::fuse( - const std::vector& inss, - std::unordered_map* map_ins, module::inserter insert) +std::vector +module::fuse(const std::vector& inss, + std::unordered_map* map_ins, + module::inserter insert) { std::unordered_map default_map_ins; if(not map_ins) map_ins = &default_map_ins; std::vector inputs; - for(auto ins:inss) + for(auto ins : inss) { - for(auto input:ins->inputs()) + for(auto input : ins->inputs()) { if(contains(inss, input)) continue; @@ -952,11 +952,10 @@ std::vector module::fuse( } std::vector - module::fuse( - const module& m, - const std::vector& inputs, - std::unordered_map* map_ins, - module::inserter insert) +module::fuse(const module& m, + const std::vector& inputs, + std::unordered_map* map_ins, + module::inserter insert) { std::unordered_map default_map_ins; if(not map_ins) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index de1ffa0330d..1f86845d758 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -608,15 +608,14 @@ struct find_pointwise_mlir { auto matcher() const { - return match::name("gpu::mlir_op")(match::any_of[match::inputs()](match::name("pointwise")(match::used_once()).bind("pointwise"))); + return match::name("gpu::mlir_op")(match::any_of[match::inputs()]( + match::name("pointwise")(match::used_once()).bind("pointwise"))); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto ins = r.result; - auto pw = r.instructions["pointwise"]; - - + auto pw = r.instructions["pointwise"]; } }; From af835094b51c3bfb48c82da7b8fdcb92153a9189 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 18 Apr 2024 07:29:45 -0700 Subject: [PATCH 006/145] Add some initial code --- src/targets/gpu/fuse_mlir.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 1f86845d758..87528c06e6d 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -616,6 +616,20 @@ struct find_pointwise_mlir { auto ins = r.result; auto pw = r.instructions["pointwise"]; + + auto* mm = ins->module_inputs().front(); + auto* pm = pw->module_inputs().front(); + + module_ref m = mpm.create_module(pm->name() + ":" + mm->name(), *pm); + m->fuse(*mm, ins->inputs()); + + // TODO: Use find_inputs + auto inputs = pw->inputs(); + inputs.insert(inputs.end(), ins->inputs().begin(), ins->inputs().end()); + + mpm.get_module().replace_instruction( + ins, ins->get_operator(), inputs, {m}); + } }; From ac479548aa7996ee3b783dce856440625c3ec60b Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 18 Apr 2024 07:29:51 -0700 Subject: [PATCH 007/145] Format --- src/targets/gpu/fuse_mlir.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 87528c06e6d..f96350c1790 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -627,9 +627,7 @@ struct find_pointwise_mlir auto inputs = pw->inputs(); inputs.insert(inputs.end(), ins->inputs().begin(), ins->inputs().end()); - mpm.get_module().replace_instruction( - ins, ins->get_operator(), inputs, {m}); - + mpm.get_module().replace_instruction(ins, ins->get_operator(), inputs, {m}); } }; From c9407aa1c8cb85ec8568c348e5d743a08056131d Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 18 Apr 2024 13:48:53 -0700 Subject: [PATCH 008/145] Reuse find_inputs --- src/fuse_reduce.cpp | 33 ++++------------------------ src/include/migraphx/param_utils.hpp | 6 +++++ src/param_utils.cpp | 29 ++++++++++++++++++++++++ src/targets/gpu/fuse_mlir.cpp | 16 +++++++++----- 4 files changed, 50 insertions(+), 34 deletions(-) diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index efb897d2d2b..1ba50a8cee0 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -93,32 +94,6 @@ insert_module_in_submodule(module_ref sm, return sm->fuse(*ins->module_inputs().front(), ins->inputs(), map_ins, std::move(insert)); } -static std::vector -find_inputs(const_module_ref sm, - const module& parent, - const std::unordered_map& map_ins) -{ - std::vector result; - std::map names; - for(auto&& [input, param] : map_ins) - { - if(not sm->has_instruction(param)) - continue; - if(param->name() != "@param") - continue; - if(not parent.has_instruction(input)) - continue; - auto v = param->get_operator().to_value(); - auto name = v.at("parameter").to(); - names[name] = input; - } - std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) { - return p.second; - }); - assert(result.size() == sm->get_parameter_shapes().size()); - return result; -} - static void create_reduce_modules(module_pass_manager& mpm) { std::size_t n = 0; @@ -193,7 +168,7 @@ struct find_pointwise_reduce // Insert fused_reduce rm->add_return(insert_module_in_submodule(rm, reduce, &map_ins)); - auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); + auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm); mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm}); } }; @@ -234,7 +209,7 @@ struct find_reduce_pointwise auto out = rm->fuse({pw}, &map_ins); rm->replace_return(out); - auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); + auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm); mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm}); } }; @@ -278,7 +253,7 @@ struct find_reduce_reduce auto out = insert_module_in_submodule(rm, reduce1, &map_ins); rm->replace_return(out); - auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); + auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm); mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm}); } }; diff --git a/src/include/migraphx/param_utils.hpp b/src/include/migraphx/param_utils.hpp index 1552c28300b..f594f8be7f7 100644 --- a/src/include/migraphx/param_utils.hpp +++ b/src/include/migraphx/param_utils.hpp @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -37,6 +38,11 @@ std::string param_name(std::size_t i, const std::string& prefix = "x"); void sort_params(std::vector& params); +std::vector +find_inputs(const std::unordered_map& map_ins, + const_module_ref parent, + const_module_ref sub); + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif // MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP diff --git a/src/param_utils.cpp b/src/param_utils.cpp index 61302a0afba..a3a07acaa26 100644 --- a/src/param_utils.cpp +++ b/src/param_utils.cpp @@ -25,6 +25,9 @@ #include #include #include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -43,5 +46,31 @@ void sort_params(std::vector& params) })); } +std::vector +find_inputs(const std::unordered_map& map_ins, + const_module_ref parent, + const_module_ref sub) +{ + std::vector result; + std::map names; + for(auto&& [input, param] : map_ins) + { + if(sub and not sub->has_instruction(param)) + continue; + if(param->name() != "@param") + continue; + if(parent and not parent->has_instruction(input)) + continue; + auto v = param->get_operator().to_value(); + auto name = v.at("parameter").to(); + names[name] = input; + } + std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) { + return p.second; + }); + assert(not sub or result.size() == sub->get_parameter_shapes().size()); + return result; +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index f96350c1790..a71451176f9 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include namespace migraphx { @@ -620,13 +621,16 @@ struct find_pointwise_mlir auto* mm = ins->module_inputs().front(); auto* pm = pw->module_inputs().front(); - module_ref m = mpm.create_module(pm->name() + ":" + mm->name(), *pm); - m->fuse(*mm, ins->inputs()); + std::unordered_map map_ins; + module_ref m = mpm.create_module(pm->name() + ":" + mm->name()); + m->set_bypass(); + auto rins = m->fuse(*pm, pw->inputs(), &map_ins).front(); + map_ins[pw] = rins; - // TODO: Use find_inputs - auto inputs = pw->inputs(); - inputs.insert(inputs.end(), ins->inputs().begin(), ins->inputs().end()); + auto ret = m->fuse(*mm, ins->inputs(), &map_ins); + m->add_return({ret}); + auto inputs = find_inputs(map_ins, &mpm.get_module(), m); mpm.get_module().replace_instruction(ins, ins->get_operator(), inputs, {m}); } }; @@ -666,6 +670,8 @@ void fuse_mlir::apply(module_pass_manager& mpm) const mpm, find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)}, find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::fast)}); + + match::find_matches(mpm, find_pointwise_mlir{}); #else (void)mpm; #endif From ea41fb98676f32eed6c7e3b28afee1fd7902724c Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 18 Apr 2024 13:48:57 -0700 Subject: [PATCH 009/145] Format --- src/targets/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index a71451176f9..671ddd9042c 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -624,7 +624,7 @@ struct find_pointwise_mlir std::unordered_map map_ins; module_ref m = mpm.create_module(pm->name() + ":" + mm->name()); m->set_bypass(); - auto rins = m->fuse(*pm, pw->inputs(), &map_ins).front(); + auto rins = m->fuse(*pm, pw->inputs(), &map_ins).front(); map_ins[pw] = rins; auto ret = m->fuse(*mm, ins->inputs(), &map_ins); From 3931cfca3f657c01829b146d440a406742f8f01d Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 20 Apr 2024 09:07:49 -0700 Subject: [PATCH 010/145] Handle two reductions --- src/dom_info.cpp | 6 ++-- src/include/migraphx/dom_info.hpp | 2 +- src/include/migraphx/ranges.hpp | 6 ++++ src/split_reduce.cpp | 46 +++++++++++++------------------ 4 files changed, 29 insertions(+), 31 deletions(-) diff --git a/src/dom_info.cpp b/src/dom_info.cpp index 400dd80e0e1..6d71cad769c 100644 --- a/src/dom_info.cpp +++ b/src/dom_info.cpp @@ -48,8 +48,8 @@ bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins struct module_visitor { - module* mm; - module& get_nodes() const { return *mm; } + const module* mm; + const module& get_nodes() const { return *mm; } const std::vector& get_children(instruction_ref ins) { return ins->inputs(); } }; @@ -91,7 +91,7 @@ dominator_info compute_dominator_generic(Visitor v) return info; } -dominator_info compute_dominator(module& m) +dominator_info compute_dominator(const module& m) { return compute_dominator_generic(module_visitor{&m}); } diff --git a/src/include/migraphx/dom_info.hpp b/src/include/migraphx/dom_info.hpp index 7fd6db3a18e..23103919ad9 100644 --- a/src/include/migraphx/dom_info.hpp +++ b/src/include/migraphx/dom_info.hpp @@ -41,7 +41,7 @@ struct MIGRAPHX_EXPORT dominator_info std::unordered_map ins2idom; }; -MIGRAPHX_EXPORT dominator_info compute_dominator(module& m); +MIGRAPHX_EXPORT dominator_info compute_dominator(const module& m); // MIGRAPHX_EXPORT dominator_info compute_dominator_naive(const module& m); } // namespace MIGRAPHX_INLINE_NS diff --git a/src/include/migraphx/ranges.hpp b/src/include/migraphx/ranges.hpp index dfe2d0a577c..d4bfa2cb68f 100644 --- a/src/include/migraphx/ranges.hpp +++ b/src/include/migraphx/ranges.hpp @@ -192,6 +192,12 @@ void copy(Range&& r, Iterator it) std::copy(r.begin(), r.end(), it); } +template +void copy_if(Range&& r, Iterator it, Predicate pred) +{ + std::copy_if(r.begin(), r.end(), it, pred); +} + template void transform(Range&& r, Iterator it, F f) { diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index fefa5b80138..d26a1e2414f 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -23,6 +23,7 @@ * */ #include +#include #include #include #include @@ -62,25 +63,13 @@ struct split_fused_reduce MIGRAPHX_THROW("Only one output supported"); auto names = sm->get_parameter_names(); check_shapes{inputs, *this}.has(names.size()).same_ndims(); - std::sort(names.begin(), names.end()); - auto shapes = sm->get_parameter_shapes(); - // Check dimension matches for each input - if(not equal(names, inputs, [&](const auto& name, const auto& input) { - return shapes.at(name).lens() == input.lens(); - })) - MIGRAPHX_THROW("Dimenstion does not match the submodule."); - const auto& s = inputs.at(0); - auto lens = s.lens(); - if(lens != sm->get_output_shapes().front().lens()) - { - for(const auto& axis : axes) - { - lens[axis] = 1; - } - } - return shape::from_permutation( - sm->get_output_shapes().front().type(), lens, find_permutation(inputs)); + auto result = sm->compute_shapes( + inputs, + {.name = name(), .strict_type = true, .strict_lens = true}); + if(result.size() == 1) + return result.front(); + return shape{result}; } std::string name() const { return "split_fused_reduce"; } @@ -92,18 +81,21 @@ static bool is_reduce(const instruction& ins) { return contains(ins.name(), "red static std::vector find_split(const_module_ref rm) { std::vector result; - auto reduce_ins = std::find_if(rm->begin(), rm->end(), &is_reduce); - if(reduce_ins == rm->end()) - return result; - // Bail if there is more than one reduce for now - // TODO: Support multiple reductions - if(std::any_of(std::next(reduce_ins), rm->end(), &is_reduce)) - return result; + copy_if(iterator_for(*rm), std::back_inserter(result), [](auto ins){ return is_reduce(*ins); }); + // if(result.size() > 2) + if(result.size() > 1) + return {}; // Only handle reduce_sum for now // TODO: Support other reduction types - if(reduce_ins->name() != "reduce_sum") + if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) { return ins->name() == "reduce_sum"; })) + return {}; + if(result.size() < 2) return result; - result.push_back(reduce_ins); + dominator_info dom = compute_dominator(*rm); + if(dom.strictly_dominate(result[0], result[1])) + return {}; + if(dom.strictly_dominate(result[1], result[0])) + return {}; return result; } From 0370543f396b2d2edc98497f4ba6c9d55b09a027 Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 20 Apr 2024 09:08:22 -0700 Subject: [PATCH 011/145] Format --- src/split_reduce.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index d26a1e2414f..f564ec9aa5b 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -64,9 +64,8 @@ struct split_fused_reduce auto names = sm->get_parameter_names(); check_shapes{inputs, *this}.has(names.size()).same_ndims(); - auto result = sm->compute_shapes( - inputs, - {.name = name(), .strict_type = true, .strict_lens = true}); + auto result = + sm->compute_shapes(inputs, {.name = name(), .strict_type = true, .strict_lens = true}); if(result.size() == 1) return result.front(); return shape{result}; @@ -81,13 +80,16 @@ static bool is_reduce(const instruction& ins) { return contains(ins.name(), "red static std::vector find_split(const_module_ref rm) { std::vector result; - copy_if(iterator_for(*rm), std::back_inserter(result), [](auto ins){ return is_reduce(*ins); }); + copy_if( + iterator_for(*rm), std::back_inserter(result), [](auto ins) { return is_reduce(*ins); }); // if(result.size() > 2) if(result.size() > 1) return {}; // Only handle reduce_sum for now // TODO: Support other reduction types - if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) { return ins->name() == "reduce_sum"; })) + if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) { + return ins->name() == "reduce_sum"; + })) return {}; if(result.size() < 2) return result; From d4db0f6ec01aa54ff54f05d010855919185e7afb Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 20 Apr 2024 09:46:53 -0700 Subject: [PATCH 012/145] Handle multi outputs in split reduce --- src/split_reduce.cpp | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index f564ec9aa5b..cab187e2d52 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -59,13 +59,12 @@ struct split_fused_reduce if(mods.size() != 1) MIGRAPHX_THROW("should have one submodule."); const auto* sm = mods.front(); - if(sm->get_output_shapes().size() != 1) - MIGRAPHX_THROW("Only one output supported"); auto names = sm->get_parameter_names(); check_shapes{inputs, *this}.has(names.size()).same_ndims(); - auto result = - sm->compute_shapes(inputs, {.name = name(), .strict_type = true, .strict_lens = true}); + auto result = sm->compute_shapes( + inputs, + {.name = name(), .strict_type = true, .strict_lens = true}); if(result.size() == 1) return result.front(); return shape{result}; @@ -80,16 +79,12 @@ static bool is_reduce(const instruction& ins) { return contains(ins.name(), "red static std::vector find_split(const_module_ref rm) { std::vector result; - copy_if( - iterator_for(*rm), std::back_inserter(result), [](auto ins) { return is_reduce(*ins); }); - // if(result.size() > 2) - if(result.size() > 1) + copy_if(iterator_for(*rm), std::back_inserter(result), [](auto ins){ return is_reduce(*ins); }); + if(result.size() > 2) return {}; // Only handle reduce_sum for now // TODO: Support other reduction types - if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) { - return ins->name() == "reduce_sum"; - })) + if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) { return ins->name() == "reduce_sum"; })) return {}; if(result.size() < 2) return result; @@ -193,7 +188,19 @@ void split_reduce::apply(module_pass_manager& mpm) const mods[0].inputs, {splitm}); - mods[1].replace(splits.front(), split_reduce); + std::vector split_reduce_each; + if(splits.size() == 1) + { + split_reduce_each = {split_reduce}; + } + else + { + transform(range(splits.size()), std::back_inserter(split_reduce_each), [&](auto i) { + return mpm.get_module().insert_instruction(ins, make_op("get_tuple_elem", {{"index", i}}), split_reduce); + }); + } + + mods[1].replace(splits, split_reduce_each); auto replaced = insert_module_inline(mpm.get_module(), ins, mods[1]); assert(replaced.size() == 1); mpm.get_module().replace_instruction(ins, replaced.front()); From 5b3785366a2f9ee0be74fcc905c8df33215958c3 Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 20 Apr 2024 09:46:58 -0700 Subject: [PATCH 013/145] Format --- src/split_reduce.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index cab187e2d52..08a6c535f09 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -62,9 +62,8 @@ struct split_fused_reduce auto names = sm->get_parameter_names(); check_shapes{inputs, *this}.has(names.size()).same_ndims(); - auto result = sm->compute_shapes( - inputs, - {.name = name(), .strict_type = true, .strict_lens = true}); + auto result = + sm->compute_shapes(inputs, {.name = name(), .strict_type = true, .strict_lens = true}); if(result.size() == 1) return result.front(); return shape{result}; @@ -79,12 +78,15 @@ static bool is_reduce(const instruction& ins) { return contains(ins.name(), "red static std::vector find_split(const_module_ref rm) { std::vector result; - copy_if(iterator_for(*rm), std::back_inserter(result), [](auto ins){ return is_reduce(*ins); }); + copy_if( + iterator_for(*rm), std::back_inserter(result), [](auto ins) { return is_reduce(*ins); }); if(result.size() > 2) return {}; // Only handle reduce_sum for now // TODO: Support other reduction types - if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) { return ins->name() == "reduce_sum"; })) + if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) { + return ins->name() == "reduce_sum"; + })) return {}; if(result.size() < 2) return result; @@ -196,7 +198,8 @@ void split_reduce::apply(module_pass_manager& mpm) const else { transform(range(splits.size()), std::back_inserter(split_reduce_each), [&](auto i) { - return mpm.get_module().insert_instruction(ins, make_op("get_tuple_elem", {{"index", i}}), split_reduce); + return mpm.get_module().insert_instruction( + ins, make_op("get_tuple_elem", {{"index", i}}), split_reduce); }); } From ac747b220eb1cee31d9a0fca6085e9d7aeaf1155 Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 20 Apr 2024 11:13:29 -0700 Subject: [PATCH 014/145] Split two reductions --- src/module.cpp | 2 +- src/split_reduce.cpp | 8 ++++-- src/targets/gpu/compile_gen.cpp | 1 + src/targets/gpu/hip.cpp | 14 +++++++--- src/targets/gpu/jit/reduce.cpp | 17 +++++++----- .../include/migraphx/kernels/functional.hpp | 8 ++++++ .../include/migraphx/kernels/reduce.hpp | 27 ++++++++++++------- 7 files changed, 54 insertions(+), 23 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index a90328db6f4..e7467f2eb54 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -717,7 +717,7 @@ std::vector module::compute_shapes(const std::vector& inputs, [&](auto in) { return ins_shapes.at(in); }); if(ins->name() == "@return") return input_shapes; - ins_shapes[ins] = ins->get_operator().compute_shape(input_shapes); + ins_shapes[ins] = ins->get_operator().compute_shape(input_shapes, ins->module_inputs()); } } MIGRAPHX_THROW("No return found in the submodule"); diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 08a6c535f09..91d12d0482f 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -103,15 +103,19 @@ static std::vector get_alive(const_module_ref rm, { std::vector result; bool stop = false; - liveness(*rm, [&](auto ins, const auto& live_set) { + liveness(*rm, [&](auto rins, const auto& live_set) { if(stop) return; + if(rins == rm->begin()) + return; + // We want to know what instructions are live after the split instruction + auto ins = std::prev(rins); if(not contains(splits, ins)) return; std::copy_if(live_set.begin(), live_set.end(), std::back_inserter(result), - [](instruction_ref live) { return live->name() != "@param"; }); + [&](instruction_ref live) { return live->name() != "@param" and not contains(splits, live); }); stop = true; }); return result; diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp index d5148dcebcc..4b3d76ffdee 100644 --- a/src/targets/gpu/compile_gen.cpp +++ b/src/targets/gpu/compile_gen.cpp @@ -292,6 +292,7 @@ std::string generate_reduce(module m, const std::string& name) run_passes(m, {optimize_module{}}); m.sort(); cpp_generator g; + g.always_return_tuple(); auto param_shapes = m.get_parameter_shapes(); auto max_shape = std::max_element(param_shapes.begin(), diff --git a/src/targets/gpu/hip.cpp b/src/targets/gpu/hip.cpp index 49505bcf8be..cb1ff7fc6c9 100644 --- a/src/targets/gpu/hip.cpp +++ b/src/targets/gpu/hip.cpp @@ -304,9 +304,17 @@ argument get_preallocation(context& ctx, const std::string& id) void gpu_fill(context& ctx, const argument& dst, int value) { - // TODO: Handle non-packed tensor when value is not 0 - assert(dst.get_shape().packed() and value == 0); - hip_async_memset(ctx, dst, value); + if(dst.get_sub_objects().empty()) + { + // TODO: Handle non-packed tensor when value is not 0 + assert(dst.get_shape().packed() and value == 0); + hip_async_memset(ctx, dst, value); + } + else + { + for(auto arg:dst.get_sub_objects()) + gpu_fill(ctx, arg, value); + } } void store_preallocated_param(context& ctx, const std::string& id, const argument& a) diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index 0dc8e34b855..6a63dbcb25b 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -293,7 +293,7 @@ namespace migraphx { extern "C" { MIGRAPHX_GLOBAL void ${kernel}(${params}) { - transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) { + transform_args(make_tensors(), ${transformers}, rotate_and_pack_last<${noutputs}>())(${args})([](auto y, auto... xs) { fused_reduce(y, ${assign}{}, partial(${lambda})(xs...)); }); } @@ -312,9 +312,11 @@ struct fused_reduce_compiler : compiler { auto assign = v.get("assign", "assign_none"); auto axes = v.at("axes").to_vector(); - auto virtual_inputs = inputs; - virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes)); - virtual_inputs.push_back(get_output_shape(inputs.front(), axes)); + auto finputs = flatten(inputs); + auto noutputs = finputs.size() - inputs.size() + 1; + auto virtual_inputs = finputs; + virtual_inputs.push_back(get_reduced_shape(finputs.front(), axes)); + virtual_inputs.push_back(get_output_shape(finputs.front(), axes)); virtual_inputs = reduce_dims(normalize_permutation(virtual_inputs)); if(assign != "assign_none") virtual_inputs = split_reduce(virtual_inputs); @@ -324,7 +326,7 @@ struct fused_reduce_compiler : compiler virtual_inputs.pop_back(); hip_compile_options options; - options.inputs = inputs; + options.inputs = finputs; options.output = inputs.back(); options.virtual_inputs = virtual_inputs; auto faxis = find_fast_axis({options.virtual_inputs.front()}); @@ -367,13 +369,14 @@ struct fused_reduce_compiler : compiler auto src = interpolate_string( fused_reduce_kernel, {{"kernel", options.kernel_name}, - {"params", enum_params(inputs.size(), "void * private_p")}, - {"args", enum_params(inputs.size(), "private_p")}, + {"params", enum_params(finputs.size(), "void * private_p")}, + {"args", enum_params(finputs.size(), "private_p")}, {"assign", assign}, {"algo", algo}, {"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"}, {"lambda", v.at("lambda").to()}, {"transformers", make_transformer_args(vec)}, + {"noutputs", std::to_string(noutputs)}, {"preamble", v.get("preamble", std::string{})}}); options.emplace_param("-Wno-float-equal"); return compile_hip_code_object(src, options); diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp index fab865c0587..83c73e1cc14 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp @@ -180,6 +180,14 @@ constexpr void each_args(F) { } +template +constexpr void each_args_unpack(F f, Ts&&... xs) +{ + each_args([&](auto&& p) { + p(f); + }, static_cast(xs)...); +} + template constexpr auto fold_impl(F&&, T&& x) { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp index b2fb0f4b00f..e160543e954 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp @@ -29,6 +29,7 @@ #include #include #include +#include #include namespace migraphx { @@ -732,18 +733,24 @@ simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOu } template -__device__ void fused_reduce(Output output, Assign assign, F f) +__device__ void fused_reduce(Output output_pack, Assign assign, F f) { Algo::template run([&](auto out_idx, auto r) { - auto result = f(r, out_idx); - if constexpr(reduce::is_inner_storage{}) - { - r.inner([&](auto& y, auto x) { assign(y, x); })(output, result); - } - else - { - r.outer([&] { assign(output[out_idx], implicit_conversion(result)); }); - } + auto result_tuple = f(r, out_idx); + output_pack([&](auto... outputs) { + result_tuple([&](auto... results) { + each_args_unpack([&](auto output, auto result) { + if constexpr(reduce::is_inner_storage{}) + { + r.inner([&](auto& y, auto x) { assign(y, x); })(output, result); + } + else + { + r.outer([&] { assign(output[out_idx], implicit_conversion(result)); }); + } + }, pack(outputs, results)...); + }); + }); }); } From 2f7e96ca29fbf463511df0c274d23eebd235b944 Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 20 Apr 2024 11:13:51 -0700 Subject: [PATCH 015/145] Format --- src/split_reduce.cpp | 4 +++- src/targets/gpu/hip.cpp | 2 +- src/targets/gpu/jit/reduce.cpp | 4 ++-- .../include/migraphx/kernels/functional.hpp | 4 +--- .../include/migraphx/kernels/reduce.hpp | 22 ++++++++++--------- 5 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 91d12d0482f..106522880de 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -115,7 +115,9 @@ static std::vector get_alive(const_module_ref rm, std::copy_if(live_set.begin(), live_set.end(), std::back_inserter(result), - [&](instruction_ref live) { return live->name() != "@param" and not contains(splits, live); }); + [&](instruction_ref live) { + return live->name() != "@param" and not contains(splits, live); + }); stop = true; }); return result; diff --git a/src/targets/gpu/hip.cpp b/src/targets/gpu/hip.cpp index cb1ff7fc6c9..a1ebabb388b 100644 --- a/src/targets/gpu/hip.cpp +++ b/src/targets/gpu/hip.cpp @@ -312,7 +312,7 @@ void gpu_fill(context& ctx, const argument& dst, int value) } else { - for(auto arg:dst.get_sub_objects()) + for(auto arg : dst.get_sub_objects()) gpu_fill(ctx, arg, value); } } diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index 6a63dbcb25b..8bad88e364d 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -312,8 +312,8 @@ struct fused_reduce_compiler : compiler { auto assign = v.get("assign", "assign_none"); auto axes = v.at("axes").to_vector(); - auto finputs = flatten(inputs); - auto noutputs = finputs.size() - inputs.size() + 1; + auto finputs = flatten(inputs); + auto noutputs = finputs.size() - inputs.size() + 1; auto virtual_inputs = finputs; virtual_inputs.push_back(get_reduced_shape(finputs.front(), axes)); virtual_inputs.push_back(get_output_shape(finputs.front(), axes)); diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp index 83c73e1cc14..2c54b4b9d09 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp @@ -183,9 +183,7 @@ constexpr void each_args(F) template constexpr void each_args_unpack(F f, Ts&&... xs) { - each_args([&](auto&& p) { - p(f); - }, static_cast(xs)...); + each_args([&](auto&& p) { p(f); }, static_cast(xs)...); } template diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp index e160543e954..bd263fba821 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp @@ -739,16 +739,18 @@ __device__ void fused_reduce(Output output_pack, Assign assign, F f) auto result_tuple = f(r, out_idx); output_pack([&](auto... outputs) { result_tuple([&](auto... results) { - each_args_unpack([&](auto output, auto result) { - if constexpr(reduce::is_inner_storage{}) - { - r.inner([&](auto& y, auto x) { assign(y, x); })(output, result); - } - else - { - r.outer([&] { assign(output[out_idx], implicit_conversion(result)); }); - } - }, pack(outputs, results)...); + each_args_unpack( + [&](auto output, auto result) { + if constexpr(reduce::is_inner_storage{}) + { + r.inner([&](auto& y, auto x) { assign(y, x); })(output, result); + } + else + { + r.outer([&] { assign(output[out_idx], implicit_conversion(result)); }); + } + }, + pack(outputs, results)...); }); }); }); From c6a7caa9a4ebcba1a7f01bca132da546afa74724 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 24 Apr 2024 17:12:08 -0700 Subject: [PATCH 016/145] Add split fix --- src/module.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index e7467f2eb54..e33a4f969b4 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -922,8 +922,7 @@ generic_split(const module& m, instructions2.push_back(ins); } - std::vector inputs2 = select_params(instructions2, param_map); - inputs2.insert(inputs2.begin(), splits.begin(), splits.end()); + std::vector inputs2 = splits; module m2; std::size_t n = 0; std::unordered_map map_ins2; @@ -935,6 +934,7 @@ generic_split(const module& m, continue; if(not contains(instructions2, ins)) continue; + inputs2.push_back(param_map[ins]); map_ins2[ins] = m2.add_parameter(param_name(n++), ins->get_shape().as_standard()); } auto r = m2.add_instructions(instructions2, &map_ins2); From 25442a5367459fe82a59f1090e4ceb332fd9ae52 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 25 Apr 2024 12:46:59 -0700 Subject: [PATCH 017/145] Fix bug with live instruction after split --- src/module.cpp | 8 ++- src/split_reduce.cpp | 115 +++++++++++++++++++++++++----------------- test/split_reduce.cpp | 98 +++++++++++++++++++++++++++++++++-- 3 files changed, 169 insertions(+), 52 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index e33a4f969b4..0d624841bbd 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -939,7 +939,7 @@ generic_split(const module& m, } auto r = m2.add_instructions(instructions2, &map_ins2); m2.add_return(r); - if(map_ins != nullptr) + if(map_ins != nullptr) *map_ins = map_ins2; return {{{std::move(m1), std::move(inputs1)}, {std::move(m2), std::move(inputs2)}}}; } @@ -960,6 +960,12 @@ std::array module::split(const std::vectordebug_print(); + std::cout << "splits1:\n"; + this->debug_print(splits1); + std::cout << "splits2:\n"; + this->debug_print(splits2); std::vector new_splits2; std::transform(splits2.begin(), splits2.end(), std::back_inserter(new_splits2), [&](auto ins) { diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 106522880de..6510785ac0f 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -75,54 +75,74 @@ MIGRAPHX_REGISTER_OP(split_fused_reduce); static bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); } -static std::vector find_split(const_module_ref rm) -{ - std::vector result; - copy_if( - iterator_for(*rm), std::back_inserter(result), [](auto ins) { return is_reduce(*ins); }); - if(result.size() > 2) - return {}; - // Only handle reduce_sum for now - // TODO: Support other reduction types - if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) { - return ins->name() == "reduce_sum"; - })) - return {}; - if(result.size() < 2) - return result; - dominator_info dom = compute_dominator(*rm); - if(dom.strictly_dominate(result[0], result[1])) - return {}; - if(dom.strictly_dominate(result[1], result[0])) - return {}; - return result; -} +namespace { + struct splitter + { + const_module_ref rm; + bool strictly_dominate(instruction_ref a, instruction_ref b) + { + if(not dom.has_value()) + dom = compute_dominator(*rm); + return dom->strictly_dominate(a, b); + } -static std::vector get_alive(const_module_ref rm, - const std::vector& splits) -{ - std::vector result; - bool stop = false; - liveness(*rm, [&](auto rins, const auto& live_set) { - if(stop) - return; - if(rins == rm->begin()) - return; - // We want to know what instructions are live after the split instruction - auto ins = std::prev(rins); - if(not contains(splits, ins)) - return; - std::copy_if(live_set.begin(), - live_set.end(), - std::back_inserter(result), - [&](instruction_ref live) { - return live->name() != "@param" and not contains(splits, live); - }); - stop = true; - }); - return result; + std::vector find_splits() + { + std::vector result; + copy_if( + iterator_for(*rm), std::back_inserter(result), [](auto ins) { return is_reduce(*ins); }); + if(result.size() > 2) + return {}; + // Only handle reduce_sum for now + // TODO: Support other reduction types + if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) { + return ins->name() == "reduce_sum"; + })) + return {}; + if(result.size() < 2) + return result; + if(this->strictly_dominate(result[0], result[1])) + return {}; + if(this->strictly_dominate(result[1], result[0])) + return {}; + return result; + } + + std::vector find_alive(const std::vector& splits) + { + std::vector result; + bool stop = false; + liveness(*rm, [&](auto rins, const auto& live_set) { + if(stop) + return; + if(rins == rm->begin()) + return; + // We want to know what instructions are live after the split instruction + auto ins = std::prev(rins); + if(not contains(splits, ins)) + return; + std::copy_if(live_set.begin(), + live_set.end(), + std::back_inserter(result), + [&](instruction_ref live) { + if(live->name() == "@param") + return false; + if(contains(splits, live)) + return false; + if(splits.size() > 1 and none_of(splits, [&](instruction_ref split) { return this->strictly_dominate(live, split); })) + return false; + return true; + }); + stop = true; + }); + return result; + } + + std::optional dom = std::nullopt; + }; } + static std::string assign_op(const std::vector& splits) { static std::unordered_map m = { @@ -158,7 +178,8 @@ void split_reduce::apply(module_pass_manager& mpm) const auto* rm = ins->module_inputs().front(); if(get_reduce_size(rm) < split_size) continue; - auto splits = find_split(rm); + splitter s{rm}; + auto splits = s.find_splits(); if(splits.empty()) continue; // Only use split reduce with float for now @@ -170,7 +191,7 @@ void split_reduce::apply(module_pass_manager& mpm) const auto v = ins->get_operator().to_value(); auto axes = v["axes"].to_vector(); - auto alive = get_alive(rm, splits); + auto alive = s.find_alive(splits); std::array mods; if(not alive.empty()) diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index f757391216e..4a708feaf81 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -41,6 +41,7 @@ void run_pass(migraphx::program& p) migraphx::fuse_reduce{}, migraphx::split_reduce{.split_size = 8192}, migraphx::dead_code_elimination{}}); + } void run_fuse_pass(migraphx::program& p) @@ -59,16 +60,24 @@ bool all_instructions_are_local(const migraphx::module& m) }); } +void auto_add_return(migraphx::module_ref m, migraphx::instruction_ref ins) +{ + m->add_return({ins}); +} + +void auto_add_return(migraphx::module_ref m, std::vector inss) +{ + m->add_return(inss); +} + template -migraphx::instruction_ref add_reduce(migraphx::program& p, +migraphx::module_ref add_reduce_module(migraphx::program& p, const std::string& name, std::vector inputs, const std::vector& axes, - const std::string& assign, F f) { auto* rm = p.create_module(name); - auto* mm = p.get_main_module(); rm->set_bypass(); std::vector params; std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { @@ -77,8 +86,33 @@ migraphx::instruction_ref add_reduce(migraphx::program& p, migraphx::shape{input->get_shape().type(), input->get_shape().lens()}); }); auto r = f(rm, params, axes); - rm->add_return({r}); + auto_add_return(rm, r); EXPECT(all_instructions_are_local(*rm)); + return rm; +} + +template +migraphx::instruction_ref add_reduce(migraphx::program& p, + const std::string& name, + std::vector inputs, + const std::vector& axes, + F f) +{ + auto* mm = p.get_main_module(); + auto rm = add_reduce_module(p, name, inputs, axes, f); + return mm->add_instruction(migraphx::make_op("fused_reduce", {{"axes", axes}}), inputs, {rm}); +} + +template +migraphx::instruction_ref add_reduce(migraphx::program& p, + const std::string& name, + std::vector inputs, + const std::vector& axes, + const std::string& assign, + F f) +{ + auto* mm = p.get_main_module(); + auto rm = add_reduce_module(p, name, inputs, axes, f); return mm->add_instruction( migraphx::make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign}}), inputs, @@ -224,4 +258,60 @@ TEST_CASE(sequence_reduces) EXPECT(p1 == p2); } +TEST_CASE(double_split_live) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = add_reduce(p1, + "fuse_reduce0", + {x}, + {2}, + [&](auto* rm, const auto& inputs, const auto& axes) { + auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[0]); + auto sqrt = add_pointwise(p1, rm, "main:pointwise0", {rsum1}, single_pointwise("sqrt")); + auto sqrtb = rm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt); + auto mul = add_pointwise(p1, rm, "main:pointwise1", {inputs[0], inputs[0]}, single_pointwise("mul")); + auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), mul); + auto add = add_pointwise(p1, rm, "main:pointwise2", {rsum2, sqrt}, single_pointwise("add")); + auto addb = rm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), add); + return add_pointwise(p1, rm, "main:pointwise3", {addb, sqrtb}, single_pointwise("mul")); + }); + mm->add_return({rsum}); + + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsums = add_reduce(p2, + "fuse_reduce0_split", + {x}, + {2}, + "assign_add", + [&](auto* rm, const auto& inputs, const auto& axes) -> std::vector { + auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[0]); + auto mul = add_pointwise(p2, rm, "main:pointwise1", {inputs[0], inputs[0]}, single_pointwise("mul")); + auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), mul); + return {rsum1, rsum2}; + }); + auto rsum1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), rsums); + auto rsum2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), rsums); + auto sqrt = add_pointwise(p2, "main:pointwise0", {rsum1}, single_pointwise("sqrt")); + auto sqrtb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt); + auto add = add_pointwise(p2, "main:pointwise2", {rsum2, sqrt}, single_pointwise("add")); + auto addb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), add); + auto mul = add_pointwise(p2, "main:pointwise3", {addb, sqrtb}, single_pointwise("mul")); + mm->add_return({mul}); + } + EXPECT(p1 == p2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 3b049223f0ed3acfa0c47fd8cf5dc852e2d2b5cc Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 25 Apr 2024 12:47:03 -0700 Subject: [PATCH 018/145] Format --- src/module.cpp | 2 +- src/split_reduce.cpp | 124 +++++++++++++++++++++--------------------- test/split_reduce.cpp | 99 ++++++++++++++++++--------------- 3 files changed, 118 insertions(+), 107 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 0d624841bbd..715853ecc3a 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -939,7 +939,7 @@ generic_split(const module& m, } auto r = m2.add_instructions(instructions2, &map_ins2); m2.add_return(r); - if(map_ins != nullptr) + if(map_ins != nullptr) *map_ins = map_ins2; return {{{std::move(m1), std::move(inputs1)}, {std::move(m2), std::move(inputs2)}}}; } diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 6510785ac0f..4759fc20c88 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -76,72 +76,74 @@ MIGRAPHX_REGISTER_OP(split_fused_reduce); static bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); } namespace { - struct splitter +struct splitter +{ + const_module_ref rm; + bool strictly_dominate(instruction_ref a, instruction_ref b) { - const_module_ref rm; - bool strictly_dominate(instruction_ref a, instruction_ref b) - { - if(not dom.has_value()) - dom = compute_dominator(*rm); - return dom->strictly_dominate(a, b); - } - - std::vector find_splits() - { - std::vector result; - copy_if( - iterator_for(*rm), std::back_inserter(result), [](auto ins) { return is_reduce(*ins); }); - if(result.size() > 2) - return {}; - // Only handle reduce_sum for now - // TODO: Support other reduction types - if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) { - return ins->name() == "reduce_sum"; - })) - return {}; - if(result.size() < 2) - return result; - if(this->strictly_dominate(result[0], result[1])) - return {}; - if(this->strictly_dominate(result[1], result[0])) - return {}; - return result; - } + if(not dom.has_value()) + dom = compute_dominator(*rm); + return dom->strictly_dominate(a, b); + } - std::vector find_alive(const std::vector& splits) - { - std::vector result; - bool stop = false; - liveness(*rm, [&](auto rins, const auto& live_set) { - if(stop) - return; - if(rins == rm->begin()) - return; - // We want to know what instructions are live after the split instruction - auto ins = std::prev(rins); - if(not contains(splits, ins)) - return; - std::copy_if(live_set.begin(), - live_set.end(), - std::back_inserter(result), - [&](instruction_ref live) { - if(live->name() == "@param") - return false; - if(contains(splits, live)) - return false; - if(splits.size() > 1 and none_of(splits, [&](instruction_ref split) { return this->strictly_dominate(live, split); })) - return false; - return true; - }); - stop = true; - }); + std::vector find_splits() + { + std::vector result; + copy_if(iterator_for(*rm), std::back_inserter(result), [](auto ins) { + return is_reduce(*ins); + }); + if(result.size() > 2) + return {}; + // Only handle reduce_sum for now + // TODO: Support other reduction types + if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) { + return ins->name() == "reduce_sum"; + })) + return {}; + if(result.size() < 2) return result; - } + if(this->strictly_dominate(result[0], result[1])) + return {}; + if(this->strictly_dominate(result[1], result[0])) + return {}; + return result; + } - std::optional dom = std::nullopt; - }; -} + std::vector find_alive(const std::vector& splits) + { + std::vector result; + bool stop = false; + liveness(*rm, [&](auto rins, const auto& live_set) { + if(stop) + return; + if(rins == rm->begin()) + return; + // We want to know what instructions are live after the split instruction + auto ins = std::prev(rins); + if(not contains(splits, ins)) + return; + std::copy_if(live_set.begin(), + live_set.end(), + std::back_inserter(result), + [&](instruction_ref live) { + if(live->name() == "@param") + return false; + if(contains(splits, live)) + return false; + if(splits.size() > 1 and none_of(splits, [&](instruction_ref split) { + return this->strictly_dominate(live, split); + })) + return false; + return true; + }); + stop = true; + }); + return result; + } + std::optional dom = std::nullopt; +}; +} // namespace static std::string assign_op(const std::vector& splits) { diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index 4a708feaf81..64a74ce93a0 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -41,7 +41,6 @@ void run_pass(migraphx::program& p) migraphx::fuse_reduce{}, migraphx::split_reduce{.split_size = 8192}, migraphx::dead_code_elimination{}}); - } void run_fuse_pass(migraphx::program& p) @@ -72,10 +71,10 @@ void auto_add_return(migraphx::module_ref m, std::vector migraphx::module_ref add_reduce_module(migraphx::program& p, - const std::string& name, - std::vector inputs, - const std::vector& axes, - F f) + const std::string& name, + std::vector inputs, + const std::vector& axes, + F f) { auto* rm = p.create_module(name); rm->set_bypass(); @@ -99,7 +98,7 @@ migraphx::instruction_ref add_reduce(migraphx::program& p, F f) { auto* mm = p.get_main_module(); - auto rm = add_reduce_module(p, name, inputs, axes, f); + auto rm = add_reduce_module(p, name, inputs, axes, f); return mm->add_instruction(migraphx::make_op("fused_reduce", {{"axes", axes}}), inputs, {rm}); } @@ -112,7 +111,7 @@ migraphx::instruction_ref add_reduce(migraphx::program& p, F f) { auto* mm = p.get_main_module(); - auto rm = add_reduce_module(p, name, inputs, axes, f); + auto rm = add_reduce_module(p, name, inputs, axes, f); return mm->add_instruction( migraphx::make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign}}), inputs, @@ -263,52 +262,62 @@ TEST_CASE(double_split_live) migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; migraphx::program p1; { - auto* mm = p1.get_main_module(); - auto x = mm->add_parameter("x", s); - auto rsum = add_reduce(p1, - "fuse_reduce0", - {x}, - {2}, - [&](auto* rm, const auto& inputs, const auto& axes) { - auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[0]); - auto sqrt = add_pointwise(p1, rm, "main:pointwise0", {rsum1}, single_pointwise("sqrt")); - auto sqrtb = rm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt); - auto mul = add_pointwise(p1, rm, "main:pointwise1", {inputs[0], inputs[0]}, single_pointwise("mul")); - auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), mul); - auto add = add_pointwise(p1, rm, "main:pointwise2", {rsum2, sqrt}, single_pointwise("add")); - auto addb = rm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), add); - return add_pointwise(p1, rm, "main:pointwise3", {addb, sqrtb}, single_pointwise("mul")); - }); + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = add_reduce( + p1, "fuse_reduce0", {x}, {2}, [&](auto* rm, const auto& inputs, const auto& axes) { + auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + auto sqrt = + add_pointwise(p1, rm, "main:pointwise0", {rsum1}, single_pointwise("sqrt")); + auto sqrtb = rm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt); + auto mul = add_pointwise( + p1, rm, "main:pointwise1", {inputs[0], inputs[0]}, single_pointwise("mul")); + auto rsum2 = + rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), mul); + auto add = add_pointwise( + p1, rm, "main:pointwise2", {rsum2, sqrt}, single_pointwise("add")); + auto addb = rm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), add); + return add_pointwise( + p1, rm, "main:pointwise3", {addb, sqrtb}, single_pointwise("mul")); + }); mm->add_return({rsum}); - } run_pass(p1); migraphx::program p2; { auto* mm = p2.get_main_module(); auto x = mm->add_parameter("x", s); - auto rsums = add_reduce(p2, - "fuse_reduce0_split", - {x}, - {2}, - "assign_add", - [&](auto* rm, const auto& inputs, const auto& axes) -> std::vector { - auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[0]); - auto mul = add_pointwise(p2, rm, "main:pointwise1", {inputs[0], inputs[0]}, single_pointwise("mul")); - auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), mul); - return {rsum1, rsum2}; - }); - auto rsum1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), rsums); - auto rsum2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), rsums); - auto sqrt = add_pointwise(p2, "main:pointwise0", {rsum1}, single_pointwise("sqrt")); - auto sqrtb = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt); + auto rsums = add_reduce( + p2, + "fuse_reduce0_split", + {x}, + {2}, + "assign_add", + [&](auto* rm, + const auto& inputs, + const auto& axes) -> std::vector { + auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + auto mul = add_pointwise( + p2, rm, "main:pointwise1", {inputs[0], inputs[0]}, single_pointwise("mul")); + auto rsum2 = + rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), mul); + return {rsum1, rsum2}; + }); + auto rsum1 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), rsums); + auto rsum2 = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), rsums); + auto sqrt = add_pointwise(p2, "main:pointwise0", {rsum1}, single_pointwise("sqrt")); + auto sqrtb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt); auto add = add_pointwise(p2, "main:pointwise2", {rsum2, sqrt}, single_pointwise("add")); - auto addb = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), add); - auto mul = add_pointwise(p2, "main:pointwise3", {addb, sqrtb}, single_pointwise("mul")); + auto addb = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), add); + auto mul = add_pointwise(p2, "main:pointwise3", {addb, sqrtb}, single_pointwise("mul")); mm->add_return({mul}); } EXPECT(p1 == p2); From 1cfa65e98ea656fcdb0cfef5ee14eaac52e84ca3 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 25 Apr 2024 13:00:38 -0700 Subject: [PATCH 019/145] Remove debug prints --- src/module.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 715853ecc3a..e33a4f969b4 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -960,12 +960,6 @@ std::array module::split(const std::vectordebug_print(); - std::cout << "splits1:\n"; - this->debug_print(splits1); - std::cout << "splits2:\n"; - this->debug_print(splits2); std::vector new_splits2; std::transform(splits2.begin(), splits2.end(), std::back_inserter(new_splits2), [&](auto ins) { From 61d788cd5788943ab78fc129ce3016c71a231523 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 25 Apr 2024 13:19:37 -0700 Subject: [PATCH 020/145] Enable with env var --- src/targets/gpu/fuse_mlir.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 671ddd9042c..3ae0ac35c35 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -40,6 +40,7 @@ struct module; namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR); /** * @brief Declares a new MIGraphX environment variable which forces to generate @@ -671,7 +672,8 @@ void fuse_mlir::apply(module_pass_manager& mpm) const find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)}, find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::fast)}); - match::find_matches(mpm, find_pointwise_mlir{}); + if(enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + match::find_matches(mpm, find_pointwise_mlir{}); #else (void)mpm; #endif From 78161de47b43c2c27f21eab6bc46929543594d50 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 13 May 2024 10:36:30 -0700 Subject: [PATCH 021/145] Use reaches --- src/include/migraphx/instruction.hpp | 2 ++ src/instruction.cpp | 12 ++++++++++++ src/split_reduce.cpp | 25 ++++++++++++++++++++----- src/targets/gpu/prepare_reduce.cpp | 12 ------------ 4 files changed, 34 insertions(+), 17 deletions(-) diff --git a/src/include/migraphx/instruction.hpp b/src/include/migraphx/instruction.hpp index 8827ce8dff0..0e722e23ad3 100644 --- a/src/include/migraphx/instruction.hpp +++ b/src/include/migraphx/instruction.hpp @@ -45,6 +45,8 @@ MIGRAPHX_EXPORT std::vector to_shapes(const std::vector& MIGRAPHX_EXPORT std::vector try_compute_shape(const operation& op, const std::vector& inputs); +bool reaches(instruction_ref start, instruction_ref end); + struct MIGRAPHX_EXPORT instruction { instruction() {} diff --git a/src/instruction.cpp b/src/instruction.cpp index 543e7a13fbb..7182184e8d5 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -516,5 +516,17 @@ migraphx::instruction* as_address(const instruction_ref& ins) noexcept return std::addressof(*ins); } +bool reaches(instruction_ref start, instruction_ref end) +{ + std::unordered_set visited; + return fix([&](auto self, auto ins) -> bool { + if(ins == start) + return true; + if(not visited.insert(ins).second) + return false; + return std::any_of(ins->inputs().begin(), ins->inputs().end(), self); + })(end); +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 4759fc20c88..b100bb63445 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -86,7 +86,7 @@ struct splitter return dom->strictly_dominate(a, b); } - std::vector find_splits() + std::vector find_reduces() { std::vector result; copy_if(iterator_for(*rm), std::back_inserter(result), [](auto ins) { @@ -102,13 +102,28 @@ struct splitter return {}; if(result.size() < 2) return result; - if(this->strictly_dominate(result[0], result[1])) - return {}; - if(this->strictly_dominate(result[1], result[0])) + if(reaches(result[0], result[1])) return {}; return result; } + static instruction_ref find_split(instruction_ref reduce) + { + if(reduce->outputs().size() != 1) + return reduce; + auto output = reduce->outputs().front(); + if(output->name() == "convert") + return find_split(output); + return reduce; + } + + std::vector find_splits() + { + std::vector result; + transform(find_reduces(), std::back_inserter(result), &find_split); + return result; + } + std::vector find_alive(const std::vector& splits) { std::vector result; @@ -131,7 +146,7 @@ struct splitter if(contains(splits, live)) return false; if(splits.size() > 1 and none_of(splits, [&](instruction_ref split) { - return this->strictly_dominate(live, split); + return reaches(live, split); })) return false; return true; diff --git a/src/targets/gpu/prepare_reduce.cpp b/src/targets/gpu/prepare_reduce.cpp index ebb1752eeed..bd5abd42bb0 100644 --- a/src/targets/gpu/prepare_reduce.cpp +++ b/src/targets/gpu/prepare_reduce.cpp @@ -71,18 +71,6 @@ std::vector find_reduce(module& m) return result; } -bool reaches(instruction_ref start, instruction_ref end) -{ - std::unordered_set visited; - return fix([&](auto self, auto ins) -> bool { - if(ins == start) - return true; - if(not visited.insert(ins).second) - return false; - return std::any_of(ins->inputs().begin(), ins->inputs().end(), self); - })(end); -} - std::vector find_parallel_reduce(const std::vector& r) { std::vector result; From 6647d4b97db82dfcaad3e69d143b3de18a241f00 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 14 May 2024 14:12:21 -0700 Subject: [PATCH 022/145] Remvoe dominator --- src/split_reduce.cpp | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index b100bb63445..256cb0e7c9d 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -23,7 +23,6 @@ * */ #include -#include #include #include #include @@ -79,14 +78,7 @@ namespace { struct splitter { const_module_ref rm; - bool strictly_dominate(instruction_ref a, instruction_ref b) - { - if(not dom.has_value()) - dom = compute_dominator(*rm); - return dom->strictly_dominate(a, b); - } - - std::vector find_reduces() + std::vector find_splits() { std::vector result; copy_if(iterator_for(*rm), std::back_inserter(result), [](auto ins) { @@ -107,23 +99,6 @@ struct splitter return result; } - static instruction_ref find_split(instruction_ref reduce) - { - if(reduce->outputs().size() != 1) - return reduce; - auto output = reduce->outputs().front(); - if(output->name() == "convert") - return find_split(output); - return reduce; - } - - std::vector find_splits() - { - std::vector result; - transform(find_reduces(), std::back_inserter(result), &find_split); - return result; - } - std::vector find_alive(const std::vector& splits) { std::vector result; @@ -155,8 +130,6 @@ struct splitter }); return result; } - - std::optional dom = std::nullopt; }; } // namespace From 66e9d31830042be3ef3a64195e2ac6cb7933e7a4 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 14 May 2024 15:33:28 -0700 Subject: [PATCH 023/145] Update comments --- src/include/migraphx/module.hpp | 6 +++++- src/module.cpp | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 5c7f8aa24a4..dbe98b7f0d7 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -244,12 +244,16 @@ struct MIGRAPHX_EXPORT module std::array split(const std::vector& args, const std::vector& splits1, const std::vector& splits2) const; - + + // Fuse the instruction into the module by inserting the instructions and + // parameters for any missing inputs. std::vector fuse(const std::vector& inss, std::unordered_map* map_ins = nullptr, inserter insert = nullptr); + // Fuse another module into this module by inserting the instructions and + // parameters from the module std::vector fuse(const module& m, const std::vector& inputs, diff --git a/src/module.cpp b/src/module.cpp index 46a5aae7317..7db4125070b 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -998,7 +998,7 @@ module::fuse(const std::vector& inss, module::inserter insert) { std::unordered_map default_map_ins; - if(not map_ins) + if(map_ins == nullptr) map_ins = &default_map_ins; std::vector inputs; for(auto ins : inss) @@ -1023,7 +1023,7 @@ module::fuse(const module& m, module::inserter insert) { std::unordered_map default_map_ins; - if(not map_ins) + if(map_ins == nullptr) map_ins = &default_map_ins; insert_params(*this, inputs, *map_ins); auto param_map = m.get_ins_param_map(inputs); From 9107b26f93bfb7935b1dcf7b4ad74d751921dea0 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 14 May 2024 15:33:34 -0700 Subject: [PATCH 024/145] Format --- src/include/migraphx/module.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index dbe98b7f0d7..b9c5c2f3541 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -244,7 +244,7 @@ struct MIGRAPHX_EXPORT module std::array split(const std::vector& args, const std::vector& splits1, const std::vector& splits2) const; - + // Fuse the instruction into the module by inserting the instructions and // parameters for any missing inputs. std::vector From cbf3afcacfc50bf3976049a92eed542b3e594ccb Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 14 May 2024 15:35:36 -0700 Subject: [PATCH 025/145] Fix param_utils --- src/include/migraphx/param_utils.hpp | 2 ++ src/param_utils.cpp | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/include/migraphx/param_utils.hpp b/src/include/migraphx/param_utils.hpp index f594f8be7f7..f645229cd10 100644 --- a/src/include/migraphx/param_utils.hpp +++ b/src/include/migraphx/param_utils.hpp @@ -38,6 +38,8 @@ std::string param_name(std::size_t i, const std::string& prefix = "x"); void sort_params(std::vector& params); +// Find the inputs for a module by finding instructions that are mapped to the +// parameters in the module std::vector find_inputs(const std::unordered_map& map_ins, const_module_ref parent, diff --git a/src/param_utils.cpp b/src/param_utils.cpp index a3a07acaa26..4a447a4ae2e 100644 --- a/src/param_utils.cpp +++ b/src/param_utils.cpp @@ -55,11 +55,11 @@ find_inputs(const std::unordered_map& map_ins, std::map names; for(auto&& [input, param] : map_ins) { - if(sub and not sub->has_instruction(param)) + if(sub != nullptr and not sub->has_instruction(param)) continue; if(param->name() != "@param") continue; - if(parent and not parent->has_instruction(input)) + if(parent != nullptr and not parent->has_instruction(input)) continue; auto v = param->get_operator().to_value(); auto name = v.at("parameter").to(); From 3947949760b240174611de1225320eff6de88b6a Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 14 May 2024 15:39:16 -0700 Subject: [PATCH 026/145] Filter supported ops --- src/targets/gpu/fuse_mlir.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 9d2dcd0b524..f8c5820f5ac 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -392,14 +392,25 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) return false; } +bool is_pointwise_op_supported_by_mlir_for_input(const instruction& i) +{ + return is_pointwise_op_supported_by_mlir(i); +} + MIGRAPHX_PRED_MATCHER(mlir_pointwise, instruction_ref ins) { if(ins->name() != "pointwise") return false; auto* pm = ins->module_inputs().front(); - return std::all_of(pm->begin(), pm->end(), [&](const auto& i) { - return is_pointwise_op_supported_by_mlir(i); - }); + return std::all_of(pm->begin(), pm->end(), &is_pointwise_op_supported_by_mlir); +} + +MIGRAPHX_PRED_MATCHER(mlir_input_pointwise, instruction_ref ins) +{ + if(ins->name() != "pointwise") + return false; + auto* pm = ins->module_inputs().front(); + return std::all_of(pm->begin(), pm->end(), &is_pointwise_op_supported_by_mlir_for_input); } struct find_mlir_fused_ops @@ -563,7 +574,7 @@ struct find_pointwise_mlir auto matcher() const { return match::name("gpu::mlir_op")(match::any_of[match::inputs()]( - match::name("pointwise")(match::used_once()).bind("pointwise"))); + mlir_input_pointwise(match::used_once()).bind("pointwise"))); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const From ffdba3cae478400570783730234bb178a4990c40 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 14 May 2024 15:41:20 -0700 Subject: [PATCH 027/145] Add another comment --- src/targets/gpu/fuse_mlir.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index f8c5820f5ac..749b89cdd32 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -392,6 +392,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) return false; } +// A seprate function so we can remove operators that are supported by mlir +// but not supported for an input fusion. bool is_pointwise_op_supported_by_mlir_for_input(const instruction& i) { return is_pointwise_op_supported_by_mlir(i); From 019bb0d4c06dcfbc5f505be48e0799590eae6785 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 15 May 2024 16:51:51 -0700 Subject: [PATCH 028/145] Add test for multi out split reduce --- test/split_reduce.cpp | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index 64a74ce93a0..6deb058ba21 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -125,6 +125,13 @@ inline auto single_reduce(const std::string& name) }; } +inline auto squared() +{ + return [](auto* pm, const auto& inputs) { + return pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[0]); + }; +} + TEST_CASE(single) { migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; @@ -257,6 +264,41 @@ TEST_CASE(sequence_reduces) EXPECT(p1 == p2); } +TEST_CASE(parallel_reduce) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto xx = mm->add_instruction(migraphx::make_op("mul"), x, x); + auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); + auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), xx); + auto mul = mm->add_instruction(migraphx::make_op("mul"), rsum1, rsum2); + mm->add_return({mul}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto rsum = add_reduce( + p2, "main:reduce_sum0:main:pointwise1:main:pointwise0:main:reduce_sum1_split", {x}, {2}, "assign_add", [&](auto* rm, const auto& inputs, const auto& axes) -> std::vector { + auto xx = add_pointwise( + p2, rm, "main:pointwise0", {inputs[0]}, squared()); + auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[0]); + auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), xx); + return {rsum2, rsum1}; + }); + auto rsum2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), rsum); + auto rsum1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), rsum); + auto mul = add_pointwise( + p2, mm, "main:pointwise1", {rsum1, rsum2}, single_pointwise("mul")); + mm->add_return({mul}); + } + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(double_split_live) { migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; From c33f7fd59ddb87695bc343f9163b97aeb169ccfb Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 15 May 2024 16:51:59 -0700 Subject: [PATCH 029/145] Format --- test/split_reduce.cpp | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index 6deb058ba21..372b72de5b0 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -269,12 +269,12 @@ TEST_CASE(parallel_reduce) migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; migraphx::program p1; { - auto* mm = p1.get_main_module(); - auto x = mm->add_parameter("x", s); - auto xx = mm->add_instruction(migraphx::make_op("mul"), x, x); + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto xx = mm->add_instruction(migraphx::make_op("mul"), x, x); auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), x); auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), xx); - auto mul = mm->add_instruction(migraphx::make_op("mul"), rsum1, rsum2); + auto mul = mm->add_instruction(migraphx::make_op("mul"), rsum1, rsum2); mm->add_return({mul}); } run_pass(p1); @@ -283,17 +283,25 @@ TEST_CASE(parallel_reduce) auto* mm = p2.get_main_module(); auto x = mm->add_parameter("x", s); auto rsum = add_reduce( - p2, "main:reduce_sum0:main:pointwise1:main:pointwise0:main:reduce_sum1_split", {x}, {2}, "assign_add", [&](auto* rm, const auto& inputs, const auto& axes) -> std::vector { - auto xx = add_pointwise( - p2, rm, "main:pointwise0", {inputs[0]}, squared()); - auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[0]); - auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), xx); + p2, + "main:reduce_sum0:main:pointwise1:main:pointwise0:main:reduce_sum1_split", + {x}, + {2}, + "assign_add", + [&](auto* rm, + const auto& inputs, + const auto& axes) -> std::vector { + auto xx = add_pointwise(p2, rm, "main:pointwise0", {inputs[0]}, squared()); + auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + auto rsum2 = + rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), xx); return {rsum2, rsum1}; }); auto rsum2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), rsum); auto rsum1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), rsum); - auto mul = add_pointwise( - p2, mm, "main:pointwise1", {rsum1, rsum2}, single_pointwise("mul")); + auto mul = + add_pointwise(p2, mm, "main:pointwise1", {rsum1, rsum2}, single_pointwise("mul")); mm->add_return({mul}); } EXPECT(p1.sort() == p2.sort()); From 86df8f13f47b18f55183d4d7bfcd2affde803d61 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 15 May 2024 19:05:44 -0700 Subject: [PATCH 030/145] Add dominator back --- src/split_reduce.cpp | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 256cb0e7c9d..72f997b53dc 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -23,6 +23,7 @@ * */ #include +#include #include #include #include @@ -78,6 +79,14 @@ namespace { struct splitter { const_module_ref rm; + + bool strictly_dominate(instruction_ref a, instruction_ref b) + { + if(not dom.has_value()) + dom = compute_dominator(*rm); + return dom->strictly_dominate(a, b); + } + std::vector find_splits() { std::vector result; @@ -109,19 +118,19 @@ struct splitter if(rins == rm->begin()) return; // We want to know what instructions are live after the split instruction - auto ins = std::prev(rins); + auto ins = instruction::get_output_alias(std::prev(rins)); if(not contains(splits, ins)) return; std::copy_if(live_set.begin(), live_set.end(), std::back_inserter(result), - [&](instruction_ref live) { + [&](instruction_ref live) { if(live->name() == "@param") return false; if(contains(splits, live)) return false; if(splits.size() > 1 and none_of(splits, [&](instruction_ref split) { - return reaches(live, split); + return this->strictly_dominate(live, split); })) return false; return true; @@ -130,6 +139,8 @@ struct splitter }); return result; } + + std::optional dom = std::nullopt; }; } // namespace From 7e5babf164f834ffea9743bab7858ae834475b2d Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 15 May 2024 19:05:48 -0700 Subject: [PATCH 031/145] Format --- src/split_reduce.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 72f997b53dc..0a4714438ea 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -124,7 +124,7 @@ struct splitter std::copy_if(live_set.begin(), live_set.end(), std::back_inserter(result), - [&](instruction_ref live) { + [&](instruction_ref live) { if(live->name() == "@param") return false; if(contains(splits, live)) From 48712c9d6e7e86ebdc92684a76e22f7a6ca3e8bd Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 22 May 2024 17:07:26 -0700 Subject: [PATCH 032/145] Handle scalars --- src/targets/gpu/fuse_mlir.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index f24ca3075d4..f91b1cf9052 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -28,6 +28,8 @@ #include #include #include +#include +#include #include #include #include @@ -595,6 +597,16 @@ struct find_pointwise_mlir mlir_input_pointwise(match::used_once()).bind("pointwise"))); } + static instruction_ref insert_pointwise(module& m, + instruction_ref ins, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) + { + assert(mod_args.empty()); + return insert_common_op(m, ins, op, inputs); + } + void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto ins = r.result; @@ -606,7 +618,7 @@ struct find_pointwise_mlir std::unordered_map map_ins; module_ref m = mpm.create_module(pm->name() + ":" + mm->name()); m->set_bypass(); - auto rins = m->fuse(*pm, pw->inputs(), &map_ins).front(); + auto rins = m->fuse(*pm, pw->inputs(), &map_ins, &insert_pointwise).front(); map_ins[pw] = rins; auto ret = m->fuse(*mm, ins->inputs(), &map_ins); @@ -653,6 +665,8 @@ void fuse_mlir::apply(module_pass_manager& mpm) const find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)}, find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::fast)}); + mpm.run_pass(dead_code_elimination{}); + if(enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) match::find_matches(mpm, find_pointwise_mlir{}); #else From c3cf902c588c9fc7d7d242e4aa9cf94212feea1c Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 22 May 2024 17:07:32 -0700 Subject: [PATCH 033/145] Format --- src/targets/gpu/fuse_mlir.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index f91b1cf9052..f5af1d3d51f 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -598,10 +598,10 @@ struct find_pointwise_mlir } static instruction_ref insert_pointwise(module& m, - instruction_ref ins, - const operation& op, - const std::vector& inputs, - const std::vector& mod_args) + instruction_ref ins, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) { assert(mod_args.empty()); return insert_common_op(m, ins, op, inputs); From 51d3ea915bbfbc1ccff595801964d06c4de1a5ef Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 18 Jun 2024 13:13:09 -0700 Subject: [PATCH 034/145] Add description --- src/module.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/module.cpp b/src/module.cpp index afd0185b3ba..9d2229dd222 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -979,6 +979,8 @@ std::array module::split(const std::vector& inputs, std::unordered_map& map_ins) From 0e600850cf00a7e95bbffa4d23260b27522fbb38 Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Wed, 19 Jun 2024 17:03:00 -0400 Subject: [PATCH 035/145] Update src/targets/gpu/fuse_mlir.cpp Co-authored-by: Umang Yadav <29876643+umangyadav@users.noreply.github.com> --- src/targets/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 57ec8d81d81..3e62934101f 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -400,7 +400,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) return false; } -// A seprate function so we can remove operators that are supported by mlir +// A separate function so we can remove operators that are supported by mlir // but not supported for an input fusion. bool is_pointwise_op_supported_by_mlir_for_input(const instruction& i) { From fd0b7f7ce4696ea9bd076fac014c35f4007a7e3f Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 19 Jun 2024 19:45:32 -0700 Subject: [PATCH 036/145] Add doc --- docs/dev/env_vars.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/dev/env_vars.rst b/docs/dev/env_vars.rst index 70135ad6836..0a7de7c4437 100644 --- a/docs/dev/env_vars.rst +++ b/docs/dev/env_vars.rst @@ -272,6 +272,11 @@ Performs exhaustive tuning for MLIR. Set to an integer greater than 1. Limits the number of solutions available to MLIR for tuning. +.. envvar:: MIGRAPHX_ENABLE_MLIR_INPUT_FUSION + +Set to "1", "enable", "enabled", "yes", or "true" to use. +Enable input fusions in MLIR. + CK vars ----------- From 9c4d6590036aa83013817d6f0a7383b78f8fbddb Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 20 Jun 2024 16:48:04 -0700 Subject: [PATCH 037/145] Add input fusion to jenkins --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index de6fa059a0b..7cb184f51d1 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -144,7 +144,7 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build -> } }, mlir_debug: rocmnode('mi100+') { cmake_build -> stage('MLIR Debug') { - withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot']) { + withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1']) { def sanitizers = "undefined" // Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS. def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}" From 5eaaed3c90874ee55e174063824c8c8bd3dc03c0 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 20 Jun 2024 17:25:38 -0700 Subject: [PATCH 038/145] Add unit test for fuse module --- test/module_test.cpp | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/test/module_test.cpp b/test/module_test.cpp index 52981930bbb..a6cb1976fd1 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -28,10 +28,11 @@ #include #include #include -#include "test.hpp" #include #include +#include +#include migraphx::program create_program() { @@ -659,4 +660,35 @@ TEST_CASE(module_split3) EXPECT(bool{mods[2].inputs[1] == splits1.front()}); } +TEST_CASE(fuse_module) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m1; + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto add = add_pointwise(p, "main:pointwise0", {x, y}, single_pointwise("add")); + auto mul = add_pointwise(p, "main:pointwise1", {add, z}, single_pointwise("mul")); + + std::unordered_map map_ins; + auto rins = m1.fuse(*add->module_inputs().front(), add->inputs(), &map_ins).front(); + map_ins[add] = rins; + auto ret = m1.fuse(*mul->module_inputs().front(), mul->inputs(), &map_ins); + m1.add_return(ret); + } + migraphx::module m2; + { + auto x = m2.add_parameter("x0", s); + auto y = m2.add_parameter("x1", s); + auto z = m2.add_parameter("x2", s); + auto add = m2.add_instruction(migraphx::make_op("add"), x, y); + auto mul = m2.add_instruction(migraphx::make_op("mul"), add, z); + m2.add_return({mul}); + } + EXPECT(m1 == m2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 1edac2d5a93e27ebd5bd0d4db8c9d7fc0b30867a Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 20 Jun 2024 17:25:43 -0700 Subject: [PATCH 039/145] Format --- test/module_test.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/module_test.cpp b/test/module_test.cpp index a6cb1976fd1..3b910f8dfdf 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -666,24 +666,24 @@ TEST_CASE(fuse_module) migraphx::module m1; { migraphx::program p; - auto* mm = p.get_main_module(); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", s); + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); auto add = add_pointwise(p, "main:pointwise0", {x, y}, single_pointwise("add")); auto mul = add_pointwise(p, "main:pointwise1", {add, z}, single_pointwise("mul")); std::unordered_map map_ins; - auto rins = m1.fuse(*add->module_inputs().front(), add->inputs(), &map_ins).front(); + auto rins = m1.fuse(*add->module_inputs().front(), add->inputs(), &map_ins).front(); map_ins[add] = rins; - auto ret = m1.fuse(*mul->module_inputs().front(), mul->inputs(), &map_ins); + auto ret = m1.fuse(*mul->module_inputs().front(), mul->inputs(), &map_ins); m1.add_return(ret); } migraphx::module m2; { - auto x = m2.add_parameter("x0", s); - auto y = m2.add_parameter("x1", s); - auto z = m2.add_parameter("x2", s); + auto x = m2.add_parameter("x0", s); + auto y = m2.add_parameter("x1", s); + auto z = m2.add_parameter("x2", s); auto add = m2.add_instruction(migraphx::make_op("add"), x, y); auto mul = m2.add_instruction(migraphx::make_op("mul"), add, z); m2.add_return({mul}); From 1ebdaf1dc33bd9c47d891528da935ef8bab30739 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 14:28:54 -0700 Subject: [PATCH 040/145] Add unit test --- src/targets/gpu/fuse_mlir.cpp | 2 +- test/gpu/fuse_mlir.cpp | 40 +++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 3e62934101f..d663fa3f1e2 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -607,7 +607,7 @@ struct find_pointwise_mlir instruction_ref ins, const operation& op, const std::vector& inputs, - const std::vector& mod_args) + const std::vector&) { assert(mod_args.empty()); return insert_common_op(m, ins, op, inputs); diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 6b646720d66..68b248c44ae 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -32,6 +32,8 @@ #include #include +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); + void run_pass(migraphx::program& p) { migraphx::run_passes( @@ -100,6 +102,44 @@ TEST_CASE(dot_add) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(add_dot) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add")); + auto dot = mm->add_instruction(migraphx::make_op("dot"), add, b); + mm->add_return({dot}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto fused = + add_mlir(p2, + "main:pointwise0:mlir_dot1", + {x, y, b}, + {"x0", "x1", "x2"}, + [=](auto* pm, const auto& inputs) { + auto add = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + auto dot = + pm->add_instruction(migraphx::make_op("dot"), add, inputs[2]); + return std::make_tuple(dot, dot); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(int_quant_dot_abs) { migraphx::shape s_a{migraphx::shape::int8_type, {5, 4}}; From d035f3bb7b155814d7a36c09bd9201987798857d Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 14:29:01 -0700 Subject: [PATCH 041/145] Format --- test/gpu/fuse_mlir.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 68b248c44ae..e124b47da84 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -128,9 +128,9 @@ TEST_CASE(add_dot) {x, y, b}, {"x0", "x1", "x2"}, [=](auto* pm, const auto& inputs) { - auto add = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); - auto dot = - pm->add_instruction(migraphx::make_op("dot"), add, inputs[2]); + auto add = + pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + auto dot = pm->add_instruction(migraphx::make_op("dot"), add, inputs[2]); return std::make_tuple(dot, dot); }); mm->add_return({fused}); From 8c4b8f060e8d614dfaa479eee2425b8bff46517c Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 14:40:35 -0700 Subject: [PATCH 042/145] Rename type --- .../gpu/kernels/include/migraphx/kernels/atomic.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp index 56fb8d6341f..82fee3e7b8d 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp @@ -28,12 +28,12 @@ template , T& x, T y, Op op) { MIGRAPHX_ATOMIC_CAS_WARNING(); - using U = conditional_t; - U* address = reinterpret_cast(&x); - U expected = __hip_atomic_load(address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + using storage = conditional_t; + storage* address = reinterpret_cast(&x); + storage expected = __hip_atomic_load(address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); while(not __hip_atomic_compare_exchange_strong(address, &expected, - bit_cast(op(bit_cast(expected), y)), + bit_cast(op(bit_cast(expected), y)), __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)) From 271ea78d992500fcf00c4f891d0aa977567941d7 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 16:41:23 -0700 Subject: [PATCH 043/145] Add verify test --- test/verify/test_add_dot.cpp | 49 ++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 test/verify/test_add_dot.cpp diff --git a/test/verify/test_add_dot.cpp b/test/verify/test_add_dot.cpp new file mode 100644 index 00000000000..eb62d0191fd --- /dev/null +++ b/test/verify/test_add_dot.cpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_add_dot : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{DType, {256, 256}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("y", s); + auto add = mm->add_instruction(migraphx::make_op("add"), x, y); + auto dot = mm->add_instruction(migraphx::make_op("dot"), add, z); + mm->add_return({dot}); + return p; + } +}; + +template struct test_add_dot; +template struct test_add_dot; From 43e76f503eaee61e72f64e59b6b57c9599f3d6d9 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 16:41:38 -0700 Subject: [PATCH 044/145] Format --- test/verify/test_add_dot.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/verify/test_add_dot.cpp b/test/verify/test_add_dot.cpp index eb62d0191fd..6028e211ced 100644 --- a/test/verify/test_add_dot.cpp +++ b/test/verify/test_add_dot.cpp @@ -35,9 +35,9 @@ struct test_add_dot : verify_program> migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape s{DType, {256, 256}}; - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("y", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("y", s); auto add = mm->add_instruction(migraphx::make_op("add"), x, y); auto dot = mm->add_instruction(migraphx::make_op("dot"), add, z); mm->add_return({dot}); From b357f943b0210705b6825863271b8511466b285e Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 16:49:57 -0700 Subject: [PATCH 045/145] Fix tidy issue --- src/targets/gpu/fuse_mlir.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index d663fa3f1e2..e901dc24a2b 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -607,8 +607,10 @@ struct find_pointwise_mlir instruction_ref ins, const operation& op, const std::vector& inputs, - const std::vector&) + const std::vector& mod_args) { + // Only used in assert + (void)mod_args; assert(mod_args.empty()); return insert_common_op(m, ins, op, inputs); } From fbb630ed026c8a709e3b95dcc1db4ee209892ce4 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 17:35:47 -0700 Subject: [PATCH 046/145] Fix tidy --- src/split_reduce.cpp | 2 +- src/targets/gpu/hip.cpp | 2 +- .../include/migraphx/kernels/functional.hpp | 40 +++++++++++++++++-- .../include/migraphx/kernels/reduce.hpp | 10 +---- test/split_reduce.cpp | 2 +- 5 files changed, 42 insertions(+), 14 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 0a4714438ea..1a4ecffbd51 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -87,7 +87,7 @@ struct splitter return dom->strictly_dominate(a, b); } - std::vector find_splits() + std::vector find_splits() const { std::vector result; copy_if(iterator_for(*rm), std::back_inserter(result), [](auto ins) { diff --git a/src/targets/gpu/hip.cpp b/src/targets/gpu/hip.cpp index a1ebabb388b..d467858c7c6 100644 --- a/src/targets/gpu/hip.cpp +++ b/src/targets/gpu/hip.cpp @@ -312,7 +312,7 @@ void gpu_fill(context& ctx, const argument& dst, int value) } else { - for(auto arg : dst.get_sub_objects()) + for(const auto& arg : dst.get_sub_objects()) gpu_fill(ctx, arg, value); } } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp index 024b68db979..f46e194890b 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp @@ -180,10 +180,38 @@ constexpr void each_args(F) { } -template -constexpr void each_args_unpack(F f, Ts&&... xs) +template +constexpr void unpack_each(F f) { - each_args([&](auto&& p) { p(f); }, static_cast(xs)...); + f(); +} + +template +constexpr void unpack_each(F f, Pack p) +{ + p([&](auto&&... xs) { + each_args(f, static_cast(xs)...); + }); +} + +template +constexpr void unpack_each(F f, Pack1 p1, Pack2 p2) +{ + p1([&](auto&&... xs) { + p2([&](auto&& ys) { + each_args([&](auto&& p) { p(f); }, pack_forward(static_cast(xs), static_cast(ys))...); + }); + }); +} + +template +constexpr void unpack_each(F f, Pack1 p1, Pack2 p2, Packs... packs) +{ + unpack_each([&](auto&& x, auto&& y) { + unpack_each([&](auto&&... zs) { + f(static_cast(x), static_cast(y), static_cast(zs)...); + }, packs...); + }, p1, p2); } template @@ -238,6 +266,12 @@ constexpr auto pack(Ts... xs) return [=](auto f) { return f(xs...); }; } +template +constexpr auto pack_forward(Ts&&... xs) +{ + return [&](auto f) { return f(static_cast(xs)...); }; +} + template constexpr auto join(G g, F f) { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp index 65060c1483f..748b4bb1a8b 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp @@ -749,10 +749,7 @@ __device__ void fused_reduce(Output output_pack, Assign assign, F f) { Algo::template run([&](auto out_idx, auto r) { auto result_tuple = f(r, out_idx); - output_pack([&](auto... outputs) { - result_tuple([&](auto... results) { - each_args_unpack( - [&](auto output, auto result) { + unpack_each([&](auto output, auto result) { if constexpr(reduce::is_inner_storage{}) { r.inner([&](auto& y, auto x) { assign(y, x); })(output, result); @@ -761,10 +758,7 @@ __device__ void fused_reduce(Output output_pack, Assign assign, F f) { r.outer([&] { assign(output[out_idx], implicit_conversion(result)); }); } - }, - pack(outputs, results)...); - }); - }); + }, output_pack, result_tuple); }); } diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index 372b72de5b0..9616de45b68 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -66,7 +66,7 @@ void auto_add_return(migraphx::module_ref m, migraphx::instruction_ref ins) void auto_add_return(migraphx::module_ref m, std::vector inss) { - m->add_return(inss); + m->add_return(std::move(inss)); } template From a3ff01a34f5c9e22cd6c235d313278abf0715b76 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 17:35:57 -0700 Subject: [PATCH 047/145] Format --- .../include/migraphx/kernels/functional.hpp | 35 +++++++++++-------- .../include/migraphx/kernels/reduce.hpp | 23 ++++++------ 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp index f46e194890b..35cb4745705 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp @@ -180,40 +180,47 @@ constexpr void each_args(F) { } -template +template constexpr void unpack_each(F f) { f(); } -template +template constexpr void unpack_each(F f, Pack p) { - p([&](auto&&... xs) { - each_args(f, static_cast(xs)...); - }); + p([&](auto&&... xs) { each_args(f, static_cast(xs)...); }); } -template +template constexpr void unpack_each(F f, Pack1 p1, Pack2 p2) { p1([&](auto&&... xs) { p2([&](auto&& ys) { - each_args([&](auto&& p) { p(f); }, pack_forward(static_cast(xs), static_cast(ys))...); + each_args( + [&](auto&& p) { p(f); }, + pack_forward(static_cast(xs), static_cast(ys))...); }); }); } -template +template constexpr void unpack_each(F f, Pack1 p1, Pack2 p2, Packs... packs) { - unpack_each([&](auto&& x, auto&& y) { - unpack_each([&](auto&&... zs) { - f(static_cast(x), static_cast(y), static_cast(zs)...); - }, packs...); - }, p1, p2); + unpack_each( + [&](auto&& x, auto&& y) { + unpack_each( + [&](auto&&... zs) { + f(static_cast(x), + static_cast(y), + static_cast(zs)...); + }, + packs...); + }, + p1, + p2); } - + template constexpr void repeat_c(F&& f) { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp index 748b4bb1a8b..cfe362948de 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp @@ -749,16 +749,19 @@ __device__ void fused_reduce(Output output_pack, Assign assign, F f) { Algo::template run([&](auto out_idx, auto r) { auto result_tuple = f(r, out_idx); - unpack_each([&](auto output, auto result) { - if constexpr(reduce::is_inner_storage{}) - { - r.inner([&](auto& y, auto x) { assign(y, x); })(output, result); - } - else - { - r.outer([&] { assign(output[out_idx], implicit_conversion(result)); }); - } - }, output_pack, result_tuple); + unpack_each( + [&](auto output, auto result) { + if constexpr(reduce::is_inner_storage{}) + { + r.inner([&](auto& y, auto x) { assign(y, x); })(output, result); + } + else + { + r.outer([&] { assign(output[out_idx], implicit_conversion(result)); }); + } + }, + output_pack, + result_tuple); }); } From 5e848ba9758583b4811a1e205b66d67b0be02501 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 17:45:49 -0700 Subject: [PATCH 048/145] Fix parameter name --- test/verify/test_add_dot.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/verify/test_add_dot.cpp b/test/verify/test_add_dot.cpp index 6028e211ced..ad23cc5acf6 100644 --- a/test/verify/test_add_dot.cpp +++ b/test/verify/test_add_dot.cpp @@ -37,7 +37,7 @@ struct test_add_dot : verify_program> migraphx::shape s{DType, {256, 256}}; auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); auto add = mm->add_instruction(migraphx::make_op("add"), x, y); auto dot = mm->add_instruction(migraphx::make_op("dot"), add, z); mm->add_return({dot}); From 3964597e269af6a4ab16ae6dcb50c89e6fa3d259 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 17:55:19 -0700 Subject: [PATCH 049/145] Add line --- src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp index 35cb4745705..7101fd2a3c6 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp @@ -221,6 +221,7 @@ constexpr void unpack_each(F f, Pack1 p1, Pack2 p2, Packs... packs) p2); } + template constexpr void repeat_c(F&& f) { From efb1f7650f21bc9bc321a4f0b39b35a96143e3b8 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 21 Jun 2024 17:55:27 -0700 Subject: [PATCH 050/145] Format --- src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp index 7101fd2a3c6..35cb4745705 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp @@ -221,7 +221,6 @@ constexpr void unpack_each(F f, Pack1 p1, Pack2 p2, Packs... packs) p2); } - template constexpr void repeat_c(F&& f) { From ae29e39d36bc28ed746f3ac9bed9c700b640288d Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 10 Jul 2024 18:24:04 +0000 Subject: [PATCH 051/145] add reshapes to fused mlir --- requirements.txt | 2 +- src/simplify_reshapes.cpp | 2 +- src/targets/gpu/fuse_mlir.cpp | 33 +++++++++++++++++++++------------ 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/requirements.txt b/requirements.txt index ab74785347a..2b491f66813 100755 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@da8969573d2ad408c7ad129126679838d82d9350 -DBUILD_FAT_LIBROCKCOMPILER=On +ROCm/rocMLIR@be181d7543ea7607d2e9402d5006fbf6e49632b2 -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index b3c11820322..1332c3ca580 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -44,7 +44,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -const auto& reshaper_names() +static const auto& reshaper_names() { // clang-format off static const std::unordered_set names = { diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index e901dc24a2b..f3d505075ce 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -165,21 +165,30 @@ MIGRAPHX_REGISTER_OP(mlir_op); namespace { +static const auto& reshaper_names() +{ + // clang-format off + static const std::unordered_set names = { + "slice", + "transpose", + "multibroadcast", + "broadcast", + "contiguous", + "reshape", + "squeeze", + "flatten", + "unsqueeze" + }; + // clang-format on + return names; +} + std::tuple> get_fusable_input_op_stream(instruction_ref lower_input) { instruction_ref upper_input = lower_input; std::vector op_stream; - while(contains({"slice", - "transpose", - "multibroadcast", - "broadcast", - "contiguous", - "reshape", - "squeeze", - "flatten", - "unsqueeze"}, - upper_input->name())) + while(contains(reshaper_names(), upper_input->name())) { operation op = upper_input->get_operator(); if(contains({"squeeze", "flatten", "unsqueeze"}, upper_input->name())) @@ -443,7 +452,7 @@ struct find_mlir_fused_ops mlir_mode dot_mode = mlir_mode::none; auto matcher() const { - auto dot_or_conv = match::skip(match::name("contiguous"))( + auto dot_or_conv = match::skip(match::name(reshaper_names()))( match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op")); return mlir_pointwise()(match::any_of[match::inputs()](dot_or_conv.bind("x"))); } @@ -452,7 +461,7 @@ struct find_mlir_fused_ops { auto ins = r.result; auto gemm_based_op = r.instructions["gemm_based_op"]; - auto x_ins = r.instructions["x"]; // input after contiguous + auto x_ins = r.instructions["x"]; // input to pointwise after reshaper stream auto* pm = ins->module_inputs().front(); auto names = pm->get_parameter_names(); std::sort(names.begin(), names.end()); From 470984d5bfe8fc90fb1d69b22c2f4fa6817ac6f6 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 11 Jul 2024 12:43:21 +0000 Subject: [PATCH 052/145] use fuse instead of fold_pointwise --- src/targets/gpu/fuse_mlir.cpp | 50 +++++++++++------------------------ test/gpu/fuse_mlir.cpp | 2 +- 2 files changed, 17 insertions(+), 35 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index f3d505075ce..04f2704e2b5 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -312,30 +312,6 @@ create_param_map_with_literals(module_ref mm, const module* pm, const shape& sha return ins_map; } -std::vector -fold_pointwise_mod(instruction_ref pm_ins, - module_ref parent_mod, - const std::unordered_map& ins_map) -{ - auto* pm = pm_ins->module_inputs().front(); - auto names = pm->get_parameter_names(); - std::sort(names.begin(), names.end()); - std::unordered_map param_map = - create_param_map_with_literals(parent_mod, pm, pm_ins->get_shape()); - std::transform(names.begin(), - names.end(), - pm_ins->inputs().begin(), - std::inserter(param_map, param_map.end()), - [&](auto name, auto input) { - if(ins_map.count(input)) - return std::make_pair(pm->get_parameter(name), ins_map.at(input)); - return std::make_pair( - pm->get_parameter(name), - parent_mod->add_parameter(name, input->get_shape().as_standard())); - }); - return parent_mod->insert_instructions(parent_mod->end(), pm, ¶m_map); -} - // Whitelist supported fusion options, including imposing type constraints // for cases where MLIR only supports an operation (usually a pointwise function) // on particular types. @@ -452,7 +428,7 @@ struct find_mlir_fused_ops mlir_mode dot_mode = mlir_mode::none; auto matcher() const { - auto dot_or_conv = match::skip(match::name(reshaper_names()))( + auto dot_or_conv = match::skip(match::name("contiguous"))( match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op")); return mlir_pointwise()(match::any_of[match::inputs()](dot_or_conv.bind("x"))); } @@ -469,7 +445,10 @@ struct find_mlir_fused_ops mm->set_bypass(); auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op( mm, gemm_based_op->inputs(), gemm_based_op->get_operator()); - mm->add_return(fold_pointwise_mod(ins, mm, {{x_ins, anchor_op}})); + std::unordered_map param_map = + create_param_map_with_literals(mm, pm, ins->get_shape()); + param_map[x_ins] = anchor_op; + mm->add_return(mm->fuse(*pm, ins->inputs(), ¶m_map)); std::vector inputs; std::copy_if(ins->inputs().begin(), @@ -573,20 +552,23 @@ struct find_mlir_standalone_attention_op auto gemm1 = mm->add_instruction(make_op("dot"), {softmax, new_upper_v}); - std::unordered_map ins_map; - ins_map[gemm_softmax_gemm] = gemm1; - auto ins_to_replace = gemm1; + std::vector ins_to_replace = {gemm1}; auto ins_to_be_replaced = gemm_softmax_gemm; if(r.instructions.find("trailing_pm") != r.instructions.end()) { - ins_to_replace = fold_pointwise_mod(r.instructions["trailing_pm"], mm, ins_map)[0]; - std::copy_if(r.instructions["trailing_pm"]->inputs().begin(), - r.instructions["trailing_pm"]->inputs().end(), + auto trailing_pm_ins = r.instructions["trailing_pm"]; + auto ins_map = create_param_map_with_literals( + mm, trailing_pm_ins->module_inputs().front(), trailing_pm_ins->get_shape()); + ins_map[gemm_softmax_gemm] = gemm1; + ins_to_replace = mm->fuse( + *trailing_pm_ins->module_inputs().front(), trailing_pm_ins->inputs(), &ins_map); + std::copy_if(trailing_pm_ins->inputs().begin(), + trailing_pm_ins->inputs().end(), std::back_inserter(inputs), [&](auto input) { return input != gemm_softmax_gemm; }); - ins_to_be_replaced = r.instructions["trailing_pm"]; + ins_to_be_replaced = trailing_pm_ins; } - mm->add_return({ins_to_replace}); + mm->add_return(ins_to_replace); mpm.get_module().replace_instruction( ins_to_be_replaced, mlir_op{gemm1->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index e124b47da84..ef7833db825 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -90,7 +90,7 @@ TEST_CASE(dot_add) add_mlir(p2, "mlir_main:pointwise0", {x, a, b}, - {"x1", "y0", "y1"}, + {"x2", "y0", "y1"}, [=](auto* pm, const auto& inputs) { auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); From 593b11957d57ecf639eca3cc07d0281a41b8362c Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 11 Jul 2024 14:52:59 +0000 Subject: [PATCH 053/145] Passes make check --- src/targets/gpu/fuse_mlir.cpp | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 04f2704e2b5..fd9082044b7 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -428,17 +428,20 @@ struct find_mlir_fused_ops mlir_mode dot_mode = mlir_mode::none; auto matcher() const { - auto dot_or_conv = match::skip(match::name("contiguous"))( + auto reshapes = reshaper_names(); + // slice is not supported + reshapes.erase("slice"); + auto dot_or_conv = match::skip(match::name(reshapes))( match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op")); return mlir_pointwise()(match::any_of[match::inputs()](dot_or_conv.bind("x"))); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const { - auto ins = r.result; + auto pw_ins = r.result; auto gemm_based_op = r.instructions["gemm_based_op"]; - auto x_ins = r.instructions["x"]; // input to pointwise after reshaper stream - auto* pm = ins->module_inputs().front(); + auto x_ins = r.instructions["x"]; // input to pointwise after reshaper op stream + auto* pm = pw_ins->module_inputs().front(); auto names = pm->get_parameter_names(); std::sort(names.begin(), names.end()); module_ref mm = mpm.create_module("mlir_" + pm->name()); @@ -446,18 +449,27 @@ struct find_mlir_fused_ops auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op( mm, gemm_based_op->inputs(), gemm_based_op->get_operator()); std::unordered_map param_map = - create_param_map_with_literals(mm, pm, ins->get_shape()); - param_map[x_ins] = anchor_op; - mm->add_return(mm->fuse(*pm, ins->inputs(), ¶m_map)); + create_param_map_with_literals(mm, pm, pw_ins->get_shape()); + auto [upper_input, op_stream] = get_fusable_input_op_stream(x_ins); + assert(upper_input == gemm_based_op); + auto prev_input = anchor_op; + for(const auto& op : reverse(op_stream)) + { + prev_input = mm->add_instruction(op, {prev_input}); + } + assert(prev_input->get_shape() == x_ins->get_shape()); + param_map[x_ins] = prev_input; // this is to avoid adding parameter for gemm/conv reshaped + // input to pointwise in new fused module + mm->add_return(mm->fuse(*pm, pw_ins->inputs(), ¶m_map)); std::vector inputs; - std::copy_if(ins->inputs().begin(), - ins->inputs().end(), + std::copy_if(pw_ins->inputs().begin(), + pw_ins->inputs().end(), std::back_inserter(inputs), - [&](auto input) { return input != gemm_based_op; }); + [&](auto input) { return input != x_ins; }); inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end()); mpm.get_module().replace_instruction( - ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); + pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); } }; From d49cfe3da80719b4ce602538f7175819bdd04467 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 12 Jul 2024 13:48:29 +0000 Subject: [PATCH 054/145] pull in changes for find_dot_slice --- src/simplify_algebra.cpp | 79 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 326b5ef7752..534466ed796 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -266,6 +266,84 @@ struct find_mul_dot } }; +struct find_dot_slice +{ + auto matcher() const + { + return match::name("slice")( + match::args(match::name("dot", "quant_dot")(match::used_once()).bind("dot_ins"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto slice_ins = r.result; + auto dot_ins = r.instructions["dot_ins"]; + auto slice_op = slice_ins->get_operator().to_value(); + auto axes = slice_op["axes"].to_vector(); + auto starts = slice_op["starts"].to_vector(); + auto ends = slice_op["ends"].to_vector(); + assert(starts.size() == ends.size() and starts.size() == axes.size()); + auto has_neg_vals = [](auto vec) { + return std::any_of(vec.begin(), vec.end(), [](auto i) { return i < 0; }); + }; + if(has_neg_vals(starts) or has_neg_vals(ends) or has_neg_vals(axes)) + { + return; + } + auto dot_inputs = dot_ins->inputs(); + auto num_batch_dims = dot_ins->get_shape().lens().size() - 2; + std::vector slice_axes_1, starts_1, ends_1; // NOLINT + std::vector slice_axes_2, starts_2, ends_2; // NOLINT + for(auto i : range(axes.size())) + { + if(axes[i] < num_batch_dims) + { + slice_axes_1.push_back(axes[i]); + starts_1.push_back(starts[i]); + ends_1.push_back(ends[i]); + slice_axes_2.push_back(axes[i]); + starts_2.push_back(starts[i]); + ends_2.push_back(ends[i]); + } + else if(axes[i] == num_batch_dims) + { + slice_axes_1.push_back(axes[i]); + starts_1.push_back(starts[i]); + ends_1.push_back(ends[i]); + } + else if(axes[i] == num_batch_dims + 1) + { + slice_axes_2.push_back(axes[i]); + starts_2.push_back(starts[i]); + ends_2.push_back(ends[i]); + } + else + { + MIGRAPHX_THROW("FIND_DOT_SLICE: invalid case"); + } + } + auto slice_1 = dot_inputs.at(0); + if(not slice_axes_1.empty()) + { + slice_1 = m.insert_instruction( + slice_ins, + migraphx::make_op("slice", + {{"axes", slice_axes_1}, {"starts", starts_1}, {"ends", ends_1}}), + dot_inputs.at(0)); + } + auto slice_2 = dot_inputs.at(1); + if(not slice_axes_2.empty()) + { + slice_2 = m.insert_instruction( + slice_ins, + migraphx::make_op("slice", + {{"axes", slice_axes_2}, {"starts", starts_2}, {"ends", ends_2}}), + dot_inputs.at(1)); + } + m.replace_instruction(slice_ins, dot_ins->get_operator(), {slice_1, slice_2}); + } +}; + struct find_dot_mul { auto matcher() const @@ -1896,6 +1974,7 @@ void simplify_algebra::apply(module& m) const find_mul_conv{}, find_mul_slice_conv{}, find_mul_dot{}, + find_dot_slice{}, find_dot_mul{}, find_mul_add{}, find_unit_ops{}, From c12c6bcb2a64dbb83dbbda05fc29247f57621959 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 15 Jul 2024 17:44:51 +0000 Subject: [PATCH 055/145] add unittest --- src/targets/gpu/fuse_mlir.cpp | 2 +- test/gpu/fuse_mlir.cpp | 43 +++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index fd9082044b7..ae5814ac542 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -457,7 +457,7 @@ struct find_mlir_fused_ops { prev_input = mm->add_instruction(op, {prev_input}); } - assert(prev_input->get_shape() == x_ins->get_shape()); + assert(prev_input->get_shape().lens() == x_ins->get_shape().lens()); param_map[x_ins] = prev_input; // this is to avoid adding parameter for gemm/conv reshaped // input to pointwise in new fused module mm->add_return(mm->fuse(*pm, pw_ins->inputs(), ¶m_map)); diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index ef7833db825..0ad016410a0 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include "migraphx/shape.hpp" #include #include #include @@ -66,6 +67,48 @@ migraphx::instruction_ref add_mlir(migraphx::program& p, {pm}); } +TEST_CASE(dot_reshape_add) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 3}}); + auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto dot_trans = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot); + auto dot_sq = mm->add_instruction(migraphx::make_op("squeeze"), dot_trans); + auto add = add_pointwise(p1, "main:pointwise0", {dot_sq, x}, single_pointwise("add")); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 3}}); + auto fused = add_mlir( + p2, + "mlir_main:pointwise0", + {x, a, b}, + {"x2", "y0", "y1"}, + [=](auto* pm, const auto& inputs) { + auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); + auto dot_trans = pm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), dot); + auto dot_rsp = pm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 3}}}), + dot_trans); + auto add = pm->add_instruction(migraphx::make_op("add"), dot_rsp, inputs[0]); + return std::make_tuple(dot, add); + }); + mm->add_return({fused}); + } + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(dot_add) { migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; From 55c3c6d44222de0b9a951106306d54d7fecc838f Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 15 Jul 2024 19:15:02 +0000 Subject: [PATCH 056/145] add verify test --- src/targets/gpu/jit/mlir.cpp | 10 +++- test/gpu/fuse_mlir.cpp | 2 +- test/verify/test_gemm_reshapes_pointwise.cpp | 63 ++++++++++++++++++++ 3 files changed, 71 insertions(+), 4 deletions(-) create mode 100644 test/verify/test_gemm_reshapes_pointwise.cpp diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 7ac0be64f5c..ffd1eae67c3 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -79,9 +79,13 @@ struct mlir_compiler : compiler auto gemm_like_ins = std::find_if(smod->begin(), smod->end(), [&](const auto& i) { return contains({"dot", "quant_dot", "convolution", "quant_convolution"}, i.name()); }); - // check if (a) module is fused (b) contains a "gemm/conv" instruction and (c) perfConfig - // can not allow fused module - if(gemm_like_ins != smod->end() and std::distance(gemm_like_ins, smod->end()) > 2 and + auto pointwise_ins = std::find_if(gemm_like_ins, smod->end(), [&](const auto& i) { + return i.get_operator().attributes().get("pointwise", false) == true; + }); + + // check if (a) module is fused (b) contains a "gemm/conv" instruction and (c) + // perfConfig can not allow fused module + if(gemm_like_ins != smod->end() and pointwise_ins != smod->end() and not is_module_fusible(*smod, ctx, solution)) { auto input_args = ins->inputs(); diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 0ad016410a0..750adcdc01d 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -67,7 +67,7 @@ migraphx::instruction_ref add_mlir(migraphx::program& p, {pm}); } -TEST_CASE(dot_reshape_add) +TEST_CASE(dot_reshapes_add) { migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; migraphx::program p1; diff --git a/test/verify/test_gemm_reshapes_pointwise.cpp b/test/verify/test_gemm_reshapes_pointwise.cpp new file mode 100644 index 00000000000..cb293fa69b3 --- /dev/null +++ b/test/verify/test_gemm_reshapes_pointwise.cpp @@ -0,0 +1,63 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include + +template +struct test_gemm_reshapes_add : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{DType, {1, 2, 1024}}; + migraphx::shape m2_shape{DType, {1, 1024, 320}}; + migraphx::shape m3_shape{DType, {1, 2, 320}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + auto l3 = mm->add_parameter("3", m3_shape); + + auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2); + // auto dot_sq = mm->add_instruction(migraphx::make_op("squeeze"), dot); + // auto dot_trans = + // mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), + // dot_sq); + mm->add_instruction(migraphx::make_op("add"), dot, l3); + return p; + } + + std::string section() const { return "gemm"; } + // Turn on Exhaustive-tune to enable split-k GEMM perf-configs from MLIR + migraphx::compile_options get_compile_options() const + { + return migraphx::compile_options{.exhaustive_tune = true}; + } +}; + +template struct test_gemm_reshapes_add; +template struct test_gemm_reshapes_add; From 1f76cc5d13415f34fcf5a31012acc029ce4e686d Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 16 Jul 2024 13:22:05 +0000 Subject: [PATCH 057/145] debugging --- src/module.cpp | 5 +- src/targets/gpu/compile_ops.cpp | 1 + src/targets/gpu/jit/mlir.cpp | 51 ++++++++++++++------ test/verify/test_gemm_add.cpp | 4 +- test/verify/test_gemm_reshapes_pointwise.cpp | 11 ++--- 5 files changed, 49 insertions(+), 23 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 9d2229dd222..5f9aa2e61f8 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -901,7 +901,10 @@ generic_split(const module& m, splits.end(), std::back_inserter(outputs), [&](instruction_ref ins) { return map_ins1.at(ins); }); - m1.add_return(outputs); + if(not outputs.empty()) + { + m1.add_return(outputs); + } std::vector instructions2; for(auto ins : iterator_for(m)) diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index cc5a7fc24d7..0604e417200 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -227,6 +227,7 @@ struct compile_plan cr->replace.replace(*bench_mm, bench_ins); // do dead code elimination by directly removing instruction bench_mm->remove_instruction(bench_ins); + bench_prog.debug_print(); auto t = time_program(*ctx, bench_prog, 20); if(trace_level > 1) std::cout << t << "ms" << std::endl; diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index ffd1eae67c3..baf71cf026f 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -90,22 +90,29 @@ struct mlir_compiler : compiler { auto input_args = ins->inputs(); input_args.pop_back(); - auto mod_splits = smod->split(input_args, {gemm_like_ins}); + auto split_ins_pw = std::prev(pointwise_ins); + std::array mod_splits; + if(split_ins_pw == gemm_like_ins) + mod_splits = smod->split(input_args, {gemm_like_ins}, {}); + else + mod_splits = smod->split(input_args, {gemm_like_ins}, {split_ins_pw}); + std::cout << "gemm mod\n"; + mod_splits[0].mod.debug_print(); + std::cout << "reshapes mod\n"; + mod_splits[1].mod.debug_print(); + std::cout << "pointwise mod\n"; + mod_splits[2].mod.debug_print(); auto dot_mlir_inputs = to_shapes(mod_splits[0].inputs); dot_mlir_inputs.push_back(mod_splits[0].mod.get_output_shapes().front()); mlir_code_object cop1 = compile_mlir(ctx, mod_splits[0].mod, dot_mlir_inputs, solution); - auto pw_inputs = mod_splits[1].inputs; - auto dot_ins_idx = std::distance( - std::find(pw_inputs.begin(), pw_inputs.end(), gemm_like_ins), pw_inputs.begin()); - auto pw_shapes = to_shapes(mod_splits[1].inputs); - pw_shapes[dot_ins_idx] = cop1.cop.output; - pw_shapes.push_back(mod_splits[1].mod.get_output_shapes().front()); + auto pw_shapes = to_shapes(mod_splits[2].inputs); + pw_shapes.push_back(mod_splits[2].mod.get_output_shapes().front()); assert(pw_shapes.back() == ins->get_shape()); - auto pw_mod = create_pointwise_module(&mod_splits[1].mod); + auto pw_mod = create_pointwise_module(&mod_splits[2].mod); auto cop2 = compile_pointwise(ctx, pw_shapes, &pw_mod); std::vector cops = {cop1, mlir_code_object{any_cast(cop2)}}; - return insert(cops, mod_splits, ins, gemm_like_ins); + return insert(cops, mod_splits, ins, gemm_like_ins, split_ins_pw); } return insert(compile_mlir(ctx, *smod, to_shapes(ins->inputs()), solution)); } @@ -130,9 +137,10 @@ struct mlir_compiler : compiler } compiler_replace insert(const std::vector& mcos, - const std::array& mods, + const std::array& mods, instruction_ref precompile_ins, - instruction_ref split_ins) const + instruction_ref split_ins, + instruction_ref split_pw) const { std::vector cobjs(mcos.size()); std::transform( @@ -175,9 +183,24 @@ struct mlir_compiler : compiler }); auto mlir = insert_mlir(m, ins, any_cast(ops[0]), dot_inputs_updated); - assert(contains(mods[1].inputs, split_ins)); - auto pwm = mods[1]; - pwm.replace(split_ins, mlir); + auto reshape_ins = mlir; + if(mods[1].mod.size() > 0) + { + assert(contains(mods[1].inputs, split_ins)); + std::unordered_map reshape_mod_map; + for(const auto i : iterator_for(mods[1].mod)) + { + if(i->name() == "@param") + { + reshape_mod_map[i] = mlir; + break; + } + } + reshape_ins = + m.insert_instructions(ins, &mods[1].mod, &reshape_mod_map).front(); + } + auto pwm = mods[2]; + pwm.replace(split_pw, reshape_ins); auto pw_inputs = pwm.inputs; pw_inputs.push_back(ins->inputs().back()); std::vector pw_inputs_updated; diff --git a/test/verify/test_gemm_add.cpp b/test/verify/test_gemm_add.cpp index 5955c666eae..7b04fc295e5 100644 --- a/test/verify/test_gemm_add.cpp +++ b/test/verify/test_gemm_add.cpp @@ -35,8 +35,8 @@ struct test_gemm_add : verify_program> { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape m1_shape{DType, {1, 2, 1280}}; - migraphx::shape m2_shape{DType, {1, 1280, 320}}; + migraphx::shape m1_shape{DType, {1, 2, 1024}}; + migraphx::shape m2_shape{DType, {1, 1024, 320}}; migraphx::shape m3_shape{DType, {1, 2, 320}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); diff --git a/test/verify/test_gemm_reshapes_pointwise.cpp b/test/verify/test_gemm_reshapes_pointwise.cpp index cb293fa69b3..51f3d5e8d39 100644 --- a/test/verify/test_gemm_reshapes_pointwise.cpp +++ b/test/verify/test_gemm_reshapes_pointwise.cpp @@ -37,17 +37,16 @@ struct test_gemm_reshapes_add : verify_program> auto* mm = p.get_main_module(); migraphx::shape m1_shape{DType, {1, 2, 1024}}; migraphx::shape m2_shape{DType, {1, 1024, 320}}; - migraphx::shape m3_shape{DType, {1, 2, 320}}; + migraphx::shape m3_shape{DType, {320, 2}}; auto l1 = mm->add_parameter("1", m1_shape); auto l2 = mm->add_parameter("2", m2_shape); auto l3 = mm->add_parameter("3", m3_shape); auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2); - // auto dot_sq = mm->add_instruction(migraphx::make_op("squeeze"), dot); - // auto dot_trans = - // mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), - // dot_sq); - mm->add_instruction(migraphx::make_op("add"), dot, l3); + auto dot_sq = mm->add_instruction(migraphx::make_op("squeeze"), dot); + auto dot_trans = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), dot_sq); + mm->add_instruction(migraphx::make_op("add"), dot_trans, l3); return p; } From a238d2aa8f00406c33f102244be45e8cac841426 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 16 Jul 2024 13:45:40 +0000 Subject: [PATCH 058/145] add lowering for contiguous --- src/targets/gpu/jit/mlir.cpp | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index baf71cf026f..6aaa76186d0 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -197,7 +197,33 @@ struct mlir_compiler : compiler } } reshape_ins = - m.insert_instructions(ins, &mods[1].mod, &reshape_mod_map).front(); + m.insert_instructions( + ins, + &mods[1].mod, + &reshape_mod_map, + [](module& insert_mod, + instruction_ref insert_loc, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) -> instruction_ref { + if(op.name() == "contiguous") + { + auto contiguous_alloc = insert_mod.insert_instruction( + insert_loc, + make_op( + "hip::allocate", + {{"shape", + to_value(op.compute_shape(to_shapes(inputs)))}})); + auto contiguous_inputs = inputs; + contiguous_inputs.push_back(contiguous_alloc); + return insert_mod.insert_instruction( + insert_loc, make_op("gpu::contiguous"), contiguous_inputs); + } + else + return insert_mod.insert_instruction( + insert_loc, op, inputs, mod_args); + }) + .front(); } auto pwm = mods[2]; pwm.replace(split_pw, reshape_ins); From e26120ba9a10540cbed483419302582dc251c347 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 16 Jul 2024 14:13:24 +0000 Subject: [PATCH 059/145] use input_rep_map --- src/targets/gpu/jit/mlir.cpp | 40 ++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 6aaa76186d0..58ff6a7a43c 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -96,12 +96,12 @@ struct mlir_compiler : compiler mod_splits = smod->split(input_args, {gemm_like_ins}, {}); else mod_splits = smod->split(input_args, {gemm_like_ins}, {split_ins_pw}); - std::cout << "gemm mod\n"; - mod_splits[0].mod.debug_print(); - std::cout << "reshapes mod\n"; - mod_splits[1].mod.debug_print(); - std::cout << "pointwise mod\n"; - mod_splits[2].mod.debug_print(); + // std::cout << "gemm mod\n"; + // mod_splits[0].mod.debug_print(); + // std::cout << "reshapes mod\n"; + // mod_splits[1].mod.debug_print(); + // std::cout << "pointwise mod\n"; + // mod_splits[2].mod.debug_print(); auto dot_mlir_inputs = to_shapes(mod_splits[0].inputs); dot_mlir_inputs.push_back(mod_splits[0].mod.get_output_shapes().front()); mlir_code_object cop1 = compile_mlir(ctx, mod_splits[0].mod, dot_mlir_inputs, solution); @@ -109,6 +109,8 @@ struct mlir_compiler : compiler pw_shapes.push_back(mod_splits[2].mod.get_output_shapes().front()); assert(pw_shapes.back() == ins->get_shape()); auto pw_mod = create_pointwise_module(&mod_splits[2].mod); + // std::cout << "pointwise module that is created\n"; + // pw_mod.debug_print(); auto cop2 = compile_pointwise(ctx, pw_shapes, &pw_mod); std::vector cops = {cop1, mlir_code_object{any_cast(cop2)}}; @@ -187,12 +189,11 @@ struct mlir_compiler : compiler if(mods[1].mod.size() > 0) { assert(contains(mods[1].inputs, split_ins)); - std::unordered_map reshape_mod_map; for(const auto i : iterator_for(mods[1].mod)) { if(i->name() == "@param") { - reshape_mod_map[i] = mlir; + inputs_rep_map[i] = mlir; break; } } @@ -200,7 +201,7 @@ struct mlir_compiler : compiler m.insert_instructions( ins, &mods[1].mod, - &reshape_mod_map, + &inputs_rep_map, [](module& insert_mod, instruction_ref insert_loc, const operation& op, @@ -208,16 +209,19 @@ struct mlir_compiler : compiler const std::vector& mod_args) -> instruction_ref { if(op.name() == "contiguous") { - auto contiguous_alloc = insert_mod.insert_instruction( - insert_loc, - make_op( - "hip::allocate", - {{"shape", - to_value(op.compute_shape(to_shapes(inputs)))}})); - auto contiguous_inputs = inputs; - contiguous_inputs.push_back(contiguous_alloc); return insert_mod.insert_instruction( - insert_loc, make_op("gpu::contiguous"), contiguous_inputs); + insert_loc, make_op("identity"), inputs); + // auto contiguous_alloc = insert_mod.insert_instruction( + // insert_loc, + // make_op( + // "hip::allocate", + // {{"shape", + // to_value(op.compute_shape(to_shapes(inputs)))}})); + // auto contiguous_inputs = inputs; + // contiguous_inputs.push_back(contiguous_alloc); + // return insert_mod.insert_instruction( + // insert_loc, make_op("gpu::contiguous"), + // contiguous_inputs); } else return insert_mod.insert_instruction( From c8b06d5f1c1c9d483c3c5c361c5005aab21b8a29 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 16 Jul 2024 14:56:46 +0000 Subject: [PATCH 060/145] add eliminate_contiguous --- src/targets/gpu/jit/mlir.cpp | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 58ff6a7a43c..e9d9de70db2 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -28,6 +28,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -88,6 +91,9 @@ struct mlir_compiler : compiler if(gemm_like_ins != smod->end() and pointwise_ins != smod->end() and not is_module_fusible(*smod, ctx, solution)) { + migraphx::run_passes( + *smod, + {migraphx::eliminate_contiguous{"contiguous"}, migraphx::dead_code_elimination{}}); auto input_args = ins->inputs(); input_args.pop_back(); auto split_ins_pw = std::prev(pointwise_ins); @@ -96,21 +102,26 @@ struct mlir_compiler : compiler mod_splits = smod->split(input_args, {gemm_like_ins}, {}); else mod_splits = smod->split(input_args, {gemm_like_ins}, {split_ins_pw}); - // std::cout << "gemm mod\n"; - // mod_splits[0].mod.debug_print(); - // std::cout << "reshapes mod\n"; - // mod_splits[1].mod.debug_print(); - // std::cout << "pointwise mod\n"; - // mod_splits[2].mod.debug_print(); + std::cout << "gemm mod\n"; + mod_splits[0].mod.debug_print(); + std::cout << "reshapes mod\n"; + mod_splits[1].mod.debug_print(); + std::cout << "pointwise mod\n"; + mod_splits[2].mod.debug_print(); auto dot_mlir_inputs = to_shapes(mod_splits[0].inputs); dot_mlir_inputs.push_back(mod_splits[0].mod.get_output_shapes().front()); mlir_code_object cop1 = compile_mlir(ctx, mod_splits[0].mod, dot_mlir_inputs, solution); auto pw_shapes = to_shapes(mod_splits[2].inputs); + std::cout << "pointwise module inputs\n"; + for(const auto i : mod_splits[2].inputs) + { + i->debug_print(); + } pw_shapes.push_back(mod_splits[2].mod.get_output_shapes().front()); assert(pw_shapes.back() == ins->get_shape()); auto pw_mod = create_pointwise_module(&mod_splits[2].mod); - // std::cout << "pointwise module that is created\n"; - // pw_mod.debug_print(); + std::cout << "pointwise module that is created\n"; + pw_mod.debug_print(); auto cop2 = compile_pointwise(ctx, pw_shapes, &pw_mod); std::vector cops = {cop1, mlir_code_object{any_cast(cop2)}}; From 64642c9783e8bc49dd5fd1d167f60970636d92c8 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 16 Jul 2024 17:09:39 +0000 Subject: [PATCH 061/145] Add lowering for reshape --- src/targets/gpu/compile_ops.cpp | 1 - src/targets/gpu/jit/mlir.cpp | 44 ++++++++++++++------------------- 2 files changed, 18 insertions(+), 27 deletions(-) diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index 0604e417200..cc5a7fc24d7 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -227,7 +227,6 @@ struct compile_plan cr->replace.replace(*bench_mm, bench_ins); // do dead code elimination by directly removing instruction bench_mm->remove_instruction(bench_ins); - bench_prog.debug_print(); auto t = time_program(*ctx, bench_prog, 20); if(trace_level > 1) std::cout << t << "ms" << std::endl; diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index e9d9de70db2..f24cf47a293 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -102,26 +102,13 @@ struct mlir_compiler : compiler mod_splits = smod->split(input_args, {gemm_like_ins}, {}); else mod_splits = smod->split(input_args, {gemm_like_ins}, {split_ins_pw}); - std::cout << "gemm mod\n"; - mod_splits[0].mod.debug_print(); - std::cout << "reshapes mod\n"; - mod_splits[1].mod.debug_print(); - std::cout << "pointwise mod\n"; - mod_splits[2].mod.debug_print(); auto dot_mlir_inputs = to_shapes(mod_splits[0].inputs); dot_mlir_inputs.push_back(mod_splits[0].mod.get_output_shapes().front()); mlir_code_object cop1 = compile_mlir(ctx, mod_splits[0].mod, dot_mlir_inputs, solution); auto pw_shapes = to_shapes(mod_splits[2].inputs); - std::cout << "pointwise module inputs\n"; - for(const auto i : mod_splits[2].inputs) - { - i->debug_print(); - } pw_shapes.push_back(mod_splits[2].mod.get_output_shapes().front()); assert(pw_shapes.back() == ins->get_shape()); auto pw_mod = create_pointwise_module(&mod_splits[2].mod); - std::cout << "pointwise module that is created\n"; - pw_mod.debug_print(); auto cop2 = compile_pointwise(ctx, pw_shapes, &pw_mod); std::vector cops = {cop1, mlir_code_object{any_cast(cop2)}}; @@ -218,21 +205,26 @@ struct mlir_compiler : compiler const operation& op, const std::vector& inputs, const std::vector& mod_args) -> instruction_ref { - if(op.name() == "contiguous") + if(op.name() == "reshape") { return insert_mod.insert_instruction( - insert_loc, make_op("identity"), inputs); - // auto contiguous_alloc = insert_mod.insert_instruction( - // insert_loc, - // make_op( - // "hip::allocate", - // {{"shape", - // to_value(op.compute_shape(to_shapes(inputs)))}})); - // auto contiguous_inputs = inputs; - // contiguous_inputs.push_back(contiguous_alloc); - // return insert_mod.insert_instruction( - // insert_loc, make_op("gpu::contiguous"), - // contiguous_inputs); + insert_loc, + make_op("reshape_lazy", + {{"dims", + op.compute_shape(to_shapes(inputs)).lens()}}), + inputs); + } else if(op.name() == "contiguous") + { + auto contiguous_alloc = insert_mod.insert_instruction( + insert_loc, + make_op( + "hip::allocate", + {{"shape", + to_value(op.compute_shape(to_shapes(inputs)))}})); + auto contiguous_inputs = inputs; + contiguous_inputs.push_back(contiguous_alloc); + return insert_mod.insert_instruction( + insert_loc, make_op("gpu::contiguous"), contiguous_inputs); } else return insert_mod.insert_instruction( From 886fc1ba4c8d51688648de7dae5bad4108acafde Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 16 Jul 2024 17:53:41 +0000 Subject: [PATCH 062/145] Fix cppcheck --- src/targets/gpu/jit/mlir.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index f24cf47a293..1ba4603d45e 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -187,14 +187,12 @@ struct mlir_compiler : compiler if(mods[1].mod.size() > 0) { assert(contains(mods[1].inputs, split_ins)); - for(const auto i : iterator_for(mods[1].mod)) - { - if(i->name() == "@param") - { - inputs_rep_map[i] = mlir; - break; - } - } + assert(mods[1].mod.get_parameters().size() == 1); + auto param_ins = + std::find_if(mods[1].mod.begin(), mods[1].mod.end(), [](const auto& i) { + return i->name() == "@param"; + }); + inputs_rep_map[param_ins] = mlir; reshape_ins = m.insert_instructions( ins, From 04e37ad872af64764281ff711315375765d19257 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 16 Jul 2024 18:06:16 +0000 Subject: [PATCH 063/145] fix tidy --- src/targets/gpu/fuse_mlir.cpp | 2 +- src/targets/gpu/jit/mlir.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index ae5814ac542..26694180ae2 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -165,7 +165,7 @@ MIGRAPHX_REGISTER_OP(mlir_op); namespace { -static const auto& reshaper_names() +const auto& reshaper_names() { // clang-format off static const std::unordered_set names = { diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 1ba4603d45e..9a8856fbaa0 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -190,7 +190,7 @@ struct mlir_compiler : compiler assert(mods[1].mod.get_parameters().size() == 1); auto param_ins = std::find_if(mods[1].mod.begin(), mods[1].mod.end(), [](const auto& i) { - return i->name() == "@param"; + return i.name() == "@param"; }); inputs_rep_map[param_ins] = mlir; reshape_ins = From 240962239cf313d970a695acf4a1be87d1178dd7 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 16 Jul 2024 18:52:10 +0000 Subject: [PATCH 064/145] fixes --- src/targets/gpu/jit/mlir.cpp | 1 + test/gpu/fuse_mlir.cpp | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 9a8856fbaa0..c6451a92b52 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -205,6 +205,7 @@ struct mlir_compiler : compiler const std::vector& mod_args) -> instruction_ref { if(op.name() == "reshape") { + // TODO: Add proper lowering for the reshape op return insert_mod.insert_instruction( insert_loc, make_op("reshape_lazy", diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 750adcdc01d..ac453a14dec 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -21,7 +21,6 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include "migraphx/shape.hpp" #include #include #include From 8a008b68e51bf277e64a1c3f2fbe6da053b4dfc9 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 16 Jul 2024 18:54:54 +0000 Subject: [PATCH 065/145] rename test file --- ...est_gemm_reshapes_pointwise.cpp => test_gemm_reshapes_add.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/verify/{test_gemm_reshapes_pointwise.cpp => test_gemm_reshapes_add.cpp} (100%) diff --git a/test/verify/test_gemm_reshapes_pointwise.cpp b/test/verify/test_gemm_reshapes_add.cpp similarity index 100% rename from test/verify/test_gemm_reshapes_pointwise.cpp rename to test/verify/test_gemm_reshapes_add.cpp From 2a7582018ba084f631e4d64f445dce58f014eec6 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 16 Jul 2024 20:21:30 +0000 Subject: [PATCH 066/145] formatting --- src/targets/gpu/fuse_mlir.cpp | 4 ++-- src/targets/gpu/jit/mlir.cpp | 5 +++-- test/verify/test_gemm_reshapes_add.cpp | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 26694180ae2..ddbe7b119ea 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -440,8 +440,8 @@ struct find_mlir_fused_ops { auto pw_ins = r.result; auto gemm_based_op = r.instructions["gemm_based_op"]; - auto x_ins = r.instructions["x"]; // input to pointwise after reshaper op stream - auto* pm = pw_ins->module_inputs().front(); + auto x_ins = r.instructions["x"]; // input to pointwise after reshaper op stream + auto* pm = pw_ins->module_inputs().front(); auto names = pm->get_parameter_names(); std::sort(names.begin(), names.end()); module_ref mm = mpm.create_module("mlir_" + pm->name()); diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index c6451a92b52..686975b92ea 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -85,7 +85,7 @@ struct mlir_compiler : compiler auto pointwise_ins = std::find_if(gemm_like_ins, smod->end(), [&](const auto& i) { return i.get_operator().attributes().get("pointwise", false) == true; }); - + // check if (a) module is fused (b) contains a "gemm/conv" instruction and (c) // perfConfig can not allow fused module if(gemm_like_ins != smod->end() and pointwise_ins != smod->end() and @@ -212,7 +212,8 @@ struct mlir_compiler : compiler {{"dims", op.compute_shape(to_shapes(inputs)).lens()}}), inputs); - } else if(op.name() == "contiguous") + } + else if(op.name() == "contiguous") { auto contiguous_alloc = insert_mod.insert_instruction( insert_loc, diff --git a/test/verify/test_gemm_reshapes_add.cpp b/test/verify/test_gemm_reshapes_add.cpp index 51f3d5e8d39..7dd8b2dd396 100644 --- a/test/verify/test_gemm_reshapes_add.cpp +++ b/test/verify/test_gemm_reshapes_add.cpp @@ -42,7 +42,7 @@ struct test_gemm_reshapes_add : verify_program> auto l2 = mm->add_parameter("2", m2_shape); auto l3 = mm->add_parameter("3", m3_shape); - auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2); + auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2); auto dot_sq = mm->add_instruction(migraphx::make_op("squeeze"), dot); auto dot_trans = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), dot_sq); From 96ac474a3cb6fe46b84e33b1c77bf618fedaf0e5 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 16 Jul 2024 20:22:14 +0000 Subject: [PATCH 067/145] fix SLES --- src/targets/gpu/jit/mlir.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 686975b92ea..21f3d938dde 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -187,6 +187,7 @@ struct mlir_compiler : compiler if(mods[1].mod.size() > 0) { assert(contains(mods[1].inputs, split_ins)); + (void)(split_ins); assert(mods[1].mod.get_parameters().size() == 1); auto param_ins = std::find_if(mods[1].mod.begin(), mods[1].mod.end(), [](const auto& i) { From a46bbaa542f960a888fe932c0792ce8010179d72 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 18 Jul 2024 19:16:24 +0000 Subject: [PATCH 068/145] fix test --- test/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index ac453a14dec..d49ab2ec56a 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -166,7 +166,7 @@ TEST_CASE(add_dot) auto y = mm->add_parameter("y", s); auto fused = add_mlir(p2, - "main:pointwise0:mlir_dot1", + "main:pointwise0:mlir_dot2", {x, y, b}, {"x0", "x1", "x2"}, [=](auto* pm, const auto& inputs) { From b88f6bd3a3e9512da0efcb9e0d2c18adbb1e7502 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 19 Jul 2024 13:07:14 +0000 Subject: [PATCH 069/145] use anonymous namespace --- src/simplify_reshapes.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 1332c3ca580..02e77759abe 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -44,7 +44,8 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -static const auto& reshaper_names() +namespace { +const auto& reshaper_names() { // clang-format off static const std::unordered_set names = { @@ -57,6 +58,7 @@ static const auto& reshaper_names() // clang-format on return names; } +} // namespace bool is_reshaper(instruction_ref ins) { return contains(reshaper_names(), ins->name()); } From c9f52019701b5fab17f5505ea21bb840af09000e Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 22 Jul 2024 19:42:32 +0000 Subject: [PATCH 070/145] multi use case --- test/verify/test_gemm_add_broadcast2.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/verify/test_gemm_add_broadcast2.cpp b/test/verify/test_gemm_add_broadcast2.cpp index 2292c2a4f24..145e8aa1526 100644 --- a/test/verify/test_gemm_add_broadcast2.cpp +++ b/test/verify/test_gemm_add_broadcast2.cpp @@ -45,7 +45,13 @@ struct test_gemm_add_broadcast2 : verify_program mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 2, 4}}}), l3); auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2); - mm->add_instruction(migraphx::make_op("add"), dot, l3_b); + auto add = mm->add_instruction(migraphx::make_op("add"), dot, l3_b); + auto reduce_mean = + mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {-1}}}), dot); + auto mlb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 2, 4}}}), reduce_mean); + auto sub = mm->add_instruction(migraphx::make_op("sub"), dot, mlb); + mm->add_instruction(migraphx::make_op("div"), sub, add); return p; } std::string section() const { return "gemm"; } From 8a44a139bc552398971a2967134869f75ab9da0b Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 22 Jul 2024 22:27:29 +0000 Subject: [PATCH 071/145] fix replace --- src/targets/gpu/fuse_mlir.cpp | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index ddbe7b119ea..cafc1a61843 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include "migraphx/instruction_ref.hpp" #include #include #include @@ -431,7 +432,7 @@ struct find_mlir_fused_ops auto reshapes = reshaper_names(); // slice is not supported reshapes.erase("slice"); - auto dot_or_conv = match::skip(match::name(reshapes))( + auto dot_or_conv = match::skip(match::all_of(match::name(reshapes), match::used_once()))( match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op")); return mlir_pointwise()(match::any_of[match::inputs()](dot_or_conv.bind("x"))); } @@ -440,6 +441,8 @@ struct find_mlir_fused_ops { auto pw_ins = r.result; auto gemm_based_op = r.instructions["gemm_based_op"]; + bool gemm_has_multi_outs = gemm_based_op->outputs().size() > 1; + std::cout << "gemm has multi outs: " << gemm_has_multi_outs << std::endl; auto x_ins = r.instructions["x"]; // input to pointwise after reshaper op stream auto* pm = pw_ins->module_inputs().front(); auto names = pm->get_parameter_names(); @@ -460,7 +463,18 @@ struct find_mlir_fused_ops assert(prev_input->get_shape().lens() == x_ins->get_shape().lens()); param_map[x_ins] = prev_input; // this is to avoid adding parameter for gemm/conv reshaped // input to pointwise in new fused module - mm->add_return(mm->fuse(*pm, pw_ins->inputs(), ¶m_map)); + auto return_vals = mm->fuse(*pm, pw_ins->inputs(), ¶m_map); + instruction_ref reshape_out = x_ins; + if(gemm_has_multi_outs) + { + while(x_ins != gemm_based_op && reshape_out->inputs().at(0) != gemm_based_op) + { + reshape_out = reshape_out->inputs().at(0); + } + return_vals.insert(return_vals.begin(), anchor_op); + reshape_out->debug_print(); + } + mm->add_return(return_vals); std::vector inputs; std::copy_if(pw_ins->inputs().begin(), @@ -468,8 +482,20 @@ struct find_mlir_fused_ops std::back_inserter(inputs), [&](auto input) { return input != x_ins; }); inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end()); - mpm.get_module().replace_instruction( - pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); + if(gemm_has_multi_outs) + { + auto fused_ins = mpm.get_module().insert_instruction( + pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); + mpm.get_module().replace_instruction( + gemm_based_op, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins); + mpm.get_module().replace_instruction( + pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused_ins); + } + else + { + mpm.get_module().replace_instruction( + pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); + } } }; From c984b83849267dd609f643f7b9f7d47a44e8faab Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 23 Jul 2024 13:59:12 +0000 Subject: [PATCH 072/145] clean up --- src/targets/gpu/fuse_mlir.cpp | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index cafc1a61843..ce81d9185b8 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -441,8 +441,6 @@ struct find_mlir_fused_ops { auto pw_ins = r.result; auto gemm_based_op = r.instructions["gemm_based_op"]; - bool gemm_has_multi_outs = gemm_based_op->outputs().size() > 1; - std::cout << "gemm has multi outs: " << gemm_has_multi_outs << std::endl; auto x_ins = r.instructions["x"]; // input to pointwise after reshaper op stream auto* pm = pw_ins->module_inputs().front(); auto names = pm->get_parameter_names(); @@ -464,15 +462,10 @@ struct find_mlir_fused_ops param_map[x_ins] = prev_input; // this is to avoid adding parameter for gemm/conv reshaped // input to pointwise in new fused module auto return_vals = mm->fuse(*pm, pw_ins->inputs(), ¶m_map); - instruction_ref reshape_out = x_ins; + bool gemm_has_multi_outs = gemm_based_op->outputs().size() > 1; if(gemm_has_multi_outs) { - while(x_ins != gemm_based_op && reshape_out->inputs().at(0) != gemm_based_op) - { - reshape_out = reshape_out->inputs().at(0); - } return_vals.insert(return_vals.begin(), anchor_op); - reshape_out->debug_print(); } mm->add_return(return_vals); From ea3fdb782b00096b13490cc208bce24ee53bbd20 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 23 Jul 2024 14:27:34 +0000 Subject: [PATCH 073/145] add test --- src/targets/gpu/fuse_mlir.cpp | 6 ++-- test/gpu/fuse_mlir.cpp | 61 +++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index ce81d9185b8..481710abaea 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -479,16 +479,18 @@ struct find_mlir_fused_ops { auto fused_ins = mpm.get_module().insert_instruction( pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); - mpm.get_module().replace_instruction( - gemm_based_op, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins); mpm.get_module().replace_instruction( pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused_ins); + auto dot_ins = mpm.get_module().insert_instruction( + pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins); + mpm.get_module().replace_instruction(gemm_based_op, dot_ins); } else { mpm.get_module().replace_instruction( pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); } + mpm.get_module().debug_print(); } }; diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 6790abbc9cd..2a69cef77bb 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -144,6 +144,67 @@ TEST_CASE(dot_add) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(dot_multi_use_add_reduce_sub) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = add_pointwise(p1, "main:pointwise0", {dot, x}, single_pointwise("add")); + auto pooling = + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::lpnorm}, + {"padding", {0, 0}}, + {"stride", {1, 1}}, + {"lengths", {3, 3}}, + {"lp_order", 2}}), + add); + auto pooling_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), pooling); + auto sub = add_pointwise(p1, "main:pointwise1", {dot, pooling_mb}, single_pointwise("sub")); + mm->add_return({sub}); + } + p1.debug_print(); + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto fused = add_mlir( + p2, + "mlir_main:pointwise0", + {x, a, b}, + {"x2", "y0", "y1"}, + [=](auto* pm, const auto& inputs) { + auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); + auto add = pm->add_instruction(migraphx::make_op("add"), dot, inputs[0]); + return std::make_tuple(dot, std::vector{dot, add}); + }); + auto fused_dot_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto pooling = + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::lpnorm}, + {"padding", {0, 0}}, + {"stride", {1, 1}}, + {"lengths", {3, 3}}, + {"lp_order", 2}}), + fused_dot_add); + auto pooling_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), pooling); + auto dot = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto sub = add_pointwise(p2, "main:pointwise1", {dot, pooling_mb}, single_pointwise("sub")); + mm->add_return({sub}); + } + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(add_dot) { migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; From 329955b2a9beb9652d89dffa25333d5dac93482d Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 23 Jul 2024 14:58:31 +0000 Subject: [PATCH 074/145] add multi use case --- src/targets/gpu/fuse_mlir.cpp | 23 +++++--- test/gpu/fuse_mlir.cpp | 107 +++++++++++++++++++++++++++------- 2 files changed, 103 insertions(+), 27 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 481710abaea..c13ffb0b291 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -432,7 +432,7 @@ struct find_mlir_fused_ops auto reshapes = reshaper_names(); // slice is not supported reshapes.erase("slice"); - auto dot_or_conv = match::skip(match::all_of(match::name(reshapes), match::used_once()))( + auto dot_or_conv = match::skip(match::name(reshapes))( match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op")); return mlir_pointwise()(match::any_of[match::inputs()](dot_or_conv.bind("x"))); } @@ -461,8 +461,15 @@ struct find_mlir_fused_ops assert(prev_input->get_shape().lens() == x_ins->get_shape().lens()); param_map[x_ins] = prev_input; // this is to avoid adding parameter for gemm/conv reshaped // input to pointwise in new fused module - auto return_vals = mm->fuse(*pm, pw_ins->inputs(), ¶m_map); - bool gemm_has_multi_outs = gemm_based_op->outputs().size() > 1; + bool gemm_has_multi_outs = gemm_based_op->outputs().size() > 1; + auto reshaped_gemm = x_ins; + while(reshaped_gemm != gemm_based_op) + { + gemm_has_multi_outs |= reshaped_gemm->outputs().size() > 1; + reshaped_gemm = reshaped_gemm->inputs().at(0); + } + + auto return_vals = mm->fuse(*pm, pw_ins->inputs(), ¶m_map); if(gemm_has_multi_outs) { return_vals.insert(return_vals.begin(), anchor_op); @@ -477,12 +484,15 @@ struct find_mlir_fused_ops inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end()); if(gemm_has_multi_outs) { - auto fused_ins = mpm.get_module().insert_instruction( - pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); + auto fused_ins = + mpm.get_module().insert_instruction(gemm_based_op, + mlir_op{gemm_based_op->get_operator()}, + mlir_contiguous(mpm, inputs), + {mm}); mpm.get_module().replace_instruction( pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused_ins); auto dot_ins = mpm.get_module().insert_instruction( - pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins); + gemm_based_op, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins); mpm.get_module().replace_instruction(gemm_based_op, dot_ins); } else @@ -490,7 +500,6 @@ struct find_mlir_fused_ops mpm.get_module().replace_instruction( pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); } - mpm.get_module().debug_print(); } }; diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 2a69cef77bb..72a4630ae14 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -144,38 +144,102 @@ TEST_CASE(dot_add) EXPECT(p1.sort() == p2.sort()); } -TEST_CASE(dot_multi_use_add_reduce_sub) +TEST_CASE(multi_use_dot_trans_add_pooling_sub) { - migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 3}}; + migraphx::shape s1{migraphx::shape::float_type, {1, 1, 4, 5}}; + migraphx::shape s2{migraphx::shape::float_type, {1, 1, 5, 5}}; migraphx::program p1; { auto* mm = p1.get_main_module(); - auto a = mm->add_parameter("a", s); - auto b = mm->add_parameter("b", s); - auto x = mm->add_parameter("x", s); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b); - auto add = add_pointwise(p1, "main:pointwise0", {dot, x}, single_pointwise("add")); + auto dot_trans = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dot); + auto add = add_pointwise(p1, "main:pointwise0", {dot_trans, x}, single_pointwise("add")); auto pooling = mm->add_instruction(migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::lpnorm}, - {"padding", {0, 0}}, + {"padding", {0, 0, 0, 1}}, {"stride", {1, 1}}, - {"lengths", {3, 3}}, + {"lengths", {2, 1}}, {"lp_order", 2}}), add); - auto pooling_mb = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), pooling); - auto sub = add_pointwise(p1, "main:pointwise1", {dot, pooling_mb}, single_pointwise("sub")); + auto sub = add_pointwise(p1, "main:pointwise1", {dot, pooling}, single_pointwise("sub")); mm->add_return({sub}); } + run_pass(p1); p1.debug_print(); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); + auto fused = add_mlir( + p2, + "mlir_main:pointwise0", + {x, a, b}, + {"x2", "y0", "y1"}, + [=](auto* pm, const auto& inputs) { + auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); + auto dot_trans = pm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dot); + + auto add = pm->add_instruction(migraphx::make_op("add"), dot_trans, inputs[0]); + return std::make_tuple(dot, std::vector{dot, add}); + }); + auto fused_dot_add = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto pooling = + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::lpnorm}, + {"padding", {0, 0, 0, 1}}, + {"stride", {1, 1}}, + {"lengths", {2, 1}}, + {"lp_order", 2}}), + fused_dot_add); + auto dot = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto sub = add_pointwise(p2, "main:pointwise1", {dot, pooling}, single_pointwise("sub")); + mm->add_return({sub}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dot_multi_use_trans_add_pooling_sub) +{ + migraphx::shape s1{migraphx::shape::float_type, {1, 1, 4, 5}}; + migraphx::shape s2{migraphx::shape::float_type, {1, 1, 5, 5}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); + auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto dot_trans = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dot); + auto add = add_pointwise(p1, "main:pointwise0", {dot_trans, x}, single_pointwise("add")); + auto pooling = + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::lpnorm}, + {"padding", {1, 0, 0, 0}}, + {"stride", {1, 1}}, + {"lengths", {2, 1}}, + {"lp_order", 2}}), + add); + auto sub = + add_pointwise(p1, "main:pointwise1", {dot_trans, pooling}, single_pointwise("sub")); + mm->add_return({sub}); + } run_pass(p1); migraphx::program p2; { auto* mm = p2.get_main_module(); - auto a = mm->add_parameter("a", s); - auto b = mm->add_parameter("b", s); - auto x = mm->add_parameter("x", s); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); auto fused = add_mlir( p2, "mlir_main:pointwise0", @@ -183,7 +247,9 @@ TEST_CASE(dot_multi_use_add_reduce_sub) {"x2", "y0", "y1"}, [=](auto* pm, const auto& inputs) { auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); - auto add = pm->add_instruction(migraphx::make_op("add"), dot, inputs[0]); + auto dot_trans = pm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dot); + auto add = pm->add_instruction(migraphx::make_op("add"), dot_trans, inputs[0]); return std::make_tuple(dot, std::vector{dot, add}); }); auto fused_dot_add = @@ -191,15 +257,16 @@ TEST_CASE(dot_multi_use_add_reduce_sub) auto pooling = mm->add_instruction(migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::lpnorm}, - {"padding", {0, 0}}, + {"padding", {1, 0, 0, 0}}, {"stride", {1, 1}}, - {"lengths", {3, 3}}, + {"lengths", {2, 1}}, {"lp_order", 2}}), fused_dot_add); - auto pooling_mb = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), pooling); auto dot = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); - auto sub = add_pointwise(p2, "main:pointwise1", {dot, pooling_mb}, single_pointwise("sub")); + auto dot_trans = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dot); + auto sub = + add_pointwise(p2, "main:pointwise1", {dot_trans, pooling}, single_pointwise("sub")); mm->add_return({sub}); } EXPECT(p1.sort() == p2.sort()); From d14cd663d0fff6a179f4b9a3faf6e34fd53a08d5 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 23 Jul 2024 14:59:20 +0000 Subject: [PATCH 075/145] revert test change --- test/verify/test_gemm_add_broadcast2.cpp | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/test/verify/test_gemm_add_broadcast2.cpp b/test/verify/test_gemm_add_broadcast2.cpp index 145e8aa1526..2292c2a4f24 100644 --- a/test/verify/test_gemm_add_broadcast2.cpp +++ b/test/verify/test_gemm_add_broadcast2.cpp @@ -45,13 +45,7 @@ struct test_gemm_add_broadcast2 : verify_program mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 2, 4}}}), l3); auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2); - auto add = mm->add_instruction(migraphx::make_op("add"), dot, l3_b); - auto reduce_mean = - mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {-1}}}), dot); - auto mlb = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 2, 4}}}), reduce_mean); - auto sub = mm->add_instruction(migraphx::make_op("sub"), dot, mlb); - mm->add_instruction(migraphx::make_op("div"), sub, add); + mm->add_instruction(migraphx::make_op("add"), dot, l3_b); return p; } std::string section() const { return "gemm"; } From 1e981a2d0a4f7e6c4fd039e59acca0ed2dc46162 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 23 Jul 2024 15:06:38 +0000 Subject: [PATCH 076/145] add verify test --- src/targets/gpu/fuse_mlir.cpp | 1 - .../test_gemm_transpose_add_pooling_sub.cpp | 65 +++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 test/verify/test_gemm_transpose_add_pooling_sub.cpp diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index c13ffb0b291..95382cca173 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -21,7 +21,6 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include "migraphx/instruction_ref.hpp" #include #include #include diff --git a/test/verify/test_gemm_transpose_add_pooling_sub.cpp b/test/verify/test_gemm_transpose_add_pooling_sub.cpp new file mode 100644 index 00000000000..29d90aacccb --- /dev/null +++ b/test/verify/test_gemm_transpose_add_pooling_sub.cpp @@ -0,0 +1,65 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include +#include + +template +struct test_gemm_transpose_add_pooling_sub + : verify_program> +{ + migraphx::program create_program() const + { + migraphx::shape s1{migraphx::shape::float_type, {1, 1, 4, 5}}; + migraphx::shape s2{migraphx::shape::float_type, {1, 1, 5, 5}}; + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); + auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto dot_trans = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dot); + auto add = mm->add_instruction(migraphx::make_op("add"), {dot_trans, x}); + auto pooling = + mm->add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::lpnorm}, + {"padding", {1, 0, 0, 0}}, + {"stride", {1, 1}}, + {"lengths", {2, 1}}, + {"lp_order", 2}}), + add); + auto sub = mm->add_instruction(migraphx::make_op("sub"), dot_trans, pooling); + mm->add_return({sub}); + } + std::string section() const { return "gemm"; } +}; + +template struct test_gemm_transpose_add_pooling_sub; +template struct test_gemm_transpose_add_pooling_sub; +template struct test_gemm_transpose_add_pooling_sub; From 2a1c4cd5b2b0d2dd2d6bfb1dd7e20bfd8f0a2369 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 23 Jul 2024 15:09:17 +0000 Subject: [PATCH 077/145] fix return --- test/verify/test_gemm_transpose_add_pooling_sub.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/test/verify/test_gemm_transpose_add_pooling_sub.cpp b/test/verify/test_gemm_transpose_add_pooling_sub.cpp index 29d90aacccb..723e0b36763 100644 --- a/test/verify/test_gemm_transpose_add_pooling_sub.cpp +++ b/test/verify/test_gemm_transpose_add_pooling_sub.cpp @@ -56,6 +56,7 @@ struct test_gemm_transpose_add_pooling_sub add); auto sub = mm->add_instruction(migraphx::make_op("sub"), dot_trans, pooling); mm->add_return({sub}); + return p; } std::string section() const { return "gemm"; } }; From 9a9c2c4d11151b80dacf51ac4b626e79e801843f Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 23 Jul 2024 15:19:28 +0000 Subject: [PATCH 078/145] Foramtting --- test/gpu/fuse_mlir.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 72a4630ae14..f66a1644f01 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -154,7 +154,7 @@ TEST_CASE(multi_use_dot_trans_add_pooling_sub) auto a = mm->add_parameter("a", s1); auto b = mm->add_parameter("b", s2); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); - auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b); auto dot_trans = mm->add_instruction( migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dot); auto add = add_pointwise(p1, "main:pointwise0", {dot_trans, x}, single_pointwise("add")); @@ -236,9 +236,9 @@ TEST_CASE(dot_multi_use_trans_add_pooling_sub) run_pass(p1); migraphx::program p2; { - auto* mm = p2.get_main_module(); - auto a = mm->add_parameter("a", s1); - auto b = mm->add_parameter("b", s2); + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); auto fused = add_mlir( p2, From bb7652831420eee3812a60c2537a60234385e92d Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 23 Jul 2024 12:56:53 -0700 Subject: [PATCH 079/145] Add missing elipsis --- src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp index 35cb4745705..2fe9255294b 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp @@ -196,7 +196,7 @@ template constexpr void unpack_each(F f, Pack1 p1, Pack2 p2) { p1([&](auto&&... xs) { - p2([&](auto&& ys) { + p2([&](auto&&... ys) { each_args( [&](auto&& p) { p(f); }, pack_forward(static_cast(xs), static_cast(ys))...); From 51d3c5f1e3fc73ee968779c9e8b8f84316db846a Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 23 Jul 2024 13:01:59 -0700 Subject: [PATCH 080/145] Add licenses --- src/include/migraphx/float8.hpp | 2 +- src/include/migraphx/instruction.hpp | 2 +- src/include/migraphx/ranges.hpp | 2 +- src/targets/gpu/abs.cpp | 2 +- src/targets/gpu/include/migraphx/gpu/abs.hpp | 2 +- .../gpu/include/migraphx/gpu/convolution.hpp | 2 +- src/targets/gpu/include/migraphx/gpu/lrn.hpp | 2 +- .../gpu/include/migraphx/gpu/miopen.hpp | 2 +- .../gpu/include/migraphx/gpu/pooling.hpp | 2 +- .../include/migraphx/kernels/atomic.hpp | 24 +++++++++++++++++++ .../kernels/include/migraphx/kernels/rank.hpp | 24 +++++++++++++++++++ .../include/migraphx/kernels/types.hpp | 2 +- src/targets/gpu/lrn.cpp | 2 +- src/targets/gpu/pooling.cpp | 2 +- test/fp8e4m3fn.cpp | 2 +- test/fp8e4m3fnuz.cpp | 2 +- test/fp8e5m2.cpp | 2 +- test/fp8e5m2fnuz.cpp | 2 +- test/verify/test_reduce_op_large.cpp | 2 +- 19 files changed, 65 insertions(+), 17 deletions(-) diff --git a/src/include/migraphx/float8.hpp b/src/include/migraphx/float8.hpp index fd91291dabd..dcae919953a 100644 --- a/src/include/migraphx/float8.hpp +++ b/src/include/migraphx/float8.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/include/migraphx/instruction.hpp b/src/include/migraphx/instruction.hpp index 0e722e23ad3..c1a5b15e2af 100644 --- a/src/include/migraphx/instruction.hpp +++ b/src/include/migraphx/instruction.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/include/migraphx/ranges.hpp b/src/include/migraphx/ranges.hpp index d4bfa2cb68f..e51dac1e186 100644 --- a/src/include/migraphx/ranges.hpp +++ b/src/include/migraphx/ranges.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/abs.cpp b/src/targets/gpu/abs.cpp index 6e5698b4aa3..8cd0a1d8bcf 100644 --- a/src/targets/gpu/abs.cpp +++ b/src/targets/gpu/abs.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/include/migraphx/gpu/abs.hpp b/src/targets/gpu/include/migraphx/gpu/abs.hpp index 4620c9ff2b3..1a9f4b87800 100644 --- a/src/targets/gpu/include/migraphx/gpu/abs.hpp +++ b/src/targets/gpu/include/migraphx/gpu/abs.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/include/migraphx/gpu/convolution.hpp b/src/targets/gpu/include/migraphx/gpu/convolution.hpp index 68c812a34da..1b1c3169830 100644 --- a/src/targets/gpu/include/migraphx/gpu/convolution.hpp +++ b/src/targets/gpu/include/migraphx/gpu/convolution.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/include/migraphx/gpu/lrn.hpp b/src/targets/gpu/include/migraphx/gpu/lrn.hpp index d1c16e1eaf6..8ccda7bba6a 100644 --- a/src/targets/gpu/include/migraphx/gpu/lrn.hpp +++ b/src/targets/gpu/include/migraphx/gpu/lrn.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/include/migraphx/gpu/miopen.hpp b/src/targets/gpu/include/migraphx/gpu/miopen.hpp index 9f447b57b44..fb61103538d 100644 --- a/src/targets/gpu/include/migraphx/gpu/miopen.hpp +++ b/src/targets/gpu/include/migraphx/gpu/miopen.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/include/migraphx/gpu/pooling.hpp b/src/targets/gpu/include/migraphx/gpu/pooling.hpp index e3676f3b210..7f6722b1130 100644 --- a/src/targets/gpu/include/migraphx/gpu/pooling.hpp +++ b/src/targets/gpu/include/migraphx/gpu/pooling.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp index 82fee3e7b8d..d097e9b223a 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp @@ -1,3 +1,27 @@ +/* +* The MIT License (MIT) +* +* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +* THE SOFTWARE. +* +*/ #ifndef MIGRAPHX_GUARD_KERNELS_ATOMIC_HPP #define MIGRAPHX_GUARD_KERNELS_ATOMIC_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/rank.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/rank.hpp index 4058de120fd..d2dbaaf2afd 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/rank.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/rank.hpp @@ -1,3 +1,27 @@ +/* +* The MIT License (MIT) +* +* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +* THE SOFTWARE. +* +*/ #ifndef MIGRAPHX_GUARD_KERNELS_RANK_HPP #define MIGRAPHX_GUARD_KERNELS_RANK_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp index a3e03507789..27f6303e6de 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/types.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/lrn.cpp b/src/targets/gpu/lrn.cpp index b247b2b8b58..2e99c208dd1 100644 --- a/src/targets/gpu/lrn.cpp +++ b/src/targets/gpu/lrn.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/pooling.cpp b/src/targets/gpu/pooling.cpp index 9dc03f46a87..a6f86f077cf 100644 --- a/src/targets/gpu/pooling.cpp +++ b/src/targets/gpu/pooling.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/fp8e4m3fn.cpp b/test/fp8e4m3fn.cpp index a691c4a1bb8..2ab04694214 100644 --- a/test/fp8e4m3fn.cpp +++ b/test/fp8e4m3fn.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/fp8e4m3fnuz.cpp b/test/fp8e4m3fnuz.cpp index 422b785d4a6..756cddd955b 100644 --- a/test/fp8e4m3fnuz.cpp +++ b/test/fp8e4m3fnuz.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/fp8e5m2.cpp b/test/fp8e5m2.cpp index 576f0e77f2e..211f2c6e411 100644 --- a/test/fp8e5m2.cpp +++ b/test/fp8e5m2.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/fp8e5m2fnuz.cpp b/test/fp8e5m2fnuz.cpp index 80cb4a7ca45..66ea7521f04 100644 --- a/test/fp8e5m2fnuz.cpp +++ b/test/fp8e5m2fnuz.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/verify/test_reduce_op_large.cpp b/test/verify/test_reduce_op_large.cpp index 131494b566a..9585b3a8764 100644 --- a/test/verify/test_reduce_op_large.cpp +++ b/test/verify/test_reduce_op_large.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From cb909a4e76f628a06d00bbd18005f93aaac0d920 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 23 Jul 2024 13:02:09 -0700 Subject: [PATCH 081/145] Format --- .../include/migraphx/kernels/atomic.hpp | 46 +++++++++---------- .../kernels/include/migraphx/kernels/rank.hpp | 46 +++++++++---------- 2 files changed, 46 insertions(+), 46 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp index d097e9b223a..76e0409cc81 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/atomic.hpp @@ -1,27 +1,27 @@ /* -* The MIT License (MIT) -* -* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. -* -* Permission is hereby granted, free of charge, to any person obtaining a copy -* of this software and associated documentation files (the "Software"), to deal -* in the Software without restriction, including without limitation the rights -* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -* copies of the Software, and to permit persons to whom the Software is -* furnished to do so, subject to the following conditions: -* -* The above copyright notice and this permission notice shall be included in -* all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -* THE SOFTWARE. -* -*/ + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ #ifndef MIGRAPHX_GUARD_KERNELS_ATOMIC_HPP #define MIGRAPHX_GUARD_KERNELS_ATOMIC_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/rank.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/rank.hpp index d2dbaaf2afd..5765b4f3e5d 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/rank.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/rank.hpp @@ -1,27 +1,27 @@ /* -* The MIT License (MIT) -* -* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. -* -* Permission is hereby granted, free of charge, to any person obtaining a copy -* of this software and associated documentation files (the "Software"), to deal -* in the Software without restriction, including without limitation the rights -* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -* copies of the Software, and to permit persons to whom the Software is -* furnished to do so, subject to the following conditions: -* -* The above copyright notice and this permission notice shall be included in -* all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -* THE SOFTWARE. -* -*/ + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ #ifndef MIGRAPHX_GUARD_KERNELS_RANK_HPP #define MIGRAPHX_GUARD_KERNELS_RANK_HPP From 3f4ef637a1f1c4b10c981424db4dbaac1bfc2186 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 24 Jul 2024 14:49:22 +0000 Subject: [PATCH 082/145] split-reduce fusion working --- src/module.cpp | 4 +- src/targets/gpu/fuse_mlir.cpp | 118 ++++++++++++++++++ .../include/migraphx/kernels/functional.hpp | 2 +- 3 files changed, 121 insertions(+), 3 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index e07ccfa7e6c..47587e437cf 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1031,8 +1031,8 @@ module::fuse(const module& m, if(map_ins == nullptr) map_ins = &default_map_ins; insert_params(*this, inputs, *map_ins); - auto param_map = m.get_ins_param_map(inputs); - for(auto&& [input, param] : param_map) + auto param_map = m.get_ins_param_map(inputs, true); + for(auto&& [param, input] : param_map) { (*map_ins)[param] = map_ins->at(input); } diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 95382cca173..64b1e59e965 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -355,8 +355,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) "erf", "exp", "floor", + "reduce_sum", "log", "recip", + "sqrt", "rsqrt", "sigmoid", "softmax", @@ -392,6 +394,22 @@ bool is_pointwise_op_supported_by_mlir_for_input(const instruction& i) return is_pointwise_op_supported_by_mlir(i); } +MIGRAPHX_PRED_MATCHER(mlir_split_reduce, instruction_ref ins) +{ + if(ins->name() != "split_fused_reduce") + return false; + std::vector sub_mods = ins->module_inputs().front()->get_sub_modules(true); + sub_mods.insert(sub_mods.begin(), ins->module_inputs().front()); + // for(const auto& mod : sub_mods) + // { + // if(not std::all_of(mod->begin(), mod->end(), &is_pointwise_op_supported_by_mlir)) + // { + // return false; + // } + // } + return true; +} + MIGRAPHX_PRED_MATCHER(mlir_pointwise, instruction_ref ins) { if(ins->name() != "pointwise") @@ -422,6 +440,99 @@ std::vector mlir_contiguous(module_pass_manager& mpm, return result; } +struct find_mlir_split_reduce +{ + mlir_mode conv_mode = mlir_mode::none; + mlir_mode dot_mode = mlir_mode::none; + auto matcher() const + { + auto dot_or_conv = match::name("gpu::mlir_op"); + return mlir_split_reduce()(match::any_of[match::inputs()](dot_or_conv.bind("gemm"))); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto reduce_ins = r.result; + auto gemm_ins = r.instructions["gemm"]; + auto* rm = reduce_ins->module_inputs().front(); + auto names = rm->get_parameter_names(); + std::sort(names.begin(), names.end()); + module_ref gemm_old_mm = gemm_ins->module_inputs().front(); + module_ref mm = mpm.create_module(gemm_old_mm->name() + "_split_fused_reduce"); + mm->add_instructions(gemm_old_mm); + mm->set_bypass(); + std::unordered_map param_map; + param_map[gemm_ins] = std::prev(mm->end())->inputs().front(); + bool gemm_has_multi_outs = gemm_ins->outputs().size() > 1; + mm->remove_instruction(std::prev(mm->end())); + auto return_vals = + mm->fuse(*rm, + reduce_ins->inputs(), + ¶m_map, + [&](module& main_mod, + instruction_ref pos, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) { + // todo handle broadcasted literals inside reduce mod + if(op.name() == "pointwise") + { + for(const auto& skip_param : inputs) + { + if(not contains(param_map, skip_param)) + { + param_map[skip_param] = + skip_param; // skip adding parameter for inputs of + // pointwise inside split_fused_reduce + } + } + auto sub_pm = mod_args.front(); + // todo: handle literals inside pointwise + auto param_map_2 = create_param_map_with_literals( + &main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args)); + for(const auto& i : param_map_2) + { + param_map.insert(i); + } + return main_mod.fuse(*sub_pm, inputs, ¶m_map).front(); + } + return main_mod.insert_instruction(pos, op, inputs, mod_args); + }); + if(gemm_has_multi_outs) + { + return_vals.insert(return_vals.end(), param_map[gemm_ins]); + } + mm->add_return(return_vals); + std::vector inputs; + std::copy_if(reduce_ins->inputs().begin(), + reduce_ins->inputs().end(), + std::back_inserter(inputs), + [&](auto input) { return input != gemm_ins; }); + inputs.insert(inputs.end(), gemm_ins->inputs().begin(), gemm_ins->inputs().end()); + if(gemm_has_multi_outs) + { + auto fused_ins = mpm.get_module().insert_instruction( + reduce_ins, mlir_op{gemm_ins->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); + auto dot_ins = mpm.get_module().insert_instruction( + reduce_ins, + migraphx::make_op("get_tuple_elem", {{"index", return_vals.size() - 1}}), + fused_ins); + + mpm.get_module().replace_instruction(gemm_ins, dot_ins); + for(const auto outs : reduce_ins->outputs()) + { + assert(outs->get_operator().name() == "get_tuple_elem"); + mpm.get_module().replace_instruction(outs, outs->get_operator(), fused_ins); + } + } + else + { + mpm.get_module().replace_instruction( + reduce_ins, mlir_op{gemm_ins->get_operator()}, mlir_contiguous(mpm, inputs), {mm}); + } + } +}; + struct find_mlir_fused_ops { mlir_mode conv_mode = mlir_mode::none; @@ -700,6 +811,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const mpm, find_mlir_fused_ops{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), .dot_mode = get_mode("fused_dot", mlir_mode::fast)}); + match::find_matches( mpm, find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)}, @@ -708,7 +820,13 @@ void fuse_mlir::apply(module_pass_manager& mpm) const mpm.run_pass(dead_code_elimination{}); if(enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + { + match::find_matches( + mpm, + find_mlir_split_reduce{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), + .dot_mode = get_mode("fused_dot", mlir_mode::fast)}); match::find_matches(mpm, find_pointwise_mlir{}); + } #else (void)mpm; #endif diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp index 35cb4745705..2fe9255294b 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp @@ -196,7 +196,7 @@ template constexpr void unpack_each(F f, Pack1 p1, Pack2 p2) { p1([&](auto&&... xs) { - p2([&](auto&& ys) { + p2([&](auto&&... ys) { each_args( [&](auto&& p) { p(f); }, pack_forward(static_cast(xs), static_cast(ys))...); From 0f785f00f6c564dc820c085358009869e5221b9a Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Wed, 24 Jul 2024 09:56:39 -0500 Subject: [PATCH 083/145] Update test/split_reduce.cpp Co-authored-by: Umang Yadav <29876643+umangyadav@users.noreply.github.com> --- test/split_reduce.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index 9616de45b68..b7f02f0afa9 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -352,7 +352,7 @@ TEST_CASE(double_split_live) auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[0]); auto mul = add_pointwise( - p2, rm, "main:pointwise1", {inputs[0], inputs[0]}, single_pointwise("mul")); + p2, rm, "main:pointwise1", {inputs[0]}, squared()); auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), mul); return {rsum1, rsum2}; From 97e4861d7e64110ef5c497696d01debb3106558c Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Wed, 24 Jul 2024 09:56:58 -0500 Subject: [PATCH 084/145] Update test/split_reduce.cpp Co-authored-by: Umang Yadav <29876643+umangyadav@users.noreply.github.com> --- test/split_reduce.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index b7f02f0afa9..de964791565 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -323,7 +323,7 @@ TEST_CASE(double_split_live) auto sqrtb = rm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt); auto mul = add_pointwise( - p1, rm, "main:pointwise1", {inputs[0], inputs[0]}, single_pointwise("mul")); + p1, rm, "main:pointwise1", {inputs[0]}, squared()); auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), mul); auto add = add_pointwise( From 32140c9d9889436c125dd7153b638c470a678219 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 24 Jul 2024 12:05:02 -0700 Subject: [PATCH 085/145] Fix test --- test/split_reduce.cpp | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index 9616de45b68..1e94b381d2c 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -361,16 +361,18 @@ TEST_CASE(double_split_live) mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), rsums); auto rsum2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), rsums); - auto sqrt = add_pointwise(p2, "main:pointwise0", {rsum1}, single_pointwise("sqrt")); - auto sqrtb = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt); - auto add = add_pointwise(p2, "main:pointwise2", {rsum2, sqrt}, single_pointwise("add")); - auto addb = - mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), add); - auto mul = add_pointwise(p2, "main:pointwise3", {addb, sqrtb}, single_pointwise("mul")); - mm->add_return({mul}); + auto rsum1b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum1); + auto rsum2b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum2); + auto sqrt_add_mul = add_pointwise(p2, "main:pointwise0", {rsum1b, rsum2b}, [](auto* pm, const auto& inputs) { + auto sqrt = pm->add_instruction(migraphx::make_op("sqrt"), inputs[0]); + auto add = pm->add_instruction(migraphx::make_op("add"), inputs[1], sqrt); + return pm->add_instruction(migraphx::make_op("mul"), add, sqrt); + }); + mm->add_return({sqrt_add_mul}); } - EXPECT(p1 == p2); + EXPECT(p1.sort() == p2.sort()); } int main(int argc, const char* argv[]) { test::run(argc, argv); } From daa607ca533eb4092a3e8b0a7ddf47009a202b42 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 24 Jul 2024 12:05:08 -0700 Subject: [PATCH 086/145] Format --- test/split_reduce.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index 1e94b381d2c..8cd25125aeb 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -365,11 +365,12 @@ TEST_CASE(double_split_live) migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum1); auto rsum2b = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum2); - auto sqrt_add_mul = add_pointwise(p2, "main:pointwise0", {rsum1b, rsum2b}, [](auto* pm, const auto& inputs) { - auto sqrt = pm->add_instruction(migraphx::make_op("sqrt"), inputs[0]); - auto add = pm->add_instruction(migraphx::make_op("add"), inputs[1], sqrt); - return pm->add_instruction(migraphx::make_op("mul"), add, sqrt); - }); + auto sqrt_add_mul = add_pointwise( + p2, "main:pointwise0", {rsum1b, rsum2b}, [](auto* pm, const auto& inputs) { + auto sqrt = pm->add_instruction(migraphx::make_op("sqrt"), inputs[0]); + auto add = pm->add_instruction(migraphx::make_op("add"), inputs[1], sqrt); + return pm->add_instruction(migraphx::make_op("mul"), add, sqrt); + }); mm->add_return({sqrt_add_mul}); } EXPECT(p1.sort() == p2.sort()); From 94d9456ea426ad938e28cc703551a31816304929 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 25 Jul 2024 14:10:15 +0000 Subject: [PATCH 087/145] refactor pieces --- src/targets/gpu/mlir.cpp | 68 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 94badfe5bbd..1fc453c57e7 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -21,11 +21,14 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include +#include #include #include #include #include #include +#include #include #ifdef MIGRAPHX_MLIR @@ -669,6 +672,8 @@ struct mlir_program void parse(const module& m) { validate(m); + std::cout << "parsing:\n"; + m.debug_print(); sym_name = get_symbol_name(m); auto mbody = mlirModuleGetBody(mmodule.get()); std::unordered_map ins_map; @@ -951,11 +956,63 @@ struct mlir_program std::string sym_name; }; +static bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); } + +void rewrite_reduce(module& m) +{ + std::vector ins_to_remove; + for(auto i : iterator_for(m)) + { + if(is_reduce(*i)) + { + auto reduce_op = i->get_operator().to_value(); + auto reduce_axes = reduce_op["axes"].to_vector(); + auto reduce_lens = i->get_shape().lens(); + auto in_shape = i->inputs().front()->get_shape(); + auto in_lens = in_shape.lens(); + assert(in_shape.standard()); + assert(std::is_sorted(reduce_axes.begin(), reduce_axes.end())); + assert(reduce_lens.size() == in_lens.size()); + assert(std::adjacent_find( + reduce_axes.begin(), reduce_axes.end(), [](auto axis_1, auto axis_2) { + return axis_2 - axis_1 > 1; + }) == reduce_axes.end()); + + std::vector rsp_lens; + std::vector new_reduce_axes; + for(const auto axis : range(in_shape.ndim())) + { + if(reduce_lens[axis] != in_lens[axis] and new_reduce_axes.empty()) + { + assert(reduce_lens[axis] == 1); + rsp_lens.push_back(-1); + new_reduce_axes.push_back(axis); + } + else + { + rsp_lens.push_back(in_lens[axis]); + } + } + auto rsp_ins = m.insert_instruction( + i, migraphx::make_op("reshape", {{"dims", rsp_lens}}), i->inputs().front()); + auto new_reduce = m.insert_instruction( + i, migraphx::make_op("reduce_sum", {{"axes", new_reduce_axes}}), rsp_ins); + auto rsp_back = m.insert_instruction( + i, migraphx::make_op("reshape", {{"dims", reduce_lens}}), new_reduce); + m.replace_instruction(i, rsp_back); + ins_to_remove.push_back(i); + } + } + std::for_each(ins_to_remove.begin(), ins_to_remove.end(), [&](const auto& remove_ins) { m.remove_instruction(remove_ins);}); +} + bool is_module_fusible(const module& m, const context& migraphx_ctx, const value& solution) { + auto mm = m; + rewrite_reduce(mm); mlir_program mp; mp.set_gpu_properties(migraphx_ctx); - mp.parse(m); + mp.parse(mm); mp.run_high_level_pipeline(); return mlirIsModuleFusible(mp.mmodule.get(), make_mlir_string_ref(*solution.if_string())); } @@ -988,6 +1045,7 @@ std::string dump_mlir(const module& m, const std::vector& inputs) mr = &mm; adjust_param_shapes(mm, inputs); } + rewrite_reduce(mm); mlir_program mp; mp.parse(*mr); auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); @@ -1001,13 +1059,17 @@ mlir_code_object compile_mlir(const context& migraphx_ctx, const std::vector& in_shapes, const value& solution) { + std::cout << "inside compile\n"; adjust_param_shapes(m, in_shapes); + rewrite_reduce(m); + m.debug_print(); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); static std::mutex mutex; if(trace) { const std::lock_guard lock(mutex); + std::cout << "compiling module\n"; std::cout << m << std::endl; } @@ -1073,7 +1135,9 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx, bool exhaustive) { adjust_param_shapes(m, inputs); - + rewrite_reduce(m); + std::cout << "getting tuning\n"; + m.debug_print(); mlir_program mp; mp.set_gpu_properties(migraphx_ctx); mp.parse(m); From 072f8dcad510139b015daa6b4fc5df59605730d0 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 25 Jul 2024 14:10:39 +0000 Subject: [PATCH 088/145] formatting --- src/targets/gpu/mlir.cpp | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 1fc453c57e7..1a1e9eb55f6 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -672,8 +672,6 @@ struct mlir_program void parse(const module& m) { validate(m); - std::cout << "parsing:\n"; - m.debug_print(); sym_name = get_symbol_name(m); auto mbody = mlirModuleGetBody(mmodule.get()); std::unordered_map ins_map; @@ -971,7 +969,6 @@ void rewrite_reduce(module& m) auto in_shape = i->inputs().front()->get_shape(); auto in_lens = in_shape.lens(); assert(in_shape.standard()); - assert(std::is_sorted(reduce_axes.begin(), reduce_axes.end())); assert(reduce_lens.size() == in_lens.size()); assert(std::adjacent_find( reduce_axes.begin(), reduce_axes.end(), [](auto axis_1, auto axis_2) { @@ -1003,7 +1000,9 @@ void rewrite_reduce(module& m) ins_to_remove.push_back(i); } } - std::for_each(ins_to_remove.begin(), ins_to_remove.end(), [&](const auto& remove_ins) { m.remove_instruction(remove_ins);}); + std::for_each(ins_to_remove.begin(), ins_to_remove.end(), [&](const auto& remove_ins) { + m.remove_instruction(remove_ins); + }); } bool is_module_fusible(const module& m, const context& migraphx_ctx, const value& solution) @@ -1059,17 +1058,14 @@ mlir_code_object compile_mlir(const context& migraphx_ctx, const std::vector& in_shapes, const value& solution) { - std::cout << "inside compile\n"; adjust_param_shapes(m, in_shapes); rewrite_reduce(m); - m.debug_print(); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); static std::mutex mutex; if(trace) { const std::lock_guard lock(mutex); - std::cout << "compiling module\n"; std::cout << m << std::endl; } @@ -1136,8 +1132,6 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx, { adjust_param_shapes(m, inputs); rewrite_reduce(m); - std::cout << "getting tuning\n"; - m.debug_print(); mlir_program mp; mp.set_gpu_properties(migraphx_ctx); mp.parse(m); From 0a2a8d8dbe9882973c2d168b85a071d41be545a6 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 25 Jul 2024 14:11:58 +0000 Subject: [PATCH 089/145] renamed --- src/targets/gpu/mlir.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 1a1e9eb55f6..3408c4d5ee6 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -992,10 +992,10 @@ void rewrite_reduce(module& m) } auto rsp_ins = m.insert_instruction( i, migraphx::make_op("reshape", {{"dims", rsp_lens}}), i->inputs().front()); - auto new_reduce = m.insert_instruction( + auto collapsed_reduce = m.insert_instruction( i, migraphx::make_op("reduce_sum", {{"axes", new_reduce_axes}}), rsp_ins); auto rsp_back = m.insert_instruction( - i, migraphx::make_op("reshape", {{"dims", reduce_lens}}), new_reduce); + i, migraphx::make_op("reshape", {{"dims", reduce_lens}}), collapsed_reduce); m.replace_instruction(i, rsp_back); ins_to_remove.push_back(i); } From 0a0626057760e08d9dcf9bb4a2aece1bfb443a01 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 25 Jul 2024 14:21:05 +0000 Subject: [PATCH 090/145] refactor --- src/targets/gpu/mlir.cpp | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 3408c4d5ee6..d6a2d228789 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -956,7 +956,7 @@ struct mlir_program static bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); } -void rewrite_reduce(module& m) +static void rewrite_reduce(module& m) { std::vector ins_to_remove; for(auto i : iterator_for(m)) @@ -975,23 +975,29 @@ void rewrite_reduce(module& m) return axis_2 - axis_1 > 1; }) == reduce_axes.end()); - std::vector rsp_lens; + std::vector new_rsp_dims; std::vector new_reduce_axes; for(const auto axis : range(in_shape.ndim())) { - if(reduce_lens[axis] != in_lens[axis] and new_reduce_axes.empty()) + if(reduce_lens[axis] == in_lens[axis]) { - assert(reduce_lens[axis] == 1); - rsp_lens.push_back(-1); - new_reduce_axes.push_back(axis); + new_rsp_dims.push_back(in_lens[axis]); } - else + else if(new_reduce_axes.empty()) { - rsp_lens.push_back(in_lens[axis]); + assert(reduce_lens[axis] == 1); + new_rsp_dims.push_back(-1); + new_reduce_axes.push_back(axis); } } + for(const auto& rsp_dim : new_rsp_dims) + { + std::cout << rsp_dim << " "; + } + std::cout << "\n"; + m.debug_print(); auto rsp_ins = m.insert_instruction( - i, migraphx::make_op("reshape", {{"dims", rsp_lens}}), i->inputs().front()); + i, migraphx::make_op("reshape", {{"dims", new_rsp_dims}}), i->inputs().front()); auto collapsed_reduce = m.insert_instruction( i, migraphx::make_op("reduce_sum", {{"axes", new_reduce_axes}}), rsp_ins); auto rsp_back = m.insert_instruction( From dff3dd4f86692d844682d4fad355c363b1482dc9 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 25 Jul 2024 14:21:25 +0000 Subject: [PATCH 091/145] remove debug --- src/targets/gpu/mlir.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index d6a2d228789..5194cfc37d1 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -990,12 +990,6 @@ static void rewrite_reduce(module& m) new_reduce_axes.push_back(axis); } } - for(const auto& rsp_dim : new_rsp_dims) - { - std::cout << rsp_dim << " "; - } - std::cout << "\n"; - m.debug_print(); auto rsp_ins = m.insert_instruction( i, migraphx::make_op("reshape", {{"dims", new_rsp_dims}}), i->inputs().front()); auto collapsed_reduce = m.insert_instruction( From ff94e0490991e1a807f4974bc5381e84ec563aa6 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 25 Jul 2024 14:54:21 +0000 Subject: [PATCH 092/145] add logic for checking is mlir_split_reduce --- src/targets/gpu/fuse_mlir.cpp | 49 ++++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 64b1e59e965..442a2dc8206 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -355,7 +355,6 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) "erf", "exp", "floor", - "reduce_sum", "log", "recip", "sqrt", @@ -387,6 +386,22 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) return false; } +bool is_reduce_op_supported_by_mlir(const instruction& i) +{ + using type_t = shape::type_t; + const auto& name = i.name(); + const auto result_type = i.get_shape().type(); + const std::initializer_list allowed_types = { + type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}; + // Preliminary type check. + if(not contains(allowed_types, result_type)) + { + return false; + } + const std::initializer_list reduce_ops = {"reduce_mean", "reduce_sum"}; + return contains(reduce_ops, i.name()); +} + // A separate function so we can remove operators that are supported by mlir // but not supported for an input fusion. bool is_pointwise_op_supported_by_mlir_for_input(const instruction& i) @@ -398,15 +413,29 @@ MIGRAPHX_PRED_MATCHER(mlir_split_reduce, instruction_ref ins) { if(ins->name() != "split_fused_reduce") return false; - std::vector sub_mods = ins->module_inputs().front()->get_sub_modules(true); - sub_mods.insert(sub_mods.begin(), ins->module_inputs().front()); - // for(const auto& mod : sub_mods) - // { - // if(not std::all_of(mod->begin(), mod->end(), &is_pointwise_op_supported_by_mlir)) - // { - // return false; - // } - // } + auto mod_arg = ins->module_inputs().front(); + auto supported_reshapes = reshaper_names(); + supported_reshapes.erase("slice"); + std::unordered_set builtins = {"@param", "@literal", "@return"}; + for(const auto i : iterator_for(*mod_arg)) + { + if(is_reduce(*i)) + { + if(not is_reduce_op_supported_by_mlir(*i)) + return false; + } + else if(i->name() == "pointwise") + { + if(not std::all_of(i->module_inputs().front()->begin(), + i->module_inputs().front()->end(), + &is_pointwise_op_supported_by_mlir)) + return false; + } + else if(not contains(reshaper_names(), i->name()) and not contains(builtins, i->name())) + { + return false; + } + } return true; } From 9540c788802f478695f8c99279561ddab9efd875 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 25 Jul 2024 14:54:32 +0000 Subject: [PATCH 093/145] add logic for is_reduce in header files --- src/targets/gpu/include/migraphx/gpu/mlir.hpp | 2 ++ src/targets/gpu/mlir.cpp | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/include/migraphx/gpu/mlir.hpp b/src/targets/gpu/include/migraphx/gpu/mlir.hpp index 8f359fcd38f..dd1e0fa31da 100644 --- a/src/targets/gpu/include/migraphx/gpu/mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/mlir.hpp @@ -50,6 +50,8 @@ struct MIGRAPHX_GPU_EXPORT mlir_code_object std::vector prefill_values = {}; }; +MIGRAPHX_GPU_EXPORT bool is_reduce(const instruction& i); + MIGRAPHX_GPU_EXPORT mlir_code_object compile_mlir(const context& migraphx_ctx, module m, const std::vector& in_shapes, diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 5194cfc37d1..2d187176dc2 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -954,7 +954,7 @@ struct mlir_program std::string sym_name; }; -static bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); } +bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); } static void rewrite_reduce(module& m) { From 00fef222b40cb11495a5177a7008f51141ea45e2 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 25 Jul 2024 08:07:30 -0700 Subject: [PATCH 094/145] Format --- test/split_reduce.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index bea5a7c56ef..0c4b8a1f983 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -322,8 +322,7 @@ TEST_CASE(double_split_live) add_pointwise(p1, rm, "main:pointwise0", {rsum1}, single_pointwise("sqrt")); auto sqrtb = rm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt); - auto mul = add_pointwise( - p1, rm, "main:pointwise1", {inputs[0]}, squared()); + auto mul = add_pointwise(p1, rm, "main:pointwise1", {inputs[0]}, squared()); auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), mul); auto add = add_pointwise( @@ -351,8 +350,7 @@ TEST_CASE(double_split_live) const auto& axes) -> std::vector { auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[0]); - auto mul = add_pointwise( - p2, rm, "main:pointwise1", {inputs[0]}, squared()); + auto mul = add_pointwise(p2, rm, "main:pointwise1", {inputs[0]}, squared()); auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), mul); return {rsum1, rsum2}; From 207f94e58580f5105e5c5dd6b718abc54942a409 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 25 Jul 2024 19:46:13 +0000 Subject: [PATCH 095/145] add TODO --- src/targets/gpu/fuse_mlir.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 442a2dc8206..dff63471105 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -476,6 +476,7 @@ struct find_mlir_split_reduce auto matcher() const { auto dot_or_conv = match::name("gpu::mlir_op"); + // TODO: Handle reshapes inbetween return mlir_split_reduce()(match::any_of[match::inputs()](dot_or_conv.bind("gemm"))); } From f022edb429c960ee9c972573d3d4f0a45a589af1 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 25 Jul 2024 19:57:24 +0000 Subject: [PATCH 096/145] add assert --- src/targets/gpu/jit/mlir.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 21f3d938dde..f51eb8a27e3 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -206,7 +206,8 @@ struct mlir_compiler : compiler const std::vector& mod_args) -> instruction_ref { if(op.name() == "reshape") { - // TODO: Add proper lowering for the reshape op + // TODO: Add support for non-standard shapes + assert(inputs.front()->get_shape().standard()); return insert_mod.insert_instruction( insert_loc, make_op("reshape_lazy", From 244e62ef163bcca80716fee3a066eaa243cd5e35 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 26 Jul 2024 14:23:30 +0000 Subject: [PATCH 097/145] remove else --- src/targets/gpu/jit/mlir.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index f51eb8a27e3..88cd48c339d 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -228,9 +228,8 @@ struct mlir_compiler : compiler return insert_mod.insert_instruction( insert_loc, make_op("gpu::contiguous"), contiguous_inputs); } - else - return insert_mod.insert_instruction( - insert_loc, op, inputs, mod_args); + return insert_mod.insert_instruction( + insert_loc, op, inputs, mod_args); }) .front(); } From e4c9eb9e58f590de169517a642703e1a1020dd5e Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 26 Jul 2024 14:38:59 +0000 Subject: [PATCH 098/145] remove else --- src/targets/gpu/jit/mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 88cd48c339d..9aa9ba36ac0 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -215,7 +215,7 @@ struct mlir_compiler : compiler op.compute_shape(to_shapes(inputs)).lens()}}), inputs); } - else if(op.name() == "contiguous") + if(op.name() == "contiguous") { auto contiguous_alloc = insert_mod.insert_instruction( insert_loc, From a1c5ad772beca9a991eebb658ddd9b34933f1c01 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 30 Jul 2024 12:37:44 +0000 Subject: [PATCH 099/145] use mlir for the reshapes --- src/targets/gpu/jit/mlir.cpp | 80 +++++++----------------------------- 1 file changed, 14 insertions(+), 66 deletions(-) diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 9aa9ba36ac0..93091a68739 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -95,24 +95,23 @@ struct mlir_compiler : compiler *smod, {migraphx::eliminate_contiguous{"contiguous"}, migraphx::dead_code_elimination{}}); auto input_args = ins->inputs(); + // remove alloc buffer input_args.pop_back(); - auto split_ins_pw = std::prev(pointwise_ins); - std::array mod_splits; - if(split_ins_pw == gemm_like_ins) - mod_splits = smod->split(input_args, {gemm_like_ins}, {}); - else - mod_splits = smod->split(input_args, {gemm_like_ins}, {split_ins_pw}); + auto split_ins = std::prev(pointwise_ins); + std::array mod_splits; + mod_splits = smod->split(input_args, {split_ins}); auto dot_mlir_inputs = to_shapes(mod_splits[0].inputs); + // add alloc for the gemm output dot_mlir_inputs.push_back(mod_splits[0].mod.get_output_shapes().front()); mlir_code_object cop1 = compile_mlir(ctx, mod_splits[0].mod, dot_mlir_inputs, solution); - auto pw_shapes = to_shapes(mod_splits[2].inputs); - pw_shapes.push_back(mod_splits[2].mod.get_output_shapes().front()); + auto pw_shapes = to_shapes(mod_splits[1].inputs); + pw_shapes.push_back(mod_splits[1].mod.get_output_shapes().front()); assert(pw_shapes.back() == ins->get_shape()); - auto pw_mod = create_pointwise_module(&mod_splits[2].mod); + auto pw_mod = create_pointwise_module(&mod_splits[1].mod); auto cop2 = compile_pointwise(ctx, pw_shapes, &pw_mod); std::vector cops = {cop1, mlir_code_object{any_cast(cop2)}}; - return insert(cops, mod_splits, ins, gemm_like_ins, split_ins_pw); + return insert(cops, mod_splits, ins, split_ins); } return insert(compile_mlir(ctx, *smod, to_shapes(ins->inputs()), solution)); } @@ -137,10 +136,9 @@ struct mlir_compiler : compiler } compiler_replace insert(const std::vector& mcos, - const std::array& mods, + const std::array& mods, instruction_ref precompile_ins, - instruction_ref split_ins, - instruction_ref split_pw) const + instruction_ref split_ins) const { std::vector cobjs(mcos.size()); std::transform( @@ -181,60 +179,10 @@ struct mlir_compiler : compiler } return i; }); - auto mlir = + auto mlir_ins = insert_mlir(m, ins, any_cast(ops[0]), dot_inputs_updated); - auto reshape_ins = mlir; - if(mods[1].mod.size() > 0) - { - assert(contains(mods[1].inputs, split_ins)); - (void)(split_ins); - assert(mods[1].mod.get_parameters().size() == 1); - auto param_ins = - std::find_if(mods[1].mod.begin(), mods[1].mod.end(), [](const auto& i) { - return i.name() == "@param"; - }); - inputs_rep_map[param_ins] = mlir; - reshape_ins = - m.insert_instructions( - ins, - &mods[1].mod, - &inputs_rep_map, - [](module& insert_mod, - instruction_ref insert_loc, - const operation& op, - const std::vector& inputs, - const std::vector& mod_args) -> instruction_ref { - if(op.name() == "reshape") - { - // TODO: Add support for non-standard shapes - assert(inputs.front()->get_shape().standard()); - return insert_mod.insert_instruction( - insert_loc, - make_op("reshape_lazy", - {{"dims", - op.compute_shape(to_shapes(inputs)).lens()}}), - inputs); - } - if(op.name() == "contiguous") - { - auto contiguous_alloc = insert_mod.insert_instruction( - insert_loc, - make_op( - "hip::allocate", - {{"shape", - to_value(op.compute_shape(to_shapes(inputs)))}})); - auto contiguous_inputs = inputs; - contiguous_inputs.push_back(contiguous_alloc); - return insert_mod.insert_instruction( - insert_loc, make_op("gpu::contiguous"), contiguous_inputs); - } - return insert_mod.insert_instruction( - insert_loc, op, inputs, mod_args); - }) - .front(); - } - auto pwm = mods[2]; - pwm.replace(split_pw, reshape_ins); + auto pwm = mods[1]; + pwm.replace(split_ins, mlir_ins); auto pw_inputs = pwm.inputs; pw_inputs.push_back(ins->inputs().back()); std::vector pw_inputs_updated; From ca11ca477aa39a6aaa705858df2a5c3ea66cbd3b Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 30 Jul 2024 12:39:13 +0000 Subject: [PATCH 100/145] fuse reshapes with dot --- src/targets/gpu/jit/mlir.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 93091a68739..a5da7772408 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -91,9 +91,6 @@ struct mlir_compiler : compiler if(gemm_like_ins != smod->end() and pointwise_ins != smod->end() and not is_module_fusible(*smod, ctx, solution)) { - migraphx::run_passes( - *smod, - {migraphx::eliminate_contiguous{"contiguous"}, migraphx::dead_code_elimination{}}); auto input_args = ins->inputs(); // remove alloc buffer input_args.pop_back(); From 9a7aa0b9c15b0bcb697ed49d578555475196a428 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 30 Jul 2024 12:41:44 +0000 Subject: [PATCH 101/145] remove header --- src/targets/gpu/jit/mlir.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index a5da7772408..65d1fc3a9e5 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -28,9 +28,6 @@ #include #include #include -#include -#include -#include #include #include #include From d3ab2affd77a1f075c8bcfba42e22c49e322850d Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 30 Jul 2024 12:43:24 +0000 Subject: [PATCH 102/145] remove changes for module split --- src/module.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index cb80bbb07d5..0317a8e1737 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -914,10 +914,7 @@ generic_split(const module& m, splits.end(), std::back_inserter(outputs), [&](instruction_ref ins) { return map_ins1.at(ins); }); - if(not outputs.empty()) - { - m1.add_return(outputs); - } + m1.add_return(outputs); std::vector instructions2; for(auto ins : iterator_for(m)) From 83fd1607ac35454273a1d59fe96fcb17b4a5797f Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 30 Jul 2024 13:29:31 +0000 Subject: [PATCH 103/145] flatten outputs --- src/targets/gpu/code_object_op.cpp | 2 +- src/targets/gpu/mlir.cpp | 10 +++++++++- test/gpu/fuse_mlir.cpp | 1 - 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/targets/gpu/code_object_op.cpp b/src/targets/gpu/code_object_op.cpp index 3f640e59d63..e78316ab139 100644 --- a/src/targets/gpu/code_object_op.cpp +++ b/src/targets/gpu/code_object_op.cpp @@ -40,7 +40,7 @@ shape code_object_op::compute_shape(std::vector inputs) const std::transform(einputs.begin(), einputs.end(), einputs.begin(), [](const shape& s) { return s.normalize_standard(); }); - if(einputs != flatten(inputs)) + if(flatten(einputs) != flatten(inputs)) MIGRAPHX_THROW("Input shapes have changed: [" + to_string_range(einputs) + "] -> [" + to_string_range(inputs) + "]"); return output; diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 94badfe5bbd..7256109701f 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -1024,7 +1024,15 @@ mlir_code_object compile_mlir(const context& migraphx_ctx, auto co = mp.compile(solution); co.expected_inputs = in_shapes; - co.output = m.get_output_shapes().front(); + auto out_shapes = m.get_output_shapes(); + if(out_shapes.size() == 1) + { + co.output = m.get_output_shapes().front(); + } + else + { + co.output = shape{out_shapes}; + } mlir_code_object mco; mco.cop = co; size_t num_prefill_args = mlirGetNumPrefillArgs(mp.mmodule.get()); diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index f66a1644f01..21f824c4ce3 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -170,7 +170,6 @@ TEST_CASE(multi_use_dot_trans_add_pooling_sub) mm->add_return({sub}); } run_pass(p1); - p1.debug_print(); migraphx::program p2; { auto* mm = p2.get_main_module(); From 662a29dd334166ae3d76a2c367c1d99c941f44ff Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 30 Jul 2024 16:02:57 +0000 Subject: [PATCH 104/145] disable test --- test/gpu/fuse_mlir.cpp | 2 +- test/verify/test_layernorm.cpp | 37 +++++++++++++++++----------------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 21f824c4ce3..f9919626c8a 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -293,7 +293,7 @@ TEST_CASE(add_dot) auto y = mm->add_parameter("y", s); auto fused = add_mlir(p2, - "main:pointwise0:mlir_dot2", + "main:pointwise0:mlir_dot4", {x, y, b}, {"x0", "x1", "x2"}, [=](auto* pm, const auto& inputs) { diff --git a/test/verify/test_layernorm.cpp b/test/verify/test_layernorm.cpp index eed0f59a9f4..5dc4ac16850 100644 --- a/test/verify/test_layernorm.cpp +++ b/test/verify/test_layernorm.cpp @@ -177,21 +177,22 @@ struct test_layernorm_triadd_large : verify_program } }; -struct test_add_layernorm_add_gemm_nonstd : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - auto s = - migraphx::shape::from_permutation(migraphx::shape::float_type, {8, 1, 16}, {1, 2, 0}); - auto x = mm->add_parameter("x", s); - auto y = mm->add_parameter("y", s); - auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {8, 16, 64}}); - auto add = mm->add_instruction(migraphx::make_op("add"), x, y); - auto layernorm_ins = add_layernorm(*mm, add, s.lens()); - mm->add_instruction(migraphx::make_op("dot"), layernorm_ins, z); - return p; - } - std::string section() const { return "gemm"; } -}; +// struct test_add_layernorm_add_gemm_nonstd : verify_program +// { +// migraphx::program create_program() const +// { +// migraphx::program p; +// auto* mm = p.get_main_module(); +// auto s = +// migraphx::shape::from_permutation(migraphx::shape::float_type, {8, 1, 16}, {1, 2, +// 0}); +// auto x = mm->add_parameter("x", s); +// auto y = mm->add_parameter("y", s); +// auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {8, 16, +// 64}}); auto add = mm->add_instruction(migraphx::make_op("add"), x, y); auto +// layernorm_ins = add_layernorm(*mm, add, s.lens()); +// mm->add_instruction(migraphx::make_op("dot"), layernorm_ins, z); +// return p; +// } +// std::string section() const { return "gemm"; } +// }; From d4dd7af06639402b1b5ccd8e1da025ac761703d3 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 30 Jul 2024 16:39:42 +0000 Subject: [PATCH 105/145] remove TODO --- src/targets/gpu/fuse_mlir.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index dff63471105..7b9a003405f 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -504,7 +504,6 @@ struct find_mlir_split_reduce const operation& op, const std::vector& inputs, const std::vector& mod_args) { - // todo handle broadcasted literals inside reduce mod if(op.name() == "pointwise") { for(const auto& skip_param : inputs) @@ -516,8 +515,7 @@ struct find_mlir_split_reduce // pointwise inside split_fused_reduce } } - auto sub_pm = mod_args.front(); - // todo: handle literals inside pointwise + auto sub_pm = mod_args.front(); auto param_map_2 = create_param_map_with_literals( &main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args)); for(const auto& i : param_map_2) From c1cba50b7d7dfcdefc6554f8aaca45d60cc4a3d5 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 30 Jul 2024 10:02:00 -0700 Subject: [PATCH 106/145] Update TODO --- src/split_reduce.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 3f94a2500d7..91bdfc9924f 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -185,7 +185,7 @@ void split_reduce::apply(module_pass_manager& mpm) const if(splits.empty()) continue; // Only use split reduce with float for now - // TODO: Support half and other data types + // TODO: Support other data types if(not std::all_of(splits.begin(), splits.end(), [](instruction_ref split) { return contains({shape::float_type, shape::half_type}, split->get_shape().type()); })) From 805793dd1a413f84f97fd708afaaa97f82d4d9d2 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 30 Jul 2024 17:11:11 +0000 Subject: [PATCH 107/145] add verify test --- test/include/layernorm.hpp | 64 ++++++++++++++++++++ test/verify/test_conv_add_layernorm_conv.cpp | 56 +++++++++++++++++ test/verify/test_layernorm.cpp | 38 +----------- 3 files changed, 121 insertions(+), 37 deletions(-) create mode 100644 test/include/layernorm.hpp create mode 100644 test/verify/test_conv_add_layernorm_conv.cpp diff --git a/test/include/layernorm.hpp b/test/include/layernorm.hpp new file mode 100644 index 00000000000..deaf0a4fa1c --- /dev/null +++ b/test/include/layernorm.hpp @@ -0,0 +1,64 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include + +inline migraphx::instruction_ref add_layernorm(migraphx::module& m, + migraphx::instruction_ref x, + const std::vector& dims, + float eps = 0e-12f) +{ + auto mgx_type = x->get_shape().type(); + auto scale = m.add_parameter("scale", migraphx::shape{mgx_type, {dims.back()}}); + auto bias = m.add_parameter("bias", migraphx::shape{mgx_type, {dims.back()}}); + + auto epsilon = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {eps}}); + auto exponent = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {2.0f}}); + + auto mean = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), x); + auto mean_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean); + auto sub = m.add_instruction(migraphx::make_op("sub"), x, mean_mbcast); + auto exponent_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent); + auto pow = m.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast); + auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow); + auto epsilon_mbcast = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", var->get_shape().lens()}}), epsilon); + + auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast); + auto sqrt = m.add_instruction(migraphx::make_op("sqrt"), add_epsilon); + auto sqrt_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), sqrt); + auto div = m.add_instruction(migraphx::make_op("div"), sub, sqrt_mbcast); + auto scale_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), scale); + auto mul = m.add_instruction(migraphx::make_op("mul"), div, scale_mbcast); + + auto bias_mbcast = + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias); + return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast); +} diff --git a/test/verify/test_conv_add_layernorm_conv.cpp b/test/verify/test_conv_add_layernorm_conv.cpp new file mode 100644 index 00000000000..a4bb972eb94 --- /dev/null +++ b/test/verify/test_conv_add_layernorm_conv.cpp @@ -0,0 +1,56 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include "verify_program.hpp" +#include +#include +#include +#include +#include + +template +struct test_conv_add_layernorm_conv : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); + auto weights1 = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 1, 1}}); + auto bias_literal = + migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}}; + auto bias = mm->add_literal(bias_literal); + auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), input, weights1); + auto bcast_bias = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv1->get_shape().lens()}}), + bias); + auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv1, bcast_bias); + add_layernorm(*mm, bias_add, {4, 4, 3, 3}); + return p; + } + std::string section() const { return "conv"; } +}; + +template struct test_conv_add_layernorm_conv; +template struct test_conv_add_layernorm_conv; +template struct test_conv_add_layernorm_conv; diff --git a/test/verify/test_layernorm.cpp b/test/verify/test_layernorm.cpp index 5dc4ac16850..e1547d035c8 100644 --- a/test/verify/test_layernorm.cpp +++ b/test/verify/test_layernorm.cpp @@ -27,43 +27,7 @@ #include #include #include - -migraphx::instruction_ref add_layernorm(migraphx::module& m, - migraphx::instruction_ref x, - std::vector dims, - float eps = 1e-12f) -{ - auto mgx_type = x->get_shape().type(); - auto scale = m.add_parameter("scale", migraphx::shape{mgx_type, {dims.back()}}); - auto bias = m.add_parameter("bias", migraphx::shape{mgx_type, {dims.back()}}); - - auto epsilon = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {eps}}); - auto exponent = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {2.0f}}); - - auto mean = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), x); - auto mean_mbcast = - m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean); - auto sub = m.add_instruction(migraphx::make_op("sub"), x, mean_mbcast); - auto exponent_mbcast = - m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent); - auto pow = m.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast); - auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow); - auto epsilon_mbcast = m.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {dims.at(0), dims.at(1), 1}}}), epsilon); - - auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast); - auto sqrt = m.add_instruction(migraphx::make_op("sqrt"), add_epsilon); - auto sqrt_mbcast = - m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), sqrt); - auto div = m.add_instruction(migraphx::make_op("div"), sub, sqrt_mbcast); - auto scale_mbcast = - m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), scale); - auto mul = m.add_instruction(migraphx::make_op("mul"), div, scale_mbcast); - - auto bias_mbcast = - m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias); - return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast); -} +#include struct test_layernorm : verify_program { From fd5a9a1bb9bb648df53696e07c08660097a19450 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 30 Jul 2024 19:35:56 +0000 Subject: [PATCH 108/145] increase reduce limite, disable rewrite_reduce to reduce_sum --- src/include/migraphx/split_reduce.hpp | 2 +- src/rewrite_reduce.cpp | 2 +- src/split_reduce.cpp | 8 +++++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/include/migraphx/split_reduce.hpp b/src/include/migraphx/split_reduce.hpp index 687f363b0d2..efbd88c0993 100644 --- a/src/include/migraphx/split_reduce.hpp +++ b/src/include/migraphx/split_reduce.hpp @@ -41,7 +41,7 @@ struct module_pass_manager; /// needing global synchronization. struct MIGRAPHX_EXPORT split_reduce { - std::size_t split_size = 8192; + std::size_t split_size = 1; std::string name() const { return "split_reduce"; } void apply(module_pass_manager& mpm) const; }; diff --git a/src/rewrite_reduce.cpp b/src/rewrite_reduce.cpp index 30834df3cca..abf85ee325a 100644 --- a/src/rewrite_reduce.cpp +++ b/src/rewrite_reduce.cpp @@ -145,7 +145,7 @@ struct find_reduce_mean void rewrite_reduce::apply(module& m) const { match::find_matches(m, find_softmax{}, find_reduce_mean_variance{}); - match::find_matches(m, find_reduce_mean{}); + // match::find_matches(m, find_reduce_mean{}); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 3f94a2500d7..90ceda1f95e 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -99,7 +99,7 @@ struct splitter // Only handle reduce_sum for now // TODO: Support other reduction types if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) { - return ins->name() == "reduce_sum"; + return ins->name() == "reduce_sum" or ins->name() == "reduce_mean"; })) return {}; if(result.size() < 2) @@ -183,13 +183,19 @@ void split_reduce::apply(module_pass_manager& mpm) const splitter s{rm}; auto splits = s.find_splits(); if(splits.empty()) + { + std::cout << "split are empty\n"; continue; + } // Only use split reduce with float for now // TODO: Support half and other data types if(not std::all_of(splits.begin(), splits.end(), [](instruction_ref split) { return contains({shape::float_type, shape::half_type}, split->get_shape().type()); })) + { + std::cout << "not supported now\n"; continue; + } auto v = ins->get_operator().to_value(); auto axes = v["axes"].to_vector(); From bd1eca32d1018b4282cc7c72802327a6e1aeb2d4 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 30 Jul 2024 14:03:36 -0700 Subject: [PATCH 109/145] Get correct data type for lane reductions --- src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp index 22cff3e4acd..cc88b91a3b6 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp @@ -667,7 +667,7 @@ struct lane template __device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const { - using type = remove_reference_t))>; + using type = remove_reference_t), xs(0, _c<0>)...))>; type r = type(init); for(index_int j = 0; j < n; j++) { From 631127acad152df0579b4678ca2b010a07269b22 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 30 Jul 2024 22:16:31 +0000 Subject: [PATCH 110/145] enable test again --- test/verify/test_layernorm.cpp | 37 +++++++++++++++++----------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/test/verify/test_layernorm.cpp b/test/verify/test_layernorm.cpp index e1547d035c8..969ed373b02 100644 --- a/test/verify/test_layernorm.cpp +++ b/test/verify/test_layernorm.cpp @@ -141,22 +141,21 @@ struct test_layernorm_triadd_large : verify_program } }; -// struct test_add_layernorm_add_gemm_nonstd : verify_program -// { -// migraphx::program create_program() const -// { -// migraphx::program p; -// auto* mm = p.get_main_module(); -// auto s = -// migraphx::shape::from_permutation(migraphx::shape::float_type, {8, 1, 16}, {1, 2, -// 0}); -// auto x = mm->add_parameter("x", s); -// auto y = mm->add_parameter("y", s); -// auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {8, 16, -// 64}}); auto add = mm->add_instruction(migraphx::make_op("add"), x, y); auto -// layernorm_ins = add_layernorm(*mm, add, s.lens()); -// mm->add_instruction(migraphx::make_op("dot"), layernorm_ins, z); -// return p; -// } -// std::string section() const { return "gemm"; } -// }; +struct test_add_layernorm_add_gemm_nonstd : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = + migraphx::shape::from_permutation(migraphx::shape::float_type, {8, 1, 16}, {1, 2, 0}); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {8, 16, 64}}); + auto add = mm->add_instruction(migraphx::make_op("add"), x, y); + auto layernorm_ins = add_layernorm(*mm, add, s.lens()); + mm->add_instruction(migraphx::make_op("dot"), layernorm_ins, z); + return p; + } + std::string section() const { return "gemm"; } +}; From e82daf17d2f2d418a01a1637a182bd236cf0463a Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 31 Jul 2024 12:12:10 +0000 Subject: [PATCH 111/145] revert back split size --- src/include/migraphx/split_reduce.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/split_reduce.hpp b/src/include/migraphx/split_reduce.hpp index efbd88c0993..687f363b0d2 100644 --- a/src/include/migraphx/split_reduce.hpp +++ b/src/include/migraphx/split_reduce.hpp @@ -41,7 +41,7 @@ struct module_pass_manager; /// needing global synchronization. struct MIGRAPHX_EXPORT split_reduce { - std::size_t split_size = 1; + std::size_t split_size = 8192; std::string name() const { return "split_reduce"; } void apply(module_pass_manager& mpm) const; }; From eb4f26257323b653fd7150de6e5e81fce0485ebe Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 31 Jul 2024 12:12:45 +0000 Subject: [PATCH 112/145] add MIGRAPHX_EXPORT For the reaches --- src/include/migraphx/instruction.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/instruction.hpp b/src/include/migraphx/instruction.hpp index c1a5b15e2af..3d02f64f7de 100644 --- a/src/include/migraphx/instruction.hpp +++ b/src/include/migraphx/instruction.hpp @@ -45,7 +45,7 @@ MIGRAPHX_EXPORT std::vector to_shapes(const std::vector& MIGRAPHX_EXPORT std::vector try_compute_shape(const operation& op, const std::vector& inputs); -bool reaches(instruction_ref start, instruction_ref end); +MIGRAPHX_EXPORT bool reaches(instruction_ref start, instruction_ref end); struct MIGRAPHX_EXPORT instruction { From 1ac328b63fa84662c94a2e5699bb011b90bd5ddc Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 31 Jul 2024 12:27:39 +0000 Subject: [PATCH 113/145] add test for the MLIR slow bench --- test/verify/test_conv_add_layernorm_conv.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/verify/test_conv_add_layernorm_conv.cpp b/test/verify/test_conv_add_layernorm_conv.cpp index a4bb972eb94..a912eba1d87 100644 --- a/test/verify/test_conv_add_layernorm_conv.cpp +++ b/test/verify/test_conv_add_layernorm_conv.cpp @@ -35,17 +35,19 @@ struct test_conv_add_layernorm_conv : verify_programadd_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); - auto weights1 = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 1, 1}}); - auto bias_literal = - migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}}; + auto input = mm->add_parameter("x", migraphx::shape{DType, {2, 4, 64, 64}}); + auto weights1 = mm->add_parameter("w", migraphx::shape{DType, {320, 4, 3, 3}}); + auto bias_literal = abs(migraphx::generate_literal(migraphx::shape{DType, {320}}, 1)); auto bias = mm->add_literal(bias_literal); - auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), input, weights1); + auto conv1 = mm->add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}}), input, weights1); auto bcast_bias = mm->add_instruction( migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv1->get_shape().lens()}}), bias); auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv1, bcast_bias); - add_layernorm(*mm, bias_add, {4, 4, 3, 3}); + auto rsp_add = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {0, 32, -1}}}), bias_add); + add_layernorm(*mm, rsp_add, rsp_add->get_shape().lens()); return p; } std::string section() const { return "conv"; } From 68a8afb586fa2326212acc7240abb970e5a63d1f Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 31 Jul 2024 13:23:05 +0000 Subject: [PATCH 114/145] fix merge --- src/targets/gpu/fuse_mlir.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index aba385364de..2e37680f15f 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -439,13 +439,10 @@ struct find_mlir_fused_ops void apply(module_pass_manager& mpm, const match::matcher_result& r) const { - auto pw_ins = r.result; auto pw_ins = r.result; auto gemm_based_op = r.instructions["gemm_based_op"]; auto x_ins = r.instructions["x"]; // input to pointwise after reshaper op stream auto* pm = pw_ins->module_inputs().front(); - auto x_ins = r.instructions["x"]; // input to pointwise after reshaper op stream - auto* pm = pw_ins->module_inputs().front(); auto names = pm->get_parameter_names(); std::sort(names.begin(), names.end()); module_ref mm = mpm.create_module("mlir_" + pm->name()); @@ -480,13 +477,10 @@ struct find_mlir_fused_ops mm->add_return(return_vals); std::vector inputs; - std::copy_if(pw_ins->inputs().begin(), - pw_ins->inputs().end(), std::copy_if(pw_ins->inputs().begin(), pw_ins->inputs().end(), std::back_inserter(inputs), [&](auto input) { return input != x_ins; }); - [&](auto input) { return input != x_ins; }); inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end()); if(gemm_has_multi_outs) { From 7e83db367d70934b41a9b44db11f5b461be7467e Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 31 Jul 2024 13:44:21 +0000 Subject: [PATCH 115/145] fix unit-test --- test/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 21f824c4ce3..f9919626c8a 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -293,7 +293,7 @@ TEST_CASE(add_dot) auto y = mm->add_parameter("y", s); auto fused = add_mlir(p2, - "main:pointwise0:mlir_dot2", + "main:pointwise0:mlir_dot4", {x, y, b}, {"x0", "x1", "x2"}, [=](auto* pm, const auto& inputs) { From c5b70b7e7f31e8f9e1803e9987f0021e08d13169 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 31 Jul 2024 13:50:41 +0000 Subject: [PATCH 116/145] merge fixes --- src/split_reduce.cpp | 69 +------------------------------------------- 1 file changed, 1 insertion(+), 68 deletions(-) diff --git a/src/split_reduce.cpp b/src/split_reduce.cpp index 379ecc29065..91bdfc9924f 100644 --- a/src/split_reduce.cpp +++ b/src/split_reduce.cpp @@ -24,7 +24,6 @@ */ #include #include -#include #include #include #include @@ -69,12 +68,6 @@ struct split_fused_reduce if(result.size() == 1) return result.front(); return shape{result}; - - auto result = - sm->compute_shapes(inputs, {.name = name(), .strict_type = true, .strict_lens = true}); - if(result.size() == 1) - return result.front(); - return shape{result}; } std::string name() const { return "split_fused_reduce"; } @@ -106,7 +99,7 @@ struct splitter // Only handle reduce_sum for now // TODO: Support other reduction types if(not std::all_of(result.begin(), result.end(), [](instruction_ref ins) { - return ins->name() == "reduce_sum" or ins->name() == "reduce_mean"; + return ins->name() == "reduce_sum"; })) return {}; if(result.size() < 2) @@ -150,41 +143,6 @@ struct splitter std::optional dom = std::nullopt; }; -} // namespace - std::vector find_alive(const std::vector& splits) - { - std::vector result; - bool stop = false; - liveness(*rm, [&](auto rins, const auto& live_set) { - if(stop) - return; - if(rins == rm->begin()) - return; - // We want to know what instructions are live after the split instruction - auto ins = instruction::get_output_alias(std::prev(rins)); - if(not contains(splits, ins)) - return; - std::copy_if(live_set.begin(), - live_set.end(), - std::back_inserter(result), - [&](instruction_ref live) { - if(live->name() == "@param") - return false; - if(contains(splits, live)) - return false; - if(splits.size() > 1 and none_of(splits, [&](instruction_ref split) { - return this->strictly_dominate(live, split); - })) - return false; - return true; - }); - stop = true; - }); - return result; - } - - std::optional dom = std::nullopt; -}; } // namespace static std::string assign_op(const std::vector& splits) @@ -224,28 +182,17 @@ void split_reduce::apply(module_pass_manager& mpm) const continue; splitter s{rm}; auto splits = s.find_splits(); - splitter s{rm}; - auto splits = s.find_splits(); if(splits.empty()) - { - std::cout << "split are empty\n"; continue; - } // Only use split reduce with float for now // TODO: Support other data types - // TODO: Support other data types if(not std::all_of(splits.begin(), splits.end(), [](instruction_ref split) { return contains({shape::float_type, shape::half_type}, split->get_shape().type()); - return contains({shape::float_type, shape::half_type}, split->get_shape().type()); })) - { - std::cout << "not supported now\n"; continue; - } auto v = ins->get_operator().to_value(); auto axes = v["axes"].to_vector(); - auto alive = s.find_alive(splits); auto alive = s.find_alive(splits); std::array mods; @@ -285,20 +232,6 @@ void split_reduce::apply(module_pass_manager& mpm) const }); } - mods[1].replace(splits, split_reduce_each); - std::vector split_reduce_each; - if(splits.size() == 1) - { - split_reduce_each = {split_reduce}; - } - else - { - transform(range(splits.size()), std::back_inserter(split_reduce_each), [&](auto i) { - return mpm.get_module().insert_instruction( - ins, make_op("get_tuple_elem", {{"index", i}}), split_reduce); - }); - } - mods[1].replace(splits, split_reduce_each); auto replaced = insert_module_inline(mpm.get_module(), ins, mods[1]); assert(replaced.size() == 1); From ca7df9244c53b6236064cfd25a8ba35825090ddf Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 31 Jul 2024 14:20:17 +0000 Subject: [PATCH 117/145] fix return bug enable rewrite_reduce --- src/rewrite_reduce.cpp | 2 +- src/targets/gpu/fuse_mlir.cpp | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/rewrite_reduce.cpp b/src/rewrite_reduce.cpp index abf85ee325a..30834df3cca 100644 --- a/src/rewrite_reduce.cpp +++ b/src/rewrite_reduce.cpp @@ -145,7 +145,7 @@ struct find_reduce_mean void rewrite_reduce::apply(module& m) const { match::find_matches(m, find_softmax{}, find_reduce_mean_variance{}); - // match::find_matches(m, find_reduce_mean{}); + match::find_matches(m, find_reduce_mean{}); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 7b9a003405f..f68edf3806d 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -494,7 +494,6 @@ struct find_mlir_split_reduce std::unordered_map param_map; param_map[gemm_ins] = std::prev(mm->end())->inputs().front(); bool gemm_has_multi_outs = gemm_ins->outputs().size() > 1; - mm->remove_instruction(std::prev(mm->end())); auto return_vals = mm->fuse(*rm, reduce_ins->inputs(), From 9f56e6a2481779e1c4c173e189e19f88874d2948 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 31 Jul 2024 14:56:56 +0000 Subject: [PATCH 118/145] fix wiring --- src/targets/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index f68edf3806d..c390cd24382 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -492,7 +492,7 @@ struct find_mlir_split_reduce mm->add_instructions(gemm_old_mm); mm->set_bypass(); std::unordered_map param_map; - param_map[gemm_ins] = std::prev(mm->end())->inputs().front(); + param_map[gemm_ins] = std::prev(mm->end()); bool gemm_has_multi_outs = gemm_ins->outputs().size() > 1; auto return_vals = mm->fuse(*rm, From f1550b189f5056bfaf9ccaee933295896f31034f Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 31 Jul 2024 16:36:04 +0000 Subject: [PATCH 119/145] fix output shape --- src/targets/gpu/jit/mlir.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 65d1fc3a9e5..5170e941191 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -85,6 +85,8 @@ struct mlir_compiler : compiler // check if (a) module is fused (b) contains a "gemm/conv" instruction and (c) // perfConfig can not allow fused module + std::cout << "Compiling solution : \n"; + solution.debug_print(); if(gemm_like_ins != smod->end() and pointwise_ins != smod->end() and not is_module_fusible(*smod, ctx, solution)) { @@ -99,7 +101,16 @@ struct mlir_compiler : compiler dot_mlir_inputs.push_back(mod_splits[0].mod.get_output_shapes().front()); mlir_code_object cop1 = compile_mlir(ctx, mod_splits[0].mod, dot_mlir_inputs, solution); auto pw_shapes = to_shapes(mod_splits[1].inputs); - pw_shapes.push_back(mod_splits[1].mod.get_output_shapes().front()); + if(mod_splits[1].mod.get_output_shapes().size() == 1) + { + pw_shapes.push_back(mod_splits[1].mod.get_output_shapes().front()); + } + else + { + pw_shapes.push_back(shape{mod_splits[1].mod.get_output_shapes()}); + } + mod_splits[1].mod.debug_print(); + ins->debug_print(); assert(pw_shapes.back() == ins->get_shape()); auto pw_mod = create_pointwise_module(&mod_splits[1].mod); auto cop2 = compile_pointwise(ctx, pw_shapes, &pw_mod); From f276db52b0da083da6401e13d82c0b87f66460df Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 31 Jul 2024 17:15:33 +0000 Subject: [PATCH 120/145] remove debug prints --- src/targets/gpu/jit/mlir.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 5170e941191..6565ba056a1 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -85,8 +85,6 @@ struct mlir_compiler : compiler // check if (a) module is fused (b) contains a "gemm/conv" instruction and (c) // perfConfig can not allow fused module - std::cout << "Compiling solution : \n"; - solution.debug_print(); if(gemm_like_ins != smod->end() and pointwise_ins != smod->end() and not is_module_fusible(*smod, ctx, solution)) { @@ -109,8 +107,6 @@ struct mlir_compiler : compiler { pw_shapes.push_back(shape{mod_splits[1].mod.get_output_shapes()}); } - mod_splits[1].mod.debug_print(); - ins->debug_print(); assert(pw_shapes.back() == ins->get_shape()); auto pw_mod = create_pointwise_module(&mod_splits[1].mod); auto cop2 = compile_pointwise(ctx, pw_shapes, &pw_mod); From 207692001889aa0e785ad4dd774070ec0fe4d7bb Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 1 Aug 2024 12:47:56 +0000 Subject: [PATCH 121/145] add env flag for the reduce fusion --- src/targets/gpu/fuse_mlir.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index c390cd24382..3db7e95ac54 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -43,6 +43,7 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR); /** * @brief Declares a new MIGraphX environment variable which forces to generate @@ -845,13 +846,16 @@ void fuse_mlir::apply(module_pass_manager& mpm) const find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::fast)}); mpm.run_pass(dead_code_elimination{}); - - if(enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + if(enabled(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION{})) { match::find_matches( mpm, find_mlir_split_reduce{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), .dot_mode = get_mode("fused_dot", mlir_mode::fast)}); + } + + if(enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + { match::find_matches(mpm, find_pointwise_mlir{}); } #else From 43a22e5a92c7c27f069f276761445ac68924d49d Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 1 Aug 2024 13:19:28 +0000 Subject: [PATCH 122/145] add doc --- docs/dev/env_vars.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/dev/env_vars.rst b/docs/dev/env_vars.rst index 6295593db0a..be8fddf5d89 100644 --- a/docs/dev/env_vars.rst +++ b/docs/dev/env_vars.rst @@ -278,6 +278,11 @@ Limits the number of solutions available to MLIR for tuning. Set to "1", "enable", "enabled", "yes", or "true" to use. Enable input fusions in MLIR. +.. envvar:: MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION + +Set to "1", "enable", "enabled", "yes", or "true" to use. +Enable reduction fusions in MLIR. + .. envvar:: MIGRAPHX_MLIR_ENABLE_SPLITK Set to "1", "enable", "enabled", "yes", or "true" to use. From a4d546d8230ecb7a6c934453ddf3b6aeb2e38091 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 1 Aug 2024 13:21:28 +0000 Subject: [PATCH 123/145] formatting --- test/verify/test_conv_add_layernorm_conv.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/verify/test_conv_add_layernorm_conv.cpp b/test/verify/test_conv_add_layernorm_conv.cpp index a912eba1d87..ffd1e3d0b4e 100644 --- a/test/verify/test_conv_add_layernorm_conv.cpp +++ b/test/verify/test_conv_add_layernorm_conv.cpp @@ -34,11 +34,11 @@ struct test_conv_add_layernorm_conv : verify_programadd_parameter("x", migraphx::shape{DType, {2, 4, 64, 64}}); auto weights1 = mm->add_parameter("w", migraphx::shape{DType, {320, 4, 3, 3}}); auto bias_literal = abs(migraphx::generate_literal(migraphx::shape{DType, {320}}, 1)); - auto bias = mm->add_literal(bias_literal); + auto bias = mm->add_literal(bias_literal); auto conv1 = mm->add_instruction( migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}}), input, weights1); auto bcast_bias = mm->add_instruction( From c64d2eeeec451611720770d91ebe6239190e6e61 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 1 Aug 2024 13:25:20 +0000 Subject: [PATCH 124/145] fix cppcheck --- src/targets/gpu/include/migraphx/gpu/mlir.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/include/migraphx/gpu/mlir.hpp b/src/targets/gpu/include/migraphx/gpu/mlir.hpp index dd1e0fa31da..8acb1807e31 100644 --- a/src/targets/gpu/include/migraphx/gpu/mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/mlir.hpp @@ -50,7 +50,7 @@ struct MIGRAPHX_GPU_EXPORT mlir_code_object std::vector prefill_values = {}; }; -MIGRAPHX_GPU_EXPORT bool is_reduce(const instruction& i); +MIGRAPHX_GPU_EXPORT bool is_reduce(const instruction& ins); MIGRAPHX_GPU_EXPORT mlir_code_object compile_mlir(const context& migraphx_ctx, module m, From 40325f9c10dc7c34c2ccd87c35cac5606f974b09 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 1 Aug 2024 13:44:20 +0000 Subject: [PATCH 125/145] update problem_key && jenkins --- Jenkinsfile | 2 +- src/targets/gpu/mlir.cpp | 11 ++++++++++- test/verify/test_conv_add_layernorm_conv.cpp | 9 +++++++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index eb247971ad9..e271908ba9f 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -144,7 +144,7 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build -> } }, mlir_debug: rocmnode('mi100+') { cmake_build -> stage('MLIR Debug') { - withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1', 'MIGRAPHX_MLIR_ENABLE_SPLITK=1']) { + withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1', 'MIGRAPHX_MLIR_ENABLE_SPLITK=1', 'MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION=1', 'MIGRAPHX_ENABLE_SPLIT_REDUCE=1','MIGRAPHX_DISABLE_LAYERNORM_FUSION=1']) { def sanitizers = "undefined" // Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS. def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}" diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index f56db4f10f1..189b244b1d8 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -1144,7 +1144,16 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx, mp.set_gpu_properties(migraphx_ctx); mp.parse(m); auto tc = mp.get_tuning_config(exhaustive); - + std::string problem_config = tc.problem.to(); + for(const auto i : iterator_for(m)) + { + if(starts_with(i->name(), "@")) + { + continue; + } + problem_config += " " + i->name(); + } + tc.problem = problem_config; const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); static std::mutex mutex; if(trace) diff --git a/test/verify/test_conv_add_layernorm_conv.cpp b/test/verify/test_conv_add_layernorm_conv.cpp index ffd1e3d0b4e..4d2d9c07056 100644 --- a/test/verify/test_conv_add_layernorm_conv.cpp +++ b/test/verify/test_conv_add_layernorm_conv.cpp @@ -36,7 +36,8 @@ struct test_conv_add_layernorm_conv : verify_programadd_parameter("x", migraphx::shape{DType, {2, 4, 64, 64}}); - auto weights1 = mm->add_parameter("w", migraphx::shape{DType, {320, 4, 3, 3}}); + auto weights1 = mm->add_parameter("w1", migraphx::shape{DType, {320, 4, 3, 3}}); + auto weights2 = mm->add_parameter("w2", migraphx::shape{DType, {4, 320, 3, 3}}); auto bias_literal = abs(migraphx::generate_literal(migraphx::shape{DType, {320}}, 1)); auto bias = mm->add_literal(bias_literal); auto conv1 = mm->add_instruction( @@ -47,7 +48,11 @@ struct test_conv_add_layernorm_conv : verify_programadd_instruction(migraphx::make_op("add"), conv1, bcast_bias); auto rsp_add = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {0, 32, -1}}}), bias_add); - add_layernorm(*mm, rsp_add, rsp_add->get_shape().lens()); + auto layernorm = add_layernorm(*mm, rsp_add, rsp_add->get_shape().lens()); + auto layernorm_rsp = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", {0, 320, 64, 64}}}), layernorm); + mm->add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}}), layernorm_rsp, weights2); return p; } std::string section() const { return "conv"; } From 67ea3c66048af5f4bc084f9e241e8086d235e862 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Thu, 1 Aug 2024 17:30:40 +0000 Subject: [PATCH 126/145] change EPS --- test/include/layernorm.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/include/layernorm.hpp b/test/include/layernorm.hpp index deaf0a4fa1c..ed8cc008bd7 100644 --- a/test/include/layernorm.hpp +++ b/test/include/layernorm.hpp @@ -29,7 +29,7 @@ inline migraphx::instruction_ref add_layernorm(migraphx::module& m, migraphx::instruction_ref x, const std::vector& dims, - float eps = 0e-12f) + float eps = 1e-12f) { auto mgx_type = x->get_shape().type(); auto scale = m.add_parameter("scale", migraphx::shape{mgx_type, {dims.back()}}); From 5b51efd90cf38720e9d2da0e9cfd466557fd4c94 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 12 Aug 2024 18:24:12 +0000 Subject: [PATCH 127/145] merge fixes --- test/gpu/fuse_mlir.cpp | 127 ----------------------------------------- 1 file changed, 127 deletions(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index faa9e16adae..102bfdc5ba3 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -277,133 +277,6 @@ TEST_CASE(dot_multi_use_trans_add_pooling_sub) EXPECT(p1.sort() == p2.sort()); } -TEST_CASE(multi_use_dot_trans_add_pooling_sub) -{ - migraphx::shape s1{migraphx::shape::float_type, {1, 1, 4, 5}}; - migraphx::shape s2{migraphx::shape::float_type, {1, 1, 5, 5}}; - migraphx::program p1; - { - auto* mm = p1.get_main_module(); - auto a = mm->add_parameter("a", s1); - auto b = mm->add_parameter("b", s2); - auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); - auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b); - auto dot_trans = mm->add_instruction( - migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dot); - auto add = add_pointwise(p1, "main:pointwise0", {dot_trans, x}, single_pointwise("add")); - auto pooling = - mm->add_instruction(migraphx::make_op("pooling", - {{"mode", migraphx::op::pooling_mode::lpnorm}, - {"padding", {0, 0, 0, 1}}, - {"stride", {1, 1}}, - {"lengths", {2, 1}}, - {"lp_order", 2}}), - add); - auto sub = add_pointwise(p1, "main:pointwise1", {dot, pooling}, single_pointwise("sub")); - mm->add_return({sub}); - } - run_pass(p1); - migraphx::program p2; - { - auto* mm = p2.get_main_module(); - auto a = mm->add_parameter("a", s1); - auto b = mm->add_parameter("b", s2); - auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); - auto fused = add_mlir( - p2, - "mlir_main:pointwise0", - {x, a, b}, - {"x2", "y0", "y1"}, - [=](auto* pm, const auto& inputs) { - auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); - auto dot_trans = pm->add_instruction( - migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dot); - - auto add = pm->add_instruction(migraphx::make_op("add"), dot_trans, inputs[0]); - return std::make_tuple(dot, std::vector{dot, add}); - }); - auto fused_dot_add = - mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); - auto pooling = - mm->add_instruction(migraphx::make_op("pooling", - {{"mode", migraphx::op::pooling_mode::lpnorm}, - {"padding", {0, 0, 0, 1}}, - {"stride", {1, 1}}, - {"lengths", {2, 1}}, - {"lp_order", 2}}), - fused_dot_add); - auto dot = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); - auto sub = add_pointwise(p2, "main:pointwise1", {dot, pooling}, single_pointwise("sub")); - mm->add_return({sub}); - } - EXPECT(p1.sort() == p2.sort()); -} - -TEST_CASE(dot_multi_use_trans_add_pooling_sub) -{ - migraphx::shape s1{migraphx::shape::float_type, {1, 1, 4, 5}}; - migraphx::shape s2{migraphx::shape::float_type, {1, 1, 5, 5}}; - migraphx::program p1; - { - auto* mm = p1.get_main_module(); - auto a = mm->add_parameter("a", s1); - auto b = mm->add_parameter("b", s2); - auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); - auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b); - auto dot_trans = mm->add_instruction( - migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dot); - auto add = add_pointwise(p1, "main:pointwise0", {dot_trans, x}, single_pointwise("add")); - auto pooling = - mm->add_instruction(migraphx::make_op("pooling", - {{"mode", migraphx::op::pooling_mode::lpnorm}, - {"padding", {1, 0, 0, 0}}, - {"stride", {1, 1}}, - {"lengths", {2, 1}}, - {"lp_order", 2}}), - add); - auto sub = - add_pointwise(p1, "main:pointwise1", {dot_trans, pooling}, single_pointwise("sub")); - mm->add_return({sub}); - } - run_pass(p1); - migraphx::program p2; - { - auto* mm = p2.get_main_module(); - auto a = mm->add_parameter("a", s1); - auto b = mm->add_parameter("b", s2); - auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); - auto fused = add_mlir( - p2, - "mlir_main:pointwise0", - {x, a, b}, - {"x2", "y0", "y1"}, - [=](auto* pm, const auto& inputs) { - auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); - auto dot_trans = pm->add_instruction( - migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dot); - auto add = pm->add_instruction(migraphx::make_op("add"), dot_trans, inputs[0]); - return std::make_tuple(dot, std::vector{dot, add}); - }); - auto fused_dot_add = - mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); - auto pooling = - mm->add_instruction(migraphx::make_op("pooling", - {{"mode", migraphx::op::pooling_mode::lpnorm}, - {"padding", {1, 0, 0, 0}}, - {"stride", {1, 1}}, - {"lengths", {2, 1}}, - {"lp_order", 2}}), - fused_dot_add); - auto dot = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); - auto dot_trans = mm->add_instruction( - migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dot); - auto sub = - add_pointwise(p2, "main:pointwise1", {dot_trans, pooling}, single_pointwise("sub")); - mm->add_return({sub}); - } - EXPECT(p1.sort() == p2.sort()); -} - TEST_CASE(add_dot) { migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; From 5e828ee71a8ef41ffa248b156836a19cf01b1c89 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 12 Aug 2024 22:12:50 +0000 Subject: [PATCH 128/145] fix tidy --- src/targets/gpu/fuse_mlir.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 2aa1402975d..9d7e4b15f39 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -414,7 +414,7 @@ MIGRAPHX_PRED_MATCHER(mlir_split_reduce, instruction_ref ins) { if(ins->name() != "split_fused_reduce") return false; - auto mod_arg = ins->module_inputs().front(); + auto* mod_arg = ins->module_inputs().front(); auto supported_reshapes = reshaper_names(); supported_reshapes.erase("slice"); std::unordered_set builtins = {"@param", "@literal", "@return"}; @@ -515,7 +515,7 @@ struct find_mlir_split_reduce // pointwise inside split_fused_reduce } } - auto sub_pm = mod_args.front(); + auto* sub_pm = mod_args.front(); auto param_map_2 = create_param_map_with_literals( &main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args)); for(const auto& i : param_map_2) From 69fef788a04cc4fc94cc320acb121150fb34ae2a Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 13 Aug 2024 14:17:58 +0000 Subject: [PATCH 129/145] change EPS For half and fp8 --- test/verify/test_conv_add_layernorm_conv.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/test/verify/test_conv_add_layernorm_conv.cpp b/test/verify/test_conv_add_layernorm_conv.cpp index 4d2d9c07056..989073aacb0 100644 --- a/test/verify/test_conv_add_layernorm_conv.cpp +++ b/test/verify/test_conv_add_layernorm_conv.cpp @@ -48,7 +48,17 @@ struct test_conv_add_layernorm_conv : verify_programadd_instruction(migraphx::make_op("add"), conv1, bcast_bias); auto rsp_add = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {0, 32, -1}}}), bias_add); - auto layernorm = add_layernorm(*mm, rsp_add, rsp_add->get_shape().lens()); + float eps = 1e-12; + if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type) + { + // use 0.250 for fp8 + eps = 0.250; + } + else if constexpr((DType) == migraphx::shape::half_type) + { + eps = 1e-6; + } + auto layernorm = add_layernorm(*mm, rsp_add, rsp_add->get_shape().lens(), eps); auto layernorm_rsp = mm->add_instruction( migraphx::make_op("reshape", {{"dims", {0, 320, 64, 64}}}), layernorm); mm->add_instruction( From 1ebf2a3a91adca98a183665dc0c1ffbbe58c6434 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 14 Aug 2024 14:53:09 +0000 Subject: [PATCH 130/145] address review comments --- src/targets/gpu/fuse_mlir.cpp | 15 +++++++++------ src/targets/gpu/mlir.cpp | 8 +++----- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 17b68cd4477..450dc8b37f9 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -485,12 +485,18 @@ struct find_mlir_split_reduce { auto reduce_ins = r.result; auto gemm_ins = r.instructions["gemm"]; + assert(gemm_ins->get_shape().sub_shapes().empty()); auto* rm = reduce_ins->module_inputs().front(); auto names = rm->get_parameter_names(); std::sort(names.begin(), names.end()); module_ref gemm_old_mm = gemm_ins->module_inputs().front(); - module_ref mm = mpm.create_module(gemm_old_mm->name() + "_split_fused_reduce"); - mm->add_instructions(gemm_old_mm); + module_ref mm = + mpm.create_module(gemm_old_mm->name() + "_split_fused_reduce", *gemm_old_mm); + // remove last return instruction + if(std::prev(mm->end())->name() == "@return") + { + mm->remove_instruction(std::prev(mm->end())); + } mm->set_bypass(); std::unordered_map param_map; param_map[gemm_ins] = std::prev(mm->end()); @@ -518,10 +524,7 @@ struct find_mlir_split_reduce auto* sub_pm = mod_args.front(); auto param_map_2 = create_param_map_with_literals( &main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args)); - for(const auto& i : param_map_2) - { - param_map.insert(i); - } + param_map.insert(param_map_2.begin(), param_map_2.end()); return main_mod.fuse(*sub_pm, inputs, ¶m_map).front(); } return main_mod.insert_instruction(pos, op, inputs, mod_args); diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 189b244b1d8..54fdc944e55 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include #include #include @@ -958,7 +960,6 @@ bool is_reduce(const instruction& ins) { return contains(ins.name(), "reduce"); static void rewrite_reduce(module& m) { - std::vector ins_to_remove; for(auto i : iterator_for(m)) { if(is_reduce(*i)) @@ -997,12 +998,9 @@ static void rewrite_reduce(module& m) auto rsp_back = m.insert_instruction( i, migraphx::make_op("reshape", {{"dims", reduce_lens}}), collapsed_reduce); m.replace_instruction(i, rsp_back); - ins_to_remove.push_back(i); } } - std::for_each(ins_to_remove.begin(), ins_to_remove.end(), [&](const auto& remove_ins) { - m.remove_instruction(remove_ins); - }); + migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); } bool is_module_fusible(const module& m, const context& migraphx_ctx, const value& solution) From 102a24639cb921e20432fbc8edabb2c2a4034195 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Wed, 14 Aug 2024 15:01:08 +0000 Subject: [PATCH 131/145] formattimg --- src/targets/gpu/fuse_mlir.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 450dc8b37f9..78dde6ce3d4 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -486,8 +486,8 @@ struct find_mlir_split_reduce auto reduce_ins = r.result; auto gemm_ins = r.instructions["gemm"]; assert(gemm_ins->get_shape().sub_shapes().empty()); - auto* rm = reduce_ins->module_inputs().front(); - auto names = rm->get_parameter_names(); + auto* rm = reduce_ins->module_inputs().front(); + auto names = rm->get_parameter_names(); std::sort(names.begin(), names.end()); module_ref gemm_old_mm = gemm_ins->module_inputs().front(); module_ref mm = From 335be33ba1d9fee2da7ea2ab275598dd6df7dd17 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Aug 2024 14:12:18 +0000 Subject: [PATCH 132/145] address review comments, add dump_mlir test --- src/targets/gpu/fuse_mlir.cpp | 18 +++++------- src/targets/gpu/include/migraphx/gpu/mlir.hpp | 4 +-- src/targets/gpu/mlir.cpp | 15 ++++------ test/gpu/mlir.cpp | 29 +++++++++++++++++++ 4 files changed, 44 insertions(+), 22 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 78dde6ce3d4..0b136710126 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -512,20 +512,16 @@ struct find_mlir_split_reduce const std::vector& mod_args) { if(op.name() == "pointwise") { - for(const auto& skip_param : inputs) - { - if(not contains(param_map, skip_param)) - { - param_map[skip_param] = - skip_param; // skip adding parameter for inputs of - // pointwise inside split_fused_reduce - } - } auto* sub_pm = mod_args.front(); auto param_map_2 = create_param_map_with_literals( &main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args)); - param_map.insert(param_map_2.begin(), param_map_2.end()); - return main_mod.fuse(*sub_pm, inputs, ¶m_map).front(); + // skip adding parameter for inputs of + // pointwise inside split_fused_reduc + for(const auto& skip_input : inputs) + { + param_map_2[skip_input] = skip_input; + } + return main_mod.fuse(*sub_pm, inputs, ¶m_map_2).front(); } return main_mod.insert_instruction(pos, op, inputs, mod_args); }); diff --git a/src/targets/gpu/include/migraphx/gpu/mlir.hpp b/src/targets/gpu/include/migraphx/gpu/mlir.hpp index 8acb1807e31..7f31f43dfd4 100644 --- a/src/targets/gpu/include/migraphx/gpu/mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/mlir.hpp @@ -37,8 +37,8 @@ inline namespace MIGRAPHX_INLINE_NS { struct module; namespace gpu { -MIGRAPHX_GPU_EXPORT std::string dump_mlir(const module& m); -MIGRAPHX_GPU_EXPORT std::string dump_mlir(const module& m, const std::vector& inputs); +MIGRAPHX_GPU_EXPORT std::string dump_mlir(module m); +MIGRAPHX_GPU_EXPORT std::string dump_mlir(module m, const std::vector& inputs); MIGRAPHX_GPU_EXPORT bool is_module_fusible(const module& m, const context& migraphx_ctx, const value& solution); diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 54fdc944e55..25bc5349330 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -1032,24 +1032,21 @@ void adjust_param_shapes(module& m, const std::vector& inputs) } } -std::string dump_mlir(const module& m, const std::vector& inputs) +std::string dump_mlir(module m, const std::vector& inputs) { - module mm; const_module_ref mr = &m; if(not inputs.empty()) { - mm = m; - mr = &mm; - adjust_param_shapes(mm, inputs); + adjust_param_shapes(m, inputs); } - rewrite_reduce(mm); + rewrite_reduce(m); mlir_program mp; mp.parse(*mr); auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); return mlir_print(&mlirOperationPrint, mod_op); } -std::string dump_mlir(const module& m) { return dump_mlir(m, {}); } +std::string dump_mlir(module m) { return dump_mlir(m, {}); } mlir_code_object compile_mlir(const context& migraphx_ctx, module m, @@ -1171,9 +1168,9 @@ void use(T&) { } -std::string dump_mlir(const module&) { return {}; } +std::string dump_mlir(module) { return {}; } -std::string dump_mlir(const module& m, const std::vector& inputs) +std::string dump_mlir(module m, const std::vector& inputs) { use(m); use(inputs); diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index 7d41148876f..a23152bed06 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -275,6 +275,35 @@ module { EXPECT(verify_mlir(m)); } +TEST_CASE(conv_reduce_sum) +{ + std::string mlir_output = R"__migraphx__( +module { + func.func @mlir_convolution_reshape_reduce_sum_reshape(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x1x1xf32, 2x1x1x1> attributes ${attrs} { + %0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1> + %1 = migraphx.reshape %0 {dims = [1, 2, 4]} : <1x2x2x2xf32, 8x4x2x1> -> <1x2x4xf32, 8x4x1> + %2 = migraphx.reduce_sum %1 {axes = [2]} : <1x2x4xf32, 8x4x1> -> <1x2x1xf32, 2x1x1> + %3 = migraphx.reshape %2 {dims = [1, 2, 1, 1]} : <1x2x1xf32, 2x1x1> -> <1x2x1x1xf32, 2x1x1x1> + return %3 : !migraphx.shaped<1x2x1x1xf32, 2x1x1x1> + } +} +)__migraphx__"; + migraphx::module m; + auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}}); + auto w = m.add_parameter("w", {migraphx::shape::float_type, {2, 8, 3, 3}}); + auto conv = m.add_instruction(migraphx::make_op("convolution"), x, w); + auto reduce_sum = m.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2, 3}}}), conv); + m.add_return({reduce_sum}); + auto s = migraphx::gpu::dump_mlir(m); + // Skip test if MLIR is not enabled + if(s.empty()) + return; + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); + //EXPECT(verify_mlir(m)); +} + TEST_CASE(quant_dot_add) { std::string mlir_output = R"__migraphx__( From 112b14a235ccb3474e664d4d2353f21dc08bbed1 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Aug 2024 14:13:00 +0000 Subject: [PATCH 133/145] formatting --- test/gpu/mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index a23152bed06..cf140c16aa4 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -301,7 +301,7 @@ module { auto mlir_output_with_attrs = migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); CHECK(encode(s) == encode(mlir_output_with_attrs)); - //EXPECT(verify_mlir(m)); + // EXPECT(verify_mlir(m)); } TEST_CASE(quant_dot_add) From 57c550e9f55abb936e2c403df6d55f4a49454310 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Aug 2024 14:14:05 +0000 Subject: [PATCH 134/145] fix typo --- src/targets/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 0b136710126..435493e7f6a 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -516,7 +516,7 @@ struct find_mlir_split_reduce auto param_map_2 = create_param_map_with_literals( &main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args)); // skip adding parameter for inputs of - // pointwise inside split_fused_reduc + // pointwise inside split_fused_reduce for(const auto& skip_input : inputs) { param_map_2[skip_input] = skip_input; From b02eb78e55f7148802293fe820461e4447449bcf Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Aug 2024 14:21:45 +0000 Subject: [PATCH 135/145] fix tidy --- src/targets/gpu/mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 25bc5349330..3097c36c1e2 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -1046,7 +1046,7 @@ std::string dump_mlir(module m, const std::vector& inputs) return mlir_print(&mlirOperationPrint, mod_op); } -std::string dump_mlir(module m) { return dump_mlir(m, {}); } +std::string dump_mlir(module m) { return dump_mlir(std::move(m), {}); } mlir_code_object compile_mlir(const context& migraphx_ctx, module m, From 86b98aaf4d48e80dd566669ef3393d666322c852 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Aug 2024 16:33:30 +0000 Subject: [PATCH 136/145] add test --- src/targets/gpu/fuse_mlir.cpp | 5 +- test/gpu/fuse_mlir.cpp | 185 +++++++++++++++++++++++++++++++--- 2 files changed, 173 insertions(+), 17 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 435493e7f6a..2e9f19df13e 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -490,8 +490,7 @@ struct find_mlir_split_reduce auto names = rm->get_parameter_names(); std::sort(names.begin(), names.end()); module_ref gemm_old_mm = gemm_ins->module_inputs().front(); - module_ref mm = - mpm.create_module(gemm_old_mm->name() + "_split_fused_reduce", *gemm_old_mm); + module_ref mm = mpm.create_module(gemm_old_mm->name() + "_" + rm->name(), *gemm_old_mm); // remove last return instruction if(std::prev(mm->end())->name() == "@return") { @@ -587,7 +586,7 @@ struct find_mlir_fused_ops return i != x_ins and reaches(gemm_based_op, i); })) return; - auto names = pm->get_parameter_names(); + auto names = pm->get_parameter_names(); std::sort(names.begin(), names.end()); module_ref mm = mpm.create_module("mlir_" + pm->name()); mm->set_bypass(); diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 87017d97c02..eeae41eabba 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include "migraphx/generate.hpp" #include #include #include @@ -33,6 +34,7 @@ #include MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION); void run_pass(migraphx::program& p) { @@ -40,6 +42,81 @@ void run_pass(migraphx::program& p) p, {migraphx::gpu::fuse_mlir{.enable_extra = true}, migraphx::dead_code_elimination{}}); } +bool all_instructions_are_local(const migraphx::module& m) +{ + return std::all_of(m.begin(), m.end(), [&](const auto& ins) { + return std::all_of(ins.inputs().begin(), ins.inputs().end(), [&](auto input) { + return m.has_instruction(input); + }); + }); +} + +void auto_add_return(migraphx::module_ref m, migraphx::instruction_ref ins) +{ + m->add_return({ins}); +} + +void auto_add_return(migraphx::module_ref m, std::vector inss) +{ + m->add_return(std::move(inss)); +} + +template +migraphx::module_ref add_reduce_module(migraphx::program& p, + const std::string& name, + std::vector inputs, + const std::vector& axes, + F f) +{ + auto* rm = p.create_module(name); + rm->set_bypass(); + std::vector params; + std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { + return rm->add_parameter( + "x" + std::to_string(params.size()), + migraphx::shape{input->get_shape().type(), input->get_shape().lens()}); + }); + auto r = f(rm, params, axes); + auto_add_return(rm, r); + EXPECT(all_instructions_are_local(*rm)); + return rm; +} + +template +migraphx::instruction_ref add_reduce(migraphx::program& p, + const std::string& name, + std::vector inputs, + const std::vector& axes, + F f) +{ + auto* mm = p.get_main_module(); + auto rm = add_reduce_module(p, name, inputs, axes, f); + return mm->add_instruction(migraphx::make_op("fused_reduce", {{"axes", axes}}), inputs, {rm}); +} + +template +migraphx::instruction_ref add_reduce(migraphx::program& p, + const std::string& name, + std::vector inputs, + const std::vector& axes, + const std::string& assign, + F f) +{ + auto* mm = p.get_main_module(); + auto rm = add_reduce_module(p, name, inputs, axes, f); + return mm->add_instruction( + migraphx::make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign}}), + inputs, + {rm}); +} + +inline auto squared() +{ + return [](auto* pm, const auto& inputs) { + return pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[0]); + }; +} + template migraphx::instruction_ref add_mlir(migraphx::program& p, const std::string& name, @@ -54,16 +131,14 @@ migraphx::instruction_ref add_mlir(migraphx::program& p, std::vector params; for(size_t i = 0, e = inputs.size(); i < e; ++i) { - params.push_back(pm->add_parameter(arg_names[i], inputs[i]->get_shape())); + params.push_back(pm->add_parameter(arg_names[i], inputs[i]->get_shape().as_standard())); } auto values = f(pm, params); auto root = std::get<0>(values); auto r = std::get<1>(values); pm->add_return({r}); return mm->add_instruction( - migraphx::make_op("gpu::mlir_op", {{"op", migraphx::to_value(root->get_operator())}}), - inputs, - {pm}); + migraphx::make_op("gpu::mlir_op", {{"op", migraphx::to_value(root)}}), inputs, {pm}); } TEST_CASE(dot_reshapes_add) @@ -101,7 +176,7 @@ TEST_CASE(dot_reshapes_add) auto dot_rsp = pm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 3}}}), dot_trans); auto add = pm->add_instruction(migraphx::make_op("add"), dot_rsp, inputs[0]); - return std::make_tuple(dot, add); + return std::make_tuple(dot->get_operator(), add); }); mm->add_return({fused}); } @@ -137,7 +212,7 @@ TEST_CASE(dot_add) auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); auto add = pm->add_instruction(migraphx::make_op("add"), dot, inputs[0]); - return std::make_tuple(dot, add); + return std::make_tuple(dot->get_operator(), add); }); mm->add_return({fused}); } @@ -187,7 +262,8 @@ TEST_CASE(multi_use_dot_trans_add_pooling_sub) migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dot); auto add = pm->add_instruction(migraphx::make_op("add"), dot_trans, inputs[0]); - return std::make_tuple(dot, std::vector{dot, add}); + return std::make_tuple(dot->get_operator(), + std::vector{dot, add}); }); auto fused_dot_add = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); @@ -253,7 +329,8 @@ TEST_CASE(dot_multi_use_trans_add_pooling_sub) auto dot_unsq = pm->add_instruction( migraphx::make_op("reshape", {{"dims", {1, 1, 5, 4}}}), dot_trans); auto add = pm->add_instruction(migraphx::make_op("add"), dot_unsq, inputs[0]); - return std::make_tuple(dot, std::vector{dot, add}); + return std::make_tuple(dot->get_operator(), + std::vector{dot, add}); }); auto fused_dot_add = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); @@ -302,12 +379,12 @@ TEST_CASE(dot_dot_pointwise) auto dot1 = add_mlir(p2, "mlir_dot4", {a, b}, {"y0", "y1"}, [=](auto* pm, const auto& inputs) { auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); - return std::make_tuple(dot, dot); + return std::make_tuple(dot->get_operator(), dot); }); auto dot2 = add_mlir(p2, "mlir_dot5", {dot1, c}, {"y0", "y1"}, [=](auto* pm, const auto& inputs) { auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); - return std::make_tuple(dot, dot); + return std::make_tuple(dot->get_operator(), dot); }); auto add = add_pointwise(p2, "main:pointwise0", {dot1, dot2}, single_pointwise("add")); mm->add_return({add}); @@ -343,7 +420,7 @@ TEST_CASE(dot_dot_pointwise_pointwise) auto dot1 = add_mlir(p2, "mlir_dot6", {a, b}, {"y0", "y1"}, [=](auto* pm, const auto& inputs) { auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); - return std::make_tuple(dot, dot); + return std::make_tuple(dot->get_operator(), dot); }); auto fused = add_mlir(p2, @@ -354,7 +431,7 @@ TEST_CASE(dot_dot_pointwise_pointwise) auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]); auto add = pm->add_instruction(migraphx::make_op("add"), dot, inputs[0]); - return std::make_tuple(dot, add); + return std::make_tuple(dot->get_operator(), add); }); auto add2 = add_pointwise(p2, "main:pointwise1", {dot1, fused}, single_pointwise("add")); mm->add_return({add2}); @@ -391,7 +468,7 @@ TEST_CASE(add_dot) auto add = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); auto dot = pm->add_instruction(migraphx::make_op("dot"), add, inputs[2]); - return std::make_tuple(dot, dot); + return std::make_tuple(dot->get_operator(), dot); }); mm->add_return({fused}); } @@ -424,7 +501,7 @@ TEST_CASE(int_quant_dot_abs) auto dot = pm->add_instruction(migraphx::make_op("quant_dot"), inputs[0], inputs[1]); auto abs = pm->add_instruction(migraphx::make_op("abs"), dot); - return std::make_tuple(dot, abs); + return std::make_tuple(dot->get_operator(), abs); }); mm->add_return({fused}); } @@ -452,6 +529,86 @@ TEST_CASE(int_quant_dot_tanh_fails) EXPECT(has_pointwise); } +TEST_CASE(conv_split_reduce) +{ + migraphx::shape s_x{migraphx::shape::float_type, {2, 4, 64, 64}}; + migraphx::shape s_w{migraphx::shape::float_type, {320, 4, 3, 3}}; + migraphx::shape s_b{migraphx::shape::float_type, {32}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s_x); + auto w = mm->add_parameter("w", s_w); + auto b = mm->add_literal(migraphx::generate_literal(s_b)); + auto mb = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 32, 10, 64, 64}}}), b); + auto conv = mm->add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}}), x, w); + auto reshape = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 32, 10, 64, 64}}}), conv); + auto add = add_pointwise(p1, "main:pointwise0", {reshape, mb}, single_pointwise("add")); + auto mean_var = add_reduce( + p1, + "main:split_reduce0", + {add}, + {2, 3, 4}, + "assign_add", + [&](auto* rm, + const auto& inputs, + const auto& axes) -> std::vector { + auto xx = add_pointwise(p1, rm, "main:pointwise1", {inputs[0]}, squared()); + auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + auto rsum2 = + rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), xx); + return {rsum2, rsum1}; + }); + auto var = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), mean_var); + auto mean = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), mean_var); + mm->add_return({var, mean}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s_x); + auto w = mm->add_parameter("w", s_w); + auto b = mm->add_literal(migraphx::generate_literal(s_b)); + auto mb = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 32, 10, 64, 64}}}), b); + auto fused = + add_mlir(p2, + "mlir_main:pointwise0_main:split_reduce0", + {mb, x, w}, + {"x2", "y0", "y1"}, + [=](auto* pm, const auto& inputs) { + auto conv = pm->add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}}), + inputs[1], + inputs[2]); + auto reshape = pm->add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 32, 10, 64, 64}}}), conv); + auto add = + pm->add_instruction(migraphx::make_op("add"), reshape, inputs[0]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), add, add); + auto mean = pm->add_instruction( + migraphx::make_op("reduce_sum", {{"axes", {2, 3, 4}}}), add); + auto var = pm->add_instruction( + migraphx::make_op("reduce_sum", {{"axes", {2, 3, 4}}}), mul); + return std::make_tuple( + migraphx::make_op("gpu::mlir_op", + {{"op", migraphx::to_value(conv->get_operator())}}), + std::vector{var, mean}); + }); + auto mean = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto var = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + mm->add_return({var, mean}); + } + EXPECT(p1.sort() == p2.sort()); +} + int main(int argc, const char* argv[]) { if(migraphx::gpu::mlir_enabled()) From df96690e33395f19f0033a300ccc643fd36809fc Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Aug 2024 16:39:27 +0000 Subject: [PATCH 137/145] add reduce.hpp header --- test/fuse_reduce.cpp | 39 +------------- test/gpu/fuse_mlir.cpp | 77 +-------------------------- test/include/reduce.hpp | 114 ++++++++++++++++++++++++++++++++++++++++ test/split_reduce.cpp | 83 +---------------------------- 4 files changed, 117 insertions(+), 196 deletions(-) create mode 100644 test/include/reduce.hpp diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index ba6cf4579b5..8189906912b 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -30,6 +30,7 @@ #include #include +#include #include void run_pass(migraphx::program& p) @@ -37,44 +38,6 @@ void run_pass(migraphx::program& p) migraphx::run_passes(p, {migraphx::fuse_reduce{}, migraphx::dead_code_elimination{}}); } -bool all_instructions_are_local(const migraphx::module& m) -{ - return std::all_of(m.begin(), m.end(), [&](const auto& ins) { - return std::all_of(ins.inputs().begin(), ins.inputs().end(), [&](auto input) { - return m.has_instruction(input); - }); - }); -} - -template -migraphx::instruction_ref add_reduce(migraphx::program& p, - const std::string& name, - std::vector inputs, - const std::vector& axes, - F f) -{ - auto* rm = p.create_module(name); - auto* mm = p.get_main_module(); - rm->set_bypass(); - std::vector params; - std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { - return rm->add_parameter( - "x" + std::to_string(params.size()), - migraphx::shape{input->get_shape().type(), input->get_shape().lens()}); - }); - auto r = f(rm, params, axes); - rm->add_return({r}); - EXPECT(all_instructions_are_local(*rm)); - return mm->add_instruction(migraphx::make_op("fused_reduce", {{"axes", axes}}), inputs, {rm}); -} - -inline auto single_reduce(const std::string& name) -{ - return [=](auto* rm, const auto& inputs, const auto& axes) { - return rm->add_instruction(migraphx::make_op(name, {{"axes", axes}}), inputs); - }; -} - TEST_CASE(single) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index eeae41eabba..b2e6dce6e56 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -21,7 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include "migraphx/generate.hpp" +#include #include #include #include @@ -42,81 +42,6 @@ void run_pass(migraphx::program& p) p, {migraphx::gpu::fuse_mlir{.enable_extra = true}, migraphx::dead_code_elimination{}}); } -bool all_instructions_are_local(const migraphx::module& m) -{ - return std::all_of(m.begin(), m.end(), [&](const auto& ins) { - return std::all_of(ins.inputs().begin(), ins.inputs().end(), [&](auto input) { - return m.has_instruction(input); - }); - }); -} - -void auto_add_return(migraphx::module_ref m, migraphx::instruction_ref ins) -{ - m->add_return({ins}); -} - -void auto_add_return(migraphx::module_ref m, std::vector inss) -{ - m->add_return(std::move(inss)); -} - -template -migraphx::module_ref add_reduce_module(migraphx::program& p, - const std::string& name, - std::vector inputs, - const std::vector& axes, - F f) -{ - auto* rm = p.create_module(name); - rm->set_bypass(); - std::vector params; - std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { - return rm->add_parameter( - "x" + std::to_string(params.size()), - migraphx::shape{input->get_shape().type(), input->get_shape().lens()}); - }); - auto r = f(rm, params, axes); - auto_add_return(rm, r); - EXPECT(all_instructions_are_local(*rm)); - return rm; -} - -template -migraphx::instruction_ref add_reduce(migraphx::program& p, - const std::string& name, - std::vector inputs, - const std::vector& axes, - F f) -{ - auto* mm = p.get_main_module(); - auto rm = add_reduce_module(p, name, inputs, axes, f); - return mm->add_instruction(migraphx::make_op("fused_reduce", {{"axes", axes}}), inputs, {rm}); -} - -template -migraphx::instruction_ref add_reduce(migraphx::program& p, - const std::string& name, - std::vector inputs, - const std::vector& axes, - const std::string& assign, - F f) -{ - auto* mm = p.get_main_module(); - auto rm = add_reduce_module(p, name, inputs, axes, f); - return mm->add_instruction( - migraphx::make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign}}), - inputs, - {rm}); -} - -inline auto squared() -{ - return [](auto* pm, const auto& inputs) { - return pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[0]); - }; -} - template migraphx::instruction_ref add_mlir(migraphx::program& p, const std::string& name, diff --git a/test/include/reduce.hpp b/test/include/reduce.hpp new file mode 100644 index 00000000000..508710e85ec --- /dev/null +++ b/test/include/reduce.hpp @@ -0,0 +1,114 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_TEST_INCLUDE_REDUCE_HPP +#define MIGRAPHX_GUARD_TEST_INCLUDE_REDUCE_HPP + +#include +#include +#include +#include +#include + +inline bool all_instructions_are_local(const migraphx::module& m) +{ + return std::all_of(m.begin(), m.end(), [&](const auto& ins) { + return std::all_of(ins.inputs().begin(), ins.inputs().end(), [&](auto input) { + return m.has_instruction(input); + }); + }); +} + +inline void auto_add_return(migraphx::module_ref m, migraphx::instruction_ref ins) +{ + m->add_return({ins}); +} + +inline void auto_add_return(migraphx::module_ref m, std::vector inss) +{ + m->add_return(std::move(inss)); +} + +template +migraphx::module_ref add_reduce_module(migraphx::program& p, + const std::string& name, + std::vector inputs, + const std::vector& axes, + F f) +{ + auto* rm = p.create_module(name); + rm->set_bypass(); + std::vector params; + std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { + return rm->add_parameter( + "x" + std::to_string(params.size()), + migraphx::shape{input->get_shape().type(), input->get_shape().lens()}); + }); + auto r = f(rm, params, axes); + auto_add_return(rm, r); + EXPECT(all_instructions_are_local(*rm)); + return rm; +} + +template +migraphx::instruction_ref add_reduce(migraphx::program& p, + const std::string& name, + std::vector inputs, + const std::vector& axes, + F f) +{ + auto* mm = p.get_main_module(); + auto rm = add_reduce_module(p, name, inputs, axes, f); + return mm->add_instruction(migraphx::make_op("fused_reduce", {{"axes", axes}}), inputs, {rm}); +} + +template +migraphx::instruction_ref add_reduce(migraphx::program& p, + const std::string& name, + std::vector inputs, + const std::vector& axes, + const std::string& assign, + F f) +{ + auto* mm = p.get_main_module(); + auto rm = add_reduce_module(p, name, inputs, axes, f); + return mm->add_instruction( + migraphx::make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign}}), + inputs, + {rm}); +} + +inline auto squared() +{ + return [](auto* pm, const auto& inputs) { + return pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[0]); + }; +} + +inline auto single_reduce(const std::string& name) +{ + return [=](auto* rm, const auto& inputs, const auto& axes) { + return rm->add_instruction(migraphx::make_op(name, {{"axes", axes}}), inputs); + }; +} +#endif // MIGRAPHX_GUARD_TEST_INCLUDE_REDUCE_HPP diff --git a/test/split_reduce.cpp b/test/split_reduce.cpp index 0c4b8a1f983..e837b577882 100644 --- a/test/split_reduce.cpp +++ b/test/split_reduce.cpp @@ -33,6 +33,7 @@ #include #include +#include void run_pass(migraphx::program& p) { @@ -50,88 +51,6 @@ void run_fuse_pass(migraphx::program& p) {migraphx::fuse_pointwise{}, migraphx::fuse_reduce{}, migraphx::dead_code_elimination{}}); } -bool all_instructions_are_local(const migraphx::module& m) -{ - return std::all_of(m.begin(), m.end(), [&](const auto& ins) { - return std::all_of(ins.inputs().begin(), ins.inputs().end(), [&](auto input) { - return m.has_instruction(input); - }); - }); -} - -void auto_add_return(migraphx::module_ref m, migraphx::instruction_ref ins) -{ - m->add_return({ins}); -} - -void auto_add_return(migraphx::module_ref m, std::vector inss) -{ - m->add_return(std::move(inss)); -} - -template -migraphx::module_ref add_reduce_module(migraphx::program& p, - const std::string& name, - std::vector inputs, - const std::vector& axes, - F f) -{ - auto* rm = p.create_module(name); - rm->set_bypass(); - std::vector params; - std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { - return rm->add_parameter( - "x" + std::to_string(params.size()), - migraphx::shape{input->get_shape().type(), input->get_shape().lens()}); - }); - auto r = f(rm, params, axes); - auto_add_return(rm, r); - EXPECT(all_instructions_are_local(*rm)); - return rm; -} - -template -migraphx::instruction_ref add_reduce(migraphx::program& p, - const std::string& name, - std::vector inputs, - const std::vector& axes, - F f) -{ - auto* mm = p.get_main_module(); - auto rm = add_reduce_module(p, name, inputs, axes, f); - return mm->add_instruction(migraphx::make_op("fused_reduce", {{"axes", axes}}), inputs, {rm}); -} - -template -migraphx::instruction_ref add_reduce(migraphx::program& p, - const std::string& name, - std::vector inputs, - const std::vector& axes, - const std::string& assign, - F f) -{ - auto* mm = p.get_main_module(); - auto rm = add_reduce_module(p, name, inputs, axes, f); - return mm->add_instruction( - migraphx::make_op("split_fused_reduce", {{"axes", axes}, {"assign", assign}}), - inputs, - {rm}); -} - -inline auto single_reduce(const std::string& name) -{ - return [=](auto* rm, const auto& inputs, const auto& axes) { - return rm->add_instruction(migraphx::make_op(name, {{"axes", axes}}), inputs); - }; -} - -inline auto squared() -{ - return [](auto* pm, const auto& inputs) { - return pm->add_instruction(migraphx::make_op("mul"), inputs[0], inputs[0]); - }; -} - TEST_CASE(single) { migraphx::shape s{migraphx::shape::float_type, {2, 3, 327680}}; From 8e0acd0b31b592546e76d2dd7db4f5a43d4e2ff4 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Aug 2024 17:09:23 +0000 Subject: [PATCH 138/145] add multi use unit-test --- test/gpu/fuse_mlir.cpp | 106 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index b2e6dce6e56..0fb62620b19 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -32,6 +32,7 @@ #include #include #include +#include MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION); @@ -531,6 +532,111 @@ TEST_CASE(conv_split_reduce) auto var = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); mm->add_return({var, mean}); } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(conv_add_split_reduce_multi_use) +{ + migraphx::shape s_x{migraphx::shape::float_type, {2, 4, 64, 64}}; + migraphx::shape s_w{migraphx::shape::float_type, {320, 4, 3, 3}}; + migraphx::shape s_b{migraphx::shape::float_type, {32}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s_x); + auto w = mm->add_parameter("w", s_w); + auto b = mm->add_literal(migraphx::generate_literal(s_b)); + auto mb = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 32, 10, 64, 64}}}), b); + auto conv = mm->add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}}), x, w); + auto reshape = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 32, 10, 64, 64}}}), conv); + auto add = add_pointwise(p1, "main:pointwise0", {reshape, mb}, single_pointwise("add")); + auto mean_var = add_reduce( + p1, + "main:split_reduce0", + {add}, + {2, 3, 4}, + "assign_add", + [&](auto* rm, + const auto& inputs, + const auto& axes) -> std::vector { + auto xx = add_pointwise(p1, rm, "main:pointwise1", {inputs[0]}, squared()); + auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + auto rsum2 = + rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), xx); + return {rsum2, rsum1}; + }); + auto var = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), mean_var); + auto mean = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), mean_var); + auto mean_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", add->get_shape().lens()}}), mean); + auto var_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", add->get_shape().lens()}}), var); + auto norm = add_pointwise( + p1, "main:pointwise2", {add, mean_mb, var_mb}, [=](auto* pm, const auto& inputs) { + auto sub = + pm->add_instruction(migraphx::make_op("sub"), inputs.at(0), inputs.at(1)); + return pm->add_instruction(migraphx::make_op("div"), sub, inputs.at(2)); + }); + mm->add_return({norm}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s_x); + auto w = mm->add_parameter("w", s_w); + auto b = mm->add_literal(migraphx::generate_literal(s_b)); + auto mb = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 32, 10, 64, 64}}}), b); + auto fused = + add_mlir(p2, + "mlir_main:pointwise0_main:split_reduce0", + {mb, x, w}, + {"x2", "y0", "y1"}, + [=](auto* pm, const auto& inputs) { + auto conv = pm->add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}}), + inputs[1], + inputs[2]); + auto reshape = pm->add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 32, 10, 64, 64}}}), conv); + auto add = + pm->add_instruction(migraphx::make_op("add"), reshape, inputs[0]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), add, add); + auto mean = pm->add_instruction( + migraphx::make_op("reduce_sum", {{"axes", {2, 3, 4}}}), add); + auto var = pm->add_instruction( + migraphx::make_op("reduce_sum", {{"axes", {2, 3, 4}}}), mul); + return std::make_tuple( + migraphx::make_op("gpu::mlir_op", + {{"op", migraphx::to_value(conv->get_operator())}}), + std::vector{var, mean, add}); + }); + auto cba = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), fused); + auto var = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto mean = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto mean_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", cba->get_shape().lens()}}), mean); + auto var_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", cba->get_shape().lens()}}), var); + auto norm = add_pointwise( + p2, "main:pointwise2", {cba, mean_mb, var_mb}, [=](auto* pm, const auto& inputs) { + auto sub = + pm->add_instruction(migraphx::make_op("sub"), inputs.at(0), inputs.at(1)); + return pm->add_instruction(migraphx::make_op("div"), sub, inputs.at(2)); + }); + mm->add_return({norm}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION{})) + return; EXPECT(p1.sort() == p2.sort()); } From a5733c5533a862590fc319dc2c855ea943f9cab7 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Aug 2024 17:41:55 +0000 Subject: [PATCH 139/145] fix licensing --- src/include/migraphx/op/pad.hpp | 2 +- src/targets/gpu/jit/pad.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/include/migraphx/op/pad.hpp b/src/include/migraphx/op/pad.hpp index c24b83edf83..4fc99ae916c 100644 --- a/src/include/migraphx/op/pad.hpp +++ b/src/include/migraphx/op/pad.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/jit/pad.cpp b/src/targets/gpu/jit/pad.cpp index 02654dd46b1..d3e14864d7a 100644 --- a/src/targets/gpu/jit/pad.cpp +++ b/src/targets/gpu/jit/pad.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 070da3da543e79b660112aed75ebaacf85241709 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Aug 2024 20:01:46 +0000 Subject: [PATCH 140/145] revert problem_key changes --- src/targets/gpu/mlir.cpp | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 3097c36c1e2..e8ff586c8ff 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -1138,17 +1138,7 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx, mlir_program mp; mp.set_gpu_properties(migraphx_ctx); mp.parse(m); - auto tc = mp.get_tuning_config(exhaustive); - std::string problem_config = tc.problem.to(); - for(const auto i : iterator_for(m)) - { - if(starts_with(i->name(), "@")) - { - continue; - } - problem_config += " " + i->name(); - } - tc.problem = problem_config; + auto tc = mp.get_tuning_config(exhaustive); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); static std::mutex mutex; if(trace) From dc71b68db90a0df4b9bacb08a3ad6d91956b804f Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Aug 2024 20:39:00 +0000 Subject: [PATCH 141/145] add one more test --- test/gpu/fuse_mlir.cpp | 130 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 0fb62620b19..082b367c0e3 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -640,6 +640,136 @@ TEST_CASE(conv_add_split_reduce_multi_use) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(conv_add_split_reduce_multi_use_conv) +{ + migraphx::shape s_x{migraphx::shape::float_type, {2, 4, 64, 64}}; + migraphx::shape s_w1{migraphx::shape::float_type, {320, 4, 3, 3}}; + migraphx::shape s_w2{migraphx::shape::float_type, {320, 320, 3, 3}}; + migraphx::shape s_b{migraphx::shape::float_type, {32}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s_x); + auto w1 = mm->add_parameter("w1", s_w1); + auto w2 = mm->add_parameter("w2", s_w2); + auto b = mm->add_literal(migraphx::generate_literal(s_b)); + auto mb = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 32, 10, 64, 64}}}), b); + auto conv = mm->add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}}), x, w1); + auto reshape = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 32, 10, 64, 64}}}), conv); + auto add = add_pointwise(p1, "main:pointwise0", {reshape, mb}, single_pointwise("add")); + auto mean_var = add_reduce( + p1, + "main:split_reduce0", + {add}, + {2, 3, 4}, + "assign_add", + [&](auto* rm, + const auto& inputs, + const auto& axes) -> std::vector { + auto xx = add_pointwise(p1, rm, "main:pointwise1", {inputs[0]}, squared()); + auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + auto rsum2 = + rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), xx); + return {rsum2, rsum1}; + }); + auto var = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), mean_var); + auto mean = + mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), mean_var); + auto mean_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", add->get_shape().lens()}}), mean); + auto mean_rsp = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 320, 64, 64}}}), mean_mb); + auto var_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", add->get_shape().lens()}}), var); + auto var_rsp = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 320, 64, 64}}}), var_mb); + auto add_rsp = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 320, 64, 64}}}), add); + auto norm = add_pointwise( + p1, "main:pointwise2", {add_rsp, mean_rsp, var_rsp}, [=](auto* pm, const auto& inputs) { + auto sub = + pm->add_instruction(migraphx::make_op("sub"), inputs.at(0), inputs.at(1)); + return pm->add_instruction(migraphx::make_op("div"), sub, inputs.at(2)); + }); + auto conv_2 = mm->add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}}), norm, w2); + mm->add_return({conv_2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s_x); + auto w1 = mm->add_parameter("w1", s_w1); + auto w2 = mm->add_parameter("w2", s_w2); + auto b = mm->add_literal(migraphx::generate_literal(s_b)); + auto mb = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 32, 10, 64, 64}}}), b); + auto fused = + add_mlir(p2, + "mlir_main:pointwise0_main:split_reduce0", + {mb, x, w1}, + {"x2", "y0", "y1"}, + [=](auto* pm, const auto& inputs) { + auto conv = pm->add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}}), + inputs[1], + inputs[2]); + auto reshape = pm->add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 32, 10, 64, 64}}}), conv); + auto add = + pm->add_instruction(migraphx::make_op("add"), reshape, inputs[0]); + auto mul = pm->add_instruction(migraphx::make_op("mul"), add, add); + auto mean = pm->add_instruction( + migraphx::make_op("reduce_sum", {{"axes", {2, 3, 4}}}), add); + auto var = pm->add_instruction( + migraphx::make_op("reduce_sum", {{"axes", {2, 3, 4}}}), mul); + return std::make_tuple( + migraphx::make_op("gpu::mlir_op", + {{"op", migraphx::to_value(conv->get_operator())}}), + std::vector{var, mean, add}); + }); + auto cba = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 2}}), fused); + auto var = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused); + auto mean = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused); + auto mean_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", cba->get_shape().lens()}}), mean); + auto mean_rsp = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 320, 64, 64}}}), mean_mb); + auto var_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", cba->get_shape().lens()}}), var); + auto var_rsp = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 320, 64, 64}}}), var_mb); + auto cba_rsp = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 320, 64, 64}}}), cba); + auto input_fused_conv = add_mlir( + p2, + "main:pointwise2:mlir_convolution3", + {cba_rsp, mean_rsp, var_rsp, w2}, + {"x0", "x1", "x2", "x3"}, + [=](auto* pm, const auto& inputs) { + auto sub = + pm->add_instruction(migraphx::make_op("sub"), inputs.at(0), inputs.at(1)); + auto div = pm->add_instruction(migraphx::make_op("div"), sub, inputs.at(2)); + auto conv = pm->add_instruction( + migraphx::make_op("convolution", {{"padding", {1, 1, 1, 1}}}), + div, + inputs.at(3)); + return std::make_tuple(conv->get_operator(), conv); + }); + mm->add_return({input_fused_conv}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION{}) or + not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + int main(int argc, const char* argv[]) { if(migraphx::gpu::mlir_enabled()) From 1b68e4588f514f4511909c7ebef343c2bf052a4b Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 16 Aug 2024 20:50:45 +0000 Subject: [PATCH 142/145] use auto_add_return --- test/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 082b367c0e3..c6cff002402 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -62,7 +62,7 @@ migraphx::instruction_ref add_mlir(migraphx::program& p, auto values = f(pm, params); auto root = std::get<0>(values); auto r = std::get<1>(values); - pm->add_return({r}); + auto_add_return(pm, r); return mm->add_instruction( migraphx::make_op("gpu::mlir_op", {{"op", migraphx::to_value(root)}}), inputs, {pm}); } From 4e043c740cfec1fddc8d7e8553bb3966b8a08d48 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sat, 17 Aug 2024 13:17:26 +0000 Subject: [PATCH 143/145] use `insert_inline()` --- src/include/migraphx/module.hpp | 9 +++++++ src/module.cpp | 15 +++++++++++ src/targets/gpu/fuse_mlir.cpp | 44 ++++++++++++++------------------- 3 files changed, 43 insertions(+), 25 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 3fc8e51baa8..dd4342f1dc5 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -259,6 +259,15 @@ struct MIGRAPHX_EXPORT module const std::vector& inputs, std::unordered_map* map_ins = nullptr, inserter insert = nullptr); + /* + Insert instructions from module `m` to this module at position `ins` + */ + std::vector + insert_inline(instruction_ref ins, + const module& m, + const std::vector& inputs, + std::unordered_map* map_in = nullptr, + inserter insert = nullptr); void debug_print() const; void debug_print(instruction_ref ins) const; diff --git a/src/module.cpp b/src/module.cpp index 21bc3a80333..c42a2dca80a 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1049,6 +1049,21 @@ module::fuse(const module& m, return this->add_instructions(&m, map_ins, std::move(insert)); } +std::vector +module::insert_inline(instruction_ref ins, + const module& m, + const std::vector& inputs, + std::unordered_map* map_ins, + module::inserter insert) +{ + std::unordered_map default_map_ins; + if(map_ins == nullptr) + map_ins = &default_map_ins; + auto param_map = m.get_ins_param_map(inputs, true); + map_ins->insert(param_map.begin(), param_map.end()); + return this->insert_instructions(ins, &m, map_ins, std::move(insert)); +} + void module_with_inputs::replace(instruction_ref ins, instruction_ref rep) { auto it = std::find(inputs.begin(), inputs.end(), ins); diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 2e9f19df13e..e4579a40847 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -500,30 +500,24 @@ struct find_mlir_split_reduce std::unordered_map param_map; param_map[gemm_ins] = std::prev(mm->end()); bool gemm_has_multi_outs = gemm_ins->outputs().size() > 1; - auto return_vals = - mm->fuse(*rm, - reduce_ins->inputs(), - ¶m_map, - [&](module& main_mod, - instruction_ref pos, - const operation& op, - const std::vector& inputs, - const std::vector& mod_args) { - if(op.name() == "pointwise") - { - auto* sub_pm = mod_args.front(); - auto param_map_2 = create_param_map_with_literals( - &main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args)); - // skip adding parameter for inputs of - // pointwise inside split_fused_reduce - for(const auto& skip_input : inputs) - { - param_map_2[skip_input] = skip_input; - } - return main_mod.fuse(*sub_pm, inputs, ¶m_map_2).front(); - } - return main_mod.insert_instruction(pos, op, inputs, mod_args); - }); + auto return_vals = mm->fuse( + *rm, + reduce_ins->inputs(), + ¶m_map, + [&](module& main_mod, + instruction_ref pos, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) { + if(op.name() == "pointwise") + { + auto* sub_pm = mod_args.front(); + auto param_map_2 = create_param_map_with_literals( + &main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args)); + return main_mod.insert_inline(pos, *sub_pm, inputs, ¶m_map_2).front(); + } + return main_mod.insert_instruction(pos, op, inputs, mod_args); + }); if(gemm_has_multi_outs) { return_vals.insert(return_vals.end(), param_map[gemm_ins]); @@ -744,7 +738,7 @@ struct find_mlir_standalone_attention_op auto gemm1 = mm->add_instruction(make_op("dot"), {softmax, new_upper_v}); std::vector ins_to_replace = {gemm1}; - auto ins_to_be_replaced = gemm_softmax_gemm; + auto ins_to_be_replaced = gemm_softmax_gemm; if(r.instructions.find("trailing_pm") != r.instructions.end()) { auto trailing_pm_ins = r.instructions["trailing_pm"]; From 848d80778876adfbc844d19494a0b50c90c44e53 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sat, 17 Aug 2024 15:50:32 +0000 Subject: [PATCH 144/145] fix cppcheck --- src/include/migraphx/module.hpp | 4 ++-- src/targets/gpu/fuse_mlir.cpp | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index dd4342f1dc5..779dd9ed0c6 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -266,8 +266,8 @@ struct MIGRAPHX_EXPORT module insert_inline(instruction_ref ins, const module& m, const std::vector& inputs, - std::unordered_map* map_in = nullptr, - inserter insert = nullptr); + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); void debug_print() const; void debug_print(instruction_ref ins) const; diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index e4579a40847..f69491b1716 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -514,7 +514,8 @@ struct find_mlir_split_reduce auto* sub_pm = mod_args.front(); auto param_map_2 = create_param_map_with_literals( &main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args)); - return main_mod.insert_inline(pos, *sub_pm, inputs, ¶m_map_2).front(); + return main_mod.insert_inline(pos, *sub_pm, inputs, ¶m_map_2) + .front(); // cppcheck-suppress returnDanglingLifetime; } return main_mod.insert_instruction(pos, op, inputs, mod_args); }); From 94e112a4bdc06179e63439cc38ef04747f6cad74 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Sun, 18 Aug 2024 13:00:27 +0000 Subject: [PATCH 145/145] Formatting --- src/targets/gpu/fuse_mlir.cpp | 38 +++++++++++++++++------------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index f69491b1716..637f8e263f1 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -500,25 +500,25 @@ struct find_mlir_split_reduce std::unordered_map param_map; param_map[gemm_ins] = std::prev(mm->end()); bool gemm_has_multi_outs = gemm_ins->outputs().size() > 1; - auto return_vals = mm->fuse( - *rm, - reduce_ins->inputs(), - ¶m_map, - [&](module& main_mod, - instruction_ref pos, - const operation& op, - const std::vector& inputs, - const std::vector& mod_args) { - if(op.name() == "pointwise") - { - auto* sub_pm = mod_args.front(); - auto param_map_2 = create_param_map_with_literals( - &main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args)); - return main_mod.insert_inline(pos, *sub_pm, inputs, ¶m_map_2) - .front(); // cppcheck-suppress returnDanglingLifetime; - } - return main_mod.insert_instruction(pos, op, inputs, mod_args); - }); + auto return_vals = + mm->fuse(*rm, + reduce_ins->inputs(), + ¶m_map, + [&](module& main_mod, + instruction_ref pos, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) { + if(op.name() == "pointwise") + { + auto* sub_pm = mod_args.front(); + auto param_map_2 = create_param_map_with_literals( + &main_mod, sub_pm, op.compute_shape(to_shapes(inputs), mod_args)); + return main_mod.insert_inline(pos, *sub_pm, inputs, ¶m_map_2) + .front(); // cppcheck-suppress returnDanglingLifetime; + } + return main_mod.insert_instruction(pos, op, inputs, mod_args); + }); if(gemm_has_multi_outs) { return_vals.insert(return_vals.end(), param_map[gemm_ins]);