proto: add RangeExtensions and adjust HasExtension and GetExtension
Two changes:
* Add RangeExtensions as a more suitable replacement for legacy
proto.ClearExtensions and proto.ExtensionDescs functions.
* Make HasExtension and GetExtension treat nil message interface
as an empty message to more consistently match legacy behavior.
Change-Id: I8eb1887a33d0737f2f80a2b80358cc296087ba3b
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/229157
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/proto/extension.go b/proto/extension.go
index 73f2431..94af03f 100644
--- a/proto/extension.go
+++ b/proto/extension.go
@@ -11,6 +11,11 @@
// HasExtension reports whether an extension field is populated.
// It panics if ext does not extend m.
func HasExtension(m Message, ext protoreflect.ExtensionType) bool {
+ // Treat nil message interface as an empty message; no populated fields.
+ if m == nil {
+ return false
+ }
+
return m.ProtoReflect().Has(ext.TypeDescriptor())
}
@@ -23,9 +28,14 @@
// GetExtension retrieves the value for an extension field.
// If the field is unpopulated, it returns the default value for
-// scalars and an immutable, empty value for lists, maps, or messages.
+// scalars and an immutable, empty value for lists or messages.
// It panics if ext does not extend m.
func GetExtension(m Message, ext protoreflect.ExtensionType) interface{} {
+ // Treat nil message interface as an empty message; return the default.
+ if m == nil {
+ return ext.InterfaceOf(ext.Zero())
+ }
+
return ext.InterfaceOf(m.ProtoReflect().Get(ext.TypeDescriptor()))
}
@@ -34,3 +44,24 @@
func SetExtension(m Message, ext protoreflect.ExtensionType, value interface{}) {
m.ProtoReflect().Set(ext.TypeDescriptor(), ext.ValueOf(value))
}
+
+// RangeExtensions iterates over every populated extension field in m in an
+// undefined order, calling f for each extension type and value encountered.
+// It returns immediately if f returns false.
+// While iterating, mutating operations may only be performed
+// on the current extension field.
+func RangeExtensions(m Message, f func(protoreflect.ExtensionType, interface{}) bool) {
+ // Treat nil message interface as an empty message; nothing to range over.
+ if m == nil {
+ return
+ }
+
+ m.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+ if fd.IsExtension() {
+ xt := fd.(protoreflect.ExtensionTypeDescriptor).Type()
+ vi := xt.InterfaceOf(v)
+ return f(xt, vi)
+ }
+ return true
+ })
+}
diff --git a/proto/extension_test.go b/proto/extension_test.go
index 0c58cc3..113160f 100644
--- a/proto/extension_test.go
+++ b/proto/extension_test.go
@@ -14,6 +14,7 @@
"google.golang.org/protobuf/proto"
pref "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoimpl"
+ "google.golang.org/protobuf/testing/protocmp"
legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2_20160225_2fc053c5"
testpb "google.golang.org/protobuf/internal/testprotos/test"
@@ -68,6 +69,33 @@
}
}
+func TestExtensionRanger(t *testing.T) {
+ want := map[pref.ExtensionType]interface{}{
+ testpb.E_OptionalInt32: int32(5),
+ testpb.E_OptionalString: string("hello"),
+ testpb.E_OptionalNestedMessage: &testpb.TestAllExtensions_NestedMessage{},
+ testpb.E_OptionalNestedEnum: testpb.TestAllTypes_BAZ,
+ testpb.E_RepeatedFloat: []float32{+32.32, -32.32},
+ testpb.E_RepeatedNestedMessage: []*testpb.TestAllExtensions_NestedMessage{{}},
+ testpb.E_RepeatedNestedEnum: []testpb.TestAllTypes_NestedEnum{testpb.TestAllTypes_BAZ},
+ }
+
+ m := &testpb.TestAllExtensions{}
+ for xt, v := range want {
+ proto.SetExtension(m, xt, v)
+ }
+
+ got := make(map[pref.ExtensionType]interface{})
+ proto.RangeExtensions(m, func(xt pref.ExtensionType, v interface{}) bool {
+ got[xt] = v
+ return true
+ })
+
+ if diff := cmp.Diff(want, got, protocmp.Transform()); diff != "" {
+ t.Errorf("proto.RangeExtensions mismatch (-want +got):\n%s", diff)
+ }
+}
+
func TestExtensionGetRace(t *testing.T) {
// Concurrently fetch an extension value while marshaling the message containing it.
// Create the message with proto.Unmarshal to give lazy extension decoding (if present)
diff --git a/proto/nil_test.go b/proto/nil_test.go
index d4563a8..9d13b2b 100644
--- a/proto/nil_test.go
+++ b/proto/nil_test.go
@@ -8,6 +8,7 @@
"testing"
"google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/reflect/protoreflect"
testpb "google.golang.org/protobuf/internal/testprotos/test"
)
@@ -19,6 +20,7 @@
func TestNil(t *testing.T) {
nilMsg := (*testpb.TestAllExtensions)(nil)
extType := testpb.E_OptionalBool
+ extRanger := func(protoreflect.ExtensionType, interface{}) bool { return true }
tests := []struct {
label string
@@ -89,11 +91,9 @@
}, {
label: "HasExtension",
test: func() { proto.HasExtension(nil, nil) },
- panic: true,
}, {
label: "HasExtension",
test: func() { proto.HasExtension(nil, extType) },
- panic: true,
}, {
label: "HasExtension",
test: func() { proto.HasExtension(nilMsg, nil) },
@@ -108,7 +108,6 @@
}, {
label: "GetExtension",
test: func() { proto.GetExtension(nil, extType) },
- panic: true,
}, {
label: "GetExtension",
test: func() { proto.GetExtension(nilMsg, nil) },
@@ -148,6 +147,18 @@
label: "ClearExtension",
test: func() { proto.ClearExtension(nilMsg, extType) },
panic: true,
+ }, {
+ label: "RangeExtensions",
+ test: func() { proto.RangeExtensions(nil, nil) },
+ }, {
+ label: "RangeExtensions",
+ test: func() { proto.RangeExtensions(nil, extRanger) },
+ }, {
+ label: "RangeExtensions",
+ test: func() { proto.RangeExtensions(nilMsg, nil) },
+ }, {
+ label: "RangeExtensions",
+ test: func() { proto.RangeExtensions(nilMsg, extRanger) },
}}
for _, tt := range tests {