diff --git a/src/module/functions/mod.rs b/src/module/functions/mod.rs index 78dd0703..2de3ab1c 100644 --- a/src/module/functions/mod.rs +++ b/src/module/functions/mod.rs @@ -1,5 +1,15 @@ //! Functions within a wasm module. +use std::cmp; +use std::collections::BTreeMap; + +use anyhow::{bail, Context}; +use wasm_encoder::Encode; +use wasmparser::{FuncValidator, FunctionBody, Range, ValidatorResources}; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + mod local_function; use crate::emit::{Emit, EmitContext}; @@ -11,13 +21,7 @@ use crate::parse::IndicesToIds; use crate::tombstone_arena::{Id, Tombstone, TombstoneArena}; use crate::ty::TypeId; use crate::ty::ValType; -use std::cmp; -use std::collections::BTreeMap; -use wasm_encoder::Encode; -use wasmparser::{FuncValidator, FunctionBody, Range, ValidatorResources}; - -#[cfg(feature = "parallel")] -use rayon::prelude::*; +use crate::{ExportItem, ImportKind, Memory, MemoryId}; pub use self::local_function::LocalFunction; @@ -206,6 +210,17 @@ impl ModuleFunctions { self.arena.delete(id); } + /// Remove a function from this module by name that corresponds to it's import or export. + /// + /// Unlike [`by_name()`](ModuleFunctions::by_name), "name" corresponds to either a functions import name or it's export name. + pub fn delete_by_name(&mut self, name: impl AsRef) { + // TODO: How do we differentiate between imports/exports here? + // add some sort of `FunctionLocation` enum (ex. FunctionLocation::ModuleImports/FunctionLocation::ModuleExports?) + // + // FunctionKind is almost there as a marker but there's no "Export" + todo!() + } + /// Get a shared reference to this module's functions. pub fn iter(&self) -> impl Iterator { self.arena.iter().map(|(_, f)| f) @@ -418,6 +433,66 @@ impl Module { Ok(()) } + + /// Retrieve an exported function by name + pub fn get_exported_func_by_name(&self, name: impl AsRef) -> Result { + self.exports + .iter() + // Find the export with the correct name and internal type + .filter_map(|expt| match expt.item { + ExportItem::Function(fid) if expt.name == name.as_ref() => Some(fid), + _ => None, + }) + .nth(0) + .with_context(|| format!("unable to find function export '{}'", name.as_ref())) + } + + /// Retrieve an imported function by name + pub fn get_imported_func_by_name(&self, name: impl AsRef) -> Result { + self.imports + .iter() + // Find the export with the correct name and internal type + .filter_map(|impt| match impt.kind { + ImportKind::Function(fid) if impt.name == name.as_ref() => Some(fid), + _ => None, + }) + .nth(0) + .with_context(|| format!("unable to find function export '{}'", name.as_ref())) + } + + /// Retrieve the ID for the first exported memory. + /// + /// This method does not work in contexts with [multi-memory enabled](https://github.com/WebAssembly/multi-memory), + /// and will error if more than one memory is present. + pub fn get_primary_memory_id(&self) -> Result { + if self.memories.len() > 1 { + bail!("multiple memories unsupported") + } + + self.memories + .iter() + .next() + .map(Memory::id) + .context("module does not export a memory") + } + + /// Replace a single function with the result + /// + /// When called, if `builder` produces a None value, the function in question will be + /// replaced with a stub that does nothing (more precisely, a function with an unreachable body). + pub fn replace_fn_by_name(&mut self, name: impl AsRef, builder: F) -> Result<()> + where + F: FnOnce() -> Result>, + { + let built_fn = builder().context("fn builder failed")?; + + // TODO: Hmnn.... how should we distinguish between which one to go for? + // this is a similar problem to the more general delete_by_name() above. + let imported_fn_id = self.get_imported_func_by_name(name.as_ref()); + let exported_fn_id = self.get_exported_func_by_name(name.as_ref()); + + Ok(()) + } } fn used_local_functions<'a>(cx: &mut EmitContext<'a>) -> Vec<(FunctionId, &'a LocalFunction, u64)> { diff --git a/src/module/memories.rs b/src/module/memories.rs index 55342528..2063d369 100644 --- a/src/module/memories.rs +++ b/src/module/memories.rs @@ -115,6 +115,11 @@ impl ModuleMemories { pub fn iter_mut(&mut self) -> impl Iterator { self.arena.iter_mut().map(|(_, f)| f) } + + /// Get the number of memories in this module + pub fn len(&self) -> usize { + self.arena.len() + } } impl Module {