diff all map key types
diff --git a/diff.go b/diff.go
index b08522d..6aa7f74 100644
--- a/diff.go
+++ b/diff.go
@@ -185,16 +185,53 @@
 
 // keyEqual compares a and b for equality.
 // Both a and b must be valid map keys.
-func keyEqual(a, b reflect.Value) bool {
-	if a.Type() != b.Type() {
+func keyEqual(av, bv reflect.Value) bool {
+	if !av.IsValid() && !bv.IsValid() {
+		return true
+	}
+	if !av.IsValid() || !bv.IsValid() || av.Type() != bv.Type() {
 		return false
 	}
-	switch kind := a.Kind(); kind {
-	case reflect.Int:
-		a, b := a.Int(), b.Int()
+	switch kind := av.Kind(); kind {
+	case reflect.Bool:
+		a, b := av.Bool(), bv.Bool()
 		return a == b
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+		a, b := av.Int(), bv.Int()
+		return a == b
+	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+		a, b := av.Uint(), bv.Uint()
+		return a == b
+	case reflect.Float32, reflect.Float64:
+		a, b := av.Float(), bv.Float()
+		return a == b
+	case reflect.Complex64, reflect.Complex128:
+		a, b := av.Complex(), bv.Complex()
+		return a == b
+	case reflect.Array:
+		for i := 0; i < av.Len(); i++ {
+			if !keyEqual(av.Index(i), bv.Index(i)) {
+				return false
+			}
+		}
+		return true
+	case reflect.Chan, reflect.UnsafePointer, reflect.Ptr:
+		a, b := av.Pointer(), bv.Pointer()
+		return a == b
+	case reflect.Interface:
+		return keyEqual(av.Elem(), bv.Elem())
+	case reflect.String:
+		a, b := av.String(), bv.String()
+		return a == b
+	case reflect.Struct:
+		for i := 0; i < av.NumField(); i++ {
+			if !keyEqual(av.Field(i), bv.Field(i)) {
+				return false
+			}
+		}
+		return true
 	default:
-		panic("invalid map reflect Kind: " + kind.String())
+		panic("invalid map key type " + av.Type().String())
 	}
 }
 
diff --git a/diff_test.go b/diff_test.go
index c47c9f0..a951e4b 100644
--- a/diff_test.go
+++ b/diff_test.go
@@ -4,6 +4,7 @@
 	"bytes"
 	"fmt"
 	"log"
+	"reflect"
 	"testing"
 	"unsafe"
 )
@@ -115,7 +116,6 @@
 	{struct{ x string }{"a"}, struct{ x string }{"b"}, []string{`x: "a" != "b"`}},
 	{struct{ x N }{N{0}}, struct{ x N }{N{0}}, nil},
 	{struct{ x N }{N{0}}, struct{ x N }{N{1}}, []string{`x.N: 0 != 1`}},
-
 	{
 		struct{ x unsafe.Pointer }{unsafe.Pointer(uintptr(0))},
 		struct{ x unsafe.Pointer }{unsafe.Pointer(uintptr(0))},
@@ -146,6 +146,44 @@
 	}
 }
 
+func TestKeyEqual(t *testing.T) {
+	var emptyInterfaceZero interface{} = 0
+
+	cases := []interface{}{
+		new(bool),
+		new(int),
+		new(int8),
+		new(int16),
+		new(int32),
+		new(int64),
+		new(uint),
+		new(uint8),
+		new(uint16),
+		new(uint32),
+		new(uint64),
+		new(uintptr),
+		new(float32),
+		new(float64),
+		new(complex64),
+		new(complex128),
+		new([1]int),
+		new(chan int),
+		new(unsafe.Pointer),
+		new(interface{}),
+		&emptyInterfaceZero,
+		new(*int),
+		new(string),
+		new(struct{ int }),
+	}
+
+	for _, test := range cases {
+		rv := reflect.ValueOf(test).Elem()
+		if !keyEqual(rv, rv) {
+			t.Errorf("keyEqual(%s, %s) = false want true", rv.Type(), rv.Type())
+		}
+	}
+}
+
 func TestFdiff(t *testing.T) {
 	var buf bytes.Buffer
 	Fdiff(&buf, 0, 1)