From 63119ea8deec43729c33422c37adce888727f377 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Thu, 30 Jun 2022 13:38:35 -0700 Subject: [PATCH] Skip Constant Folding for ops producing an optional type output (#11839) --- onnxruntime/core/optimizer/constant_folding.cc | 8 ++++---- .../core/providers/cpu/optional/optional_ops.cc | 6 +++++- onnxruntime/test/shared_lib/test_inference.cc | 11 +++++++++++ onnxruntime/test/testdata/gh_issue_11717.onnx | Bin 0 -> 1253 bytes 4 files changed, 20 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/test/testdata/gh_issue_11717.onnx diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index e6735c2e9ad0d..f5dc2c0b3c054 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -196,11 +196,11 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, ORT_ENFORCE(fetches.size() == node->OutputDefs().size()); converted_to_constant = true; for (size_t fetch_idx = 0; fetch_idx < fetches.size(); ++fetch_idx) { - OrtValue& ort_value = fetches[fetch_idx]; + const auto& constant_arg_out = *node->OutputDefs()[fetch_idx]; // XXX: Add support for SparseTensors outputs when we have sparse outputs - if (!ort_value.IsTensor()) { - LOGS(logger, WARNING) << "Unsupported output type of " << ort_value.Type() - << ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'"; + if (!utils::HasTensorType(*constant_arg_out.TypeAsProto())) { + LOGS(logger, INFO) << "Unsupported output type of " << constant_arg_out.Type() + << ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'"; converted_to_constant = false; break; } diff --git a/onnxruntime/core/providers/cpu/optional/optional_ops.cc b/onnxruntime/core/providers/cpu/optional/optional_ops.cc index 1d3ead1980c2d..25f69bcc9d29e 100644 --- a/onnxruntime/core/providers/cpu/optional/optional_ops.cc +++ b/onnxruntime/core/providers/cpu/optional/optional_ops.cc @@ -118,7 +118,11 @@ Status Optional::Compute(OpKernelContext* ctx) const { } else { // No input was provided - we use the type proto to construct the output OrtValue - CheckValidTypeProto(*type_proto_); + if (!CheckValidTypeProto(*type_proto_)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The TypeProto attribute in the Optional op ", + "can only be of type(tensor) or (seq(tensor))"); + } // type is either Tensor or TensorSeq (we have validated this already in CheckValidTypeProto()) if (type_proto_->has_tensor_type()) { diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index ebe820f157473..5c37bb77a01f3 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -182,6 +182,7 @@ static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/f static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx"); static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx"); static constexpr PATH_TYPE CUSTOM_OP_MODEL_WITH_ATTRIBUTES_URI = TSTR("testdata/foo_bar_3.onnx"); +static constexpr PATH_TYPE OPTIONAL_TYPE_GH_11717_MODEL = TSTR("testdata/gh_issue_11717.onnx"); #if !defined(DISABLE_SPARSE_TENSORS) static constexpr PATH_TYPE SPARSE_OUTPUT_MODEL_URI = TSTR("testdata/sparse_initializer_as_output.onnx"); #ifndef DISABLE_CONTRIB_OPS @@ -2164,3 +2165,13 @@ TEST(CApiTest, TestCudaMemcpyToHostWithSequenceTensors) { } #endif + +#if !defined(DISABLE_OPTIONAL_TYPE) +TEST(CApiTest, GH_11717) { + const auto* model_path = OPTIONAL_TYPE_GH_11717_MODEL; + Ort::SessionOptions session_options{}; + // Just check if the model loads fine without a segmentation fault + // in the default CPU EP + EXPECT_NO_THROW(Ort::Session session(*ort_env, model_path, session_options)); +} +#endif diff --git a/onnxruntime/test/testdata/gh_issue_11717.onnx b/onnxruntime/test/testdata/gh_issue_11717.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4ea10d042a35a859df9247a63ac5a34b1738eca1 GIT binary patch literal 1253 zcmb7@L66cv6vvrT*!C4oouJvRQDYC9G=?xOrG=xbQ8?(`m^5YW1`M_YMq{u05lp;! z;QM&?J9zUe_&PW+uyWWY?MyrWdGp@?{c0HWWKr^Bc`RP290x79gd?;Tbt9)(-lv`jj*>72C%nAVWS*f zCVi*5(YtyzdM0SRI5pQqgC9C5y1nQTB{8j+o!HT<7_K(D^nMet_%zFt$ui9^rv5Ji ze&qqm{9Wy@GA?V6cD3)m!UG_TY7F}_`hX72c$vhmA?Fa{8g@aNmJD1&&b>3xdAZ1z zx1nV;sR5c%$6H^MTNFg^laDjbCN045v