bigquery: support struct field parameters

See #390.

Change-Id: I4006c8f17dc7131997d096701420ea3fc1e13174
Reviewed-on: https://code-review.googlesource.com/9559
Reviewed-by: Ross Light <light@google.com>
diff --git a/bigquery/params.go b/bigquery/params.go
index 046b874..97ebcc3 100644
--- a/bigquery/params.go
+++ b/bigquery/params.go
@@ -16,16 +16,21 @@
 
 import (
 	"encoding/base64"
+	"errors"
 	"fmt"
 	"reflect"
 	"time"
 
+	"cloud.google.com/go/internal/fields"
+
 	bq "google.golang.org/api/bigquery/v2"
 )
 
 // See https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#timestamp-type.
 var timestampFormat = "2006-01-02 15:04:05.999999-07:00"
 
+var fieldCache = fields.NewCache(nil)
+
 var (
 	int64ParamType     = &bq.QueryParameterType{Type: "INT64"}
 	float64ParamType   = &bq.QueryParameterType{Type: "FLOAT64"}
@@ -38,55 +43,118 @@
 var timeType = reflect.TypeOf(time.Time{})
 
 func paramType(t reflect.Type) (*bq.QueryParameterType, error) {
+	if t == nil {
+		return nil, errors.New("bigquery: nil parameter")
+	}
+	if t == timeType {
+		return timestampParamType, nil
+	}
 	switch t.Kind() {
 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32:
 		return int64ParamType, nil
+
 	case reflect.Float32, reflect.Float64:
 		return float64ParamType, nil
+
 	case reflect.Bool:
 		return boolParamType, nil
+
 	case reflect.String:
 		return stringParamType, nil
-	case reflect.Slice, reflect.Array:
-		if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 {
+
+	case reflect.Slice:
+		if t.Elem().Kind() == reflect.Uint8 {
 			return bytesParamType, nil
 		}
+		fallthrough
+
+	case reflect.Array:
 		et, err := paramType(t.Elem())
 		if err != nil {
 			return nil, err
 		}
 		return &bq.QueryParameterType{Type: "ARRAY", ArrayType: et}, nil
+
+	case reflect.Ptr:
+		if t.Elem().Kind() != reflect.Struct {
+			break
+		}
+		t = t.Elem()
+		fallthrough
+
+	case reflect.Struct:
+		var fts []*bq.QueryParameterTypeStructTypes
+		for _, f := range fieldCache.Fields(t) {
+			pt, err := paramType(f.Type)
+			if err != nil {
+				return nil, err
+			}
+			fts = append(fts, &bq.QueryParameterTypeStructTypes{
+				Name: f.Name,
+				Type: pt,
+			})
+		}
+		return &bq.QueryParameterType{Type: "STRUCT", StructTypes: fts}, nil
 	}
-	if t == timeType {
-		return timestampParamType, nil
-	}
-	return nil, fmt.Errorf("Go type %s cannot be represented as a parameter type", t)
+	return nil, fmt.Errorf("bigquery: Go type %s cannot be represented as a parameter type", t)
 }
 
-func paramValue(x interface{}) (bq.QueryParameterValue, error) {
-	// convenience function for scalar value
-	sval := func(s string) bq.QueryParameterValue {
-		return bq.QueryParameterValue{Value: s}
+func paramValue(v reflect.Value) (bq.QueryParameterValue, error) {
+	var res bq.QueryParameterValue
+	if !v.IsValid() {
+		return res, errors.New("bigquery: nil parameter")
 	}
-	switch x := x.(type) {
-	case []byte:
-		return sval(base64.StdEncoding.EncodeToString(x)), nil
-	case time.Time:
-		return sval(x.Format(timestampFormat)), nil
+	t := v.Type()
+	if t == timeType {
+		res.Value = v.Interface().(time.Time).Format(timestampFormat)
+		return res, nil
 	}
-	t := reflect.TypeOf(x)
 	switch t.Kind() {
-	case reflect.Slice, reflect.Array:
+	case reflect.Slice:
+		if t.Elem().Kind() == reflect.Uint8 {
+			res.Value = base64.StdEncoding.EncodeToString(v.Interface().([]byte))
+			return res, nil
+		}
+		fallthrough
+
+	case reflect.Array:
 		var vals []*bq.QueryParameterValue
-		v := reflect.ValueOf(x)
 		for i := 0; i < v.Len(); i++ {
-			val, err := paramValue(v.Index(i).Interface())
+			val, err := paramValue(v.Index(i))
 			if err != nil {
 				return bq.QueryParameterValue{}, err
 			}
 			vals = append(vals, &val)
 		}
 		return bq.QueryParameterValue{ArrayValues: vals}, nil
+
+	case reflect.Ptr:
+		if t.Elem().Kind() != reflect.Struct {
+			return res, fmt.Errorf("bigquery: Go type %s cannot be represented as a parameter value", t)
+		}
+		t = t.Elem()
+		v = v.Elem()
+		if !v.IsValid() {
+			// nil pointer becomes empty value
+			return res, nil
+		}
+		fallthrough
+
+	case reflect.Struct:
+		fields := fieldCache.Fields(t)
+		res.StructValues = map[string]bq.QueryParameterValue{}
+		for _, f := range fields {
+			fv := v.FieldByIndex(f.Index)
+			fp, err := paramValue(fv)
+			if err != nil {
+				return bq.QueryParameterValue{}, err
+			}
+			res.StructValues[f.Name] = fp
+		}
+		return res, nil
 	}
-	return sval(fmt.Sprint(x)), nil
+	// None of the above: assume a scalar type. (If it's not a valid type,
+	// paramType will catch the error.)
+	res.Value = fmt.Sprint(v.Interface())
+	return res, nil
 }
diff --git a/bigquery/params_test.go b/bigquery/params_test.go
index 97c75cc..5a68204 100644
--- a/bigquery/params_test.go
+++ b/bigquery/params_test.go
@@ -41,45 +41,94 @@
 		"2016-03-20 04:22:09.000005-01:02"},
 }
 
+type S1 struct {
+	A int
+	B *S2
+	C bool
+}
+
+type S2 struct {
+	D string
+	e int
+}
+
+var s1 = S1{
+	A: 1,
+	B: &S2{D: "s"},
+	C: true,
+}
+
+func sval(s string) bq.QueryParameterValue {
+	return bq.QueryParameterValue{Value: s}
+}
+
 func TestParamValueScalar(t *testing.T) {
 	for _, test := range scalarTests {
-		got, err := paramValue(test.val)
+		got, err := paramValue(reflect.ValueOf(test.val))
 		if err != nil {
 			t.Errorf("%v: got %v, want nil", test.val, err)
 			continue
 		}
-		if got.ArrayValues != nil {
-			t.Errorf("%v, ArrayValues: got %v, expected nil", test.val, got.ArrayValues)
-		}
-		if got.StructValues != nil {
-			t.Errorf("%v, StructValues: got %v, expected nil", test.val, got.StructValues)
-		}
-		if got.Value != test.want {
-			t.Errorf("%v: got %q, want %q", test.val, got.Value, test.want)
+		want := sval(test.want)
+		if !reflect.DeepEqual(got, want) {
+			t.Errorf("%v:\ngot  %+v\nwant %+v", test.val, got, want)
 		}
 	}
 }
 
 func TestParamValueArray(t *testing.T) {
+	qpv := bq.QueryParameterValue{ArrayValues: []*bq.QueryParameterValue{
+		{Value: "1"},
+		{Value: "2"},
+	},
+	}
 	for _, test := range []struct {
 		val  interface{}
-		want []string
+		want bq.QueryParameterValue
 	}{
-		{[]int(nil), []string{}},
-		{[]int{}, []string{}},
-		{[]int{1, 2}, []string{"1", "2"}},
-		{[3]int{1, 2, 3}, []string{"1", "2", "3"}},
+		{[]int(nil), bq.QueryParameterValue{}},
+		{[]int{}, bq.QueryParameterValue{}},
+		{[]int{1, 2}, qpv},
+		{[2]int{1, 2}, qpv},
 	} {
-		got, err := paramValue(test.val)
+		got, err := paramValue(reflect.ValueOf(test.val))
 		if err != nil {
 			t.Fatal(err)
 		}
-		var want bq.QueryParameterValue
-		for _, s := range test.want {
-			want.ArrayValues = append(want.ArrayValues, &bq.QueryParameterValue{Value: s})
+		if !reflect.DeepEqual(got, test.want) {
+			t.Errorf("%#v:\ngot  %+v\nwant %+v", test.val, got, test.want)
 		}
-		if !reflect.DeepEqual(got, want) {
-			t.Errorf("%#v:\ngot  %+v\nwant %+v", test.val, got, want)
+	}
+}
+
+func TestParamValueStruct(t *testing.T) {
+	got, err := paramValue(reflect.ValueOf(s1))
+	if err != nil {
+		t.Fatal(err)
+	}
+	want := bq.QueryParameterValue{
+		StructValues: map[string]bq.QueryParameterValue{
+			"A": sval("1"),
+			"B": bq.QueryParameterValue{
+				StructValues: map[string]bq.QueryParameterValue{
+					"D": sval("s"),
+				},
+			},
+			"C": sval("true"),
+		},
+	}
+	if !reflect.DeepEqual(got, want) {
+		t.Errorf("got  %+v\nwant %+v", got, want)
+	}
+}
+
+func TestParamValueErrors(t *testing.T) {
+	// paramValue lets a few invalid types through, but paramType catches them.
+	// Since we never call one without the other that's fine.
+	for _, val := range []interface{}{nil, new([]int)} {
+		_, err := paramValue(reflect.ValueOf(val))
+		if err == nil {
+			t.Errorf("%v (%T): got nil, want error", val, val)
 		}
 	}
 }
@@ -101,6 +150,19 @@
 		{[]byte("foo"), bytesParamType},
 		{[]int{}, &bq.QueryParameterType{Type: "ARRAY", ArrayType: int64ParamType}},
 		{[3]bool{}, &bq.QueryParameterType{Type: "ARRAY", ArrayType: boolParamType}},
+		{S1{}, &bq.QueryParameterType{
+			Type: "STRUCT",
+			StructTypes: []*bq.QueryParameterTypeStructTypes{
+				{Name: "A", Type: int64ParamType},
+				{Name: "B", Type: &bq.QueryParameterType{
+					Type: "STRUCT",
+					StructTypes: []*bq.QueryParameterTypeStructTypes{
+						{Name: "D", Type: stringParamType},
+					},
+				}},
+				{Name: "C", Type: boolParamType},
+			},
+		}},
 	} {
 		got, err := paramType(reflect.TypeOf(test.val))
 		if err != nil {
@@ -112,6 +174,17 @@
 	}
 }
 
+func TestParamTypeErrors(t *testing.T) {
+	for _, val := range []interface{}{
+		nil, uint(0), new([]int), make(chan int),
+	} {
+		_, err := paramType(reflect.TypeOf(val))
+		if err == nil {
+			t.Errorf("%v (%T): got nil, want error", val, val)
+		}
+	}
+}
+
 func TestIntegration_ScalarParam(t *testing.T) {
 	c := getClient(t)
 	for _, test := range scalarTests {
@@ -125,7 +198,7 @@
 	}
 }
 
-func TestIntegration_ArrayParam(t *testing.T) {
+func TestIntegration_OtherParam(t *testing.T) {
 	c := getClient(t)
 	for _, test := range []struct {
 		val  interface{}
@@ -135,6 +208,8 @@
 		{[]int{}, []Value(nil)},
 		{[]int{1, 2}, []Value{int64(1), int64(2)}},
 		{[3]int{1, 2, 3}, []Value{int64(1), int64(2), int64(3)}},
+		{S1{}, []Value{int64(0), nil, false}},
+		{s1, []Value{int64(1), []Value{"s"}, true}},
 	} {
 		got, err := paramRoundTrip(c, test.val)
 		if err != nil {
diff --git a/bigquery/query.go b/bigquery/query.go
index af40d23..ce6eb86 100644
--- a/bigquery/query.go
+++ b/bigquery/query.go
@@ -123,6 +123,8 @@
 	// string: STRING
 	// []byte: BYTES
 	// time.Time: TIMESTAMP
+	// Arrays and slices of the above.
+	// Structs of the above. Only the exported fields are used.
 	Value interface{}
 }
 
@@ -207,7 +209,7 @@
 		conf.DestinationTable = q.Dst.tableRefProto()
 	}
 	for _, p := range q.Parameters {
-		pv, err := paramValue(p.Value)
+		pv, err := paramValue(reflect.ValueOf(p.Value))
 		if err != nil {
 			return err
 		}