-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Graph Optimization which converts SkipLayerNorm from Add + LayerNorm doesn't work with onnx models generated using newer version of onnx and onnxruntime. #12916
Comments
The 'Add' operator had some additional types added in opset 14. Those types aren't relevant here, so the existing implementation should work as is. Can you try updating the supported opset versions for Add in all places where it has '{7, 13}' to be '{7, 13, 14}'? e.g. onnxruntime/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc Lines 165 to 167 in 1e34440
|
I applied changes as follow, recompiled and reinstalled the ort, but it is giving me the same results. diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc
index a81ca6705..c980905d8 100644
--- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc
+++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc
@@ -163,8 +163,8 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
// Format 1
std::vector<graph_utils::EdgeEndToMatch> format1_parent_path{
- {0, 0, "Add", {7, 13}, kOnnxDomain},
- {0, 0, "Add", {7, 13}, kOnnxDomain}};
+ {0, 0, "Add", {7, 13, 14}, kOnnxDomain},
+ {0, 0, "Add", {7, 13, 14}, kOnnxDomain}};
std::vector<const Node::EdgeEnd*> edges;
if (graph_utils::FindPath(ln_node, true, format1_parent_path, edges, logger)) {
@@ -182,8 +182,8 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
if (matched_format == Format::None) {
// Format 2
std::vector<graph_utils::EdgeEndToMatch> format2_parent_path{
- {0, 0, "Add", {7, 13}, kOnnxDomain},
- {0, 1, "Add", {7, 13}, kOnnxDomain}};
+ {0, 0, "Add", {7, 13, 14}, kOnnxDomain},
+ {0, 1, "Add", {7, 13, 14}, kOnnxDomain}};
if (graph_utils::FindPath(ln_node, true, format2_parent_path, edges, logger)) {
p_add1 = const_cast<Node*>(&edges[0]->GetNode());
@@ -201,7 +201,7 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
if (matched_format == Format::None) {
// Format 3
std::vector<graph_utils::EdgeEndToMatch> format3_parent_path{
- {0, 0, "Add", {7, 13}, kOnnxDomain}};
+ {0, 0, "Add", {7, 13, 14}, kOnnxDomain}};
if (graph_utils::FindPath(ln_node, true, format3_parent_path, edges, logger)) {
p_add1 = const_cast<Node*>(&edges[0]->GetNode()); One more thing when I checked graph optimization with CPUExecutionProvider, for both the cases the desired and undesired results was exactly as described in the question. But when I changed execution provider to DnnlExecutionProvider in ort.InferenceSession, for both the cases I get the following error. onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Unable to serialize model as it contains compiled nodes. Please disable any execution providers which generate compiled nodes. |
Couple of factors here. Originally LayerNormalization was an ORT internal operator. Unfortunately when it was created it was defined as using the ONNX domain which slightly confuses things as it's using the ONNX opset instead of the ORT domain's opset to match a kernel. The latest version of ONNX adds an official LayerNormalization operator, and because the internal operator is using the ONNX domain instead of our internal domain, it now matches the ONNX operator in opset 17. The ONNX LayerNormalization is defined as a function. As there's no kernel registered for ONNX opset 17 in ORT to handle LayerNormalization the function gets expanded into the nodes that you see, which happens prior to fusion. Quickest option would be to use opset 16 for your model (would still require the update to add opset 14 of 'Add' to the optimizer). This draft PR (#12978) will do the fusion as expected, but I haven't verified that there are no diffs between the ONNX spec and what the contrib op does. |
**Description**: LayerNormalization is now part of the ONNX spec as of opset 17. We had a LayerNormalization contrib op, which (incorrectly) was registered in the ONNX domain. Use that implementation for the ONNX operator. Update skip_layer_norm_fusion.cc. There are other optimizers that use LayerNormalization that need updates as well. **Motivation and Context** #12916
**Description**: LayerNormalization is now part of the ONNX spec as of opset 17. We had a LayerNormalization contrib op, which (incorrectly) was registered in the ONNX domain. Use that implementation for the ONNX operator. Update skip_layer_norm_fusion.cc. There are other optimizers that use LayerNormalization that need updates as well. **Motivation and Context** #12916
Describe the issue
to work with the latest onnx version 8 and opset id 17, but this gives a compilation error.
To reproduce
Use zip provided for reference.
sln_graph_opt_issue.zip
Applying graph optimization on skip_layer_norm_no_beta.onnx (a model whose onnx version is 7 and opset id is 12 provided in zip), it get converted to skip_layer_norm_no_beta_after_graph_optimization.onnx (as provided in the zip file). This is the desired result.
Applying graph optimization on skip_layer_norm_no_beta_newer_version.onnx (a model whose onnx version is 8 and opset id is 17 provided in zip), it get converted to skip_layer_norm_no_beta_newer_version_after_graph_optimization.onnx (as provided in the zip file). This is the undesired result.
Urgency
Medium to Low
Platform
Windows
OS Version
11
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
e3b5011
ONNX Runtime API
Python
Architecture
X86
Execution Provider
oneDNN
Execution Provider Library Version
2.6
The text was updated successfully, but these errors were encountered: