From 4cf8e1f978df6d1097be505b9f39860b98e79af0 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Tue, 12 Dec 2023 17:18:57 -0800 Subject: [PATCH] wasi-nn: add `testing` module This testing-only module has code (i.e., `check_test!`) to check whether OpenVINO and some test artifacts are available. The test artifacts are downloaded and cached if not present, expecting `curl` to be present on the command line (as discussed in the previous version of this, #6895). --- crates/wasi-nn/src/lib.rs | 1 + crates/wasi-nn/src/testing.rs | 91 +++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 crates/wasi-nn/src/testing.rs diff --git a/crates/wasi-nn/src/lib.rs b/crates/wasi-nn/src/lib.rs index 71d089d07489..f367e0a1582d 100644 --- a/crates/wasi-nn/src/lib.rs +++ b/crates/wasi-nn/src/lib.rs @@ -6,6 +6,7 @@ pub use ctx::{preload, WasiNnCtx}; pub use registry::{GraphRegistry, InMemoryRegistry}; pub mod wit; pub mod witx; +pub mod testing; use std::sync::Arc; diff --git a/crates/wasi-nn/src/testing.rs b/crates/wasi-nn/src/testing.rs new file mode 100644 index 000000000000..2294fea9de4b --- /dev/null +++ b/crates/wasi-nn/src/testing.rs @@ -0,0 +1,91 @@ +//! This is testing-specific code--it is public only so that it can be +//! accessible both in unit and integration tests. +//! +//! This module checks: +//! - that OpenVINO can be found in the environment +//! - that some ML model artifacts can be downloaded and cached. + +use anyhow::{anyhow, Context, Result}; +use std::{env, fs, path::Path, path::PathBuf, process::Command}; + +/// Return the directory in which the test artifacts are stored. +pub fn artifacts_dir() -> PathBuf { + PathBuf::from(env!("OUT_DIR")).join("mobilenet") +} + +/// Early-return from a test if the test environment is not met. If the `CI` +/// or `FORCE_WASINN_TEST_CHECK` environment variables are set, though, this +/// will return an error instead. +#[macro_export] +macro_rules! check_test { + () => { + if let Err(e) = $crate::testing::check() { + if std::env::var_os("CI").is_some() + || std::env::var_os("FORCE_WASINN_TEST_CHECK").is_some() + { + return Err(e); + } else { + println!("> ignoring test: {}", e); + return Ok(()); + } + } + }; +} + +/// Return `Ok` if all checks pass. +pub fn check() -> Result<()> { + check_openvino_is_installed()?; + check_openvino_artifacts_are_available()?; + Ok(()) +} + +/// Return `Ok` if we find a working OpenVINO installation. +fn check_openvino_is_installed() -> Result<()> { + match std::panic::catch_unwind(|| println!("> found openvino version: {}", openvino::version())) + { + Ok(_) => Ok(()), + Err(e) => Err(anyhow!("unable to find an OpenVINO installation: {:?}", e)), + } +} + +/// Return `Ok` if we find the cached MobileNet test artifacts; this will +/// download the artifacts if necessary. +fn check_openvino_artifacts_are_available() -> Result<()> { + const BASE_URL: &str = + "https://github.com/intel/openvino-rs/raw/main/crates/openvino/tests/fixtures/mobilenet"; + let artifacts_dir = artifacts_dir(); + if !artifacts_dir.is_dir() { + fs::create_dir(&artifacts_dir)?; + } + for (from, to) in [ + ("mobilenet.bin", "model.bin"), + ("mobilenet.xml", "model.xml"), + ("tensor-1x224x224x3-f32.bgr", "tensor.bgr"), + ] { + let remote_url = [BASE_URL, from].join("/"); + let local_path = artifacts_dir.join(to); + if !local_path.is_file() { + download(&remote_url, &local_path) + .with_context(|| "unable to retrieve test artifact")?; + } else { + println!("> using cached artifact: {}", local_path.display()) + } + } + Ok(()) +} + +/// Retrieve the bytes at the `from` URL and place them in the `to` file. +fn download(from: &str, to: &Path) -> anyhow::Result<()> { + let mut curl = Command::new("curl"); + curl.arg("--location").arg(from).arg("--output").arg(to); + println!("> downloading: {:?}", &curl); + let result = curl.output().unwrap(); + if !result.status.success() { + panic!( + "curl failed: {}\n{}", + result.status, + String::from_utf8_lossy(&result.stderr) + ); + } + Ok(()) +}