proto: add hooks.go to switch-over to the new implementation

Using the proto_reimpl build tag, we can control whether we go through
the current logic or the re-implemented logic. This allows us to re-use
all of the current tests with little to no modification to test the
new implementation.

Change-Id: I6d6beec05b859014f63193bf2c7530afa49eccd4
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/167767
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/integration_test.go b/integration_test.go
index f8fcadd..ba6d3ce 100644
--- a/integration_test.go
+++ b/integration_test.go
@@ -64,6 +64,7 @@
 			workDir := filepath.Join(goPath, "src", modulePath)
 			runGo("Build", workDir, "go", "build", "./...")
 			runGo("TestNormal", workDir, "go", "test", "-race", "./...")
+			runGo("TestReimpl", workDir, "go", "test", "-race", "-tags", "proto_reimpl", "./...")
 		})
 	}
 
diff --git a/internal/proto/common.go b/internal/proto/common.go
index 7258bae..f8ec496 100644
--- a/internal/proto/common.go
+++ b/internal/proto/common.go
@@ -9,9 +9,9 @@
 // that they would otherwise be able to call directly.
 
 import (
-	"github.com/golang/protobuf/proto"
+	"github.com/golang/protobuf/protoapi"
 
 	_ "github.com/golang/protobuf/v2/runtime/protolegacy"
 )
 
-type Message = proto.Message
+type Message = protoapi.Message
diff --git a/proto/all_test.go b/proto/all_test.go
index 2471011..a16c056 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -18,7 +18,6 @@
 	"testing"
 	"time"
 
-	protoV1a "github.com/golang/protobuf/internal/proto"
 	. "github.com/golang/protobuf/proto"
 	pb3 "github.com/golang/protobuf/proto/proto3_proto"
 	. "github.com/golang/protobuf/proto/test_proto"
@@ -1170,28 +1169,6 @@
 	}
 }
 
-func TestBadWireTypeUnknown2(t *testing.T) {
-	var b []byte
-	fmt.Sscanf("0a01780d00000000080b101612036161611521000000202c220362626225370000002203636363214200000000000000584d5a036464645900000000000056405d63000000", "%x", &b)
-
-	m := new(MyMessage)
-	if err := Unmarshal(b, m); err != nil {
-		t.Errorf("unexpected Unmarshal error: %v", err)
-	}
-
-	var unknown []byte
-	fmt.Sscanf("0a01780d0000000010161521000000202c2537000000214200000000000000584d5a036464645d63000000", "%x", &unknown)
-	if !bytes.Equal(m.XXX_unrecognized, unknown) {
-		t.Errorf("unknown bytes mismatch:\ngot  %x\nwant %x", m.XXX_unrecognized, unknown)
-	}
-	protoV1a.DiscardUnknown(m)
-
-	want := &MyMessage{Count: Int32(11), Name: String("aaa"), Pet: []string{"bbb", "ccc"}, Bigfloat: Float64(88)}
-	if !Equal(m, want) {
-		t.Errorf("message mismatch:\ngot  %v\nwant %v", m, want)
-	}
-}
-
 func encodeDecode(t *testing.T, in, out Message, msg string) {
 	buf, err := Marshal(in)
 	if err != nil {
@@ -1394,38 +1371,6 @@
 	}
 }
 
-func TestAllSetDefaults2(t *testing.T) {
-	// Exercise SetDefaults with all scalar field types.
-	m := &Defaults{
-		// NaN != NaN, so override that here.
-		F_Nan: Float32(1.7),
-	}
-	expected := &Defaults{
-		F_Bool:    Bool(true),
-		F_Int32:   Int32(32),
-		F_Int64:   Int64(64),
-		F_Fixed32: Uint32(320),
-		F_Fixed64: Uint64(640),
-		F_Uint32:  Uint32(3200),
-		F_Uint64:  Uint64(6400),
-		F_Float:   Float32(314159),
-		F_Double:  Float64(271828),
-		F_String:  String(`hello, "world!"` + "\n"),
-		F_Bytes:   []byte("Bignose"),
-		F_Sint32:  Int32(-32),
-		F_Sint64:  Int64(-64),
-		F_Enum:    Defaults_GREEN.Enum(),
-		F_Pinf:    Float32(float32(math.Inf(1))),
-		F_Ninf:    Float32(float32(math.Inf(-1))),
-		F_Nan:     Float32(1.7),
-		StrZero:   String(""),
-	}
-	protoV1a.SetDefaults(m)
-	if !Equal(m, expected) {
-		t.Errorf("SetDefaults failed\n got %v\nwant %v", m, expected)
-	}
-}
-
 func TestSetDefaultsWithSetField(t *testing.T) {
 	// Check that a set value is not overridden.
 	m := &Defaults{
@@ -1437,17 +1382,6 @@
 	}
 }
 
-func TestSetDefaultsWithSetField2(t *testing.T) {
-	// Check that a set value is not overridden.
-	m := &Defaults{
-		F_Int32: Int32(12),
-	}
-	protoV1a.SetDefaults(m)
-	if v := m.GetF_Int32(); v != 12 {
-		t.Errorf("m.FInt32 = %v, want 12", v)
-	}
-}
-
 func TestSetDefaultsWithSubMessage(t *testing.T) {
 	m := &OtherMessage{
 		Key: Int64(123),
@@ -1468,26 +1402,6 @@
 	}
 }
 
-func TestSetDefaultsWithSubMessage2(t *testing.T) {
-	m := &OtherMessage{
-		Key: Int64(123),
-		Inner: &InnerMessage{
-			Host: String("gopher"),
-		},
-	}
-	expected := &OtherMessage{
-		Key: Int64(123),
-		Inner: &InnerMessage{
-			Host: String("gopher"),
-			Port: Int32(4000),
-		},
-	}
-	protoV1a.SetDefaults(m)
-	if !Equal(m, expected) {
-		t.Errorf("\n got %v\nwant %v", m, expected)
-	}
-}
-
 func TestSetDefaultsWithRepeatedSubMessage(t *testing.T) {
 	m := &MyMessage{
 		RepInner: []*InnerMessage{{}},
@@ -1503,21 +1417,6 @@
 	}
 }
 
-func TestSetDefaultsWithRepeatedSubMessage2(t *testing.T) {
-	m := &MyMessage{
-		RepInner: []*InnerMessage{{}},
-	}
-	expected := &MyMessage{
-		RepInner: []*InnerMessage{{
-			Port: Int32(4000),
-		}},
-	}
-	protoV1a.SetDefaults(m)
-	if !Equal(m, expected) {
-		t.Errorf("\n got %v\nwant %v", m, expected)
-	}
-}
-
 func TestSetDefaultWithRepeatedNonMessage(t *testing.T) {
 	m := &MyMessage{
 		Pet: []string{"turtle", "wombat"},
@@ -1529,17 +1428,6 @@
 	}
 }
 
-func TestSetDefaultWithRepeatedNonMessage2(t *testing.T) {
-	m := &MyMessage{
-		Pet: []string{"turtle", "wombat"},
-	}
-	expected := Clone(m)
-	protoV1a.SetDefaults(m)
-	if !Equal(m, expected) {
-		t.Errorf("\n got %v\nwant %v", m, expected)
-	}
-}
-
 func TestMaximumTagNumber(t *testing.T) {
 	m := &MaxTag{
 		LastField: String("natural goat essence"),
diff --git a/proto/discard.go b/proto/discard.go
index c850e09..fe5a140 100644
--- a/proto/discard.go
+++ b/proto/discard.go
@@ -18,6 +18,8 @@
 	XXX_DiscardUnknown()
 }
 
+var discardUnknownAlt func(Message) // populated by hooks.go
+
 // DiscardUnknown recursively discards all unknown fields from this message
 // and all embedded messages.
 //
@@ -30,6 +32,11 @@
 // For proto2 messages, the unknown fields of message extensions are only
 // discarded from messages that have been accessed via GetExtension.
 func DiscardUnknown(m Message) {
+	if discardUnknownAlt != nil {
+		discardUnknownAlt(m)
+		return
+	}
+
 	if m, ok := m.(generatedDiscarder); ok {
 		m.XXX_DiscardUnknown()
 		return
diff --git a/proto/discard_test.go b/proto/discard_test.go
index 1f4b6eb..1c9dab9 100644
--- a/proto/discard_test.go
+++ b/proto/discard_test.go
@@ -7,7 +7,6 @@
 import (
 	"testing"
 
-	protoV1a "github.com/golang/protobuf/internal/proto"
 	"github.com/golang/protobuf/proto"
 
 	proto3pb "github.com/golang/protobuf/proto/proto3_proto"
@@ -104,39 +103,6 @@
 		}(),
 	}}
 
-	// Test the reflection code path.
-	for _, tt := range tests {
-		// Clone the input so that we don't alter the original.
-		in := tt.in
-		if in != nil {
-			in = proto.Clone(tt.in)
-		}
-
-		protoV1a.DiscardUnknown(tt.in)
-		if !proto.Equal(tt.in, tt.want) {
-			t.Errorf("test %s, expected unknown fields to be discarded\ngot  %v\nwant %v", tt.desc, tt.in, tt.want)
-		}
-	}
-
-	// Test the legacy code path.
-	for _, tt := range tests {
-		// Clone the input so that we don't alter the original.
-		in := tt.in
-		if in != nil {
-			in = proto.Clone(tt.in)
-		}
-
-		var m LegacyMessage
-		m.Message, _ = in.(*proto3pb.Message)
-		m.Communique, _ = in.(*pb.Communique)
-		m.MessageWithMap, _ = in.(*pb.MessageWithMap)
-		m.MyMessage, _ = in.(*pb.MyMessage)
-		proto.DiscardUnknown(&m)
-		if !proto.Equal(in, tt.want) {
-			t.Errorf("test %s/Legacy, expected unknown fields to be discarded\ngot  %v\nwant %v", tt.desc, in, tt.want)
-		}
-	}
-
 	for _, tt := range tests {
 		proto.DiscardUnknown(tt.in)
 		if !proto.Equal(tt.in, tt.want) {
@@ -144,17 +110,3 @@
 		}
 	}
 }
-
-// LegacyMessage is a proto.Message that has several nested messages.
-// This does not have the XXX_DiscardUnknown method and so forces DiscardUnknown
-// to use the legacy fallback logic.
-type LegacyMessage struct {
-	Message        *proto3pb.Message
-	Communique     *pb.Communique
-	MessageWithMap *pb.MessageWithMap
-	MyMessage      *pb.MyMessage
-}
-
-func (m *LegacyMessage) Reset()         { *m = LegacyMessage{} }
-func (m *LegacyMessage) String() string { return proto.CompactTextString(m) }
-func (*LegacyMessage) ProtoMessage()    {}
diff --git a/proto/hooks.go b/proto/hooks.go
new file mode 100644
index 0000000..cd4da7d
--- /dev/null
+++ b/proto/hooks.go
@@ -0,0 +1,14 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build proto_reimpl
+
+package proto
+
+import "github.com/golang/protobuf/internal/proto"
+
+func init() {
+	setDefaultsAlt = proto.SetDefaults
+	discardUnknownAlt = proto.DiscardUnknown
+}
diff --git a/proto/lib.go b/proto/lib.go
index 3106066..a35f5db 100644
--- a/proto/lib.go
+++ b/proto/lib.go
@@ -309,10 +309,16 @@
 	p.index = index
 }
 
+var setDefaultsAlt func(Message) // populated by hooks.go
+
 // SetDefaults sets unset protocol buffer fields to their default values.
 // It only modifies fields that are both unset and have defined defaults.
 // It recursively sets default values in any non-nil sub-messages.
 func SetDefaults(pb Message) {
+	if setDefaultsAlt != nil {
+		setDefaultsAlt(pb)
+		return
+	}
 	setDefaults(reflect.ValueOf(pb), true, false)
 }
 
diff --git a/proto/proto3_test.go b/proto/proto3_test.go
index a09f0fb..943ca20 100644
--- a/proto/proto3_test.go
+++ b/proto/proto3_test.go
@@ -8,7 +8,6 @@
 	"bytes"
 	"testing"
 
-	protoV1a "github.com/golang/protobuf/internal/proto"
 	"github.com/golang/protobuf/proto"
 	pb "github.com/golang/protobuf/proto/proto3_proto"
 	tpb "github.com/golang/protobuf/proto/test_proto"
@@ -109,37 +108,6 @@
 	}
 }
 
-func TestProto3SetDefaults2(t *testing.T) {
-	in := &pb.Message{
-		Terrain: map[string]*pb.Nested{
-			"meadow": new(pb.Nested),
-		},
-		Proto2Field: new(tpb.SubDefaults),
-		Proto2Value: map[string]*tpb.SubDefaults{
-			"badlands": new(tpb.SubDefaults),
-		},
-	}
-
-	got := proto.Clone(in).(*pb.Message)
-	protoV1a.SetDefaults(got)
-
-	// There are no defaults in proto3.  Everything should be the zero value, but
-	// we need to remember to set defaults for nested proto2 messages.
-	want := &pb.Message{
-		Terrain: map[string]*pb.Nested{
-			"meadow": new(pb.Nested),
-		},
-		Proto2Field: &tpb.SubDefaults{N: proto.Int64(7)},
-		Proto2Value: map[string]*tpb.SubDefaults{
-			"badlands": &tpb.SubDefaults{N: proto.Int64(7)},
-		},
-	}
-
-	if !proto.Equal(got, want) {
-		t.Errorf("with in = %v\nproto.SetDefaults(in) =>\ngot %v\nwant %v", in, got, want)
-	}
-}
-
 func TestUnknownFieldPreservation(t *testing.T) {
 	b1 := "\x0a\x05David"      // Known tag 1
 	b2 := "\xc2\x0c\x06Google" // Unknown tag 200