bigquery: support array query parameters

Change-Id: I396aeed685b421d3c3ab07a552e28e4eded8f064
Reviewed-on: https://code-review.googlesource.com/9558
Reviewed-by: Ross Light <light@google.com>
diff --git a/bigquery/params.go b/bigquery/params.go
index 7d6fe36..046b874 100644
--- a/bigquery/params.go
+++ b/bigquery/params.go
@@ -17,6 +17,7 @@
 import (
 	"encoding/base64"
 	"fmt"
+	"reflect"
 	"time"
 
 	bq "google.golang.org/api/bigquery/v2"
@@ -34,23 +35,32 @@
 	timestampParamType = &bq.QueryParameterType{Type: "TIMESTAMP"}
 )
 
-func paramType(x interface{}) (*bq.QueryParameterType, error) {
-	switch x.(type) {
-	case int, int8, int16, int32, int64, uint8, uint16, uint32:
+var timeType = reflect.TypeOf(time.Time{})
+
+func paramType(t reflect.Type) (*bq.QueryParameterType, error) {
+	switch t.Kind() {
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32:
 		return int64ParamType, nil
-	case float32, float64:
+	case reflect.Float32, reflect.Float64:
 		return float64ParamType, nil
-	case bool:
+	case reflect.Bool:
 		return boolParamType, nil
-	case string:
+	case reflect.String:
 		return stringParamType, nil
-	case time.Time:
-		return timestampParamType, nil
-	case []byte:
-		return bytesParamType, nil
-	default:
-		return nil, fmt.Errorf("Go type %T cannot be represented as a parameter type", x)
+	case reflect.Slice, reflect.Array:
+		if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 {
+			return bytesParamType, nil
+		}
+		et, err := paramType(t.Elem())
+		if err != nil {
+			return nil, err
+		}
+		return &bq.QueryParameterType{Type: "ARRAY", ArrayType: et}, nil
 	}
+	if t == timeType {
+		return timestampParamType, nil
+	}
+	return nil, fmt.Errorf("Go type %s cannot be represented as a parameter type", t)
 }
 
 func paramValue(x interface{}) (bq.QueryParameterValue, error) {
@@ -63,7 +73,20 @@
 		return sval(base64.StdEncoding.EncodeToString(x)), nil
 	case time.Time:
 		return sval(x.Format(timestampFormat)), nil
-	default:
-		return sval(fmt.Sprint(x)), nil
 	}
+	t := reflect.TypeOf(x)
+	switch t.Kind() {
+	case reflect.Slice, reflect.Array:
+		var vals []*bq.QueryParameterValue
+		v := reflect.ValueOf(x)
+		for i := 0; i < v.Len(); i++ {
+			val, err := paramValue(v.Index(i).Interface())
+			if err != nil {
+				return bq.QueryParameterValue{}, err
+			}
+			vals = append(vals, &val)
+		}
+		return bq.QueryParameterValue{ArrayValues: vals}, nil
+	}
+	return sval(fmt.Sprint(x)), nil
 }
diff --git a/bigquery/params_test.go b/bigquery/params_test.go
index bd545bd..97c75cc 100644
--- a/bigquery/params_test.go
+++ b/bigquery/params_test.go
@@ -15,8 +15,8 @@
 package bigquery
 
 import (
-	"bytes"
 	"context"
+	"errors"
 	"math"
 	"reflect"
 	"testing"
@@ -60,7 +60,31 @@
 	}
 }
 
-func TestParamTypeScalar(t *testing.T) {
+func TestParamValueArray(t *testing.T) {
+	for _, test := range []struct {
+		val  interface{}
+		want []string
+	}{
+		{[]int(nil), []string{}},
+		{[]int{}, []string{}},
+		{[]int{1, 2}, []string{"1", "2"}},
+		{[3]int{1, 2, 3}, []string{"1", "2", "3"}},
+	} {
+		got, err := paramValue(test.val)
+		if err != nil {
+			t.Fatal(err)
+		}
+		var want bq.QueryParameterValue
+		for _, s := range test.want {
+			want.ArrayValues = append(want.ArrayValues, &bq.QueryParameterValue{Value: s})
+		}
+		if !reflect.DeepEqual(got, want) {
+			t.Errorf("%#v:\ngot  %+v\nwant %+v", test.val, got, want)
+		}
+	}
+}
+
+func TestParamType(t *testing.T) {
 	for _, test := range []struct {
 		val  interface{}
 		want *bq.QueryParameterType
@@ -75,42 +99,71 @@
 		{"string", stringParamType},
 		{time.Now(), timestampParamType},
 		{[]byte("foo"), bytesParamType},
+		{[]int{}, &bq.QueryParameterType{Type: "ARRAY", ArrayType: int64ParamType}},
+		{[3]bool{}, &bq.QueryParameterType{Type: "ARRAY", ArrayType: boolParamType}},
 	} {
-		got, err := paramType(test.val)
+		got, err := paramType(reflect.TypeOf(test.val))
 		if err != nil {
 			t.Fatal(err)
 		}
-		if got != test.want {
+		if !reflect.DeepEqual(got, test.want) {
 			t.Errorf("%v (%T): got %v, want %v", test.val, test.val, got, test.want)
 		}
 	}
 }
 
 func TestIntegration_ScalarParam(t *testing.T) {
-	ctx := context.Background()
 	c := getClient(t)
 	for _, test := range scalarTests {
-		q := c.Query("select ?")
-		q.Parameters = []QueryParameter{{Value: test.val}}
-		it, err := q.Read(ctx)
+		got, err := paramRoundTrip(c, test.val)
 		if err != nil {
 			t.Fatal(err)
 		}
-		var val []Value
-		err = it.Next(&val)
-		if err != nil {
-			t.Fatal(err)
-		}
-		if len(val) != 1 {
-			t.Fatalf("got %d values, want 1", len(val))
-		}
-		got := val[0]
 		if !equal(got, test.val) {
 			t.Errorf("\ngot  %#v (%T)\nwant %#v (%T)", got, got, test.val, test.val)
 		}
 	}
 }
 
+func TestIntegration_ArrayParam(t *testing.T) {
+	c := getClient(t)
+	for _, test := range []struct {
+		val  interface{}
+		want interface{}
+	}{
+		{[]int(nil), []Value(nil)},
+		{[]int{}, []Value(nil)},
+		{[]int{1, 2}, []Value{int64(1), int64(2)}},
+		{[3]int{1, 2, 3}, []Value{int64(1), int64(2), int64(3)}},
+	} {
+		got, err := paramRoundTrip(c, test.val)
+		if err != nil {
+			t.Fatal(err)
+		}
+		if !equal(got, test.want) {
+			t.Errorf("\ngot  %#v (%T)\nwant %#v (%T)", got, got, test.want, test.want)
+		}
+	}
+}
+
+func paramRoundTrip(c *Client, x interface{}) (Value, error) {
+	q := c.Query("select ?")
+	q.Parameters = []QueryParameter{{Value: x}}
+	it, err := q.Read(context.Background())
+	if err != nil {
+		return nil, err
+	}
+	var val []Value
+	err = it.Next(&val)
+	if err != nil {
+		return nil, err
+	}
+	if len(val) != 1 {
+		return nil, errors.New("wrong number of values")
+	}
+	return val[0], nil
+}
+
 func equal(x1, x2 interface{}) bool {
 	if reflect.TypeOf(x1) != reflect.TypeOf(x2) {
 		return false
@@ -124,9 +177,7 @@
 	case time.Time:
 		// BigQuery is only accurate to the microsecond.
 		return x1.Round(time.Microsecond).Equal(x2.(time.Time).Round(time.Microsecond))
-	case []byte:
-		return bytes.Equal(x1, x2.([]byte))
 	default:
-		return x1 == x2
+		return reflect.DeepEqual(x1, x2)
 	}
 }
diff --git a/bigquery/query.go b/bigquery/query.go
index 8ad2444..af40d23 100644
--- a/bigquery/query.go
+++ b/bigquery/query.go
@@ -15,6 +15,8 @@
 package bigquery
 
 import (
+	"reflect"
+
 	"golang.org/x/net/context"
 	bq "google.golang.org/api/bigquery/v2"
 )
@@ -209,7 +211,7 @@
 		if err != nil {
 			return err
 		}
-		pt, err := paramType(p.Value)
+		pt, err := paramType(reflect.TypeOf(p.Value))
 		if err != nil {
 			return err
 		}