Skip to content

Commit

Permalink
Optimize calling a WebAssembly function (#2757)
Browse files Browse the repository at this point in the history
This commit implements a few optimizations, mainly inlining, that should
improve the performance of calling a WebAssembly function. This code
path can be quite hot depending on the embedding case and we hadn't
really put much effort into optimizing the nitty gritty.

The predominant optimization here is adding `#[inline]` to trivial
functions so performance is improved without having to compile with LTO.
Another optimization is to call `lazy_per_thread_init` when traps are
initialized per-thread (when a `Store` is created) rather than each time
a function is called. The next optimization is to change the unwind
reason in the `CallThreadState` to `MaybeUninit` to avoid extra checks
in the default case about whether we need to drop its variants (since in
the happy path we never need to drop it). The final optimization is to
optimize out a few checks when `async` support is disabled for a small
speed boost.

In a small benchmark where wasmtime calls a simple wasm function my
macOS computer dropped from 110ns to 86ns overhead, a 20% decrease. The
macOS overhead is still largely dominated by the global lock acquisition
and hash table management for traps right now, but I suspect the Linux
overhead is much better (should be on the order of ~30 or so ns).

We still have a long way to go to compete with SpiderMonkey which, in
testing, seem to have ~6ns overhead in calling the same wasm function on
my computer.
  • Loading branch information
alexcrichton authored Mar 23, 2021
1 parent 49ef2c6 commit c95971a
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 29 deletions.
2 changes: 2 additions & 0 deletions crates/runtime/src/externref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ impl VMExternRefActivationsTable {
/// // call has returned.
/// drop(auto_reset_canary);
/// ```
#[inline]
pub fn set_stack_canary<'a>(&'a self, canary: &u8) -> impl Drop + 'a {
let should_reset = if self.stack_canary.get().is_none() {
let canary = canary as *const u8 as *mut u8;
Expand All @@ -775,6 +776,7 @@ impl VMExternRefActivationsTable {
}

impl Drop for AutoResetCanary<'_> {
#[inline]
fn drop(&mut self) {
if self.should_reset {
debug_assert!(self.table.stack_canary.get().is_some());
Expand Down
52 changes: 26 additions & 26 deletions crates/runtime/src/traphandlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
use crate::VMInterrupts;
use backtrace::Backtrace;
use std::any::Any;
use std::cell::Cell;
use std::cell::{Cell, UnsafeCell};
use std::error::Error;
use std::mem::MaybeUninit;
use std::ptr;
use std::sync::atomic::{AtomicUsize, Ordering::SeqCst};
use std::sync::Once;
Expand Down Expand Up @@ -47,9 +48,10 @@ pub use sys::SignalHandler;
/// function needs to be called at the end of the startup process, after other
/// handlers have been installed. This function can thus be called multiple
/// times, having no effect after the first call.
pub fn init_traps() {
pub fn init_traps() -> Result<(), Trap> {
static INIT: Once = Once::new();
INIT.call_once(|| unsafe { sys::platform_init() });
sys::lazy_per_thread_init()
}

/// Raises a user-defined trap immediately.
Expand Down Expand Up @@ -155,8 +157,6 @@ pub unsafe fn catch_traps<F>(trap_info: &impl TrapInfo, mut closure: F) -> Resul
where
F: FnMut(),
{
sys::lazy_per_thread_init()?;

return CallThreadState::new(trap_info).with(|cx| {
RegisterSetjmp(
cx.jmp_buf.as_ptr(),
Expand Down Expand Up @@ -191,7 +191,7 @@ pub fn out_of_gas() {
/// Temporary state stored on the stack which is registered in the `tls` module
/// below for calls into wasm.
pub struct CallThreadState<'a> {
unwind: Cell<UnwindReason>,
unwind: UnsafeCell<MaybeUninit<UnwindReason>>,
jmp_buf: Cell<*const u8>,
handling_trap: Cell<bool>,
trap_info: &'a (dyn TrapInfo + 'a),
Expand Down Expand Up @@ -232,17 +232,17 @@ pub unsafe trait TrapInfo {
}

enum UnwindReason {
None,
Panic(Box<dyn Any + Send>),
UserTrap(Box<dyn Error + Send + Sync>),
LibTrap(Trap),
JitTrap { backtrace: Backtrace, pc: usize },
}

impl<'a> CallThreadState<'a> {
#[inline]
fn new(trap_info: &'a (dyn TrapInfo + 'a)) -> CallThreadState<'a> {
CallThreadState {
unwind: Cell::new(UnwindReason::None),
unwind: UnsafeCell::new(MaybeUninit::uninit()),
jmp_buf: Cell::new(ptr::null()),
handling_trap: Cell::new(false),
trap_info,
Expand All @@ -253,18 +253,13 @@ impl<'a> CallThreadState<'a> {
fn with(self, closure: impl FnOnce(&CallThreadState) -> i32) -> Result<(), Trap> {
let _reset = self.update_stack_limit()?;
let ret = tls::set(&self, || closure(&self));
match self.unwind.replace(UnwindReason::None) {
UnwindReason::None => {
debug_assert_eq!(ret, 1);
Ok(())
}
UnwindReason::UserTrap(data) => {
debug_assert_eq!(ret, 0);
Err(Trap::User(data))
}
if ret != 0 {
return Ok(());
}
match unsafe { (*self.unwind.get()).as_ptr().read() } {
UnwindReason::UserTrap(data) => Err(Trap::User(data)),
UnwindReason::LibTrap(trap) => Err(trap),
UnwindReason::JitTrap { backtrace, pc } => {
debug_assert_eq!(ret, 0);
let interrupts = self.trap_info.interrupts();
let maybe_interrupted =
interrupts.stack_limit.load(SeqCst) == wasmtime_environ::INTERRUPTED;
Expand All @@ -274,10 +269,7 @@ impl<'a> CallThreadState<'a> {
maybe_interrupted,
})
}
UnwindReason::Panic(panic) => {
debug_assert_eq!(ret, 0);
std::panic::resume_unwind(panic)
}
UnwindReason::Panic(panic) => std::panic::resume_unwind(panic),
}
}

Expand Down Expand Up @@ -310,6 +302,7 @@ impl<'a> CallThreadState<'a> {
///
/// Note that this function must be called with `self` on the stack, not the
/// heap/etc.
#[inline]
fn update_stack_limit(&self) -> Result<impl Drop + '_, Trap> {
// Determine the stack pointer where, after which, any wasm code will
// immediately trap. This is checked on the entry to all wasm functions.
Expand Down Expand Up @@ -361,6 +354,7 @@ impl<'a> CallThreadState<'a> {
struct Reset<'a>(bool, &'a AtomicUsize);

impl Drop for Reset<'_> {
#[inline]
fn drop(&mut self) {
if self.0 {
self.1.store(usize::max_value(), SeqCst);
Expand All @@ -372,8 +366,8 @@ impl<'a> CallThreadState<'a> {
}

fn unwind_with(&self, reason: UnwindReason) -> ! {
self.unwind.replace(reason);
unsafe {
(*self.unwind.get()).as_mut_ptr().write(reason);
Unwind(self.jmp_buf.get());
}
}
Expand Down Expand Up @@ -432,16 +426,21 @@ impl<'a> CallThreadState<'a> {

fn capture_backtrace(&self, pc: *const u8) {
let backtrace = Backtrace::new_unresolved();
self.unwind.replace(UnwindReason::JitTrap {
backtrace,
pc: pc as usize,
});
unsafe {
(*self.unwind.get())
.as_mut_ptr()
.write(UnwindReason::JitTrap {
backtrace,
pc: pc as usize,
});
}
}
}

struct ResetCell<'a, T: Copy>(&'a Cell<T>, T);

impl<T: Copy> Drop for ResetCell<'_, T> {
#[inline]
fn drop(&mut self) {
self.0.set(self.1);
}
Expand Down Expand Up @@ -544,6 +543,7 @@ mod tls {
struct Reset<'a, 'b>(&'a CallThreadState<'b>);

impl Drop for Reset<'_, '_> {
#[inline]
fn drop(&mut self) {
raw::replace(self.0.prev.replace(ptr::null()));
}
Expand Down
1 change: 1 addition & 0 deletions crates/wasmtime/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ impl Engine {
}

/// Returns the configuration settings that this engine is using.
#[inline]
pub fn config(&self) -> &Config {
&self.inner.config
}
Expand Down
4 changes: 3 additions & 1 deletion crates/wasmtime/src/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ impl Func {
/// initiates a panic.
pub fn call(&self, params: &[Val]) -> Result<Box<[Val]>> {
assert!(
!self.store().async_support(),
!cfg!(feature = "async") || !self.store().async_support(),
"must use `call_async` when async support is enabled on the config",
);
self._call(params)
Expand Down Expand Up @@ -926,6 +926,7 @@ impl Func {
}

/// Get a reference to this function's store.
#[inline]
pub fn store(&self) -> &Store {
&self.instance.store
}
Expand Down Expand Up @@ -1414,6 +1415,7 @@ impl Caller<'_> {
}

/// Get a reference to the caller's store.
#[inline]
pub fn store(&self) -> &Store {
self.store
}
Expand Down
2 changes: 1 addition & 1 deletion crates/wasmtime/src/func/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ where
/// connected to an asynchronous store.
pub fn call(&self, params: Params) -> Result<Results, Trap> {
assert!(
!self.func.store().async_support(),
!cfg!(feature = "async") || !self.func.store().async_support(),
"must use `call_async` with async stores"
);
unsafe { self._call(params) }
Expand Down
9 changes: 8 additions & 1 deletion crates/wasmtime/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ impl Store {
// once-per-thread. Platforms like Unix, however, only require this
// once-per-program. In any case this is safe to call many times and
// each one that's not relevant just won't do anything.
wasmtime_runtime::init_traps();
wasmtime_runtime::init_traps().expect("failed to initialize trap handling");

Store {
inner: Rc::new(StoreInner {
Expand Down Expand Up @@ -209,6 +209,7 @@ impl Store {
}

/// Returns the [`Engine`] that this store is associated with.
#[inline]
pub fn engine(&self) -> &Engine {
&self.inner.engine
}
Expand Down Expand Up @@ -503,10 +504,12 @@ impl Store {
}
}

#[inline]
pub(crate) fn externref_activations_table(&self) -> &VMExternRefActivationsTable {
&self.inner.externref_activations_table
}

#[inline]
pub(crate) fn stack_map_registry(&self) -> &StackMapRegistry {
&self.inner.stack_map_registry
}
Expand Down Expand Up @@ -655,6 +658,7 @@ impl Store {
});
}

#[inline]
pub(crate) fn async_support(&self) -> bool {
self.inner.engine.config().async_support
}
Expand Down Expand Up @@ -915,6 +919,7 @@ impl Store {
}

unsafe impl TrapInfo for Store {
#[inline]
fn as_any(&self) -> &dyn Any {
self
}
Expand All @@ -930,6 +935,7 @@ unsafe impl TrapInfo for Store {
false
}

#[inline]
fn max_wasm_stack(&self) -> usize {
self.engine().config().max_wasm_stack
}
Expand All @@ -956,6 +962,7 @@ unsafe impl TrapInfo for Store {
}
}

#[inline]
fn interrupts(&self) -> &VMInterrupts {
&self.inner.interrupts
}
Expand Down

0 comments on commit c95971a

Please sign in to comment.