diff all map key types
diff --git a/diff.go b/diff.go
index b08522d..6aa7f74 100644
--- a/diff.go
+++ b/diff.go
@@ -185,16 +185,53 @@
// keyEqual compares a and b for equality.
// Both a and b must be valid map keys.
-func keyEqual(a, b reflect.Value) bool {
- if a.Type() != b.Type() {
+func keyEqual(av, bv reflect.Value) bool {
+ if !av.IsValid() && !bv.IsValid() {
+ return true
+ }
+ if !av.IsValid() || !bv.IsValid() || av.Type() != bv.Type() {
return false
}
- switch kind := a.Kind(); kind {
- case reflect.Int:
- a, b := a.Int(), b.Int()
+ switch kind := av.Kind(); kind {
+ case reflect.Bool:
+ a, b := av.Bool(), bv.Bool()
return a == b
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ a, b := av.Int(), bv.Int()
+ return a == b
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ a, b := av.Uint(), bv.Uint()
+ return a == b
+ case reflect.Float32, reflect.Float64:
+ a, b := av.Float(), bv.Float()
+ return a == b
+ case reflect.Complex64, reflect.Complex128:
+ a, b := av.Complex(), bv.Complex()
+ return a == b
+ case reflect.Array:
+ for i := 0; i < av.Len(); i++ {
+ if !keyEqual(av.Index(i), bv.Index(i)) {
+ return false
+ }
+ }
+ return true
+ case reflect.Chan, reflect.UnsafePointer, reflect.Ptr:
+ a, b := av.Pointer(), bv.Pointer()
+ return a == b
+ case reflect.Interface:
+ return keyEqual(av.Elem(), bv.Elem())
+ case reflect.String:
+ a, b := av.String(), bv.String()
+ return a == b
+ case reflect.Struct:
+ for i := 0; i < av.NumField(); i++ {
+ if !keyEqual(av.Field(i), bv.Field(i)) {
+ return false
+ }
+ }
+ return true
default:
- panic("invalid map reflect Kind: " + kind.String())
+ panic("invalid map key type " + av.Type().String())
}
}
diff --git a/diff_test.go b/diff_test.go
index c47c9f0..a951e4b 100644
--- a/diff_test.go
+++ b/diff_test.go
@@ -4,6 +4,7 @@
"bytes"
"fmt"
"log"
+ "reflect"
"testing"
"unsafe"
)
@@ -115,7 +116,6 @@
{struct{ x string }{"a"}, struct{ x string }{"b"}, []string{`x: "a" != "b"`}},
{struct{ x N }{N{0}}, struct{ x N }{N{0}}, nil},
{struct{ x N }{N{0}}, struct{ x N }{N{1}}, []string{`x.N: 0 != 1`}},
-
{
struct{ x unsafe.Pointer }{unsafe.Pointer(uintptr(0))},
struct{ x unsafe.Pointer }{unsafe.Pointer(uintptr(0))},
@@ -146,6 +146,44 @@
}
}
+func TestKeyEqual(t *testing.T) {
+ var emptyInterfaceZero interface{} = 0
+
+ cases := []interface{}{
+ new(bool),
+ new(int),
+ new(int8),
+ new(int16),
+ new(int32),
+ new(int64),
+ new(uint),
+ new(uint8),
+ new(uint16),
+ new(uint32),
+ new(uint64),
+ new(uintptr),
+ new(float32),
+ new(float64),
+ new(complex64),
+ new(complex128),
+ new([1]int),
+ new(chan int),
+ new(unsafe.Pointer),
+ new(interface{}),
+ &emptyInterfaceZero,
+ new(*int),
+ new(string),
+ new(struct{ int }),
+ }
+
+ for _, test := range cases {
+ rv := reflect.ValueOf(test).Elem()
+ if !keyEqual(rv, rv) {
+ t.Errorf("keyEqual(%s, %s) = false want true", rv.Type(), rv.Type())
+ }
+ }
+}
+
func TestFdiff(t *testing.T) {
var buf bytes.Buffer
Fdiff(&buf, 0, 1)