Skip to content

Commit

Permalink
BTreeSet symmetric_difference & union optimized, cleaned
Browse files Browse the repository at this point in the history
  • Loading branch information
ssomers committed Oct 17, 2019
1 parent 3da6836 commit 5697432
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 121 deletions.
239 changes: 119 additions & 120 deletions src/liballoc/collections/btree/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// to TreeMap

use core::borrow::Borrow;
use core::cmp::Ordering::{self, Less, Greater, Equal};
use core::cmp::Ordering::{Less, Greater, Equal};
use core::cmp::{max, min};
use core::fmt::{self, Debug};
use core::iter::{Peekable, FromIterator, FusedIterator};
Expand Down Expand Up @@ -109,6 +109,77 @@ pub struct Range<'a, T: 'a> {
iter: btree_map::Range<'a, T, ()>,
}

/// Core of SymmetricDifference and Union.
/// More efficient than btree.map.MergeIter,
/// and crucially for SymmetricDifference, nexts() reports on both sides.
#[derive(Clone)]
struct MergeIterInner<I>
where I: Iterator,
I::Item: Copy,
{
a: I,
b: I,
peeked: Option<MergeIterPeeked<I>>,
}

#[derive(Copy, Clone, Debug)]
enum MergeIterPeeked<I: Iterator> {
A(I::Item),
B(I::Item),
}

impl<I> MergeIterInner<I>
where I: ExactSizeIterator + FusedIterator,
I::Item: Copy + Ord,
{
fn new(a: I, b: I) -> Self {
MergeIterInner { a, b, peeked: None }
}

fn nexts(&mut self) -> (Option<I::Item>, Option<I::Item>) {
let mut a_next = match self.peeked {
Some(MergeIterPeeked::A(next)) => Some(next),
_ => self.a.next(),
};
let mut b_next = match self.peeked {
Some(MergeIterPeeked::B(next)) => Some(next),
_ => self.b.next(),
};
let ord = match (a_next, b_next) {
(None, None) => Equal,
(_, None) => Less,
(None, _) => Greater,
(Some(a1), Some(b1)) => a1.cmp(&b1),
};
self.peeked = match ord {
Less => b_next.take().map(MergeIterPeeked::B),
Equal => None,
Greater => a_next.take().map(MergeIterPeeked::A),
};
(a_next, b_next)
}

fn lens(&self) -> (usize, usize) {
match self.peeked {
Some(MergeIterPeeked::A(_)) => (1 + self.a.len(), self.b.len()),
Some(MergeIterPeeked::B(_)) => (self.a.len(), 1 + self.b.len()),
_ => (self.a.len(), self.b.len()),
}
}
}

impl<I> Debug for MergeIterInner<I>
where I: Iterator + Debug,
I::Item: Copy + Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("MergeIterInner")
.field(&self.a)
.field(&self.b)
.finish()
}
}

/// A lazy iterator producing elements in the difference of `BTreeSet`s.
///
/// This `struct` is created by the [`difference`] method on [`BTreeSet`].
Expand All @@ -120,6 +191,7 @@ pub struct Range<'a, T: 'a> {
pub struct Difference<'a, T: 'a> {
inner: DifferenceInner<'a, T>,
}
#[derive(Debug)]
enum DifferenceInner<'a, T: 'a> {
Stitch {
// iterate all of self and some of other, spotting matches along the way
Expand All @@ -137,21 +209,7 @@ enum DifferenceInner<'a, T: 'a> {
#[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for Difference<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.inner {
DifferenceInner::Stitch {
self_iter,
other_iter,
} => f
.debug_tuple("Difference")
.field(&self_iter)
.field(&other_iter)
.finish(),
DifferenceInner::Search {
self_iter,
other_set: _,
} => f.debug_tuple("Difference").field(&self_iter).finish(),
DifferenceInner::Iterate(iter) => f.debug_tuple("Difference").field(&iter).finish(),
}
f.debug_tuple("Difference").field(&self.inner).finish()
}
}

Expand All @@ -163,18 +221,12 @@ impl<T: fmt::Debug> fmt::Debug for Difference<'_, T> {
/// [`BTreeSet`]: struct.BTreeSet.html
/// [`symmetric_difference`]: struct.BTreeSet.html#method.symmetric_difference
#[stable(feature = "rust1", since = "1.0.0")]
pub struct SymmetricDifference<'a, T: 'a> {
a: Peekable<Iter<'a, T>>,
b: Peekable<Iter<'a, T>>,
}
pub struct SymmetricDifference<'a, T: 'a>(MergeIterInner<Iter<'a, T>>);

#[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("SymmetricDifference")
.field(&self.a)
.field(&self.b)
.finish()
f.debug_tuple("SymmetricDifference").field(&self.0).finish()
}
}

Expand All @@ -189,6 +241,7 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
pub struct Intersection<'a, T: 'a> {
inner: IntersectionInner<'a, T>,
}
#[derive(Debug)]
enum IntersectionInner<'a, T: 'a> {
Stitch {
// iterate similarly sized sets jointly, spotting matches along the way
Expand All @@ -206,23 +259,7 @@ enum IntersectionInner<'a, T: 'a> {
#[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.inner {
IntersectionInner::Stitch {
a,
b,
} => f
.debug_tuple("Intersection")
.field(&a)
.field(&b)
.finish(),
IntersectionInner::Search {
small_iter,
large_set: _,
} => f.debug_tuple("Intersection").field(&small_iter).finish(),
IntersectionInner::Answer(answer) => {
f.debug_tuple("Intersection").field(&answer).finish()
}
}
f.debug_tuple("Intersection").field(&self.inner).finish()
}
}

Expand All @@ -234,18 +271,12 @@ impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
/// [`BTreeSet`]: struct.BTreeSet.html
/// [`union`]: struct.BTreeSet.html#method.union
#[stable(feature = "rust1", since = "1.0.0")]
pub struct Union<'a, T: 'a> {
a: Peekable<Iter<'a, T>>,
b: Peekable<Iter<'a, T>>,
}
pub struct Union<'a, T: 'a>(MergeIterInner<Iter<'a, T>>);

#[stable(feature = "collection_debug", since = "1.17.0")]
impl<T: fmt::Debug> fmt::Debug for Union<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Union")
.field(&self.a)
.field(&self.b)
.finish()
f.debug_tuple("Union").field(&self.0).finish()
}
}

Expand Down Expand Up @@ -355,19 +386,16 @@ impl<T: Ord> BTreeSet<T> {
self_iter.next_back();
DifferenceInner::Iterate(self_iter)
}
_ => {
if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
DifferenceInner::Search {
self_iter: self.iter(),
other_set: other,
}
} else {
DifferenceInner::Stitch {
self_iter: self.iter(),
other_iter: other.iter().peekable(),
}
_ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
DifferenceInner::Search {
self_iter: self.iter(),
other_set: other,
}
}
_ => DifferenceInner::Stitch {
self_iter: self.iter(),
other_iter: other.iter().peekable(),
},
},
}
}
Expand Down Expand Up @@ -396,10 +424,7 @@ impl<T: Ord> BTreeSet<T> {
pub fn symmetric_difference<'a>(&'a self,
other: &'a BTreeSet<T>)
-> SymmetricDifference<'a, T> {
SymmetricDifference {
a: self.iter().peekable(),
b: other.iter().peekable(),
}
SymmetricDifference(MergeIterInner::new(self.iter(), other.iter()))
}

/// Visits the values representing the intersection,
Expand Down Expand Up @@ -447,24 +472,22 @@ impl<T: Ord> BTreeSet<T> {
(Greater, _) | (_, Less) => IntersectionInner::Answer(None),
(Equal, _) => IntersectionInner::Answer(Some(self_min)),
(_, Equal) => IntersectionInner::Answer(Some(self_max)),
_ => {
if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
IntersectionInner::Search {
small_iter: self.iter(),
large_set: other,
}
} else if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
IntersectionInner::Search {
small_iter: other.iter(),
large_set: self,
}
} else {
IntersectionInner::Stitch {
a: self.iter(),
b: other.iter(),
}
_ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
IntersectionInner::Search {
small_iter: self.iter(),
large_set: other,
}
}
_ if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
IntersectionInner::Search {
small_iter: other.iter(),
large_set: self,
}
}
_ => IntersectionInner::Stitch {
a: self.iter(),
b: other.iter(),
},
},
}
}
Expand All @@ -489,10 +512,7 @@ impl<T: Ord> BTreeSet<T> {
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn union<'a>(&'a self, other: &'a BTreeSet<T>) -> Union<'a, T> {
Union {
a: self.iter().peekable(),
b: other.iter().peekable(),
}
Union(MergeIterInner::new(self.iter(), other.iter()))
}

/// Clears the set, removing all values.
Expand Down Expand Up @@ -1166,15 +1186,6 @@ impl<'a, T> DoubleEndedIterator for Range<'a, T> {
#[stable(feature = "fused", since = "1.26.0")]
impl<T> FusedIterator for Range<'_, T> {}

/// Compares `x` and `y`, but return `short` if x is None and `long` if y is None
fn cmp_opt<T: Ord>(x: Option<&T>, y: Option<&T>, short: Ordering, long: Ordering) -> Ordering {
match (x, y) {
(None, _) => short,
(_, None) => long,
(Some(x1), Some(y1)) => x1.cmp(y1),
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<T> Clone for Difference<'_, T> {
fn clone(&self) -> Self {
Expand Down Expand Up @@ -1261,10 +1272,7 @@ impl<T: Ord> FusedIterator for Difference<'_, T> {}
#[stable(feature = "rust1", since = "1.0.0")]
impl<T> Clone for SymmetricDifference<'_, T> {
fn clone(&self) -> Self {
SymmetricDifference {
a: self.a.clone(),
b: self.b.clone(),
}
SymmetricDifference(self.0.clone())
}
}
#[stable(feature = "rust1", since = "1.0.0")]
Expand All @@ -1273,19 +1281,19 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> {

fn next(&mut self) -> Option<&'a T> {
loop {
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
Less => return self.a.next(),
Equal => {
self.a.next();
self.b.next();
}
Greater => return self.b.next(),
let (a_next, b_next) = self.0.nexts();
if a_next.and(b_next).is_none() {
return a_next.or(b_next);
}
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(self.a.len() + self.b.len()))
let (a_len, b_len) = self.0.lens();
// No checked_add, because even if a and b refer to the same set,
// and T is an empty type, the storage overhead of sets limits
// the number of elements to less than half the range of usize.
(0, Some(a_len + b_len))
}
}

Expand All @@ -1311,7 +1319,7 @@ impl<T> Clone for Intersection<'_, T> {
small_iter: small_iter.clone(),
large_set,
},
IntersectionInner::Answer(answer) => IntersectionInner::Answer(answer.clone()),
IntersectionInner::Answer(answer) => IntersectionInner::Answer(*answer),
},
}
}
Expand Down Expand Up @@ -1365,30 +1373,21 @@ impl<T: Ord> FusedIterator for Intersection<'_, T> {}
#[stable(feature = "rust1", since = "1.0.0")]
impl<T> Clone for Union<'_, T> {
fn clone(&self) -> Self {
Union {
a: self.a.clone(),
b: self.b.clone(),
}
Union(self.0.clone())
}
}
#[stable(feature = "rust1", since = "1.0.0")]
impl<'a, T: Ord> Iterator for Union<'a, T> {
type Item = &'a T;

fn next(&mut self) -> Option<&'a T> {
match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
Less => self.a.next(),
Equal => {
self.b.next();
self.a.next()
}
Greater => self.b.next(),
}
let (a_next, b_next) = self.0.nexts();
a_next.or(b_next)
}

fn size_hint(&self) -> (usize, Option<usize>) {
let a_len = self.a.len();
let b_len = self.b.len();
let (a_len, b_len) = self.0.lens();
// No checked_add - see SymmetricDifference::size_hint.
(max(a_len, b_len), Some(a_len + b_len))
}
}
Expand Down
Loading

0 comments on commit 5697432

Please sign in to comment.