From f1b4c0082ddfaaf3b0785d368f7b2a26853bc41e Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Tue, 6 Feb 2024 15:11:58 -0800 Subject: [PATCH 01/10] test --- .../core/providers/dml/dml_provider_factory.h | 2 + .../core/optimizer/graph_transformer.cc | 3 +- .../core/optimizer/graph_transformer_utils.cc | 285 +++++++++--------- winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp | 4 +- 4 files changed, 152 insertions(+), 142 deletions(-) diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index 7d7f05193f486..23cf551f3d997 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -26,6 +26,8 @@ typedef struct IDMLDevice IDMLDevice; #include "onnxruntime_c_api.h" +#define ENABLE_NPU_ADAPTER_ENUMERATION + #ifdef __cplusplus extern "C" { #endif diff --git a/onnxruntime/core/optimizer/graph_transformer.cc b/onnxruntime/core/optimizer/graph_transformer.cc index 37093496a66fa..9b5d005cb9c3f 100644 --- a/onnxruntime/core/optimizer/graph_transformer.cc +++ b/onnxruntime/core/optimizer/graph_transformer.cc @@ -12,7 +12,8 @@ Status GraphTransformer::Apply(Graph& graph, bool& modified, const logging::Logg // ORT_RETURN_IF_ERROR(graph.Resolve()); auto status = ApplyImpl(graph, modified, 0, logger); - LOGS(logger, INFO) << "GraphTransformer " << Name() << " modified: " << modified << " with status: " << status; + LOGS(logger, INFO) << "[" << std::chrono::duration_cast(std::chrono::high_resolution_clock::now().time_since_epoch()).count() + << "] GraphTransformer " << Name() << " modified: " << modified << " with status: " << status; ORT_RETURN_IF_ERROR(status); #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index cd3c49be15aa4..f4fb262612698 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -85,6 +85,8 @@ #endif // !defined(ORT_MINIMAL_BUILD) +#define UNREFERENCED_PARAMETER(P) (P) + namespace onnxruntime::optimizer_utils { static void FilterTransformers(InlinedVector>& transformers, @@ -194,6 +196,12 @@ InlinedVector> GenerateTransformers( const InlinedHashSet dml_ep = {onnxruntime::kDmlExecutionProvider}; AllocatorPtr cpu_allocator = std::make_shared(); + UNREFERENCED_PARAMETER(disable_quant_qdq); + UNREFERENCED_PARAMETER(cpu_ep); + UNREFERENCED_PARAMETER(dml_ep); + UNREFERENCED_PARAMETER(session_options); + UNREFERENCED_PARAMETER(cpu_execution_provider); + switch (level) { case TransformerLevel::Level1: { // RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run) @@ -214,13 +222,13 @@ InlinedVector> GenerateTransformers( // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by // default, CSE will not merge them, because the different initializers are represented by different NodeArg. - InlinedHashSet excluded_initializers; - excluded_initializers.reserve(session_options.initializers_to_share_map.size()); - for (const auto& p : session_options.initializers_to_share_map) { - excluded_initializers.insert(p.first); - } - const InlinedHashSet no_limit_empty_ep_list = {}; - transformers.emplace_back(std::make_unique(no_limit_empty_ep_list, excluded_initializers)); + // InlinedHashSet excluded_initializers; + // excluded_initializers.reserve(session_options.initializers_to_share_map.size()); + // for (const auto& p : session_options.initializers_to_share_map) { + // excluded_initializers.insert(p.first); + // } + // const InlinedHashSet no_limit_empty_ep_list = {}; + // transformers.emplace_back(std::make_unique(no_limit_empty_ep_list, excluded_initializers)); transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq, @@ -240,144 +248,145 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique()); } - // add __backwardpass attribute to nodes after YieldOp, ROCm-only - const InlinedHashSet rocm_ep = {onnxruntime::kRocmExecutionProvider}; - transformers.emplace_back(std::make_unique(rocm_ep)); - // run TransposeOptimizer last as it works in a slightly different way by moving Transpose nodes around. // shouldn't affect the end result - just easier to debug any issue if it's last. transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); - } break; - - case TransformerLevel::Level2: { - // we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be - // applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2). - transformers.emplace_back(std::make_unique(std::move(cpu_allocator), kCpuExecutionProvider)); - - const bool enable_quant_qdq_cleanup = - session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableQuantQDQCleanup, "0") == "1"; -#if !defined(DISABLE_CONTRIB_OPS) - const bool qdq_is_int8_allowed = - session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQIsInt8Allowed, - QDQIsInt8Allowed() ? "1" : "0") == "1"; - const bool enable_gelu_approximation = - session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableGeluApproximation, "0") == "1"; - - const InlinedHashSet cuda_rocm_eps = {onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider}; - const InlinedHashSet cpu_cuda_rocm_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider}; - const InlinedHashSet cpu_cuda_dml_rocm_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kDmlExecutionProvider}; - const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kAclExecutionProvider, - onnxruntime::kArmNNExecutionProvider, - onnxruntime::kJsExecutionProvider}; - -#ifdef MLAS_TARGET_AMD64_IX86 - const bool avx2_precision_mode = - session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); -#else - const bool avx2_precision_mode = false; -#endif - if (!disable_quant_qdq) { - // currently we don't support QDQS8ToU8Transformer in a minimal build and if supported, this needs to run in - // Level 1 during export and not Level 2 at runtime as it would result in overlapping optimizations which - // runtime optimization does not support, so add session config value here to force qdqisint8allowed to be true. - if (!qdq_is_int8_allowed) { - transformers.emplace_back(std::make_unique(avx2_precision_mode, cpu_ep)); - } - transformers.emplace_back(std::make_unique(qdq_is_int8_allowed)); - } - - transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_ep)); - - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); - - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - // GeluApproximation has side effects which may change results. It needs to be manually enabled, - // or alternatively the model can be updated offline using a model conversion script - // e.g. fusion_gelu_approximation function used by onnxruntime/python/tools/transformers/onnx_model_bert.py - if (enable_gelu_approximation) { - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - } - -#ifdef ENABLE_TRITON - if (training::framework::triton::TritonOpExecutor::Instance().IsInitialized()) { - transformers.emplace_back( - std::make_unique(training::framework::triton::TritonOpExecutor::Instance().GetConfigJson(), - InlinedHashSet{onnxruntime::kCudaExecutionProvider})); - } -#endif // ENABLE_TRITON - - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cuda_rocm_eps)); -#ifdef ENABLE_TRAINING - transformers.emplace_back(std::make_unique(cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); -#endif - - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(dml_ep)); - -#ifdef MLAS_TARGET_AMD64_IX86 - if (avx2_precision_mode) { - transformers.emplace_back(std::make_unique(cpu_ep)); - } -#endif - -#endif // !defined(DISABLE_CONTRIB_OPS) - // The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their - // fusions might be prevented if this one removes a Q/DQ node too early. - transformers.emplace_back(std::make_unique(enable_quant_qdq_cleanup)); + // add __backwardpass attribute to nodes after YieldOp, ROCm-only + // const InlinedHashSet rocm_ep = {onnxruntime::kRocmExecutionProvider}; + // transformers.emplace_back(std::make_unique(rocm_ep)); } break; - case TransformerLevel::Level3: { -#ifndef DISABLE_CONTRIB_OPS - // Register the NCHWc layout transformer if supported by the platform. - if (MlasNchwcGetBlockSize() > 1) { - transformers.emplace_back(std::make_unique()); - } - - auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); - auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); - if (nhwc_transformer->IsActive()) { - transformers.emplace_back(std::move(nhwc_transformer)); - } - - // NchwcTransformer must have a higher priority than ConvAddActivationFusion. NchwcTransformer does similar - // fusions targeting CPU but also reorders the layout to NCHWc which is expected to be more efficient but is - // only available on x86-64. - // PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu, - // while we can fuse more activation. - transformers.emplace_back(std::make_unique(cpu_ep)); -#endif - - } break; +// case TransformerLevel::Level2: { +// // we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be +// // applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2). +// transformers.emplace_back(std::make_unique(std::move(cpu_allocator), kCpuExecutionProvider)); + +// const bool enable_quant_qdq_cleanup = +// session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableQuantQDQCleanup, "0") == "1"; +// #if !defined(DISABLE_CONTRIB_OPS) +// const bool qdq_is_int8_allowed = +// session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQIsInt8Allowed, +// QDQIsInt8Allowed() ? "1" : "0") == "1"; +// const bool enable_gelu_approximation = +// session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableGeluApproximation, "0") == "1"; + +// const InlinedHashSet cuda_rocm_eps = {onnxruntime::kCudaExecutionProvider, +// onnxruntime::kRocmExecutionProvider}; +// const InlinedHashSet cpu_cuda_rocm_eps = {onnxruntime::kCpuExecutionProvider, +// onnxruntime::kCudaExecutionProvider, +// onnxruntime::kRocmExecutionProvider}; +// const InlinedHashSet cpu_cuda_dml_rocm_eps = {onnxruntime::kCpuExecutionProvider, +// onnxruntime::kCudaExecutionProvider, +// onnxruntime::kRocmExecutionProvider, +// onnxruntime::kDmlExecutionProvider}; +// const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, +// onnxruntime::kCudaExecutionProvider, +// onnxruntime::kRocmExecutionProvider, +// onnxruntime::kAclExecutionProvider, +// onnxruntime::kArmNNExecutionProvider, +// onnxruntime::kJsExecutionProvider}; + +// #ifdef MLAS_TARGET_AMD64_IX86 +// const bool avx2_precision_mode = +// session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); +// #else +// const bool avx2_precision_mode = false; +// #endif +// if (!disable_quant_qdq) { +// // currently we don't support QDQS8ToU8Transformer in a minimal build and if supported, this needs to run in +// // Level 1 during export and not Level 2 at runtime as it would result in overlapping optimizations which +// // runtime optimization does not support, so add session config value here to force qdqisint8allowed to be true. +// if (!qdq_is_int8_allowed) { +// transformers.emplace_back(std::make_unique(avx2_precision_mode, cpu_ep)); +// } +// transformers.emplace_back(std::make_unique(qdq_is_int8_allowed)); +// } + +// transformers.emplace_back(std::make_unique(cpu_ep)); +// transformers.emplace_back(std::make_unique(cpu_ep)); +// transformers.emplace_back(std::make_unique(cpu_ep)); + +// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); + +// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); +// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); +// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); +// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); +// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); +// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); +// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + +// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); +// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + +// transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + +// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); +// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + +// // GeluApproximation has side effects which may change results. It needs to be manually enabled, +// // or alternatively the model can be updated offline using a model conversion script +// // e.g. fusion_gelu_approximation function used by onnxruntime/python/tools/transformers/onnx_model_bert.py +// if (enable_gelu_approximation) { +// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); +// } + +// #ifdef ENABLE_TRITON +// if (training::framework::triton::TritonOpExecutor::Instance().IsInitialized()) { +// transformers.emplace_back( +// std::make_unique(training::framework::triton::TritonOpExecutor::Instance().GetConfigJson(), +// InlinedHashSet{onnxruntime::kCudaExecutionProvider})); +// } +// #endif // ENABLE_TRITON + +// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); +// transformers.emplace_back(std::make_unique(cuda_rocm_eps)); +// #ifdef ENABLE_TRAINING +// transformers.emplace_back(std::make_unique(cuda_rocm_eps)); +// transformers.emplace_back(std::make_unique(cuda_rocm_eps)); +// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); +// #endif + +// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); +// //transformers.emplace_back(std::make_unique(dml_ep)); + +// #ifdef MLAS_TARGET_AMD64_IX86 +// if (avx2_precision_mode) { +// transformers.emplace_back(std::make_unique(cpu_ep)); +// } +// #endif + +// #endif // !defined(DISABLE_CONTRIB_OPS) +// // The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their +// // fusions might be prevented if this one removes a Q/DQ node too early. +// transformers.emplace_back(std::make_unique(enable_quant_qdq_cleanup)); + +// } break; + +// case TransformerLevel::Level3: { +// #ifndef DISABLE_CONTRIB_OPS +// // Register the NCHWc layout transformer if supported by the platform. +// if (MlasNchwcGetBlockSize() > 1) { +// transformers.emplace_back(std::make_unique()); +// } + +// // auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); +// // auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); +// // if (nhwc_transformer->IsActive()) { +// // transformers.emplace_back(std::move(nhwc_transformer)); +// // } + +// // NchwcTransformer must have a higher priority than ConvAddActivationFusion. NchwcTransformer does similar +// // fusions targeting CPU but also reorders the layout to NCHWc which is expected to be more efficient but is +// // only available on x86-64. +// // PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu, +// // while we can fuse more activation. +// //transformers.emplace_back(std::make_unique(cpu_ep)); +// #endif + + // } break; default: ORT_THROW("Unsupported optimization level: ", static_cast(level)); diff --git a/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp index 56a360218fa1d..6cc40f14cc2e3 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp @@ -164,9 +164,7 @@ static void __stdcall WinmlOrtLoggingCallback( ); } - if (debug_output_) { - OutputDebugStringA((std::string(message) + "\r\n").c_str()); - } + printf((std::string(message) + "\r\n").c_str()); } static void __stdcall WinmlOrtProfileEventCallback(const OrtProfilerEventRecord* profiler_record) noexcept { From a10768964daab7fca5af633e87bfc3662374d71c Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Thu, 8 Feb 2024 11:50:46 -0800 Subject: [PATCH 02/10] improve session creation performance --- onnxruntime/core/graph/graph.cc | 40 ++- onnxruntime/core/graph/graph_viewer.cc | 18 +- .../core/optimizer/graph_transformer_utils.cc | 276 +++++++++--------- .../ort_optimizer_api_impl.cc | 2 +- winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp | 4 +- 5 files changed, 184 insertions(+), 156 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 902839bee04ba..cf4eb2407f979 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,24 +1818,43 @@ void Graph::ReverseDFSFrom(gsl::span from, } } +struct PQ { + std::list list_; + const std::function& comparator_ = nullptr; + PQ(const std::function& comp) : + comparator_(comp) + {} + + void push_back(const Node* node) { + list_.push_back(node); + for (int i = 0; i < log(list_.size()); i++) { + comparator_(node, list_.front()); + } + } + bool empty() { return list_.empty();} + const Node* front(){ return list_.front(); } + void pop_front(){ list_.pop_front(); } +}; + #if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { - std::unordered_map in_degree; - std::priority_queue, decltype(comp)> to_visit(comp); - std::vector topo_order; + InlinedVector in_degree(MaxNodeIndex(), 0); + //std::priority_queue, decltype(comp)> to_visit(comp); + PQ to_visit(comp); + InlinedVector topo_order; for (auto& node : Nodes()) { size_t input_edge_count = node.GetInputEdgesCount(); - in_degree.insert({node.Index(), input_edge_count}); + in_degree[node.Index()] = input_edge_count; if (input_edge_count == 0) { - to_visit.push(&node); + to_visit.push_back(&node); } } while (!to_visit.empty()) { - const Node* current = to_visit.top(); - to_visit.pop(); + const Node* current = to_visit.front(); + to_visit.pop_front(); if (!current) continue; @@ -1844,10 +1863,11 @@ void Graph::KahnsTopologicalSort(const std::function& enter, } for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) { - in_degree[node_it->Index()]--; + auto& node_in_degree = in_degree[node_it->Index()]; + node_in_degree--; - if (in_degree[node_it->Index()] == 0) { - to_visit.push(&*node_it); + if (node_in_degree == 0) { + to_visit.push_back(&*node_it); } } topo_order.push_back(current->Index()); diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index acf7b3a16541f..119d420066a84 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const { struct PriorityNodeCompare { inline bool IsHighPri(const Node* n) const { // local statics so we can compare std::strings in the checks - static const std::string shape_op("Shape"); - static const std::string size_op("Size"); + static constexpr std::string_view shape_op("Shape"); + static constexpr std::string_view size_op("Size"); const auto& op_type = n->OpType(); return op_type == shape_op || op_type == size_op; @@ -26,15 +26,20 @@ struct PriorityNodeCompare { // If return true, n2 will be output first bool operator()(const Node* n1, const Node* n2) const { // nodes in global high priority list will be output first - if (IsHighPri(n1) != IsHighPri(n2)) { - return IsHighPri(n2); + const bool isN1HighPri = IsHighPri(n1); + const bool isN2HighPri = IsHighPri(n2); + if (isN1HighPri != isN2HighPri) { + return isN2HighPri; } // nodes with lower priority value will be output first - if (n1->Priority() != n2->Priority()) { - return n1->Priority() > n2->Priority(); + const auto n1_priority = n1->Priority(); + const auto n2_priority = n2->Priority(); + if (n1_priority != n2_priority) { + return n1_priority > n2_priority; } +#ifdef ENABLE_TRAINING // nodes of forward pass will be output first auto n1_attrs = n1->GetAttributes(); auto n2_attrs = n2->GetAttributes(); @@ -45,6 +50,7 @@ struct PriorityNodeCompare { if (n1_is_forward != n2_is_forward) { return n2_is_forward > n1_is_forward; } +#endif // otherwise, nodes with lower index will be output first return n1->Index() > n2->Index(); diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index f4fb262612698..0b200bbe16348 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -222,13 +222,13 @@ InlinedVector> GenerateTransformers( // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by // default, CSE will not merge them, because the different initializers are represented by different NodeArg. - // InlinedHashSet excluded_initializers; - // excluded_initializers.reserve(session_options.initializers_to_share_map.size()); - // for (const auto& p : session_options.initializers_to_share_map) { - // excluded_initializers.insert(p.first); - // } - // const InlinedHashSet no_limit_empty_ep_list = {}; - // transformers.emplace_back(std::make_unique(no_limit_empty_ep_list, excluded_initializers)); + InlinedHashSet excluded_initializers; + excluded_initializers.reserve(session_options.initializers_to_share_map.size()); + for (const auto& p : session_options.initializers_to_share_map) { + excluded_initializers.insert(p.first); + } + const InlinedHashSet no_limit_empty_ep_list = {}; + transformers.emplace_back(std::make_unique(no_limit_empty_ep_list, excluded_initializers)); transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq, @@ -253,140 +253,140 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); // add __backwardpass attribute to nodes after YieldOp, ROCm-only - // const InlinedHashSet rocm_ep = {onnxruntime::kRocmExecutionProvider}; - // transformers.emplace_back(std::make_unique(rocm_ep)); + const InlinedHashSet rocm_ep = {onnxruntime::kRocmExecutionProvider}; + transformers.emplace_back(std::make_unique(rocm_ep)); + + } break; + + case TransformerLevel::Level2: { + // we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be + // applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2). + transformers.emplace_back(std::make_unique(std::move(cpu_allocator), kCpuExecutionProvider)); + + const bool enable_quant_qdq_cleanup = + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableQuantQDQCleanup, "0") == "1"; +#if !defined(DISABLE_CONTRIB_OPS) + const bool qdq_is_int8_allowed = + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQIsInt8Allowed, + QDQIsInt8Allowed() ? "1" : "0") == "1"; + const bool enable_gelu_approximation = + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableGeluApproximation, "0") == "1"; + + const InlinedHashSet cuda_rocm_eps = {onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider}; + const InlinedHashSet cpu_cuda_rocm_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider}; + const InlinedHashSet cpu_cuda_dml_rocm_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kDmlExecutionProvider}; + const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider}; + +#ifdef MLAS_TARGET_AMD64_IX86 + const bool avx2_precision_mode = + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); +#else + const bool avx2_precision_mode = false; +#endif + if (!disable_quant_qdq) { + // currently we don't support QDQS8ToU8Transformer in a minimal build and if supported, this needs to run in + // Level 1 during export and not Level 2 at runtime as it would result in overlapping optimizations which + // runtime optimization does not support, so add session config value here to force qdqisint8allowed to be true. + if (!qdq_is_int8_allowed) { + transformers.emplace_back(std::make_unique(avx2_precision_mode, cpu_ep)); + } + transformers.emplace_back(std::make_unique(qdq_is_int8_allowed)); + } + + transformers.emplace_back(std::make_unique(cpu_ep)); + transformers.emplace_back(std::make_unique(cpu_ep)); + transformers.emplace_back(std::make_unique(cpu_ep)); + + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); + + transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + + transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + + transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + + // GeluApproximation has side effects which may change results. It needs to be manually enabled, + // or alternatively the model can be updated offline using a model conversion script + // e.g. fusion_gelu_approximation function used by onnxruntime/python/tools/transformers/onnx_model_bert.py + if (enable_gelu_approximation) { + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + } + +#ifdef ENABLE_TRITON + if (training::framework::triton::TritonOpExecutor::Instance().IsInitialized()) { + transformers.emplace_back( + std::make_unique(training::framework::triton::TritonOpExecutor::Instance().GetConfigJson(), + InlinedHashSet{onnxruntime::kCudaExecutionProvider})); + } +#endif // ENABLE_TRITON + + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cuda_rocm_eps)); +#ifdef ENABLE_TRAINING + transformers.emplace_back(std::make_unique(cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); +#endif + + transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(dml_ep)); + +#ifdef MLAS_TARGET_AMD64_IX86 + if (avx2_precision_mode) { + transformers.emplace_back(std::make_unique(cpu_ep)); + } +#endif + +#endif // !defined(DISABLE_CONTRIB_OPS) + // The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their + // fusions might be prevented if this one removes a Q/DQ node too early. + transformers.emplace_back(std::make_unique(enable_quant_qdq_cleanup)); } break; -// case TransformerLevel::Level2: { -// // we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be -// // applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2). -// transformers.emplace_back(std::make_unique(std::move(cpu_allocator), kCpuExecutionProvider)); - -// const bool enable_quant_qdq_cleanup = -// session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableQuantQDQCleanup, "0") == "1"; -// #if !defined(DISABLE_CONTRIB_OPS) -// const bool qdq_is_int8_allowed = -// session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQIsInt8Allowed, -// QDQIsInt8Allowed() ? "1" : "0") == "1"; -// const bool enable_gelu_approximation = -// session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableGeluApproximation, "0") == "1"; - -// const InlinedHashSet cuda_rocm_eps = {onnxruntime::kCudaExecutionProvider, -// onnxruntime::kRocmExecutionProvider}; -// const InlinedHashSet cpu_cuda_rocm_eps = {onnxruntime::kCpuExecutionProvider, -// onnxruntime::kCudaExecutionProvider, -// onnxruntime::kRocmExecutionProvider}; -// const InlinedHashSet cpu_cuda_dml_rocm_eps = {onnxruntime::kCpuExecutionProvider, -// onnxruntime::kCudaExecutionProvider, -// onnxruntime::kRocmExecutionProvider, -// onnxruntime::kDmlExecutionProvider}; -// const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, -// onnxruntime::kCudaExecutionProvider, -// onnxruntime::kRocmExecutionProvider, -// onnxruntime::kAclExecutionProvider, -// onnxruntime::kArmNNExecutionProvider, -// onnxruntime::kJsExecutionProvider}; - -// #ifdef MLAS_TARGET_AMD64_IX86 -// const bool avx2_precision_mode = -// session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); -// #else -// const bool avx2_precision_mode = false; -// #endif -// if (!disable_quant_qdq) { -// // currently we don't support QDQS8ToU8Transformer in a minimal build and if supported, this needs to run in -// // Level 1 during export and not Level 2 at runtime as it would result in overlapping optimizations which -// // runtime optimization does not support, so add session config value here to force qdqisint8allowed to be true. -// if (!qdq_is_int8_allowed) { -// transformers.emplace_back(std::make_unique(avx2_precision_mode, cpu_ep)); -// } -// transformers.emplace_back(std::make_unique(qdq_is_int8_allowed)); -// } - -// transformers.emplace_back(std::make_unique(cpu_ep)); -// transformers.emplace_back(std::make_unique(cpu_ep)); -// transformers.emplace_back(std::make_unique(cpu_ep)); - -// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); - -// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); -// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); -// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); -// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); -// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); -// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); -// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - -// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); -// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - -// transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - -// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); -// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - -// // GeluApproximation has side effects which may change results. It needs to be manually enabled, -// // or alternatively the model can be updated offline using a model conversion script -// // e.g. fusion_gelu_approximation function used by onnxruntime/python/tools/transformers/onnx_model_bert.py -// if (enable_gelu_approximation) { -// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); -// } - -// #ifdef ENABLE_TRITON -// if (training::framework::triton::TritonOpExecutor::Instance().IsInitialized()) { -// transformers.emplace_back( -// std::make_unique(training::framework::triton::TritonOpExecutor::Instance().GetConfigJson(), -// InlinedHashSet{onnxruntime::kCudaExecutionProvider})); -// } -// #endif // ENABLE_TRITON - -// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); -// transformers.emplace_back(std::make_unique(cuda_rocm_eps)); -// #ifdef ENABLE_TRAINING -// transformers.emplace_back(std::make_unique(cuda_rocm_eps)); -// transformers.emplace_back(std::make_unique(cuda_rocm_eps)); -// transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); -// #endif - -// //transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); -// //transformers.emplace_back(std::make_unique(dml_ep)); - -// #ifdef MLAS_TARGET_AMD64_IX86 -// if (avx2_precision_mode) { -// transformers.emplace_back(std::make_unique(cpu_ep)); -// } -// #endif - -// #endif // !defined(DISABLE_CONTRIB_OPS) -// // The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their -// // fusions might be prevented if this one removes a Q/DQ node too early. -// transformers.emplace_back(std::make_unique(enable_quant_qdq_cleanup)); - -// } break; - -// case TransformerLevel::Level3: { -// #ifndef DISABLE_CONTRIB_OPS -// // Register the NCHWc layout transformer if supported by the platform. -// if (MlasNchwcGetBlockSize() > 1) { -// transformers.emplace_back(std::make_unique()); -// } - -// // auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); -// // auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); -// // if (nhwc_transformer->IsActive()) { -// // transformers.emplace_back(std::move(nhwc_transformer)); -// // } - -// // NchwcTransformer must have a higher priority than ConvAddActivationFusion. NchwcTransformer does similar -// // fusions targeting CPU but also reorders the layout to NCHWc which is expected to be more efficient but is -// // only available on x86-64. -// // PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu, -// // while we can fuse more activation. -// //transformers.emplace_back(std::make_unique(cpu_ep)); -// #endif - - // } break; + case TransformerLevel::Level3: { +#ifndef DISABLE_CONTRIB_OPS + // Register the NCHWc layout transformer if supported by the platform. + if (MlasNchwcGetBlockSize() > 1) { + transformers.emplace_back(std::make_unique()); + } + + auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); + auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); + if (nhwc_transformer->IsActive()) { + transformers.emplace_back(std::move(nhwc_transformer)); + } + + // NchwcTransformer must have a higher priority than ConvAddActivationFusion. NchwcTransformer does similar + // fusions targeting CPU but also reorders the layout to NCHWc which is expected to be more efficient but is + // only available on x86-64. + // PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu, + // while we can fuse more activation. + transformers.emplace_back(std::make_unique(cpu_ep)); +#endif + + } break; default: ORT_THROW("Unsupported optimization level: ", static_cast(level)); diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index d9f08ffe1171e..c532f56b3d3d9 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -115,7 +115,7 @@ class ApiGraph final : public api::GraphRef { const auto& graph_outputs = graph_.GetOutputs(); graph_outputs_.reserve(graph_outputs.size()); for (const auto* output : graph_outputs) { - graph_outputs_.insert(output->Name()); + graph_outputs_.emplace(output->Name()); } } diff --git a/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp index 6cc40f14cc2e3..56a360218fa1d 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp @@ -164,7 +164,9 @@ static void __stdcall WinmlOrtLoggingCallback( ); } - printf((std::string(message) + "\r\n").c_str()); + if (debug_output_) { + OutputDebugStringA((std::string(message) + "\r\n").c_str()); + } } static void __stdcall WinmlOrtProfileEventCallback(const OrtProfilerEventRecord* profiler_record) noexcept { From 1a6ec723d14a1c62781c9f0b57bb32c81779d0bc Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Thu, 8 Feb 2024 12:32:56 -0800 Subject: [PATCH 03/10] cleanup --- .../core/providers/dml/dml_provider_factory.h | 2 -- onnxruntime/core/graph/graph.cc | 31 ++++++++++--------- .../core/optimizer/graph_transformer.cc | 3 +- .../core/optimizer/graph_transformer_utils.cc | 7 ++--- 4 files changed, 21 insertions(+), 22 deletions(-) diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index 23cf551f3d997..7d7f05193f486 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -26,8 +26,6 @@ typedef struct IDMLDevice IDMLDevice; #include "onnxruntime_c_api.h" -#define ENABLE_NPU_ADAPTER_ENUMERATION - #ifdef __cplusplus extern "C" { #endif diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index cf4eb2407f979..9aac97e33d053 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1825,36 +1825,39 @@ struct PQ { comparator_(comp) {} - void push_back(const Node* node) { - list_.push_back(node); - for (int i = 0; i < log(list_.size()); i++) { - comparator_(node, list_.front()); - } + void push(const Node* node) { + list_.insert + ( + std::upper_bound( list_.begin(), list_.end(), node, comparator_), + node + ); } bool empty() { return list_.empty();} - const Node* front(){ return list_.front(); } - void pop_front(){ list_.pop_front(); } + const Node* top(){ return list_.front(); } + void pop(){ list_.pop_front(); } }; #if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { InlinedVector in_degree(MaxNodeIndex(), 0); - //std::priority_queue, decltype(comp)> to_visit(comp); - PQ to_visit(comp); InlinedVector topo_order; + PQ to_visit(comp); + + auto number_of_nodes = NumberOfNodes(); + topo_order.reserve(number_of_nodes); for (auto& node : Nodes()) { size_t input_edge_count = node.GetInputEdgesCount(); in_degree[node.Index()] = input_edge_count; if (input_edge_count == 0) { - to_visit.push_back(&node); + to_visit.push(&node); } } while (!to_visit.empty()) { - const Node* current = to_visit.front(); - to_visit.pop_front(); + const Node* current = to_visit.top(); + to_visit.pop(); if (!current) continue; @@ -1867,13 +1870,13 @@ void Graph::KahnsTopologicalSort(const std::function& enter, node_in_degree--; if (node_in_degree == 0) { - to_visit.push_back(&*node_it); + to_visit.push(&*node_it); } } topo_order.push_back(current->Index()); } - if (NumberOfNodes() != static_cast(topo_order.size())) { + if (number_of_nodes != static_cast(topo_order.size())) { ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle."); } } diff --git a/onnxruntime/core/optimizer/graph_transformer.cc b/onnxruntime/core/optimizer/graph_transformer.cc index 9b5d005cb9c3f..37093496a66fa 100644 --- a/onnxruntime/core/optimizer/graph_transformer.cc +++ b/onnxruntime/core/optimizer/graph_transformer.cc @@ -12,8 +12,7 @@ Status GraphTransformer::Apply(Graph& graph, bool& modified, const logging::Logg // ORT_RETURN_IF_ERROR(graph.Resolve()); auto status = ApplyImpl(graph, modified, 0, logger); - LOGS(logger, INFO) << "[" << std::chrono::duration_cast(std::chrono::high_resolution_clock::now().time_since_epoch()).count() - << "] GraphTransformer " << Name() << " modified: " << modified << " with status: " << status; + LOGS(logger, INFO) << "GraphTransformer " << Name() << " modified: " << modified << " with status: " << status; ORT_RETURN_IF_ERROR(status); #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 0b200bbe16348..c60ecb07bf28d 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -248,14 +248,13 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique()); } - // run TransposeOptimizer last as it works in a slightly different way by moving Transpose nodes around. - // shouldn't affect the end result - just easier to debug any issue if it's last. - transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); - // add __backwardpass attribute to nodes after YieldOp, ROCm-only const InlinedHashSet rocm_ep = {onnxruntime::kRocmExecutionProvider}; transformers.emplace_back(std::make_unique(rocm_ep)); + // run TransposeOptimizer last as it works in a slightly different way by moving Transpose nodes around. + // shouldn't affect the end result - just easier to debug any issue if it's last. + transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); } break; case TransformerLevel::Level2: { From dc22d4277f51ac5ddf194eb6e5ebf5796f33fb3f Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Thu, 8 Feb 2024 16:53:49 -0800 Subject: [PATCH 04/10] template priority queue and switch to emplace --- onnxruntime/core/graph/graph.cc | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 9aac97e33d053..2127cbfdd4b75 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,22 +1818,22 @@ void Graph::ReverseDFSFrom(gsl::span from, } } -struct PQ { - std::list list_; - const std::function& comparator_ = nullptr; - PQ(const std::function& comp) : - comparator_(comp) - {} - - void push(const Node* node) { +template +struct VisitorPriorityQueue { + using ComparatorType = std::function; + std::list list_; + const ComparatorType comparator_ = nullptr; + VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} + + void push(T node) { list_.insert ( - std::upper_bound( list_.begin(), list_.end(), node, comparator_), + std::upper_bound(list_.begin(), list_.end(), node, comparator_), node ); } - bool empty() { return list_.empty();} - const Node* top(){ return list_.front(); } + bool empty() { return list_.empty(); } + T top(){ return list_.front(); } void pop(){ list_.pop_front(); } }; @@ -1842,7 +1842,7 @@ void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { InlinedVector in_degree(MaxNodeIndex(), 0); InlinedVector topo_order; - PQ to_visit(comp); + VisitorPriorityQueue to_visit(comp); auto number_of_nodes = NumberOfNodes(); topo_order.reserve(number_of_nodes); @@ -1879,6 +1879,10 @@ void Graph::KahnsTopologicalSort(const std::function& enter, if (number_of_nodes != static_cast(topo_order.size())) { ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle."); } + + // for (auto i : topo_order) { + // printf("%d\n", static_cast(i)); + // } } GSL_SUPPRESS(es.84) // noisy warning about ignoring return value from insert(...) @@ -2866,7 +2870,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { const gsl::not_null tensor_added{graph_proto_->add_initializer()}; *(tensor_added) = tensor; - name_to_initial_tensor_[tensor.name()] = tensor_added; + name_to_initial_tensor_.emplace(tensor.name(), tensor_added); SetGraphResolveNeeded(); if (!is_loaded_from_model_file_ && GetNodeArg(tensor.name()) == nullptr) { // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs. From 65ca6ad81b162de830419071b7aaf0e45a68b082 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Thu, 8 Feb 2024 18:21:54 -0800 Subject: [PATCH 05/10] limit noop to load initializers for only mul/div/sub/add --- .../core/optimizer/noop_elimination.cc | 74 +++++++++++-------- 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc index b3c2991d54b28..68dc7bc8c5120 100644 --- a/onnxruntime/core/optimizer/noop_elimination.cc +++ b/onnxruntime/core/optimizer/noop_elimination.cc @@ -42,49 +42,63 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con // if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule, // but it won't happen if the case is accepted, thus reject it - auto initializer_rank = initializer->dims().size(); + const auto& dims = initializer->dims(); + auto initializer_rank = dims.size(); const auto* other_input_shape = node.InputDefs()[input0_is_initializer ? 1 : 0]->Shape(); if (other_input_shape == nullptr || initializer_rank > other_input_shape->dim_size()) { return false; } - int32_t data_type = initializer->data_type(); - Initializer add_init(*initializer, graph.ModelPath()); - if (add_init.size() > 1) { + int64_t tensor_size = 1; + for (auto i : dims) { + tensor_size *= i; + } + + if (tensor_size > 1) { return false; } + // handle edge case where the total size of the initializer is 0 - if (add_init.size() == 0) { + if (tensor_size == 0) { return true; } - float value = 0.0f; - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - value = *add_init.data(); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - value = math::halfToFloat(add_init.data()->val); - break; - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - value = static_cast(*add_init.data()); - break; - default: + if (op_type == "Add" || + op_type == "Sub" || + op_type == "Mul" || + op_type == "Div") + { + int32_t data_type = initializer->data_type(); + Initializer add_init(*initializer, graph.ModelPath()); + + float value = 0.0f; + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + value = *add_init.data(); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + value = math::halfToFloat(add_init.data()->val); + break; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + value = static_cast(*add_init.data()); + break; + default: + return false; + } + + if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) { return false; - } + } - if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) { - return false; - } - - if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) { - return false; + if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) { + return false; + } } // reject node output is graph output for now From 68f252c63ff090878d9ea149aa79ef8d8db87672 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Thu, 8 Feb 2024 18:38:33 -0800 Subject: [PATCH 06/10] cleanup --- onnxruntime/core/graph/graph.cc | 4 ---- onnxruntime/core/optimizer/graph_transformer_utils.cc | 6 ------ 2 files changed, 10 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 2127cbfdd4b75..e3ab351b27034 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1879,10 +1879,6 @@ void Graph::KahnsTopologicalSort(const std::function& enter, if (number_of_nodes != static_cast(topo_order.size())) { ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle."); } - - // for (auto i : topo_order) { - // printf("%d\n", static_cast(i)); - // } } GSL_SUPPRESS(es.84) // noisy warning about ignoring return value from insert(...) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index c60ecb07bf28d..abfe4b8a74d6e 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -196,12 +196,6 @@ InlinedVector> GenerateTransformers( const InlinedHashSet dml_ep = {onnxruntime::kDmlExecutionProvider}; AllocatorPtr cpu_allocator = std::make_shared(); - UNREFERENCED_PARAMETER(disable_quant_qdq); - UNREFERENCED_PARAMETER(cpu_ep); - UNREFERENCED_PARAMETER(dml_ep); - UNREFERENCED_PARAMETER(session_options); - UNREFERENCED_PARAMETER(cpu_execution_provider); - switch (level) { case TransformerLevel::Level1: { // RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run) From dfb9d3d2387909e568155737aa344f318c34d9ed Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Thu, 8 Feb 2024 18:39:32 -0800 Subject: [PATCH 07/10] cleanup --- onnxruntime/core/optimizer/graph_transformer_utils.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index abfe4b8a74d6e..cd3c49be15aa4 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -85,8 +85,6 @@ #endif // !defined(ORT_MINIMAL_BUILD) -#define UNREFERENCED_PARAMETER(P) (P) - namespace onnxruntime::optimizer_utils { static void FilterTransformers(InlinedVector>& transformers, From 51a6d356cfdaf2cf31d5701c8981ded7d395dd70 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Thu, 8 Feb 2024 20:58:14 -0800 Subject: [PATCH 08/10] reverse the returned item --- onnxruntime/core/graph/graph.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index e3ab351b27034..aef18b9d5fa79 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1833,8 +1833,8 @@ struct VisitorPriorityQueue { ); } bool empty() { return list_.empty(); } - T top(){ return list_.front(); } - void pop(){ list_.pop_front(); } + T top(){ return list_.back(); } + void pop(){ list_.pop_back(); } }; #if !defined(ORT_MINIMAL_BUILD) From 8ced95eb13f655eb01f419fc91512c114b7f9af8 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 9 Feb 2024 07:56:16 -0800 Subject: [PATCH 09/10] linrunner --- onnxruntime/core/graph/graph.cc | 12 +++++------- onnxruntime/core/optimizer/noop_elimination.cc | 3 +-- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index aef18b9d5fa79..305122c56b865 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1826,15 +1826,13 @@ struct VisitorPriorityQueue { VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} void push(T node) { - list_.insert - ( - std::upper_bound(list_.begin(), list_.end(), node, comparator_), - node - ); + list_.insert( + std::upper_bound(list_.begin(), list_.end(), node, comparator_), + node); } bool empty() { return list_.empty(); } - T top(){ return list_.back(); } - void pop(){ list_.pop_back(); } + T top() { return list_.back(); } + void pop() { list_.pop_back(); } }; #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc index 68dc7bc8c5120..eb9cbd3c7c2e6 100644 --- a/onnxruntime/core/optimizer/noop_elimination.cc +++ b/onnxruntime/core/optimizer/noop_elimination.cc @@ -66,8 +66,7 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con if (op_type == "Add" || op_type == "Sub" || op_type == "Mul" || - op_type == "Div") - { + op_type == "Div") { int32_t data_type = initializer->data_type(); Initializer add_init(*initializer, graph.ModelPath()); From a424542fc4404deb9f60eef9183cd96ae3bbf19d Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Thu, 15 Feb 2024 09:47:22 -0800 Subject: [PATCH 10/10] change order of value --- onnxruntime/core/optimizer/noop_elimination.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc index eb9cbd3c7c2e6..bba39b698a27a 100644 --- a/onnxruntime/core/optimizer/noop_elimination.cc +++ b/onnxruntime/core/optimizer/noop_elimination.cc @@ -91,11 +91,11 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con return false; } - if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) { + if (value != 0.0f && (op_type == "Add" || op_type == "Sub")) { return false; } - if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) { + if (value != 1.0f && (op_type == "Mul" || op_type == "Div")) { return false; } }