feat(spanner/spannertest): support UPDATE DML (#3201)
Fixes #3162.
diff --git a/spanner/spannertest/README.md b/spanner/spannertest/README.md
index 6f820f9..d05f9d2 100644
--- a/spanner/spannertest/README.md
+++ b/spanner/spannertest/README.md
@@ -19,16 +19,16 @@
- expression functions
- more aggregation functions
-- INSERT/UPDATE DML statements
- SELECT HAVING
- case insensitivity
-- FULL JOIN
+- FULL JOIN, multiple joins
- alternate literal types (esp. strings)
- STRUCT types
- transaction simulation
- expression type casting, coercion
- subselects
- FOREIGN KEY and CHECK constraints
+- INSERT DML statements
- set operations (UNION, INTERSECT, EXCEPT)
- partition support
- conditional expressions
diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go
index 810bcec..d1a9ef2 100644
--- a/spanner/spannertest/db.go
+++ b/spanner/spannertest/db.go
@@ -999,6 +999,65 @@
i++
}
return n, nil
+ case *spansql.Update:
+ t, err := d.table(stmt.Table)
+ if err != nil {
+ return 0, err
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ ec := evalContext{
+ cols: t.cols,
+ params: params,
+ }
+
+ // Build parallel slices of destination column index and expressions to evaluate.
+ var dstIndex []int
+ var expr []spansql.Expr
+ for _, ui := range stmt.Items {
+ i, err := ec.resolveColumnIndex(ui.Column)
+ if err != nil {
+ return 0, err
+ }
+ // TODO: Enforce "A column can appear only once in the SET clause.".
+ if i < t.pkCols {
+ return 0, status.Errorf(codes.InvalidArgument, "cannot update primary key %s", ui.Column)
+ }
+ dstIndex = append(dstIndex, i)
+ expr = append(expr, ui.Value)
+ }
+
+ n := 0
+ values := make(row, len(stmt.Items)) // scratch space for new values
+ for i := 0; i < len(t.rows); i++ {
+ ec.row = t.rows[i]
+ b, err := ec.evalBoolExpr(stmt.Where)
+ if err != nil {
+ return 0, err
+ }
+ if b != nil && *b {
+ // Compute every update item.
+ for j := range dstIndex {
+ if expr[j] == nil { // DEFAULT
+ values[j] = nil
+ continue
+ }
+ v, err := ec.evalExpr(expr[j])
+ if err != nil {
+ return 0, err
+ }
+ values[j] = v
+ }
+ // Write them to the row.
+ for j, v := range values {
+ t.rows[i][dstIndex[j]] = v
+ }
+ n++
+ }
+ }
+ return n, nil
}
}
diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go
index adbe329..f70b01f 100644
--- a/spanner/spannertest/integration_test.go
+++ b/spanner/spannertest/integration_test.go
@@ -412,7 +412,7 @@
"Staff",
"PlayerStats",
"JoinA", "JoinB", "JoinC", "JoinD", "JoinE", "JoinF",
- "SomeStrings",
+ "SomeStrings", "Updateable",
}
errc := make(chan error)
for _, table := range allTables {
@@ -618,6 +618,11 @@
`CREATE TABLE JoinF ( y INT64, z STRING(MAX) ) PRIMARY KEY (y, z)`,
// Some other test tables.
`CREATE TABLE SomeStrings ( i INT64, str STRING(MAX) ) PRIMARY KEY (i)`,
+ `CREATE TABLE Updateable (
+ id INT64,
+ first STRING(MAX),
+ last STRING(MAX),
+ ) PRIMARY KEY (id)`,
)
if err != nil {
t.Fatalf("Creating sample tables: %v", err)
@@ -661,11 +666,39 @@
spanner.Insert("SomeStrings", []string{"i", "str"}, []interface{}{1, "abar"}),
spanner.Insert("SomeStrings", []string{"i", "str"}, []interface{}{2, nil}),
spanner.Insert("SomeStrings", []string{"i", "str"}, []interface{}{3, "bbar"}),
+
+ spanner.Insert("Updateable", []string{"id", "first", "last"}, []interface{}{0, "joe", nil}),
+ spanner.Insert("Updateable", []string{"id", "first", "last"}, []interface{}{1, "doe", "joan"}),
+ spanner.Insert("Updateable", []string{"id", "first", "last"}, []interface{}{2, "wong", "wong"}),
})
if err != nil {
t.Fatalf("Inserting sample data: %v", err)
}
+ // Perform UPDATE DML; the results are checked later on.
+ n = 0
+ _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
+ for _, u := range []string{
+ `UPDATE Updateable SET last = "bloggs" WHERE id = 0`,
+ `UPDATE Updateable SET first = last, last = first WHERE id = 1`,
+ `UPDATE Updateable SET last = DEFAULT WHERE id = 2`,
+ `UPDATE Updateable SET first = "noname" WHERE id = 3`, // no id=3
+ } {
+ nr, err := tx.Update(ctx, spanner.NewStatement(u))
+ if err != nil {
+ return err
+ }
+ n += nr
+ }
+ return nil
+ })
+ if err != nil {
+ t.Fatalf("Updating with DML: %v", err)
+ }
+ if n != 3 {
+ t.Errorf("Updating with DML affected %d rows, want 3", n)
+ }
+
// Do some complex queries.
tests := []struct {
q string
@@ -976,6 +1009,16 @@
{int64(4), nil, "p"},
},
},
+ // Check the output of the UPDATE DML.
+ {
+ `SELECT id, first, last FROM Updateable ORDER BY id`,
+ nil,
+ [][]interface{}{
+ {int64(0), "joe", "bloggs"},
+ {int64(1), "joan", "doe"},
+ {int64(2), "wong", nil},
+ },
+ },
// Regression test for aggregating no rows; it used to return an empty row.
// https://github.com/googleapis/google-cloud-go/issues/2793
{