Skip to content

Commit

Permalink
impl Send for JITModule (#8718)
Browse files Browse the repository at this point in the history
* Impl Send for JITModule

Ref: https://bytecodealliance.zulipchat.com/#narrow/stream/206238-general/topic/Cross.20thread.20JITModule/near/441536899

* impl Send on Memory,CompiledBlob,GotUpdate and wrap remaining Sendable fields in JITModule with a SendWrapper

* Derive Copy,Clone for SendWrapper
  • Loading branch information
MolotovCherry authored Jun 3, 2024
1 parent bda1a64 commit b010bfd
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 30 deletions.
78 changes: 48 additions & 30 deletions cranelift/jit/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ const READONLY_DATA_ALIGNMENT: u64 = 0x1;
/// A builder for `JITModule`.
pub struct JITBuilder {
isa: OwnedTargetIsa,
symbols: HashMap<String, *const u8>,
lookup_symbols: Vec<Box<dyn Fn(&str) -> Option<*const u8>>>,
symbols: HashMap<String, SendWrapper<*const u8>>,
lookup_symbols: Vec<Box<dyn Fn(&str) -> Option<*const u8> + Send>>,
libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
hotswap_enabled: bool,
}
Expand Down Expand Up @@ -116,7 +116,7 @@ impl JITBuilder {
where
K: Into<String>,
{
self.symbols.insert(name.into(), ptr);
self.symbols.insert(name.into(), SendWrapper(ptr));
self
}

Expand All @@ -129,7 +129,7 @@ impl JITBuilder {
K: Into<String>,
{
for (name, ptr) in symbols {
self.symbols.insert(name.into(), ptr);
self.symbols.insert(name.into(), SendWrapper(ptr));
}
self
}
Expand All @@ -140,7 +140,7 @@ impl JITBuilder {
/// symbol table. Symbol lookup fn's are called in reverse of the order in which they were added.
pub fn symbol_lookup_fn(
&mut self,
symbol_lookup_fn: Box<dyn Fn(&str) -> Option<*const u8>>,
symbol_lookup_fn: Box<dyn Fn(&str) -> Option<*const u8> + Send>,
) -> &mut Self {
self.lookup_symbols.push(symbol_lookup_fn);
self
Expand All @@ -165,23 +165,32 @@ struct GotUpdate {
ptr: *const u8,
}

unsafe impl Send for GotUpdate {}

/// A wrapper that impls Send for the contents.
///
/// SAFETY: This must not be used for any types where it would be UB for them to be Send
#[derive(Copy, Clone)]
struct SendWrapper<T>(T);
unsafe impl<T> Send for SendWrapper<T> {}

/// A `JITModule` implements `Module` and emits code and data into memory where it can be
/// directly called and accessed.
///
/// See the `JITBuilder` for a convenient way to construct `JITModule` instances.
pub struct JITModule {
isa: OwnedTargetIsa,
hotswap_enabled: bool,
symbols: RefCell<HashMap<String, *const u8>>,
lookup_symbols: Vec<Box<dyn Fn(&str) -> Option<*const u8>>>,
libcall_names: Box<dyn Fn(ir::LibCall) -> String>,
symbols: RefCell<HashMap<String, SendWrapper<*const u8>>>,
lookup_symbols: Vec<Box<dyn Fn(&str) -> Option<*const u8> + Send>>,
libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
memory: MemoryHandle,
declarations: ModuleDeclarations,
function_got_entries: SecondaryMap<FuncId, Option<NonNull<AtomicPtr<u8>>>>,
function_plt_entries: SecondaryMap<FuncId, Option<NonNull<[u8; 16]>>>,
data_object_got_entries: SecondaryMap<DataId, Option<NonNull<AtomicPtr<u8>>>>,
libcall_got_entries: HashMap<ir::LibCall, NonNull<AtomicPtr<u8>>>,
libcall_plt_entries: HashMap<ir::LibCall, NonNull<[u8; 16]>>,
function_got_entries: SecondaryMap<FuncId, Option<SendWrapper<NonNull<AtomicPtr<u8>>>>>,
function_plt_entries: SecondaryMap<FuncId, Option<SendWrapper<NonNull<[u8; 16]>>>>,
data_object_got_entries: SecondaryMap<DataId, Option<SendWrapper<NonNull<AtomicPtr<u8>>>>>,
libcall_got_entries: HashMap<ir::LibCall, SendWrapper<NonNull<AtomicPtr<u8>>>>,
libcall_plt_entries: HashMap<ir::LibCall, SendWrapper<NonNull<[u8; 16]>>>,
compiled_functions: SecondaryMap<FuncId, Option<CompiledBlob>>,
compiled_data_objects: SecondaryMap<DataId, Option<CompiledBlob>>,
functions_to_finalize: Vec<FuncId>,
Expand Down Expand Up @@ -215,15 +224,15 @@ impl JITModule {

fn lookup_symbol(&self, name: &str) -> Option<*const u8> {
match self.symbols.borrow_mut().entry(name.to_owned()) {
std::collections::hash_map::Entry::Occupied(occ) => Some(*occ.get()),
std::collections::hash_map::Entry::Occupied(occ) => Some(occ.get().0),
std::collections::hash_map::Entry::Vacant(vac) => {
let ptr = self
.lookup_symbols
.iter()
.rev() // Try last lookup function first
.find_map(|lookup| lookup(name));
if let Some(ptr) = ptr {
vac.insert(ptr);
vac.insert(SendWrapper(ptr));
}
ptr
}
Expand Down Expand Up @@ -266,7 +275,7 @@ impl JITModule {

fn new_func_plt_entry(&mut self, id: FuncId, val: *const u8) {
let got_entry = self.new_got_entry(val);
self.function_got_entries[id] = Some(got_entry);
self.function_got_entries[id] = Some(SendWrapper(got_entry));
let plt_entry = self.new_plt_entry(got_entry);
self.record_function_for_perf(
plt_entry.as_ptr().cast(),
Expand All @@ -276,12 +285,12 @@ impl JITModule {
self.declarations.get_function_decl(id).linkage_name(id)
),
);
self.function_plt_entries[id] = Some(plt_entry);
self.function_plt_entries[id] = Some(SendWrapper(plt_entry));
}

fn new_data_got_entry(&mut self, id: DataId, val: *const u8) {
let got_entry = self.new_got_entry(val);
self.data_object_got_entries[id] = Some(got_entry);
self.data_object_got_entries[id] = Some(SendWrapper(got_entry));
}

unsafe fn write_plt_entry_bytes(plt_ptr: *mut [u8; 16], got_ptr: NonNull<AtomicPtr<u8>>) {
Expand Down Expand Up @@ -350,24 +359,26 @@ impl JITModule {
/// Panics if there's no entry in the table for the given function.
pub fn read_got_entry(&self, func_id: FuncId) -> *const u8 {
let got_entry = self.function_got_entries[func_id].unwrap();
unsafe { got_entry.as_ref() }.load(Ordering::SeqCst)
unsafe { got_entry.0.as_ref() }.load(Ordering::SeqCst)
}

fn get_got_address(&self, name: &ModuleRelocTarget) -> NonNull<AtomicPtr<u8>> {
match *name {
ModuleRelocTarget::User { .. } => {
if ModuleDeclarations::is_function(name) {
let func_id = FuncId::from_name(name);
self.function_got_entries[func_id].unwrap()
self.function_got_entries[func_id].unwrap().0
} else {
let data_id = DataId::from_name(name);
self.data_object_got_entries[data_id].unwrap()
self.data_object_got_entries[data_id].unwrap().0
}
}
ModuleRelocTarget::LibCall(ref libcall) => *self
.libcall_got_entries
.get(libcall)
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)),
ModuleRelocTarget::LibCall(ref libcall) => {
self.libcall_got_entries
.get(libcall)
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall))
.0
}
_ => panic!("invalid name"),
}
}
Expand All @@ -379,6 +390,7 @@ impl JITModule {
let func_id = FuncId::from_name(name);
self.function_plt_entries[func_id]
.unwrap()
.0
.as_ptr()
.cast::<u8>()
} else {
Expand All @@ -389,6 +401,7 @@ impl JITModule {
.libcall_plt_entries
.get(libcall)
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall))
.0
.as_ptr()
.cast::<u8>(),
_ => panic!("invalid name"),
Expand Down Expand Up @@ -543,9 +556,13 @@ impl JITModule {
continue;
};
let got_entry = module.new_got_entry(addr);
module.libcall_got_entries.insert(libcall, got_entry);
module
.libcall_got_entries
.insert(libcall, SendWrapper(got_entry));
let plt_entry = module.new_plt_entry(got_entry);
module.libcall_plt_entries.insert(libcall, plt_entry);
module
.libcall_plt_entries
.insert(libcall, SendWrapper(plt_entry));
}

module
Expand Down Expand Up @@ -719,7 +736,7 @@ impl Module for JITModule {

if self.isa.flags().is_pic() {
self.pending_got_updates.push(GotUpdate {
entry: self.function_got_entries[id].unwrap(),
entry: self.function_got_entries[id].unwrap().0,
ptr,
})
}
Expand All @@ -737,6 +754,7 @@ impl Module for JITModule {
.libcall_plt_entries
.get(libcall)
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall))
.0
.as_ptr()
.cast::<u8>(),
_ => panic!("invalid name"),
Expand Down Expand Up @@ -802,7 +820,7 @@ impl Module for JITModule {

if self.isa.flags().is_pic() {
self.pending_got_updates.push(GotUpdate {
entry: self.function_got_entries[id].unwrap(),
entry: self.function_got_entries[id].unwrap().0,
ptr,
})
}
Expand Down Expand Up @@ -907,7 +925,7 @@ impl Module for JITModule {
self.data_objects_to_finalize.push(id);
if self.isa.flags().is_pic() {
self.pending_got_updates.push(GotUpdate {
entry: self.data_object_got_entries[id].unwrap(),
entry: self.data_object_got_entries[id].unwrap().0,
ptr,
})
}
Expand Down
2 changes: 2 additions & 0 deletions cranelift/jit/src/compiled_blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub(crate) struct CompiledBlob {
pub(crate) relocs: Vec<ModuleReloc>,
}

unsafe impl Send for CompiledBlob {}

impl CompiledBlob {
pub(crate) fn perform_relocations(
&self,
Expand Down
2 changes: 2 additions & 0 deletions cranelift/jit/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ pub(crate) struct Memory {
branch_protection: BranchProtection,
}

unsafe impl Send for Memory {}

impl Memory {
pub(crate) fn new(branch_protection: BranchProtection) -> Self {
Self {
Expand Down

0 comments on commit b010bfd

Please sign in to comment.