diff --git a/library/core/src/slice/mod.rs b/library/core/src/slice/mod.rs index 6d3e625bef428..8d87e17cb2392 100644 --- a/library/core/src/slice/mod.rs +++ b/library/core/src/slice/mod.rs @@ -9,7 +9,7 @@ use crate::cmp::Ordering::{self, Equal, Greater, Less}; use crate::fmt; use crate::hint; -use crate::intrinsics::{exact_div, unchecked_sub}; +use crate::intrinsics::{exact_div, select_unpredictable, unchecked_sub}; use crate::mem::{self, SizedTypeProperties}; use crate::num::NonZero; use crate::ops::{Bound, OneSidedRange, Range, RangeBounds}; @@ -2787,41 +2787,54 @@ impl [T] { where F: FnMut(&'a T) -> Ordering, { - // INVARIANTS: - // - 0 <= left <= left + size = right <= self.len() - // - f returns Less for everything in self[..left] - // - f returns Greater for everything in self[right..] let mut size = self.len(); - let mut left = 0; - let mut right = size; - while left < right { - let mid = left + size / 2; - - // SAFETY: the while condition means `size` is strictly positive, so - // `size/2 < size`. Thus `left + size/2 < left + size`, which - // coupled with the `left + size <= self.len()` invariant means - // we have `left + size/2 < self.len()`, and this is in-bounds. + if size == 0 { + return Err(0); + } + let mut base = 0usize; + + // This loop intentionally doesn't have an early exit if the comparison + // returns Equal. We want the number of loop iterations to depend *only* + // on the size of the input slice so that the CPU can reliably predict + // the loop count. + while size > 1 { + let half = size / 2; + let mid = base + half; + + // SAFETY: the call is made safe by the following inconstants: + // - `mid >= 0`: by definition + // - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...` let cmp = f(unsafe { self.get_unchecked(mid) }); - // This control flow produces conditional moves, which results in - // fewer branches and instructions than if/else or matching on - // cmp::Ordering. - // This is x86 asm for u8: https://rust.godbolt.org/z/698eYffTx. - left = if cmp == Less { mid + 1 } else { left }; - right = if cmp == Greater { mid } else { right }; - if cmp == Equal { - // SAFETY: same as the `get_unchecked` above - unsafe { hint::assert_unchecked(mid < self.len()) }; - return Ok(mid); - } - - size = right - left; + // Binary search interacts poorly with branch prediction, so force + // the compiler to use conditional moves if supported by the target + // architecture. + base = select_unpredictable(cmp == Greater, base, mid); + + // This is imprecise in the case where `size` is odd and the + // comparison returns Greater: the mid element still gets included + // by `size` even though it's known to be larger than the element + // being searched for. + // + // This is fine though: we gain more performance by keeping the + // loop iteration count invariant (and thus predictable) than we + // lose from considering one additional element. + size -= half; } - // SAFETY: directly true from the overall invariant. - // Note that this is `<=`, unlike the assume in the `Ok` path. - unsafe { hint::assert_unchecked(left <= self.len()) }; - Err(left) + // SAFETY: base is always in [0, size) because base <= mid. + let cmp = f(unsafe { self.get_unchecked(base) }); + if cmp == Equal { + // SAFETY: same as the `get_unchecked` above. + unsafe { hint::assert_unchecked(base < self.len()) }; + Ok(base) + } else { + let result = base + (cmp == Less) as usize; + // SAFETY: same as the `get_unchecked` above. + // Note that this is `<=`, unlike the assume in the `Ok` path. + unsafe { hint::assert_unchecked(result <= self.len()) }; + Err(result) + } } /// Binary searches this slice with a key extraction function. diff --git a/library/core/tests/slice.rs b/library/core/tests/slice.rs index 4cbbabb672ba0..9f526a85e69d3 100644 --- a/library/core/tests/slice.rs +++ b/library/core/tests/slice.rs @@ -69,13 +69,13 @@ fn test_binary_search() { assert_eq!(b.binary_search(&8), Err(5)); let b = [(); usize::MAX]; - assert_eq!(b.binary_search(&()), Ok(usize::MAX / 2)); + assert_eq!(b.binary_search(&()), Ok(usize::MAX - 1)); } #[test] fn test_binary_search_by_overflow() { let b = [(); usize::MAX]; - assert_eq!(b.binary_search_by(|_| Ordering::Equal), Ok(usize::MAX / 2)); + assert_eq!(b.binary_search_by(|_| Ordering::Equal), Ok(usize::MAX - 1)); assert_eq!(b.binary_search_by(|_| Ordering::Greater), Err(0)); assert_eq!(b.binary_search_by(|_| Ordering::Less), Err(usize::MAX)); } @@ -87,13 +87,13 @@ fn test_binary_search_implementation_details() { let b = [1, 1, 2, 2, 3, 3, 3]; assert_eq!(b.binary_search(&1), Ok(1)); assert_eq!(b.binary_search(&2), Ok(3)); - assert_eq!(b.binary_search(&3), Ok(5)); + assert_eq!(b.binary_search(&3), Ok(6)); let b = [1, 1, 1, 1, 1, 3, 3, 3, 3]; assert_eq!(b.binary_search(&1), Ok(4)); - assert_eq!(b.binary_search(&3), Ok(7)); + assert_eq!(b.binary_search(&3), Ok(8)); let b = [1, 1, 1, 1, 3, 3, 3, 3, 3]; - assert_eq!(b.binary_search(&1), Ok(2)); - assert_eq!(b.binary_search(&3), Ok(4)); + assert_eq!(b.binary_search(&1), Ok(3)); + assert_eq!(b.binary_search(&3), Ok(8)); } #[test]