Skip to content

Commit

Permalink
Added support for CoreML
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 2, 2024
1 parent 77c46dd commit f73ea63
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## 0.9.2 (unreleased)

- Updated ONNX Runtime to 1.19.0
- Added support for CoreML

## 0.9.1 (2024-05-22)

Expand Down
6 changes: 6 additions & 0 deletions lib/onnxruntime/ffi.rb
Original file line number Diff line number Diff line change
Expand Up @@ -257,5 +257,11 @@ class Libc
attach_function :mbstowcs, %i[pointer string size_t], :size_t
end
end

# https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
begin
attach_function :OrtSessionOptionsAppendExecutionProvider_CoreML, %i[pointer uint32], :pointer
rescue FFI::NotFoundError
end
end
end
7 changes: 7 additions & 0 deletions lib/onnxruntime/inference_session.rb
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: tr
check_status api[:CreateCUDAProviderOptions].call(cuda_options)
check_status api[:SessionOptionsAppendExecutionProvider_CUDA_V2].call(session_options.read_pointer, cuda_options.read_pointer)
release :CUDAProviderOptions, cuda_options
when "CoreMLExecutionProvider"
unless FFI.respond_to?(:OrtSessionOptionsAppendExecutionProvider_CoreML)
raise ArgumentError, "Provider not supported: #{provider}"
end

coreml_flags = 0x010 # COREML_FLAG_CREATE_MLPROGRAM
check_status FFI.OrtSessionOptionsAppendExecutionProvider_CoreML(session_options.read_pointer, coreml_flags)
when "CPUExecutionProvider"
break
else
Expand Down
11 changes: 11 additions & 0 deletions test/model_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,17 @@ def test_providers_cuda
end
end

def test_providers_coreml
skip unless mac?

model = OnnxRuntime::InferenceSession.new("test/support/lightgbm.onnx", providers: ["CoreMLExecutionProvider", "CPUExecutionProvider"])
x = [[5.8, 2.8]]
label, probabilities = model.run(nil, {input: x})
assert_equal [1], label
assert_equal [0, 1, 2], probabilities[0].keys
assert_elements_in_delta [0.2593829035758972, 0.409047931432724, 0.3315691649913788], probabilities[0].values
end

def test_profiling
sess = OnnxRuntime::InferenceSession.new("test/support/model.onnx", enable_profiling: true)
file = sess.end_profiling
Expand Down
4 changes: 4 additions & 0 deletions test/test_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def assert_elements_in_delta(expected, actual)
end
end

def mac?
RbConfig::CONFIG["host_os"] =~ /darwin/i
end

def stress?
ENV["STRESS"]
end
Expand Down

0 comments on commit f73ea63

Please sign in to comment.