diff --git a/CMakeLists.txt b/CMakeLists.txt index 7535d2c8b2f..7fc0267392c 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -130,6 +130,7 @@ rocm_enable_clang_tidy( -bugprone-implicit-widening-of-multiplication-result -bugprone-macro-parentheses -bugprone-signed-char-misuse + -bugprone-unchecked-optional-access # Disable the aliased reserved identifiers -cert-dcl37-c -cert-dcl51-cpp diff --git a/Jenkinsfile b/Jenkinsfile index fda8245417d..7b5b6c65b2a 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -89,6 +89,8 @@ def rocmnodename(name) { node_name = "${rocmtest_name} && vega"; } else if(name == "navi21") { node_name = "${rocmtest_name} && navi21"; + } else if(name == "mi100+") { + node_name = "${rocmtest_name} && (gfx908 || gfx90a)"; } else if(name == "anygpu") { node_name = "${rocmtest_name} && (gfx908 || gfx90a || vega)"; } else if(name == "nogpu") { @@ -120,7 +122,7 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build -> } }, hiprtc_gpu_debug: rocmnode('vega') { cmake_build -> stage('HipRTC GPU Debug') { - cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On", gpu_debug: true, hiprtc_workarounds: true) + cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On", gpu_debug: true, hiprtc_workarounds: true) } }, all_targets_debug : rocmnode('vega') { cmake_build -> stage('All targets Release') { @@ -134,6 +136,12 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build -> cmake_build(flags: "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'") } } +}, ck_release: rocmnode('mi100+') { cmake_build -> + stage('CK Release') { + withEnv(['MIGRAPHX_ENABLE_CK=1', 'MIGRAPHX_TUNE_CK=1']) { + cmake_build(flags: "-DCMAKE_BUILD_TYPE=release") + } + } }, clang_asan: rocmnode('nogpu') { cmake_build -> stage('Clang ASAN') { def sanitizers = "undefined,address" diff --git a/requirements.txt b/requirements.txt index 18962a571dd..8ac67c6a300 100755 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,4 @@ ROCmSoftwarePlatform/half@rocm-5.4.2 pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCmSoftwarePlatform/composable_kernel@ac580f77a84c705c678816ef7195adfcc02bdda5 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On +ROCmSoftwarePlatform/composable_kernel@5172ec5280f14974beee2acf1af1db3b2670244c -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On diff --git a/src/dead_code_elimination.cpp b/src/dead_code_elimination.cpp index 1d84c4fd9a5..772555b3106 100644 --- a/src/dead_code_elimination.cpp +++ b/src/dead_code_elimination.cpp @@ -49,8 +49,10 @@ void dead_code_elimination::apply(module& m) const if(i == last) break; // Skip instruction with empty shape as output unless its [dynamic, builtin, undefined, - // identity, allocate] - if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and + // identity, allocate or tuple_type] + if((not i->get_shape().dynamic() and + (i->get_shape().elements() == 0 and + i->get_shape().type() != migraphx::shape::tuple_type)) and not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and not i->is_undefined()) continue; diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index 2d49f894681..c8428860c9e 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -370,31 +371,30 @@ match::matcher_result find_match(module& modl, M&& m) } MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES_FOR) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES) /// Find matches for an instruction in the module for per section of matchers template -void find_matches(size_t trace_pass, Mod& mod, instruction_ref ins, Ms&&... ms) -{ -#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 - const -#endif - int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); -#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 - const -#endif - bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); - bool match = false; +void find_matches_for(source_location location, Mod& mod, instruction_ref ins, Ms&&... ms) +{ + const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); + const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); + const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{}); + const bool trace_for = not trace_filter.empty() and + (contains(std::string{location.file_name()}, trace_filter) or + contains(std::string{location.function_name()}, trace_filter)); + bool match = false; each_args( [&](auto&& m) { if(match) return; - if(trace > 1 or trace_pass > 1) + if(trace > 1 or trace_for) std::cout << "Match: " << get_type_name(m) << std::endl; auto r = match_instruction(get_module(mod), ins, m.matcher()); if(r.result == get_module(mod).end()) return; - if(trace > 0 or trace_pass > 0) + if(trace > 0 or trace_for) { std::cout << "Matched by " << get_type_name(m) << std::endl; get_module(mod).debug_print(ins); @@ -420,23 +420,19 @@ void find_matches(size_t trace_pass, Mod& mod, instruction_ref ins, Ms&&... ms) /// Find matches in a module template -void find_matches(Mod& mod, Ms&&... ms) +struct find_matches { - for(auto ins : iterator_for(get_module(mod))) + find_matches(Mod& mod, Ms&&... ms, source_location location = source_location::current()) { - find_matches(0, mod, ins, ms...); + for(auto ins : iterator_for(get_module(mod))) + { + find_matches_for(location, mod, ins, ms...); + } } -} +}; -/// Find matches in a pass template -void find_matches(size_t trace_pass, Mod& mod, Ms&&... ms) -{ - for(auto ins : iterator_for(get_module(mod))) - { - find_matches(trace_pass, mod, ins, ms...); - } -} +find_matches(Mod& mod, Ms&&... ms) -> find_matches; template struct find_generic_match diff --git a/src/include/migraphx/source_location.hpp b/src/include/migraphx/source_location.hpp new file mode 100644 index 00000000000..facf9c4fac6 --- /dev/null +++ b/src/include/migraphx/source_location.hpp @@ -0,0 +1,73 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP + +#include + +#if defined(CPPCHECK) +#define MIGRAPHX_HAS_SOURCE_LOCATION 1 +#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 1 +#elif defined(__has_include) +#if __has_include() && __cplusplus >= 202003L +#define MIGRAPHX_HAS_SOURCE_LOCATION 1 +#else +#define MIGRAPHX_HAS_SOURCE_LOCATION 0 +#endif +#if __has_include() && __cplusplus >= 201103L +#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 1 +#else +#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 0 +#endif +#else +#define MIGRAPHX_HAS_SOURCE_LOCATION 0 +#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 0 +#endif + +#if MIGRAPHX_HAS_SOURCE_LOCATION +#include +#elif MIGRAPHX_HAS_SOURCE_LOCATION_TS +#include +#endif + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +#if MIGRAPHX_HAS_SOURCE_LOCATION +using source_location = std::source_location; +#elif MIGRAPHX_HAS_SOURCE_LOCATION_TS +using source_location = std::experimental::source_location; +#else +struct source_location +{ + static constexpr source_location current() noexcept { return source_location{}; } + constexpr std::uint_least32_t line() const noexcept { return 0; } + constexpr std::uint_least32_t column() const noexcept { return 0; } + constexpr const char* file_name() const noexcept { return ""; } + constexpr const char* function_name() const noexcept { return ""; } +}; +#endif + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP diff --git a/src/module.cpp b/src/module.cpp index fe6775a21a9..e7fb920d92b 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref if(ins == std::prev(this->end())) { + // "rep" instruction could be used earlier in the program and moving it at the end + // may cause invalid program, therefore make an identity operation in this case. return replace_instruction(ins, make_op("identity"), rep); } diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index b281e8bcd98..6a3c463df39 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -38,6 +38,9 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_ONNX_PARSER) static shape shape_from_dyn_dims(shape::type_t shape_type, const std::vector& dyn_dims) @@ -53,8 +56,6 @@ static shape shape_from_dyn_dims(shape::type_t shape_type, return {shape_type, dyn_dims}; } -namespace onnx { - static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node) { std::unordered_map result; @@ -297,16 +298,48 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) return version; } -std::vector -onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining) +void print_added_instructions(module* mod, + const std::vector& args, + const std::vector& result) +{ + // Print instructions added by the parser not in args + std::vector added_instructions; + fix([&](auto self, auto r) { + for(auto ins : r) + { + if(contains(args, ins)) + continue; + if(contains(added_instructions, ins)) + continue; + self(ins->inputs()); + added_instructions.push_back(ins); + } + })(result); + mod->debug_print(added_instructions); +} + +std::unordered_map +parse_intializer(const onnx_parser& parser, module* mod, const onnx::GraphProto& graph) { std::unordered_map mod_insts; for(auto&& f : graph.initializer()) { + if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{})) + std::cout << "initializer: " << f.name() << std::endl; // backup instructions in parent mod - mod_insts[f.name()] = mod->add_literal(parse_tensor(f)); + mod_insts[f.name()] = mod->add_literal(parser.parse_tensor(f)); + if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{})) + mod->debug_print(mod_insts[f.name()]); } + return mod_insts; +} +std::unordered_map +parse_inputs(const onnx_parser& parser, + module* mod, + const onnx::GraphProto& graph, + std::unordered_map mod_insts) +{ for(auto&& input : graph.input()) { const std::string& name = input.name(); @@ -317,7 +350,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini // scenario that a nested subgraph contains a parameter with the // name existed in its parent graph. // In the current implementation, MIGraphX throws an exception for that. - if(contains(instructions, name)) + if(contains(parser.instructions, name)) { MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name + "\" existing in parent graph!"); @@ -325,28 +358,41 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini shape s; std::vector dims; - if(map_input_dims.count(name) > 0) + if(parser.map_input_dims.count(name) > 0) { - dims = map_input_dims.at(name); - s = parse_type(input.type(), dims); + dims = parser.map_input_dims.at(name); + s = parser.parse_type(input.type(), dims); } - else if(map_dyn_input_dims.count(name) > 0) + else if(parser.map_dyn_input_dims.count(name) > 0) { shape::type_t shape_type = get_type(input.type().tensor_type().elem_type()); - s = shape_from_dyn_dims(shape_type, map_dyn_input_dims.at(name)); + s = shape_from_dyn_dims(shape_type, parser.map_dyn_input_dims.at(name)); } else { - s = parse_type(input.type(), dims); + s = parser.parse_type(input.type(), dims); } mod_insts[name] = mod->add_parameter(name, s); } } + return mod_insts; +} + +std::vector +onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining) +{ + std::unordered_map mod_insts = + parse_intializer(*this, mod, graph); + + mod_insts = parse_inputs(*this, mod, graph, mod_insts); std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end())); for(auto&& node : graph.node()) { + if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{})) + std::cout << "operator: " << node.op_type() << std::endl; + std::vector args; for(auto&& input : node.input()) { @@ -384,6 +430,11 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini result.begin(), std::inserter(instructions, instructions.end()), [](auto&& x, auto&& y) { return std::make_pair(x, y); }); + + if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{})) + { + print_added_instructions(mod, args, result); + } } // Find instructions corresponding to the output diff --git a/src/program.cpp b/src/program.cpp index 7a01a6b88b3..6f42ca903ce 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -366,7 +366,7 @@ void print_statistics(std::ostream& os, const argument& a) os << "Min value: " << *std::min_element(t.begin(), t.end()) << ", "; os << "Max value: " << *std::max_element(t.begin(), t.end()) << ", "; double num_elements = t.size(); - auto mean = std::reduce(t.begin(), t.end(), 0.0) / num_elements; + auto mean = std::accumulate(t.begin(), t.end(), 0.0) / num_elements; auto stddev = std::sqrt( std::accumulate(t.begin(), t.end(), diff --git a/src/quantize_fp16.cpp b/src/quantize_fp16.cpp index c8577b6afe5..b6d8bed4e89 100644 --- a/src/quantize_fp16.cpp +++ b/src/quantize_fp16.cpp @@ -52,14 +52,6 @@ static void quantize_module(module& m, const std::vector& ins_names auto mod_inputs = ins->module_inputs(); auto s = ins->get_shape(); - // Convert back to original type before quantizing the inputs - if(mod_inputs.empty()) - { - auto r = m.insert_instruction( - std::next(ins), make_op("convert", {{"target_type", s.type()}}), ins); - m.replace_instruction(ins, r); - } - // Convert each of the inputs that are floating point to fp16 auto inputs = ins->inputs(); std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { @@ -70,8 +62,17 @@ static void quantize_module(module& m, const std::vector& ins_names ins, make_op("convert", {{"target_type", shape::half_type}}), input); }); - // Replace inputs - m.replace_instruction(ins, ins->get_operator(), inputs, mod_inputs); + // Insert quantized ins + auto converted_ins = m.insert_instruction(ins, ins->get_operator(), inputs, mod_inputs); + + // Convert back to original type after quantizing + if(mod_inputs.empty()) + { + converted_ins = m.insert_instruction( + ins, make_op("convert", {{"target_type", s.type()}}), converted_ins); + } + // Replace original instruction + m.replace_instruction(ins, converted_ins); } } diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 931528f34c2..5bd0226b017 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -39,8 +39,6 @@ #include #include -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES) - namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -1487,13 +1485,10 @@ struct find_split_transpose void simplify_algebra::apply(module& m) const { - size_t trace = value_of(MIGRAPHX_TRACE_SIMPLIFY_ALGEBRA_MATCHES{}); - // Run simplifications multiple times for(int i = 0; i < 8; i++) { - match::find_matches(trace, - m, + match::find_matches(m, find_inner_broadcast{}, find_dot_broadcast{}, find_double_add_lit_broadcast{}, diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index 39dec285b88..6bbe9c21d66 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -111,9 +111,27 @@ struct compile_plan context* ctx; operation preop; instruction_ref ins; - optional config = nullopt; - std::vector results = {}; - void update_config() { config = get_tuning_config(*ctx, ins, preop); } + optional config = nullopt; + std::vector> results = {}; + void update_config(bool exhaustive) + { + config = get_tuning_config(*ctx, ins, preop, exhaustive); + } + template + void insert_compiles(Vector& compiles, const value& solution, std::size_t i) + { + compiles.emplace_back([=] { + try + { + results[i] = compiled_result{compile(*ctx, ins, preop, solution), ins}; + } + catch(...) + { + results[i] = nullopt; + } + }); + } + template void add_compiles(Vector& compiles, problem_cache& pc) { @@ -127,9 +145,7 @@ struct compile_plan if(solution.is_null()) return; results.resize(1); - compiles.emplace_back([=] { - results[0] = compiled_result{compile(*ctx, ins, preop, solution), ins}; - }); + insert_compiles(compiles, solution, 0); } else { @@ -139,18 +155,14 @@ struct compile_plan for(auto i : range(solutions.size())) { auto solution = solutions[i]; - compiles.emplace_back([=] { - results[i] = compiled_result{compile(*ctx, ins, preop, solution), ins}; - }); + insert_compiles(compiles, solution, i); } } } else { results.resize(1); - compiles.emplace_back([=] { - results[0] = compiled_result{compile(*ctx, ins, preop, value{}), ins}; - }); + insert_compiles(compiles, value{}, 0); } } const compiled_result& benchmark(problem_cache& pc) const @@ -158,7 +170,11 @@ struct compile_plan if(results.empty()) MIGRAPHX_THROW("No configs to tune"); if(results.size() == 1) - return results.front(); + { + if(not results.front().has_value()) + MIGRAPHX_THROW("No configs to tune"); + return *results.front(); + } if(not config) MIGRAPHX_THROW("Multiple kernels without config"); std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs" @@ -167,11 +183,17 @@ struct compile_plan times.reserve(results.size()); std::transform( results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) { - return time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first; + if(not cr.has_value()) + return std::numeric_limits::max(); + return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20) + .first; }); auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end())); + std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl; pc.insert(preop.name(), config->problem, config->solutions.at(i)); - return results[i]; + if(not results[i].has_value()) + MIGRAPHX_THROW("No valid tuned compilation."); + return *results[i]; } void replace(module& m, problem_cache& pc) const { @@ -185,7 +207,10 @@ void par_compile(std::size_t n, F f) { if(n == 0) return; - par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f); + auto d = value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}); + if(d == 0) + d = n; + par_for(n, n / d, f); } struct compile_manager @@ -202,9 +227,7 @@ struct compile_manager void update_configs() { - if(not exhaustive) - return; - par_compile(cps.size(), [&](auto i) { cps[i].update_config(); }); + par_compile(cps.size(), [&](auto i) { cps[i].update_config(exhaustive); }); } void compile(module& m) diff --git a/src/targets/gpu/compiler.cpp b/src/targets/gpu/compiler.cpp index 1c3a3971e24..5a4c5d702c6 100644 --- a/src/targets/gpu/compiler.cpp +++ b/src/targets/gpu/compiler.cpp @@ -63,9 +63,10 @@ compile_op(const std::string& name, context& ctx, const std::vector& inpu return compiler_map().at(name).compile_op(ctx, inputs, v); } -optional get_tuning_config(context& ctx, instruction_ref ins, const operation& op) +optional +get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive) { - return compiler_map().at(op.name()).get_tuning_config(ctx, ins, op); + return compiler_map().at(op.name()).get_tuning_config(ctx, ins, op, exhaustive); } } // namespace gpu diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index 64296273e32..fc3b3e773c8 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -83,10 +83,23 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) return false; auto a = ins->inputs().front()->get_shape(); auto b = ins->inputs().back()->get_shape(); + auto m = a.lens()[a.lens().size() - 2]; + auto n = b.lens().back(); + auto k = a.lens().back(); + // Integer gemms must be divisible by 4 in ck + if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type())) + { + if(m % 4 != 0) + return false; + if(n % 4 != 0) + return false; + if(k % 4 != 0) + return false; + } // Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy // to avoid poor-performing GEMM kernels from CK // To-do: Investigate a more precise strategy - return a.lens().back() <= 2048; + return k <= 2048; } struct find_ck_gemm_pointwise diff --git a/src/targets/gpu/include/migraphx/gpu/compiler.hpp b/src/targets/gpu/include/migraphx/gpu/compiler.hpp index 1aed5b84a44..d9d2262070d 100644 --- a/src/targets/gpu/include/migraphx/gpu/compiler.hpp +++ b/src/targets/gpu/include/migraphx/gpu/compiler.hpp @@ -79,7 +79,7 @@ using compiler_compile = using compiler_compile_op = std::function& inputs, const value&)>; using compiler_tuning_config = - std::function(context&, instruction_ref, const operation&)>; + std::function(context&, instruction_ref, const operation&, bool)>; void register_compiler(const std::string& name, compiler_compile c, @@ -91,7 +91,8 @@ compiler_replace compile(context& ctx, instruction_ref ins, const operation& op, const value& solution); operation compile_op(const std::string& name, context& ctx, const std::vector& inputs, const value& v); -optional get_tuning_config(context& ctx, instruction_ref ins, const operation& op); +optional +get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive); template void register_compiler() @@ -125,7 +126,8 @@ template struct compiler : auto_register_compiler { const Derived& derived() const { return static_cast(*this); } - optional get_tuning_config(context&, instruction_ref, const operation&) const + optional + get_tuning_config(context&, instruction_ref, const operation&, bool) const { return nullopt; } diff --git a/src/targets/gpu/jit/ck_gemm.cpp b/src/targets/gpu/jit/ck_gemm.cpp index 2bf084563f6..ac931ecebd1 100644 --- a/src/targets/gpu/jit/ck_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm.cpp @@ -50,6 +50,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING_VALUE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK); // NOLINTNEXTLINE static const char* const ck_gemm_kernel = R"__migraphx__( @@ -265,7 +266,7 @@ struct ck_gemm_compiler : compiler s = shape{s.type(), {m1, m2}}; } - std::vector names() const { return {"gpu::ck_gemm"}; } + std::vector names() const { return {"ck_gemm", "gpu::ck_gemm"}; } static bool standard_batch(const shape& s) { @@ -418,9 +419,7 @@ struct ck_gemm_compiler : compiler { auto shapes = to_shapes(ins->inputs()); auto v = create_settings(ins, op); - if(solution.is_null()) - v["tuning_value"] = 4; - else + if(not solution.is_null()) v["tuning_value"] = solution; return {compile_op(ctx, shapes, v), [=](module& m, instruction_ref ins2, const operation& code_object) { @@ -436,8 +435,10 @@ struct ck_gemm_compiler : compiler } optional - get_tuning_config(context& ctx, instruction_ref ins, const operation& op) const + get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive) const { + if(not exhaustive and not enabled(MIGRAPHX_TUNE_CK{})) + return nullopt; tuning_config tc; auto shapes = to_shapes(ins->inputs()); auto problem = create_problem(shapes, create_settings(ins, op)); diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp index 87bd4e58985..fb032ca7e96 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp @@ -52,7 +52,7 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds) ck::make_tuple(to_ck_tensor()...), to_ck_tensor()); - static_assert(desc.is_valid, "Invalid ck gemm."); + static_assert(desc.IsValid(), "Invalid ck gemm."); G::Run(desc, to_ck_const_pointer(a.data()), diff --git a/test/dead_code_elimination_test.cpp b/test/dead_code_elimination_test.cpp index 673b628391a..53a93ac7077 100644 --- a/test/dead_code_elimination_test.cpp +++ b/test/dead_code_elimination_test.cpp @@ -232,7 +232,6 @@ TEST_CASE(reused_twice) auto count = std::distance(mm->begin(), mm->end()); run_pass(p); - p.debug_print(); EXPECT(std::distance(mm->begin(), mm->end()) != count); EXPECT(std::distance(mm->begin(), mm->end()) == 4); } @@ -274,4 +273,17 @@ TEST_CASE(param_not_eliminated) EXPECT(p == create_program()); } +TEST_CASE(tuple_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(tuple_op{}, one, two); + mm->add_return({one, two}); + auto count = std::distance(mm->begin(), mm->end()); + run_pass(p); + EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/include/basic_ops.hpp b/test/include/basic_ops.hpp index 9a0d7236046..836197358a6 100644 --- a/test/include/basic_ops.hpp +++ b/test/include/basic_ops.hpp @@ -186,6 +186,21 @@ struct nop migraphx::shape compute_shape(const std::vector&) const { return {}; } }; +struct tuple_op +{ + std::string name() const { return "tuple_op"; } + migraphx::shape compute_shape(const std::vector& inputs) const + { + return {inputs}; + } + migraphx::argument compute(migraphx::context&, + const migraphx::shape&, + const std::vector& input_args) const + { + return input_args; + } +}; + inline migraphx::literal get_2x2(int base = 0) { return migraphx::literal{{migraphx::shape::float_type, {2, 2}}, diff --git a/test/quantization.cpp b/test/quantization.cpp index 66b5062486f..6fac7443c35 100644 --- a/test/quantization.cpp +++ b/test/quantization.cpp @@ -82,13 +82,17 @@ TEST_CASE(param_add) auto hp1 = mm->add_instruction(migraphx::make_op("convert"), p1); auto hp2 = mm->add_instruction(migraphx::make_op("convert"), p2); auto hs = mm->add_instruction(migraphx::make_op("add"), hp1, hp2); - auto res = mm->add_instruction( + auto fs = mm->add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), hs); if(add_return) { - mm->add_return({res}); + mm->add_return({fs}); + } + else + { + mm->add_instruction(migraphx::make_op("identity"), {fs}); } return p; @@ -159,10 +163,10 @@ TEST_CASE(param_add_sub) auto diff = mm->add_instruction(migraphx::make_op("sub"), sum, p2); auto hdiff = mm->add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), diff); - auto res = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1); - auto r = mm->add_instruction( - migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), res); - mm->add_return({r}); + auto hadd = mm->add_instruction(migraphx::make_op("add"), hdiff, hp1); + auto fadd = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), hadd); + mm->add_return({fadd}); return p; }; @@ -258,7 +262,8 @@ TEST_CASE(param_add_sub) }; auto p0 = create_program_float(); - migraphx::run_passes(p0, {migraphx::quantize_fp16_pass{{"all"}}}); + migraphx::run_passes( + p0, {migraphx::quantize_fp16_pass{{"all"}}, migraphx::dead_code_elimination{}}); EXPECT(p0 == create_program_fp16()); auto p1 = create_program_float(); @@ -278,7 +283,6 @@ TEST_CASE(literal_add) auto l1 = mm->add_literal(migraphx::literal(s, data)); auto l2 = mm->add_literal(migraphx::literal(s, data)); mm->add_instruction(migraphx::make_op("add"), l1, l2); - return p; }; @@ -291,11 +295,11 @@ TEST_CASE(literal_add) auto l1 = mm->add_literal(migraphx::literal(s, data)); auto l2 = mm->add_literal(migraphx::literal(s, data)); auto hs = mm->add_instruction(migraphx::make_op("add"), l1, l2); - mm->add_instruction( + auto fs = mm->add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), hs); - + mm->add_instruction(migraphx::make_op("identity"), fs); return p; }; diff --git a/test/verify/gemm_add_broadcast_half.cpp b/test/verify/gemm_add_broadcast_half.cpp new file mode 100644 index 00000000000..fb1918b1715 --- /dev/null +++ b/test/verify/gemm_add_broadcast_half.cpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include +struct gemm_add_broadcast_half : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::half_type, {1, 2, 3}}; + migraphx::shape m2_shape{migraphx::shape::half_type, {1, 3, 4}}; + migraphx::shape m3_shape{migraphx::shape::half_type, {1, 1, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + auto l3 = mm->add_parameter("3", m3_shape); + auto l3_b = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 2, 4}}}), l3); + + auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2); + mm->add_instruction(migraphx::make_op("add"), dot, l3_b); + return p; + } +}; diff --git a/test/verify/gemm_add_half.cpp b/test/verify/gemm_add_half.cpp new file mode 100644 index 00000000000..168fc853e6e --- /dev/null +++ b/test/verify/gemm_add_half.cpp @@ -0,0 +1,47 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include +struct gemm_add_half : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape m1_shape{migraphx::shape::half_type, {1, 2, 3}}; + migraphx::shape m2_shape{migraphx::shape::half_type, {1, 3, 4}}; + migraphx::shape m3_shape{migraphx::shape::half_type, {1, 2, 4}}; + auto l1 = mm->add_parameter("1", m1_shape); + auto l2 = mm->add_parameter("2", m2_shape); + auto l3 = mm->add_parameter("3", m3_shape); + + auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2); + mm->add_instruction(migraphx::make_op("add"), dot, l3); + return p; + } +};