bigquery: support array query parameters
Change-Id: I396aeed685b421d3c3ab07a552e28e4eded8f064
Reviewed-on: https://code-review.googlesource.com/9558
Reviewed-by: Ross Light <light@google.com>
diff --git a/bigquery/params.go b/bigquery/params.go
index 7d6fe36..046b874 100644
--- a/bigquery/params.go
+++ b/bigquery/params.go
@@ -17,6 +17,7 @@
import (
"encoding/base64"
"fmt"
+ "reflect"
"time"
bq "google.golang.org/api/bigquery/v2"
@@ -34,23 +35,32 @@
timestampParamType = &bq.QueryParameterType{Type: "TIMESTAMP"}
)
-func paramType(x interface{}) (*bq.QueryParameterType, error) {
- switch x.(type) {
- case int, int8, int16, int32, int64, uint8, uint16, uint32:
+var timeType = reflect.TypeOf(time.Time{})
+
+func paramType(t reflect.Type) (*bq.QueryParameterType, error) {
+ switch t.Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64ParamType, nil
- case float32, float64:
+ case reflect.Float32, reflect.Float64:
return float64ParamType, nil
- case bool:
+ case reflect.Bool:
return boolParamType, nil
- case string:
+ case reflect.String:
return stringParamType, nil
- case time.Time:
- return timestampParamType, nil
- case []byte:
- return bytesParamType, nil
- default:
- return nil, fmt.Errorf("Go type %T cannot be represented as a parameter type", x)
+ case reflect.Slice, reflect.Array:
+ if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 {
+ return bytesParamType, nil
+ }
+ et, err := paramType(t.Elem())
+ if err != nil {
+ return nil, err
+ }
+ return &bq.QueryParameterType{Type: "ARRAY", ArrayType: et}, nil
}
+ if t == timeType {
+ return timestampParamType, nil
+ }
+ return nil, fmt.Errorf("Go type %s cannot be represented as a parameter type", t)
}
func paramValue(x interface{}) (bq.QueryParameterValue, error) {
@@ -63,7 +73,20 @@
return sval(base64.StdEncoding.EncodeToString(x)), nil
case time.Time:
return sval(x.Format(timestampFormat)), nil
- default:
- return sval(fmt.Sprint(x)), nil
}
+ t := reflect.TypeOf(x)
+ switch t.Kind() {
+ case reflect.Slice, reflect.Array:
+ var vals []*bq.QueryParameterValue
+ v := reflect.ValueOf(x)
+ for i := 0; i < v.Len(); i++ {
+ val, err := paramValue(v.Index(i).Interface())
+ if err != nil {
+ return bq.QueryParameterValue{}, err
+ }
+ vals = append(vals, &val)
+ }
+ return bq.QueryParameterValue{ArrayValues: vals}, nil
+ }
+ return sval(fmt.Sprint(x)), nil
}
diff --git a/bigquery/params_test.go b/bigquery/params_test.go
index bd545bd..97c75cc 100644
--- a/bigquery/params_test.go
+++ b/bigquery/params_test.go
@@ -15,8 +15,8 @@
package bigquery
import (
- "bytes"
"context"
+ "errors"
"math"
"reflect"
"testing"
@@ -60,7 +60,31 @@
}
}
-func TestParamTypeScalar(t *testing.T) {
+func TestParamValueArray(t *testing.T) {
+ for _, test := range []struct {
+ val interface{}
+ want []string
+ }{
+ {[]int(nil), []string{}},
+ {[]int{}, []string{}},
+ {[]int{1, 2}, []string{"1", "2"}},
+ {[3]int{1, 2, 3}, []string{"1", "2", "3"}},
+ } {
+ got, err := paramValue(test.val)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var want bq.QueryParameterValue
+ for _, s := range test.want {
+ want.ArrayValues = append(want.ArrayValues, &bq.QueryParameterValue{Value: s})
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("%#v:\ngot %+v\nwant %+v", test.val, got, want)
+ }
+ }
+}
+
+func TestParamType(t *testing.T) {
for _, test := range []struct {
val interface{}
want *bq.QueryParameterType
@@ -75,42 +99,71 @@
{"string", stringParamType},
{time.Now(), timestampParamType},
{[]byte("foo"), bytesParamType},
+ {[]int{}, &bq.QueryParameterType{Type: "ARRAY", ArrayType: int64ParamType}},
+ {[3]bool{}, &bq.QueryParameterType{Type: "ARRAY", ArrayType: boolParamType}},
} {
- got, err := paramType(test.val)
+ got, err := paramType(reflect.TypeOf(test.val))
if err != nil {
t.Fatal(err)
}
- if got != test.want {
+ if !reflect.DeepEqual(got, test.want) {
t.Errorf("%v (%T): got %v, want %v", test.val, test.val, got, test.want)
}
}
}
func TestIntegration_ScalarParam(t *testing.T) {
- ctx := context.Background()
c := getClient(t)
for _, test := range scalarTests {
- q := c.Query("select ?")
- q.Parameters = []QueryParameter{{Value: test.val}}
- it, err := q.Read(ctx)
+ got, err := paramRoundTrip(c, test.val)
if err != nil {
t.Fatal(err)
}
- var val []Value
- err = it.Next(&val)
- if err != nil {
- t.Fatal(err)
- }
- if len(val) != 1 {
- t.Fatalf("got %d values, want 1", len(val))
- }
- got := val[0]
if !equal(got, test.val) {
t.Errorf("\ngot %#v (%T)\nwant %#v (%T)", got, got, test.val, test.val)
}
}
}
+func TestIntegration_ArrayParam(t *testing.T) {
+ c := getClient(t)
+ for _, test := range []struct {
+ val interface{}
+ want interface{}
+ }{
+ {[]int(nil), []Value(nil)},
+ {[]int{}, []Value(nil)},
+ {[]int{1, 2}, []Value{int64(1), int64(2)}},
+ {[3]int{1, 2, 3}, []Value{int64(1), int64(2), int64(3)}},
+ } {
+ got, err := paramRoundTrip(c, test.val)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !equal(got, test.want) {
+ t.Errorf("\ngot %#v (%T)\nwant %#v (%T)", got, got, test.want, test.want)
+ }
+ }
+}
+
+func paramRoundTrip(c *Client, x interface{}) (Value, error) {
+ q := c.Query("select ?")
+ q.Parameters = []QueryParameter{{Value: x}}
+ it, err := q.Read(context.Background())
+ if err != nil {
+ return nil, err
+ }
+ var val []Value
+ err = it.Next(&val)
+ if err != nil {
+ return nil, err
+ }
+ if len(val) != 1 {
+ return nil, errors.New("wrong number of values")
+ }
+ return val[0], nil
+}
+
func equal(x1, x2 interface{}) bool {
if reflect.TypeOf(x1) != reflect.TypeOf(x2) {
return false
@@ -124,9 +177,7 @@
case time.Time:
// BigQuery is only accurate to the microsecond.
return x1.Round(time.Microsecond).Equal(x2.(time.Time).Round(time.Microsecond))
- case []byte:
- return bytes.Equal(x1, x2.([]byte))
default:
- return x1 == x2
+ return reflect.DeepEqual(x1, x2)
}
}
diff --git a/bigquery/query.go b/bigquery/query.go
index 8ad2444..af40d23 100644
--- a/bigquery/query.go
+++ b/bigquery/query.go
@@ -15,6 +15,8 @@
package bigquery
import (
+ "reflect"
+
"golang.org/x/net/context"
bq "google.golang.org/api/bigquery/v2"
)
@@ -209,7 +211,7 @@
if err != nil {
return err
}
- pt, err := paramType(p.Value)
+ pt, err := paramType(reflect.TypeOf(p.Value))
if err != nil {
return err
}