Skip to content

Commit

Permalink
Merge branch 'develop' into dump-mlir
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 committed Sep 20, 2024
2 parents 97663b7 + 32a84a8 commit 161b7c6
Show file tree
Hide file tree
Showing 25 changed files with 1,804 additions and 138 deletions.
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ else()
option(MIGRAPHX_USE_MIOPEN "Enable MIGraphX to use MIOpen" ON)
endif()

if(WIN32)
option(MIGRAPHX_USE_HIPBLASLT "Enable MIGraphX to use hipBLASLt" OFF)
else()
option(MIGRAPHX_USE_HIPBLASLT "Enable MIGraphX to use hipBLASLt" ON)
endif()

# By default build shared libraries
option(BUILD_SHARED_LIBS "Create shared libraries" ON)

Expand Down Expand Up @@ -315,6 +321,7 @@ rocm_enable_cppcheck(
${CMAKE_CURRENT_SOURCE_DIR}/test/include
DEFINE
MIGRAPHX_MLIR=1
MIGRAPHX_ENABLE_HIPBLASLT_GEMM=1
MIGRAPHX_HAS_EXECUTORS=0
CPPCHECK=1
MIGRAPHX_USE_MIOPEN=1
Expand Down Expand Up @@ -364,6 +371,10 @@ if(MIGRAPHX_USE_ROCBLAS)
list(APPEND PACKAGE_DEPENDS rocblas)
endif()

if(MIGRAPHX_USE_HIPBLASLT)
list(APPEND PACKAGE_DEPENDS hipblaslt)
endif()

rocm_package_add_deb_dependencies(SHARED_DEPENDS "hip-dev")
rocm_package_add_rpm_dependencies(SHARED_DEPENDS "hip-devel")

Expand Down
3 changes: 3 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ if(MIGRAPHX_ENABLE_GPU)
if(MIGRAPHX_USE_ROCBLAS)
list(APPEND MIGRAPHX_CONFIG_DEPENDS PACKAGE rocblas)
endif()
if(MIGRAPHX_USE_HIPBLASLT)
list(APPEND MIGRAPHX_CONFIG_DEPENDS PACKAGE hipblaslt)
endif()
add_subdirectory(targets/gpu)
target_link_libraries(migraphx_all_targets INTERFACE migraphx_gpu)
target_compile_definitions(migraphx_all_targets INTERFACE -DHAVE_GPU)
Expand Down
12 changes: 6 additions & 6 deletions src/include/migraphx/instruction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct MIGRAPHX_EXPORT instruction

void clear_arguments();

friend bool operator==(const instruction& i, instruction_ref ref);
MIGRAPHX_EXPORT friend bool operator==(const instruction& i, instruction_ref ref);

bool valid(instruction_ref start, bool check_order = false) const;

Expand All @@ -86,15 +86,15 @@ struct MIGRAPHX_EXPORT instruction
/// Where this instruction is used as an input to another instruction
const std::vector<instruction_ref>& outputs() const;

friend bool operator==(const instruction& x, const instruction& y);
MIGRAPHX_EXPORT friend bool operator==(const instruction& x, const instruction& y);

friend bool operator!=(const instruction& x, const instruction& y);
MIGRAPHX_EXPORT friend bool operator!=(const instruction& x, const instruction& y);

friend bool operator==(instruction_ref ref, const instruction& i);
MIGRAPHX_EXPORT friend bool operator==(instruction_ref ref, const instruction& i);

friend bool operator!=(const instruction& i, instruction_ref ref);
MIGRAPHX_EXPORT friend bool operator!=(const instruction& i, instruction_ref ref);

friend bool operator!=(instruction_ref ref, const instruction& i);
MIGRAPHX_EXPORT friend bool operator!=(instruction_ref ref, const instruction& i);

void add_output(instruction_ref ins);

Expand Down
72 changes: 72 additions & 0 deletions src/include/migraphx/match/softmax.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MATCH_SOFTMAX_HPP
#define MIGRAPHX_GUARD_MATCH_SOFTMAX_HPP

#include <migraphx/config.hpp>
#include <migraphx/matcher.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace match {

namespace detail {
template <class F>
struct softmax_matcher
{
F f;

auto exp_x_minus_max() const
{
auto x_minus_max =
f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(f("reduce_max"))));
return f("exp")(arg(0)(x_minus_max));
}

auto softmax_base_ops() const
{
auto sum_exp_x_minus_max = f("reduce_sum")(arg(0)(exp_x_minus_max()));
return f("div")(arg(0)(exp_x_minus_max()), arg(1)(skip_broadcasts(sum_exp_x_minus_max)));
}

auto matcher() const { return softmax_base_ops(); }
};
} // namespace detail

template <class F>
auto softmax(F f)
{
return detail::softmax_matcher<F>{f}.matcher();
}

inline auto softmax()
{
return softmax([](auto x) { return name(x); });
}

} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
48 changes: 48 additions & 0 deletions src/include/migraphx/matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ struct matcher_context
return ins == std::prev(mod->end());
}

void debug_print(instruction_ref ins) const { mod->debug_print(ins); }

private:
module* mod = nullptr;
};
Expand Down Expand Up @@ -582,6 +584,52 @@ inline auto outputs()
};
}

inline auto trace(const std::string& s)
{
return [=](auto m) {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
std::cout << s << ": ";
ctx.debug_print(ins);
optional<instruction_ref> result = m.match(ctx, ins);
if(result.has_value())
std::cout << "Found\n";
else
std::cout << "Not Found\n";
return result;
});
};
}

inline auto trace_found(const std::string& s)
{
return [=](auto m) {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
optional<instruction_ref> result = m.match(ctx, ins);
if(result.has_value())
{
std::cout << "Found: " << s << ": ";
ctx.debug_print(ins);
}
return result;
});
};
}

inline auto trace_not_found(const std::string& s)
{
return [=](auto m) {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
optional<instruction_ref> result = m.match(ctx, ins);
if(not result.has_value())
{
std::cout << "Not Found: " << s << ": ";
ctx.debug_print(ins);
}
return result;
});
};
}

MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
Expand Down
11 changes: 11 additions & 0 deletions src/include/migraphx/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ struct MIGRAPHX_EXPORT module
std::unordered_map<instruction_ref, instruction_ref>
get_ins_param_map(const std::vector<instruction_ref>& inputs, bool reverse = false) const;

/// Given a mapping from submodule instructions to parent module instructions
/// construct a vector of inputs with parent module instructions in the
/// correct order
std::vector<instruction_ref>
get_inputs(const std::unordered_map<instruction_ref, instruction_ref>& map_ins) const;

using with_inputs = module_with_inputs;

/// This will split the module into two parts at the instruction splits.
Expand All @@ -245,6 +251,11 @@ struct MIGRAPHX_EXPORT module
const std::vector<instruction_ref>& splits1,
const std::vector<instruction_ref>& splits2) const;

// Insert params to module based on given input instructions and add
// mappings from inputs to corresponding params in instructions map
void add_params(const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr);

// Fuse the instruction into the module by inserting the instructions and
// parameters for any missing inputs.
std::vector<instruction_ref>
Expand Down
21 changes: 21 additions & 0 deletions src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,21 @@ module::get_ins_param_map(const std::vector<instruction_ref>& inputs, bool rever
return result;
}

std::vector<instruction_ref>
module::get_inputs(const std::unordered_map<instruction_ref, instruction_ref>& map_ins) const
{
std::vector<instruction_ref> inputs;
auto params = this->get_parameters();
sort_params(params);

std::transform(params.begin(),
params.end(),
std::back_inserter(inputs),
[&](instruction_ref param) { return map_ins.at(param); });

return inputs;
}

static std::vector<instruction_ref>
select_params(const std::vector<instruction_ref>& instructions,
const std::unordered_map<instruction_ref, instruction_ref>& param_map)
Expand Down Expand Up @@ -1008,6 +1023,12 @@ static void insert_params(module& m,
}
}

void module::add_params(const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>* map_ins)
{
insert_params(*this, inputs, *map_ins);
}

std::vector<instruction_ref>
module::fuse(const std::vector<instruction_ref>& inss,
std::unordered_map<instruction_ref, instruction_ref>* map_ins,
Expand Down
2 changes: 1 addition & 1 deletion src/optimize_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void optimize_module::apply(module_pass_manager& mpm) const
{
mpm.get_module().repeat_while_changes(2, [&] {
// loop to further optimize after initial transformations
mpm.get_module().repeat_while_changes(3, [&] {
mpm.get_module().repeat_while_changes(4, [&] {
mpm.run_pass(simplify_reshapes{});
mpm.run_pass(eliminate_convert{});
mpm.run_pass(dead_code_elimination{});
Expand Down
28 changes: 28 additions & 0 deletions src/targets/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ else()
message(STATUS "MIGraphX build without rocBLAS")
endif()

if(MIGRAPHX_USE_HIPBLASLT)
# hipblaslt
find_package(hipblaslt REQUIRED)
# Making hipblas required to workaround the broken hipblaslt package.
find_package(hipblas REQUIRED)
message(STATUS "MIGraphx build with hipBLAS and hipBLASLt")
else()
message(STATUS "MIGraphX build without hipBLAS and hipBLASLt")
endif()

if(MIGRAPHX_USE_COMPOSABLEKERNEL)
find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
endif()
Expand Down Expand Up @@ -147,6 +157,8 @@ add_library(migraphx_gpu
fuse_ops.cpp
gemm_impl.cpp
hip.cpp
hipblaslt.cpp
hip_gemm_impl.cpp
kernel.cpp
lowering.cpp
logsoftmax.cpp
Expand Down Expand Up @@ -217,6 +229,12 @@ if(MIGRAPHX_USE_ROCBLAS)
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp)
endif()
if(MIGRAPHX_USE_HIPBLASLT)
register_op(migraphx_gpu
HEADER migraphx/gpu/hip_gemm.hpp
OPERATORS gpu::hip_gemm<op::dot> gpu::hip_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp)
endif()
if (MIGRAPHX_USE_MIOPEN)
register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::convolution_backwards> gpu::miopen_convolution<op::quant_convolution>
Expand Down Expand Up @@ -305,6 +323,12 @@ else()
target_compile_definitions(migraphx_gpu PUBLIC MIGRAPHX_USE_ROCBLAS=0)
endif()

if(MIGRAPHX_USE_HIPBLASLT)
target_compile_definitions(migraphx_gpu PUBLIC MIGRAPHX_USE_HIPBLASLT=1)
else()
target_compile_definitions(migraphx_gpu PUBLIC MIGRAPHX_USE_HIPBLASLT=0)
endif()

if(MIGRAPHX_USE_MIOPEN)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")

Expand Down Expand Up @@ -349,6 +373,10 @@ if(MIGRAPHX_USE_ROCBLAS)
target_link_libraries(migraphx_gpu PUBLIC roc::rocblas)
endif()

if(MIGRAPHX_USE_HIPBLASLT)
target_link_libraries(migraphx_gpu PUBLIC roc::hipblaslt)
endif()

if(WIN32)
# Temporary workaround on rocMLIR not exporting correctly libraries it depends on.
target_link_libraries(migraphx_gpu PRIVATE ntdll)
Expand Down
Loading

0 comments on commit 161b7c6

Please sign in to comment.