diff --git a/cranelift/jit/src/backend.rs b/cranelift/jit/src/backend.rs index 226d3c9afba1..903fecd6c406 100644 --- a/cranelift/jit/src/backend.rs +++ b/cranelift/jit/src/backend.rs @@ -27,8 +27,8 @@ const READONLY_DATA_ALIGNMENT: u64 = 0x1; /// A builder for `JITModule`. pub struct JITBuilder { isa: OwnedTargetIsa, - symbols: HashMap, - lookup_symbols: Vec Option<*const u8>>>, + symbols: HashMap>, + lookup_symbols: Vec Option<*const u8> + Send>>, libcall_names: Box String + Send + Sync>, hotswap_enabled: bool, } @@ -116,7 +116,7 @@ impl JITBuilder { where K: Into, { - self.symbols.insert(name.into(), ptr); + self.symbols.insert(name.into(), SendWrapper(ptr)); self } @@ -129,7 +129,7 @@ impl JITBuilder { K: Into, { for (name, ptr) in symbols { - self.symbols.insert(name.into(), ptr); + self.symbols.insert(name.into(), SendWrapper(ptr)); } self } @@ -140,7 +140,7 @@ impl JITBuilder { /// symbol table. Symbol lookup fn's are called in reverse of the order in which they were added. pub fn symbol_lookup_fn( &mut self, - symbol_lookup_fn: Box Option<*const u8>>, + symbol_lookup_fn: Box Option<*const u8> + Send>, ) -> &mut Self { self.lookup_symbols.push(symbol_lookup_fn); self @@ -165,6 +165,15 @@ struct GotUpdate { ptr: *const u8, } +unsafe impl Send for GotUpdate {} + +/// A wrapper that impls Send for the contents. +/// +/// SAFETY: This must not be used for any types where it would be UB for them to be Send +#[derive(Copy, Clone)] +struct SendWrapper(T); +unsafe impl Send for SendWrapper {} + /// A `JITModule` implements `Module` and emits code and data into memory where it can be /// directly called and accessed. /// @@ -172,16 +181,16 @@ struct GotUpdate { pub struct JITModule { isa: OwnedTargetIsa, hotswap_enabled: bool, - symbols: RefCell>, - lookup_symbols: Vec Option<*const u8>>>, - libcall_names: Box String>, + symbols: RefCell>>, + lookup_symbols: Vec Option<*const u8> + Send>>, + libcall_names: Box String + Send + Sync>, memory: MemoryHandle, declarations: ModuleDeclarations, - function_got_entries: SecondaryMap>>>, - function_plt_entries: SecondaryMap>>, - data_object_got_entries: SecondaryMap>>>, - libcall_got_entries: HashMap>>, - libcall_plt_entries: HashMap>, + function_got_entries: SecondaryMap>>>>, + function_plt_entries: SecondaryMap>>>, + data_object_got_entries: SecondaryMap>>>>, + libcall_got_entries: HashMap>>>, + libcall_plt_entries: HashMap>>, compiled_functions: SecondaryMap>, compiled_data_objects: SecondaryMap>, functions_to_finalize: Vec, @@ -215,7 +224,7 @@ impl JITModule { fn lookup_symbol(&self, name: &str) -> Option<*const u8> { match self.symbols.borrow_mut().entry(name.to_owned()) { - std::collections::hash_map::Entry::Occupied(occ) => Some(*occ.get()), + std::collections::hash_map::Entry::Occupied(occ) => Some(occ.get().0), std::collections::hash_map::Entry::Vacant(vac) => { let ptr = self .lookup_symbols @@ -223,7 +232,7 @@ impl JITModule { .rev() // Try last lookup function first .find_map(|lookup| lookup(name)); if let Some(ptr) = ptr { - vac.insert(ptr); + vac.insert(SendWrapper(ptr)); } ptr } @@ -266,7 +275,7 @@ impl JITModule { fn new_func_plt_entry(&mut self, id: FuncId, val: *const u8) { let got_entry = self.new_got_entry(val); - self.function_got_entries[id] = Some(got_entry); + self.function_got_entries[id] = Some(SendWrapper(got_entry)); let plt_entry = self.new_plt_entry(got_entry); self.record_function_for_perf( plt_entry.as_ptr().cast(), @@ -276,12 +285,12 @@ impl JITModule { self.declarations.get_function_decl(id).linkage_name(id) ), ); - self.function_plt_entries[id] = Some(plt_entry); + self.function_plt_entries[id] = Some(SendWrapper(plt_entry)); } fn new_data_got_entry(&mut self, id: DataId, val: *const u8) { let got_entry = self.new_got_entry(val); - self.data_object_got_entries[id] = Some(got_entry); + self.data_object_got_entries[id] = Some(SendWrapper(got_entry)); } unsafe fn write_plt_entry_bytes(plt_ptr: *mut [u8; 16], got_ptr: NonNull>) { @@ -350,7 +359,7 @@ impl JITModule { /// Panics if there's no entry in the table for the given function. pub fn read_got_entry(&self, func_id: FuncId) -> *const u8 { let got_entry = self.function_got_entries[func_id].unwrap(); - unsafe { got_entry.as_ref() }.load(Ordering::SeqCst) + unsafe { got_entry.0.as_ref() }.load(Ordering::SeqCst) } fn get_got_address(&self, name: &ModuleRelocTarget) -> NonNull> { @@ -358,16 +367,18 @@ impl JITModule { ModuleRelocTarget::User { .. } => { if ModuleDeclarations::is_function(name) { let func_id = FuncId::from_name(name); - self.function_got_entries[func_id].unwrap() + self.function_got_entries[func_id].unwrap().0 } else { let data_id = DataId::from_name(name); - self.data_object_got_entries[data_id].unwrap() + self.data_object_got_entries[data_id].unwrap().0 } } - ModuleRelocTarget::LibCall(ref libcall) => *self - .libcall_got_entries - .get(libcall) - .unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)), + ModuleRelocTarget::LibCall(ref libcall) => { + self.libcall_got_entries + .get(libcall) + .unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)) + .0 + } _ => panic!("invalid name"), } } @@ -379,6 +390,7 @@ impl JITModule { let func_id = FuncId::from_name(name); self.function_plt_entries[func_id] .unwrap() + .0 .as_ptr() .cast::() } else { @@ -389,6 +401,7 @@ impl JITModule { .libcall_plt_entries .get(libcall) .unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)) + .0 .as_ptr() .cast::(), _ => panic!("invalid name"), @@ -543,9 +556,13 @@ impl JITModule { continue; }; let got_entry = module.new_got_entry(addr); - module.libcall_got_entries.insert(libcall, got_entry); + module + .libcall_got_entries + .insert(libcall, SendWrapper(got_entry)); let plt_entry = module.new_plt_entry(got_entry); - module.libcall_plt_entries.insert(libcall, plt_entry); + module + .libcall_plt_entries + .insert(libcall, SendWrapper(plt_entry)); } module @@ -719,7 +736,7 @@ impl Module for JITModule { if self.isa.flags().is_pic() { self.pending_got_updates.push(GotUpdate { - entry: self.function_got_entries[id].unwrap(), + entry: self.function_got_entries[id].unwrap().0, ptr, }) } @@ -737,6 +754,7 @@ impl Module for JITModule { .libcall_plt_entries .get(libcall) .unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)) + .0 .as_ptr() .cast::(), _ => panic!("invalid name"), @@ -802,7 +820,7 @@ impl Module for JITModule { if self.isa.flags().is_pic() { self.pending_got_updates.push(GotUpdate { - entry: self.function_got_entries[id].unwrap(), + entry: self.function_got_entries[id].unwrap().0, ptr, }) } @@ -907,7 +925,7 @@ impl Module for JITModule { self.data_objects_to_finalize.push(id); if self.isa.flags().is_pic() { self.pending_got_updates.push(GotUpdate { - entry: self.data_object_got_entries[id].unwrap(), + entry: self.data_object_got_entries[id].unwrap().0, ptr, }) } diff --git a/cranelift/jit/src/compiled_blob.rs b/cranelift/jit/src/compiled_blob.rs index 84e10460e182..4eab207ff011 100644 --- a/cranelift/jit/src/compiled_blob.rs +++ b/cranelift/jit/src/compiled_blob.rs @@ -17,6 +17,8 @@ pub(crate) struct CompiledBlob { pub(crate) relocs: Vec, } +unsafe impl Send for CompiledBlob {} + impl CompiledBlob { pub(crate) fn perform_relocations( &self, diff --git a/cranelift/jit/src/memory.rs b/cranelift/jit/src/memory.rs index a4ceca3ac403..abbb513c663f 100644 --- a/cranelift/jit/src/memory.rs +++ b/cranelift/jit/src/memory.rs @@ -132,6 +132,8 @@ pub(crate) struct Memory { branch_protection: BranchProtection, } +unsafe impl Send for Memory {} + impl Memory { pub(crate) fn new(branch_protection: BranchProtection) -> Self { Self {