From 7bfc3cfeb54aebfe63c602dbe5d9a58472873429 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Wed, 16 Aug 2023 17:24:36 -0700 Subject: [PATCH] wasi-nn: add [named models] This change adds a way to retrieve preloaded ML models (i.e., "graphs" in wasi-nn terms) from a registry. The wasi-nn specification includes a new function, `load_by_name`, that can be used to access these models more efficiently than before; previously, a user's only option was to read/download/etc. all of the bytes of an ML model and pass them to the `load` function. [named models]: https://github.com/WebAssembly/wasi-nn/issues/36 In Wasmtime's implementation of wasi-nn, we call the registry that holds the models a `GraphRegistry`. We include a simplistic `InMemoryRegistry` for use in the Wasmtime CLI (more on this later) but the idea is that production use will involve some more complex caching and thus a new implementation of a registry--a `Box`--passed into the wasi-nn context. Note that, because we now must be able to `clone` a graph out of the registry and into the "used graphs" table, the OpenVINO `BackendGraph` is updated to be easier to copy around. To allow experimentation with this "preload a named model" functionality, this change also adds a new Wasmtime CLI flag: `--graph :`. Wasmtime CLI users can now preload a model from a directory; the directory `basename` is used as the model name. Loading models from a directory is probably not desired in Wasmtime embeddings so it is cordoned off into a separate `BackendFromDir` extension trait. --- crates/wasi-nn/src/backend/mod.rs | 34 ++++++++++++++- crates/wasi-nn/src/backend/openvino.rs | 45 +++++++++++++++++--- crates/wasi-nn/src/ctx.rs | 54 +++++++++++++++++++----- crates/wasi-nn/src/lib.rs | 16 +++---- crates/wasi-nn/src/registry/in_memory.rs | 43 +++++++++++++++++++ crates/wasi-nn/src/registry/mod.rs | 16 +++++++ crates/wasi-nn/src/wit.rs | 11 +++-- crates/wasi-nn/src/witx.rs | 10 ++++- src/commands/run.rs | 22 +++++++++- 9 files changed, 217 insertions(+), 34 deletions(-) create mode 100644 crates/wasi-nn/src/registry/in_memory.rs create mode 100644 crates/wasi-nn/src/registry/mod.rs diff --git a/crates/wasi-nn/src/backend/mod.rs b/crates/wasi-nn/src/backend/mod.rs index ad929a9a74e0..5912b26b606d 100644 --- a/crates/wasi-nn/src/backend/mod.rs +++ b/crates/wasi-nn/src/backend/mod.rs @@ -7,6 +7,7 @@ mod openvino; use self::openvino::OpenvinoBackend; use crate::wit::types::{ExecutionTarget, Tensor}; use crate::{ExecutionContext, Graph}; +use std::{error::Error, fmt, path::Path, str::FromStr}; use thiserror::Error; use wiggle::GuestError; @@ -15,16 +16,28 @@ pub fn list() -> Vec<(BackendKind, Box)> { vec![(BackendKind::OpenVINO, Box::new(OpenvinoBackend::default()))] } -/// A [Backend] contains the necessary state to load [BackendGraph]s. +/// A [Backend] contains the necessary state to load [Graph]s. pub trait Backend: Send + Sync { fn name(&self) -> &str; fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result; + fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir>; +} + +/// Some [Backend]s support loading a [Graph] from a directory on the +/// filesystem; this is not a general requirement for backends but is useful for +/// the Wasmtime CLI. +pub trait BackendFromDir: Backend { + fn load_from_dir( + &mut self, + builders: &Path, + target: ExecutionTarget, + ) -> Result; } /// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing /// implementation for a [crate::witx::types::Graph]. pub trait BackendGraph: Send + Sync { - fn init_execution_context(&mut self) -> Result; + fn init_execution_context(&self) -> Result; } /// A [BackendExecutionContext] performs the actual inference; this is the @@ -53,3 +66,20 @@ pub enum BackendError { pub enum BackendKind { OpenVINO, } +impl FromStr for BackendKind { + type Err = BackendKindParseError; + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "openvino" => Ok(BackendKind::OpenVINO), + _ => Err(BackendKindParseError(s.into())), + } + } +} +#[derive(Debug)] +pub struct BackendKindParseError(String); +impl fmt::Display for BackendKindParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "unknown backend: {}", self.0) + } +} +impl Error for BackendKindParseError {} diff --git a/crates/wasi-nn/src/backend/openvino.rs b/crates/wasi-nn/src/backend/openvino.rs index 93f51771c95f..a478ec9a4bdc 100644 --- a/crates/wasi-nn/src/backend/openvino.rs +++ b/crates/wasi-nn/src/backend/openvino.rs @@ -1,10 +1,11 @@ //! Implements a `wasi-nn` [`Backend`] using OpenVINO. -use super::{Backend, BackendError, BackendExecutionContext, BackendGraph}; +use super::{Backend, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph}; use crate::wit::types::{ExecutionTarget, Tensor, TensorType}; use crate::{ExecutionContext, Graph}; use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; +use std::{fs::File, io::Read, path::Path}; #[derive(Default)] pub(crate) struct OpenvinoBackend(Option); @@ -51,20 +52,42 @@ impl Backend for OpenvinoBackend { let exec_network = core.load_network(&cnn_network, map_execution_target_to_string(target))?; - let box_: Box = - Box::new(OpenvinoGraph(Arc::new(cnn_network), exec_network)); + let box_: Box = Box::new(OpenvinoGraph( + Arc::new(cnn_network), + Arc::new(Mutex::new(exec_network)), + )); Ok(box_.into()) } + + fn as_dir_loadable(&mut self) -> Option<&mut dyn BackendFromDir> { + Some(self) + } } -struct OpenvinoGraph(Arc, openvino::ExecutableNetwork); +impl BackendFromDir for OpenvinoBackend { + fn load_from_dir( + &mut self, + path: &Path, + target: ExecutionTarget, + ) -> Result { + let model = read(&path.join("model.xml"))?; + let weights = read(&path.join("model.bin"))?; + self.load(&[&model, &weights], target) + } +} + +struct OpenvinoGraph( + Arc, + Arc>, +); unsafe impl Send for OpenvinoGraph {} unsafe impl Sync for OpenvinoGraph {} impl BackendGraph for OpenvinoGraph { - fn init_execution_context(&mut self) -> Result { - let infer_request = self.1.create_infer_request()?; + fn init_execution_context(&self) -> Result { + let mut network = self.1.lock().unwrap(); + let infer_request = network.create_infer_request()?; let box_: Box = Box::new(OpenvinoExecutionContext(self.0.clone(), infer_request)); Ok(box_.into()) @@ -145,3 +168,11 @@ fn map_tensor_type_to_precision(tensor_type: TensorType) -> openvino::Precision TensorType::Bf16 => todo!("not yet supported in `openvino` bindings"), } } + +/// Read a file into a byte vector. +fn read(path: &Path) -> anyhow::Result> { + let mut file = File::open(path)?; + let mut buffer = vec![]; + file.read_to_end(&mut buffer)?; + Ok(buffer) +} diff --git a/crates/wasi-nn/src/ctx.rs b/crates/wasi-nn/src/ctx.rs index e961938a8f3d..913205c6d81c 100644 --- a/crates/wasi-nn/src/ctx.rs +++ b/crates/wasi-nn/src/ctx.rs @@ -1,38 +1,59 @@ //! Implements the host state for the `wasi-nn` API: [WasiNnCtx]. -use crate::backend::{self, Backend, BackendError, BackendKind}; +use crate::backend::{Backend, BackendError, BackendKind}; use crate::wit::types::GraphEncoding; -use crate::{ExecutionContext, Graph}; -use std::{collections::HashMap, hash::Hash}; +use crate::{ExecutionContext, Graph, GraphRegistry, InMemoryRegistry}; +use anyhow::anyhow; +use std::{collections::HashMap, hash::Hash, path::Path}; use thiserror::Error; use wiggle::GuestError; type Backends = HashMap>; +type Registry = Box; type GraphId = u32; type GraphExecutionContextId = u32; +type BackendName = String; +type GraphDirectory = String; + +/// Construct an in-memory registry from the available backends and a list of +/// `(, )`. This assumes graphs can be loaded +/// from a local directory, which is a safe assumption currently for the current +/// model types. +pub fn preload( + preload_graphs: &[(BackendName, GraphDirectory)], +) -> anyhow::Result<(Backends, Registry)> { + let mut backends: HashMap<_, _> = crate::backend::list().into_iter().collect(); + let mut registry = InMemoryRegistry::new(); + for (kind, path) in preload_graphs { + let backend = backends + .get_mut(&kind.parse()?) + .ok_or(anyhow!("unsupported backend: {}", kind))? + .as_dir_loadable() + .ok_or(anyhow!("{} does not support directory loading", kind))?; + registry.load(backend, Path::new(path))?; + } + Ok((backends, Box::new(registry))) +} /// Capture the state necessary for calling into the backend ML libraries. pub struct WasiNnCtx { pub(crate) backends: Backends, + pub(crate) registry: Registry, pub(crate) graphs: Table, pub(crate) executions: Table, } impl WasiNnCtx { /// Make a new context from the default state. - pub fn new(backends: Backends) -> Self { + pub fn new(backends: Backends, registry: Registry) -> Self { Self { backends, + registry, graphs: Table::default(), executions: Table::default(), } } } -impl Default for WasiNnCtx { - fn default() -> Self { - WasiNnCtx::new(backend::list().into_iter().collect()) - } -} /// Possible errors while interacting with [WasiNnCtx]. #[derive(Debug, Error)] @@ -90,6 +111,10 @@ where key } + pub fn get(&self, key: K) -> Option<&V> { + self.entries.get(&key) + } + pub fn get_mut(&mut self, key: K) -> Option<&mut V> { self.entries.get_mut(&key) } @@ -106,7 +131,14 @@ mod test { use super::*; #[test] - fn instantiate() { - WasiNnCtx::default(); + fn example() { + struct FakeRegistry; + impl GraphRegistry for FakeRegistry { + fn get_mut(&mut self, _: &str) -> Option<&mut Graph> { + None + } + } + + let ctx = WasiNnCtx::new(HashMap::new(), Box::new(FakeRegistry)); } } diff --git a/crates/wasi-nn/src/lib.rs b/crates/wasi-nn/src/lib.rs index 1abd6c0b1372..f5c9bfe641d8 100644 --- a/crates/wasi-nn/src/lib.rs +++ b/crates/wasi-nn/src/lib.rs @@ -1,15 +1,20 @@ mod backend; mod ctx; +mod registry; -pub use ctx::WasiNnCtx; +pub use ctx::{preload, WasiNnCtx}; +pub use registry::{GraphRegistry, InMemoryRegistry}; pub mod wit; pub mod witx; +use std::sync::Arc; + /// A backend-defined graph (i.e., ML model). -pub struct Graph(Box); +#[derive(Clone)] +pub struct Graph(Arc); impl From> for Graph { fn from(value: Box) -> Self { - Self(value) + Self(value.into()) } } impl std::ops::Deref for Graph { @@ -18,11 +23,6 @@ impl std::ops::Deref for Graph { self.0.as_ref() } } -impl std::ops::DerefMut for Graph { - fn deref_mut(&mut self) -> &mut Self::Target { - self.0.as_mut() - } -} /// A backend-defined execution context. pub struct ExecutionContext(Box); diff --git a/crates/wasi-nn/src/registry/in_memory.rs b/crates/wasi-nn/src/registry/in_memory.rs new file mode 100644 index 000000000000..b008f7f43684 --- /dev/null +++ b/crates/wasi-nn/src/registry/in_memory.rs @@ -0,0 +1,43 @@ +//! Implement a [`GraphRegistry`] with a hash map. + +use super::{Graph, GraphRegistry}; +use crate::backend::BackendFromDir; +use crate::wit::types::ExecutionTarget; +use anyhow::{anyhow, bail}; +use std::{collections::HashMap, path::Path}; + +pub struct InMemoryRegistry(HashMap); +impl InMemoryRegistry { + pub fn new() -> Self { + Self(HashMap::new()) + } + + /// Load a graph from the files contained in the `path` directory. + /// + /// This expects the backend to know how to load graphs (i.e., ML model) + /// from a directory. The name used in the registry is the directory's last + /// suffix: if the backend can find the files it expects in `/my/model/foo`, + /// the registry will contain a new graph named `foo`. + pub fn load(&mut self, backend: &mut dyn BackendFromDir, path: &Path) -> anyhow::Result<()> { + if !path.is_dir() { + bail!( + "preload directory is not a valid directory: {}", + path.display() + ); + } + let name = path + .file_name() + .map(|s| s.to_string_lossy()) + .ok_or(anyhow!("no file name in path"))?; + + let graph = backend.load_from_dir(path, ExecutionTarget::Cpu)?; + self.0.insert(name.into_owned(), graph); + Ok(()) + } +} + +impl GraphRegistry for InMemoryRegistry { + fn get_mut(&mut self, name: &str) -> Option<&mut Graph> { + self.0.get_mut(name) + } +} diff --git a/crates/wasi-nn/src/registry/mod.rs b/crates/wasi-nn/src/registry/mod.rs new file mode 100644 index 000000000000..83f88e4dca0e --- /dev/null +++ b/crates/wasi-nn/src/registry/mod.rs @@ -0,0 +1,16 @@ +//! Define the registry API. +//! +//! A [`GraphRegistry`] is place to store backend graphs so they can be loaded +//! by name. This API does not mandate how a graph is loaded or how it must be +//! stored--it could be stored remotely and rematerialized when needed, e.g. A +//! naive in-memory implementation, [`InMemoryRegistry`] is provided for use +//! with the Wasmtime CLI. + +mod in_memory; + +use crate::Graph; +pub use in_memory::InMemoryRegistry; + +pub trait GraphRegistry: Send + Sync { + fn get_mut(&mut self, name: &str) -> Option<&mut Graph>; +} diff --git a/crates/wasi-nn/src/wit.rs b/crates/wasi-nn/src/wit.rs index 25510374d5cf..2b2032bc3109 100644 --- a/crates/wasi-nn/src/wit.rs +++ b/crates/wasi-nn/src/wit.rs @@ -53,9 +53,14 @@ impl gen::graph::Host for WasiNnCtx { fn load_by_name( &mut self, - _name: String, + name: String, ) -> wasmtime::Result> { - todo!() + if let Some(graph) = self.registry.get_mut(&name) { + let graph_id = self.graphs.insert(graph.clone().into()); + Ok(Ok(graph_id)) + } else { + return Err(UsageError::NotFound(name.to_string()).into()); + } } } @@ -67,7 +72,7 @@ impl gen::inference::Host for WasiNnCtx { &mut self, graph_id: gen::graph::Graph, ) -> wasmtime::Result> { - let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) { + let exec_context = if let Some(graph) = self.graphs.get(graph_id) { graph.init_execution_context()? } else { return Err(UsageError::InvalidGraphHandle.into()); diff --git a/crates/wasi-nn/src/witx.rs b/crates/wasi-nn/src/witx.rs index b339d9a8d389..bf5a57b8e980 100644 --- a/crates/wasi-nn/src/witx.rs +++ b/crates/wasi-nn/src/witx.rs @@ -78,8 +78,14 @@ impl<'a> gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { Ok(graph_id.into()) } - fn load_by_name<'b>(&mut self, _name: &wiggle::GuestPtr<'b, str>) -> Result { - todo!() + fn load_by_name<'b>(&mut self, name: &wiggle::GuestPtr<'b, str>) -> Result { + let name = name.as_str()?.unwrap(); + if let Some(graph) = self.registry.get_mut(&name) { + let graph_id = self.graphs.insert(graph.clone().into()); + Ok(graph_id.into()) + } else { + return Err(UsageError::NotFound(name.to_string()).into()); + } } fn init_execution_context( diff --git a/src/commands/run.rs b/src/commands/run.rs index 89a050e1b105..43d2f0d237a9 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -51,6 +51,14 @@ fn parse_map_dirs(s: &str) -> Result<(String, String)> { Ok((parts[0].into(), parts[1].into())) } +fn parse_graphs(s: &str) -> Result<(String, String)> { + let parts: Vec<&str> = s.split("::").collect(); + if parts.len() != 2 { + bail!("must contain exactly one double colon ('::')"); + } + Ok((parts[0].into(), parts[1].into())) +} + fn parse_dur(s: &str) -> Result { // assume an integer without a unit specified is a number of seconds ... if let Ok(val) = s.parse() { @@ -158,6 +166,17 @@ pub struct RunCommand { #[clap(long = "mapdir", number_of_values = 1, value_name = "GUEST_DIR::HOST_DIR", value_parser = parse_map_dirs)] map_dirs: Vec<(String, String)>, + /// Pre-load machine learning graphs (i.e., models) for use by wasi-nn. + /// + /// Each use of the flag will preload a ML model from the host directory + /// using the given model encoding. The model will be mapped to the + /// directory name: e.g., `--graph openvino:/foo/bar` will preload an + /// OpenVINO model named `bar`. Note that which model encodings are + /// available is dependent on the backends implemented in the + /// `wasmtime_wasi_nn` crate. + #[clap(long = "graph", value_name = "FORMAT::HOST_DIR", value_parser = parse_graphs)] + graphs: Vec<(String, String)>, + /// Load the given WebAssembly module before the main module #[clap( long = "preload", @@ -922,7 +941,8 @@ impl RunCommand { })?; } } - store.data_mut().wasi_nn = Some(Arc::new(WasiNnCtx::default())); + let (backends, registry) = wasmtime_wasi_nn::preload(&self.graphs)?; + store.data_mut().wasi_nn = Some(Arc::new(WasiNnCtx::new(backends, registry))); } }