bigquery: support struct field parameters
See #390.
Change-Id: I4006c8f17dc7131997d096701420ea3fc1e13174
Reviewed-on: https://code-review.googlesource.com/9559
Reviewed-by: Ross Light <light@google.com>
diff --git a/bigquery/params.go b/bigquery/params.go
index 046b874..97ebcc3 100644
--- a/bigquery/params.go
+++ b/bigquery/params.go
@@ -16,16 +16,21 @@
import (
"encoding/base64"
+ "errors"
"fmt"
"reflect"
"time"
+ "cloud.google.com/go/internal/fields"
+
bq "google.golang.org/api/bigquery/v2"
)
// See https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#timestamp-type.
var timestampFormat = "2006-01-02 15:04:05.999999-07:00"
+var fieldCache = fields.NewCache(nil)
+
var (
int64ParamType = &bq.QueryParameterType{Type: "INT64"}
float64ParamType = &bq.QueryParameterType{Type: "FLOAT64"}
@@ -38,55 +43,118 @@
var timeType = reflect.TypeOf(time.Time{})
func paramType(t reflect.Type) (*bq.QueryParameterType, error) {
+ if t == nil {
+ return nil, errors.New("bigquery: nil parameter")
+ }
+ if t == timeType {
+ return timestampParamType, nil
+ }
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64ParamType, nil
+
case reflect.Float32, reflect.Float64:
return float64ParamType, nil
+
case reflect.Bool:
return boolParamType, nil
+
case reflect.String:
return stringParamType, nil
- case reflect.Slice, reflect.Array:
- if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 {
+
+ case reflect.Slice:
+ if t.Elem().Kind() == reflect.Uint8 {
return bytesParamType, nil
}
+ fallthrough
+
+ case reflect.Array:
et, err := paramType(t.Elem())
if err != nil {
return nil, err
}
return &bq.QueryParameterType{Type: "ARRAY", ArrayType: et}, nil
+
+ case reflect.Ptr:
+ if t.Elem().Kind() != reflect.Struct {
+ break
+ }
+ t = t.Elem()
+ fallthrough
+
+ case reflect.Struct:
+ var fts []*bq.QueryParameterTypeStructTypes
+ for _, f := range fieldCache.Fields(t) {
+ pt, err := paramType(f.Type)
+ if err != nil {
+ return nil, err
+ }
+ fts = append(fts, &bq.QueryParameterTypeStructTypes{
+ Name: f.Name,
+ Type: pt,
+ })
+ }
+ return &bq.QueryParameterType{Type: "STRUCT", StructTypes: fts}, nil
}
- if t == timeType {
- return timestampParamType, nil
- }
- return nil, fmt.Errorf("Go type %s cannot be represented as a parameter type", t)
+ return nil, fmt.Errorf("bigquery: Go type %s cannot be represented as a parameter type", t)
}
-func paramValue(x interface{}) (bq.QueryParameterValue, error) {
- // convenience function for scalar value
- sval := func(s string) bq.QueryParameterValue {
- return bq.QueryParameterValue{Value: s}
+func paramValue(v reflect.Value) (bq.QueryParameterValue, error) {
+ var res bq.QueryParameterValue
+ if !v.IsValid() {
+ return res, errors.New("bigquery: nil parameter")
}
- switch x := x.(type) {
- case []byte:
- return sval(base64.StdEncoding.EncodeToString(x)), nil
- case time.Time:
- return sval(x.Format(timestampFormat)), nil
+ t := v.Type()
+ if t == timeType {
+ res.Value = v.Interface().(time.Time).Format(timestampFormat)
+ return res, nil
}
- t := reflect.TypeOf(x)
switch t.Kind() {
- case reflect.Slice, reflect.Array:
+ case reflect.Slice:
+ if t.Elem().Kind() == reflect.Uint8 {
+ res.Value = base64.StdEncoding.EncodeToString(v.Interface().([]byte))
+ return res, nil
+ }
+ fallthrough
+
+ case reflect.Array:
var vals []*bq.QueryParameterValue
- v := reflect.ValueOf(x)
for i := 0; i < v.Len(); i++ {
- val, err := paramValue(v.Index(i).Interface())
+ val, err := paramValue(v.Index(i))
if err != nil {
return bq.QueryParameterValue{}, err
}
vals = append(vals, &val)
}
return bq.QueryParameterValue{ArrayValues: vals}, nil
+
+ case reflect.Ptr:
+ if t.Elem().Kind() != reflect.Struct {
+ return res, fmt.Errorf("bigquery: Go type %s cannot be represented as a parameter value", t)
+ }
+ t = t.Elem()
+ v = v.Elem()
+ if !v.IsValid() {
+ // nil pointer becomes empty value
+ return res, nil
+ }
+ fallthrough
+
+ case reflect.Struct:
+ fields := fieldCache.Fields(t)
+ res.StructValues = map[string]bq.QueryParameterValue{}
+ for _, f := range fields {
+ fv := v.FieldByIndex(f.Index)
+ fp, err := paramValue(fv)
+ if err != nil {
+ return bq.QueryParameterValue{}, err
+ }
+ res.StructValues[f.Name] = fp
+ }
+ return res, nil
}
- return sval(fmt.Sprint(x)), nil
+ // None of the above: assume a scalar type. (If it's not a valid type,
+ // paramType will catch the error.)
+ res.Value = fmt.Sprint(v.Interface())
+ return res, nil
}
diff --git a/bigquery/params_test.go b/bigquery/params_test.go
index 97c75cc..5a68204 100644
--- a/bigquery/params_test.go
+++ b/bigquery/params_test.go
@@ -41,45 +41,94 @@
"2016-03-20 04:22:09.000005-01:02"},
}
+type S1 struct {
+ A int
+ B *S2
+ C bool
+}
+
+type S2 struct {
+ D string
+ e int
+}
+
+var s1 = S1{
+ A: 1,
+ B: &S2{D: "s"},
+ C: true,
+}
+
+func sval(s string) bq.QueryParameterValue {
+ return bq.QueryParameterValue{Value: s}
+}
+
func TestParamValueScalar(t *testing.T) {
for _, test := range scalarTests {
- got, err := paramValue(test.val)
+ got, err := paramValue(reflect.ValueOf(test.val))
if err != nil {
t.Errorf("%v: got %v, want nil", test.val, err)
continue
}
- if got.ArrayValues != nil {
- t.Errorf("%v, ArrayValues: got %v, expected nil", test.val, got.ArrayValues)
- }
- if got.StructValues != nil {
- t.Errorf("%v, StructValues: got %v, expected nil", test.val, got.StructValues)
- }
- if got.Value != test.want {
- t.Errorf("%v: got %q, want %q", test.val, got.Value, test.want)
+ want := sval(test.want)
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("%v:\ngot %+v\nwant %+v", test.val, got, want)
}
}
}
func TestParamValueArray(t *testing.T) {
+ qpv := bq.QueryParameterValue{ArrayValues: []*bq.QueryParameterValue{
+ {Value: "1"},
+ {Value: "2"},
+ },
+ }
for _, test := range []struct {
val interface{}
- want []string
+ want bq.QueryParameterValue
}{
- {[]int(nil), []string{}},
- {[]int{}, []string{}},
- {[]int{1, 2}, []string{"1", "2"}},
- {[3]int{1, 2, 3}, []string{"1", "2", "3"}},
+ {[]int(nil), bq.QueryParameterValue{}},
+ {[]int{}, bq.QueryParameterValue{}},
+ {[]int{1, 2}, qpv},
+ {[2]int{1, 2}, qpv},
} {
- got, err := paramValue(test.val)
+ got, err := paramValue(reflect.ValueOf(test.val))
if err != nil {
t.Fatal(err)
}
- var want bq.QueryParameterValue
- for _, s := range test.want {
- want.ArrayValues = append(want.ArrayValues, &bq.QueryParameterValue{Value: s})
+ if !reflect.DeepEqual(got, test.want) {
+ t.Errorf("%#v:\ngot %+v\nwant %+v", test.val, got, test.want)
}
- if !reflect.DeepEqual(got, want) {
- t.Errorf("%#v:\ngot %+v\nwant %+v", test.val, got, want)
+ }
+}
+
+func TestParamValueStruct(t *testing.T) {
+ got, err := paramValue(reflect.ValueOf(s1))
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := bq.QueryParameterValue{
+ StructValues: map[string]bq.QueryParameterValue{
+ "A": sval("1"),
+ "B": bq.QueryParameterValue{
+ StructValues: map[string]bq.QueryParameterValue{
+ "D": sval("s"),
+ },
+ },
+ "C": sval("true"),
+ },
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("got %+v\nwant %+v", got, want)
+ }
+}
+
+func TestParamValueErrors(t *testing.T) {
+ // paramValue lets a few invalid types through, but paramType catches them.
+ // Since we never call one without the other that's fine.
+ for _, val := range []interface{}{nil, new([]int)} {
+ _, err := paramValue(reflect.ValueOf(val))
+ if err == nil {
+ t.Errorf("%v (%T): got nil, want error", val, val)
}
}
}
@@ -101,6 +150,19 @@
{[]byte("foo"), bytesParamType},
{[]int{}, &bq.QueryParameterType{Type: "ARRAY", ArrayType: int64ParamType}},
{[3]bool{}, &bq.QueryParameterType{Type: "ARRAY", ArrayType: boolParamType}},
+ {S1{}, &bq.QueryParameterType{
+ Type: "STRUCT",
+ StructTypes: []*bq.QueryParameterTypeStructTypes{
+ {Name: "A", Type: int64ParamType},
+ {Name: "B", Type: &bq.QueryParameterType{
+ Type: "STRUCT",
+ StructTypes: []*bq.QueryParameterTypeStructTypes{
+ {Name: "D", Type: stringParamType},
+ },
+ }},
+ {Name: "C", Type: boolParamType},
+ },
+ }},
} {
got, err := paramType(reflect.TypeOf(test.val))
if err != nil {
@@ -112,6 +174,17 @@
}
}
+func TestParamTypeErrors(t *testing.T) {
+ for _, val := range []interface{}{
+ nil, uint(0), new([]int), make(chan int),
+ } {
+ _, err := paramType(reflect.TypeOf(val))
+ if err == nil {
+ t.Errorf("%v (%T): got nil, want error", val, val)
+ }
+ }
+}
+
func TestIntegration_ScalarParam(t *testing.T) {
c := getClient(t)
for _, test := range scalarTests {
@@ -125,7 +198,7 @@
}
}
-func TestIntegration_ArrayParam(t *testing.T) {
+func TestIntegration_OtherParam(t *testing.T) {
c := getClient(t)
for _, test := range []struct {
val interface{}
@@ -135,6 +208,8 @@
{[]int{}, []Value(nil)},
{[]int{1, 2}, []Value{int64(1), int64(2)}},
{[3]int{1, 2, 3}, []Value{int64(1), int64(2), int64(3)}},
+ {S1{}, []Value{int64(0), nil, false}},
+ {s1, []Value{int64(1), []Value{"s"}, true}},
} {
got, err := paramRoundTrip(c, test.val)
if err != nil {
diff --git a/bigquery/query.go b/bigquery/query.go
index af40d23..ce6eb86 100644
--- a/bigquery/query.go
+++ b/bigquery/query.go
@@ -123,6 +123,8 @@
// string: STRING
// []byte: BYTES
// time.Time: TIMESTAMP
+ // Arrays and slices of the above.
+ // Structs of the above. Only the exported fields are used.
Value interface{}
}
@@ -207,7 +209,7 @@
conf.DestinationTable = q.Dst.tableRefProto()
}
for _, p := range q.Parameters {
- pv, err := paramValue(p.Value)
+ pv, err := paramValue(reflect.ValueOf(p.Value))
if err != nil {
return err
}