graph/path: simplify and clean up union-find code and Kruskal
diff --git a/graph/path/disjoint.go b/graph/path/disjoint.go
index 235694c..bfc05fb 100644
--- a/graph/path/disjoint.go
+++ b/graph/path/disjoint.go
@@ -4,84 +4,53 @@
package path
-// A disjoint set is a collection of non-overlapping sets. That is, for any two sets in the
-// disjoint set, their intersection is the empty set.
-//
-// A disjoint set has three principle operations: Make Set, Find, and Union.
-//
-// Make set creates a new set for an element (presuming it does not already exist in any set in
-// the disjoint set), Find finds the set containing that element (if any), and Union merges two
-// sets in the disjoint set. In general, algorithms operating on disjoint sets are "union-find"
-// algorithms, where two sets are found with Find, and then joined with Union.
-//
-// A concrete example of a union-find algorithm can be found as discrete.Kruskal -- which unions
-// two sets when an edge is created between two vertices, and refuses to make an edge between two
-// vertices if they're part of the same set.
-type disjointSet struct {
- master map[int64]*disjointSetNode
-}
+// djSet implements a disjoint set finder using the union-find algorithm.
+type djSet map[int64]*dsNode
-type disjointSetNode struct {
- parent *disjointSetNode
- rank int
-}
-
-func newDisjointSet() *disjointSet {
- return &disjointSet{master: make(map[int64]*disjointSetNode)}
-}
-
-// If the element isn't already somewhere in there, adds it to the master set and its own tiny set.
-func (ds *disjointSet) makeSet(e int64) {
- if _, ok := ds.master[e]; ok {
+// add adds e to the collection of sets held by the disjoint set.
+func (s djSet) add(e int64) {
+ if _, ok := s[e]; ok {
return
}
- dsNode := &disjointSetNode{rank: 0}
- dsNode.parent = dsNode
- ds.master[e] = dsNode
+ s[e] = &dsNode{}
}
-// Returns the set the element belongs to, or nil if none.
-func (ds *disjointSet) find(e int64) *disjointSetNode {
- dsNode, ok := ds.master[e]
+// union joins two sets a and b within the collection of sets held by
+// the disjoint set.
+func (djSet) union(a, b *dsNode) {
+ ra := find(a)
+ rb := find(b)
+ if ra == rb {
+ return
+ }
+ if ra.rank < rb.rank {
+ ra.parent = rb
+ return
+ }
+ rb.parent = ra
+ if ra.rank == rb.rank {
+ ra.rank++
+ }
+}
+
+// find returns the root of the set containing e.
+func (s djSet) find(e int64) *dsNode {
+ n, ok := s[e]
if !ok {
return nil
}
-
- return find(dsNode)
+ return find(n)
}
-func find(dsNode *disjointSetNode) *disjointSetNode {
- if dsNode.parent != dsNode {
- dsNode.parent = find(dsNode.parent)
+// find returns the root of the set containing the set node, n.
+func find(n *dsNode) *dsNode {
+ for ; n.parent != nil; n = n.parent {
}
-
- return dsNode.parent
+ return n
}
-// Unions two subsets within the disjointSet.
-//
-// If x or y are not in this disjoint set, the behavior is undefined. If either pointer is nil,
-// this function will panic.
-func (ds *disjointSet) union(x, y *disjointSetNode) {
- if x == nil || y == nil {
- panic("Disjoint Set union on nil sets")
- }
- xRoot := find(x)
- yRoot := find(y)
- if xRoot == nil || yRoot == nil {
- return
- }
-
- if xRoot == yRoot {
- return
- }
-
- if xRoot.rank < yRoot.rank {
- xRoot.parent = yRoot
- } else if yRoot.rank < xRoot.rank {
- yRoot.parent = xRoot
- } else {
- yRoot.parent = xRoot
- xRoot.rank++
- }
+// dsNode is a disjoint set element.
+type dsNode struct {
+ parent *dsNode
+ rank int
}
diff --git a/graph/path/disjoint_test.go b/graph/path/disjoint_test.go
index d1cd78d..7e833a5 100644
--- a/graph/path/disjoint_test.go
+++ b/graph/path/disjoint_test.go
@@ -10,30 +10,26 @@
func TestDisjointSetMakeSet(t *testing.T) {
t.Parallel()
- ds := newDisjointSet()
- if ds.master == nil {
- t.Fatal("Internal disjoint set map erroneously nil")
- } else if len(ds.master) != 0 {
+
+ ds := make(djSet)
+ ds.add(3)
+ if len(ds) != 1 {
t.Error("Disjoint set master map of wrong size")
}
- ds.makeSet(3)
- if len(ds.master) != 1 {
- t.Error("Disjoint set master map of wrong size")
- }
-
- if node, ok := ds.master[3]; !ok {
+ node, ok := ds[3]
+ if !ok {
t.Error("Make set did not successfully add element")
} else {
if node == nil {
- t.Fatal("Disjoint set node from makeSet is nil")
+ t.Fatal("Disjoint set node from add is nil")
}
if node.rank != 0 {
t.Error("Node rank set incorrectly")
}
- if node.parent != node {
+ if node.parent != nil {
t.Error("Node parent set incorrectly")
}
}
@@ -41,10 +37,12 @@
func TestDisjointSetFind(t *testing.T) {
t.Parallel()
- ds := newDisjointSet()
- ds.makeSet(3)
- ds.makeSet(5)
+ ds := make(djSet)
+ ds.add(3)
+ ds.add(4)
+ ds.add(5)
+ ds.union(ds.find(3), ds.find(4))
if ds.find(3) == ds.find(5) {
t.Error("Disjoint sets incorrectly found to be the same")
@@ -53,12 +51,13 @@
func TestUnion(t *testing.T) {
t.Parallel()
- ds := newDisjointSet()
- ds.makeSet(3)
- ds.makeSet(5)
-
- ds.union(ds.find(3), ds.find(5))
+ ds := make(djSet)
+ ds.add(3)
+ ds.add(4)
+ ds.add(5)
+ ds.union(ds.find(3), ds.find(4))
+ ds.union(ds.find(4), ds.find(5))
if ds.find(3) != ds.find(5) {
t.Error("Sets found to be disjoint after union")
diff --git a/graph/path/spanning_tree.go b/graph/path/spanning_tree.go
index 55b6192..e4e67fc 100644
--- a/graph/path/spanning_tree.go
+++ b/graph/path/spanning_tree.go
@@ -165,10 +165,12 @@
edges := graph.WeightedEdgesOf(g.WeightedEdges())
sort.Sort(byWeight(edges))
- ds := newDisjointSet()
- for _, node := range graph.NodesOf(g.Nodes()) {
- dst.AddNode(node)
- ds.makeSet(node.ID())
+ ds := make(djSet)
+ it := g.Nodes()
+ for it.Next() {
+ n := it.Node()
+ dst.AddNode(n)
+ ds.add(n.ID())
}
var w float64