Revert "Revert "spanner: DML support""
This reverts commit 31ef90783ca56d389e4a7168f104088269c145fc.
Reason for revert: Spanner DML backend issue is fixed.
Change-Id: Iae2f25c78f36981e0c5af6aa2dfb50d438f741e0
Reviewed-on: https://code-review.googlesource.com/c/34110
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Eno Compton <enocom@google.com>
diff --git a/spanner/big_pdml_test.go b/spanner/big_pdml_test.go
new file mode 100644
index 0000000..b7e2f31
--- /dev/null
+++ b/spanner/big_pdml_test.go
@@ -0,0 +1,61 @@
+// Copyright 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build bigtest
+
+// An integration test for PDML using a relatively large database.
+
+package spanner
+
+import (
+ "fmt"
+ "testing"
+
+ "golang.org/x/net/context"
+)
+
+func TestBigPDML(t *testing.T) {
+ const nRows int = 1e4
+
+ ctx := context.Background()
+ client, _, tearDown := prepare(ctx, t, singerDBStatements)
+ defer tearDown()
+
+ columns := []string{"SingerId", "FirstName", "LastName"}
+
+ // Populate the Singers table with random data.
+ const rowsPerApply = 1000
+ for i := 0; i < nRows; i += rowsPerApply {
+ var muts []*Mutation
+ for j := 0; j < rowsPerApply; j++ {
+ id := i + j
+ row := []interface{}{id, fmt.Sprintf("FirstName%d", id), fmt.Sprintf("LastName%d", id)}
+ muts = append(muts, Insert("Singers", columns, row))
+ }
+ if _, err := client.Apply(ctx, muts); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ // Run a PDML statement.
+ count, err := client.PartitionedUpdate(ctx, Statement{
+ SQL: `UPDATE Singers SET Singers.FirstName = "changed" WHERE Singers.SingerId != -1`,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := int64(nRows); count != want {
+ t.Errorf("got %d, want %d", count, want)
+ }
+}
diff --git a/spanner/client.go b/spanner/client.go
index 309b903..2ff856f 100644
--- a/spanner/client.go
+++ b/spanner/client.go
@@ -18,7 +18,6 @@
import (
"fmt"
- "log"
"regexp"
"sync/atomic"
"time"
@@ -255,26 +254,12 @@
)
defer func() {
if err != nil && sh != nil {
- e := runRetryable(ctx, func(ctx context.Context) error {
- _, e := s.client.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: s.getID()})
- return e
- })
- if e != nil {
- log.Printf("Failed to delete session %v. Error: %v", s.getID(), e)
- }
+ s.delete(ctx)
}
}()
// create session
sc := c.rrNext()
- err = runRetryable(ctx, func(ctx context.Context) error {
- sid, e := sc.CreateSession(ctx, &sppb.CreateSessionRequest{Database: c.database, Session: &sppb.Session{Labels: c.sessionLabels}})
- if e != nil {
- return e
- }
- // If no error, construct the new session.
- s = &session{valid: true, client: sc, id: sid.Name, createTime: time.Now(), md: c.md}
- return nil
- })
+ s, err = createSession(ctx, sc, c.database, c.sessionLabels, c.md)
if err != nil {
return nil, err
}
diff --git a/spanner/doc.go b/spanner/doc.go
index 8f0303b..4954082 100644
--- a/spanner/doc.go
+++ b/spanner/doc.go
@@ -306,6 +306,19 @@
return nil
})
+
+DML and Partitioned DML
+
+Spanner supports DML statements like INSERT, UPDATE and DELETE. Use
+ReadWriteTransaction.Update to run DML statements. It returns the number of rows
+affected. (You can call use ReadWriteTransaction.Query with a DML statement. The first
+call to Next on the resulting RowIterator will return iterator.Done, and the RowCount
+field of the iterator will hold the number of affected rows.)
+
+For large databases, it may be more efficient to partition the DML statement. Use
+client.PartitionedUpdate to run a DML statement in this way. Not all DML statements
+can be partitioned.
+
Tracing
This client has been instrumented to use OpenCensus tracing (http://opencensus.io).
diff --git a/spanner/integration_test.go b/spanner/integration_test.go
index 9b28fef..eabfe49 100644
--- a/spanner/integration_test.go
+++ b/spanner/integration_test.go
@@ -949,7 +949,7 @@
// Test client recovery on database recreation.
func TestDbRemovalRecovery(t *testing.T) {
t.Parallel()
- ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
client, dbPath, tearDown := prepare(ctx, t, singerDBStatements)
defer tearDown()
@@ -1924,3 +1924,210 @@
t.Error(msg)
}
}
+
+func TestDML(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+ defer cancel()
+ client, _, tearDown := prepare(ctx, t, singerDBStatements)
+ defer tearDown()
+
+ // Function that reads a single row's first name from within a transaction.
+ readFirstName := func(tx *ReadWriteTransaction, key int) (string, error) {
+ row, err := tx.ReadRow(ctx, "Singers", Key{key}, []string{"FirstName"})
+ if err != nil {
+ return "", err
+ }
+ var fn string
+ if err := row.Column(0, &fn); err != nil {
+ return "", err
+ }
+ return fn, nil
+ }
+
+ // Function that reads multiple rows' first names from outside a read/write transaction.
+ readFirstNames := func(keys ...int) []string {
+ var ks []KeySet
+ for _, k := range keys {
+ ks = append(ks, Key{k})
+ }
+ iter := client.Single().Read(ctx, "Singers", KeySets(ks...), []string{"FirstName"})
+ var got []string
+ var fn string
+ err := iter.Do(func(row *Row) error {
+ if err := row.Column(0, &fn); err != nil {
+ return err
+ }
+ got = append(got, fn)
+ return nil
+ })
+ if err != nil {
+ t.Fatalf("readFirstNames(%v): %v", keys, err)
+ }
+ return got
+ }
+
+ // Use ReadWriteTransaction.Query to execute a DML statement.
+ _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+ iter := tx.Query(ctx, Statement{
+ SQL: `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, "Umm", "Kulthum")`,
+ })
+ defer iter.Stop()
+ if row, err := iter.Next(); err != iterator.Done {
+ t.Fatalf("got results from iterator, want none: %#v, err = %v\n", row, err)
+ }
+ if iter.RowCount != 1 {
+ t.Errorf("row count: got %d, want 1", iter.RowCount)
+ }
+ // The results of the DML statement should be visible to the transaction.
+ got, err := readFirstName(tx, 1)
+ if err != nil {
+ return err
+ }
+ if want := "Umm"; got != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+ return nil
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Use ReadWriteTransaction.Update to execute a DML statement.
+ _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+ count, err := tx.Update(ctx, Statement{
+ SQL: `Insert INTO Singers (SingerId, FirstName, LastName) VALUES (2, "Eduard", "Khil")`,
+ })
+ if count != 1 {
+ t.Errorf("row count: got %d, want 1", count)
+ }
+ got, err := readFirstName(tx, 2)
+ if err != nil {
+ return err
+ }
+ if want := "Eduard"; got != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+ return nil
+
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Roll back a DML statement and confirm that it didn't happen.
+ var fail = errors.New("fail")
+ _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+ _, err := tx.Update(ctx, Statement{
+ SQL: `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (3, "Audra", "McDonald")`,
+ })
+ if err != nil {
+ return err
+ }
+ return fail
+ })
+ if err != fail {
+ t.Fatalf("rolling back: got error %v, want the error 'fail'", err)
+ }
+ _, err = client.Single().ReadRow(ctx, "Singers", Key{3}, []string{"FirstName"})
+ if got, want := ErrCode(err), codes.NotFound; got != want {
+ t.Errorf("got %s, want %s", got, want)
+ }
+
+ // Run two DML statements in the same transaction.
+ _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+ _, err := tx.Update(ctx, Statement{SQL: `UPDATE Singers SET FirstName = "Oum" WHERE SingerId = 1`})
+ if err != nil {
+ return err
+ }
+ _, err = tx.Update(ctx, Statement{SQL: `UPDATE Singers SET FirstName = "Eddie" WHERE SingerId = 2`})
+ if err != nil {
+ return err
+ }
+ return nil
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ got := readFirstNames(1, 2)
+ want := []string{"Oum", "Eddie"}
+ if !testEqual(got, want) {
+ t.Errorf("got %v, want %v", got, want)
+ }
+
+ // Run a DML statement and an ordinary mutation in the same transaction.
+ _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+ _, err := tx.Update(ctx, Statement{
+ SQL: `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (3, "Audra", "McDonald")`,
+ })
+ if err != nil {
+ return err
+ }
+ tx.BufferWrite([]*Mutation{
+ Insert("Singers", []string{"SingerId", "FirstName", "LastName"},
+ []interface{}{4, "Andy", "Irvine"}),
+ })
+ return nil
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ got = readFirstNames(3, 4)
+ want = []string{"Audra", "Andy"}
+ if !testEqual(got, want) {
+ t.Errorf("got %v, want %v", got, want)
+ }
+
+ // Attempt to run a query using update.
+ _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+ _, err := tx.Update(ctx, Statement{SQL: `SELECT FirstName from Singers`})
+ return err
+ })
+ if got, want := ErrCode(err), codes.InvalidArgument; got != want {
+ t.Errorf("got %s, want %s", got, want)
+ }
+}
+
+func TestPDML(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+ defer cancel()
+ client, _, tearDown := prepare(ctx, t, singerDBStatements)
+ defer tearDown()
+
+ columns := []string{"SingerId", "FirstName", "LastName"}
+
+ // Populate the Singers table.
+ var muts []*Mutation
+ for _, row := range [][]interface{}{
+ {1, "Umm", "Kulthum"},
+ {2, "Eduard", "Khil"},
+ {3, "Audra", "McDonald"},
+ } {
+ muts = append(muts, Insert("Singers", columns, row))
+ }
+ if _, err := client.Apply(ctx, muts); err != nil {
+ t.Fatal(err)
+ }
+ // Identifiers in PDML statements must be fully qualified.
+ // TODO(jba): revisit the above.
+ count, err := client.PartitionedUpdate(ctx, Statement{
+ SQL: `UPDATE Singers SET Singers.FirstName = "changed" WHERE Singers.SingerId >= 1 AND Singers.SingerId <= 3`,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := int64(3); count != want {
+ t.Errorf("got %d, want %d", count, want)
+ }
+ got, err := readAll(client.Single().Read(ctx, "Singers", AllKeys(), columns))
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := [][]interface{}{
+ {int64(1), "changed", "Kulthum"},
+ {int64(2), "changed", "Khil"},
+ {int64(3), "changed", "McDonald"},
+ }
+ if !testEqual(got, want) {
+ t.Errorf("\ngot %v\nwant%v", got, want)
+ }
+}
diff --git a/spanner/internal/testutil/mockserver.go b/spanner/internal/testutil/mockserver.go
index 6fccbe4..6cd41c1 100644
--- a/spanner/internal/testutil/mockserver.go
+++ b/spanner/internal/testutil/mockserver.go
@@ -155,6 +155,16 @@
switch r.Sql {
case "SELECT * from t_unavailable":
return status.Errorf(codes.Unavailable, "mock table unavailable")
+
+ case "UPDATE t SET x = 2 WHERE x = 1":
+ err := s.Send(&sppb.PartialResultSet{
+ Stats: &sppb.ResultSetStats{RowCount: &sppb.ResultSetStats_RowCountLowerBound{3}},
+ })
+ if err != nil {
+ panic(err)
+ }
+ return nil
+
case "SELECT t.key key, t.value value FROM t_mock t":
if r.ResumeToken != nil {
s, err := DecodeResumeToken(r.ResumeToken)
@@ -233,6 +243,13 @@
go m.s.Serve(lis)
}
+// BeginTransaction is a placeholder for SpannerServer.BeginTransaction.
+func (m *MockCloudSpanner) BeginTransaction(_ context.Context, r *sppb.BeginTransactionRequest) (*sppb.Transaction, error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return &sppb.Transaction{}, nil
+}
+
// Stop terminates MockCloudSpanner and closes the serving port.
func (m *MockCloudSpanner) Stop() {
m.s.Stop()
diff --git a/spanner/pdml.go b/spanner/pdml.go
new file mode 100644
index 0000000..8cf486e
--- /dev/null
+++ b/spanner/pdml.go
@@ -0,0 +1,101 @@
+// Copyright 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package spanner
+
+import (
+ "time"
+
+ "golang.org/x/net/context"
+ "google.golang.org/api/iterator"
+ sppb "google.golang.org/genproto/googleapis/spanner/v1"
+ "google.golang.org/grpc/codes"
+)
+
+// PartitionedUpdate executes a DML statement in parallel across the database, using
+// separate, internal transactions that commit independently. The DML statement must
+// be fully partitionable: it must be expressible as the union of many statements
+// each of which accesses only a single row of the table. The statement should also be
+// idempotent, because it may be applied more than once.
+//
+// PartitionedUpdate returns an estimated count of the number of rows affected. The actual
+// number of affected rows may be greater than the estimate.
+func (c *Client) PartitionedUpdate(ctx context.Context, statement Statement) (count int64, err error) {
+ ctx = traceStartSpan(ctx, "cloud.google.com/go/spanner.PartitionedUpdate")
+ defer func() { traceEndSpan(ctx, err) }()
+ if err := checkNestedTxn(ctx); err != nil {
+ return 0, err
+ }
+
+ var (
+ tx transactionID
+ s *session
+ sh *sessionHandle
+ )
+ // create session
+ sc := c.rrNext()
+ s, err = createSession(ctx, sc, c.database, c.sessionLabels, c.md)
+ if err != nil {
+ return 0, toSpannerError(err)
+ }
+ defer s.delete(ctx)
+ sh = &sessionHandle{session: s}
+ // begin transaction
+ err = runRetryable(contextWithOutgoingMetadata(ctx, sh.getMetadata()), func(ctx context.Context) error {
+ res, e := sc.BeginTransaction(ctx, &sppb.BeginTransactionRequest{
+ Session: sh.getID(),
+ Options: &sppb.TransactionOptions{
+ Mode: &sppb.TransactionOptions_PartitionedDml_{PartitionedDml: &sppb.TransactionOptions_PartitionedDml{}},
+ },
+ })
+ if e != nil {
+ return e
+ }
+ tx = res.Id
+ return nil
+ })
+ if err != nil {
+ return 0, toSpannerError(err)
+ }
+ req := &sppb.ExecuteSqlRequest{
+ Session: sh.getID(),
+ Transaction: &sppb.TransactionSelector{
+ Selector: &sppb.TransactionSelector_Id{Id: tx},
+ },
+ Sql: statement.SQL,
+ }
+ rpc := func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) {
+ req.ResumeToken = resumeToken
+ return sc.ExecuteStreamingSql(ctx, req)
+ }
+ iter := stream(contextWithOutgoingMetadata(ctx, sh.getMetadata()),
+ rpc, func(time.Time) {}, func(error) {})
+ // TODO(jba): factor out the following code from here and ReadWriteTransaction.Update.
+ defer iter.Stop()
+ for {
+ _, err := iter.Next()
+ if err == iterator.Done {
+ break
+ }
+ if err != nil {
+ return 0, toSpannerError(err)
+ }
+ time.Sleep(time.Second)
+ }
+
+ if !iter.sawStats {
+ return 0, spannerErrorf(codes.InvalidArgument, "query passed to Update: %q", statement.SQL)
+ }
+ return iter.RowCount, nil
+}
diff --git a/spanner/pdml_test.go b/spanner/pdml_test.go
new file mode 100644
index 0000000..495d2af
--- /dev/null
+++ b/spanner/pdml_test.go
@@ -0,0 +1,61 @@
+// Copyright 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package spanner
+
+import (
+ "io"
+ "testing"
+
+ "cloud.google.com/go/spanner/internal/testutil"
+ "golang.org/x/net/context"
+ sppb "google.golang.org/genproto/googleapis/spanner/v1"
+ "google.golang.org/grpc/codes"
+)
+
+func TestMockPartitionedUpdate(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ ms := testutil.NewMockCloudSpanner(t, trxTs)
+ ms.Serve()
+ mc := sppb.NewSpannerClient(dialMock(t, ms))
+ client := &Client{database: "mockdb"}
+ client.clients = append(client.clients, mc)
+ stmt := NewStatement("UPDATE t SET x = 2 WHERE x = 1")
+ rowCount, err := client.PartitionedUpdate(ctx, stmt)
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := int64(3)
+ if rowCount != want {
+ t.Errorf("got %d, want %d", rowCount, want)
+ }
+}
+
+func TestMockPartitionedUpdateWithQuery(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ ms := testutil.NewMockCloudSpanner(t, trxTs)
+ ms.AddMsg(io.EOF, true)
+ ms.Serve()
+ mc := sppb.NewSpannerClient(dialMock(t, ms))
+ client := &Client{database: "mockdb"}
+ client.clients = append(client.clients, mc)
+ stmt := NewStatement("SELECT t.key key, t.value value FROM t_mock t")
+ _, err := client.PartitionedUpdate(ctx, stmt)
+ wantCode := codes.InvalidArgument
+ if serr, ok := err.(*Error); !ok || serr.Code != wantCode {
+ t.Errorf("got error %v, want code %s", err, wantCode)
+ }
+}
diff --git a/spanner/read.go b/spanner/read.go
index 08cbfa6..b9459e9 100644
--- a/spanner/read.go
+++ b/spanner/read.go
@@ -67,6 +67,10 @@
// if QueryWithStats was called.
QueryStats map[string]interface{}
+ // For a DML statement, the number of rows affected. For PDML, this is a lower bound.
+ // Available for DML statements after RowIterator.Next returns iterator.Done.
+ RowCount int64
+
streamd *resumableStreamDecoder
rowd *partialResultSetDecoder
setTimestamp func(time.Time)
@@ -74,6 +78,7 @@
cancel func()
err error
rows []*Row
+ sawStats bool
}
// Next returns the next result. Its second return value is iterator.Done if
@@ -86,8 +91,16 @@
for len(r.rows) == 0 && r.streamd.next() {
prs := r.streamd.get()
if prs.Stats != nil {
+ r.sawStats = true
r.QueryPlan = prs.Stats.QueryPlan
r.QueryStats = protostruct.DecodeToMap(prs.Stats.QueryStats)
+ if prs.Stats.RowCount != nil {
+ rc, err := extractRowCount(prs.Stats)
+ if err != nil {
+ return nil, err
+ }
+ r.RowCount = rc
+ }
}
r.rows, r.err = r.rowd.add(prs)
if r.err != nil {
@@ -113,6 +126,20 @@
return nil, r.err
}
+func extractRowCount(stats *sppb.ResultSetStats) (int64, error) {
+ if stats.RowCount == nil {
+ return 0, spannerErrorf(codes.Internal, "missing RowCount")
+ }
+ switch rc := stats.RowCount.(type) {
+ case *sppb.ResultSetStats_RowCountExact:
+ return rc.RowCountExact, nil
+ case *sppb.ResultSetStats_RowCountLowerBound:
+ return rc.RowCountLowerBound, nil
+ default:
+ return 0, spannerErrorf(codes.Internal, "unknown RowCount type %T", stats.RowCount)
+ }
+}
+
// Do calls the provided function once in sequence for each row in the iteration. If the
// function returns a non-nil error, Do immediately returns that error.
//
diff --git a/spanner/session.go b/spanner/session.go
index 1c922a8..142992e 100644
--- a/spanner/session.go
+++ b/spanner/session.go
@@ -260,6 +260,11 @@
// Remove s from Cloud Spanner service.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
+ s.delete(ctx)
+ return true
+}
+
+func (s *session) delete(ctx context.Context) {
// Ignore the error returned by runRetryable because even if we fail to explicitly destroy the session,
// it will be eventually garbage collected by Cloud Spanner.
err := runRetryable(ctx, func(ctx context.Context) error {
@@ -269,7 +274,6 @@
if err != nil {
log.Printf("Failed to delete session %v. Error: %v", s.getID(), err)
}
- return true
}
// prepareForWrite prepares the session for write if it is not already in that state.
@@ -464,29 +468,38 @@
doneCreate(false)
return nil, err
}
- var s *session
- err = runRetryable(ctx, func(ctx context.Context) error {
- sid, e := sc.CreateSession(ctx, &sppb.CreateSessionRequest{
- Database: p.db,
- Session: &sppb.Session{Labels: p.sessionLabels},
- })
- if e != nil {
- return e
- }
- // If no error, construct the new session.
- s = &session{valid: true, client: sc, id: sid.Name, pool: p, createTime: time.Now(), md: p.md}
- p.hc.register(s)
- return nil
- })
+ s, err := createSession(ctx, sc, p.db, p.sessionLabels, p.md)
if err != nil {
doneCreate(false)
// Should return error directly because of the previous retries on CreateSession RPC.
return nil, err
}
+ s.pool = p
+ p.hc.register(s)
doneCreate(true)
return s, nil
}
+func createSession(ctx context.Context, sc sppb.SpannerClient, db string, labels map[string]string, md metadata.MD) (*session, error) {
+ var s *session
+ err := runRetryable(ctx, func(ctx context.Context) error {
+ sid, e := sc.CreateSession(ctx, &sppb.CreateSessionRequest{
+ Database: db,
+ Session: &sppb.Session{Labels: labels},
+ })
+ if e != nil {
+ return e
+ }
+ // If no error, construct the new session.
+ s = &session{valid: true, client: sc, id: sid.Name, createTime: time.Now(), md: md}
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ return s, nil
+}
+
func (p *sessionPool) isHealthy(s *session) bool {
if s.getNextCheck().Add(2 * p.hc.getInterval()).Before(time.Now()) {
// TODO: figure out if we need to schedule a new healthcheck worker here.
diff --git a/spanner/transaction.go b/spanner/transaction.go
index b261031..7624fe3 100644
--- a/spanner/transaction.go
+++ b/spanner/transaction.go
@@ -18,6 +18,7 @@
import (
"sync"
+ "sync/atomic"
"time"
"golang.org/x/net/context"
@@ -46,6 +47,8 @@
type txReadOnly struct {
// read-transaction environment for performing transactional read operations.
txReadEnv
+
+ sequenceNumber int64 // Atomic. Only needed for DML statements, but used for all.
}
// errSessionClosed returns error for using a recycled/destroyed session
@@ -159,7 +162,7 @@
return t.query(ctx, statement, sppb.ExecuteSqlRequest_NORMAL)
}
-// Query executes a query against the database. It returns a RowIterator
+// Query executes a SQL statement against the database. It returns a RowIterator
// for retrieving the resulting rows. The RowIterator will also be populated
// with a query plan and execution statistics.
func (t *txReadOnly) QueryWithStats(ctx context.Context, statement Statement) *RowIterator {
@@ -188,29 +191,11 @@
func (t *txReadOnly) query(ctx context.Context, statement Statement, mode sppb.ExecuteSqlRequest_QueryMode) (ri *RowIterator) {
ctx = traceStartSpan(ctx, "cloud.google.com/go/spanner.Query")
defer func() { traceEndSpan(ctx, ri.err) }()
- var (
- sh *sessionHandle
- ts *sppb.TransactionSelector
- err error
- )
- if sh, ts, err = t.acquire(ctx); err != nil {
+ req, sh, err := t.prepareExecuteSql(ctx, statement, mode)
+ if err != nil {
return &RowIterator{err: err}
}
- // Cloud Spanner will return "Session not found" on bad sessions.
- sid, client := sh.getID(), sh.getClient()
- if sid == "" || client == nil {
- // Might happen if transaction is closed in the middle of a API call.
- return &RowIterator{err: errSessionClosed(sh)}
- }
- req := &sppb.ExecuteSqlRequest{
- Session: sid,
- Transaction: ts,
- Sql: statement.SQL,
- QueryMode: mode,
- }
- if err := statement.bindParams(req); err != nil {
- return &RowIterator{err: err}
- }
+ client := sh.getClient()
return stream(
contextWithOutgoingMetadata(ctx, sh.getMetadata()),
func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) {
@@ -221,6 +206,31 @@
t.release)
}
+func (t *txReadOnly) prepareExecuteSql(ctx context.Context, stmt Statement, mode sppb.ExecuteSqlRequest_QueryMode) (
+ *sppb.ExecuteSqlRequest, *sessionHandle, error) {
+ sh, ts, err := t.acquire(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+ // Cloud Spanner will return "Session not found" on bad sessions.
+ sid := sh.getID()
+ if sid == "" {
+ // Might happen if transaction is closed in the middle of a API call.
+ return nil, nil, errSessionClosed(sh)
+ }
+ req := &sppb.ExecuteSqlRequest{
+ Session: sid,
+ Transaction: ts,
+ Sql: stmt.SQL,
+ QueryMode: mode,
+ Seqno: atomic.AddInt64(&t.sequenceNumber, 1),
+ }
+ if err := stmt.bindParams(req); err != nil {
+ return nil, nil, err
+ }
+ return req, sh, nil
+}
+
// txState is the status of a transaction.
type txState int
@@ -648,6 +658,27 @@
return nil
}
+// Update executes a DML statement against the database. It returns the number of
+// affected rows.
+// Update returns an error if the statement is a query. However, the
+// query is executed, and any data read will be validated upon commit.
+func (t *ReadWriteTransaction) Update(ctx context.Context, stmt Statement) (rowCount int64, err error) {
+ ctx = traceStartSpan(ctx, "cloud.google.com/go/spanner.Update")
+ defer func() { traceEndSpan(ctx, err) }()
+ req, sh, err := t.prepareExecuteSql(ctx, stmt, sppb.ExecuteSqlRequest_NORMAL)
+ if err != nil {
+ return 0, err
+ }
+ resultSet, err := sh.getClient().ExecuteSql(ctx, req)
+ if err != nil {
+ return 0, err
+ }
+ if resultSet.Stats == nil {
+ return 0, spannerErrorf(codes.InvalidArgument, "query passed to Update: %q", stmt.SQL)
+ }
+ return extractRowCount(resultSet.Stats)
+}
+
// acquire implements txReadEnv.acquire.
func (t *ReadWriteTransaction) acquire(ctx context.Context) (*sessionHandle, *sppb.TransactionSelector, error) {
ts := &sppb.TransactionSelector{