Improve BTreeSet::Intersection::size_hint
The commented invariant that an iterator is smaller than other iterator
was violated after next is called and two iterators are consumed at
different rates.
diff --git a/src/liballoc/collections/btree/set.rs b/src/liballoc/collections/btree/set.rs
index d3af910..0cb91ba 100644
--- a/src/liballoc/collections/btree/set.rs
+++ b/src/liballoc/collections/btree/set.rs
@@ -3,7 +3,7 @@
use core::borrow::Borrow;
use core::cmp::Ordering::{self, Less, Greater, Equal};
-use core::cmp::max;
+use core::cmp::{max, min};
use core::fmt::{self, Debug};
use core::iter::{Peekable, FromIterator, FusedIterator};
use core::ops::{BitOr, BitAnd, BitXor, Sub, RangeBounds};
@@ -187,8 +187,8 @@
}
enum IntersectionInner<'a, T: 'a> {
Stitch {
- small_iter: Iter<'a, T>, // for size_hint, should be the smaller of the sets
- other_iter: Iter<'a, T>,
+ a: Iter<'a, T>,
+ b: Iter<'a, T>,
},
Search {
small_iter: Iter<'a, T>,
@@ -201,12 +201,12 @@
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.inner {
IntersectionInner::Stitch {
- small_iter,
- other_iter,
+ a,
+ b,
} => f
.debug_tuple("Intersection")
- .field(&small_iter)
- .field(&other_iter)
+ .field(&a)
+ .field(&b)
.finish(),
IntersectionInner::Search {
small_iter,
@@ -397,8 +397,8 @@
// Iterate both sets jointly, spotting matches along the way.
Intersection {
inner: IntersectionInner::Stitch {
- small_iter: small.iter(),
- other_iter: other.iter(),
+ a: small.iter(),
+ b: other.iter(),
},
}
} else {
@@ -1221,11 +1221,11 @@
Intersection {
inner: match &self.inner {
IntersectionInner::Stitch {
- small_iter,
- other_iter,
+ a,
+ b,
} => IntersectionInner::Stitch {
- small_iter: small_iter.clone(),
- other_iter: other_iter.clone(),
+ a: a.clone(),
+ b: b.clone(),
},
IntersectionInner::Search {
small_iter,
@@ -1245,16 +1245,16 @@
fn next(&mut self) -> Option<&'a T> {
match &mut self.inner {
IntersectionInner::Stitch {
- small_iter,
- other_iter,
+ a,
+ b,
} => {
- let mut small_next = small_iter.next()?;
- let mut other_next = other_iter.next()?;
+ let mut a_next = a.next()?;
+ let mut b_next = b.next()?;
loop {
- match Ord::cmp(small_next, other_next) {
- Less => small_next = small_iter.next()?,
- Greater => other_next = other_iter.next()?,
- Equal => return Some(small_next),
+ match Ord::cmp(a_next, b_next) {
+ Less => a_next = a.next()?,
+ Greater => b_next = b.next()?,
+ Equal => return Some(a_next),
}
}
}
@@ -1272,7 +1272,7 @@
fn size_hint(&self) -> (usize, Option<usize>) {
let min_len = match &self.inner {
- IntersectionInner::Stitch { small_iter, .. } => small_iter.len(),
+ IntersectionInner::Stitch { a, b } => min(a.len(), b.len()),
IntersectionInner::Search { small_iter, .. } => small_iter.len(),
};
(0, Some(min_len))
diff --git a/src/liballoc/tests/btree/set.rs b/src/liballoc/tests/btree/set.rs
index 62ccb53..35db18c 100644
--- a/src/liballoc/tests/btree/set.rs
+++ b/src/liballoc/tests/btree/set.rs
@@ -91,6 +91,17 @@
}
#[test]
+fn test_intersection_size_hint() {
+ let x: BTreeSet<i32> = [3, 4].iter().copied().collect();
+ let y: BTreeSet<i32> = [1, 2, 3].iter().copied().collect();
+ let mut iter = x.intersection(&y);
+ assert_eq!(iter.size_hint(), (0, Some(2)));
+ assert_eq!(iter.next(), Some(&3));
+ assert_eq!(iter.size_hint(), (0, Some(0)));
+ assert_eq!(iter.next(), None);
+}
+
+#[test]
fn test_difference() {
fn check_difference(a: &[i32], b: &[i32], expected: &[i32]) {
check(a, b, expected, |x, y, f| x.difference(y).all(f))