diff --git a/CHANGELOG.md b/CHANGELOG.md index 9145594..238ed47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 0.9.2 (unreleased) - Updated ONNX Runtime to 1.19.0 +- Added support for CoreML ## 0.9.1 (2024-05-22) diff --git a/lib/onnxruntime/ffi.rb b/lib/onnxruntime/ffi.rb index ac779d1..bbf5308 100644 --- a/lib/onnxruntime/ffi.rb +++ b/lib/onnxruntime/ffi.rb @@ -257,5 +257,11 @@ class Libc attach_function :mbstowcs, %i[pointer string size_t], :size_t end end + + # https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h + begin + attach_function :OrtSessionOptionsAppendExecutionProvider_CoreML, %i[pointer uint32], :pointer + rescue FFI::NotFoundError + end end end diff --git a/lib/onnxruntime/inference_session.rb b/lib/onnxruntime/inference_session.rb index bddc2ff..8cc22b0 100644 --- a/lib/onnxruntime/inference_session.rb +++ b/lib/onnxruntime/inference_session.rb @@ -66,6 +66,13 @@ def initialize(path_or_bytes, enable_cpu_mem_arena: true, enable_mem_pattern: tr check_status api[:CreateCUDAProviderOptions].call(cuda_options) check_status api[:SessionOptionsAppendExecutionProvider_CUDA_V2].call(session_options.read_pointer, cuda_options.read_pointer) release :CUDAProviderOptions, cuda_options + when "CoreMLExecutionProvider" + unless FFI.respond_to?(:OrtSessionOptionsAppendExecutionProvider_CoreML) + raise ArgumentError, "Provider not supported: #{provider}" + end + + coreml_flags = 0x010 # COREML_FLAG_CREATE_MLPROGRAM + check_status FFI.OrtSessionOptionsAppendExecutionProvider_CoreML(session_options.read_pointer, coreml_flags) when "CPUExecutionProvider" break else diff --git a/test/model_test.rb b/test/model_test.rb index 7142fec..0042181 100644 --- a/test/model_test.rb +++ b/test/model_test.rb @@ -286,6 +286,17 @@ def test_providers_cuda end end + def test_providers_coreml + skip unless mac? + + model = OnnxRuntime::InferenceSession.new("test/support/lightgbm.onnx", providers: ["CoreMLExecutionProvider", "CPUExecutionProvider"]) + x = [[5.8, 2.8]] + label, probabilities = model.run(nil, {input: x}) + assert_equal [1], label + assert_equal [0, 1, 2], probabilities[0].keys + assert_elements_in_delta [0.2593829035758972, 0.409047931432724, 0.3315691649913788], probabilities[0].values + end + def test_profiling sess = OnnxRuntime::InferenceSession.new("test/support/model.onnx", enable_profiling: true) file = sess.end_profiling diff --git a/test/test_helper.rb b/test/test_helper.rb index 25c3f3f..c26769b 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -22,6 +22,10 @@ def assert_elements_in_delta(expected, actual) end end + def mac? + RbConfig::CONFIG["host_os"] =~ /darwin/i + end + def stress? ENV["STRESS"] end