Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

impl Send for JITModule #8718

Merged
merged 3 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 49 additions & 28 deletions cranelift/jit/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const READONLY_DATA_ALIGNMENT: u64 = 0x1;
/// A builder for `JITModule`.
pub struct JITBuilder {
isa: OwnedTargetIsa,
symbols: HashMap<String, *const u8>,
symbols: HashMap<String, SendWrapper<*const u8>>,
lookup_symbols: Vec<Box<dyn Fn(&str) -> Option<*const u8> + Send>>,
libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
hotswap_enabled: bool,
Expand Down Expand Up @@ -116,7 +116,7 @@ impl JITBuilder {
where
K: Into<String>,
{
self.symbols.insert(name.into(), ptr);
self.symbols.insert(name.into(), SendWrapper(ptr));
self
}

Expand All @@ -129,7 +129,7 @@ impl JITBuilder {
K: Into<String>,
{
for (name, ptr) in symbols {
self.symbols.insert(name.into(), ptr);
self.symbols.insert(name.into(), SendWrapper(ptr));
}
self
}
Expand Down Expand Up @@ -165,23 +165,37 @@ struct GotUpdate {
ptr: *const u8,
}

unsafe impl Send for GotUpdate {}

/// A wrapper that impls Send for the contents.
///
/// SAFETY: This must not be used for any types where it would be UB for them to be Send
struct SendWrapper<T>(T);
MolotovCherry marked this conversation as resolved.
Show resolved Hide resolved
unsafe impl<T> Send for SendWrapper<T> {}
impl<T: Copy> Copy for SendWrapper<T> {}
impl<T: Clone> Clone for SendWrapper<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

/// A `JITModule` implements `Module` and emits code and data into memory where it can be
/// directly called and accessed.
///
/// See the `JITBuilder` for a convenient way to construct `JITModule` instances.
pub struct JITModule {
isa: OwnedTargetIsa,
hotswap_enabled: bool,
symbols: RefCell<HashMap<String, *const u8>>,
symbols: RefCell<HashMap<String, SendWrapper<*const u8>>>,
lookup_symbols: Vec<Box<dyn Fn(&str) -> Option<*const u8> + Send>>,
libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
memory: MemoryHandle,
declarations: ModuleDeclarations,
function_got_entries: SecondaryMap<FuncId, Option<NonNull<AtomicPtr<u8>>>>,
function_plt_entries: SecondaryMap<FuncId, Option<NonNull<[u8; 16]>>>,
data_object_got_entries: SecondaryMap<DataId, Option<NonNull<AtomicPtr<u8>>>>,
libcall_got_entries: HashMap<ir::LibCall, NonNull<AtomicPtr<u8>>>,
libcall_plt_entries: HashMap<ir::LibCall, NonNull<[u8; 16]>>,
function_got_entries: SecondaryMap<FuncId, Option<SendWrapper<NonNull<AtomicPtr<u8>>>>>,
function_plt_entries: SecondaryMap<FuncId, Option<SendWrapper<NonNull<[u8; 16]>>>>,
data_object_got_entries: SecondaryMap<DataId, Option<SendWrapper<NonNull<AtomicPtr<u8>>>>>,
libcall_got_entries: HashMap<ir::LibCall, SendWrapper<NonNull<AtomicPtr<u8>>>>,
libcall_plt_entries: HashMap<ir::LibCall, SendWrapper<NonNull<[u8; 16]>>>,
compiled_functions: SecondaryMap<FuncId, Option<CompiledBlob>>,
compiled_data_objects: SecondaryMap<DataId, Option<CompiledBlob>>,
functions_to_finalize: Vec<FuncId>,
Expand All @@ -191,8 +205,6 @@ pub struct JITModule {
pending_got_updates: Vec<GotUpdate>,
}

unsafe impl Send for JITModule {}

/// A handle to allow freeing memory allocated by the `Module`.
struct MemoryHandle {
code: Memory,
Expand All @@ -217,15 +229,15 @@ impl JITModule {

fn lookup_symbol(&self, name: &str) -> Option<*const u8> {
match self.symbols.borrow_mut().entry(name.to_owned()) {
std::collections::hash_map::Entry::Occupied(occ) => Some(*occ.get()),
std::collections::hash_map::Entry::Occupied(occ) => Some(occ.get().0),
std::collections::hash_map::Entry::Vacant(vac) => {
let ptr = self
.lookup_symbols
.iter()
.rev() // Try last lookup function first
.find_map(|lookup| lookup(name));
if let Some(ptr) = ptr {
vac.insert(ptr);
vac.insert(SendWrapper(ptr));
}
ptr
}
Expand Down Expand Up @@ -268,7 +280,7 @@ impl JITModule {

fn new_func_plt_entry(&mut self, id: FuncId, val: *const u8) {
let got_entry = self.new_got_entry(val);
self.function_got_entries[id] = Some(got_entry);
self.function_got_entries[id] = Some(SendWrapper(got_entry));
let plt_entry = self.new_plt_entry(got_entry);
self.record_function_for_perf(
plt_entry.as_ptr().cast(),
Expand All @@ -278,12 +290,12 @@ impl JITModule {
self.declarations.get_function_decl(id).linkage_name(id)
),
);
self.function_plt_entries[id] = Some(plt_entry);
self.function_plt_entries[id] = Some(SendWrapper(plt_entry));
}

fn new_data_got_entry(&mut self, id: DataId, val: *const u8) {
let got_entry = self.new_got_entry(val);
self.data_object_got_entries[id] = Some(got_entry);
self.data_object_got_entries[id] = Some(SendWrapper(got_entry));
}

unsafe fn write_plt_entry_bytes(plt_ptr: *mut [u8; 16], got_ptr: NonNull<AtomicPtr<u8>>) {
Expand Down Expand Up @@ -352,24 +364,26 @@ impl JITModule {
/// Panics if there's no entry in the table for the given function.
pub fn read_got_entry(&self, func_id: FuncId) -> *const u8 {
let got_entry = self.function_got_entries[func_id].unwrap();
unsafe { got_entry.as_ref() }.load(Ordering::SeqCst)
unsafe { got_entry.0.as_ref() }.load(Ordering::SeqCst)
}

fn get_got_address(&self, name: &ModuleRelocTarget) -> NonNull<AtomicPtr<u8>> {
match *name {
ModuleRelocTarget::User { .. } => {
if ModuleDeclarations::is_function(name) {
let func_id = FuncId::from_name(name);
self.function_got_entries[func_id].unwrap()
self.function_got_entries[func_id].unwrap().0
} else {
let data_id = DataId::from_name(name);
self.data_object_got_entries[data_id].unwrap()
self.data_object_got_entries[data_id].unwrap().0
}
}
ModuleRelocTarget::LibCall(ref libcall) => *self
.libcall_got_entries
.get(libcall)
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)),
ModuleRelocTarget::LibCall(ref libcall) => {
self.libcall_got_entries
.get(libcall)
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall))
.0
}
_ => panic!("invalid name"),
}
}
Expand All @@ -381,6 +395,7 @@ impl JITModule {
let func_id = FuncId::from_name(name);
self.function_plt_entries[func_id]
.unwrap()
.0
.as_ptr()
.cast::<u8>()
} else {
Expand All @@ -391,6 +406,7 @@ impl JITModule {
.libcall_plt_entries
.get(libcall)
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall))
.0
.as_ptr()
.cast::<u8>(),
_ => panic!("invalid name"),
Expand Down Expand Up @@ -545,9 +561,13 @@ impl JITModule {
continue;
};
let got_entry = module.new_got_entry(addr);
module.libcall_got_entries.insert(libcall, got_entry);
module
.libcall_got_entries
.insert(libcall, SendWrapper(got_entry));
let plt_entry = module.new_plt_entry(got_entry);
module.libcall_plt_entries.insert(libcall, plt_entry);
module
.libcall_plt_entries
.insert(libcall, SendWrapper(plt_entry));
}

module
Expand Down Expand Up @@ -721,7 +741,7 @@ impl Module for JITModule {

if self.isa.flags().is_pic() {
self.pending_got_updates.push(GotUpdate {
entry: self.function_got_entries[id].unwrap(),
entry: self.function_got_entries[id].unwrap().0,
ptr,
})
}
Expand All @@ -739,6 +759,7 @@ impl Module for JITModule {
.libcall_plt_entries
.get(libcall)
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall))
.0
.as_ptr()
.cast::<u8>(),
_ => panic!("invalid name"),
Expand Down Expand Up @@ -804,7 +825,7 @@ impl Module for JITModule {

if self.isa.flags().is_pic() {
self.pending_got_updates.push(GotUpdate {
entry: self.function_got_entries[id].unwrap(),
entry: self.function_got_entries[id].unwrap().0,
ptr,
})
}
Expand Down Expand Up @@ -909,7 +930,7 @@ impl Module for JITModule {
self.data_objects_to_finalize.push(id);
if self.isa.flags().is_pic() {
self.pending_got_updates.push(GotUpdate {
entry: self.data_object_got_entries[id].unwrap(),
entry: self.data_object_got_entries[id].unwrap().0,
ptr,
})
}
Expand Down
2 changes: 2 additions & 0 deletions cranelift/jit/src/compiled_blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub(crate) struct CompiledBlob {
pub(crate) relocs: Vec<ModuleReloc>,
}

unsafe impl Send for CompiledBlob {}

impl CompiledBlob {
pub(crate) fn perform_relocations(
&self,
Expand Down
2 changes: 2 additions & 0 deletions cranelift/jit/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ pub(crate) struct Memory {
branch_protection: BranchProtection,
}

unsafe impl Send for Memory {}

impl Memory {
pub(crate) fn new(branch_protection: BranchProtection) -> Self {
Self {
Expand Down