Skip to content

Commit

Permalink
wasmtime: Refactor trap-handling (#9087)
Browse files Browse the repository at this point in the history
This commit groups together the registers that have to be collected from
a signal handler to correctly report a trap: namely, the program counter
and frame pointer, as of the time that the trap occurred.

I also moved the call to set_jit_trap inside test_if_trap for every
platform that uses both methods. Only the implementation for Mach ports
still needs to call set_jit_trap because it doesn't use test_if_trap.

In addition I'm fixing an unrelated doc comment that I stumbled across
while working on this.
  • Loading branch information
jameysharp authored Aug 7, 2024
1 parent fa9a78b commit 895180d
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 86 deletions.
4 changes: 2 additions & 2 deletions crates/wasmtime/src/runtime/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -984,8 +984,8 @@ impl Func {
/// # Panics
///
/// This function will panic if called on a function belonging to an async
/// store. Asynchronous stores must always use `call_async`.
/// initiates a panic. Also panics if `store` does not own this function.
/// store. Asynchronous stores must always use `call_async`. Also panics if
/// `store` does not own this function.
///
/// [`WasmBacktrace`]: crate::WasmBacktrace
pub fn call(
Expand Down
13 changes: 5 additions & 8 deletions crates/wasmtime/src/runtime/vm/sys/custom/traphandlers.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::runtime::vm::traphandlers::{tls, TrapTest};
use crate::runtime::vm::traphandlers::{tls, TrapRegisters, TrapTest};
use crate::runtime::vm::VMContext;
use core::mem;

Expand Down Expand Up @@ -31,7 +31,7 @@ impl TrapHandler {
pub fn validate_config(&self, _macos_use_mach_ports: bool) {}
}

extern "C" fn handle_trap(ip: usize, fp: usize, has_faulting_addr: bool, faulting_addr: usize) {
extern "C" fn handle_trap(pc: usize, fp: usize, has_faulting_addr: bool, faulting_addr: usize) {
tls::with(|info| {
let info = match info {
Some(info) => info,
Expand All @@ -42,17 +42,14 @@ extern "C" fn handle_trap(ip: usize, fp: usize, has_faulting_addr: bool, faultin
} else {
None
};
let ip = ip as *const u8;
let test = info.test_if_trap(ip, |_handler| {
let regs = TrapRegisters { pc, fp };
let test = info.test_if_trap(regs, faulting_addr, |_handler| {
panic!("custom signal handlers are not supported on this platform");
});
match test {
TrapTest::NotWasm => {}
TrapTest::HandledByEmbedder => unreachable!(),
TrapTest::Trap { jmp_buf, trap } => {
info.set_jit_trap(ip, fp, faulting_addr, trap);
unsafe { wasmtime_longjmp(jmp_buf) }
}
TrapTest::Trap { jmp_buf } => unsafe { wasmtime_longjmp(jmp_buf) },
}
})
}
Expand Down
13 changes: 4 additions & 9 deletions crates/wasmtime/src/runtime/vm/sys/unix/machports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

use crate::runtime::module::lookup_code;
use crate::runtime::vm::sys::traphandlers::wasmtime_longjmp;
use crate::runtime::vm::traphandlers::tls;
use crate::runtime::vm::traphandlers::{tls, TrapRegisters};
use mach2::exc::*;
use mach2::exception_types::*;
use mach2::kern_return::*;
Expand Down Expand Up @@ -384,21 +384,16 @@ unsafe fn handle_exception(request: &mut ExceptionRequest) -> bool {
/// a native backtrace once we've switched back to the thread itself. After
/// the backtrace is captured we can do the usual `longjmp` back to the source
/// of the wasm code.
unsafe extern "C" fn unwind(
wasm_pc: *const u8,
wasm_fp: usize,
fault1: usize,
fault2: usize,
trap: u8,
) -> ! {
unsafe extern "C" fn unwind(pc: usize, fp: usize, fault1: usize, fault2: usize, trap: u8) -> ! {
let jmp_buf = tls::with(|state| {
let state = state.unwrap();
let regs = TrapRegisters { pc, fp };
let faulting_addr = match fault1 {
0 => None,
_ => Some(fault2),
};
let trap = Trap::from_u8(trap).unwrap();
state.set_jit_trap(wasm_pc, wasm_fp, faulting_addr, trap);
state.set_jit_trap(regs, faulting_addr, trap);
state.take_jmp_buf()
});
debug_assert!(!jmp_buf.is_null());
Expand Down
95 changes: 48 additions & 47 deletions crates/wasmtime/src/runtime/vm/sys/unix/signals.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Trap handling on Unix based on POSIX signals.

use crate::runtime::vm::traphandlers::{tls, TrapTest};
use crate::runtime::vm::traphandlers::{tls, TrapRegisters, TrapTest};
use crate::runtime::vm::VMContext;
use std::cell::RefCell;
use std::io;
Expand Down Expand Up @@ -166,23 +166,24 @@ unsafe extern "C" fn trap_handler(
// Otherwise flag ourselves as handling a trap, do the trap
// handling, and reset our trap handling flag. Then we figure
// out what to do based on the result of the trap handling.
let (pc, fp) = get_pc_and_fp(context, signum);
let test = info.test_if_trap(pc, |handler| handler(signum, siginfo, context));
let faulting_addr = match signum {
libc::SIGSEGV | libc::SIGBUS => Some((*siginfo).si_addr() as usize),
_ => None,
};
let regs = get_trap_registers(context, signum);
let test = info.test_if_trap(regs, faulting_addr, |handler| {
handler(signum, siginfo, context)
});

// Figure out what to do based on the result of this handling of
// the trap. Note that our sentinel value of 1 means that the
// exception was handled by a custom exception handler, so we
// keep executing.
let (jmp_buf, trap) = match test {
let jmp_buf = match test {
TrapTest::NotWasm => return false,
TrapTest::HandledByEmbedder => return true,
TrapTest::Trap { jmp_buf, trap } => (jmp_buf, trap),
};
let faulting_addr = match signum {
libc::SIGSEGV | libc::SIGBUS => Some((*siginfo).si_addr() as usize),
_ => None,
TrapTest::Trap { jmp_buf } => jmp_buf,
};
info.set_jit_trap(pc, fp, faulting_addr, trap);
// On macOS this is a bit special, unfortunately. If we were to
// `siglongjmp` out of the signal handler that notably does
// *not* reset the sigaltstack state of our signal handler. This
Expand Down Expand Up @@ -247,20 +248,20 @@ unsafe extern "C" fn trap_handler(
}
}

unsafe fn get_pc_and_fp(cx: *mut libc::c_void, _signum: libc::c_int) -> (*const u8, usize) {
unsafe fn get_trap_registers(cx: *mut libc::c_void, _signum: libc::c_int) -> TrapRegisters {
cfg_if::cfg_if! {
if #[cfg(all(any(target_os = "linux", target_os = "android"), target_arch = "x86_64"))] {
let cx = &*(cx as *const libc::ucontext_t);
(
cx.uc_mcontext.gregs[libc::REG_RIP as usize] as *const u8,
cx.uc_mcontext.gregs[libc::REG_RBP as usize] as usize
)
TrapRegisters {
pc: cx.uc_mcontext.gregs[libc::REG_RIP as usize] as usize,
fp: cx.uc_mcontext.gregs[libc::REG_RBP as usize] as usize,
}
} else if #[cfg(all(any(target_os = "linux", target_os = "android"), target_arch = "aarch64"))] {
let cx = &*(cx as *const libc::ucontext_t);
(
cx.uc_mcontext.pc as *const u8,
cx.uc_mcontext.regs[29] as usize,
)
TrapRegisters {
pc: cx.uc_mcontext.pc as usize,
fp: cx.uc_mcontext.regs[29] as usize,
}
} else if #[cfg(all(target_os = "linux", target_arch = "s390x"))] {
// On s390x, SIGILL and SIGFPE are delivered with the PSW address
// pointing *after* the faulting instruction, while SIGSEGV and
Expand All @@ -277,46 +278,46 @@ unsafe fn get_pc_and_fp(cx: *mut libc::c_void, _signum: libc::c_int) -> (*const
_ => 0,
};
let cx = &*(cx as *const libc::ucontext_t);
(
(cx.uc_mcontext.psw.addr - trap_offset) as *const u8,
*(cx.uc_mcontext.gregs[15] as *const usize),
)
TrapRegisters {
pc: (cx.uc_mcontext.psw.addr - trap_offset) as usize,
fp: *(cx.uc_mcontext.gregs[15] as *const usize),
}
} else if #[cfg(all(target_os = "macos", target_arch = "x86_64"))] {
let cx = &*(cx as *const libc::ucontext_t);
(
(*cx.uc_mcontext).__ss.__rip as *const u8,
(*cx.uc_mcontext).__ss.__rbp as usize,
)
TrapRegisters {
pc: (*cx.uc_mcontext).__ss.__rip as usize,
fp: (*cx.uc_mcontext).__ss.__rbp as usize,
}
} else if #[cfg(all(target_os = "macos", target_arch = "aarch64"))] {
let cx = &*(cx as *const libc::ucontext_t);
(
(*cx.uc_mcontext).__ss.__pc as *const u8,
(*cx.uc_mcontext).__ss.__fp as usize,
)
TrapRegisters {
pc: (*cx.uc_mcontext).__ss.__pc as usize,
fp: (*cx.uc_mcontext).__ss.__fp as usize,
}
} else if #[cfg(all(target_os = "freebsd", target_arch = "x86_64"))] {
let cx = &*(cx as *const libc::ucontext_t);
(
cx.uc_mcontext.mc_rip as *const u8,
cx.uc_mcontext.mc_rbp as usize,
)
TrapRegisters {
pc: cx.uc_mcontext.mc_rip as usize,
fp: cx.uc_mcontext.mc_rbp as usize,
}
} else if #[cfg(all(target_os = "linux", target_arch = "riscv64"))] {
let cx = &*(cx as *const libc::ucontext_t);
(
cx.uc_mcontext.__gregs[libc::REG_PC] as *const u8,
cx.uc_mcontext.__gregs[libc::REG_S0] as usize,
)
TrapRegisters {
pc: cx.uc_mcontext.__gregs[libc::REG_PC] as usize,
fp: cx.uc_mcontext.__gregs[libc::REG_S0] as usize,
}
} else if #[cfg(all(target_os = "freebsd", target_arch = "aarch64"))] {
let cx = &*(cx as *const libc::mcontext_t);
(
cx.mc_gpregs.gp_elr as *const u8,
cx.mc_gpregs.gp_x[29] as usize,
)
TrapRegisters {
pc: cx.mc_gpregs.gp_elr as usize,
fp: cx.mc_gpregs.gp_x[29] as usize,
}
} else if #[cfg(all(target_os = "openbsd", target_arch = "x86_64"))] {
let cx = &*(cx as *const libc::ucontext_t);
(
cx.sc_rip as *const u8,
cx.sc_rbp as usize,
)
TrapRegisters {
pc: cx.sc_rip as usize,
fp: cx.sc_rbp as usize,
}
}
else {
compile_error!("unsupported platform");
Expand Down
22 changes: 12 additions & 10 deletions crates/wasmtime/src/runtime/vm/sys/windows/traphandlers.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::runtime::vm::traphandlers::{tls, TrapTest};
use crate::runtime::vm::traphandlers::{tls, TrapRegisters, TrapTest};
use crate::runtime::vm::VMContext;
use std::ffi::c_void;
use std::io;
Expand Down Expand Up @@ -96,13 +96,18 @@ unsafe extern "system" fn exception_handler(exception_info: *mut EXCEPTION_POINT
Some(info) => info,
None => return ExceptionContinueSearch,
};
let context = &*(*exception_info).ContextRecord;
cfg_if::cfg_if! {
if #[cfg(target_arch = "x86_64")] {
let ip = (*(*exception_info).ContextRecord).Rip as *const u8;
let fp = (*(*exception_info).ContextRecord).Rbp as usize;
let regs = TrapRegisters {
pc: context.Rip as usize,
fp: context.Rbp as usize,
};
} else if #[cfg(target_arch = "aarch64")] {
let ip = (*(*exception_info).ContextRecord).Pc as *const u8;
let fp = (*(*exception_info).ContextRecord).Anonymous.Anonymous.Fp as usize;
let regs = TrapRegisters {
pc: context.Pc as usize,
fp: context.Anonymous.Anonymous.Fp as usize,
};
} else {
compile_error!("unsupported platform");
}
Expand All @@ -117,13 +122,10 @@ unsafe extern "system" fn exception_handler(exception_info: *mut EXCEPTION_POINT
} else {
None
};
match info.test_if_trap(ip, |handler| handler(exception_info)) {
match info.test_if_trap(regs, faulting_addr, |handler| handler(exception_info)) {
TrapTest::NotWasm => ExceptionContinueSearch,
TrapTest::HandledByEmbedder => ExceptionContinueExecution,
TrapTest::Trap { jmp_buf, trap } => {
info.set_jit_trap(ip, fp, faulting_addr, trap);
wasmtime_longjmp(jmp_buf)
}
TrapTest::Trap { jmp_buf } => wasmtime_longjmp(jmp_buf),
}
})
}
Expand Down
24 changes: 14 additions & 10 deletions crates/wasmtime/src/runtime/vm/traphandlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ impl From<wasmtime_environ::Trap> for TrapReason {
}
}

pub(crate) struct TrapRegisters {
pub pc: usize,
pub fp: usize,
}

/// Return value from `test_if_trap`.
pub(crate) enum TrapTest {
/// Not a wasm trap, need to delegate to whatever process handler is next.
Expand All @@ -230,8 +235,6 @@ pub(crate) enum TrapTest {
Trap {
/// How to longjmp back to the original wasm frame.
jmp_buf: *const u8,
/// The trap code of this trap.
trap: wasmtime_environ::Trap,
},
}

Expand Down Expand Up @@ -455,7 +458,8 @@ impl CallThreadState {
#[cfg_attr(miri, allow(dead_code))] // miri doesn't handle traps yet
pub(crate) fn test_if_trap(
&self,
pc: *const u8,
regs: TrapRegisters,
faulting_addr: Option<usize>,
call_handler: impl Fn(&SignalHandler) -> bool,
) -> TrapTest {
// If we haven't even started to handle traps yet, bail out.
Expand All @@ -473,19 +477,20 @@ impl CallThreadState {
}

// If this fault wasn't in wasm code, then it's not our problem
let Some((code, text_offset)) = lookup_code(pc as usize) else {
let Some((code, text_offset)) = lookup_code(regs.pc) else {
return TrapTest::NotWasm;
};

let Some(trap) = code.lookup_trap_code(text_offset) else {
return TrapTest::NotWasm;
};

self.set_jit_trap(regs, faulting_addr, trap);

// If all that passed then this is indeed a wasm trap, so return the
// `jmp_buf` passed to `wasmtime_longjmp` to resume.
TrapTest::Trap {
jmp_buf: self.take_jmp_buf(),
trap,
}
}

Expand All @@ -496,17 +501,16 @@ impl CallThreadState {
#[cfg_attr(miri, allow(dead_code))] // miri doesn't handle traps yet
pub(crate) fn set_jit_trap(
&self,
pc: *const u8,
fp: usize,
TrapRegisters { pc, fp, .. }: TrapRegisters,
faulting_addr: Option<usize>,
trap: wasmtime_environ::Trap,
) {
let backtrace = self.capture_backtrace(self.limits, Some((pc as usize, fp)));
let coredump = self.capture_coredump(self.limits, Some((pc as usize, fp)));
let backtrace = self.capture_backtrace(self.limits, Some((pc, fp)));
let coredump = self.capture_coredump(self.limits, Some((pc, fp)));
unsafe {
(*self.unwind.get()).as_mut_ptr().write((
UnwindReason::Trap(TrapReason::Jit {
pc: pc as usize,
pc,
faulting_addr,
trap,
}),
Expand Down

0 comments on commit 895180d

Please sign in to comment.