Skip to content

Commit

Permalink
Added support for OrtValue to run
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 11, 2024
1 parent 16acd38 commit 3501a30
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
4 changes: 3 additions & 1 deletion lib/onnxruntime/inference_session.rb
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,9 @@ def create_input_tensor(input_feed)
inp = @inputs.find { |i| i[:name] == input_name.to_s }
raise Error, "Unknown input: #{input_name}" unless inp

if inp[:type] == "tensor(string)"
if input.is_a?(OrtValue)
input
elsif inp[:type] == "tensor(string)"
OrtValue.ortvalue_from_array(input, element_type: :string)
elsif (tensor_type = tensor_types[inp[:type]])
OrtValue.ortvalue_from_array(input, element_type: tensor_type)
Expand Down
7 changes: 7 additions & 0 deletions test/inference_session_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def test_run_with_ort_values_invalid_type
assert_equal "Unexpected input data type. Actual: (tensor(double)) , expected: (tensor(float))", error.message
end

def test_run_ort_value_input
sess = OnnxRuntime::InferenceSession.new("test/support/lightgbm.onnx")
x = OnnxRuntime::OrtValue.ortvalue_from_numo(Numo::SFloat.cast([[5.8, 2.8]]))
output = sess.run(nil, {input: x})
assert_equal [1], output[0]
end

def test_providers
sess = OnnxRuntime::InferenceSession.new("test/support/model.onnx")
assert_includes sess.providers, "CPUExecutionProvider"
Expand Down

0 comments on commit 3501a30

Please sign in to comment.