blob: fb195829f7124660885a92dc34b54fc93668e4f0 [file] [log] [blame]
/*
Copyright 2020 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package spannertest
import (
"fmt"
"math"
"strconv"
"strings"
"time"
"cloud.google.com/go/civil"
"cloud.google.com/go/spanner/spansql"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// This file contains implementations of query functions.
type function struct {
// Eval evaluates the result of the function using the given input.
Eval func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error)
}
func firstErr(errors []error) error {
for _, err := range errors {
if err != nil {
return err
}
}
return nil
}
var functions = map[string]function{
"STARTS_WITH": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
// TODO: Refine error messages to exactly match Spanner.
// Check input values first.
if len(values) != 2 {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function STARTS_WITH for the given argument types")
}
for _, v := range values {
// TODO: STARTS_WITH also supports BYTES as input parameters.
if _, ok := v.(string); !ok {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function STARTS_WITH for the given argument types")
}
}
s := values[0].(string)
prefix := values[1].(string)
return strings.HasPrefix(s, prefix), spansql.Type{Base: spansql.Bool}, nil
},
},
"LOWER": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
if len(values) != 1 {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function LOWER for the given argument types")
}
if values[0] == nil {
return nil, spansql.Type{Base: spansql.String}, nil
}
if _, ok := values[0].(string); !ok {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function LOWER for the given argument types")
}
return strings.ToLower(values[0].(string)), spansql.Type{Base: spansql.String}, nil
},
},
"CAST": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
return cast(values, types, false)
},
},
"SAFE_CAST": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
return cast(values, types, true)
},
},
"JSON_VALUE": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
if len(values) != 2 {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function JSON_VALUE for the given argument types")
}
if values[0] == nil || values[1] == nil {
return nil, spansql.Type{Base: spansql.String}, nil
}
_, okArg1 := values[0].(string)
_, okArg2 := values[1].(string)
if !(okArg1 && okArg2) {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function JSON_VALUE for the given argument types")
}
// This function currently has no implementation and always returns
// an empty string, as it would otherwise require an XPath query
// engine.
return "", spansql.Type{Base: spansql.String}, nil
},
},
"EXTRACT": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
date, okArg1 := values[0].(civil.Date)
part, okArg2 := values[0].(int64)
if !(okArg1 || okArg2) {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function EXTRACT for the given argument types")
}
if okArg1 {
return date, spansql.Type{Base: spansql.Date}, nil
}
return part, spansql.Type{Base: spansql.Int64}, nil
},
},
"TIMESTAMP": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
t, okArg1 := values[0].(string)
if !(okArg1) {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function TIMESTAMP for the given argument types")
}
timestamp, err := time.Parse(time.RFC3339, t)
if err != nil {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function TIMESTAMP for the given argument types")
}
return timestamp, spansql.Type{Base: spansql.Timestamp}, nil
},
},
"FARM_FINGERPRINT": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
// Check input values first.
if len(values) != 1 {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function FARM_FINGERPRINT for the given argument types")
}
if values[0] == nil {
return int64(1), spansql.Type{Base: spansql.Int64}, nil
}
if _, ok := values[0].(string); !ok {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function FARM_FINGERPRINT for the given argument types")
}
// This function currently has no implementation and always returns
// same value, as it would otherwise require an fingerprint function
return int64(1), spansql.Type{Base: spansql.Int64}, nil
},
},
"MOD": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
// Check input values first.
if len(values) != 2 {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function MOD for the given argument types")
}
x, okArg1 := values[0].(int64)
y, okArg2 := values[1].(int64)
if !(okArg1 && okArg2) {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function MOD for the given argument types")
}
if y == 0 {
return nil, spansql.Type{}, status.Error(codes.OutOfRange, "Division by zero")
}
return x % y, spansql.Type{Base: spansql.Int64}, nil
},
},
}
func cast(values []interface{}, types []spansql.Type, safe bool) (interface{}, spansql.Type, error) {
name := "CAST"
if safe {
name = "SAFE_CAST"
}
if len(types) != 1 {
return nil, spansql.Type{}, status.Errorf(codes.InvalidArgument, "No type information for function %s for the given arguments", name)
}
if len(values) != 1 {
return nil, spansql.Type{}, status.Errorf(codes.InvalidArgument, "No matching signature for function %s for the given arguments", name)
}
// If the input type is an error, then the conversion itself failed.
if err, ok := values[0].(error); ok {
if safe {
return nil, types[0], nil
}
return nil, types[0], err
}
return values[0], types[0], nil
}
func convert(val interface{}, tp spansql.Type) (interface{}, error) {
// TODO: Implement more conversions.
if tp.Array {
return nil, status.Errorf(codes.Unimplemented, "conversion to ARRAY types is not implemented")
}
var res interface{}
var convertErr, err error
switch tp.Base {
case spansql.Int64:
res, convertErr, err = convertToInt64(val)
case spansql.Float64:
res, convertErr, err = convertToFloat64(val)
case spansql.String:
res, convertErr, err = convertToString(val)
case spansql.Bool:
res, convertErr, err = convertToBool(val)
case spansql.Date:
res, convertErr, err = convertToDate(val)
case spansql.Timestamp:
res, convertErr, err = convertToTimestamp(val)
case spansql.Numeric:
case spansql.JSON:
}
if err != nil {
return nil, err
}
if convertErr != nil {
res = convertErr
}
if res != nil {
return res, nil
}
return nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to %v", val, tp.Base.SQL())
}
func convertToInt64(val interface{}) (res int64, convertErr error, err error) {
switch v := val.(type) {
case int64:
return v, nil, nil
case string:
res, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return 0, status.Errorf(codes.InvalidArgument, "invalid value for INT64: %q", v), nil
}
return res, nil, nil
}
return 0, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to INT64", val)
}
func convertToFloat64(val interface{}) (res float64, convertErr error, err error) {
switch v := val.(type) {
case int64:
return float64(v), nil, nil
case float64:
return v, nil, nil
case string:
res, err := strconv.ParseFloat(v, 64)
if err != nil {
return 0, status.Errorf(codes.InvalidArgument, "invalid value for FLOAT64: %q", v), nil
}
return res, nil, nil
}
return 0, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to FLOAT64", val)
}
func convertToString(val interface{}) (res string, convertErr error, err error) {
switch v := val.(type) {
case string:
return v, nil, nil
case bool, int64, float64:
return fmt.Sprintf("%v", v), nil, nil
case civil.Date:
return v.String(), nil, nil
case time.Time:
return v.UTC().Format(time.RFC3339Nano), nil, nil
}
return "", nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to STRING", val)
}
func convertToBool(val interface{}) (res bool, convertErr error, err error) {
switch v := val.(type) {
case bool:
return v, nil, nil
case string:
res, err := strconv.ParseBool(v)
if err != nil {
return false, status.Errorf(codes.InvalidArgument, "invalid value for BOOL: %q", v), nil
}
return res, nil, nil
}
return false, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to BOOL", val)
}
func convertToDate(val interface{}) (res civil.Date, convertErr error, err error) {
switch v := val.(type) {
case civil.Date:
return v, nil, nil
case string:
res, err := civil.ParseDate(v)
if err != nil {
return civil.Date{}, status.Errorf(codes.InvalidArgument, "invalid value for DATE: %q", v), nil
}
return res, nil, nil
}
return civil.Date{}, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to DATE", val)
}
func convertToTimestamp(val interface{}) (res time.Time, convertErr error, err error) {
switch v := val.(type) {
case time.Time:
return v, nil, nil
case string:
res, err := time.Parse(time.RFC3339Nano, v)
if err != nil {
return time.Time{}, status.Errorf(codes.InvalidArgument, "invalid value for TIMESTAMP: %q", v), nil
}
return res, nil, nil
}
return time.Time{}, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to TIMESTAMP", val)
}
type aggregateFunc struct {
// Whether the function can take a * arg (only COUNT).
AcceptStar bool
// Every aggregate func takes one expression.
Eval func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error)
// TODO: Handle qualifiers such as DISTINCT.
}
// TODO: more aggregate funcs.
var aggregateFuncs = map[string]aggregateFunc{
"ANY_VALUE": {
// https://cloud.google.com/spanner/docs/aggregate_functions#any_value
Eval: func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) {
// Return the first non-NULL value.
for _, v := range values {
if v != nil {
return v, typ, nil
}
}
// Either all values are NULL, or there are no values.
return nil, typ, nil
},
},
"ARRAY_AGG": {
// https://cloud.google.com/spanner/docs/aggregate_functions#array_agg
Eval: func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) {
if typ.Array {
return nil, spansql.Type{}, fmt.Errorf("ARRAY_AGG unsupported on values of type %v", typ.SQL())
}
typ.Array = true // use as return type
if len(values) == 0 {
// "If there are zero input rows, this function returns NULL."
return nil, typ, nil
}
return values, typ, nil
},
},
"COUNT": {
AcceptStar: true,
Eval: func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) {
// Count the number of non-NULL values.
// COUNT(*) receives a list of non-NULL placeholders rather than values,
// so every value will be non-NULL.
var n int64
for _, v := range values {
if v != nil {
n++
}
}
return n, int64Type, nil
},
},
"MAX": {Eval: func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) {
return evalMinMax("MAX", false, values, typ)
}},
"MIN": {Eval: func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) {
return evalMinMax("MIN", true, values, typ)
}},
"SUM": {
Eval: func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) {
if typ.Array || !(typ.Base == spansql.Int64 || typ.Base == spansql.Float64) {
return nil, spansql.Type{}, fmt.Errorf("SUM only supports arguments of INT64 or FLOAT64 type, not %s", typ.SQL())
}
if typ.Base == spansql.Int64 {
var seen bool
var sum int64
for _, v := range values {
if v == nil {
continue
}
seen = true
sum += v.(int64)
}
if !seen {
// "Returns NULL if the input contains only NULLs".
return nil, typ, nil
}
return sum, typ, nil
}
var seen bool
var sum float64
for _, v := range values {
if v == nil {
continue
}
seen = true
sum += v.(float64)
}
if !seen {
// "Returns NULL if the input contains only NULLs".
return nil, typ, nil
}
return sum, typ, nil
},
},
"AVG": {
Eval: func(values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) {
if typ.Array || !(typ.Base == spansql.Int64 || typ.Base == spansql.Float64) {
return nil, spansql.Type{}, fmt.Errorf("AVG only supports arguments of INT64 or FLOAT64 type, not %s", typ.SQL())
}
if typ.Base == spansql.Int64 {
var sum int64
var n float64
for _, v := range values {
if v == nil {
continue
}
sum += v.(int64)
n++
}
if n == 0 {
// "Returns NULL if the input contains only NULLs".
return nil, typ, nil
}
return (float64(sum) / n), float64Type, nil
}
var sum float64
var n float64
for _, v := range values {
if v == nil {
continue
}
sum += v.(float64)
n++
}
if n == 0 {
// "Returns NULL if the input contains only NULLs".
return nil, typ, nil
}
return (sum / n), typ, nil
},
},
}
func evalMinMax(name string, isMin bool, values []interface{}, typ spansql.Type) (interface{}, spansql.Type, error) {
if typ.Array {
return nil, spansql.Type{}, fmt.Errorf("%s only supports non-array arguments, not %s", name, typ.SQL())
}
if len(values) == 0 {
// "Returns NULL if there are zero input rows".
return nil, typ, nil
}
// Compute running MIN/MAX.
// "Returns NULL if ... expression evaluates to NULL for all rows".
var minMax interface{}
for _, v := range values {
if v == nil {
// "Returns the {maximum|minimum} value of non-NULL expressions".
continue
}
if typ.Base == spansql.Float64 && math.IsNaN(v.(float64)) {
// "Returns NaN if the input contains a NaN".
return v, typ, nil
}
if minMax == nil || compareVals(v, minMax) < 0 == isMin {
minMax = v
}
}
return minMax, typ, nil
}