Skip to content

Commit

Permalink
wasi-nn: add [named models]
Browse files Browse the repository at this point in the history
This change adds a way to retrieve preloaded ML models (i.e., "graphs"
in wasi-nn terms) from a registry. The wasi-nn specification includes a
new function, `load_by_name`, that can be used to access these models
more efficiently than before; previously, a user's only option was to
read/download/etc. all of the bytes of an ML model and pass them to the
`load` function.

[named models]: WebAssembly/wasi-nn#36

In Wasmtime's implementation of wasi-nn, we call the registry that holds
the models a `GraphRegistry`. We include a simplistic `InMemoryRegistry`
for use in the Wasmtime CLI (more on this later) but the idea is that
production use will involve some more complex caching and thus a new
implementation of a registry--a `Box<dyn GraphRegistry>`--passed into
the wasi-nn context. Note that, because we now must be able to `clone` a
graph out of the registry and into the "used graphs" table, the OpenVINO
`BackendGraph` is updated to be easier to copy around.

To allow experimentation with this "preload a named model"
functionality, this change also adds a new Wasmtime CLI flag: `--graph
<encoding>:<host dir>`. Wasmtime CLI users can now preload a model from
a directory; the directory `basename` is used as the model name. Loading
models from a directory is probably not desired in Wasmtime embeddings
so it is cordoned off into a separate `BackendFromDir` extension trait.
  • Loading branch information
abrown committed Aug 22, 2023
1 parent 819fad0 commit 7bfc3cf
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 34 deletions.
34 changes: 32 additions & 2 deletions crates/wasi-nn/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod openvino;
use self::openvino::OpenvinoBackend;
use crate::wit::types::{ExecutionTarget, Tensor};
use crate::{ExecutionContext, Graph};
use std::{error::Error, fmt, path::Path, str::FromStr};
use thiserror::Error;
use wiggle::GuestError;

Expand All @@ -15,16 +16,28 @@ pub fn list() -> Vec<(BackendKind, Box<dyn Backend>)> {
vec![(BackendKind::OpenVINO, Box::new(OpenvinoBackend::default()))]
}

/// A [Backend] contains the necessary state to load [BackendGraph]s.
/// A [Backend] contains the necessary state to load [Graph]s.
pub trait Backend: Send + Sync {
fn name(&self) -> &str;
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError>;
fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir>;
}

/// Some [Backend]s support loading a [Graph] from a directory on the
/// filesystem; this is not a general requirement for backends but is useful for
/// the Wasmtime CLI.
pub trait BackendFromDir: Backend {
fn load_from_dir(
&mut self,
builders: &Path,
target: ExecutionTarget,
) -> Result<Graph, BackendError>;
}

/// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing
/// implementation for a [crate::witx::types::Graph].
pub trait BackendGraph: Send + Sync {
fn init_execution_context(&mut self) -> Result<ExecutionContext, BackendError>;
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError>;
}

/// A [BackendExecutionContext] performs the actual inference; this is the
Expand Down Expand Up @@ -53,3 +66,20 @@ pub enum BackendError {
pub enum BackendKind {
OpenVINO,
}
impl FromStr for BackendKind {
type Err = BackendKindParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"openvino" => Ok(BackendKind::OpenVINO),
_ => Err(BackendKindParseError(s.into())),
}
}
}
#[derive(Debug)]
pub struct BackendKindParseError(String);
impl fmt::Display for BackendKindParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "unknown backend: {}", self.0)
}
}
impl Error for BackendKindParseError {}
45 changes: 38 additions & 7 deletions crates/wasi-nn/src/backend/openvino.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
//! Implements a `wasi-nn` [`Backend`] using OpenVINO.

use super::{Backend, BackendError, BackendExecutionContext, BackendGraph};
use super::{Backend, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph};
use crate::wit::types::{ExecutionTarget, Tensor, TensorType};
use crate::{ExecutionContext, Graph};
use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc};
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::{fs::File, io::Read, path::Path};

#[derive(Default)]
pub(crate) struct OpenvinoBackend(Option<openvino::Core>);
Expand Down Expand Up @@ -51,20 +52,42 @@ impl Backend for OpenvinoBackend {

let exec_network =
core.load_network(&cnn_network, map_execution_target_to_string(target))?;
let box_: Box<dyn BackendGraph> =
Box::new(OpenvinoGraph(Arc::new(cnn_network), exec_network));
let box_: Box<dyn BackendGraph> = Box::new(OpenvinoGraph(
Arc::new(cnn_network),
Arc::new(Mutex::new(exec_network)),
));
Ok(box_.into())
}

fn as_dir_loadable(&mut self) -> Option<&mut dyn BackendFromDir> {
Some(self)
}
}

struct OpenvinoGraph(Arc<openvino::CNNNetwork>, openvino::ExecutableNetwork);
impl BackendFromDir for OpenvinoBackend {
fn load_from_dir(
&mut self,
path: &Path,
target: ExecutionTarget,
) -> Result<Graph, BackendError> {
let model = read(&path.join("model.xml"))?;
let weights = read(&path.join("model.bin"))?;
self.load(&[&model, &weights], target)
}
}

struct OpenvinoGraph(
Arc<openvino::CNNNetwork>,
Arc<Mutex<openvino::ExecutableNetwork>>,
);

unsafe impl Send for OpenvinoGraph {}
unsafe impl Sync for OpenvinoGraph {}

impl BackendGraph for OpenvinoGraph {
fn init_execution_context(&mut self) -> Result<ExecutionContext, BackendError> {
let infer_request = self.1.create_infer_request()?;
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
let mut network = self.1.lock().unwrap();
let infer_request = network.create_infer_request()?;
let box_: Box<dyn BackendExecutionContext> =
Box::new(OpenvinoExecutionContext(self.0.clone(), infer_request));
Ok(box_.into())
Expand Down Expand Up @@ -145,3 +168,11 @@ fn map_tensor_type_to_precision(tensor_type: TensorType) -> openvino::Precision
TensorType::Bf16 => todo!("not yet supported in `openvino` bindings"),
}
}

/// Read a file into a byte vector.
fn read(path: &Path) -> anyhow::Result<Vec<u8>> {
let mut file = File::open(path)?;
let mut buffer = vec![];
file.read_to_end(&mut buffer)?;
Ok(buffer)
}
54 changes: 43 additions & 11 deletions crates/wasi-nn/src/ctx.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,59 @@
//! Implements the host state for the `wasi-nn` API: [WasiNnCtx].

use crate::backend::{self, Backend, BackendError, BackendKind};
use crate::backend::{Backend, BackendError, BackendKind};
use crate::wit::types::GraphEncoding;
use crate::{ExecutionContext, Graph};
use std::{collections::HashMap, hash::Hash};
use crate::{ExecutionContext, Graph, GraphRegistry, InMemoryRegistry};
use anyhow::anyhow;
use std::{collections::HashMap, hash::Hash, path::Path};
use thiserror::Error;
use wiggle::GuestError;

type Backends = HashMap<BackendKind, Box<dyn Backend>>;
type Registry = Box<dyn GraphRegistry>;
type GraphId = u32;
type GraphExecutionContextId = u32;
type BackendName = String;
type GraphDirectory = String;

/// Construct an in-memory registry from the available backends and a list of
/// `(<backend name>, <graph directory>)`. This assumes graphs can be loaded
/// from a local directory, which is a safe assumption currently for the current
/// model types.
pub fn preload(
preload_graphs: &[(BackendName, GraphDirectory)],
) -> anyhow::Result<(Backends, Registry)> {
let mut backends: HashMap<_, _> = crate::backend::list().into_iter().collect();
let mut registry = InMemoryRegistry::new();
for (kind, path) in preload_graphs {
let backend = backends
.get_mut(&kind.parse()?)
.ok_or(anyhow!("unsupported backend: {}", kind))?
.as_dir_loadable()
.ok_or(anyhow!("{} does not support directory loading", kind))?;
registry.load(backend, Path::new(path))?;
}
Ok((backends, Box::new(registry)))
}

/// Capture the state necessary for calling into the backend ML libraries.
pub struct WasiNnCtx {
pub(crate) backends: Backends,
pub(crate) registry: Registry,
pub(crate) graphs: Table<GraphId, Graph>,
pub(crate) executions: Table<GraphExecutionContextId, ExecutionContext>,
}

impl WasiNnCtx {
/// Make a new context from the default state.
pub fn new(backends: Backends) -> Self {
pub fn new(backends: Backends, registry: Registry) -> Self {
Self {
backends,
registry,
graphs: Table::default(),
executions: Table::default(),
}
}
}
impl Default for WasiNnCtx {
fn default() -> Self {
WasiNnCtx::new(backend::list().into_iter().collect())
}
}

/// Possible errors while interacting with [WasiNnCtx].
#[derive(Debug, Error)]
Expand Down Expand Up @@ -90,6 +111,10 @@ where
key
}

pub fn get(&self, key: K) -> Option<&V> {
self.entries.get(&key)
}

pub fn get_mut(&mut self, key: K) -> Option<&mut V> {
self.entries.get_mut(&key)
}
Expand All @@ -106,7 +131,14 @@ mod test {
use super::*;

#[test]
fn instantiate() {
WasiNnCtx::default();
fn example() {
struct FakeRegistry;
impl GraphRegistry for FakeRegistry {
fn get_mut(&mut self, _: &str) -> Option<&mut Graph> {
None
}
}

let ctx = WasiNnCtx::new(HashMap::new(), Box::new(FakeRegistry));
}
}
16 changes: 8 additions & 8 deletions crates/wasi-nn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
mod backend;
mod ctx;
mod registry;

pub use ctx::WasiNnCtx;
pub use ctx::{preload, WasiNnCtx};
pub use registry::{GraphRegistry, InMemoryRegistry};
pub mod wit;
pub mod witx;

use std::sync::Arc;

/// A backend-defined graph (i.e., ML model).
pub struct Graph(Box<dyn backend::BackendGraph>);
#[derive(Clone)]
pub struct Graph(Arc<dyn backend::BackendGraph>);
impl From<Box<dyn backend::BackendGraph>> for Graph {
fn from(value: Box<dyn backend::BackendGraph>) -> Self {
Self(value)
Self(value.into())
}
}
impl std::ops::Deref for Graph {
Expand All @@ -18,11 +23,6 @@ impl std::ops::Deref for Graph {
self.0.as_ref()
}
}
impl std::ops::DerefMut for Graph {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.as_mut()
}
}

/// A backend-defined execution context.
pub struct ExecutionContext(Box<dyn backend::BackendExecutionContext>);
Expand Down
43 changes: 43 additions & 0 deletions crates/wasi-nn/src/registry/in_memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//! Implement a [`GraphRegistry`] with a hash map.

use super::{Graph, GraphRegistry};
use crate::backend::BackendFromDir;
use crate::wit::types::ExecutionTarget;
use anyhow::{anyhow, bail};
use std::{collections::HashMap, path::Path};

pub struct InMemoryRegistry(HashMap<String, Graph>);
impl InMemoryRegistry {
pub fn new() -> Self {
Self(HashMap::new())
}

/// Load a graph from the files contained in the `path` directory.
///
/// This expects the backend to know how to load graphs (i.e., ML model)
/// from a directory. The name used in the registry is the directory's last
/// suffix: if the backend can find the files it expects in `/my/model/foo`,
/// the registry will contain a new graph named `foo`.
pub fn load(&mut self, backend: &mut dyn BackendFromDir, path: &Path) -> anyhow::Result<()> {
if !path.is_dir() {
bail!(
"preload directory is not a valid directory: {}",
path.display()
);
}
let name = path
.file_name()
.map(|s| s.to_string_lossy())
.ok_or(anyhow!("no file name in path"))?;

let graph = backend.load_from_dir(path, ExecutionTarget::Cpu)?;
self.0.insert(name.into_owned(), graph);
Ok(())
}
}

impl GraphRegistry for InMemoryRegistry {
fn get_mut(&mut self, name: &str) -> Option<&mut Graph> {
self.0.get_mut(name)
}
}
16 changes: 16 additions & 0 deletions crates/wasi-nn/src/registry/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//! Define the registry API.
//!
//! A [`GraphRegistry`] is place to store backend graphs so they can be loaded
//! by name. This API does not mandate how a graph is loaded or how it must be
//! stored--it could be stored remotely and rematerialized when needed, e.g. A
//! naive in-memory implementation, [`InMemoryRegistry`] is provided for use
//! with the Wasmtime CLI.

mod in_memory;

use crate::Graph;
pub use in_memory::InMemoryRegistry;

pub trait GraphRegistry: Send + Sync {
fn get_mut(&mut self, name: &str) -> Option<&mut Graph>;
}
11 changes: 8 additions & 3 deletions crates/wasi-nn/src/wit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,14 @@ impl gen::graph::Host for WasiNnCtx {

fn load_by_name(
&mut self,
_name: String,
name: String,
) -> wasmtime::Result<Result<gen::graph::Graph, gen::errors::Error>> {
todo!()
if let Some(graph) = self.registry.get_mut(&name) {
let graph_id = self.graphs.insert(graph.clone().into());
Ok(Ok(graph_id))
} else {
return Err(UsageError::NotFound(name.to_string()).into());
}
}
}

Expand All @@ -67,7 +72,7 @@ impl gen::inference::Host for WasiNnCtx {
&mut self,
graph_id: gen::graph::Graph,
) -> wasmtime::Result<Result<gen::inference::GraphExecutionContext, gen::errors::Error>> {
let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) {
let exec_context = if let Some(graph) = self.graphs.get(graph_id) {
graph.init_execution_context()?
} else {
return Err(UsageError::InvalidGraphHandle.into());
Expand Down
10 changes: 8 additions & 2 deletions crates/wasi-nn/src/witx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,14 @@ impl<'a> gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx {
Ok(graph_id.into())
}

fn load_by_name<'b>(&mut self, _name: &wiggle::GuestPtr<'b, str>) -> Result<gen::types::Graph> {
todo!()
fn load_by_name<'b>(&mut self, name: &wiggle::GuestPtr<'b, str>) -> Result<gen::types::Graph> {
let name = name.as_str()?.unwrap();
if let Some(graph) = self.registry.get_mut(&name) {
let graph_id = self.graphs.insert(graph.clone().into());
Ok(graph_id.into())
} else {
return Err(UsageError::NotFound(name.to_string()).into());
}
}

fn init_execution_context(
Expand Down
Loading

0 comments on commit 7bfc3cf

Please sign in to comment.