Revert "Revert "spanner: DML support""

This reverts commit 31ef90783ca56d389e4a7168f104088269c145fc.

Reason for revert: Spanner DML backend issue is fixed.

Change-Id: Iae2f25c78f36981e0c5af6aa2dfb50d438f741e0
Reviewed-on: https://code-review.googlesource.com/c/34110
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Eno Compton <enocom@google.com>
diff --git a/spanner/big_pdml_test.go b/spanner/big_pdml_test.go
new file mode 100644
index 0000000..b7e2f31
--- /dev/null
+++ b/spanner/big_pdml_test.go
@@ -0,0 +1,61 @@
+// Copyright 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build bigtest
+
+// An integration test for PDML using a relatively large database.
+
+package spanner
+
+import (
+	"fmt"
+	"testing"
+
+	"golang.org/x/net/context"
+)
+
+func TestBigPDML(t *testing.T) {
+	const nRows int = 1e4
+
+	ctx := context.Background()
+	client, _, tearDown := prepare(ctx, t, singerDBStatements)
+	defer tearDown()
+
+	columns := []string{"SingerId", "FirstName", "LastName"}
+
+	// Populate the Singers table with random data.
+	const rowsPerApply = 1000
+	for i := 0; i < nRows; i += rowsPerApply {
+		var muts []*Mutation
+		for j := 0; j < rowsPerApply; j++ {
+			id := i + j
+			row := []interface{}{id, fmt.Sprintf("FirstName%d", id), fmt.Sprintf("LastName%d", id)}
+			muts = append(muts, Insert("Singers", columns, row))
+		}
+		if _, err := client.Apply(ctx, muts); err != nil {
+			t.Fatal(err)
+		}
+	}
+
+	// Run a PDML statement.
+	count, err := client.PartitionedUpdate(ctx, Statement{
+		SQL: `UPDATE Singers SET Singers.FirstName = "changed" WHERE Singers.SingerId != -1`,
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+	if want := int64(nRows); count != want {
+		t.Errorf("got %d, want %d", count, want)
+	}
+}
diff --git a/spanner/client.go b/spanner/client.go
index 309b903..2ff856f 100644
--- a/spanner/client.go
+++ b/spanner/client.go
@@ -18,7 +18,6 @@
 
 import (
 	"fmt"
-	"log"
 	"regexp"
 	"sync/atomic"
 	"time"
@@ -255,26 +254,12 @@
 	)
 	defer func() {
 		if err != nil && sh != nil {
-			e := runRetryable(ctx, func(ctx context.Context) error {
-				_, e := s.client.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: s.getID()})
-				return e
-			})
-			if e != nil {
-				log.Printf("Failed to delete session %v. Error: %v", s.getID(), e)
-			}
+			s.delete(ctx)
 		}
 	}()
 	// create session
 	sc := c.rrNext()
-	err = runRetryable(ctx, func(ctx context.Context) error {
-		sid, e := sc.CreateSession(ctx, &sppb.CreateSessionRequest{Database: c.database, Session: &sppb.Session{Labels: c.sessionLabels}})
-		if e != nil {
-			return e
-		}
-		// If no error, construct the new session.
-		s = &session{valid: true, client: sc, id: sid.Name, createTime: time.Now(), md: c.md}
-		return nil
-	})
+	s, err = createSession(ctx, sc, c.database, c.sessionLabels, c.md)
 	if err != nil {
 		return nil, err
 	}
diff --git a/spanner/doc.go b/spanner/doc.go
index 8f0303b..4954082 100644
--- a/spanner/doc.go
+++ b/spanner/doc.go
@@ -306,6 +306,19 @@
         return nil
     })
 
+
+DML and Partitioned DML
+
+Spanner supports DML statements like INSERT, UPDATE and DELETE. Use
+ReadWriteTransaction.Update to run DML statements. It returns the number of rows
+affected. (You can call use ReadWriteTransaction.Query with a DML statement. The first
+call to Next on the resulting RowIterator will return iterator.Done, and the RowCount
+field of the iterator will hold the number of affected rows.)
+
+For large databases, it may be more efficient to partition the DML statement. Use
+client.PartitionedUpdate to run a DML statement in this way. Not all DML statements
+can be partitioned.
+
 Tracing
 
 This client has been instrumented to use OpenCensus tracing (http://opencensus.io).
diff --git a/spanner/integration_test.go b/spanner/integration_test.go
index 9b28fef..eabfe49 100644
--- a/spanner/integration_test.go
+++ b/spanner/integration_test.go
@@ -949,7 +949,7 @@
 // Test client recovery on database recreation.
 func TestDbRemovalRecovery(t *testing.T) {
 	t.Parallel()
-	ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
+	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
 	defer cancel()
 	client, dbPath, tearDown := prepare(ctx, t, singerDBStatements)
 	defer tearDown()
@@ -1924,3 +1924,210 @@
 		t.Error(msg)
 	}
 }
+
+func TestDML(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+	defer cancel()
+	client, _, tearDown := prepare(ctx, t, singerDBStatements)
+	defer tearDown()
+
+	// Function that reads a single row's first name from within a transaction.
+	readFirstName := func(tx *ReadWriteTransaction, key int) (string, error) {
+		row, err := tx.ReadRow(ctx, "Singers", Key{key}, []string{"FirstName"})
+		if err != nil {
+			return "", err
+		}
+		var fn string
+		if err := row.Column(0, &fn); err != nil {
+			return "", err
+		}
+		return fn, nil
+	}
+
+	// Function that reads multiple rows' first names from outside a read/write transaction.
+	readFirstNames := func(keys ...int) []string {
+		var ks []KeySet
+		for _, k := range keys {
+			ks = append(ks, Key{k})
+		}
+		iter := client.Single().Read(ctx, "Singers", KeySets(ks...), []string{"FirstName"})
+		var got []string
+		var fn string
+		err := iter.Do(func(row *Row) error {
+			if err := row.Column(0, &fn); err != nil {
+				return err
+			}
+			got = append(got, fn)
+			return nil
+		})
+		if err != nil {
+			t.Fatalf("readFirstNames(%v): %v", keys, err)
+		}
+		return got
+	}
+
+	// Use ReadWriteTransaction.Query to execute a DML statement.
+	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+		iter := tx.Query(ctx, Statement{
+			SQL: `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (1, "Umm", "Kulthum")`,
+		})
+		defer iter.Stop()
+		if row, err := iter.Next(); err != iterator.Done {
+			t.Fatalf("got results from iterator, want none: %#v, err = %v\n", row, err)
+		}
+		if iter.RowCount != 1 {
+			t.Errorf("row count: got %d, want 1", iter.RowCount)
+		}
+		// The results of the DML statement should be visible to the transaction.
+		got, err := readFirstName(tx, 1)
+		if err != nil {
+			return err
+		}
+		if want := "Umm"; got != want {
+			t.Errorf("got %q, want %q", got, want)
+		}
+		return nil
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Use ReadWriteTransaction.Update to execute a DML statement.
+	_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+		count, err := tx.Update(ctx, Statement{
+			SQL: `Insert INTO Singers (SingerId, FirstName, LastName) VALUES (2, "Eduard", "Khil")`,
+		})
+		if count != 1 {
+			t.Errorf("row count: got %d, want 1", count)
+		}
+		got, err := readFirstName(tx, 2)
+		if err != nil {
+			return err
+		}
+		if want := "Eduard"; got != want {
+			t.Errorf("got %q, want %q", got, want)
+		}
+		return nil
+
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Roll back a DML statement and confirm that it didn't happen.
+	var fail = errors.New("fail")
+	_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+		_, err := tx.Update(ctx, Statement{
+			SQL: `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (3, "Audra", "McDonald")`,
+		})
+		if err != nil {
+			return err
+		}
+		return fail
+	})
+	if err != fail {
+		t.Fatalf("rolling back: got error %v, want the error 'fail'", err)
+	}
+	_, err = client.Single().ReadRow(ctx, "Singers", Key{3}, []string{"FirstName"})
+	if got, want := ErrCode(err), codes.NotFound; got != want {
+		t.Errorf("got %s, want %s", got, want)
+	}
+
+	// Run two DML statements in the same transaction.
+	_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+		_, err := tx.Update(ctx, Statement{SQL: `UPDATE Singers SET FirstName = "Oum" WHERE SingerId = 1`})
+		if err != nil {
+			return err
+		}
+		_, err = tx.Update(ctx, Statement{SQL: `UPDATE Singers SET FirstName = "Eddie" WHERE SingerId = 2`})
+		if err != nil {
+			return err
+		}
+		return nil
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+	got := readFirstNames(1, 2)
+	want := []string{"Oum", "Eddie"}
+	if !testEqual(got, want) {
+		t.Errorf("got %v, want %v", got, want)
+	}
+
+	// Run a DML statement and an ordinary mutation in the same transaction.
+	_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+		_, err := tx.Update(ctx, Statement{
+			SQL: `INSERT INTO Singers (SingerId, FirstName, LastName) VALUES (3, "Audra", "McDonald")`,
+		})
+		if err != nil {
+			return err
+		}
+		tx.BufferWrite([]*Mutation{
+			Insert("Singers", []string{"SingerId", "FirstName", "LastName"},
+				[]interface{}{4, "Andy", "Irvine"}),
+		})
+		return nil
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+	got = readFirstNames(3, 4)
+	want = []string{"Audra", "Andy"}
+	if !testEqual(got, want) {
+		t.Errorf("got %v, want %v", got, want)
+	}
+
+	// Attempt to run a query using update.
+	_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
+		_, err := tx.Update(ctx, Statement{SQL: `SELECT FirstName from Singers`})
+		return err
+	})
+	if got, want := ErrCode(err), codes.InvalidArgument; got != want {
+		t.Errorf("got %s, want %s", got, want)
+	}
+}
+
+func TestPDML(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+	defer cancel()
+	client, _, tearDown := prepare(ctx, t, singerDBStatements)
+	defer tearDown()
+
+	columns := []string{"SingerId", "FirstName", "LastName"}
+
+	// Populate the Singers table.
+	var muts []*Mutation
+	for _, row := range [][]interface{}{
+		{1, "Umm", "Kulthum"},
+		{2, "Eduard", "Khil"},
+		{3, "Audra", "McDonald"},
+	} {
+		muts = append(muts, Insert("Singers", columns, row))
+	}
+	if _, err := client.Apply(ctx, muts); err != nil {
+		t.Fatal(err)
+	}
+	// Identifiers in PDML statements must be fully qualified.
+	// TODO(jba): revisit the above.
+	count, err := client.PartitionedUpdate(ctx, Statement{
+		SQL: `UPDATE Singers SET Singers.FirstName = "changed" WHERE Singers.SingerId >= 1 AND Singers.SingerId <= 3`,
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+	if want := int64(3); count != want {
+		t.Errorf("got %d, want %d", count, want)
+	}
+	got, err := readAll(client.Single().Read(ctx, "Singers", AllKeys(), columns))
+	if err != nil {
+		t.Fatal(err)
+	}
+	want := [][]interface{}{
+		{int64(1), "changed", "Kulthum"},
+		{int64(2), "changed", "Khil"},
+		{int64(3), "changed", "McDonald"},
+	}
+	if !testEqual(got, want) {
+		t.Errorf("\ngot %v\nwant%v", got, want)
+	}
+}
diff --git a/spanner/internal/testutil/mockserver.go b/spanner/internal/testutil/mockserver.go
index 6fccbe4..6cd41c1 100644
--- a/spanner/internal/testutil/mockserver.go
+++ b/spanner/internal/testutil/mockserver.go
@@ -155,6 +155,16 @@
 	switch r.Sql {
 	case "SELECT * from t_unavailable":
 		return status.Errorf(codes.Unavailable, "mock table unavailable")
+
+	case "UPDATE t SET x = 2 WHERE x = 1":
+		err := s.Send(&sppb.PartialResultSet{
+			Stats: &sppb.ResultSetStats{RowCount: &sppb.ResultSetStats_RowCountLowerBound{3}},
+		})
+		if err != nil {
+			panic(err)
+		}
+		return nil
+
 	case "SELECT t.key key, t.value value FROM t_mock t":
 		if r.ResumeToken != nil {
 			s, err := DecodeResumeToken(r.ResumeToken)
@@ -233,6 +243,13 @@
 	go m.s.Serve(lis)
 }
 
+// BeginTransaction is a placeholder for SpannerServer.BeginTransaction.
+func (m *MockCloudSpanner) BeginTransaction(_ context.Context, r *sppb.BeginTransactionRequest) (*sppb.Transaction, error) {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	return &sppb.Transaction{}, nil
+}
+
 // Stop terminates MockCloudSpanner and closes the serving port.
 func (m *MockCloudSpanner) Stop() {
 	m.s.Stop()
diff --git a/spanner/pdml.go b/spanner/pdml.go
new file mode 100644
index 0000000..8cf486e
--- /dev/null
+++ b/spanner/pdml.go
@@ -0,0 +1,101 @@
+// Copyright 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package spanner
+
+import (
+	"time"
+
+	"golang.org/x/net/context"
+	"google.golang.org/api/iterator"
+	sppb "google.golang.org/genproto/googleapis/spanner/v1"
+	"google.golang.org/grpc/codes"
+)
+
+// PartitionedUpdate executes a DML statement in parallel across the database, using
+// separate, internal transactions that commit independently. The DML statement must
+// be fully partitionable: it must be expressible as the union of many statements
+// each of which accesses only a single row of the table. The statement should also be
+// idempotent, because it may be applied more than once.
+//
+// PartitionedUpdate returns an estimated count of the number of rows affected. The actual
+// number of affected rows may be greater than the estimate.
+func (c *Client) PartitionedUpdate(ctx context.Context, statement Statement) (count int64, err error) {
+	ctx = traceStartSpan(ctx, "cloud.google.com/go/spanner.PartitionedUpdate")
+	defer func() { traceEndSpan(ctx, err) }()
+	if err := checkNestedTxn(ctx); err != nil {
+		return 0, err
+	}
+
+	var (
+		tx transactionID
+		s  *session
+		sh *sessionHandle
+	)
+	// create session
+	sc := c.rrNext()
+	s, err = createSession(ctx, sc, c.database, c.sessionLabels, c.md)
+	if err != nil {
+		return 0, toSpannerError(err)
+	}
+	defer s.delete(ctx)
+	sh = &sessionHandle{session: s}
+	// begin transaction
+	err = runRetryable(contextWithOutgoingMetadata(ctx, sh.getMetadata()), func(ctx context.Context) error {
+		res, e := sc.BeginTransaction(ctx, &sppb.BeginTransactionRequest{
+			Session: sh.getID(),
+			Options: &sppb.TransactionOptions{
+				Mode: &sppb.TransactionOptions_PartitionedDml_{PartitionedDml: &sppb.TransactionOptions_PartitionedDml{}},
+			},
+		})
+		if e != nil {
+			return e
+		}
+		tx = res.Id
+		return nil
+	})
+	if err != nil {
+		return 0, toSpannerError(err)
+	}
+	req := &sppb.ExecuteSqlRequest{
+		Session: sh.getID(),
+		Transaction: &sppb.TransactionSelector{
+			Selector: &sppb.TransactionSelector_Id{Id: tx},
+		},
+		Sql: statement.SQL,
+	}
+	rpc := func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) {
+		req.ResumeToken = resumeToken
+		return sc.ExecuteStreamingSql(ctx, req)
+	}
+	iter := stream(contextWithOutgoingMetadata(ctx, sh.getMetadata()),
+		rpc, func(time.Time) {}, func(error) {})
+	// TODO(jba): factor out the following code from here and ReadWriteTransaction.Update.
+	defer iter.Stop()
+	for {
+		_, err := iter.Next()
+		if err == iterator.Done {
+			break
+		}
+		if err != nil {
+			return 0, toSpannerError(err)
+		}
+		time.Sleep(time.Second)
+	}
+
+	if !iter.sawStats {
+		return 0, spannerErrorf(codes.InvalidArgument, "query passed to Update: %q", statement.SQL)
+	}
+	return iter.RowCount, nil
+}
diff --git a/spanner/pdml_test.go b/spanner/pdml_test.go
new file mode 100644
index 0000000..495d2af
--- /dev/null
+++ b/spanner/pdml_test.go
@@ -0,0 +1,61 @@
+// Copyright 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package spanner
+
+import (
+	"io"
+	"testing"
+
+	"cloud.google.com/go/spanner/internal/testutil"
+	"golang.org/x/net/context"
+	sppb "google.golang.org/genproto/googleapis/spanner/v1"
+	"google.golang.org/grpc/codes"
+)
+
+func TestMockPartitionedUpdate(t *testing.T) {
+	t.Parallel()
+	ctx := context.Background()
+	ms := testutil.NewMockCloudSpanner(t, trxTs)
+	ms.Serve()
+	mc := sppb.NewSpannerClient(dialMock(t, ms))
+	client := &Client{database: "mockdb"}
+	client.clients = append(client.clients, mc)
+	stmt := NewStatement("UPDATE t SET x = 2 WHERE x = 1")
+	rowCount, err := client.PartitionedUpdate(ctx, stmt)
+	if err != nil {
+		t.Fatal(err)
+	}
+	want := int64(3)
+	if rowCount != want {
+		t.Errorf("got %d, want %d", rowCount, want)
+	}
+}
+
+func TestMockPartitionedUpdateWithQuery(t *testing.T) {
+	t.Parallel()
+	ctx := context.Background()
+	ms := testutil.NewMockCloudSpanner(t, trxTs)
+	ms.AddMsg(io.EOF, true)
+	ms.Serve()
+	mc := sppb.NewSpannerClient(dialMock(t, ms))
+	client := &Client{database: "mockdb"}
+	client.clients = append(client.clients, mc)
+	stmt := NewStatement("SELECT t.key key, t.value value FROM t_mock t")
+	_, err := client.PartitionedUpdate(ctx, stmt)
+	wantCode := codes.InvalidArgument
+	if serr, ok := err.(*Error); !ok || serr.Code != wantCode {
+		t.Errorf("got error %v, want code %s", err, wantCode)
+	}
+}
diff --git a/spanner/read.go b/spanner/read.go
index 08cbfa6..b9459e9 100644
--- a/spanner/read.go
+++ b/spanner/read.go
@@ -67,6 +67,10 @@
 	// if QueryWithStats was called.
 	QueryStats map[string]interface{}
 
+	// For a DML statement, the number of rows affected. For PDML, this is a lower bound.
+	// Available for DML statements after RowIterator.Next returns iterator.Done.
+	RowCount int64
+
 	streamd      *resumableStreamDecoder
 	rowd         *partialResultSetDecoder
 	setTimestamp func(time.Time)
@@ -74,6 +78,7 @@
 	cancel       func()
 	err          error
 	rows         []*Row
+	sawStats     bool
 }
 
 // Next returns the next result. Its second return value is iterator.Done if
@@ -86,8 +91,16 @@
 	for len(r.rows) == 0 && r.streamd.next() {
 		prs := r.streamd.get()
 		if prs.Stats != nil {
+			r.sawStats = true
 			r.QueryPlan = prs.Stats.QueryPlan
 			r.QueryStats = protostruct.DecodeToMap(prs.Stats.QueryStats)
+			if prs.Stats.RowCount != nil {
+				rc, err := extractRowCount(prs.Stats)
+				if err != nil {
+					return nil, err
+				}
+				r.RowCount = rc
+			}
 		}
 		r.rows, r.err = r.rowd.add(prs)
 		if r.err != nil {
@@ -113,6 +126,20 @@
 	return nil, r.err
 }
 
+func extractRowCount(stats *sppb.ResultSetStats) (int64, error) {
+	if stats.RowCount == nil {
+		return 0, spannerErrorf(codes.Internal, "missing RowCount")
+	}
+	switch rc := stats.RowCount.(type) {
+	case *sppb.ResultSetStats_RowCountExact:
+		return rc.RowCountExact, nil
+	case *sppb.ResultSetStats_RowCountLowerBound:
+		return rc.RowCountLowerBound, nil
+	default:
+		return 0, spannerErrorf(codes.Internal, "unknown RowCount type %T", stats.RowCount)
+	}
+}
+
 // Do calls the provided function once in sequence for each row in the iteration.  If the
 // function returns a non-nil error, Do immediately returns that error.
 //
diff --git a/spanner/session.go b/spanner/session.go
index 1c922a8..142992e 100644
--- a/spanner/session.go
+++ b/spanner/session.go
@@ -260,6 +260,11 @@
 	// Remove s from Cloud Spanner service.
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 	defer cancel()
+	s.delete(ctx)
+	return true
+}
+
+func (s *session) delete(ctx context.Context) {
 	// Ignore the error returned by runRetryable because even if we fail to explicitly destroy the session,
 	// it will be eventually garbage collected by Cloud Spanner.
 	err := runRetryable(ctx, func(ctx context.Context) error {
@@ -269,7 +274,6 @@
 	if err != nil {
 		log.Printf("Failed to delete session %v. Error: %v", s.getID(), err)
 	}
-	return true
 }
 
 // prepareForWrite prepares the session for write if it is not already in that state.
@@ -464,29 +468,38 @@
 		doneCreate(false)
 		return nil, err
 	}
-	var s *session
-	err = runRetryable(ctx, func(ctx context.Context) error {
-		sid, e := sc.CreateSession(ctx, &sppb.CreateSessionRequest{
-			Database: p.db,
-			Session:  &sppb.Session{Labels: p.sessionLabels},
-		})
-		if e != nil {
-			return e
-		}
-		// If no error, construct the new session.
-		s = &session{valid: true, client: sc, id: sid.Name, pool: p, createTime: time.Now(), md: p.md}
-		p.hc.register(s)
-		return nil
-	})
+	s, err := createSession(ctx, sc, p.db, p.sessionLabels, p.md)
 	if err != nil {
 		doneCreate(false)
 		// Should return error directly because of the previous retries on CreateSession RPC.
 		return nil, err
 	}
+	s.pool = p
+	p.hc.register(s)
 	doneCreate(true)
 	return s, nil
 }
 
+func createSession(ctx context.Context, sc sppb.SpannerClient, db string, labels map[string]string, md metadata.MD) (*session, error) {
+	var s *session
+	err := runRetryable(ctx, func(ctx context.Context) error {
+		sid, e := sc.CreateSession(ctx, &sppb.CreateSessionRequest{
+			Database: db,
+			Session:  &sppb.Session{Labels: labels},
+		})
+		if e != nil {
+			return e
+		}
+		// If no error, construct the new session.
+		s = &session{valid: true, client: sc, id: sid.Name, createTime: time.Now(), md: md}
+		return nil
+	})
+	if err != nil {
+		return nil, err
+	}
+	return s, nil
+}
+
 func (p *sessionPool) isHealthy(s *session) bool {
 	if s.getNextCheck().Add(2 * p.hc.getInterval()).Before(time.Now()) {
 		// TODO: figure out if we need to schedule a new healthcheck worker here.
diff --git a/spanner/transaction.go b/spanner/transaction.go
index b261031..7624fe3 100644
--- a/spanner/transaction.go
+++ b/spanner/transaction.go
@@ -18,6 +18,7 @@
 
 import (
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"golang.org/x/net/context"
@@ -46,6 +47,8 @@
 type txReadOnly struct {
 	// read-transaction environment for performing transactional read operations.
 	txReadEnv
+
+	sequenceNumber int64 // Atomic. Only needed for DML statements, but used for all.
 }
 
 // errSessionClosed returns error for using a recycled/destroyed session
@@ -159,7 +162,7 @@
 	return t.query(ctx, statement, sppb.ExecuteSqlRequest_NORMAL)
 }
 
-// Query executes a query against the database. It returns a RowIterator
+// Query executes a SQL statement against the database. It returns a RowIterator
 // for retrieving the resulting rows. The RowIterator will also be populated
 // with a query plan and execution statistics.
 func (t *txReadOnly) QueryWithStats(ctx context.Context, statement Statement) *RowIterator {
@@ -188,29 +191,11 @@
 func (t *txReadOnly) query(ctx context.Context, statement Statement, mode sppb.ExecuteSqlRequest_QueryMode) (ri *RowIterator) {
 	ctx = traceStartSpan(ctx, "cloud.google.com/go/spanner.Query")
 	defer func() { traceEndSpan(ctx, ri.err) }()
-	var (
-		sh  *sessionHandle
-		ts  *sppb.TransactionSelector
-		err error
-	)
-	if sh, ts, err = t.acquire(ctx); err != nil {
+	req, sh, err := t.prepareExecuteSql(ctx, statement, mode)
+	if err != nil {
 		return &RowIterator{err: err}
 	}
-	// Cloud Spanner will return "Session not found" on bad sessions.
-	sid, client := sh.getID(), sh.getClient()
-	if sid == "" || client == nil {
-		// Might happen if transaction is closed in the middle of a API call.
-		return &RowIterator{err: errSessionClosed(sh)}
-	}
-	req := &sppb.ExecuteSqlRequest{
-		Session:     sid,
-		Transaction: ts,
-		Sql:         statement.SQL,
-		QueryMode:   mode,
-	}
-	if err := statement.bindParams(req); err != nil {
-		return &RowIterator{err: err}
-	}
+	client := sh.getClient()
 	return stream(
 		contextWithOutgoingMetadata(ctx, sh.getMetadata()),
 		func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) {
@@ -221,6 +206,31 @@
 		t.release)
 }
 
+func (t *txReadOnly) prepareExecuteSql(ctx context.Context, stmt Statement, mode sppb.ExecuteSqlRequest_QueryMode) (
+	*sppb.ExecuteSqlRequest, *sessionHandle, error) {
+	sh, ts, err := t.acquire(ctx)
+	if err != nil {
+		return nil, nil, err
+	}
+	// Cloud Spanner will return "Session not found" on bad sessions.
+	sid := sh.getID()
+	if sid == "" {
+		// Might happen if transaction is closed in the middle of a API call.
+		return nil, nil, errSessionClosed(sh)
+	}
+	req := &sppb.ExecuteSqlRequest{
+		Session:     sid,
+		Transaction: ts,
+		Sql:         stmt.SQL,
+		QueryMode:   mode,
+		Seqno:       atomic.AddInt64(&t.sequenceNumber, 1),
+	}
+	if err := stmt.bindParams(req); err != nil {
+		return nil, nil, err
+	}
+	return req, sh, nil
+}
+
 // txState is the status of a transaction.
 type txState int
 
@@ -648,6 +658,27 @@
 	return nil
 }
 
+// Update executes a DML statement against the database. It returns the number of
+// affected rows.
+// Update returns an error if the statement is a query. However, the
+// query is executed, and any data read will be validated upon commit.
+func (t *ReadWriteTransaction) Update(ctx context.Context, stmt Statement) (rowCount int64, err error) {
+	ctx = traceStartSpan(ctx, "cloud.google.com/go/spanner.Update")
+	defer func() { traceEndSpan(ctx, err) }()
+	req, sh, err := t.prepareExecuteSql(ctx, stmt, sppb.ExecuteSqlRequest_NORMAL)
+	if err != nil {
+		return 0, err
+	}
+	resultSet, err := sh.getClient().ExecuteSql(ctx, req)
+	if err != nil {
+		return 0, err
+	}
+	if resultSet.Stats == nil {
+		return 0, spannerErrorf(codes.InvalidArgument, "query passed to Update: %q", stmt.SQL)
+	}
+	return extractRowCount(resultSet.Stats)
+}
+
 // acquire implements txReadEnv.acquire.
 func (t *ReadWriteTransaction) acquire(ctx context.Context) (*sessionHandle, *sppb.TransactionSelector, error) {
 	ts := &sppb.TransactionSelector{