Skip to content

Commit

Permalink
Improved data_type method [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 11, 2024
1 parent 0c64caf commit 3114e7a
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 62 deletions.
2 changes: 1 addition & 1 deletion lib/onnxruntime/ffi.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module FFI

# enums
TensorElementDataType = enum(:undefined, :float, :uint8, :int8, :uint16, :int16, :int32, :int64, :string, :bool, :float16, :double, :uint32, :uint64, :complex64, :complex128, :bfloat16)
OnnxType = enum(:unknown, :tensor, :sequence, :map, :opaque, :sparsetensor)
OnnxType = enum(:unknown, :tensor, :sequence, :map, :opaque, :sparsetensor, :optional)

class Api < ::FFI::Struct
layout \
Expand Down
56 changes: 2 additions & 54 deletions lib/onnxruntime/inference_session.rb
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def load_inputs
# freed in node_info
typeinfo = ::FFI::MemoryPointer.new(:pointer)
check_status api[:SessionGetInputTypeInfo].call(read_pointer, i, typeinfo)
inputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
inputs << {name: name_ptr.read_pointer.read_string}.merge(Utils.node_info(typeinfo))
allocator_free name_ptr
end
inputs
Expand All @@ -266,7 +266,7 @@ def load_outputs
# freed in node_info
typeinfo = ::FFI::MemoryPointer.new(:pointer)
check_status api[:SessionGetOutputTypeInfo].call(read_pointer, i, typeinfo)
outputs << {name: name_ptr.read_pointer.read_string}.merge(node_info(typeinfo))
outputs << {name: name_ptr.read_pointer.read_string}.merge(Utils.node_info(typeinfo))
allocator_free name_ptr
end
outputs
Expand Down Expand Up @@ -482,58 +482,6 @@ def check_status(status)
Utils.check_status(status)
end

def node_info(typeinfo)
onnx_type = ::FFI::MemoryPointer.new(:int)
check_status api[:GetOnnxTypeFromTypeInfo].call(typeinfo.read_pointer, onnx_type)

type = FFI::OnnxType[onnx_type.read_int]
case type
when :tensor
tensor_info = ::FFI::MemoryPointer.new(:pointer)
# don't free tensor_info
check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)

type, shape = Utils.tensor_type_and_shape(tensor_info)
{
type: "tensor(#{FFI::TensorElementDataType[type]})",
shape: shape
}
when :sequence
sequence_type_info = ::FFI::MemoryPointer.new(:pointer)
check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info)
nested_type_info = ::FFI::MemoryPointer.new(:pointer)
check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info)
v = node_info(nested_type_info)[:type]

{
type: "seq(#{v})",
shape: []
}
when :map
map_type_info = ::FFI::MemoryPointer.new(:pointer)
check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info)

# key
key_type = ::FFI::MemoryPointer.new(:int)
check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type)
k = FFI::TensorElementDataType[key_type.read_int]

# value
value_type_info = ::FFI::MemoryPointer.new(:pointer)
check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info)
v = node_info(value_type_info)[:type]

{
type: "map(#{k},#{v})",
shape: []
}
else
Utils.unsupported_type("ONNX", type)
end
ensure
release :TypeInfo, typeinfo
end

def tensor_types
@tensor_types ||= [:float, :uint8, :int8, :uint16, :int16, :int32, :int64, :bool, :double, :uint32, :uint64].map { |v| ["tensor(#{v})", v] }.to_h
end
Expand Down
11 changes: 4 additions & 7 deletions lib/onnxruntime/ort_value.rb
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,10 @@ def tensor?
end

def data_type
type = FFI::OnnxType[value_type]

if type == :tensor
elem_type = FFI::TensorElementDataType[element_type]
"tensor(#{elem_type})"
else
Utils.unsupported_type("ONNX", type)
@data_type ||= begin
typeinfo = ::FFI::MemoryPointer.new(:pointer)
Utils.check_status FFI.api[:GetTypeInfo].call(out_ptr, typeinfo)
Utils.node_info(typeinfo)[:type]
end
end

Expand Down
52 changes: 52 additions & 0 deletions lib/onnxruntime/utils.rb
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,57 @@ def self.tensor_type_and_shape(tensor_info)

[type.read_int, dims]
end

def self.node_info(typeinfo)
onnx_type = ::FFI::MemoryPointer.new(:int)
check_status api[:GetOnnxTypeFromTypeInfo].call(typeinfo.read_pointer, onnx_type)

type = FFI::OnnxType[onnx_type.read_int]
case type
when :tensor
tensor_info = ::FFI::MemoryPointer.new(:pointer)
# don't free tensor_info
check_status api[:CastTypeInfoToTensorInfo].call(typeinfo.read_pointer, tensor_info)

type, shape = Utils.tensor_type_and_shape(tensor_info)
{
type: "tensor(#{FFI::TensorElementDataType[type]})",
shape: shape
}
when :sequence
sequence_type_info = ::FFI::MemoryPointer.new(:pointer)
check_status api[:CastTypeInfoToSequenceTypeInfo].call(typeinfo.read_pointer, sequence_type_info)
nested_type_info = ::FFI::MemoryPointer.new(:pointer)
check_status api[:GetSequenceElementType].call(sequence_type_info.read_pointer, nested_type_info)
v = node_info(nested_type_info)[:type]

{
type: "seq(#{v})",
shape: []
}
when :map
map_type_info = ::FFI::MemoryPointer.new(:pointer)
check_status api[:CastTypeInfoToMapTypeInfo].call(typeinfo.read_pointer, map_type_info)

# key
key_type = ::FFI::MemoryPointer.new(:int)
check_status api[:GetMapKeyType].call(map_type_info.read_pointer, key_type)
k = FFI::TensorElementDataType[key_type.read_int]

# value
value_type_info = ::FFI::MemoryPointer.new(:pointer)
check_status api[:GetMapValueType].call(map_type_info.read_pointer, value_type_info)
v = node_info(value_type_info)[:type]

{
type: "map(#{k},#{v})",
shape: []
}
else
Utils.unsupported_type("ONNX", type)
end
ensure
release :TypeInfo, typeinfo
end
end
end
3 changes: 3 additions & 0 deletions test/model_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,10 @@ def test_run_with_ort_values
sess = OnnxRuntime::InferenceSession.new("test/support/lightgbm.onnx")
x = OnnxRuntime::OrtValue.ortvalue_from_numo(Numo::SFloat.cast([[5.8, 2.8]]))
output = sess.run_with_ort_values(nil, {input: x})
assert_equal true, output[0].tensor?
assert_equal "tensor(int64)", output[0].data_type
assert_equal false, output[1].tensor?
assert_equal "seq(map(int64,tensor(float)))", output[1].data_type
end

def test_invalid_rank
Expand Down

0 comments on commit 3114e7a

Please sign in to comment.