Skip to content

Commit

Permalink
Skip Constant Folding for ops producing an optional type output (#11839)
Browse files Browse the repository at this point in the history
  • Loading branch information
hariharans29 authored Jun 30, 2022
1 parent 0fa2041 commit 2e27a7e
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
8 changes: 4 additions & 4 deletions onnxruntime/core/optimizer/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,11 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
ORT_ENFORCE(fetches.size() == node->OutputDefs().size());
converted_to_constant = true;
for (size_t fetch_idx = 0; fetch_idx < fetches.size(); ++fetch_idx) {
OrtValue& ort_value = fetches[fetch_idx];
const auto& constant_arg_out = *node->OutputDefs()[fetch_idx];
// XXX: Add support for SparseTensors outputs when we have sparse outputs
if (!ort_value.IsTensor()) {
LOGS(logger, WARNING) << "Unsupported output type of " << ort_value.Type()
<< ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'";
if (!utils::HasTensorType(*constant_arg_out.TypeAsProto())) {
LOGS(logger, INFO) << "Unsupported output type of " << constant_arg_out.Type()
<< ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'";
converted_to_constant = false;
break;
}
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/cpu/optional/optional_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ Status Optional::Compute(OpKernelContext* ctx) const {

} else { // No input was provided - we use the type proto to construct the output OrtValue

CheckValidTypeProto(*type_proto_);
if (!CheckValidTypeProto(*type_proto_)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"The TypeProto attribute in the Optional op ",
"can only be of type(tensor) or (seq(tensor))");
}

// type is either Tensor or TensorSeq (we have validated this already in CheckValidTypeProto())
if (type_proto_->has_tensor_type()) {
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/f
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx");
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx");
static constexpr PATH_TYPE CUSTOM_OP_MODEL_WITH_ATTRIBUTES_URI = TSTR("testdata/foo_bar_3.onnx");
static constexpr PATH_TYPE OPTIONAL_TYPE_GH_11717_MODEL = TSTR("testdata/gh_issue_11717.onnx");
#if !defined(DISABLE_SPARSE_TENSORS)
static constexpr PATH_TYPE SPARSE_OUTPUT_MODEL_URI = TSTR("testdata/sparse_initializer_as_output.onnx");
#ifndef DISABLE_CONTRIB_OPS
Expand Down Expand Up @@ -2164,3 +2165,13 @@ TEST(CApiTest, TestCudaMemcpyToHostWithSequenceTensors) {
}

#endif

#if !defined(DISABLE_OPTIONAL_TYPE)
TEST(CApiTest, GH_11717) {
const auto* model_path = OPTIONAL_TYPE_GH_11717_MODEL;
Ort::SessionOptions session_options{};
// Just check if the model loads fine without a segmentation fault
// in the default CPU EP
EXPECT_NO_THROW(Ort::Session session(*ort_env, model_path, session_options));
}
#endif
Binary file added onnxruntime/test/testdata/gh_issue_11717.onnx
Binary file not shown.

0 comments on commit 2e27a7e

Please sign in to comment.