-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This change adds a way to retrieve preloaded ML models (i.e., "graphs" in wasi-nn terms) from a registry. The wasi-nn specification includes a new function, `load_by_name`, that can be used to access these models more efficiently than before; previously, a user's only option was to read/download/etc. all of the bytes of an ML model and pass them to the `load` function. [named models]: WebAssembly/wasi-nn#36 In Wasmtime's implementation of wasi-nn, we call the registry that holds the models a `GraphRegistry`. We include a simplistic `InMemoryRegistry` for use in the Wasmtime CLI (more on this later) but the idea is that production use will involve some more complex caching and thus a new implementation of a registry--a `Box<dyn GraphRegistry>`--passed into the wasi-nn context. Note that, because we now must be able to `clone` a graph out of the registry and into the "used graphs" table, the OpenVINO `BackendGraph` is updated to be easier to copy around. To allow experimentation with this "preload a named model" functionality, this change also adds a new Wasmtime CLI flag: `--graph <encoding>:<host dir>`. Wasmtime CLI users can now preload a model from a directory; the directory `basename` is used as the model name. Loading models from a directory is probably not desired in Wasmtime embeddings so it is cordoned off into a separate `BackendFromDir` extension trait.
- Loading branch information
Showing
9 changed files
with
217 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
//! Implement a [`GraphRegistry`] with a hash map. | ||
|
||
use super::{Graph, GraphRegistry}; | ||
use crate::backend::BackendFromDir; | ||
use crate::wit::types::ExecutionTarget; | ||
use anyhow::{anyhow, bail}; | ||
use std::{collections::HashMap, path::Path}; | ||
|
||
pub struct InMemoryRegistry(HashMap<String, Graph>); | ||
impl InMemoryRegistry { | ||
pub fn new() -> Self { | ||
Self(HashMap::new()) | ||
} | ||
|
||
/// Load a graph from the files contained in the `path` directory. | ||
/// | ||
/// This expects the backend to know how to load graphs (i.e., ML model) | ||
/// from a directory. The name used in the registry is the directory's last | ||
/// suffix: if the backend can find the files it expects in `/my/model/foo`, | ||
/// the registry will contain a new graph named `foo`. | ||
pub fn load(&mut self, backend: &mut dyn BackendFromDir, path: &Path) -> anyhow::Result<()> { | ||
if !path.is_dir() { | ||
bail!( | ||
"preload directory is not a valid directory: {}", | ||
path.display() | ||
); | ||
} | ||
let name = path | ||
.file_name() | ||
.map(|s| s.to_string_lossy()) | ||
.ok_or(anyhow!("no file name in path"))?; | ||
|
||
let graph = backend.load_from_dir(path, ExecutionTarget::Cpu)?; | ||
self.0.insert(name.into_owned(), graph); | ||
Ok(()) | ||
} | ||
} | ||
|
||
impl GraphRegistry for InMemoryRegistry { | ||
fn get_mut(&mut self, name: &str) -> Option<&mut Graph> { | ||
self.0.get_mut(name) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
//! Define the registry API. | ||
//! | ||
//! A [`GraphRegistry`] is place to store backend graphs so they can be loaded | ||
//! by name. This API does not mandate how a graph is loaded or how it must be | ||
//! stored--it could be stored remotely and rematerialized when needed, e.g. A | ||
//! naive in-memory implementation, [`InMemoryRegistry`] is provided for use | ||
//! with the Wasmtime CLI. | ||
|
||
mod in_memory; | ||
|
||
use crate::Graph; | ||
pub use in_memory::InMemoryRegistry; | ||
|
||
pub trait GraphRegistry: Send + Sync { | ||
fn get_mut(&mut self, name: &str) -> Option<&mut Graph>; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.