Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip Constant Folding for ops producing an optional type output #11839

Merged
merged 8 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions onnxruntime/core/optimizer/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,11 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
ORT_ENFORCE(fetches.size() == node->OutputDefs().size());
converted_to_constant = true;
for (size_t fetch_idx = 0; fetch_idx < fetches.size(); ++fetch_idx) {
OrtValue& ort_value = fetches[fetch_idx];
const auto& constant_arg_out = *node->OutputDefs()[fetch_idx];
// XXX: Add support for SparseTensors outputs when we have sparse outputs
skottmckay marked this conversation as resolved.
Show resolved Hide resolved
if (!ort_value.IsTensor()) {
LOGS(logger, WARNING) << "Unsupported output type of " << ort_value.Type()
<< ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'";
if (!utils::HasTensorType(*constant_arg_out.TypeAsProto())) {
LOGS(logger, INFO) << "Unsupported output type of " << constant_arg_out.Type()
<< ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'";
converted_to_constant = false;
break;
}
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/cpu/optional/optional_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ Status Optional::Compute(OpKernelContext* ctx) const {

} else { // No input was provided - we use the type proto to construct the output OrtValue

CheckValidTypeProto(*type_proto_);
if (!CheckValidTypeProto(*type_proto_)) {
hariharans29 marked this conversation as resolved.
Show resolved Hide resolved
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"The TypeProto attribute in the Optional op ",
"can only be of type(tensor) or (seq(tensor))");
}

skottmckay marked this conversation as resolved.
Show resolved Hide resolved
// type is either Tensor or TensorSeq (we have validated this already in CheckValidTypeProto())
if (type_proto_->has_tensor_type()) {
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/f
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx");
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx");
static constexpr PATH_TYPE CUSTOM_OP_MODEL_WITH_ATTRIBUTES_URI = TSTR("testdata/foo_bar_3.onnx");
static constexpr PATH_TYPE OPTIONAL_TYPE_GH_11717_MODEL = TSTR("testdata/gh_issue_11717.onnx");
#if !defined(DISABLE_SPARSE_TENSORS)
static constexpr PATH_TYPE SPARSE_OUTPUT_MODEL_URI = TSTR("testdata/sparse_initializer_as_output.onnx");
#ifndef DISABLE_CONTRIB_OPS
Expand Down Expand Up @@ -2164,3 +2165,13 @@ TEST(CApiTest, TestCudaMemcpyToHostWithSequenceTensors) {
}

#endif

#if !defined(DISABLE_OPTIONAL_TYPE)
TEST(CApiTest, GH_11717) {
const auto* model_path = OPTIONAL_TYPE_GH_11717_MODEL;
Ort::SessionOptions session_options{};
// Just check if the model loads fine without a segmentation fault
// in the default CPU EP
EXPECT_NO_THROW(Ort::Session session(*ort_env, model_path, session_options));
}
hariharans29 marked this conversation as resolved.
Show resolved Hide resolved
#endif
Binary file added onnxruntime/test/testdata/gh_issue_11717.onnx
Binary file not shown.