Skip to content

Commit

Permalink
wasi-nn: add "named model" test
Browse files Browse the repository at this point in the history
Add a new example crate which loads a model by name and performs image
classification. It uses the same MobileNet model as the existing test
but a new version of the Rust bindings. The new crate is built and run
with the new CLI flag in the `ci/run-wasi-nn-example.sh` script.

prtest:full
  • Loading branch information
abrown committed Aug 22, 2023
1 parent 7bfc3cf commit c03abac
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 10 deletions.
37 changes: 27 additions & 10 deletions ci/run-wasi-nn-example.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#!/bin/bash

# The following script demonstrates how to execute a machine learning inference using the wasi-nn
# module optionally compiled into Wasmtime. Calling it will download the necessary model and tensor
# files stored separately in $FIXTURE into $TMP_DIR (optionally pass a directory with existing files
# as the first argument to re-try the script). Then, it will compile the example code in
# crates/wasi-nn/tests/example into a Wasm file that is subsequently executed with the Wasmtime CLI.
# The following script demonstrates how to execute a machine learning inference
# using the wasi-nn module optionally compiled into Wasmtime. Calling it will
# download the necessary model and tensor files stored separately in $FIXTURE
# into $TMP_DIR (optionally pass a directory with existing files as the first
# argument to re-try the script). Then, it will compile and run several examples
# in the Wasmtime CLI.
set -e
WASMTIME_DIR=$(dirname "$0" | xargs dirname)
FIXTURE=https://github.com/intel/openvino-rs/raw/main/crates/openvino/tests/fixtures/mobilenet
Expand All @@ -18,24 +19,40 @@ else
REMOVE_TMP_DIR=0
fi

# Build Wasmtime with wasi-nn enabled; we attempt this first to avoid extra work if the build fails.
# One of the examples expects to be in a specifically-named directory.
mkdir -p $TMP_DIR/mobilenet
TMP_DIR=$TMP_DIR/mobilenet

# Build Wasmtime with wasi-nn enabled; we attempt this first to avoid extra work
# if the build fails.
cargo build -p wasmtime-cli --features wasi-nn

# Download all necessary test fixtures to the temporary directory.
wget --no-clobber $FIXTURE/mobilenet.bin --output-document=$TMP_DIR/model.bin
wget --no-clobber $FIXTURE/mobilenet.xml --output-document=$TMP_DIR/model.xml
wget --no-clobber $FIXTURE/tensor-1x224x224x3-f32.bgr --output-document=$TMP_DIR/tensor.bgr

# Now build an example that uses the wasi-nn API.
# Now build an example that uses the wasi-nn API. Run the example in Wasmtime
# (note that the example uses `fixture` as the expected location of the
# model/tensor files).
pushd $WASMTIME_DIR/crates/wasi-nn/examples/classification-example
cargo build --release --target=wasm32-wasi
cp target/wasm32-wasi/release/wasi-nn-example.wasm $TMP_DIR
popd
cargo run -- run --mapdir fixture::$TMP_DIR \
--wasi-modules=experimental-wasi-nn $TMP_DIR/wasi-nn-example.wasm

# Run the example in Wasmtime (note that the example uses `fixture` as the expected location of the model/tensor files).
cargo run -- run --mapdir fixture::$TMP_DIR --wasi-modules=experimental-wasi-nn $TMP_DIR/wasi-nn-example.wasm
# Build and run another example, this time using Wasmtime's --graph flag to
# preload the model.
pushd $WASMTIME_DIR/crates/wasi-nn/examples/classification-example-named
cargo build --release --target=wasm32-wasi
cp target/wasm32-wasi/release/wasi-nn-example-named.wasm $TMP_DIR
popd
cargo run -- run --mapdir fixture::$TMP_DIR --graph openvino::$TMP_DIR \
--wasi-modules=experimental-wasi-nn $TMP_DIR/wasi-nn-example-named.wasm

# Clean up the temporary directory only if it was not specified (users may want to keep the directory around).
# Clean up the temporary directory only if it was not specified (users may want
# to keep the directory around).
if [[ $REMOVE_TMP_DIR -eq 1 ]]; then
rm -rf $TMP_DIR
fi
74 changes: 74 additions & 0 deletions crates/wasi-nn/examples/classification-example-named/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions crates/wasi-nn/examples/classification-example-named/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "wasi-nn-example-named"
version = "0.0.0"
authors = ["The Wasmtime Project Developers"]
readme = "README.md"
edition = "2021"
publish = false

[dependencies]
wasi-nn = "0.5.0"

# This crate is built with the wasm32-wasi target, so it's separate
# from the main Wasmtime build, so use this directive to exclude it
# from the parent directory's workspace.
[workspace]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
This example project demonstrates using the `wasi-nn` API to perform ML inference. It consists of Rust code that is
built using the `wasm32-wasi` target. See `ci/run-wasi-nn-example.sh` for how this is used.
53 changes: 53 additions & 0 deletions crates/wasi-nn/examples/classification-example-named/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use std::fs;
use wasi_nn::*;

pub fn main() {
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU)
.build_from_cache("mobilenet")
.unwrap();
println!("Loaded a graph: {:?}", graph);

let mut context = graph.init_execution_context().unwrap();
println!("Created an execution context: {:?}", context);

// Load a tensor that precisely matches the graph input tensor (see
// `fixture/frozen_inference_graph.xml`).
let tensor_data = fs::read("fixture/tensor.bgr").unwrap();
println!("Read input tensor, size in bytes: {}", tensor_data.len());
context
.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data)
.unwrap();

// Execute the inference.
context.compute().unwrap();
println!("Executed graph inference");

// Retrieve the output.
let mut output_buffer = vec![0f32; 1001];
context.get_output(0, &mut output_buffer[..]).unwrap();

println!(
"Found results, sorted top 5: {:?}",
&sort_results(&output_buffer)[..5]
)
}

// Sort the buffer of probabilities. The graph places the match probability for each class at the
// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert
// to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output
// indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense
// (e.g. 763 = "revolver" vs 762 = "restaurant")
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
let mut results: Vec<InferenceResult> = buffer
.iter()
.skip(1)
.enumerate()
.map(|(c, p)| InferenceResult(c, *p))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results
}

// A wrapper for class ID and match probabilities.
#[derive(Debug, PartialEq)]
struct InferenceResult(usize, f32);

0 comments on commit c03abac

Please sign in to comment.