Auto merge of #81090 - ssomers:btree_drainy_refactor_2, r=Mark-Simulacrum

BTreeMap: offer merge in variants with more clarity

r? `@Mark-Simulacrum`
diff --git a/library/alloc/src/collections/btree/node.rs b/library/alloc/src/collections/btree/node.rs
index 928c6f9..097e3e6 100644
--- a/library/alloc/src/collections/btree/node.rs
+++ b/library/alloc/src/collections/btree/node.rs
@@ -1282,25 +1282,25 @@
         self.right_child
     }
 
-    /// Returns `true` if it is valid to call `.merge()` in the balancing context,
-    /// i.e., whether there is enough room in a node to hold the combination of
-    /// both adjacent child nodes, along with the key-value pair in the parent.
+    /// Returns whether merging is possible, i.e., whether there is enough room
+    /// in a node to combine the central KV with both adjacent child nodes.
     pub fn can_merge(&self) -> bool {
         self.left_child.len() + 1 + self.right_child.len() <= CAPACITY
     }
 }
 
 impl<'a, K: 'a, V: 'a> BalancingContext<'a, K, V> {
-    /// Merges the parent's key-value pair and both adjacent child nodes into
-    /// the left node and returns an edge handle in that expanded left node.
-    /// If `track_edge_idx` is given some value, the returned edge corresponds
-    /// to where the edge in that child node ended up,
-    ///
-    /// Panics unless we `.can_merge()`.
-    pub fn merge(
+    /// Performs a merge and lets a closure decide what to return.
+    fn do_merge<
+        F: FnOnce(
+            NodeRef<marker::Mut<'a>, K, V, marker::Internal>,
+            NodeRef<marker::Mut<'a>, K, V, marker::LeafOrInternal>,
+        ) -> R,
+        R,
+    >(
         self,
-        track_edge_idx: Option<LeftOrRight<usize>>,
-    ) -> Handle<NodeRef<marker::Mut<'a>, K, V, marker::LeafOrInternal>, marker::Edge> {
+        result: F,
+    ) -> R {
         let Handle { node: mut parent_node, idx: parent_idx, _marker } = self.parent;
         let old_parent_len = parent_node.len();
         let mut left_node = self.left_child;
@@ -1310,11 +1310,6 @@
         let new_left_len = old_left_len + 1 + right_len;
 
         assert!(new_left_len <= CAPACITY);
-        assert!(match track_edge_idx {
-            None => true,
-            Some(LeftOrRight::Left(idx)) => idx <= old_left_len,
-            Some(LeftOrRight::Right(idx)) => idx <= right_len,
-        });
 
         unsafe {
             *left_node.len_mut() = new_left_len as u16;
@@ -1353,14 +1348,47 @@
             } else {
                 Global.deallocate(right_node.node.cast(), Layout::new::<LeafNode<K, V>>());
             }
-
-            let new_idx = match track_edge_idx {
-                None => 0,
-                Some(LeftOrRight::Left(idx)) => idx,
-                Some(LeftOrRight::Right(idx)) => old_left_len + 1 + idx,
-            };
-            Handle::new_edge(left_node, new_idx)
         }
+        result(parent_node, left_node)
+    }
+
+    /// Merges the parent's key-value pair and both adjacent child nodes into
+    /// the left child node and returns the shrunk parent node.
+    ///
+    /// Panics unless we `.can_merge()`.
+    pub fn merge_tracking_parent(self) -> NodeRef<marker::Mut<'a>, K, V, marker::Internal> {
+        self.do_merge(|parent, _child| parent)
+    }
+
+    /// Merges the parent's key-value pair and both adjacent child nodes into
+    /// the left child node and returns that child node.
+    ///
+    /// Panics unless we `.can_merge()`.
+    pub fn merge_tracking_child(self) -> NodeRef<marker::Mut<'a>, K, V, marker::LeafOrInternal> {
+        self.do_merge(|_parent, child| child)
+    }
+
+    /// Merges the parent's key-value pair and both adjacent child nodes into
+    /// the left child node and returns the edge handle in that child node
+    /// where the tracked child edge ended up,
+    ///
+    /// Panics unless we `.can_merge()`.
+    pub fn merge_tracking_child_edge(
+        self,
+        track_edge_idx: LeftOrRight<usize>,
+    ) -> Handle<NodeRef<marker::Mut<'a>, K, V, marker::LeafOrInternal>, marker::Edge> {
+        let old_left_len = self.left_child.len();
+        let right_len = self.right_child.len();
+        assert!(match track_edge_idx {
+            LeftOrRight::Left(idx) => idx <= old_left_len,
+            LeftOrRight::Right(idx) => idx <= right_len,
+        });
+        let child = self.merge_tracking_child();
+        let new_idx = match track_edge_idx {
+            LeftOrRight::Left(idx) => idx,
+            LeftOrRight::Right(idx) => old_left_len + 1 + idx,
+        };
+        unsafe { Handle::new_edge(child, new_idx) }
     }
 
     /// Removes a key-value pair from the left child and places it in the key-value storage
diff --git a/library/alloc/src/collections/btree/remove.rs b/library/alloc/src/collections/btree/remove.rs
index 7aeb39c..04683e0 100644
--- a/library/alloc/src/collections/btree/remove.rs
+++ b/library/alloc/src/collections/btree/remove.rs
@@ -33,7 +33,7 @@
                 Ok(Left(left_parent_kv)) => {
                     debug_assert!(left_parent_kv.right_child_len() == MIN_LEN - 1);
                     if left_parent_kv.can_merge() {
-                        left_parent_kv.merge(Some(Right(idx)))
+                        left_parent_kv.merge_tracking_child_edge(Right(idx))
                     } else {
                         debug_assert!(left_parent_kv.left_child_len() > MIN_LEN);
                         left_parent_kv.steal_left(idx)
@@ -42,7 +42,7 @@
                 Ok(Right(right_parent_kv)) => {
                     debug_assert!(right_parent_kv.left_child_len() == MIN_LEN - 1);
                     if right_parent_kv.can_merge() {
-                        right_parent_kv.merge(Some(Left(idx)))
+                        right_parent_kv.merge_tracking_child_edge(Left(idx))
                     } else {
                         debug_assert!(right_parent_kv.right_child_len() > MIN_LEN);
                         right_parent_kv.steal_right(idx)
@@ -124,9 +124,8 @@
             Ok(Left(left_parent_kv)) => {
                 debug_assert_eq!(left_parent_kv.right_child_len(), MIN_LEN - 1);
                 if left_parent_kv.can_merge() {
-                    let pos = left_parent_kv.merge(None);
-                    let parent_edge = unsafe { unwrap_unchecked(pos.into_node().ascend().ok()) };
-                    Some(parent_edge.into_node())
+                    let parent = left_parent_kv.merge_tracking_parent();
+                    Some(parent)
                 } else {
                     debug_assert!(left_parent_kv.left_child_len() > MIN_LEN);
                     left_parent_kv.steal_left(0);
@@ -136,9 +135,8 @@
             Ok(Right(right_parent_kv)) => {
                 debug_assert_eq!(right_parent_kv.left_child_len(), MIN_LEN - 1);
                 if right_parent_kv.can_merge() {
-                    let pos = right_parent_kv.merge(None);
-                    let parent_edge = unsafe { unwrap_unchecked(pos.into_node().ascend().ok()) };
-                    Some(parent_edge.into_node())
+                    let parent = right_parent_kv.merge_tracking_parent();
+                    Some(parent)
                 } else {
                     debug_assert!(right_parent_kv.right_child_len() > MIN_LEN);
                     right_parent_kv.steal_right(0);
diff --git a/library/alloc/src/collections/btree/split.rs b/library/alloc/src/collections/btree/split.rs
index 4561c8e..375c617 100644
--- a/library/alloc/src/collections/btree/split.rs
+++ b/library/alloc/src/collections/btree/split.rs
@@ -66,7 +66,7 @@
                 let mut last_kv = node.last_kv().consider_for_balancing();
 
                 if last_kv.can_merge() {
-                    cur_node = last_kv.merge(None).into_node();
+                    cur_node = last_kv.merge_tracking_child();
                 } else {
                     let right_len = last_kv.right_child_len();
                     // `MIN_LEN + 1` to avoid readjust if merge happens on the next level.
@@ -93,7 +93,7 @@
                 let mut first_kv = node.first_kv().consider_for_balancing();
 
                 if first_kv.can_merge() {
-                    cur_node = first_kv.merge(None).into_node();
+                    cur_node = first_kv.merge_tracking_child();
                 } else {
                     let left_len = first_kv.left_child_len();
                     // `MIN_LEN + 1` to avoid readjust if merge happens on the next level.