| /* |
| Copyright 2019 Google LLC |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| */ |
| |
| /* |
| Package spannertest contains test helpers for working with Cloud Spanner. |
| |
| This package is EXPERIMENTAL, and is lacking several features. See the README.md |
| file in this directory for more details. |
| |
| In-memory fake |
| |
| This package has an in-memory fake implementation of spanner. To use it, |
| create a Server, and then connect to it with no security: |
| srv, err := spannertest.NewServer("localhost:0") |
| ... |
| conn, err := grpc.DialContext(ctx, srv.Addr, grpc.WithInsecure()) |
| ... |
| client, err := spanner.NewClient(ctx, db, option.WithGRPCConn(conn)) |
| ... |
| |
| Alternatively, create a Server, then set the SPANNER_EMULATOR_HOST environment |
| variable and use the regular spanner.NewClient: |
| srv, err := spannertest.NewServer("localhost:0") |
| ... |
| os.Setenv("SPANNER_EMULATOR_HOST", srv.Addr) |
| client, err := spanner.NewClient(ctx, db) |
| ... |
| |
| The same server also supports database admin operations for use with |
| the cloud.google.com/go/spanner/admin/database/apiv1 package. This only |
| simulates the existence of a single database; its name is ignored. |
| */ |
| package spannertest |
| |
| import ( |
| "context" |
| "encoding/base64" |
| "fmt" |
| "io" |
| "log" |
| "math/rand" |
| "net" |
| "strconv" |
| "sync" |
| "sync/atomic" |
| "time" |
| |
| "github.com/golang/protobuf/proto" |
| "github.com/golang/protobuf/ptypes" |
| "google.golang.org/grpc" |
| "google.golang.org/grpc/codes" |
| "google.golang.org/grpc/status" |
| |
| anypb "github.com/golang/protobuf/ptypes/any" |
| emptypb "github.com/golang/protobuf/ptypes/empty" |
| structpb "github.com/golang/protobuf/ptypes/struct" |
| timestamppb "github.com/golang/protobuf/ptypes/timestamp" |
| lropb "google.golang.org/genproto/googleapis/longrunning" |
| adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" |
| spannerpb "google.golang.org/genproto/googleapis/spanner/v1" |
| |
| "cloud.google.com/go/civil" |
| "cloud.google.com/go/spanner/spansql" |
| ) |
| |
| // Server is an in-memory Cloud Spanner fake. |
| // It is unauthenticated, non-performant, and only a rough approximation. |
| type Server struct { |
| Addr string |
| |
| l net.Listener |
| srv *grpc.Server |
| s *server |
| } |
| |
| // server is the real implementation of the fake. |
| // It is a separate and unexported type so the API won't be cluttered with |
| // methods that are only relevant to the fake's implementation. |
| type server struct { |
| logf Logger |
| |
| db database |
| start time.Time |
| |
| mu sync.Mutex |
| sessions map[string]*session |
| lros map[string]*lro |
| |
| // Any unimplemented methods will cause a panic. |
| // TODO: Switch to Unimplemented at some point? spannerpb would need regenerating. |
| adminpb.DatabaseAdminServer |
| spannerpb.SpannerServer |
| lropb.OperationsServer |
| } |
| |
| type session struct { |
| name string |
| creation time.Time |
| |
| // This context tracks the lifetime of this session. |
| // It is canceled in DeleteSession. |
| ctx context.Context |
| cancel func() |
| |
| mu sync.Mutex |
| lastUse time.Time |
| transactions map[string]*transaction |
| } |
| |
| func (s *session) Proto() *spannerpb.Session { |
| s.mu.Lock() |
| defer s.mu.Unlock() |
| m := &spannerpb.Session{ |
| Name: s.name, |
| CreateTime: timestampProto(s.creation), |
| ApproximateLastUseTime: timestampProto(s.lastUse), |
| } |
| return m |
| } |
| |
| // timestampProto returns a valid timestamp.Timestamp, |
| // or nil if the given time is zero or isn't representable. |
| func timestampProto(t time.Time) *timestamppb.Timestamp { |
| if t.IsZero() { |
| return nil |
| } |
| ts, err := ptypes.TimestampProto(t) |
| if err != nil { |
| return nil |
| } |
| return ts |
| } |
| |
| // lro represents a Long-Running Operation, generally a schema change. |
| type lro struct { |
| mu sync.Mutex |
| state *lropb.Operation |
| |
| // waitc is closed when anyone starts waiting on the LRO. |
| // waitatom is CAS'd from 0 to 1 to make that closing safe. |
| waitc chan struct{} |
| waitatom int32 |
| } |
| |
| func newLRO(initState *lropb.Operation) *lro { |
| return &lro{ |
| state: initState, |
| waitc: make(chan struct{}), |
| } |
| } |
| |
| // noWait causes the LRO to stop the artificial delay when applying the operation. |
| // It returns whether this was the first invocation of this method. |
| func (l *lro) noWait() bool { |
| if atomic.CompareAndSwapInt32(&l.waitatom, 0, 1) { |
| close(l.waitc) |
| return true |
| } |
| return false |
| } |
| |
| func (l *lro) State() *lropb.Operation { |
| l.mu.Lock() |
| defer l.mu.Unlock() |
| return proto.Clone(l.state).(*lropb.Operation) |
| } |
| |
| // Logger is something that can be used for logging. |
| // It is matched by log.Printf and testing.T.Logf. |
| type Logger func(format string, args ...interface{}) |
| |
| // NewServer creates a new Server. |
| // The Server will be listening for gRPC connections, without TLS, on the provided TCP address. |
| // The resolved address is available in the Addr field. |
| func NewServer(laddr string) (*Server, error) { |
| l, err := net.Listen("tcp", laddr) |
| if err != nil { |
| return nil, err |
| } |
| |
| s := &Server{ |
| Addr: l.Addr().String(), |
| l: l, |
| srv: grpc.NewServer(), |
| s: &server{ |
| logf: func(format string, args ...interface{}) { |
| log.Printf("spannertest.inmem: "+format, args...) |
| }, |
| start: time.Now(), |
| sessions: make(map[string]*session), |
| lros: make(map[string]*lro), |
| }, |
| } |
| adminpb.RegisterDatabaseAdminServer(s.srv, s.s) |
| spannerpb.RegisterSpannerServer(s.srv, s.s) |
| lropb.RegisterOperationsServer(s.srv, s.s) |
| |
| go s.srv.Serve(s.l) |
| |
| return s, nil |
| } |
| |
| // SetLogger sets a logger for the server. |
| // You can use a *testing.T as this argument to collate extra information |
| // from the execution of the server. |
| func (s *Server) SetLogger(l Logger) { s.s.logf = l } |
| |
| // Close shuts down the server. |
| func (s *Server) Close() { |
| s.srv.Stop() |
| s.l.Close() |
| } |
| |
| func genRandomSession() string { |
| var b [4]byte |
| rand.Read(b[:]) |
| return fmt.Sprintf("%x", b) |
| } |
| |
| func genRandomTransaction() string { |
| var b [6]byte |
| rand.Read(b[:]) |
| return fmt.Sprintf("tx-%x", b) |
| } |
| |
| func genRandomOperation() string { |
| var b [3]byte |
| rand.Read(b[:]) |
| return fmt.Sprintf("op-%x", b) |
| } |
| |
| func (s *server) GetOperation(ctx context.Context, req *lropb.GetOperationRequest) (*lropb.Operation, error) { |
| s.mu.Lock() |
| lro, ok := s.lros[req.Name] |
| s.mu.Unlock() |
| if !ok { |
| return nil, status.Errorf(codes.NotFound, "unknown LRO %q", req.Name) |
| } |
| |
| // Someone is waiting on this LRO. Disable sleeping in its Run method. |
| if lro.noWait() { |
| // The sleeping has been canceled for the first time. |
| // Have a slight pause to give the LRO a chance to complete. |
| time.Sleep(50 * time.Millisecond) |
| } |
| |
| return lro.State(), nil |
| } |
| |
| func (s *server) GetDatabase(ctx context.Context, req *adminpb.GetDatabaseRequest) (*adminpb.Database, error) { |
| s.logf("GetDatabase(%q)", req.Name) |
| |
| return &adminpb.Database{ |
| Name: req.Name, |
| State: adminpb.Database_READY, |
| CreateTime: timestampProto(s.start), |
| }, nil |
| } |
| |
| // UpdateDDL applies the given DDL to the server. |
| // |
| // This is a convenience method for tests that may assume an existing schema. |
| // The more general approach is to dial this server using an admin client, and |
| // use the UpdateDatabaseDdl RPC method. |
| func (s *Server) UpdateDDL(ddl *spansql.DDL) error { |
| ctx := context.Background() |
| for _, stmt := range ddl.List { |
| if st := s.s.runOneDDL(ctx, stmt); st.Code() != codes.OK { |
| return st.Err() |
| } |
| } |
| return nil |
| } |
| |
| func (s *server) UpdateDatabaseDdl(ctx context.Context, req *adminpb.UpdateDatabaseDdlRequest) (*lropb.Operation, error) { |
| // Parse all the DDL statements first. |
| var stmts []spansql.DDLStmt |
| for _, s := range req.Statements { |
| stmt, err := spansql.ParseDDLStmt(s) |
| if err != nil { |
| // TODO: check what code the real Spanner returns here. |
| return nil, status.Errorf(codes.InvalidArgument, "bad DDL statement %q: %v", s, err) |
| } |
| stmts = append(stmts, stmt) |
| } |
| |
| // Nothing should be depending on the exact structure of this, |
| // but it is specified in google/spanner/admin/database/v1/spanner_database_admin.proto. |
| id := "projects/fake-proj/instances/fake-instance/databases/fake-db/operations/" + genRandomOperation() |
| lro := newLRO(&lropb.Operation{Name: id}) |
| s.mu.Lock() |
| s.lros[id] = lro |
| s.mu.Unlock() |
| |
| go lro.Run(s, stmts) |
| return lro.State(), nil |
| } |
| |
| func (l *lro) Run(s *server, stmts []spansql.DDLStmt) { |
| ctx := context.Background() |
| |
| for _, stmt := range stmts { |
| // Simulate delayed DDL application, but only if nobody is waiting. |
| select { |
| case <-time.After(100 * time.Millisecond): |
| case <-l.waitc: |
| } |
| |
| if st := s.runOneDDL(ctx, stmt); st.Code() != codes.OK { |
| l.mu.Lock() |
| l.state.Done = true |
| l.state.Result = &lropb.Operation_Error{st.Proto()} |
| l.mu.Unlock() |
| return |
| } |
| } |
| |
| l.mu.Lock() |
| l.state.Done = true |
| l.state.Result = &lropb.Operation_Response{&anypb.Any{}} |
| l.mu.Unlock() |
| } |
| |
| func (s *server) runOneDDL(ctx context.Context, stmt spansql.DDLStmt) *status.Status { |
| return s.db.ApplyDDL(stmt) |
| } |
| |
| func (s *server) GetDatabaseDdl(ctx context.Context, req *adminpb.GetDatabaseDdlRequest) (*adminpb.GetDatabaseDdlResponse, error) { |
| s.logf("GetDatabaseDdl(%q)", req.Database) |
| |
| var resp adminpb.GetDatabaseDdlResponse |
| for _, stmt := range s.db.GetDDL() { |
| resp.Statements = append(resp.Statements, stmt.SQL()) |
| } |
| return &resp, nil |
| } |
| |
| func (s *server) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) { |
| //s.logf("CreateSession(%q)", req.Database) |
| return s.newSession(), nil |
| } |
| |
| func (s *server) newSession() *spannerpb.Session { |
| id := genRandomSession() |
| now := time.Now() |
| sess := &session{ |
| name: id, |
| creation: now, |
| lastUse: now, |
| transactions: make(map[string]*transaction), |
| } |
| sess.ctx, sess.cancel = context.WithCancel(context.Background()) |
| |
| s.mu.Lock() |
| s.sessions[id] = sess |
| s.mu.Unlock() |
| |
| return sess.Proto() |
| } |
| |
| func (s *server) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) { |
| //s.logf("BatchCreateSessions(%q)", req.Database) |
| |
| var sessions []*spannerpb.Session |
| for i := int32(0); i < req.GetSessionCount(); i++ { |
| sessions = append(sessions, s.newSession()) |
| } |
| |
| return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil |
| } |
| |
| func (s *server) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) { |
| s.mu.Lock() |
| sess, ok := s.sessions[req.Name] |
| s.mu.Unlock() |
| |
| if !ok { |
| // TODO: what error does the real Spanner return? |
| return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Name) |
| } |
| |
| return sess.Proto(), nil |
| } |
| |
| // TODO: ListSessions |
| |
| func (s *server) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) { |
| //s.logf("DeleteSession(%q)", req.Name) |
| |
| s.mu.Lock() |
| sess, ok := s.sessions[req.Name] |
| delete(s.sessions, req.Name) |
| s.mu.Unlock() |
| |
| if !ok { |
| // TODO: what error does the real Spanner return? |
| return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Name) |
| } |
| |
| // Terminate any operations in this session. |
| sess.cancel() |
| |
| return &emptypb.Empty{}, nil |
| } |
| |
| // popTx returns an existing transaction, removing it from the session. |
| // This is called when a transaction is finishing (Commit, Rollback). |
| func (s *server) popTx(sessionID, tid string) (tx *transaction, err error) { |
| s.mu.Lock() |
| sess, ok := s.sessions[sessionID] |
| s.mu.Unlock() |
| if !ok { |
| // TODO: what error does the real Spanner return? |
| return nil, status.Errorf(codes.NotFound, "unknown session %q", sessionID) |
| } |
| |
| sess.mu.Lock() |
| sess.lastUse = time.Now() |
| tx, ok = sess.transactions[tid] |
| if ok { |
| delete(sess.transactions, tid) |
| } |
| sess.mu.Unlock() |
| if !ok { |
| // TODO: what error does the real Spanner return? |
| return nil, status.Errorf(codes.NotFound, "unknown transaction ID %q", tid) |
| } |
| return tx, nil |
| } |
| |
| // readTx returns a transaction for the given session and transaction selector. |
| // It is used by read/query operations (ExecuteStreamingSql, StreamingRead). |
| func (s *server) readTx(ctx context.Context, session string, tsel *spannerpb.TransactionSelector) (tx *transaction, cleanup func(), err error) { |
| s.mu.Lock() |
| sess, ok := s.sessions[session] |
| s.mu.Unlock() |
| if !ok { |
| // TODO: what error does the real Spanner return? |
| return nil, nil, status.Errorf(codes.NotFound, "unknown session %q", session) |
| } |
| |
| sess.mu.Lock() |
| sess.lastUse = time.Now() |
| sess.mu.Unlock() |
| |
| // Only give a read-only transaction regardless of whether the selector |
| // is requesting a read-write or read-only one, since this is in readTx |
| // and so shouldn't be mutating anyway. |
| singleUse := func() (*transaction, func(), error) { |
| tx := s.db.NewReadOnlyTransaction() |
| return tx, tx.Rollback, nil |
| } |
| |
| if tsel.GetSelector() == nil { |
| return singleUse() |
| } |
| |
| switch sel := tsel.Selector.(type) { |
| default: |
| return nil, nil, fmt.Errorf("TransactionSelector type %T not supported", sel) |
| case *spannerpb.TransactionSelector_SingleUse: |
| // Ignore options (e.g. timestamps). |
| switch mode := sel.SingleUse.Mode.(type) { |
| case *spannerpb.TransactionOptions_ReadOnly_: |
| return singleUse() |
| case *spannerpb.TransactionOptions_ReadWrite_: |
| return singleUse() |
| default: |
| return nil, nil, fmt.Errorf("single use transaction in mode %T not supported", mode) |
| } |
| case *spannerpb.TransactionSelector_Id: |
| sess.mu.Lock() |
| tx, ok := sess.transactions[string(sel.Id)] |
| sess.mu.Unlock() |
| if !ok { |
| return nil, nil, fmt.Errorf("no transaction with id %q", sel.Id) |
| } |
| return tx, func() {}, nil |
| } |
| } |
| |
| func (s *server) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { |
| // Assume this is probably a DML statement or a ping from the session pool. |
| // Queries normally use ExecuteStreamingSql. |
| // TODO: Expand this to support more things. |
| |
| // If it is a single-use transaction we assume it is a query. |
| if req.Transaction.GetSelector() == nil || req.Transaction.GetSingleUse().GetReadOnly() != nil { |
| ri, err := s.executeQuery(req) |
| if err != nil { |
| return nil, err |
| } |
| return s.resultSet(ri) |
| } |
| |
| obj, ok := req.Transaction.Selector.(*spannerpb.TransactionSelector_Id) |
| if !ok { |
| return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction.Selector) |
| } |
| tid := string(obj.Id) |
| _ = tid // TODO: lookup an existing transaction by ID. |
| |
| stmt, err := spansql.ParseDMLStmt(req.Sql) |
| if err != nil { |
| return nil, status.Errorf(codes.InvalidArgument, "bad DML: %v", err) |
| } |
| params, err := parseQueryParams(req.GetParams(), req.ParamTypes) |
| if err != nil { |
| return nil, err |
| } |
| |
| s.logf("Executing: %s", stmt.SQL()) |
| if len(params) > 0 { |
| s.logf(" â–¹ %v", params) |
| } |
| |
| n, err := s.db.Execute(stmt, params) |
| if err != nil { |
| return nil, err |
| } |
| return &spannerpb.ResultSet{ |
| Stats: &spannerpb.ResultSetStats{ |
| RowCount: &spannerpb.ResultSetStats_RowCountExact{int64(n)}, |
| }, |
| }, nil |
| } |
| |
| func (s *server) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error { |
| tx, cleanup, err := s.readTx(stream.Context(), req.Session, req.Transaction) |
| if err != nil { |
| return err |
| } |
| defer cleanup() |
| |
| ri, err := s.executeQuery(req) |
| if err != nil { |
| return err |
| } |
| return s.readStream(stream.Context(), tx, stream.Send, ri) |
| } |
| |
| func (s *server) executeQuery(req *spannerpb.ExecuteSqlRequest) (ri rowIter, err error) { |
| q, err := spansql.ParseQuery(req.Sql) |
| if err != nil { |
| // TODO: check what code the real Spanner returns here. |
| return nil, status.Errorf(codes.InvalidArgument, "bad query: %v", err) |
| } |
| |
| params, err := parseQueryParams(req.GetParams(), req.ParamTypes) |
| if err != nil { |
| return nil, err |
| } |
| |
| s.logf("Querying: %s", q.SQL()) |
| if len(params) > 0 { |
| s.logf(" â–¹ %v", params) |
| } |
| |
| return s.db.Query(q, params) |
| } |
| |
| // TODO: Read |
| |
| func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error { |
| tx, cleanup, err := s.readTx(stream.Context(), req.Session, req.Transaction) |
| if err != nil { |
| return err |
| } |
| defer cleanup() |
| |
| // Bail out if various advanced features are being used. |
| if req.Index != "" { |
| // This is okay; we can still return results. |
| s.logf("Warning: index reads (%q) not supported", req.Index) |
| } |
| if len(req.ResumeToken) > 0 { |
| // This should only happen if we send resume_token ourselves. |
| return fmt.Errorf("read resumption not supported") |
| } |
| if len(req.PartitionToken) > 0 { |
| return fmt.Errorf("partition restrictions not supported") |
| } |
| |
| var ri rowIter |
| if req.KeySet.All { |
| s.logf("Reading all from %s (cols: %v)", req.Table, req.Columns) |
| ri, err = s.db.ReadAll(spansql.ID(req.Table), idList(req.Columns), req.Limit) |
| } else { |
| s.logf("Reading rows from %d keys and %d ranges from %s (cols: %v)", len(req.KeySet.Keys), len(req.KeySet.Ranges), req.Table, req.Columns) |
| ri, err = s.db.Read(spansql.ID(req.Table), idList(req.Columns), req.KeySet.Keys, makeKeyRangeList(req.KeySet.Ranges), req.Limit) |
| } |
| if err != nil { |
| return err |
| } |
| |
| // TODO: Figure out the right contexts to use here. There's the session one (sess.ctx), |
| // but also this specific RPC one (stream.Context()). Which takes precedence? |
| // They appear to be independent. |
| |
| return s.readStream(stream.Context(), tx, stream.Send, ri) |
| } |
| |
| func (s *server) resultSet(ri rowIter) (*spannerpb.ResultSet, error) { |
| rsm, err := s.buildResultSetMetadata(ri) |
| if err != nil { |
| return nil, err |
| } |
| rs := &spannerpb.ResultSet{ |
| Metadata: rsm, |
| } |
| for { |
| row, err := ri.Next() |
| if err == io.EOF { |
| break |
| } else if err != nil { |
| return nil, err |
| } |
| |
| values := make([]*structpb.Value, len(row)) |
| for i, x := range row { |
| v, err := spannerValueFromValue(x) |
| if err != nil { |
| return nil, err |
| } |
| values[i] = v |
| } |
| rs.Rows = append(rs.Rows, &structpb.ListValue{Values: values}) |
| } |
| return rs, nil |
| } |
| |
| func (s *server) readStream(ctx context.Context, tx *transaction, send func(*spannerpb.PartialResultSet) error, ri rowIter) error { |
| rsm, err := s.buildResultSetMetadata(ri) |
| if err != nil { |
| return err |
| } |
| |
| for { |
| row, err := ri.Next() |
| if err == io.EOF { |
| break |
| } else if err != nil { |
| return err |
| } |
| |
| values := make([]*structpb.Value, len(row)) |
| for i, x := range row { |
| v, err := spannerValueFromValue(x) |
| if err != nil { |
| return err |
| } |
| values[i] = v |
| } |
| |
| prs := &spannerpb.PartialResultSet{ |
| Metadata: rsm, |
| Values: values, |
| } |
| if err := send(prs); err != nil { |
| return err |
| } |
| |
| // ResultSetMetadata is only set for the first PartialResultSet. |
| rsm = nil |
| } |
| |
| return nil |
| } |
| |
| func (s *server) buildResultSetMetadata(ri rowIter) (*spannerpb.ResultSetMetadata, error) { |
| // Build the result set metadata. |
| rsm := &spannerpb.ResultSetMetadata{ |
| RowType: &spannerpb.StructType{}, |
| // TODO: transaction info? |
| } |
| for _, ci := range ri.Cols() { |
| st, err := spannerTypeFromType(ci.Type) |
| if err != nil { |
| return nil, err |
| } |
| rsm.RowType.Fields = append(rsm.RowType.Fields, &spannerpb.StructType_Field{ |
| Name: string(ci.Name), |
| Type: st, |
| }) |
| } |
| return rsm, nil |
| } |
| |
| func (s *server) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) { |
| //s.logf("BeginTransaction(%v)", req) |
| |
| s.mu.Lock() |
| sess, ok := s.sessions[req.Session] |
| s.mu.Unlock() |
| if !ok { |
| // TODO: what error does the real Spanner return? |
| return nil, status.Errorf(codes.NotFound, "unknown session %q", req.Session) |
| } |
| |
| id := genRandomTransaction() |
| tx := s.db.NewTransaction() |
| |
| sess.mu.Lock() |
| sess.lastUse = time.Now() |
| sess.transactions[id] = tx |
| sess.mu.Unlock() |
| |
| tr := &spannerpb.Transaction{Id: []byte(id)} |
| |
| if req.GetOptions().GetReadOnly().GetReturnReadTimestamp() { |
| // Return the last commit timestamp. |
| // This isn't wholly accurate, but may be good enough for simple use cases. |
| tr.ReadTimestamp = timestampProto(s.db.LastCommitTimestamp()) |
| } |
| |
| return tr, nil |
| } |
| |
| func (s *server) Commit(ctx context.Context, req *spannerpb.CommitRequest) (resp *spannerpb.CommitResponse, err error) { |
| //s.logf("Commit(%q, %q)", req.Session, req.Transaction) |
| |
| obj, ok := req.Transaction.(*spannerpb.CommitRequest_TransactionId) |
| if !ok { |
| return nil, fmt.Errorf("unsupported transaction type %T", req.Transaction) |
| } |
| tid := string(obj.TransactionId) |
| |
| tx, err := s.popTx(req.Session, tid) |
| if err != nil { |
| return nil, err |
| } |
| defer func() { |
| if err != nil { |
| tx.Rollback() |
| } |
| }() |
| tx.Start() |
| |
| for _, m := range req.Mutations { |
| switch op := m.Operation.(type) { |
| default: |
| return nil, fmt.Errorf("unsupported mutation operation type %T", op) |
| case *spannerpb.Mutation_Insert: |
| ins := op.Insert |
| err := s.db.Insert(tx, spansql.ID(ins.Table), idList(ins.Columns), ins.Values) |
| if err != nil { |
| return nil, err |
| } |
| case *spannerpb.Mutation_Update: |
| up := op.Update |
| err := s.db.Update(tx, spansql.ID(up.Table), idList(up.Columns), up.Values) |
| if err != nil { |
| return nil, err |
| } |
| case *spannerpb.Mutation_InsertOrUpdate: |
| iou := op.InsertOrUpdate |
| err := s.db.InsertOrUpdate(tx, spansql.ID(iou.Table), idList(iou.Columns), iou.Values) |
| if err != nil { |
| return nil, err |
| } |
| case *spannerpb.Mutation_Delete_: |
| del := op.Delete |
| ks := del.KeySet |
| |
| err := s.db.Delete(tx, spansql.ID(del.Table), ks.Keys, makeKeyRangeList(ks.Ranges), ks.All) |
| if err != nil { |
| return nil, err |
| } |
| } |
| |
| } |
| |
| ts, err := tx.Commit() |
| if err != nil { |
| return nil, err |
| } |
| |
| return &spannerpb.CommitResponse{ |
| CommitTimestamp: timestampProto(ts), |
| }, nil |
| } |
| |
| func (s *server) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) { |
| s.logf("Rollback(%v)", req) |
| |
| tx, err := s.popTx(req.Session, string(req.TransactionId)) |
| if err != nil { |
| return nil, err |
| } |
| |
| tx.Rollback() |
| |
| return &emptypb.Empty{}, nil |
| } |
| |
| // TODO: PartitionQuery, PartitionRead |
| |
| func parseQueryParams(p *structpb.Struct, types map[string]*spannerpb.Type) (queryParams, error) { |
| params := make(queryParams) |
| for k, v := range p.GetFields() { |
| p, err := parseQueryParam(v, types[k]) |
| if err != nil { |
| return nil, err |
| } |
| params[k] = p |
| } |
| return params, nil |
| } |
| |
| func parseQueryParam(v *structpb.Value, typ *spannerpb.Type) (queryParam, error) { |
| // TODO: Use valForType and typeFromSpannerType more comprehensively here? |
| // They are only used for StringValue vs, since that's what mostly needs parsing. |
| |
| rawv := v |
| switch v := v.Kind.(type) { |
| default: |
| return queryParam{}, fmt.Errorf("unsupported well-known type value kind %T", v) |
| case *structpb.Value_NullValue: |
| return queryParam{Value: nil}, nil // TODO: set a type? |
| case *structpb.Value_BoolValue: |
| return queryParam{Value: v.BoolValue, Type: boolType}, nil |
| case *structpb.Value_NumberValue: |
| return queryParam{Value: v.NumberValue, Type: float64Type}, nil |
| case *structpb.Value_StringValue: |
| t, err := typeFromSpannerType(typ) |
| if err != nil { |
| return queryParam{}, err |
| } |
| val, err := valForType(rawv, t) |
| if err != nil { |
| return queryParam{}, err |
| } |
| return queryParam{Value: val, Type: t}, nil |
| case *structpb.Value_ListValue: |
| var list []interface{} |
| for _, elem := range v.ListValue.Values { |
| // TODO: Change the type parameter passed through? We only look at the code. |
| p, err := parseQueryParam(elem, typ) |
| if err != nil { |
| return queryParam{}, err |
| } |
| list = append(list, p.Value) |
| } |
| t, err := typeFromSpannerType(typ) |
| if err != nil { |
| return queryParam{}, err |
| } |
| return queryParam{Value: list, Type: t}, nil |
| } |
| } |
| |
| func typeFromSpannerType(st *spannerpb.Type) (spansql.Type, error) { |
| switch st.Code { |
| default: |
| return spansql.Type{}, fmt.Errorf("unhandled spanner type code %v", st.Code) |
| case spannerpb.TypeCode_BOOL: |
| return spansql.Type{Base: spansql.Bool}, nil |
| case spannerpb.TypeCode_INT64: |
| return spansql.Type{Base: spansql.Int64}, nil |
| case spannerpb.TypeCode_FLOAT64: |
| return spansql.Type{Base: spansql.Float64}, nil |
| case spannerpb.TypeCode_TIMESTAMP: |
| return spansql.Type{Base: spansql.Timestamp}, nil |
| case spannerpb.TypeCode_DATE: |
| return spansql.Type{Base: spansql.Date}, nil |
| case spannerpb.TypeCode_STRING: |
| return spansql.Type{Base: spansql.String}, nil // no len |
| case spannerpb.TypeCode_BYTES: |
| return spansql.Type{Base: spansql.Bytes}, nil // no len |
| case spannerpb.TypeCode_ARRAY: |
| typ, err := typeFromSpannerType(st.ArrayElementType) |
| if err != nil { |
| return spansql.Type{}, err |
| } |
| typ.Array = true |
| return typ, nil |
| } |
| } |
| |
| func spannerTypeFromType(typ spansql.Type) (*spannerpb.Type, error) { |
| var code spannerpb.TypeCode |
| switch typ.Base { |
| default: |
| return nil, fmt.Errorf("unhandled base type %d", typ.Base) |
| case spansql.Bool: |
| code = spannerpb.TypeCode_BOOL |
| case spansql.Int64: |
| code = spannerpb.TypeCode_INT64 |
| case spansql.Float64: |
| code = spannerpb.TypeCode_FLOAT64 |
| case spansql.String: |
| code = spannerpb.TypeCode_STRING |
| case spansql.Bytes: |
| code = spannerpb.TypeCode_BYTES |
| case spansql.Date: |
| code = spannerpb.TypeCode_DATE |
| case spansql.Timestamp: |
| code = spannerpb.TypeCode_TIMESTAMP |
| } |
| st := &spannerpb.Type{Code: code} |
| if typ.Array { |
| st = &spannerpb.Type{ |
| Code: spannerpb.TypeCode_ARRAY, |
| ArrayElementType: st, |
| } |
| } |
| return st, nil |
| } |
| |
| func spannerValueFromValue(x interface{}) (*structpb.Value, error) { |
| switch x := x.(type) { |
| default: |
| return nil, fmt.Errorf("unhandled database value type %T", x) |
| case bool: |
| return &structpb.Value{Kind: &structpb.Value_BoolValue{x}}, nil |
| case int64: |
| // The Spanner int64 is actually a decimal string. |
| s := strconv.FormatInt(x, 10) |
| return &structpb.Value{Kind: &structpb.Value_StringValue{s}}, nil |
| case float64: |
| return &structpb.Value{Kind: &structpb.Value_NumberValue{x}}, nil |
| case string: |
| return &structpb.Value{Kind: &structpb.Value_StringValue{x}}, nil |
| case []byte: |
| return &structpb.Value{Kind: &structpb.Value_StringValue{base64.StdEncoding.EncodeToString(x)}}, nil |
| case civil.Date: |
| // RFC 3339 date format. |
| return &structpb.Value{Kind: &structpb.Value_StringValue{x.String()}}, nil |
| case time.Time: |
| // RFC 3339 timestamp format with zone Z. |
| s := x.Format("2006-01-02T15:04:05.999999999Z") |
| return &structpb.Value{Kind: &structpb.Value_StringValue{s}}, nil |
| case nil: |
| return &structpb.Value{Kind: &structpb.Value_NullValue{}}, nil |
| case []interface{}: |
| var vs []*structpb.Value |
| for _, elem := range x { |
| v, err := spannerValueFromValue(elem) |
| if err != nil { |
| return nil, err |
| } |
| vs = append(vs, v) |
| } |
| return &structpb.Value{Kind: &structpb.Value_ListValue{ |
| &structpb.ListValue{Values: vs}, |
| }}, nil |
| } |
| } |
| |
| func makeKeyRangeList(ranges []*spannerpb.KeyRange) keyRangeList { |
| var krl keyRangeList |
| for _, r := range ranges { |
| krl = append(krl, makeKeyRange(r)) |
| } |
| return krl |
| } |
| |
| func makeKeyRange(r *spannerpb.KeyRange) *keyRange { |
| var kr keyRange |
| switch s := r.StartKeyType.(type) { |
| case *spannerpb.KeyRange_StartClosed: |
| kr.start = s.StartClosed |
| kr.startClosed = true |
| case *spannerpb.KeyRange_StartOpen: |
| kr.start = s.StartOpen |
| } |
| switch e := r.EndKeyType.(type) { |
| case *spannerpb.KeyRange_EndClosed: |
| kr.end = e.EndClosed |
| kr.endClosed = true |
| case *spannerpb.KeyRange_EndOpen: |
| kr.end = e.EndOpen |
| } |
| return &kr |
| } |
| |
| func idList(ss []string) (ids []spansql.ID) { |
| for _, s := range ss { |
| ids = append(ids, spansql.ID(s)) |
| } |
| return |
| } |