proto: add RangeExtensions and adjust HasExtension and GetExtension

Two changes:
* Add RangeExtensions as a more suitable replacement for legacy
proto.ClearExtensions and proto.ExtensionDescs functions.
* Make HasExtension and GetExtension treat nil message interface
as an empty message to more consistently match legacy behavior.

Change-Id: I8eb1887a33d0737f2f80a2b80358cc296087ba3b
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/229157
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/proto/extension.go b/proto/extension.go
index 73f2431..94af03f 100644
--- a/proto/extension.go
+++ b/proto/extension.go
@@ -11,6 +11,11 @@
 // HasExtension reports whether an extension field is populated.
 // It panics if ext does not extend m.
 func HasExtension(m Message, ext protoreflect.ExtensionType) bool {
+	// Treat nil message interface as an empty message; no populated fields.
+	if m == nil {
+		return false
+	}
+
 	return m.ProtoReflect().Has(ext.TypeDescriptor())
 }
 
@@ -23,9 +28,14 @@
 
 // GetExtension retrieves the value for an extension field.
 // If the field is unpopulated, it returns the default value for
-// scalars and an immutable, empty value for lists, maps, or messages.
+// scalars and an immutable, empty value for lists or messages.
 // It panics if ext does not extend m.
 func GetExtension(m Message, ext protoreflect.ExtensionType) interface{} {
+	// Treat nil message interface as an empty message; return the default.
+	if m == nil {
+		return ext.InterfaceOf(ext.Zero())
+	}
+
 	return ext.InterfaceOf(m.ProtoReflect().Get(ext.TypeDescriptor()))
 }
 
@@ -34,3 +44,24 @@
 func SetExtension(m Message, ext protoreflect.ExtensionType, value interface{}) {
 	m.ProtoReflect().Set(ext.TypeDescriptor(), ext.ValueOf(value))
 }
+
+// RangeExtensions iterates over every populated extension field in m in an
+// undefined order, calling f for each extension type and value encountered.
+// It returns immediately if f returns false.
+// While iterating, mutating operations may only be performed
+// on the current extension field.
+func RangeExtensions(m Message, f func(protoreflect.ExtensionType, interface{}) bool) {
+	// Treat nil message interface as an empty message; nothing to range over.
+	if m == nil {
+		return
+	}
+
+	m.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+		if fd.IsExtension() {
+			xt := fd.(protoreflect.ExtensionTypeDescriptor).Type()
+			vi := xt.InterfaceOf(v)
+			return f(xt, vi)
+		}
+		return true
+	})
+}
diff --git a/proto/extension_test.go b/proto/extension_test.go
index 0c58cc3..113160f 100644
--- a/proto/extension_test.go
+++ b/proto/extension_test.go
@@ -14,6 +14,7 @@
 	"google.golang.org/protobuf/proto"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/runtime/protoimpl"
+	"google.golang.org/protobuf/testing/protocmp"
 
 	legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2_20160225_2fc053c5"
 	testpb "google.golang.org/protobuf/internal/testprotos/test"
@@ -68,6 +69,33 @@
 	}
 }
 
+func TestExtensionRanger(t *testing.T) {
+	want := map[pref.ExtensionType]interface{}{
+		testpb.E_OptionalInt32:         int32(5),
+		testpb.E_OptionalString:        string("hello"),
+		testpb.E_OptionalNestedMessage: &testpb.TestAllExtensions_NestedMessage{},
+		testpb.E_OptionalNestedEnum:    testpb.TestAllTypes_BAZ,
+		testpb.E_RepeatedFloat:         []float32{+32.32, -32.32},
+		testpb.E_RepeatedNestedMessage: []*testpb.TestAllExtensions_NestedMessage{{}},
+		testpb.E_RepeatedNestedEnum:    []testpb.TestAllTypes_NestedEnum{testpb.TestAllTypes_BAZ},
+	}
+
+	m := &testpb.TestAllExtensions{}
+	for xt, v := range want {
+		proto.SetExtension(m, xt, v)
+	}
+
+	got := make(map[pref.ExtensionType]interface{})
+	proto.RangeExtensions(m, func(xt pref.ExtensionType, v interface{}) bool {
+		got[xt] = v
+		return true
+	})
+
+	if diff := cmp.Diff(want, got, protocmp.Transform()); diff != "" {
+		t.Errorf("proto.RangeExtensions mismatch (-want +got):\n%s", diff)
+	}
+}
+
 func TestExtensionGetRace(t *testing.T) {
 	// Concurrently fetch an extension value while marshaling the message containing it.
 	// Create the message with proto.Unmarshal to give lazy extension decoding (if present)
diff --git a/proto/nil_test.go b/proto/nil_test.go
index d4563a8..9d13b2b 100644
--- a/proto/nil_test.go
+++ b/proto/nil_test.go
@@ -8,6 +8,7 @@
 	"testing"
 
 	"google.golang.org/protobuf/proto"
+	"google.golang.org/protobuf/reflect/protoreflect"
 
 	testpb "google.golang.org/protobuf/internal/testprotos/test"
 )
@@ -19,6 +20,7 @@
 func TestNil(t *testing.T) {
 	nilMsg := (*testpb.TestAllExtensions)(nil)
 	extType := testpb.E_OptionalBool
+	extRanger := func(protoreflect.ExtensionType, interface{}) bool { return true }
 
 	tests := []struct {
 		label string
@@ -89,11 +91,9 @@
 	}, {
 		label: "HasExtension",
 		test:  func() { proto.HasExtension(nil, nil) },
-		panic: true,
 	}, {
 		label: "HasExtension",
 		test:  func() { proto.HasExtension(nil, extType) },
-		panic: true,
 	}, {
 		label: "HasExtension",
 		test:  func() { proto.HasExtension(nilMsg, nil) },
@@ -108,7 +108,6 @@
 	}, {
 		label: "GetExtension",
 		test:  func() { proto.GetExtension(nil, extType) },
-		panic: true,
 	}, {
 		label: "GetExtension",
 		test:  func() { proto.GetExtension(nilMsg, nil) },
@@ -148,6 +147,18 @@
 		label: "ClearExtension",
 		test:  func() { proto.ClearExtension(nilMsg, extType) },
 		panic: true,
+	}, {
+		label: "RangeExtensions",
+		test:  func() { proto.RangeExtensions(nil, nil) },
+	}, {
+		label: "RangeExtensions",
+		test:  func() { proto.RangeExtensions(nil, extRanger) },
+	}, {
+		label: "RangeExtensions",
+		test:  func() { proto.RangeExtensions(nilMsg, nil) },
+	}, {
+		label: "RangeExtensions",
+		test:  func() { proto.RangeExtensions(nilMsg, extRanger) },
 	}}
 
 	for _, tt := range tests {