spanner: add BatchDML method

Change-Id: I5c1defc759eabf4969529a83f6ed6e1e3f55d5f4
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/38751
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Chris Broadfoot <cbro@google.com>
diff --git a/internal/kokoro/vet.sh b/internal/kokoro/vet.sh
index da52d86..cd645c7 100755
--- a/internal/kokoro/vet.sh
+++ b/internal/kokoro/vet.sh
@@ -63,6 +63,7 @@
     grep -v "internal/backoff" | \
     grep -v "internal/trace" | \
     grep -v "a blank import should be only in a main or test package" | \
+    grep -v "method ExecuteSql should be ExecuteSQL" | \
     grep -vE "\.pb\.go:" || true) | tee /dev/stderr | (! read)
 
 # TODO(deklerk) It doesn't seem like it, but is it possible to glob both before
diff --git a/spanner/integration_test.go b/spanner/integration_test.go
index 76f8a94..b4df029 100644
--- a/spanner/integration_test.go
+++ b/spanner/integration_test.go
@@ -38,6 +38,7 @@
 	"google.golang.org/api/option"
 	adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
 	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
 )
 
 var (
@@ -2069,6 +2070,182 @@
 	}
 }
 
+func TestBatchDML(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+	defer cancel()
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	defer cleanup()
+
+	columns := []string{"SingerId", "FirstName", "LastName"}
+
+	// Populate the Singers table.
+	var muts []*Mutation
+	for _, row := range [][]interface{}{
+		{1, "Umm", "Kulthum"},
+		{2, "Eduard", "Khil"},
+		{3, "Audra", "McDonald"},
+	} {
+		muts = append(muts, Insert("Singers", columns, row))
+	}
+	if _, err := client.Apply(ctx, muts); err != nil {
+		t.Fatal(err)
+	}
+
+	var counts []int64
+	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) {
+		counts, err = tx.BatchUpdate(ctx, []Statement{
+			{SQL: `UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`},
+			{SQL: `UPDATE Singers SET Singers.FirstName = "changed 2" WHERE Singers.SingerId = 2`},
+			{SQL: `UPDATE Singers SET Singers.FirstName = "changed 3" WHERE Singers.SingerId = 3`},
+		})
+		return err
+	})
+
+	if err != nil {
+		t.Fatal(err)
+	}
+	if want := []int64{1, 1, 1}; !testEqual(counts, want) {
+		t.Fatalf("got %d, want %d", counts, want)
+	}
+	got, err := readAll(client.Single().Read(ctx, "Singers", AllKeys(), columns))
+	if err != nil {
+		t.Fatal(err)
+	}
+	want := [][]interface{}{
+		{int64(1), "changed 1", "Kulthum"},
+		{int64(2), "changed 2", "Khil"},
+		{int64(3), "changed 3", "McDonald"},
+	}
+	if !testEqual(got, want) {
+		t.Errorf("\ngot %v\nwant%v", got, want)
+	}
+}
+
+func TestBatchDML_NoStatements(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+	defer cancel()
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	defer cleanup()
+
+	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) {
+		_, err = tx.BatchUpdate(ctx, []Statement{})
+		return err
+	})
+	if err == nil {
+		t.Fatal("expected error, got nil")
+	}
+	if s, ok := status.FromError(err); ok {
+		if s.Code() != codes.InvalidArgument {
+			t.Fatalf("expected InvalidArgument, got %v", err)
+		}
+	} else {
+		t.Fatalf("expected InvalidArgument, got %v", err)
+	}
+}
+
+func TestBatchDML_TwoStatements(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+	defer cancel()
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	defer cleanup()
+
+	columns := []string{"SingerId", "FirstName", "LastName"}
+
+	// Populate the Singers table.
+	var muts []*Mutation
+	for _, row := range [][]interface{}{
+		{1, "Umm", "Kulthum"},
+		{2, "Eduard", "Khil"},
+		{3, "Audra", "McDonald"},
+	} {
+		muts = append(muts, Insert("Singers", columns, row))
+	}
+	if _, err := client.Apply(ctx, muts); err != nil {
+		t.Fatal(err)
+	}
+
+	var updateCount int64
+	var batchCounts []int64
+	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) {
+		batchCounts, err = tx.BatchUpdate(ctx, []Statement{
+			{SQL: `UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`},
+			{SQL: `UPDATE Singers SET Singers.FirstName = "changed 2" WHERE Singers.SingerId = 2`},
+			{SQL: `UPDATE Singers SET Singers.FirstName = "changed 3" WHERE Singers.SingerId = 3`},
+		})
+		if err != nil {
+			return err
+		}
+
+		updateCount, err = tx.Update(ctx, Statement{SQL: `UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`})
+		return err
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+	if want := []int64{1, 1, 1}; !testEqual(batchCounts, want) {
+		t.Fatalf("got %d, want %d", batchCounts, want)
+	}
+	if updateCount != 1 {
+		t.Fatalf("got %v, want 1", updateCount)
+	}
+}
+
+// TODO(deklerk) this currently does not work because the transaction appears to
+// get rolled back after a single statement fails. b/120158761
+func TestBatchDML_Error(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+	defer cancel()
+	client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+	defer cleanup()
+
+	columns := []string{"SingerId", "FirstName", "LastName"}
+
+	// Populate the Singers table.
+	var muts []*Mutation
+	for _, row := range [][]interface{}{
+		{1, "Umm", "Kulthum"},
+		{2, "Eduard", "Khil"},
+		{3, "Audra", "McDonald"},
+	} {
+		muts = append(muts, Insert("Singers", columns, row))
+	}
+	if _, err := client.Apply(ctx, muts); err != nil {
+		t.Fatal(err)
+	}
+
+	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) {
+		counts, err := tx.BatchUpdate(ctx, []Statement{
+			{SQL: `UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`},
+			{SQL: `some illegal statement`},
+			{SQL: `UPDATE Singers SET Singers.FirstName = "changed 3" WHERE Singers.SingerId = 3`},
+		})
+		if err == nil {
+			t.Fatal("expected err, got nil")
+		}
+		if want := []int64{1}; !testEqual(counts, want) {
+			t.Fatalf("got %d, want %d", counts, want)
+		}
+
+		got, err := readAll(tx.Read(ctx, "Singers", AllKeys(), columns))
+		if err != nil {
+			t.Fatal(err)
+		}
+		want := [][]interface{}{
+			{int64(1), "changed 1", "Kulthum"},
+			{int64(2), "Eduard", "Khil"},
+			{int64(3), "Audra", "McDonald"},
+		}
+		if !testEqual(got, want) {
+			t.Errorf("\ngot %v\nwant%v", got, want)
+		}
+
+		return nil
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
 // Prepare initializes Cloud Spanner testing DB and clients.
 func prepareIntegrationTest(ctx context.Context, t *testing.T, statements []string) (*Client, string, func()) {
 	if admin == nil {
diff --git a/spanner/internal/testutil/funcmock.go b/spanner/internal/testutil/funcmock.go
index d669851..2536afc 100644
--- a/spanner/internal/testutil/funcmock.go
+++ b/spanner/internal/testutil/funcmock.go
@@ -29,32 +29,32 @@
 // Note: if you overload a method, you're in charge of making sure
 // MockCloudSpannerClient.ReceivedRequests receives the request appropriately.
 type FuncMock struct {
-	CommitFn           func(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error)
-	BeginTransactionFn func(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error)
-	GetSessionFn       func(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error)
-	CreateSessionFn    func(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error)
+	CommitFn           func(ctx context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error)
+	BeginTransactionFn func(ctx context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error)
+	GetSessionFn       func(ctx context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error)
+	CreateSessionFn    func(ctx context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error)
 	*MockCloudSpannerClient
 }
 
-func (s FuncMock) Commit(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) {
+func (s FuncMock) Commit(ctx context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) {
 	if s.CommitFn == nil {
-		return s.MockCloudSpannerClient.Commit(c, r, opts...)
+		return s.MockCloudSpannerClient.Commit(ctx, r, opts...)
 	}
-	return s.CommitFn(c, r, opts...)
+	return s.CommitFn(ctx, r, opts...)
 }
 
-func (s FuncMock) BeginTransaction(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) {
+func (s FuncMock) BeginTransaction(ctx context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) {
 	if s.BeginTransactionFn == nil {
-		return s.MockCloudSpannerClient.BeginTransaction(c, r, opts...)
+		return s.MockCloudSpannerClient.BeginTransaction(ctx, r, opts...)
 	}
-	return s.BeginTransactionFn(c, r, opts...)
+	return s.BeginTransactionFn(ctx, r, opts...)
 }
 
-func (s *FuncMock) GetSession(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
+func (s *FuncMock) GetSession(ctx context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
 	if s.GetSessionFn == nil {
-		return s.MockCloudSpannerClient.GetSession(c, r, opts...)
+		return s.MockCloudSpannerClient.GetSession(ctx, r, opts...)
 	}
-	return s.GetSessionFn(c, r, opts...)
+	return s.GetSessionFn(ctx, r, opts...)
 }
 
 func (s *FuncMock) CreateSession(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
diff --git a/spanner/internal/testutil/mockclient.go b/spanner/internal/testutil/mockclient.go
index f808458..2a3b918 100644
--- a/spanner/internal/testutil/mockclient.go
+++ b/spanner/internal/testutil/mockclient.go
@@ -28,6 +28,7 @@
 	"github.com/golang/protobuf/ptypes/empty"
 	proto3 "github.com/golang/protobuf/ptypes/struct"
 	pbt "github.com/golang/protobuf/ptypes/timestamp"
+	pbs "google.golang.org/genproto/googleapis/rpc/status"
 	sppb "google.golang.org/genproto/googleapis/spanner/v1"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/codes"
@@ -87,7 +88,7 @@
 }
 
 // CreateSession is a placeholder for SpannerClient.CreateSession.
-func (m *MockCloudSpannerClient) CreateSession(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
+func (m *MockCloudSpannerClient) CreateSession(ctx context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
 	m.ready()
 	m.ReceivedRequests <- r
 
@@ -105,7 +106,7 @@
 }
 
 // GetSession is a placeholder for SpannerClient.GetSession.
-func (m *MockCloudSpannerClient) GetSession(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
+func (m *MockCloudSpannerClient) GetSession(ctx context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
 	m.ready()
 	m.ReceivedRequests <- r
 
@@ -119,7 +120,7 @@
 }
 
 // DeleteSession is a placeholder for SpannerClient.DeleteSession.
-func (m *MockCloudSpannerClient) DeleteSession(c context.Context, r *sppb.DeleteSessionRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
+func (m *MockCloudSpannerClient) DeleteSession(ctx context.Context, r *sppb.DeleteSessionRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
 	m.ready()
 	m.ReceivedRequests <- r
 
@@ -134,8 +135,28 @@
 	return &empty.Empty{}, nil
 }
 
+// ExecuteSql is a placeholder for SpannerClient.ExecuteSql.
+func (m *MockCloudSpannerClient) ExecuteSql(ctx context.Context, r *sppb.ExecuteSqlRequest, opts ...grpc.CallOption) (*sppb.ResultSet, error) {
+	m.ready()
+	m.ReceivedRequests <- r
+
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	return &sppb.ResultSet{Stats: &sppb.ResultSetStats{RowCount: &sppb.ResultSetStats_RowCountExact{7}}}, nil
+}
+
+// ExecuteBatchDml is a placeholder for SpannerClient.ExecuteBatchDml.
+func (m *MockCloudSpannerClient) ExecuteBatchDml(ctx context.Context, r *sppb.ExecuteBatchDmlRequest, opts ...grpc.CallOption) (*sppb.ExecuteBatchDmlResponse, error) {
+	m.ready()
+	m.ReceivedRequests <- r
+
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	return &sppb.ExecuteBatchDmlResponse{Status: &pbs.Status{Code: 0}, ResultSets: []*sppb.ResultSet{}}, nil
+}
+
 // ExecuteStreamingSql is a mock implementation of SpannerClient.ExecuteStreamingSql.
-func (m *MockCloudSpannerClient) ExecuteStreamingSql(c context.Context, r *sppb.ExecuteSqlRequest, opts ...grpc.CallOption) (sppb.Spanner_ExecuteStreamingSqlClient, error) {
+func (m *MockCloudSpannerClient) ExecuteStreamingSql(ctx context.Context, r *sppb.ExecuteSqlRequest, opts ...grpc.CallOption) (sppb.Spanner_ExecuteStreamingSqlClient, error) {
 	m.ready()
 	m.ReceivedRequests <- r
 
@@ -170,7 +191,7 @@
 }
 
 // StreamingRead is a placeholder for SpannerClient.StreamingRead.
-func (m *MockCloudSpannerClient) StreamingRead(c context.Context, r *sppb.ReadRequest, opts ...grpc.CallOption) (sppb.Spanner_StreamingReadClient, error) {
+func (m *MockCloudSpannerClient) StreamingRead(ctx context.Context, r *sppb.ReadRequest, opts ...grpc.CallOption) (sppb.Spanner_StreamingReadClient, error) {
 	m.ready()
 	m.ReceivedRequests <- r
 
@@ -213,7 +234,7 @@
 }
 
 // BeginTransaction is a placeholder for SpannerClient.BeginTransaction.
-func (m *MockCloudSpannerClient) BeginTransaction(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) {
+func (m *MockCloudSpannerClient) BeginTransaction(ctx context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) {
 	m.ready()
 	m.ReceivedRequests <- r
 
@@ -227,7 +248,7 @@
 }
 
 // Commit is a placeholder for SpannerClient.Commit.
-func (m *MockCloudSpannerClient) Commit(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) {
+func (m *MockCloudSpannerClient) Commit(ctx context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) {
 	m.ready()
 	m.ReceivedRequests <- r
 
@@ -237,7 +258,7 @@
 }
 
 // Rollback is a placeholder for SpannerClient.Rollback.
-func (m *MockCloudSpannerClient) Rollback(c context.Context, r *sppb.RollbackRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
+func (m *MockCloudSpannerClient) Rollback(ctx context.Context, r *sppb.RollbackRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
 	m.ready()
 	m.ReceivedRequests <- r
 
diff --git a/spanner/transaction.go b/spanner/transaction.go
index d7f9ced..be75c30 100644
--- a/spanner/transaction.go
+++ b/spanner/transaction.go
@@ -205,8 +205,7 @@
 		t.release)
 }
 
-func (t *txReadOnly) prepareExecuteSQL(ctx context.Context, stmt Statement, mode sppb.ExecuteSqlRequest_QueryMode) (
-	*sppb.ExecuteSqlRequest, *sessionHandle, error) {
+func (t *txReadOnly) prepareExecuteSQL(ctx context.Context, stmt Statement, mode sppb.ExecuteSqlRequest_QueryMode) (*sppb.ExecuteSqlRequest, *sessionHandle, error) {
 	sh, ts, err := t.acquire(ctx)
 	if err != nil {
 		return nil, nil, err
@@ -558,9 +557,9 @@
 // ReadWriteTransaction provides a locking read-write transaction.
 //
 // This type of transaction is the only way to write data into Cloud Spanner;
-// (*Client).Apply and (*Client).ApplyAtLeastOnce use transactions
-// internally. These transactions rely on pessimistic locking and, if
-// necessary, two-phase commit. Locking read-write transactions may abort,
+// (*Client).Apply, (*Client).ApplyAtLeastOnce, (*Client).PartitionedUpdate use
+// transactions internally. These transactions rely on pessimistic locking and,
+// if necessary, two-phase commit. Locking read-write transactions may abort,
 // requiring the application to retry. However, the interface exposed by
 // (*Client).ReadWriteTransaction eliminates the need for applications to write
 // retry loops explicitly.
@@ -676,6 +675,64 @@
 	return extractRowCount(resultSet.Stats)
 }
 
+// BatchUpdate groups one or more DML statements and sends them to Spanner in a
+// single RPC. This is an efficient way to execute multiple DML statements.
+//
+// A slice of counts is returned, where each count represents the number of
+// affected rows for the given query at the same index. If an error occurs,
+// counts will be returned up to the query that encountered the error.
+func (t *ReadWriteTransaction) BatchUpdate(ctx context.Context, stmts []Statement) (_ []int64, err error) {
+	ctx = startSpan(ctx, "cloud.google.com/go/spanner.BatchUpdate")
+	defer func() { endSpan(ctx, err) }()
+
+	sh, ts, err := t.acquire(ctx)
+	if err != nil {
+		return nil, err
+	}
+	// Cloud Spanner will return "Session not found" on bad sessions.
+	sid := sh.getID()
+	if sid == "" {
+		// Might happen if transaction is closed in the middle of a API call.
+		return nil, errSessionClosed(sh)
+	}
+
+	var sppbStmts []*sppb.ExecuteBatchDmlRequest_Statement
+	for _, st := range stmts {
+		params, paramTypes, err := st.convertParams()
+		if err != nil {
+			return nil, err
+		}
+		sppbStmts = append(sppbStmts, &sppb.ExecuteBatchDmlRequest_Statement{
+			Sql:        st.SQL,
+			Params:     params,
+			ParamTypes: paramTypes,
+		})
+	}
+
+	resp, err := sh.getClient().ExecuteBatchDml(ctx, &sppb.ExecuteBatchDmlRequest{
+		Session:     sh.getID(),
+		Transaction: ts,
+		Statements:  sppbStmts,
+		Seqno:       atomic.AddInt64(&t.sequenceNumber, 1),
+	})
+	if err != nil {
+		return nil, err
+	}
+
+	var counts []int64
+	for _, rs := range resp.ResultSets {
+		count, err := extractRowCount(rs.Stats)
+		if err != nil {
+			return nil, err
+		}
+		counts = append(counts, count)
+	}
+	if resp.Status.Code != 0 {
+		return counts, spannerErrorf(codes.Code(uint32(resp.Status.Code)), resp.Status.Message)
+	}
+	return counts, nil
+}
+
 // acquire implements txReadEnv.acquire.
 func (t *ReadWriteTransaction) acquire(ctx context.Context) (*sessionHandle, *sppb.TransactionSelector, error) {
 	ts := &sppb.TransactionSelector{
diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go
index a710ecf..c4e1b2d 100644
--- a/spanner/transaction_test.go
+++ b/spanner/transaction_test.go
@@ -50,7 +50,7 @@
 	}
 
 	// Only one CreateSessionRequest is sent.
-	if err := shouldHaveReceived(mock, []interface{}{&sppb.CreateSessionRequest{}}); err != nil {
+	if _, err := shouldHaveReceived(mock, []interface{}{&sppb.CreateSessionRequest{}}); err != nil {
 		t.Fatal(err)
 	}
 }
@@ -159,7 +159,7 @@
 		t.Fatalf("applyAtLeastOnce retry on abort, got %v, want nil.", e)
 	}
 
-	if err := shouldHaveReceived(mock, []interface{}{
+	if _, err := shouldHaveReceived(mock, []interface{}{
 		&sppb.CreateSessionRequest{},
 		&sppb.CommitRequest{},
 	}); err != nil {
@@ -194,7 +194,7 @@
 		t.Fatalf("ReadWriteTransaction retry on abort, got %v, want nil.", e)
 	}
 
-	if err := shouldHaveReceived(mock, []interface{}{
+	if _, err := shouldHaveReceived(mock, []interface{}{
 		&sppb.CreateSessionRequest{},
 		&sppb.BeginTransactionRequest{},
 		&sppb.CommitRequest{}, // First commit fails.
@@ -262,7 +262,7 @@
 	if got != want {
 		t.Fatalf("got %+v, want %+v", got, want)
 	}
-	if err := shouldHaveReceived(mock, []interface{}{
+	if _, err := shouldHaveReceived(mock, []interface{}{
 		&sppb.CreateSessionRequest{},
 		&sppb.BeginTransactionRequest{},
 		&sppb.RollbackRequest{},
@@ -271,12 +271,62 @@
 	}
 }
 
+func TestBatchDML_WithMultipleDML(t *testing.T) {
+	t.Parallel()
+	ctx := context.Background()
+	client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{})
+	defer cleanup()
+
+	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) {
+		if _, err = tx.Update(ctx, Statement{SQL: "SELECT * FROM whatever"}); err != nil {
+			return err
+		}
+		if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: "SELECT * FROM whatever"}, {SQL: "SELECT * FROM whatever"}}); err != nil {
+			return err
+		}
+		if _, err = tx.Update(ctx, Statement{SQL: "SELECT * FROM whatever"}); err != nil {
+			return err
+		}
+		_, err = tx.BatchUpdate(ctx, []Statement{{SQL: "SELECT * FROM whatever"}})
+		return err
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	gotReqs, err := shouldHaveReceived(mock, []interface{}{
+		&sppb.CreateSessionRequest{},
+		&sppb.BeginTransactionRequest{},
+		&sppb.ExecuteSqlRequest{},
+		&sppb.ExecuteBatchDmlRequest{},
+		&sppb.ExecuteSqlRequest{},
+		&sppb.ExecuteBatchDmlRequest{},
+		&sppb.CommitRequest{},
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if got, want := gotReqs[2].(*sppb.ExecuteSqlRequest).Seqno, int64(1); got != want {
+		t.Errorf("got %d, want %d", got, want)
+	}
+	if got, want := gotReqs[3].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(2); got != want {
+		t.Errorf("got %d, want %d", got, want)
+	}
+	if got, want := gotReqs[4].(*sppb.ExecuteSqlRequest).Seqno, int64(3); got != want {
+		t.Errorf("got %d, want %d", got, want)
+	}
+	if got, want := gotReqs[5].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(4); got != want {
+		t.Errorf("got %d, want %d", got, want)
+	}
+}
+
 // shouldHaveReceived asserts that exactly expectedRequests were present in
 // the server's ReceivedRequests channel. It only looks at type, not contents.
 //
 // Note: this in-place modifies serverClientMock by popping items off the
 // ReceivedRequests channel.
-func shouldHaveReceived(mock *testutil.FuncMock, want []interface{}) error {
+func shouldHaveReceived(mock *testutil.FuncMock, want []interface{}) ([]interface{}, error) {
 	got := drainRequests(mock)
 
 	if len(got) != len(want) {
@@ -290,16 +340,16 @@
 			wantMsg += fmt.Sprintf("%v: %+v]\n", reflect.TypeOf(r), r)
 		}
 
-		return fmt.Errorf("got %d requests, want %d requests:\ngot:\n%s\nwant:\n%s", len(got), len(want), gotMsg, wantMsg)
+		return got, fmt.Errorf("got %d requests, want %d requests:\ngot:\n%s\nwant:\n%s", len(got), len(want), gotMsg, wantMsg)
 	}
 
 	for i, want := range want {
 		if reflect.TypeOf(got[i]) != reflect.TypeOf(want) {
-			return fmt.Errorf("request %d: got %+v, want %+v", i, reflect.TypeOf(got[i]), reflect.TypeOf(want))
+			return got, fmt.Errorf("request %d: got %+v, want %+v", i, reflect.TypeOf(got[i]), reflect.TypeOf(want))
 		}
 	}
 
-	return nil
+	return got, nil
 }
 
 func drainRequests(mock *testutil.FuncMock) []interface{} {