internal/btree: add SetWithIndex

Change-Id: I936f99b0fa1ac19dec938c95db0d5853540cc904
Reviewed-on: https://code-review.googlesource.com/26010
Reviewed-by: Jean de Klerk <deklerk@google.com>
diff --git a/internal/btree/README.md b/internal/btree/README.md
index d323715..601ff54 100644
--- a/internal/btree/README.md
+++ b/internal/btree/README.md
@@ -1,5 +1,5 @@
 This package is a fork of github.com/jba/btree at commit
-aa53f88384b4d43de7e047ebe8d2c0fbb84fce89, which itself was a fork of
+d4edd57f39b8425fc2c631047ff4dc6024d82a4f, which itself was a fork of
 github.com/google/btree at 316fb6d3f031ae8f4d457c6c5186b9e3ded70435.
 
 This directory makes the following modifications:
diff --git a/internal/btree/btree.go b/internal/btree/btree.go
index 4c81e6f..7dfd78e 100644
--- a/internal/btree/btree.go
+++ b/internal/btree/btree.go
@@ -1,4 +1,5 @@
-// Copyright 2014 Google Inc. All Rights Reserved.
+// Copyright 2014 Google Inc.
+// Modified 2018 by Jonathan Amsterdam (jbamsterdam@gmail.com)
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -273,17 +274,22 @@
 // insert inserts an item into the subtree rooted at this node, making sure
 // no nodes in the subtree exceed maxItems items.  Should an equivalent item be
 // be found/replaced by insert, its value will be returned.
-func (n *node) insert(m item, maxItems int, less lessFunc) (old Value, present bool) {
+//
+// If computeIndex is true, the third return value is the index of the value with respect to n.
+func (n *node) insert(m item, maxItems int, less lessFunc, computeIndex bool) (old Value, present bool, idx int) {
 	i, found := n.items.find(m.key, less)
 	if found {
 		out := n.items[i]
 		n.items[i] = m
-		return out.value, true
+		if computeIndex {
+			idx = n.itemIndex(i)
+		}
+		return out.value, true, idx
 	}
 	if len(n.children) == 0 {
 		n.items.insertAt(i, m)
 		n.size++
-		return old, false
+		return old, false, i
 	}
 	if n.maybeSplitChild(i, maxItems) {
 		inTree := n.items[i]
@@ -295,31 +301,33 @@
 		default:
 			out := n.items[i]
 			n.items[i] = m
-			return out.value, true
+			if computeIndex {
+				idx = n.itemIndex(i)
+			}
+			return out.value, true, idx
 		}
 	}
-	old, present = n.mutableChild(i).insert(m, maxItems, less)
+	old, present, idx = n.mutableChild(i).insert(m, maxItems, less, computeIndex)
 	if !present {
 		n.size++
 	}
-	return old, present
+	if computeIndex {
+		idx += n.partialSize(i)
+	}
+	return old, present, idx
 }
 
 // get finds the given key in the subtree and returns the corresponding item, along with a boolean reporting
 // whether it was found.
-// If withIndex is true, it also returns the index of the key relative to the node's subtree.
-func (n *node) get(k Key, withIndex bool, less lessFunc) (item, bool, int) {
+// If computeIndex is true, it also returns the index of the key relative to the node's subtree.
+func (n *node) get(k Key, computeIndex bool, less lessFunc) (item, bool, int) {
 	i, found := n.items.find(k, less)
 	if found {
-		idx := i
-		if withIndex && len(n.children) > 0 {
-			idx = n.partialSize(i+1) - 1
-		}
-		return n.items[i], true, idx
+		return n.items[i], true, n.itemIndex(i)
 	}
 	if len(n.children) > 0 {
-		m, found, idx := n.children[i].get(k, withIndex, less)
-		if withIndex && found {
+		m, found, idx := n.children[i].get(k, computeIndex, less)
+		if computeIndex && found {
 			idx += n.partialSize(i)
 		}
 		return m, found, idx
@@ -327,6 +335,16 @@
 	return item{}, false, -1
 }
 
+// itemIndex returns the index w.r.t. n of the ith item in n.
+func (n *node) itemIndex(i int) int {
+	if len(n.children) == 0 {
+		return i
+	}
+	// Get the size of the node up to but not including the child to the right of
+	// item i. Subtract 1 because the index is 0-based.
+	return n.partialSize(i+1) - 1
+}
+
 // Returns the size of the non-leaf node up to but not including child i.
 func (n *node) partialSize(i int) int {
 	var sz int
@@ -617,11 +635,20 @@
 // return value of true. If the key is not in the tree, it is added, and the second
 // return value is false.
 func (t *BTree) Set(k Key, v Value) (old Value, present bool) {
+	old, present, _ = t.set(k, v, false)
+	return old, present
+}
+
+func (t *BTree) SetWithIndex(k Key, v Value) (old Value, present bool, index int) {
+	return t.set(k, v, true)
+}
+
+func (t *BTree) set(k Key, v Value, computeIndex bool) (old Value, present bool, idx int) {
 	if t.root == nil {
 		t.root = t.cow.newNode()
 		t.root.items = append(t.root.items, item{k, v})
 		t.root.size = 1
-		return old, false
+		return old, false, 0
 	}
 	t.root = t.root.mutableFor(t.cow)
 	if len(t.root.items) >= t.maxItems() {
@@ -634,7 +661,7 @@
 		t.root.size = sz
 	}
 
-	return t.root.insert(item{k, v}, t.maxItems(), t.less)
+	return t.root.insert(item{k, v}, t.maxItems(), t.less, computeIndex)
 }
 
 // Delete removes the item with the given key, returning its value. The second return value
diff --git a/internal/btree/btree_test.go b/internal/btree/btree_test.go
index 0381edc..0a12104 100644
--- a/internal/btree/btree_test.go
+++ b/internal/btree/btree_test.go
@@ -1,4 +1,5 @@
-// Copyright 2014 Google Inc. All Rights Reserved.
+// Copyright 2014 Google Inc.
+// Modified 2018 by Jonathan Amsterdam (jbamsterdam@gmail.com)
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -19,6 +20,7 @@
 	"fmt"
 	"math/rand"
 	"os"
+	"sort"
 	"sync"
 	"testing"
 	"time"
@@ -98,9 +100,13 @@
 			}
 		}
 		for _, m := range perm(treeSize) {
-			if _, ok := tr.Set(m.Key, m.Value); !ok {
+			_, ok, idx := tr.SetWithIndex(m.Key, m.Value)
+			if !ok {
 				t.Fatal("set didn't find item", m)
 			}
+			if idx != m.Index {
+				t.Fatalf("got index %d, want %d", idx, m.Index)
+			}
 		}
 		mink, minv := tr.Min()
 		if want := 0; mink != want || minv != want {
@@ -159,6 +165,26 @@
 	}
 }
 
+func TestSetWithIndex(t *testing.T) {
+	tr := New(4, less) // use a small degree to cover more cases
+	var contents []int
+	for _, m := range perm(100) {
+		_, _, idx := tr.SetWithIndex(m.Key, m.Value)
+		contents = append(contents, m.Index)
+		sort.Ints(contents)
+		want := -1
+		for i, c := range contents {
+			if c == m.Index {
+				want = i
+				break
+			}
+		}
+		if idx != want {
+			t.Fatalf("got %d, want %d", idx, want)
+		}
+	}
+}
+
 func TestDeleteMin(t *testing.T) {
 	tr := New(3, less)
 	for _, m := range perm(100) {