spanner: add BatchDML method
Change-Id: I5c1defc759eabf4969529a83f6ed6e1e3f55d5f4
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/38751
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Chris Broadfoot <cbro@google.com>
diff --git a/internal/kokoro/vet.sh b/internal/kokoro/vet.sh
index da52d86..cd645c7 100755
--- a/internal/kokoro/vet.sh
+++ b/internal/kokoro/vet.sh
@@ -63,6 +63,7 @@
grep -v "internal/backoff" | \
grep -v "internal/trace" | \
grep -v "a blank import should be only in a main or test package" | \
+ grep -v "method ExecuteSql should be ExecuteSQL" | \
grep -vE "\.pb\.go:" || true) | tee /dev/stderr | (! read)
# TODO(deklerk) It doesn't seem like it, but is it possible to glob both before
diff --git a/spanner/integration_test.go b/spanner/integration_test.go
index 76f8a94..b4df029 100644
--- a/spanner/integration_test.go
+++ b/spanner/integration_test.go
@@ -38,6 +38,7 @@
"google.golang.org/api/option"
adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
"google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
)
var (
@@ -2069,6 +2070,182 @@
}
}
+func TestBatchDML(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+ defer cancel()
+ client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+ defer cleanup()
+
+ columns := []string{"SingerId", "FirstName", "LastName"}
+
+ // Populate the Singers table.
+ var muts []*Mutation
+ for _, row := range [][]interface{}{
+ {1, "Umm", "Kulthum"},
+ {2, "Eduard", "Khil"},
+ {3, "Audra", "McDonald"},
+ } {
+ muts = append(muts, Insert("Singers", columns, row))
+ }
+ if _, err := client.Apply(ctx, muts); err != nil {
+ t.Fatal(err)
+ }
+
+ var counts []int64
+ _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) {
+ counts, err = tx.BatchUpdate(ctx, []Statement{
+ {SQL: `UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`},
+ {SQL: `UPDATE Singers SET Singers.FirstName = "changed 2" WHERE Singers.SingerId = 2`},
+ {SQL: `UPDATE Singers SET Singers.FirstName = "changed 3" WHERE Singers.SingerId = 3`},
+ })
+ return err
+ })
+
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := []int64{1, 1, 1}; !testEqual(counts, want) {
+ t.Fatalf("got %d, want %d", counts, want)
+ }
+ got, err := readAll(client.Single().Read(ctx, "Singers", AllKeys(), columns))
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := [][]interface{}{
+ {int64(1), "changed 1", "Kulthum"},
+ {int64(2), "changed 2", "Khil"},
+ {int64(3), "changed 3", "McDonald"},
+ }
+ if !testEqual(got, want) {
+ t.Errorf("\ngot %v\nwant%v", got, want)
+ }
+}
+
+func TestBatchDML_NoStatements(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+ defer cancel()
+ client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+ defer cleanup()
+
+ _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) {
+ _, err = tx.BatchUpdate(ctx, []Statement{})
+ return err
+ })
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+ if s, ok := status.FromError(err); ok {
+ if s.Code() != codes.InvalidArgument {
+ t.Fatalf("expected InvalidArgument, got %v", err)
+ }
+ } else {
+ t.Fatalf("expected InvalidArgument, got %v", err)
+ }
+}
+
+func TestBatchDML_TwoStatements(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+ defer cancel()
+ client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+ defer cleanup()
+
+ columns := []string{"SingerId", "FirstName", "LastName"}
+
+ // Populate the Singers table.
+ var muts []*Mutation
+ for _, row := range [][]interface{}{
+ {1, "Umm", "Kulthum"},
+ {2, "Eduard", "Khil"},
+ {3, "Audra", "McDonald"},
+ } {
+ muts = append(muts, Insert("Singers", columns, row))
+ }
+ if _, err := client.Apply(ctx, muts); err != nil {
+ t.Fatal(err)
+ }
+
+ var updateCount int64
+ var batchCounts []int64
+ _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) {
+ batchCounts, err = tx.BatchUpdate(ctx, []Statement{
+ {SQL: `UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`},
+ {SQL: `UPDATE Singers SET Singers.FirstName = "changed 2" WHERE Singers.SingerId = 2`},
+ {SQL: `UPDATE Singers SET Singers.FirstName = "changed 3" WHERE Singers.SingerId = 3`},
+ })
+ if err != nil {
+ return err
+ }
+
+ updateCount, err = tx.Update(ctx, Statement{SQL: `UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`})
+ return err
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := []int64{1, 1, 1}; !testEqual(batchCounts, want) {
+ t.Fatalf("got %d, want %d", batchCounts, want)
+ }
+ if updateCount != 1 {
+ t.Fatalf("got %v, want 1", updateCount)
+ }
+}
+
+// TODO(deklerk) this currently does not work because the transaction appears to
+// get rolled back after a single statement fails. b/120158761
+func TestBatchDML_Error(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+ defer cancel()
+ client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements)
+ defer cleanup()
+
+ columns := []string{"SingerId", "FirstName", "LastName"}
+
+ // Populate the Singers table.
+ var muts []*Mutation
+ for _, row := range [][]interface{}{
+ {1, "Umm", "Kulthum"},
+ {2, "Eduard", "Khil"},
+ {3, "Audra", "McDonald"},
+ } {
+ muts = append(muts, Insert("Singers", columns, row))
+ }
+ if _, err := client.Apply(ctx, muts); err != nil {
+ t.Fatal(err)
+ }
+
+ _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) {
+ counts, err := tx.BatchUpdate(ctx, []Statement{
+ {SQL: `UPDATE Singers SET Singers.FirstName = "changed 1" WHERE Singers.SingerId = 1`},
+ {SQL: `some illegal statement`},
+ {SQL: `UPDATE Singers SET Singers.FirstName = "changed 3" WHERE Singers.SingerId = 3`},
+ })
+ if err == nil {
+ t.Fatal("expected err, got nil")
+ }
+ if want := []int64{1}; !testEqual(counts, want) {
+ t.Fatalf("got %d, want %d", counts, want)
+ }
+
+ got, err := readAll(tx.Read(ctx, "Singers", AllKeys(), columns))
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := [][]interface{}{
+ {int64(1), "changed 1", "Kulthum"},
+ {int64(2), "Eduard", "Khil"},
+ {int64(3), "Audra", "McDonald"},
+ }
+ if !testEqual(got, want) {
+ t.Errorf("\ngot %v\nwant%v", got, want)
+ }
+
+ return nil
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
// Prepare initializes Cloud Spanner testing DB and clients.
func prepareIntegrationTest(ctx context.Context, t *testing.T, statements []string) (*Client, string, func()) {
if admin == nil {
diff --git a/spanner/internal/testutil/funcmock.go b/spanner/internal/testutil/funcmock.go
index d669851..2536afc 100644
--- a/spanner/internal/testutil/funcmock.go
+++ b/spanner/internal/testutil/funcmock.go
@@ -29,32 +29,32 @@
// Note: if you overload a method, you're in charge of making sure
// MockCloudSpannerClient.ReceivedRequests receives the request appropriately.
type FuncMock struct {
- CommitFn func(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error)
- BeginTransactionFn func(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error)
- GetSessionFn func(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error)
- CreateSessionFn func(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error)
+ CommitFn func(ctx context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error)
+ BeginTransactionFn func(ctx context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error)
+ GetSessionFn func(ctx context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error)
+ CreateSessionFn func(ctx context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error)
*MockCloudSpannerClient
}
-func (s FuncMock) Commit(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) {
+func (s FuncMock) Commit(ctx context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) {
if s.CommitFn == nil {
- return s.MockCloudSpannerClient.Commit(c, r, opts...)
+ return s.MockCloudSpannerClient.Commit(ctx, r, opts...)
}
- return s.CommitFn(c, r, opts...)
+ return s.CommitFn(ctx, r, opts...)
}
-func (s FuncMock) BeginTransaction(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) {
+func (s FuncMock) BeginTransaction(ctx context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) {
if s.BeginTransactionFn == nil {
- return s.MockCloudSpannerClient.BeginTransaction(c, r, opts...)
+ return s.MockCloudSpannerClient.BeginTransaction(ctx, r, opts...)
}
- return s.BeginTransactionFn(c, r, opts...)
+ return s.BeginTransactionFn(ctx, r, opts...)
}
-func (s *FuncMock) GetSession(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
+func (s *FuncMock) GetSession(ctx context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
if s.GetSessionFn == nil {
- return s.MockCloudSpannerClient.GetSession(c, r, opts...)
+ return s.MockCloudSpannerClient.GetSession(ctx, r, opts...)
}
- return s.GetSessionFn(c, r, opts...)
+ return s.GetSessionFn(ctx, r, opts...)
}
func (s *FuncMock) CreateSession(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
diff --git a/spanner/internal/testutil/mockclient.go b/spanner/internal/testutil/mockclient.go
index f808458..2a3b918 100644
--- a/spanner/internal/testutil/mockclient.go
+++ b/spanner/internal/testutil/mockclient.go
@@ -28,6 +28,7 @@
"github.com/golang/protobuf/ptypes/empty"
proto3 "github.com/golang/protobuf/ptypes/struct"
pbt "github.com/golang/protobuf/ptypes/timestamp"
+ pbs "google.golang.org/genproto/googleapis/rpc/status"
sppb "google.golang.org/genproto/googleapis/spanner/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@@ -87,7 +88,7 @@
}
// CreateSession is a placeholder for SpannerClient.CreateSession.
-func (m *MockCloudSpannerClient) CreateSession(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
+func (m *MockCloudSpannerClient) CreateSession(ctx context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
m.ready()
m.ReceivedRequests <- r
@@ -105,7 +106,7 @@
}
// GetSession is a placeholder for SpannerClient.GetSession.
-func (m *MockCloudSpannerClient) GetSession(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
+func (m *MockCloudSpannerClient) GetSession(ctx context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) {
m.ready()
m.ReceivedRequests <- r
@@ -119,7 +120,7 @@
}
// DeleteSession is a placeholder for SpannerClient.DeleteSession.
-func (m *MockCloudSpannerClient) DeleteSession(c context.Context, r *sppb.DeleteSessionRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
+func (m *MockCloudSpannerClient) DeleteSession(ctx context.Context, r *sppb.DeleteSessionRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
m.ready()
m.ReceivedRequests <- r
@@ -134,8 +135,28 @@
return &empty.Empty{}, nil
}
+// ExecuteSql is a placeholder for SpannerClient.ExecuteSql.
+func (m *MockCloudSpannerClient) ExecuteSql(ctx context.Context, r *sppb.ExecuteSqlRequest, opts ...grpc.CallOption) (*sppb.ResultSet, error) {
+ m.ready()
+ m.ReceivedRequests <- r
+
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return &sppb.ResultSet{Stats: &sppb.ResultSetStats{RowCount: &sppb.ResultSetStats_RowCountExact{7}}}, nil
+}
+
+// ExecuteBatchDml is a placeholder for SpannerClient.ExecuteBatchDml.
+func (m *MockCloudSpannerClient) ExecuteBatchDml(ctx context.Context, r *sppb.ExecuteBatchDmlRequest, opts ...grpc.CallOption) (*sppb.ExecuteBatchDmlResponse, error) {
+ m.ready()
+ m.ReceivedRequests <- r
+
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return &sppb.ExecuteBatchDmlResponse{Status: &pbs.Status{Code: 0}, ResultSets: []*sppb.ResultSet{}}, nil
+}
+
// ExecuteStreamingSql is a mock implementation of SpannerClient.ExecuteStreamingSql.
-func (m *MockCloudSpannerClient) ExecuteStreamingSql(c context.Context, r *sppb.ExecuteSqlRequest, opts ...grpc.CallOption) (sppb.Spanner_ExecuteStreamingSqlClient, error) {
+func (m *MockCloudSpannerClient) ExecuteStreamingSql(ctx context.Context, r *sppb.ExecuteSqlRequest, opts ...grpc.CallOption) (sppb.Spanner_ExecuteStreamingSqlClient, error) {
m.ready()
m.ReceivedRequests <- r
@@ -170,7 +191,7 @@
}
// StreamingRead is a placeholder for SpannerClient.StreamingRead.
-func (m *MockCloudSpannerClient) StreamingRead(c context.Context, r *sppb.ReadRequest, opts ...grpc.CallOption) (sppb.Spanner_StreamingReadClient, error) {
+func (m *MockCloudSpannerClient) StreamingRead(ctx context.Context, r *sppb.ReadRequest, opts ...grpc.CallOption) (sppb.Spanner_StreamingReadClient, error) {
m.ready()
m.ReceivedRequests <- r
@@ -213,7 +234,7 @@
}
// BeginTransaction is a placeholder for SpannerClient.BeginTransaction.
-func (m *MockCloudSpannerClient) BeginTransaction(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) {
+func (m *MockCloudSpannerClient) BeginTransaction(ctx context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) {
m.ready()
m.ReceivedRequests <- r
@@ -227,7 +248,7 @@
}
// Commit is a placeholder for SpannerClient.Commit.
-func (m *MockCloudSpannerClient) Commit(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) {
+func (m *MockCloudSpannerClient) Commit(ctx context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) {
m.ready()
m.ReceivedRequests <- r
@@ -237,7 +258,7 @@
}
// Rollback is a placeholder for SpannerClient.Rollback.
-func (m *MockCloudSpannerClient) Rollback(c context.Context, r *sppb.RollbackRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
+func (m *MockCloudSpannerClient) Rollback(ctx context.Context, r *sppb.RollbackRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
m.ready()
m.ReceivedRequests <- r
diff --git a/spanner/transaction.go b/spanner/transaction.go
index d7f9ced..be75c30 100644
--- a/spanner/transaction.go
+++ b/spanner/transaction.go
@@ -205,8 +205,7 @@
t.release)
}
-func (t *txReadOnly) prepareExecuteSQL(ctx context.Context, stmt Statement, mode sppb.ExecuteSqlRequest_QueryMode) (
- *sppb.ExecuteSqlRequest, *sessionHandle, error) {
+func (t *txReadOnly) prepareExecuteSQL(ctx context.Context, stmt Statement, mode sppb.ExecuteSqlRequest_QueryMode) (*sppb.ExecuteSqlRequest, *sessionHandle, error) {
sh, ts, err := t.acquire(ctx)
if err != nil {
return nil, nil, err
@@ -558,9 +557,9 @@
// ReadWriteTransaction provides a locking read-write transaction.
//
// This type of transaction is the only way to write data into Cloud Spanner;
-// (*Client).Apply and (*Client).ApplyAtLeastOnce use transactions
-// internally. These transactions rely on pessimistic locking and, if
-// necessary, two-phase commit. Locking read-write transactions may abort,
+// (*Client).Apply, (*Client).ApplyAtLeastOnce, (*Client).PartitionedUpdate use
+// transactions internally. These transactions rely on pessimistic locking and,
+// if necessary, two-phase commit. Locking read-write transactions may abort,
// requiring the application to retry. However, the interface exposed by
// (*Client).ReadWriteTransaction eliminates the need for applications to write
// retry loops explicitly.
@@ -676,6 +675,64 @@
return extractRowCount(resultSet.Stats)
}
+// BatchUpdate groups one or more DML statements and sends them to Spanner in a
+// single RPC. This is an efficient way to execute multiple DML statements.
+//
+// A slice of counts is returned, where each count represents the number of
+// affected rows for the given query at the same index. If an error occurs,
+// counts will be returned up to the query that encountered the error.
+func (t *ReadWriteTransaction) BatchUpdate(ctx context.Context, stmts []Statement) (_ []int64, err error) {
+ ctx = startSpan(ctx, "cloud.google.com/go/spanner.BatchUpdate")
+ defer func() { endSpan(ctx, err) }()
+
+ sh, ts, err := t.acquire(ctx)
+ if err != nil {
+ return nil, err
+ }
+ // Cloud Spanner will return "Session not found" on bad sessions.
+ sid := sh.getID()
+ if sid == "" {
+ // Might happen if transaction is closed in the middle of a API call.
+ return nil, errSessionClosed(sh)
+ }
+
+ var sppbStmts []*sppb.ExecuteBatchDmlRequest_Statement
+ for _, st := range stmts {
+ params, paramTypes, err := st.convertParams()
+ if err != nil {
+ return nil, err
+ }
+ sppbStmts = append(sppbStmts, &sppb.ExecuteBatchDmlRequest_Statement{
+ Sql: st.SQL,
+ Params: params,
+ ParamTypes: paramTypes,
+ })
+ }
+
+ resp, err := sh.getClient().ExecuteBatchDml(ctx, &sppb.ExecuteBatchDmlRequest{
+ Session: sh.getID(),
+ Transaction: ts,
+ Statements: sppbStmts,
+ Seqno: atomic.AddInt64(&t.sequenceNumber, 1),
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ var counts []int64
+ for _, rs := range resp.ResultSets {
+ count, err := extractRowCount(rs.Stats)
+ if err != nil {
+ return nil, err
+ }
+ counts = append(counts, count)
+ }
+ if resp.Status.Code != 0 {
+ return counts, spannerErrorf(codes.Code(uint32(resp.Status.Code)), resp.Status.Message)
+ }
+ return counts, nil
+}
+
// acquire implements txReadEnv.acquire.
func (t *ReadWriteTransaction) acquire(ctx context.Context) (*sessionHandle, *sppb.TransactionSelector, error) {
ts := &sppb.TransactionSelector{
diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go
index a710ecf..c4e1b2d 100644
--- a/spanner/transaction_test.go
+++ b/spanner/transaction_test.go
@@ -50,7 +50,7 @@
}
// Only one CreateSessionRequest is sent.
- if err := shouldHaveReceived(mock, []interface{}{&sppb.CreateSessionRequest{}}); err != nil {
+ if _, err := shouldHaveReceived(mock, []interface{}{&sppb.CreateSessionRequest{}}); err != nil {
t.Fatal(err)
}
}
@@ -159,7 +159,7 @@
t.Fatalf("applyAtLeastOnce retry on abort, got %v, want nil.", e)
}
- if err := shouldHaveReceived(mock, []interface{}{
+ if _, err := shouldHaveReceived(mock, []interface{}{
&sppb.CreateSessionRequest{},
&sppb.CommitRequest{},
}); err != nil {
@@ -194,7 +194,7 @@
t.Fatalf("ReadWriteTransaction retry on abort, got %v, want nil.", e)
}
- if err := shouldHaveReceived(mock, []interface{}{
+ if _, err := shouldHaveReceived(mock, []interface{}{
&sppb.CreateSessionRequest{},
&sppb.BeginTransactionRequest{},
&sppb.CommitRequest{}, // First commit fails.
@@ -262,7 +262,7 @@
if got != want {
t.Fatalf("got %+v, want %+v", got, want)
}
- if err := shouldHaveReceived(mock, []interface{}{
+ if _, err := shouldHaveReceived(mock, []interface{}{
&sppb.CreateSessionRequest{},
&sppb.BeginTransactionRequest{},
&sppb.RollbackRequest{},
@@ -271,12 +271,62 @@
}
}
+func TestBatchDML_WithMultipleDML(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{})
+ defer cleanup()
+
+ _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) {
+ if _, err = tx.Update(ctx, Statement{SQL: "SELECT * FROM whatever"}); err != nil {
+ return err
+ }
+ if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: "SELECT * FROM whatever"}, {SQL: "SELECT * FROM whatever"}}); err != nil {
+ return err
+ }
+ if _, err = tx.Update(ctx, Statement{SQL: "SELECT * FROM whatever"}); err != nil {
+ return err
+ }
+ _, err = tx.BatchUpdate(ctx, []Statement{{SQL: "SELECT * FROM whatever"}})
+ return err
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ gotReqs, err := shouldHaveReceived(mock, []interface{}{
+ &sppb.CreateSessionRequest{},
+ &sppb.BeginTransactionRequest{},
+ &sppb.ExecuteSqlRequest{},
+ &sppb.ExecuteBatchDmlRequest{},
+ &sppb.ExecuteSqlRequest{},
+ &sppb.ExecuteBatchDmlRequest{},
+ &sppb.CommitRequest{},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if got, want := gotReqs[2].(*sppb.ExecuteSqlRequest).Seqno, int64(1); got != want {
+ t.Errorf("got %d, want %d", got, want)
+ }
+ if got, want := gotReqs[3].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(2); got != want {
+ t.Errorf("got %d, want %d", got, want)
+ }
+ if got, want := gotReqs[4].(*sppb.ExecuteSqlRequest).Seqno, int64(3); got != want {
+ t.Errorf("got %d, want %d", got, want)
+ }
+ if got, want := gotReqs[5].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(4); got != want {
+ t.Errorf("got %d, want %d", got, want)
+ }
+}
+
// shouldHaveReceived asserts that exactly expectedRequests were present in
// the server's ReceivedRequests channel. It only looks at type, not contents.
//
// Note: this in-place modifies serverClientMock by popping items off the
// ReceivedRequests channel.
-func shouldHaveReceived(mock *testutil.FuncMock, want []interface{}) error {
+func shouldHaveReceived(mock *testutil.FuncMock, want []interface{}) ([]interface{}, error) {
got := drainRequests(mock)
if len(got) != len(want) {
@@ -290,16 +340,16 @@
wantMsg += fmt.Sprintf("%v: %+v]\n", reflect.TypeOf(r), r)
}
- return fmt.Errorf("got %d requests, want %d requests:\ngot:\n%s\nwant:\n%s", len(got), len(want), gotMsg, wantMsg)
+ return got, fmt.Errorf("got %d requests, want %d requests:\ngot:\n%s\nwant:\n%s", len(got), len(want), gotMsg, wantMsg)
}
for i, want := range want {
if reflect.TypeOf(got[i]) != reflect.TypeOf(want) {
- return fmt.Errorf("request %d: got %+v, want %+v", i, reflect.TypeOf(got[i]), reflect.TypeOf(want))
+ return got, fmt.Errorf("request %d: got %+v, want %+v", i, reflect.TypeOf(got[i]), reflect.TypeOf(want))
}
}
- return nil
+ return got, nil
}
func drainRequests(mock *testutil.FuncMock) []interface{} {