Skip to content

Commit

Permalink
wasi-nn: add named models (bytecodealliance#6854)
Browse files Browse the repository at this point in the history
* wasi-nn: add [named models]

This change adds a way to retrieve preloaded ML models (i.e., "graphs"
in wasi-nn terms) from a registry. The wasi-nn specification includes a
new function, `load_by_name`, that can be used to access these models
more efficiently than before; previously, a user's only option was to
read/download/etc. all of the bytes of an ML model and pass them to the
`load` function.

[named models]: WebAssembly/wasi-nn#36

In Wasmtime's implementation of wasi-nn, we call the registry that holds
the models a `GraphRegistry`. We include a simplistic `InMemoryRegistry`
for use in the Wasmtime CLI (more on this later) but the idea is that
production use will involve some more complex caching and thus a new
implementation of a registry--a `Box<dyn GraphRegistry>`--passed into
the wasi-nn context. Note that, because we now must be able to `clone` a
graph out of the registry and into the "used graphs" table, the OpenVINO
`BackendGraph` is updated to be easier to copy around.

To allow experimentation with this "preload a named model"
functionality, this change also adds a new Wasmtime CLI flag: `--graph
<encoding>:<host dir>`. Wasmtime CLI users can now preload a model from
a directory; the directory `basename` is used as the model name. Loading
models from a directory is probably not desired in Wasmtime embeddings
so it is cordoned off into a separate `BackendFromDir` extension trait.

* wasi-nn: add "named model" test

Add a new example crate which loads a model by name and performs image
classification. It uses the same MobileNet model as the existing test
but a new version of the Rust bindings. The new crate is built and run
with the new CLI flag in the `ci/run-wasi-nn-example.sh` script.

prtest:full

* review: rename `--graph` to `--wasi-nn-graph`
  • Loading branch information
abrown authored and eduardomourar committed Sep 6, 2023
1 parent cd92093 commit d566f9e
Show file tree
Hide file tree
Showing 14 changed files with 388 additions and 44 deletions.
37 changes: 27 additions & 10 deletions ci/run-wasi-nn-example.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#!/bin/bash

# The following script demonstrates how to execute a machine learning inference using the wasi-nn
# module optionally compiled into Wasmtime. Calling it will download the necessary model and tensor
# files stored separately in $FIXTURE into $TMP_DIR (optionally pass a directory with existing files
# as the first argument to re-try the script). Then, it will compile the example code in
# crates/wasi-nn/tests/example into a Wasm file that is subsequently executed with the Wasmtime CLI.
# The following script demonstrates how to execute a machine learning inference
# using the wasi-nn module optionally compiled into Wasmtime. Calling it will
# download the necessary model and tensor files stored separately in $FIXTURE
# into $TMP_DIR (optionally pass a directory with existing files as the first
# argument to re-try the script). Then, it will compile and run several examples
# in the Wasmtime CLI.
set -e
WASMTIME_DIR=$(dirname "$0" | xargs dirname)
FIXTURE=https://github.com/intel/openvino-rs/raw/main/crates/openvino/tests/fixtures/mobilenet
Expand All @@ -18,24 +19,40 @@ else
REMOVE_TMP_DIR=0
fi

# Build Wasmtime with wasi-nn enabled; we attempt this first to avoid extra work if the build fails.
# One of the examples expects to be in a specifically-named directory.
mkdir -p $TMP_DIR/mobilenet
TMP_DIR=$TMP_DIR/mobilenet

# Build Wasmtime with wasi-nn enabled; we attempt this first to avoid extra work
# if the build fails.
cargo build -p wasmtime-cli --features wasi-nn

# Download all necessary test fixtures to the temporary directory.
wget --no-clobber $FIXTURE/mobilenet.bin --output-document=$TMP_DIR/model.bin
wget --no-clobber $FIXTURE/mobilenet.xml --output-document=$TMP_DIR/model.xml
wget --no-clobber $FIXTURE/tensor-1x224x224x3-f32.bgr --output-document=$TMP_DIR/tensor.bgr

# Now build an example that uses the wasi-nn API.
# Now build an example that uses the wasi-nn API. Run the example in Wasmtime
# (note that the example uses `fixture` as the expected location of the
# model/tensor files).
pushd $WASMTIME_DIR/crates/wasi-nn/examples/classification-example
cargo build --release --target=wasm32-wasi
cp target/wasm32-wasi/release/wasi-nn-example.wasm $TMP_DIR
popd
cargo run -- run --mapdir fixture::$TMP_DIR \
--wasi-modules=experimental-wasi-nn $TMP_DIR/wasi-nn-example.wasm

# Run the example in Wasmtime (note that the example uses `fixture` as the expected location of the model/tensor files).
cargo run -- run --mapdir fixture::$TMP_DIR --wasi-modules=experimental-wasi-nn $TMP_DIR/wasi-nn-example.wasm
# Build and run another example, this time using Wasmtime's graph flag to
# preload the model.
pushd $WASMTIME_DIR/crates/wasi-nn/examples/classification-example-named
cargo build --release --target=wasm32-wasi
cp target/wasm32-wasi/release/wasi-nn-example-named.wasm $TMP_DIR
popd
cargo run -- run --mapdir fixture::$TMP_DIR --wasi-nn-graph openvino::$TMP_DIR \
--wasi-modules=experimental-wasi-nn $TMP_DIR/wasi-nn-example-named.wasm

# Clean up the temporary directory only if it was not specified (users may want to keep the directory around).
# Clean up the temporary directory only if it was not specified (users may want
# to keep the directory around).
if [[ $REMOVE_TMP_DIR -eq 1 ]]; then
rm -rf $TMP_DIR
fi
74 changes: 74 additions & 0 deletions crates/wasi-nn/examples/classification-example-named/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions crates/wasi-nn/examples/classification-example-named/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "wasi-nn-example-named"
version = "0.0.0"
authors = ["The Wasmtime Project Developers"]
readme = "README.md"
edition = "2021"
publish = false

[dependencies]
wasi-nn = "0.5.0"

# This crate is built with the wasm32-wasi target, so it's separate
# from the main Wasmtime build, so use this directive to exclude it
# from the parent directory's workspace.
[workspace]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
This example project demonstrates using the `wasi-nn` API to perform ML inference. It consists of Rust code that is
built using the `wasm32-wasi` target. See `ci/run-wasi-nn-example.sh` for how this is used.
53 changes: 53 additions & 0 deletions crates/wasi-nn/examples/classification-example-named/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use std::fs;
use wasi_nn::*;

pub fn main() {
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU)
.build_from_cache("mobilenet")
.unwrap();
println!("Loaded a graph: {:?}", graph);

let mut context = graph.init_execution_context().unwrap();
println!("Created an execution context: {:?}", context);

// Load a tensor that precisely matches the graph input tensor (see
// `fixture/frozen_inference_graph.xml`).
let tensor_data = fs::read("fixture/tensor.bgr").unwrap();
println!("Read input tensor, size in bytes: {}", tensor_data.len());
context
.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data)
.unwrap();

// Execute the inference.
context.compute().unwrap();
println!("Executed graph inference");

// Retrieve the output.
let mut output_buffer = vec![0f32; 1001];
context.get_output(0, &mut output_buffer[..]).unwrap();

println!(
"Found results, sorted top 5: {:?}",
&sort_results(&output_buffer)[..5]
)
}

// Sort the buffer of probabilities. The graph places the match probability for each class at the
// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert
// to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output
// indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense
// (e.g. 763 = "revolver" vs 762 = "restaurant")
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
let mut results: Vec<InferenceResult> = buffer
.iter()
.skip(1)
.enumerate()
.map(|(c, p)| InferenceResult(c, *p))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results
}

// A wrapper for class ID and match probabilities.
#[derive(Debug, PartialEq)]
struct InferenceResult(usize, f32);
34 changes: 32 additions & 2 deletions crates/wasi-nn/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod openvino;
use self::openvino::OpenvinoBackend;
use crate::wit::types::{ExecutionTarget, Tensor};
use crate::{ExecutionContext, Graph};
use std::{error::Error, fmt, path::Path, str::FromStr};
use thiserror::Error;
use wiggle::GuestError;

Expand All @@ -15,16 +16,28 @@ pub fn list() -> Vec<(BackendKind, Box<dyn Backend>)> {
vec![(BackendKind::OpenVINO, Box::new(OpenvinoBackend::default()))]
}

/// A [Backend] contains the necessary state to load [BackendGraph]s.
/// A [Backend] contains the necessary state to load [Graph]s.
pub trait Backend: Send + Sync {
fn name(&self) -> &str;
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError>;
fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir>;
}

/// Some [Backend]s support loading a [Graph] from a directory on the
/// filesystem; this is not a general requirement for backends but is useful for
/// the Wasmtime CLI.
pub trait BackendFromDir: Backend {
fn load_from_dir(
&mut self,
builders: &Path,
target: ExecutionTarget,
) -> Result<Graph, BackendError>;
}

/// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing
/// implementation for a [crate::witx::types::Graph].
pub trait BackendGraph: Send + Sync {
fn init_execution_context(&mut self) -> Result<ExecutionContext, BackendError>;
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError>;
}

/// A [BackendExecutionContext] performs the actual inference; this is the
Expand Down Expand Up @@ -53,3 +66,20 @@ pub enum BackendError {
pub enum BackendKind {
OpenVINO,
}
impl FromStr for BackendKind {
type Err = BackendKindParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"openvino" => Ok(BackendKind::OpenVINO),
_ => Err(BackendKindParseError(s.into())),
}
}
}
#[derive(Debug)]
pub struct BackendKindParseError(String);
impl fmt::Display for BackendKindParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "unknown backend: {}", self.0)
}
}
impl Error for BackendKindParseError {}
45 changes: 38 additions & 7 deletions crates/wasi-nn/src/backend/openvino.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
//! Implements a `wasi-nn` [`Backend`] using OpenVINO.

use super::{Backend, BackendError, BackendExecutionContext, BackendGraph};
use super::{Backend, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph};
use crate::wit::types::{ExecutionTarget, Tensor, TensorType};
use crate::{ExecutionContext, Graph};
use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc};
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::{fs::File, io::Read, path::Path};

#[derive(Default)]
pub(crate) struct OpenvinoBackend(Option<openvino::Core>);
Expand Down Expand Up @@ -51,20 +52,42 @@ impl Backend for OpenvinoBackend {

let exec_network =
core.load_network(&cnn_network, map_execution_target_to_string(target))?;
let box_: Box<dyn BackendGraph> =
Box::new(OpenvinoGraph(Arc::new(cnn_network), exec_network));
let box_: Box<dyn BackendGraph> = Box::new(OpenvinoGraph(
Arc::new(cnn_network),
Arc::new(Mutex::new(exec_network)),
));
Ok(box_.into())
}

fn as_dir_loadable(&mut self) -> Option<&mut dyn BackendFromDir> {
Some(self)
}
}

struct OpenvinoGraph(Arc<openvino::CNNNetwork>, openvino::ExecutableNetwork);
impl BackendFromDir for OpenvinoBackend {
fn load_from_dir(
&mut self,
path: &Path,
target: ExecutionTarget,
) -> Result<Graph, BackendError> {
let model = read(&path.join("model.xml"))?;
let weights = read(&path.join("model.bin"))?;
self.load(&[&model, &weights], target)
}
}

struct OpenvinoGraph(
Arc<openvino::CNNNetwork>,
Arc<Mutex<openvino::ExecutableNetwork>>,
);

unsafe impl Send for OpenvinoGraph {}
unsafe impl Sync for OpenvinoGraph {}

impl BackendGraph for OpenvinoGraph {
fn init_execution_context(&mut self) -> Result<ExecutionContext, BackendError> {
let infer_request = self.1.create_infer_request()?;
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
let mut network = self.1.lock().unwrap();
let infer_request = network.create_infer_request()?;
let box_: Box<dyn BackendExecutionContext> =
Box::new(OpenvinoExecutionContext(self.0.clone(), infer_request));
Ok(box_.into())
Expand Down Expand Up @@ -145,3 +168,11 @@ fn map_tensor_type_to_precision(tensor_type: TensorType) -> openvino::Precision
TensorType::Bf16 => todo!("not yet supported in `openvino` bindings"),
}
}

/// Read a file into a byte vector.
fn read(path: &Path) -> anyhow::Result<Vec<u8>> {
let mut file = File::open(path)?;
let mut buffer = vec![];
file.read_to_end(&mut buffer)?;
Ok(buffer)
}
Loading

0 comments on commit d566f9e

Please sign in to comment.