proto: add hooks.go to switch-over to the new implementation
Using the proto_reimpl build tag, we can control whether we go through
the current logic or the re-implemented logic. This allows us to re-use
all of the current tests with little to no modification to test the
new implementation.
Change-Id: I6d6beec05b859014f63193bf2c7530afa49eccd4
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/167767
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/integration_test.go b/integration_test.go
index f8fcadd..ba6d3ce 100644
--- a/integration_test.go
+++ b/integration_test.go
@@ -64,6 +64,7 @@
workDir := filepath.Join(goPath, "src", modulePath)
runGo("Build", workDir, "go", "build", "./...")
runGo("TestNormal", workDir, "go", "test", "-race", "./...")
+ runGo("TestReimpl", workDir, "go", "test", "-race", "-tags", "proto_reimpl", "./...")
})
}
diff --git a/internal/proto/common.go b/internal/proto/common.go
index 7258bae..f8ec496 100644
--- a/internal/proto/common.go
+++ b/internal/proto/common.go
@@ -9,9 +9,9 @@
// that they would otherwise be able to call directly.
import (
- "github.com/golang/protobuf/proto"
+ "github.com/golang/protobuf/protoapi"
_ "github.com/golang/protobuf/v2/runtime/protolegacy"
)
-type Message = proto.Message
+type Message = protoapi.Message
diff --git a/proto/all_test.go b/proto/all_test.go
index 2471011..a16c056 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -18,7 +18,6 @@
"testing"
"time"
- protoV1a "github.com/golang/protobuf/internal/proto"
. "github.com/golang/protobuf/proto"
pb3 "github.com/golang/protobuf/proto/proto3_proto"
. "github.com/golang/protobuf/proto/test_proto"
@@ -1170,28 +1169,6 @@
}
}
-func TestBadWireTypeUnknown2(t *testing.T) {
- var b []byte
- fmt.Sscanf("0a01780d00000000080b101612036161611521000000202c220362626225370000002203636363214200000000000000584d5a036464645900000000000056405d63000000", "%x", &b)
-
- m := new(MyMessage)
- if err := Unmarshal(b, m); err != nil {
- t.Errorf("unexpected Unmarshal error: %v", err)
- }
-
- var unknown []byte
- fmt.Sscanf("0a01780d0000000010161521000000202c2537000000214200000000000000584d5a036464645d63000000", "%x", &unknown)
- if !bytes.Equal(m.XXX_unrecognized, unknown) {
- t.Errorf("unknown bytes mismatch:\ngot %x\nwant %x", m.XXX_unrecognized, unknown)
- }
- protoV1a.DiscardUnknown(m)
-
- want := &MyMessage{Count: Int32(11), Name: String("aaa"), Pet: []string{"bbb", "ccc"}, Bigfloat: Float64(88)}
- if !Equal(m, want) {
- t.Errorf("message mismatch:\ngot %v\nwant %v", m, want)
- }
-}
-
func encodeDecode(t *testing.T, in, out Message, msg string) {
buf, err := Marshal(in)
if err != nil {
@@ -1394,38 +1371,6 @@
}
}
-func TestAllSetDefaults2(t *testing.T) {
- // Exercise SetDefaults with all scalar field types.
- m := &Defaults{
- // NaN != NaN, so override that here.
- F_Nan: Float32(1.7),
- }
- expected := &Defaults{
- F_Bool: Bool(true),
- F_Int32: Int32(32),
- F_Int64: Int64(64),
- F_Fixed32: Uint32(320),
- F_Fixed64: Uint64(640),
- F_Uint32: Uint32(3200),
- F_Uint64: Uint64(6400),
- F_Float: Float32(314159),
- F_Double: Float64(271828),
- F_String: String(`hello, "world!"` + "\n"),
- F_Bytes: []byte("Bignose"),
- F_Sint32: Int32(-32),
- F_Sint64: Int64(-64),
- F_Enum: Defaults_GREEN.Enum(),
- F_Pinf: Float32(float32(math.Inf(1))),
- F_Ninf: Float32(float32(math.Inf(-1))),
- F_Nan: Float32(1.7),
- StrZero: String(""),
- }
- protoV1a.SetDefaults(m)
- if !Equal(m, expected) {
- t.Errorf("SetDefaults failed\n got %v\nwant %v", m, expected)
- }
-}
-
func TestSetDefaultsWithSetField(t *testing.T) {
// Check that a set value is not overridden.
m := &Defaults{
@@ -1437,17 +1382,6 @@
}
}
-func TestSetDefaultsWithSetField2(t *testing.T) {
- // Check that a set value is not overridden.
- m := &Defaults{
- F_Int32: Int32(12),
- }
- protoV1a.SetDefaults(m)
- if v := m.GetF_Int32(); v != 12 {
- t.Errorf("m.FInt32 = %v, want 12", v)
- }
-}
-
func TestSetDefaultsWithSubMessage(t *testing.T) {
m := &OtherMessage{
Key: Int64(123),
@@ -1468,26 +1402,6 @@
}
}
-func TestSetDefaultsWithSubMessage2(t *testing.T) {
- m := &OtherMessage{
- Key: Int64(123),
- Inner: &InnerMessage{
- Host: String("gopher"),
- },
- }
- expected := &OtherMessage{
- Key: Int64(123),
- Inner: &InnerMessage{
- Host: String("gopher"),
- Port: Int32(4000),
- },
- }
- protoV1a.SetDefaults(m)
- if !Equal(m, expected) {
- t.Errorf("\n got %v\nwant %v", m, expected)
- }
-}
-
func TestSetDefaultsWithRepeatedSubMessage(t *testing.T) {
m := &MyMessage{
RepInner: []*InnerMessage{{}},
@@ -1503,21 +1417,6 @@
}
}
-func TestSetDefaultsWithRepeatedSubMessage2(t *testing.T) {
- m := &MyMessage{
- RepInner: []*InnerMessage{{}},
- }
- expected := &MyMessage{
- RepInner: []*InnerMessage{{
- Port: Int32(4000),
- }},
- }
- protoV1a.SetDefaults(m)
- if !Equal(m, expected) {
- t.Errorf("\n got %v\nwant %v", m, expected)
- }
-}
-
func TestSetDefaultWithRepeatedNonMessage(t *testing.T) {
m := &MyMessage{
Pet: []string{"turtle", "wombat"},
@@ -1529,17 +1428,6 @@
}
}
-func TestSetDefaultWithRepeatedNonMessage2(t *testing.T) {
- m := &MyMessage{
- Pet: []string{"turtle", "wombat"},
- }
- expected := Clone(m)
- protoV1a.SetDefaults(m)
- if !Equal(m, expected) {
- t.Errorf("\n got %v\nwant %v", m, expected)
- }
-}
-
func TestMaximumTagNumber(t *testing.T) {
m := &MaxTag{
LastField: String("natural goat essence"),
diff --git a/proto/discard.go b/proto/discard.go
index c850e09..fe5a140 100644
--- a/proto/discard.go
+++ b/proto/discard.go
@@ -18,6 +18,8 @@
XXX_DiscardUnknown()
}
+var discardUnknownAlt func(Message) // populated by hooks.go
+
// DiscardUnknown recursively discards all unknown fields from this message
// and all embedded messages.
//
@@ -30,6 +32,11 @@
// For proto2 messages, the unknown fields of message extensions are only
// discarded from messages that have been accessed via GetExtension.
func DiscardUnknown(m Message) {
+ if discardUnknownAlt != nil {
+ discardUnknownAlt(m)
+ return
+ }
+
if m, ok := m.(generatedDiscarder); ok {
m.XXX_DiscardUnknown()
return
diff --git a/proto/discard_test.go b/proto/discard_test.go
index 1f4b6eb..1c9dab9 100644
--- a/proto/discard_test.go
+++ b/proto/discard_test.go
@@ -7,7 +7,6 @@
import (
"testing"
- protoV1a "github.com/golang/protobuf/internal/proto"
"github.com/golang/protobuf/proto"
proto3pb "github.com/golang/protobuf/proto/proto3_proto"
@@ -104,39 +103,6 @@
}(),
}}
- // Test the reflection code path.
- for _, tt := range tests {
- // Clone the input so that we don't alter the original.
- in := tt.in
- if in != nil {
- in = proto.Clone(tt.in)
- }
-
- protoV1a.DiscardUnknown(tt.in)
- if !proto.Equal(tt.in, tt.want) {
- t.Errorf("test %s, expected unknown fields to be discarded\ngot %v\nwant %v", tt.desc, tt.in, tt.want)
- }
- }
-
- // Test the legacy code path.
- for _, tt := range tests {
- // Clone the input so that we don't alter the original.
- in := tt.in
- if in != nil {
- in = proto.Clone(tt.in)
- }
-
- var m LegacyMessage
- m.Message, _ = in.(*proto3pb.Message)
- m.Communique, _ = in.(*pb.Communique)
- m.MessageWithMap, _ = in.(*pb.MessageWithMap)
- m.MyMessage, _ = in.(*pb.MyMessage)
- proto.DiscardUnknown(&m)
- if !proto.Equal(in, tt.want) {
- t.Errorf("test %s/Legacy, expected unknown fields to be discarded\ngot %v\nwant %v", tt.desc, in, tt.want)
- }
- }
-
for _, tt := range tests {
proto.DiscardUnknown(tt.in)
if !proto.Equal(tt.in, tt.want) {
@@ -144,17 +110,3 @@
}
}
}
-
-// LegacyMessage is a proto.Message that has several nested messages.
-// This does not have the XXX_DiscardUnknown method and so forces DiscardUnknown
-// to use the legacy fallback logic.
-type LegacyMessage struct {
- Message *proto3pb.Message
- Communique *pb.Communique
- MessageWithMap *pb.MessageWithMap
- MyMessage *pb.MyMessage
-}
-
-func (m *LegacyMessage) Reset() { *m = LegacyMessage{} }
-func (m *LegacyMessage) String() string { return proto.CompactTextString(m) }
-func (*LegacyMessage) ProtoMessage() {}
diff --git a/proto/hooks.go b/proto/hooks.go
new file mode 100644
index 0000000..cd4da7d
--- /dev/null
+++ b/proto/hooks.go
@@ -0,0 +1,14 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build proto_reimpl
+
+package proto
+
+import "github.com/golang/protobuf/internal/proto"
+
+func init() {
+ setDefaultsAlt = proto.SetDefaults
+ discardUnknownAlt = proto.DiscardUnknown
+}
diff --git a/proto/lib.go b/proto/lib.go
index 3106066..a35f5db 100644
--- a/proto/lib.go
+++ b/proto/lib.go
@@ -309,10 +309,16 @@
p.index = index
}
+var setDefaultsAlt func(Message) // populated by hooks.go
+
// SetDefaults sets unset protocol buffer fields to their default values.
// It only modifies fields that are both unset and have defined defaults.
// It recursively sets default values in any non-nil sub-messages.
func SetDefaults(pb Message) {
+ if setDefaultsAlt != nil {
+ setDefaultsAlt(pb)
+ return
+ }
setDefaults(reflect.ValueOf(pb), true, false)
}
diff --git a/proto/proto3_test.go b/proto/proto3_test.go
index a09f0fb..943ca20 100644
--- a/proto/proto3_test.go
+++ b/proto/proto3_test.go
@@ -8,7 +8,6 @@
"bytes"
"testing"
- protoV1a "github.com/golang/protobuf/internal/proto"
"github.com/golang/protobuf/proto"
pb "github.com/golang/protobuf/proto/proto3_proto"
tpb "github.com/golang/protobuf/proto/test_proto"
@@ -109,37 +108,6 @@
}
}
-func TestProto3SetDefaults2(t *testing.T) {
- in := &pb.Message{
- Terrain: map[string]*pb.Nested{
- "meadow": new(pb.Nested),
- },
- Proto2Field: new(tpb.SubDefaults),
- Proto2Value: map[string]*tpb.SubDefaults{
- "badlands": new(tpb.SubDefaults),
- },
- }
-
- got := proto.Clone(in).(*pb.Message)
- protoV1a.SetDefaults(got)
-
- // There are no defaults in proto3. Everything should be the zero value, but
- // we need to remember to set defaults for nested proto2 messages.
- want := &pb.Message{
- Terrain: map[string]*pb.Nested{
- "meadow": new(pb.Nested),
- },
- Proto2Field: &tpb.SubDefaults{N: proto.Int64(7)},
- Proto2Value: map[string]*tpb.SubDefaults{
- "badlands": &tpb.SubDefaults{N: proto.Int64(7)},
- },
- }
-
- if !proto.Equal(got, want) {
- t.Errorf("with in = %v\nproto.SetDefaults(in) =>\ngot %v\nwant %v", in, got, want)
- }
-}
-
func TestUnknownFieldPreservation(t *testing.T) {
b1 := "\x0a\x05David" // Known tag 1
b2 := "\xc2\x0c\x06Google" // Unknown tag 200