Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wasi-nn: adapt to new test infrastructure #7679

Merged
merged 6 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,10 @@ jobs:
- run: echo CARGO_BUILD_TARGET=${{ matrix.target }} >> $GITHUB_ENV
if: matrix.target != ''

# Install OpenVINO for testing wasmtime-wasi-nn.
- uses: abrown/install-openvino-action@v8
if: runner.arch == 'X64'

# Fix an ICE for now in gcc when compiling zstd with debuginfo (??)
- run: echo CFLAGS=-g0 >> $GITHUB_ENV
if: matrix.target == 'x86_64-pc-windows-gnu'
Expand Down
13 changes: 13 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/test-programs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ workspace = true
[dependencies]
anyhow = { workspace = true }
wasi = "0.11.0"
wasi-nn = "0.6.0"
wit-bindgen = { workspace = true, features = ['default'] }
libc = { workspace = true }
getrandom = "0.2.9"
Expand Down
21 changes: 13 additions & 8 deletions crates/test-programs/artifacts/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,6 @@ fn build_and_generate_tests() {

generated_code += &format!("pub const {camel}: &'static str = {wasm:?};\n");

let adapter = match target.as_str() {
"reactor" => &reactor_adapter,
s if s.starts_with("api_proxy") => &proxy_adapter,
_ => &command_adapter,
};
let path = compile_component(&wasm, adapter);
generated_code += &format!("pub const {camel}_COMPONENT: &'static str = {path:?};\n");

// Bucket, based on the name of the test, into a "kind" which generates
// a `foreach_*` macro below.
let kind = match target.as_str() {
Expand All @@ -81,6 +73,7 @@ fn build_and_generate_tests() {
s if s.starts_with("preview2_") => "preview2",
s if s.starts_with("cli_") => "cli",
s if s.starts_with("api_") => "api",
s if s.starts_with("nn_") => "nn",
// If you're reading this because you hit this panic, either add it
// to a test suite above or add a new "suite". The purpose of the
// categorization above is to have a static assertion that tests
Expand All @@ -93,6 +86,18 @@ fn build_and_generate_tests() {
if !kind.is_empty() {
kinds.entry(kind).or_insert(Vec::new()).push(target);
}

// Generate a component from each test.
if kind == "nn" {
continue;
}
let adapter = match target.as_str() {
"reactor" => &reactor_adapter,
s if s.starts_with("api_proxy") => &proxy_adapter,
_ => &command_adapter,
};
let path = compile_component(&wasm, adapter);
generated_code += &format!("pub const {camel}_COMPONENT: &'static str = {path:?};\n");
}

for (kind, targets) in kinds {
Expand Down
59 changes: 59 additions & 0 deletions crates/test-programs/src/bin/nn_image_classification.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use anyhow::Result;
use std::fs;
use wasi_nn::*;

pub fn main() -> Result<()> {
let xml = fs::read_to_string("fixture/model.xml").unwrap();
println!("Read graph XML, first 50 characters: {}", &xml[..50]);

let weights = fs::read("fixture/model.bin").unwrap();
println!("Read graph weights, size in bytes: {}", weights.len());

let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU)
.build_from_bytes([&xml.into_bytes(), &weights])?;
println!("Loaded graph into wasi-nn with ID: {}", graph);

let mut context = graph.init_execution_context()?;
println!("Created wasi-nn execution context with ID: {}", context);

// Load a tensor that precisely matches the graph input tensor (see
// `fixture/frozen_inference_graph.xml`).
let data = fs::read("fixture/tensor.bgr").unwrap();
println!("Read input tensor, size in bytes: {}", data.len());
context.set_input(0, wasi_nn::TensorType::F32, &[1, 3, 224, 224], &data)?;

// Execute the inference.
context.compute()?;
println!("Executed graph inference");

// Retrieve the output.
let mut output_buffer = vec![0f32; 1001];
context.get_output(0, &mut output_buffer[..])?;
println!(
"Found results, sorted top 5: {:?}",
&sort_results(&output_buffer)[..5]
);

Ok(())
}

// Sort the buffer of probabilities. The graph places the match probability for
// each class at the index for that class (e.g. the probability of class 42 is
// placed at buffer[42]). Here we convert to a wrapping InferenceResult and sort
// the results. It is unclear why the MobileNet output indices are "off by one"
// but the `.skip(1)` below seems necessary to get results that make sense (e.g.
// 763 = "revolver" vs 762 = "restaurant").
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
let mut results: Vec<InferenceResult> = buffer
.iter()
.skip(1)
.enumerate()
.map(|(c, p)| InferenceResult(c, *p))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results
}

// A wrapper for class ID and match probabilities.
#[derive(Debug, PartialEq)]
struct InferenceResult(usize, f32);
53 changes: 53 additions & 0 deletions crates/test-programs/src/bin/nn_image_classification_named.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use anyhow::Result;
use std::fs;
use wasi_nn::*;

pub fn main() -> Result<()> {
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU)
.build_from_cache("mobilenet")?;
println!("Loaded a graph: {:?}", graph);

let mut context = graph.init_execution_context()?;
println!("Created an execution context: {:?}", context);

// Load a tensor that precisely matches the graph input tensor (see
// `fixture/frozen_inference_graph.xml`).
let tensor_data = fs::read("fixture/tensor.bgr")?;
println!("Read input tensor, size in bytes: {}", tensor_data.len());
context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data)?;

// Execute the inference.
context.compute()?;
println!("Executed graph inference");

// Retrieve the output.
let mut output_buffer = vec![0f32; 1001];
context.get_output(0, &mut output_buffer[..])?;

println!(
"Found results, sorted top 5: {:?}",
&sort_results(&output_buffer)[..5]
);
Ok(())
}

// Sort the buffer of probabilities. The graph places the match probability for
// each class at the index for that class (e.g. the probability of class 42 is
// placed at buffer[42]). Here we convert to a wrapping InferenceResult and sort
// the results. It is unclear why the MobileNet output indices are "off by one"
// but the `.skip(1)` below seems necessary to get results that make sense (e.g.
// 763 = "revolver" vs 762 = "restaurant").
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
let mut results: Vec<InferenceResult> = buffer
.iter()
.skip(1)
.enumerate()
.map(|(c, p)| InferenceResult(c, *p))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results
}

// A wrapper for class ID and match probabilities.
#[derive(Debug, PartialEq)]
struct InferenceResult(usize, f32);
6 changes: 6 additions & 0 deletions crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,9 @@ thiserror = { workspace = true }

[build-dependencies]
walkdir = { workspace = true }

[dev-dependencies]
cap-std = { workspace = true }
test-programs-artifacts = { workspace = true }
wasmtime-wasi = { workspace = true, features = ["sync"] }
wasmtime = { workspace = true, features = ["cranelift"] }
1 change: 1 addition & 0 deletions crates/wasi-nn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod registry;
pub mod backend;
pub use ctx::{preload, WasiNnCtx};
pub use registry::{GraphRegistry, InMemoryRegistry};
pub mod testing;
pub mod wit;
pub mod witx;

Expand Down
97 changes: 97 additions & 0 deletions crates/wasi-nn/src/testing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
//! This is testing-specific code--it is public only so that it can be
//! accessible both in unit and integration tests.
//!
//! This module checks:
//! - that OpenVINO can be found in the environment
//! - that some ML model artifacts can be downloaded and cached.

use anyhow::{anyhow, Context, Result};
use std::{env, fs, path::Path, path::PathBuf, process::Command, sync::Mutex};

/// Return the directory in which the test artifacts are stored.
pub fn artifacts_dir() -> PathBuf {
PathBuf::from(env!("OUT_DIR")).join("mobilenet")
}

/// Early-return from a test if the test environment is not met. If the `CI`
/// or `FORCE_WASINN_TEST_CHECK` environment variables are set, though, this
/// will return an error instead.
#[macro_export]
macro_rules! check_test {
() => {
if let Err(e) = $crate::testing::check() {
if std::env::var_os("CI").is_some()
|| std::env::var_os("FORCE_WASINN_TEST_CHECK").is_some()
{
return Err(e);
} else {
println!("> ignoring test: {}", e);
return Ok(());
}
}
};
}

/// Return `Ok` if all checks pass.
pub fn check() -> Result<()> {
check_openvino_is_installed()?;
check_openvino_artifacts_are_available()?;
Ok(())
}

/// Return `Ok` if we find a working OpenVINO installation.
fn check_openvino_is_installed() -> Result<()> {
match std::panic::catch_unwind(|| println!("> found openvino version: {}", openvino::version()))
{
Ok(_) => Ok(()),
Err(e) => Err(anyhow!("unable to find an OpenVINO installation: {:?}", e)),
}
}

/// Protect `check_openvino_artifacts_are_available` from concurrent access;
/// when running tests in parallel, we want to avoid two threads attempting to
/// create the same directory or download the same file.
static ARTIFACTS: Mutex<()> = Mutex::new(());

/// Return `Ok` if we find the cached MobileNet test artifacts; this will
/// download the artifacts if necessary.
fn check_openvino_artifacts_are_available() -> Result<()> {
let _exclusively_retrieve_artifacts = ARTIFACTS.lock().unwrap();
const BASE_URL: &str =
"https://github.com/intel/openvino-rs/raw/main/crates/openvino/tests/fixtures/mobilenet";
let artifacts_dir = artifacts_dir();
if !artifacts_dir.is_dir() {
fs::create_dir(&artifacts_dir)?;
}
for (from, to) in [
("mobilenet.bin", "model.bin"),
("mobilenet.xml", "model.xml"),
("tensor-1x224x224x3-f32.bgr", "tensor.bgr"),
] {
let remote_url = [BASE_URL, from].join("/");
let local_path = artifacts_dir.join(to);
if !local_path.is_file() {
download(&remote_url, &local_path)
.with_context(|| "unable to retrieve test artifact")?;
} else {
println!("> using cached artifact: {}", local_path.display())
}
}
Ok(())
}

/// Retrieve the bytes at the `from` URL and place them in the `to` file.
fn download(from: &str, to: &Path) -> anyhow::Result<()> {
let mut curl = Command::new("curl");
curl.arg("--location").arg(from).arg("--output").arg(to);
println!("> downloading: {:?}", &curl);
let result = curl.output().unwrap();
if !result.status.success() {
panic!(
"curl failed: {}\n{}",
result.status,
String::from_utf8_lossy(&result.stderr)
);
}
Ok(())
}
Loading
Loading