diff --git a/lib/onnxruntime/ffi.rb b/lib/onnxruntime/ffi.rb index 1763029..8755959 100644 --- a/lib/onnxruntime/ffi.rb +++ b/lib/onnxruntime/ffi.rb @@ -11,7 +11,7 @@ module FFI # enums TensorElementDataType = enum(:undefined, :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :string, :bool, :float16, :double, :uint32, :uint64, :complex64, :complex128, :bfloat16) - OnnxType = enum(:unknown, :tensor, :sequence, :map, :opaque, :sparsetensor) + OnnxType = enum(:unknown, :tensor, :sequence, :map, :opaque, :sparsetensor, :optional) class Api < ::FFI::Struct layout \ diff --git a/lib/onnxruntime/inference_session.rb b/lib/onnxruntime/inference_session.rb index 88eec67..54a318e 100644 --- a/lib/onnxruntime/inference_session.rb +++ b/lib/onnxruntime/inference_session.rb @@ -250,7 +250,7 @@ def load_inputs # freed in node_info typeinfo = ::FFI::MemoryPointer.new(:pointer) check_status api[:SessionGetInputTypeInfo].call(read_pointer, i, typeinfo) - inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo)) + inputs << {name: name_ptr.read_pointer.read_string}.merge(Utils.node_info(typeinfo)) allocator_free name_ptr end inputs @@ -266,7 +266,7 @@ def load_outputs # freed in node_info typeinfo = ::FFI::MemoryPointer.new(:pointer) check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo) - outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo)) + outputs << {name: name_ptr.read_pointer.read_string}.merge(Utils.node_info(typeinfo)) allocator_free name_ptr end outputs @@ -482,58 +482,6 @@ def check_status(status) Utils.check_status(status) end - def node_info(typeinfo) - onnx_type = ::FFI::MemoryPointer.new(:int) - check_status api[:GetOnnxTypeFromTypeInfo].call(typeinfo.read_pointer, onnx_type) - - type = FFI::OnnxType[onnx_type.read_int] - case type - when :tensor - tensor_info = ::FFI::MemoryPointer.new(:pointer) - # don't free tensor_info - check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info) - - type, shape = Utils.tensor_type_and_shape(tensor_info) - { - type: "tensor(#{FFI::TensorElementDataType[type]})", - shape: shape - } - when :sequence - sequence_type_info = ::FFI::MemoryPointer.new(:pointer) - check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info) - nested_type_info = ::FFI::MemoryPointer.new(:pointer) - check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info) - v = node_info(nested_type_info)[:type] - - { - type: "seq(#{v})", - shape: [] - } - when :map - map_type_info = ::FFI::MemoryPointer.new(:pointer) - check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info) - - # key - key_type = ::FFI::MemoryPointer.new(:int) - check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type) - k = FFI::TensorElementDataType[key_type.read_int] - - # value - value_type_info = ::FFI::MemoryPointer.new(:pointer) - check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info) - v = node_info(value_type_info)[:type] - - { - type: "map(#{k},#{v})", - shape: [] - } - else - Utils.unsupported_type("ONNX", type) - end - ensure - release :TypeInfo, typeinfo - end - def tensor_types @tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h end diff --git a/lib/onnxruntime/ort_value.rb b/lib/onnxruntime/ort_value.rb index 3e86175..e5ed719 100644 --- a/lib/onnxruntime/ort_value.rb +++ b/lib/onnxruntime/ort_value.rb @@ -29,13 +29,10 @@ def tensor? end def data_type - type = FFI::OnnxType[value_type] - - if type == :tensor - elem_type = FFI::TensorElementDataType[element_type] - "tensor(#{elem_type})" - else - Utils.unsupported_type("ONNX", type) + @data_type ||= begin + typeinfo = ::FFI::MemoryPointer.new(:pointer) + Utils.check_status FFI.api[:GetTypeInfo].call(out_ptr, typeinfo) + Utils.node_info(typeinfo)[:type] end end diff --git a/lib/onnxruntime/utils.rb b/lib/onnxruntime/utils.rb index fa6989b..4d068cd 100644 --- a/lib/onnxruntime/utils.rb +++ b/lib/onnxruntime/utils.rb @@ -52,5 +52,57 @@ def self.tensor_type_and_shape(tensor_info) [type.read_int, dims] end + + def self.node_info(typeinfo) + onnx_type = ::FFI::MemoryPointer.new(:int) + check_status api[:GetOnnxTypeFromTypeInfo].call(typeinfo.read_pointer, onnx_type) + + type = FFI::OnnxType[onnx_type.read_int] + case type + when :tensor + tensor_info = ::FFI::MemoryPointer.new(:pointer) + # don't free tensor_info + check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info) + + type, shape = Utils.tensor_type_and_shape(tensor_info) + { + type: "tensor(#{FFI::TensorElementDataType[type]})", + shape: shape + } + when :sequence + sequence_type_info = ::FFI::MemoryPointer.new(:pointer) + check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info) + nested_type_info = ::FFI::MemoryPointer.new(:pointer) + check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info) + v = node_info(nested_type_info)[:type] + + { + type: "seq(#{v})", + shape: [] + } + when :map + map_type_info = ::FFI::MemoryPointer.new(:pointer) + check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info) + + # key + key_type = ::FFI::MemoryPointer.new(:int) + check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type) + k = FFI::TensorElementDataType[key_type.read_int] + + # value + value_type_info = ::FFI::MemoryPointer.new(:pointer) + check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info) + v = node_info(value_type_info)[:type] + + { + type: "map(#{k},#{v})", + shape: [] + } + else + Utils.unsupported_type("ONNX", type) + end + ensure + release :TypeInfo, typeinfo + end end end diff --git a/test/model_test.rb b/test/model_test.rb index 6f98b2e..c79d8f0 100644 --- a/test/model_test.rb +++ b/test/model_test.rb @@ -226,7 +226,10 @@ def test_run_with_ort_values sess = OnnxRuntime::InferenceSession.new("test/support/lightgbm.onnx") x = OnnxRuntime::OrtValue.ortvalue_from_numo(Numo::SFloat.cast([[5.8, 2.8]])) output = sess.run_with_ort_values(nil, {input: x}) + assert_equal true, output[0].tensor? assert_equal "tensor(int64)", output[0].data_type + assert_equal false, output[1].tensor? + assert_equal "seq(map(int64,tensor(float)))", output[1].data_type end def test_invalid_rank