From 98fff18a7b54948b95d2dc0184827bb5b52851ec Mon Sep 17 00:00:00 2001 From: Hari Seshadri Date: Mon, 13 Jun 2022 15:08:23 -0700 Subject: [PATCH] Skip ConstantFolding for ops producing an optional type output --- onnxruntime/core/optimizer/constant_folding.cc | 6 +++--- onnxruntime/core/providers/cpu/optional/optional_ops.cc | 6 +++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index e6735c2e9ad0d..96fdacb12e23e 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -196,10 +196,10 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, ORT_ENFORCE(fetches.size() == node->OutputDefs().size()); converted_to_constant = true; for (size_t fetch_idx = 0; fetch_idx < fetches.size(); ++fetch_idx) { - OrtValue& ort_value = fetches[fetch_idx]; + auto* constant_arg_out = node->MutableOutputDefs()[fetch_idx]; // XXX: Add support for SparseTensors outputs when we have sparse outputs - if (!ort_value.IsTensor()) { - LOGS(logger, WARNING) << "Unsupported output type of " << ort_value.Type() + if (!utils::HasTensorType(*constant_arg_out->TypeAsProto())) { + LOGS(logger, INFO) << "Unsupported output type of " << constant_arg_out->Type() << ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'"; converted_to_constant = false; break; diff --git a/onnxruntime/core/providers/cpu/optional/optional_ops.cc b/onnxruntime/core/providers/cpu/optional/optional_ops.cc index 1d3ead1980c2d..25f69bcc9d29e 100644 --- a/onnxruntime/core/providers/cpu/optional/optional_ops.cc +++ b/onnxruntime/core/providers/cpu/optional/optional_ops.cc @@ -118,7 +118,11 @@ Status Optional::Compute(OpKernelContext* ctx) const { } else { // No input was provided - we use the type proto to construct the output OrtValue - CheckValidTypeProto(*type_proto_); + if (!CheckValidTypeProto(*type_proto_)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The TypeProto attribute in the Optional op ", + "can only be of type(tensor) or (seq(tensor))"); + } // type is either Tensor or TensorSeq (we have validated this already in CheckValidTypeProto()) if (type_proto_->has_tensor_type()) {