diff --git a/src/commands/run.rs b/src/commands/run.rs index a75dfdb6c2a1..199cee273481 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -5,6 +5,8 @@ allow(irrefutable_let_patterns, unreachable_patterns) )] +use crate::common::{Profile, RunCommon}; + use anyhow::{anyhow, bail, Context as _, Error, Result}; use clap::Parser; use std::fs::File; @@ -12,13 +14,10 @@ use std::io::Write; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::thread; -use std::time::Duration; use wasmtime::{ AsContextMut, Engine, Func, GuestProfiler, Module, Precompiled, Store, StoreLimits, StoreLimitsBuilder, UpdateDeadline, Val, ValType, }; -use wasmtime_cli_flags::opt::WasmtimeOptionValue; -use wasmtime_cli_flags::CommonOptions; use wasmtime_wasi::maybe_exit_on_error; use wasmtime_wasi::preview2; use wasmtime_wasi::sync::{ambient_authority, Dir, TcpListener, WasiCtxBuilder}; @@ -61,43 +60,12 @@ fn parse_preloads(s: &str) -> Result<(String, PathBuf)> { Ok((parts[0].into(), parts[1].into())) } -fn parse_profile(s: &str) -> Result { - let parts = s.split(',').collect::>(); - match &parts[..] { - ["perfmap"] => Ok(Profile::Native(wasmtime::ProfilingStrategy::PerfMap)), - ["jitdump"] => Ok(Profile::Native(wasmtime::ProfilingStrategy::JitDump)), - ["vtune"] => Ok(Profile::Native(wasmtime::ProfilingStrategy::VTune)), - ["guest"] => Ok(Profile::Guest { - path: "wasmtime-guest-profile.json".to_string(), - interval: Duration::from_millis(10), - }), - ["guest", path] => Ok(Profile::Guest { - path: path.to_string(), - interval: Duration::from_millis(10), - }), - ["guest", path, dur] => Ok(Profile::Guest { - path: path.to_string(), - interval: WasmtimeOptionValue::parse(Some(dur))?, - }), - _ => bail!("unknown profiling strategy: {s}"), - } -} - /// Runs a WebAssembly module #[derive(Parser)] #[structopt(name = "run")] pub struct RunCommand { #[clap(flatten)] - common: CommonOptions, - - /// Allow executing precompiled WebAssembly modules as `*.cwasm` files. - /// - /// Note that this option is not safe to pass if the module being passed in - /// is arbitrary user input. Only `wasmtime`-precompiled modules generated - /// via the `wasmtime compile` command or equivalent should be passed as an - /// argument with this option specified. - #[clap(long = "allow-precompiled")] - allow_precompiled: bool, + run: RunCommon, /// Grant access of a host directory to a guest. /// @@ -131,28 +99,6 @@ pub struct RunCommand { )] preloads: Vec<(String, PathBuf)>, - /// Profiling strategy (valid options are: perfmap, jitdump, vtune, guest) - /// - /// The perfmap, jitdump, and vtune profiling strategies integrate Wasmtime - /// with external profilers such as `perf`. The guest profiling strategy - /// enables in-process sampling and will write the captured profile to - /// `wasmtime-guest-profile.json` by default which can be viewed at - /// https://profiler.firefox.com/. - /// - /// The `guest` option can be additionally configured as: - /// - /// --profile=guest[,path[,interval]] - /// - /// where `path` is where to write the profile and `interval` is the - /// duration between samples. When used with `--wasm-timeout` the timeout - /// will be rounded up to the nearest multiple of this interval. - #[clap( - long, - value_name = "STRATEGY", - value_parser = parse_profile, - )] - profile: Option, - /// The WebAssembly module to run and arguments to pass to it. /// /// Arguments passed to the wasm module will be configured as WASI CLI @@ -162,12 +108,6 @@ pub struct RunCommand { module_and_args: Vec, } -#[derive(Clone)] -enum Profile { - Native(wasmtime::ProfilingStrategy), - Guest { path: String, interval: Duration }, -} - enum CliLinker { Core(wasmtime::Linker), #[cfg(feature = "component-model")] @@ -201,14 +141,14 @@ impl CliModule { impl RunCommand { /// Executes the command. pub fn execute(mut self) -> Result<()> { - self.common.init_logging(); + self.run.common.init_logging(); - let mut config = self.common.config(None)?; + let mut config = self.run.common.config(None)?; - if self.common.wasm.timeout.is_some() { + if self.run.common.wasm.timeout.is_some() { config.epoch_interruption(true); } - match self.profile { + match self.run.profile { Some(Profile::Native(s)) => { config.profiler(s); } @@ -225,7 +165,7 @@ impl RunCommand { let main = self.load_module(&engine, &self.module_and_args[0])?; // Validate coredump-on-trap argument - if let Some(path) = &self.common.debug.coredump { + if let Some(path) = &self.run.common.debug.coredump { if path.contains("%") { bail!("the coredump-on-trap path does not support patterns yet.") } @@ -238,7 +178,7 @@ impl RunCommand { CliLinker::Component(wasmtime::component::Linker::new(&engine)) } }; - if let Some(enable) = self.common.wasm.unknown_exports_allow { + if let Some(enable) = self.run.common.wasm.unknown_exports_allow { match &mut linker { CliLinker::Core(l) => { l.allow_unknown_exports(enable); @@ -255,22 +195,22 @@ impl RunCommand { self.populate_with_wasi(&mut linker, &mut store, &main)?; let mut limits = StoreLimitsBuilder::new(); - if let Some(max) = self.common.wasm.max_memory_size { + if let Some(max) = self.run.common.wasm.max_memory_size { limits = limits.memory_size(max); } - if let Some(max) = self.common.wasm.max_table_elements { + if let Some(max) = self.run.common.wasm.max_table_elements { limits = limits.table_elements(max); } - if let Some(max) = self.common.wasm.max_instances { + if let Some(max) = self.run.common.wasm.max_instances { limits = limits.instances(max); } - if let Some(max) = self.common.wasm.max_tables { + if let Some(max) = self.run.common.wasm.max_tables { limits = limits.tables(max); } - if let Some(max) = self.common.wasm.max_memories { + if let Some(max) = self.run.common.wasm.max_memories { limits = limits.memories(max); } - if let Some(enable) = self.common.wasm.trap_on_grow_failure { + if let Some(enable) = self.run.common.wasm.trap_on_grow_failure { limits = limits.trap_on_grow_failure(enable); } store.data_mut().limits = limits.build(); @@ -278,7 +218,7 @@ impl RunCommand { // If fuel has been configured, we want to add the configured // fuel amount to this store. - if let Some(fuel) = self.common.wasm.fuel { + if let Some(fuel) = self.run.common.wasm.fuel { store.add_fuel(fuel)?; } @@ -350,7 +290,7 @@ impl RunCommand { fn compute_preopen_sockets(&self) -> Result> { let mut listeners = vec![]; - for address in &self.common.wasi.tcplisten { + for address in &self.run.common.wasi.tcplisten { let stdlistener = std::net::TcpListener::bind(address) .with_context(|| format!("failed to bind to address '{}'", address))?; @@ -387,7 +327,7 @@ impl RunCommand { store: &mut Store, modules: Vec<(String, Module)>, ) -> Box)> { - if let Some(Profile::Guest { path, interval }) = &self.profile { + if let Some(Profile::Guest { path, interval }) = &self.run.profile { let module_name = self.module_and_args[0].to_str().unwrap_or("
"); let interval = *interval; store.data_mut().guest_profiler = @@ -406,7 +346,7 @@ impl RunCommand { store.as_context_mut().data_mut().guest_profiler = Some(profiler); } - if let Some(timeout) = self.common.wasm.timeout { + if let Some(timeout) = self.run.common.wasm.timeout { let mut timeout = (timeout.as_secs_f64() / interval.as_secs_f64()).ceil() as u64; assert!(timeout > 0); store.epoch_deadline_callback(move |mut store| { @@ -448,7 +388,7 @@ impl RunCommand { }); } - if let Some(timeout) = self.common.wasm.timeout { + if let Some(timeout) = self.run.common.wasm.timeout { store.set_epoch_deadline(1); let engine = store.engine().clone(); thread::spawn(move || { @@ -469,7 +409,7 @@ impl RunCommand { ) -> Result<()> { // The main module might be allowed to have unknown imports, which // should be defined as traps: - if self.common.wasm.unknown_imports_trap == Some(true) { + if self.run.common.wasm.unknown_imports_trap == Some(true) { match linker { CliLinker::Core(linker) => { linker.define_unknown_imports_as_traps(module.unwrap_core())?; @@ -479,7 +419,7 @@ impl RunCommand { } // ...or as default values. - if self.common.wasm.unknown_imports_default == Some(true) { + if self.run.common.wasm.unknown_imports_default == Some(true) { match linker { CliLinker::Core(linker) => { linker.define_unknown_imports_as_default_values(module.unwrap_core())?; @@ -620,7 +560,7 @@ impl RunCommand { } fn handle_core_dump(&self, store: &mut Store, err: Error) -> Error { - let coredump_path = match &self.common.debug.coredump { + let coredump_path = match &self.run.common.debug.coredump { Some(path) => path, None => return err, }; @@ -736,7 +676,7 @@ impl RunCommand { } fn ensure_allow_precompiled(&self) -> Result<()> { - if self.allow_precompiled { + if self.run.allow_precompiled { Ok(()) } else { bail!("running a precompiled module requires the `--allow-precompiled` flag") @@ -745,8 +685,8 @@ impl RunCommand { #[cfg(feature = "component-model")] fn ensure_allow_components(&self) -> Result<()> { - if self.common.wasm.component_model != Some(true) { - bail!("cannot execute a component without `--wasm component-model`"); + if self.run.common.wasm.component_model != Some(true) { + bail!("cannot execute a component without `--wasm-features component-model`"); } Ok(()) @@ -759,10 +699,10 @@ impl RunCommand { store: &mut Store, module: &CliModule, ) -> Result<()> { - if self.common.wasi.common != Some(false) { + if self.run.common.wasi.common != Some(false) { match linker { CliLinker::Core(linker) => { - if self.common.wasi.preview2 == Some(true) { + if self.run.common.wasi.preview2 == Some(true) { preview2::preview1::add_to_linker_sync(linker)?; self.set_preview2_ctx(store)?; } else { @@ -780,7 +720,7 @@ impl RunCommand { } } - if self.common.wasi.nn == Some(true) { + if self.run.common.wasi.nn == Some(true) { #[cfg(not(feature = "wasi-nn"))] { bail!("Cannot enable wasi-nn when the binary is not compiled with this feature."); @@ -810,6 +750,7 @@ impl RunCommand { } } let graphs = self + .run .common .wasi .nn_graph @@ -821,7 +762,7 @@ impl RunCommand { } } - if self.common.wasi.threads == Some(true) { + if self.run.common.wasi.threads == Some(true) { #[cfg(not(feature = "wasi-threads"))] { // Silence the unused warning for `module` as it is only used in the @@ -849,7 +790,7 @@ impl RunCommand { } } - if self.common.wasi.http == Some(true) { + if self.run.common.wasi.http == Some(true) { #[cfg(not(all(feature = "wasi-http", feature = "component-model")))] { bail!("Cannot enable wasi-http when the binary is not compiled with this feature."); @@ -887,7 +828,7 @@ impl RunCommand { let mut num_fd: usize = 3; - if self.common.wasi.listenfd == Some(true) { + if self.run.common.wasi.listenfd == Some(true) { num_fd = ctx_set_listenfd(num_fd, &mut builder)?; } @@ -917,7 +858,7 @@ impl RunCommand { builder.env(key, &value); } - if self.common.wasi.listenfd == Some(true) { + if self.run.common.wasi.listenfd == Some(true) { bail!("components do not support --listenfd"); } for _ in self.compute_preopen_sockets()? { @@ -933,7 +874,7 @@ impl RunCommand { ); } - if self.common.wasi.inherit_network == Some(true) { + if self.run.common.wasi.inherit_network == Some(true) { builder.inherit_network(ambient_authority()); } diff --git a/src/commands/serve.rs b/src/commands/serve.rs index bfca73c92cd2..5525a5dc94ce 100644 --- a/src/commands/serve.rs +++ b/src/commands/serve.rs @@ -1,26 +1,22 @@ -use anyhow::Result; +use crate::common::{Profile, RunCommon}; +use anyhow::{bail, Result}; use clap::Parser; use std::{path::PathBuf, pin::Pin, sync::Arc}; use wasmtime::component::{Component, InstancePre, Linker}; use wasmtime::{Engine, Store}; -use wasmtime_cli_flags::CommonOptions; -use wasmtime_wasi::preview2::{Table, WasiCtx, WasiCtxBuilder, WasiView}; +use wasmtime_wasi::preview2::{self, Table, WasiCtx, WasiCtxBuilder, WasiView}; use wasmtime_wasi_http::{body::HyperOutgoingBody, WasiHttpCtx, WasiHttpView}; +#[cfg(feature = "wasi-nn")] +use wasmtime_wasi_nn::WasiNnCtx; + struct Host { table: Table, ctx: WasiCtx, http: WasiHttpCtx, -} -impl Host { - fn new() -> Self { - Host { - table: Table::new(), - ctx: WasiCtxBuilder::new().build(), - http: WasiHttpCtx, - } - } + #[cfg(feature = "wasi-nn")] + nn: Option, } impl WasiView for Host { @@ -51,16 +47,21 @@ impl WasiHttpView for Host { } } +const DEFAULT_ADDR: std::net::SocketAddr = std::net::SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), + 8080, +); + /// Runs a WebAssembly module #[derive(Parser)] #[structopt(name = "run")] pub struct ServeCommand { #[clap(flatten)] - common: CommonOptions, + run: RunCommon, - /// Socket address for the web server to bind to. Defaults to 0.0.0.0:8080. - #[clap(long = "addr", value_name = "SOCKADDR")] - addr: Option, + /// Socket address for the web server to bind to. + #[clap(long = "addr", value_name = "SOCKADDR", default_value_t = DEFAULT_ADDR )] + addr: std::net::SocketAddr, /// The WebAssembly component to run. #[clap(value_name = "WASM", required = true)] @@ -68,12 +69,8 @@ pub struct ServeCommand { } impl ServeCommand { - fn addr(&self) -> std::net::SocketAddr { - self.addr.unwrap_or("0.0.0.0:8080".parse().unwrap()) - } - /// Start a server to run the given wasi-http proxy component - pub fn execute(mut self) -> Result<()> { + pub fn execute(self) -> Result<()> { let runtime = tokio::runtime::Builder::new_multi_thread() .enable_time() .enable_io() @@ -94,38 +91,131 @@ impl ServeCommand { Ok(()) } + fn new_store(&self, engine: &Engine) -> Result> { + let mut builder = WasiCtxBuilder::new(); + + // TODO: connect stdio to logging infrastructure + + let mut host = Host { + table: Table::new(), + ctx: builder.build(), + http: WasiHttpCtx, + + #[cfg(feature = "wasi-nn")] + nn: None, + }; + + if self.run.common.wasi.nn == Some(true) { + #[cfg(not(feature = "wasi-nn"))] + { + bail!("Cannot enable wasi-nn when the binary is not compiled with this feature."); + } + #[cfg(feature = "wasi-nn")] + { + let graphs = self + .run + .common + .wasi + .nn_graph + .iter() + .map(|g| (g.format.clone(), g.dir.clone())) + .collect::>(); + let (backends, registry) = wasmtime_wasi_nn::preload(&graphs)?; + host.nn.replace(WasiNnCtx::new(backends, registry)); + } + } + + let mut store = Store::new(engine, host); + + if let Some(Profile::Guest { .. }) = &self.run.profile { + bail!("Cannot use the guest profiler with components"); + } + + if self.run.common.wasm.timeout.is_some() { + store.set_epoch_deadline(1); + } + + Ok(store) + } + fn add_to_linker(&self, linker: &mut Linker) -> Result<()> { + // wasi-http and the component model are implicitly enabled for `wasmtime serve`, so we + // don't test for `self.run.common.wasi.common` or `self.run.common.wasi.http` in this + // function. + wasmtime_wasi_http::proxy::add_to_linker(linker)?; + + if self.run.common.wasi.nn == Some(true) { + #[cfg(not(feature = "wasi-nn"))] + { + bail!("Cannot enable wasi-nn when the binary is not compiled with this feature."); + } + #[cfg(feature = "wasi-nn")] + { + wasmtime_wasi_nn::wit::ML::add_to_linker(linker, |host| host.nn.as_mut().unwrap())?; + } + } + + if self.run.common.wasi.threads == Some(true) { + bail!("wasi-threads does not support components yet") + } + Ok(()) } - async fn serve(&mut self) -> Result<()> { + async fn serve(mut self) -> Result<()> { use hyper::server::conn::http1; - let mut config = self.common.config(None)?; + self.run.common.init_logging(); + + let mut config = self.run.common.config(None)?; config.wasm_component_model(true); config.async_support(true); - let engine = Arc::new(Engine::new(&config)?); + if self.run.common.wasm.timeout.is_some() { + config.epoch_interruption(true); + } + + match self.run.profile { + Some(Profile::Native(s)) => { + config.profiler(s); + } + Some(Profile::Guest { .. }) => { + bail!("guest profiling not yet available with components"); + } + None => {} + } + + let engine = Engine::new(&config)?; let mut linker = Linker::new(&engine); self.add_to_linker(&mut linker)?; let component = Component::from_file(&engine, &self.component)?; - let instance = Arc::new(linker.instantiate_pre(&component)?); + let instance = linker.instantiate_pre(&component)?; + + let listener = tokio::net::TcpListener::bind(self.addr).await?; - let listener = tokio::net::TcpListener::bind(self.addr()).await?; + let _epoch_thread = if let Some(timeout) = self.run.common.wasm.timeout { + let engine = engine.clone(); + Some(preview2::spawn(async move { + tokio::time::sleep(timeout).await; + engine.increment_epoch(); + })) + } else { + None + }; + + let handler = ProxyHandler::new(self, engine, instance); loop { let (stream, _) = listener.accept().await?; - let engine = Arc::clone(&engine); - let instance = Arc::clone(&instance); + let h = handler.clone(); tokio::task::spawn(async move { - let handler = ProxyHandler::new(engine, instance); if let Err(e) = http1::Builder::new() .keep_alive(true) - .serve_connection(stream, handler) + .serve_connection(stream, h) .await { eprintln!("error: {e:?}"); @@ -135,18 +225,22 @@ impl ServeCommand { } } -#[derive(Clone)] -struct ProxyHandler { - engine: Arc, - instance_pre: Arc>, +struct ProxyHandlerInner { + cmd: ServeCommand, + engine: Engine, + instance_pre: InstancePre, } +#[derive(Clone)] +struct ProxyHandler(Arc); + impl ProxyHandler { - fn new(engine: Arc, instance_pre: Arc>) -> Self { - Self { + fn new(cmd: ServeCommand, engine: Engine, instance_pre: InstancePre) -> Self { + Self(Arc::new(ProxyHandlerInner { + cmd, engine, instance_pre, - } + })) } } @@ -166,7 +260,7 @@ impl hyper::service::Service for ProxyHandler { // TODO: need to track the join handle, but don't want to block the response on it tokio::task::spawn(async move { - let mut store = Store::new(&handler.engine, Host::new()); + let mut store = handler.0.cmd.new_store(&handler.0.engine)?; let req = store.data_mut().new_incoming_request( req.map(|body| body.map_err(|e| anyhow::anyhow!(e)).boxed()), @@ -176,7 +270,7 @@ impl hyper::service::Service for ProxyHandler { let (proxy, _inst) = wasmtime_wasi_http::proxy::Proxy::instantiate_pre( &mut store, - &handler.instance_pre, + &handler.0.instance_pre, ) .await?; diff --git a/src/common.rs b/src/common.rs new file mode 100644 index 000000000000..0c95f69cf4a2 --- /dev/null +++ b/src/common.rs @@ -0,0 +1,76 @@ +//! Common functionality shared between command implementations. + +use anyhow::{bail, Result}; +use clap::Parser; +use std::time::Duration; +use wasmtime_cli_flags::{opt::WasmtimeOptionValue, CommonOptions}; + +/// Common command line arguments for run commands. +#[derive(Parser)] +#[structopt(name = "run")] +pub struct RunCommon { + #[clap(flatten)] + pub common: CommonOptions, + + /// Allow executing precompiled WebAssembly modules as `*.cwasm` files. + /// + /// Note that this option is not safe to pass if the module being passed in + /// is arbitrary user input. Only `wasmtime`-precompiled modules generated + /// via the `wasmtime compile` command or equivalent should be passed as an + /// argument with this option specified. + #[clap(long = "allow-precompiled")] + pub allow_precompiled: bool, + + /// Profiling strategy (valid options are: perfmap, jitdump, vtune, guest) + /// + /// The perfmap, jitdump, and vtune profiling strategies integrate Wasmtime + /// with external profilers such as `perf`. The guest profiling strategy + /// enables in-process sampling and will write the captured profile to + /// `wasmtime-guest-profile.json` by default which can be viewed at + /// https://profiler.firefox.com/. + /// + /// The `guest` option can be additionally configured as: + /// + /// --profile=guest[,path[,interval]] + /// + /// where `path` is where to write the profile and `interval` is the + /// duration between samples. When used with `--wasm-timeout` the timeout + /// will be rounded up to the nearest multiple of this interval. + #[clap( + long, + value_name = "STRATEGY", + value_parser = Profile::parse, + )] + pub profile: Option, +} + +#[derive(Clone)] +pub enum Profile { + Native(wasmtime::ProfilingStrategy), + Guest { path: String, interval: Duration }, +} + +impl Profile { + /// Parse the `profile` argument to either the `run` or `serve` commands. + pub fn parse(s: &str) -> Result { + let parts = s.split(',').collect::>(); + match &parts[..] { + ["perfmap"] => Ok(Profile::Native(wasmtime::ProfilingStrategy::PerfMap)), + ["jitdump"] => Ok(Profile::Native(wasmtime::ProfilingStrategy::JitDump)), + ["vtune"] => Ok(Profile::Native(wasmtime::ProfilingStrategy::VTune)), + ["guest"] => Ok(Profile::Guest { + path: "wasmtime-guest-profile.json".to_string(), + interval: Duration::from_millis(10), + }), + ["guest", path] => Ok(Profile::Guest { + path: path.to_string(), + interval: Duration::from_millis(10), + }), + ["guest", path, dur] => Ok(Profile::Guest { + path: path.to_string(), + interval: WasmtimeOptionValue::parse(Some(dur))?, + }), + _ => bail!("unknown profiling strategy: {s}"), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 5cf21f6fe768..11473981a467 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,3 +11,5 @@ #![warn(unused_import_braces)] pub mod commands; + +pub(crate) mod common;