Skip to content

Commit

Permalink
wasi-nn: add testing module
Browse files Browse the repository at this point in the history
This testing-only module has code (i.e., `check_test!`) to check whether
OpenVINO and some test artifacts are available. The test artifacts are
downloaded and cached if not present, expecting `curl` to be present on
the command line (as discussed in the previous version of this, bytecodealliance#6895).
  • Loading branch information
abrown committed Dec 13, 2023
1 parent 4cb031e commit 4cf8e1f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
1 change: 1 addition & 0 deletions crates/wasi-nn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub use ctx::{preload, WasiNnCtx};
pub use registry::{GraphRegistry, InMemoryRegistry};
pub mod wit;
pub mod witx;
pub mod testing;

use std::sync::Arc;

Expand Down
91 changes: 91 additions & 0 deletions crates/wasi-nn/src/testing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//! This is testing-specific code--it is public only so that it can be
//! accessible both in unit and integration tests.
//!
//! This module checks:
//! - that OpenVINO can be found in the environment
//! - that some ML model artifacts can be downloaded and cached.

use anyhow::{anyhow, Context, Result};
use std::{env, fs, path::Path, path::PathBuf, process::Command};

/// Return the directory in which the test artifacts are stored.
pub fn artifacts_dir() -> PathBuf {
PathBuf::from(env!("OUT_DIR")).join("mobilenet")
}

/// Early-return from a test if the test environment is not met. If the `CI`
/// or `FORCE_WASINN_TEST_CHECK` environment variables are set, though, this
/// will return an error instead.
#[macro_export]
macro_rules! check_test {
() => {
if let Err(e) = $crate::testing::check() {
if std::env::var_os("CI").is_some()
|| std::env::var_os("FORCE_WASINN_TEST_CHECK").is_some()
{
return Err(e);
} else {
println!("> ignoring test: {}", e);
return Ok(());
}
}
};
}

/// Return `Ok` if all checks pass.
pub fn check() -> Result<()> {
check_openvino_is_installed()?;
check_openvino_artifacts_are_available()?;
Ok(())
}

/// Return `Ok` if we find a working OpenVINO installation.
fn check_openvino_is_installed() -> Result<()> {
match std::panic::catch_unwind(|| println!("> found openvino version: {}", openvino::version()))
{
Ok(_) => Ok(()),
Err(e) => Err(anyhow!("unable to find an OpenVINO installation: {:?}", e)),
}
}

/// Return `Ok` if we find the cached MobileNet test artifacts; this will
/// download the artifacts if necessary.
fn check_openvino_artifacts_are_available() -> Result<()> {
const BASE_URL: &str =
"https://github.com/intel/openvino-rs/raw/main/crates/openvino/tests/fixtures/mobilenet";
let artifacts_dir = artifacts_dir();
if !artifacts_dir.is_dir() {
fs::create_dir(&artifacts_dir)?;
}
for (from, to) in [
("mobilenet.bin", "model.bin"),
("mobilenet.xml", "model.xml"),
("tensor-1x224x224x3-f32.bgr", "tensor.bgr"),
] {
let remote_url = [BASE_URL, from].join("/");
let local_path = artifacts_dir.join(to);
if !local_path.is_file() {
download(&remote_url, &local_path)
.with_context(|| "unable to retrieve test artifact")?;
} else {
println!("> using cached artifact: {}", local_path.display())
}
}
Ok(())
}

/// Retrieve the bytes at the `from` URL and place them in the `to` file.
fn download(from: &str, to: &Path) -> anyhow::Result<()> {
let mut curl = Command::new("curl");
curl.arg("--location").arg(from).arg("--output").arg(to);
println!("> downloading: {:?}", &curl);
let result = curl.output().unwrap();
if !result.status.success() {
panic!(
"curl failed: {}\n{}",
result.status,
String::from_utf8_lossy(&result.stderr)
);
}
Ok(())
}

0 comments on commit 4cf8e1f

Please sign in to comment.