| // Copyright ©2019 The Gonum Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package kdtree |
| |
| import ( |
| "flag" |
| "fmt" |
| "math" |
| "os" |
| "reflect" |
| "sort" |
| "strings" |
| "testing" |
| "unsafe" |
| |
| "golang.org/x/exp/rand" |
| ) |
| |
| var ( |
| genDot = flag.Bool("dot", false, "generate dot code for failing trees") |
| dotLimit = flag.Int("dotmax", 100, "specify maximum size for tree output for dot format") |
| ) |
| |
| var ( |
| // Using example from WP article: https://en.wikipedia.org/w/index.php?title=K-d_tree&oldid=887573572. |
| wpData = Points{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}} |
| nbWpData = nbPoints{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}} |
| wpBound = &Bounding{Point{2, 1}, Point{9, 7}} |
| ) |
| |
| var newTests = []struct { |
| data Interface |
| bounding bool |
| wantBounds *Bounding |
| }{ |
| {data: wpData, bounding: false, wantBounds: nil}, |
| {data: nbWpData, bounding: false, wantBounds: nil}, |
| {data: wpData, bounding: true, wantBounds: wpBound}, |
| {data: nbWpData, bounding: true, wantBounds: nil}, |
| } |
| |
| func TestNew(t *testing.T) { |
| for i, test := range newTests { |
| var tree *Tree |
| var panicked bool |
| func() { |
| defer func() { |
| if r := recover(); r != nil { |
| panicked = true |
| } |
| }() |
| tree = New(test.data, test.bounding) |
| }() |
| if panicked { |
| t.Errorf("unexpected panic for test %d", i) |
| continue |
| } |
| |
| if !tree.Root.isKDTree() { |
| t.Errorf("tree %d is not k-d tree", i) |
| } |
| |
| switch data := test.data.(type) { |
| case Points: |
| for _, p := range data { |
| if !tree.Contains(p) { |
| t.Errorf("failed to find point %.3f in test %d", p, i) |
| } |
| } |
| case nbPoints: |
| for _, p := range data { |
| if !tree.Contains(p) { |
| t.Errorf("failed to find point %.3f in test %d", p, i) |
| } |
| } |
| default: |
| t.Fatalf("bad test: unknown data type: %T", test.data) |
| } |
| |
| if !reflect.DeepEqual(tree.Root.Bounding, test.wantBounds) { |
| t.Errorf("unexpected bounding box for test %d with data type %T: got:%v want:%v", |
| i, test.data, tree.Root.Bounding, test.wantBounds) |
| } |
| |
| if t.Failed() && *genDot && tree.Len() <= *dotLimit { |
| err := dotFile(tree, fmt.Sprintf("TestNew%T", test.data), "") |
| if err != nil { |
| t.Fatalf("failed to write DOT file: %v", err) |
| } |
| } |
| } |
| } |
| |
| var insertTests = []struct { |
| data Interface |
| insert []Comparable |
| wantBounds *Bounding |
| }{ |
| { |
| data: wpData, |
| insert: []Comparable{Point{0, 0}, Point{10, 10}}, |
| wantBounds: &Bounding{Point{0, 0}, Point{10, 10}}, |
| }, |
| { |
| data: nbWpData, |
| insert: []Comparable{nbPoint{0, 0}, nbPoint{10, 10}}, |
| wantBounds: nil, |
| }, |
| } |
| |
| func TestInsert(t *testing.T) { |
| for i, test := range insertTests { |
| tree := New(test.data, true) |
| for _, v := range test.insert { |
| tree.Insert(v, true) |
| } |
| |
| if !tree.Root.isKDTree() { |
| t.Errorf("tree %d is not k-d tree", i) |
| } |
| |
| if !reflect.DeepEqual(tree.Root.Bounding, test.wantBounds) { |
| t.Errorf("unexpected bounding box for test %d with data type %T: got:%v want:%v", |
| i, test.data, tree.Root.Bounding, test.wantBounds) |
| } |
| |
| if t.Failed() && *genDot && tree.Len() <= *dotLimit { |
| err := dotFile(tree, fmt.Sprintf("TestInsert%T", test.data), "") |
| if err != nil { |
| t.Fatalf("failed to write DOT file: %v", err) |
| } |
| } |
| } |
| } |
| |
| type compFn func(float64) bool |
| |
| func left(v float64) bool { return v <= 0 } |
| func right(v float64) bool { return !left(v) } |
| |
| func (n *Node) isKDTree() bool { |
| if n == nil { |
| return true |
| } |
| d := n.Point.Dims() |
| // Together these define the property of minimal orthogonal bounding. |
| if !(n.isContainedBy(n.Bounding) && n.Bounding.planesHaveCoincidentPointsIn(n, [2][]bool{make([]bool, d), make([]bool, d)})) { |
| return false |
| } |
| if !n.Left.isPartitioned(n.Point, left, n.Plane) { |
| return false |
| } |
| if !n.Right.isPartitioned(n.Point, right, n.Plane) { |
| return false |
| } |
| return n.Left.isKDTree() && n.Right.isKDTree() |
| } |
| |
| func (n *Node) isPartitioned(pivot Comparable, fn compFn, plane Dim) bool { |
| if n == nil { |
| return true |
| } |
| if n.Left != nil && fn(pivot.Compare(n.Left.Point, plane)) { |
| return false |
| } |
| if n.Right != nil && fn(pivot.Compare(n.Right.Point, plane)) { |
| return false |
| } |
| return n.Left.isPartitioned(pivot, fn, plane) && n.Right.isPartitioned(pivot, fn, plane) |
| } |
| |
| func (n *Node) isContainedBy(b *Bounding) bool { |
| if n == nil { |
| return true |
| } |
| if !b.Contains(n.Point) { |
| return false |
| } |
| return n.Left.isContainedBy(b) && n.Right.isContainedBy(b) |
| } |
| |
| func (b *Bounding) planesHaveCoincidentPointsIn(n *Node, tight [2][]bool) bool { |
| if b == nil { |
| return true |
| } |
| if n == nil { |
| return true |
| } |
| |
| b.planesHaveCoincidentPointsIn(n.Left, tight) |
| b.planesHaveCoincidentPointsIn(n.Right, tight) |
| |
| var ok = true |
| for i := range tight { |
| for d := 0; d < n.Point.Dims(); d++ { |
| if c := n.Point.Compare(b.Min, Dim(d)); c == 0 { |
| tight[i][d] = true |
| } |
| ok = ok && tight[i][d] |
| } |
| } |
| return ok |
| } |
| |
| func nearest(q Point, p Points) (Point, float64) { |
| min := q.Distance(p[0]) |
| var r int |
| for i := 1; i < p.Len(); i++ { |
| d := q.Distance(p[i]) |
| if d < min { |
| min = d |
| r = i |
| } |
| } |
| return p[r], min |
| } |
| |
| func TestNearestRandom(t *testing.T) { |
| rnd := rand.New(rand.NewSource(1)) |
| |
| const ( |
| min = 0.0 |
| max = 1000.0 |
| |
| dims = 4 |
| setSize = 10000 |
| ) |
| |
| var randData Points |
| for i := 0; i < setSize; i++ { |
| p := make(Point, dims) |
| for j := 0; j < dims; j++ { |
| p[j] = (max-min)*rnd.Float64() + min |
| } |
| randData = append(randData, p) |
| } |
| tree := New(randData, false) |
| |
| for i := 0; i < setSize; i++ { |
| q := make(Point, dims) |
| for j := 0; j < dims; j++ { |
| q[j] = (max-min)*rnd.Float64() + min |
| } |
| |
| got, _ := tree.Nearest(q) |
| want, _ := nearest(q, randData) |
| if !reflect.DeepEqual(got, want) { |
| t.Fatalf("unexpected result from query %d %.3f: got:%.3f want:%.3f", i, q, got, want) |
| } |
| } |
| } |
| |
| func TestNearest(t *testing.T) { |
| tree := New(wpData, false) |
| for _, q := range append([]Point{ |
| {4, 6}, |
| {7, 5}, |
| {8, 7}, |
| {6, -5}, |
| {1e5, 1e5}, |
| {1e5, -1e5}, |
| {-1e5, 1e5}, |
| {-1e5, -1e5}, |
| {1e5, 0}, |
| {0, -1e5}, |
| {0, 1e5}, |
| {-1e5, 0}, |
| }, wpData...) { |
| gotP, gotD := tree.Nearest(q) |
| wantP, wantD := nearest(q, wpData) |
| if !reflect.DeepEqual(gotP, wantP) { |
| t.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP) |
| } |
| if gotD != wantD { |
| t.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD) |
| } |
| } |
| } |
| |
| func nearestN(n int, q Point, p Points) []ComparableDist { |
| nk := NewNKeeper(n) |
| for i := 0; i < p.Len(); i++ { |
| nk.Keep(ComparableDist{Comparable: p[i], Dist: q.Distance(p[i])}) |
| } |
| if len(nk.Heap) == 1 { |
| return nk.Heap |
| } |
| sort.Sort(nk) |
| for i, j := 0, len(nk.Heap)-1; i < j; i, j = i+1, j-1 { |
| nk.Heap[i], nk.Heap[j] = nk.Heap[j], nk.Heap[i] |
| } |
| return nk.Heap |
| } |
| |
| func TestNearestSetN(t *testing.T) { |
| data := append([]Point{ |
| {4, 6}, |
| {7, 5}, |
| {8, 7}, |
| {6, -5}, |
| {1e5, 1e5}, |
| {1e5, -1e5}, |
| {-1e5, 1e5}, |
| {-1e5, -1e5}, |
| {1e5, 0}, |
| {0, -1e5}, |
| {0, 1e5}, |
| {-1e5, 0}}, |
| wpData[:len(wpData)-1]...) |
| |
| tree := New(wpData, false) |
| for k := 1; k <= len(wpData); k++ { |
| for _, q := range data { |
| wantP := nearestN(k, q, wpData) |
| |
| nk := NewNKeeper(k) |
| tree.NearestSet(nk, q) |
| |
| var max float64 |
| wantD := make(map[float64]map[string]struct{}) |
| for _, p := range wantP { |
| if p.Dist > max { |
| max = p.Dist |
| } |
| d, ok := wantD[p.Dist] |
| if !ok { |
| d = make(map[string]struct{}) |
| } |
| d[fmt.Sprint(p.Comparable)] = struct{}{} |
| wantD[p.Dist] = d |
| } |
| gotD := make(map[float64]map[string]struct{}) |
| for _, p := range nk.Heap { |
| if p.Dist > max { |
| t.Errorf("unexpected distance for point %.3f: got:%v want:<=%v", p.Comparable, p.Dist, max) |
| } |
| d, ok := gotD[p.Dist] |
| if !ok { |
| d = make(map[string]struct{}) |
| } |
| d[fmt.Sprint(p.Comparable)] = struct{}{} |
| gotD[p.Dist] = d |
| } |
| |
| // If the available number of slots does not fit all the coequal furthest points |
| // we will fail the check. So remove, but check them minimally here. |
| if !reflect.DeepEqual(wantD[max], gotD[max]) { |
| // The best we can do at this stage is confirm that there are an equal number of matches at this distance. |
| if len(gotD[max]) != len(wantD[max]) { |
| t.Errorf("unexpected number of maximal distance points: got:%d want:%d", len(gotD[max]), len(wantD[max])) |
| } |
| delete(wantD, max) |
| delete(gotD, max) |
| } |
| |
| if !reflect.DeepEqual(gotD, wantD) { |
| t.Errorf("unexpected result for k=%d query %.3f: got:%v want:%v", k, q, gotD, wantD) |
| } |
| } |
| } |
| } |
| |
| var nearestSetDistTests = []Point{ |
| {4, 6}, |
| {7, 5}, |
| {8, 7}, |
| {6, -5}, |
| } |
| |
| func TestNearestSetDist(t *testing.T) { |
| tree := New(wpData, false) |
| for i, q := range nearestSetDistTests { |
| for d := 1.0; d < 100; d += 0.1 { |
| dk := NewDistKeeper(d) |
| tree.NearestSet(dk, q) |
| |
| hits := make(map[string]float64) |
| for _, p := range wpData { |
| hits[fmt.Sprint(p)] = p.Distance(q) |
| } |
| |
| for _, p := range dk.Heap { |
| var done bool |
| if p.Comparable == nil { |
| done = true |
| continue |
| } |
| delete(hits, fmt.Sprint(p.Comparable)) |
| if done { |
| t.Error("expectedly finished heap iteration") |
| break |
| } |
| dist := p.Comparable.Distance(q) |
| if dist > d { |
| t.Errorf("Test %d: query %v found %v expect %.3f <= %.3f", i, q, p, dist, d) |
| } |
| } |
| |
| for p, dist := range hits { |
| if dist <= d { |
| t.Errorf("Test %d: query %v missed %v expect %.3f > %.3f", i, q, p, dist, d) |
| } |
| } |
| } |
| } |
| } |
| |
| func TestDo(t *testing.T) { |
| tree := New(wpData, false) |
| var got Points |
| fn := func(c Comparable, _ *Bounding, _ int) (done bool) { |
| got = append(got, c.(Point)) |
| return |
| } |
| killed := tree.Do(fn) |
| if !reflect.DeepEqual(got, wpData) { |
| t.Errorf("unexpected result from tree iteration: got:%v want:%v", got, wpData) |
| } |
| if killed { |
| t.Error("tree iteration unexpectedly killed") |
| } |
| } |
| |
| var doBoundedTests = []struct { |
| bounds *Bounding |
| want Points |
| }{ |
| { |
| bounds: nil, |
| want: wpData, |
| }, |
| { |
| bounds: &Bounding{Point{0, 0}, Point{10, 10}}, |
| want: wpData, |
| }, |
| { |
| bounds: &Bounding{Point{3, 4}, Point{10, 10}}, |
| want: Points{Point{5, 4}, Point{4, 7}, Point{9, 6}}, |
| }, |
| { |
| bounds: &Bounding{Point{3, 3}, Point{10, 10}}, |
| want: Points{Point{5, 4}, Point{4, 7}, Point{9, 6}}, |
| }, |
| { |
| bounds: &Bounding{Point{0, 0}, Point{6, 5}}, |
| want: Points{Point{2, 3}, Point{5, 4}}, |
| }, |
| { |
| bounds: &Bounding{Point{5, 2}, Point{7, 4}}, |
| want: Points{Point{5, 4}, Point{7, 2}}, |
| }, |
| { |
| bounds: &Bounding{Point{2, 2}, Point{7, 4}}, |
| want: Points{Point{2, 3}, Point{5, 4}, Point{7, 2}}, |
| }, |
| { |
| bounds: &Bounding{Point{2, 3}, Point{9, 6}}, |
| want: Points{Point{2, 3}, Point{5, 4}, Point{9, 6}}, |
| }, |
| { |
| bounds: &Bounding{Point{7, 2}, Point{7, 2}}, |
| want: Points{Point{7, 2}}, |
| }, |
| } |
| |
| func TestDoBounded(t *testing.T) { |
| for _, test := range doBoundedTests { |
| tree := New(wpData, false) |
| var got Points |
| fn := func(c Comparable, _ *Bounding, _ int) (done bool) { |
| got = append(got, c.(Point)) |
| return |
| } |
| killed := tree.DoBounded(test.bounds, fn) |
| if !reflect.DeepEqual(got, test.want) { |
| t.Errorf("unexpected result from bounded tree iteration: got:%v want:%v", got, test.want) |
| } |
| if killed { |
| t.Error("tree iteration unexpectedly killed") |
| } |
| } |
| } |
| |
| func BenchmarkNew(b *testing.B) { |
| rnd := rand.New(rand.NewSource(1)) |
| p := make(Points, 1e5) |
| for i := range p { |
| p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} |
| } |
| b.ResetTimer() |
| for i := 0; i < b.N; i++ { |
| _ = New(p, false) |
| } |
| } |
| |
| func BenchmarkNewBounds(b *testing.B) { |
| rnd := rand.New(rand.NewSource(1)) |
| p := make(Points, 1e5) |
| for i := range p { |
| p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} |
| } |
| b.ResetTimer() |
| for i := 0; i < b.N; i++ { |
| _ = New(p, true) |
| } |
| } |
| |
| func BenchmarkInsert(b *testing.B) { |
| rnd := rand.New(rand.NewSource(1)) |
| t := &Tree{} |
| for i := 0; i < b.N; i++ { |
| t.Insert(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, false) |
| } |
| } |
| |
| func BenchmarkInsertBounds(b *testing.B) { |
| rnd := rand.New(rand.NewSource(1)) |
| t := &Tree{} |
| for i := 0; i < b.N; i++ { |
| t.Insert(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, true) |
| } |
| } |
| |
| func Benchmark(b *testing.B) { |
| rnd := rand.New(rand.NewSource(1)) |
| data := make(Points, 1e2) |
| for i := range data { |
| data[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} |
| } |
| tree := New(data, true) |
| |
| if !tree.Root.isKDTree() { |
| b.Fatal("tree is not k-d tree") |
| } |
| |
| for i := 0; i < 1e3; i++ { |
| q := Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} |
| gotP, gotD := tree.Nearest(q) |
| wantP, wantD := nearest(q, data) |
| if !reflect.DeepEqual(gotP, wantP) { |
| b.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP) |
| } |
| if gotD != wantD { |
| b.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD) |
| } |
| } |
| |
| if b.Failed() && *genDot && tree.Len() <= *dotLimit { |
| err := dotFile(tree, "TestBenches", "") |
| if err != nil { |
| b.Fatalf("failed to write DOT file: %v", err) |
| } |
| return |
| } |
| |
| var r Comparable |
| var d float64 |
| queryBenchmarks := []struct { |
| name string |
| fn func(*testing.B) |
| }{ |
| { |
| name: "Nearest", fn: func(b *testing.B) { |
| for i := 0; i < b.N; i++ { |
| r, d = tree.Nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}) |
| } |
| if r == nil { |
| b.Error("unexpected nil result") |
| } |
| if math.IsNaN(d) { |
| b.Error("unexpected NaN result") |
| } |
| }, |
| }, |
| { |
| name: "NearestBrute", fn: func(b *testing.B) { |
| for i := 0; i < b.N; i++ { |
| r, d = nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data) |
| } |
| if r == nil { |
| b.Error("unexpected nil result") |
| } |
| if math.IsNaN(d) { |
| b.Error("unexpected NaN result") |
| } |
| }, |
| }, |
| { |
| name: "NearestSetN10", fn: func(b *testing.B) { |
| nk := NewNKeeper(10) |
| for i := 0; i < b.N; i++ { |
| tree.NearestSet(nk, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}) |
| if nk.Len() != 10 { |
| b.Error("unexpected result length") |
| } |
| nk.Heap = nk.Heap[:1] |
| nk.Heap[0] = ComparableDist{Dist: inf} |
| } |
| }, |
| }, |
| { |
| name: "NearestBruteN10", fn: func(b *testing.B) { |
| var r []ComparableDist |
| for i := 0; i < b.N; i++ { |
| r = nearestN(10, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data) |
| } |
| if len(r) != 10 { |
| b.Error("unexpected result length", len(r)) |
| } |
| }, |
| }, |
| } |
| for _, bench := range queryBenchmarks { |
| b.Run(bench.name, bench.fn) |
| } |
| } |
| |
| func dot(t *Tree, label string) string { |
| if t == nil { |
| return "" |
| } |
| var ( |
| s []string |
| follow func(*Node) |
| ) |
| follow = func(n *Node) { |
| id := uintptr(unsafe.Pointer(n)) |
| c := fmt.Sprintf("%d[label = \"<Left> |<Elem> %s/%.3f\\n%.3f|<Right>\"];", |
| id, n, n.Point.(Point)[n.Plane], *n.Bounding) |
| if n.Left != nil { |
| c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Left -> \"%d\":Elem;", |
| id, uintptr(unsafe.Pointer(n.Left))) |
| follow(n.Left) |
| } |
| if n.Right != nil { |
| c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Right -> \"%d\":Elem;", |
| id, uintptr(unsafe.Pointer(n.Right))) |
| follow(n.Right) |
| } |
| s = append(s, c) |
| } |
| if t.Root != nil { |
| follow(t.Root) |
| } |
| return fmt.Sprintf("digraph %s {\n\tnode [shape=record,height=0.1];\n\t%s\n}\n", |
| label, |
| strings.Join(s, "\n\t"), |
| ) |
| } |
| |
| func dotFile(t *Tree, label, dotString string) (err error) { |
| if t == nil && dotString == "" { |
| return |
| } |
| f, err := os.Create(label + ".dot") |
| if err != nil { |
| return |
| } |
| defer f.Close() |
| if dotString == "" { |
| fmt.Fprint(f, dot(t, label)) |
| } else { |
| fmt.Fprint(f, dotString) |
| } |
| return |
| } |