unflake TestReadWriteTransaction_ErrorReturned

The mehod shouldHaveReceived was called twice
inside the test case, but calling this method
drains the requests from the mock server, which
will cause the second call to always return an
empty range of requests.

Fixes #1409

Change-Id: I25e07d787f206b07554de36e045c033c128e75ca
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/40410
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jean de Klerk <deklerk@google.com>
diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go
index 90bc547..f15f83d 100644
--- a/spanner/transaction_test.go
+++ b/spanner/transaction_test.go
@@ -262,20 +262,22 @@
 	if got != want {
 		t.Fatalf("got %+v, want %+v", got, want)
 	}
-	if _, err := shouldHaveReceived(mock, []interface{}{
+	requests := drainRequests(mock)
+	if err := compareRequests([]interface{}{
 		&sppb.CreateSessionRequest{},
 		&sppb.BeginTransactionRequest{},
-		&sppb.RollbackRequest{},
-	}); err != nil {
+		&sppb.RollbackRequest{}}, requests); err != nil {
+		// TODO: remove this once the session pool maintainer has been changed
+		// so that is doesn't delete sessions already during the first
+		// maintenance window.
 		// If we failed to get 3, it might have because - due to timing - we got
 		// a fourth request. If this request is DeleteSession, that's OK and
 		// expected.
-		if _, err := shouldHaveReceived(mock, []interface{}{
+		if err := compareRequests([]interface{}{
 			&sppb.CreateSessionRequest{},
 			&sppb.BeginTransactionRequest{},
 			&sppb.RollbackRequest{},
-			&sppb.DeleteSessionRequest{},
-		}); err != nil {
+			&sppb.DeleteSessionRequest{}}, requests); err != nil {
 			t.Fatal(err)
 		}
 	}
@@ -338,7 +340,11 @@
 // ReceivedRequests channel.
 func shouldHaveReceived(mock *testutil.FuncMock, want []interface{}) ([]interface{}, error) {
 	got := drainRequests(mock)
+	return got, compareRequests(want, got)
+}
 
+// Compares expected requests (want) with actual requests (got).
+func compareRequests(want []interface{}, got []interface{}) error {
 	if len(got) != len(want) {
 		var gotMsg string
 		for _, r := range got {
@@ -350,16 +356,15 @@
 			wantMsg += fmt.Sprintf("%v: %+v]\n", reflect.TypeOf(r), r)
 		}
 
-		return got, fmt.Errorf("got %d requests, want %d requests:\ngot:\n%s\nwant:\n%s", len(got), len(want), gotMsg, wantMsg)
+		return fmt.Errorf("got %d requests, want %d requests:\ngot:\n%s\nwant:\n%s", len(got), len(want), gotMsg, wantMsg)
 	}
 
 	for i, want := range want {
 		if reflect.TypeOf(got[i]) != reflect.TypeOf(want) {
-			return got, fmt.Errorf("request %d: got %+v, want %+v", i, reflect.TypeOf(got[i]), reflect.TypeOf(want))
+			return fmt.Errorf("request %d: got %+v, want %+v", i, reflect.TypeOf(got[i]), reflect.TypeOf(want))
 		}
 	}
-
-	return got, nil
+	return nil
 }
 
 func drainRequests(mock *testutil.FuncMock) []interface{} {