bigquery: check for recursive types during schema inference

Keep track of which types we've seen to avoid infinite recursion
on recursive types.

Change-Id: Ibb36190adde8199f1ebfc32f9260661b3834c747
Reviewed-on: https://code-review.googlesource.com/9831
Reviewed-by: Sarah Adams <shadams@google.com>
diff --git a/bigquery/schema.go b/bigquery/schema.go
index b632f9d..e202d9c 100644
--- a/bigquery/schema.go
+++ b/bigquery/schema.go
@@ -16,6 +16,7 @@
 
 import (
 	"errors"
+	"fmt"
 	"reflect"
 
 	bq "google.golang.org/api/bigquery/v2"
@@ -128,34 +129,43 @@
 // InferSchema tries to derive a BigQuery schema from the supplied struct value.
 // NOTE: All fields in the returned Schema are configured to be required,
 // unless the corresponding field in the supplied struct is a slice or array.
+//
 // It is considered an error if the struct (including nested structs) contains
 // any exported fields that are pointers or one of the following types:
 // uint, uint64, uintptr, map, interface, complex64, complex128, func, chan.
 // In these cases, an error will be returned.
 // Future versions may handle these cases without error.
+//
+// Recursively defined structs are also disallowed.
 func InferSchema(st interface{}) (Schema, error) {
-	return inferStruct(reflect.TypeOf(st))
+	return inferSchemaReflect(reflect.TypeOf(st))
 }
 
-func inferStruct(rt reflect.Type) (Schema, error) {
-	switch rt.Kind() {
+func inferSchemaReflect(t reflect.Type) (Schema, error) {
+	return inferStruct(t, map[reflect.Type]bool{})
+}
+func inferStruct(t reflect.Type, seen map[reflect.Type]bool) (Schema, error) {
+	if seen[t] {
+		return nil, fmt.Errorf("bigquery: schema inference for recursive type %s", t)
+	}
+	seen[t] = true
+	switch t.Kind() {
 	case reflect.Ptr:
-		if rt.Elem().Kind() != reflect.Struct {
+		if t.Elem().Kind() != reflect.Struct {
 			return nil, errNoStruct
 		}
-		rt = rt.Elem()
+		t = t.Elem()
 		fallthrough
 
 	case reflect.Struct:
-		return inferFields(rt)
+		return inferFields(t, seen)
 	default:
 		return nil, errNoStruct
 	}
-
 }
 
 // inferFieldSchema infers the FieldSchema for a Go type
-func inferFieldSchema(rt reflect.Type) (*FieldSchema, error) {
+func inferFieldSchema(rt reflect.Type, seen map[reflect.Type]bool) (*FieldSchema, error) {
 	switch rt {
 	case typeOfByteSlice:
 		return &FieldSchema{Required: true, Type: BytesFieldType}, nil
@@ -179,7 +189,7 @@
 			return nil, errUnsupportedFieldType
 		}
 
-		f, err := inferFieldSchema(et)
+		f, err := inferFieldSchema(et, seen)
 		if err != nil {
 			return nil, err
 		}
@@ -187,7 +197,7 @@
 		f.Required = false
 		return f, nil
 	case reflect.Struct, reflect.Ptr:
-		nested, err := inferStruct(rt)
+		nested, err := inferStruct(rt, seen)
 		if err != nil {
 			return nil, err
 		}
@@ -204,14 +214,14 @@
 }
 
 // inferFields extracts all exported field types from struct type.
-func inferFields(rt reflect.Type) (Schema, error) {
+func inferFields(rt reflect.Type, seen map[reflect.Type]bool) (Schema, error) {
 	var s Schema
 	fields, err := fieldCache.Fields(rt)
 	if err != nil {
 		return nil, err
 	}
 	for _, field := range fields {
-		f, err := inferFieldSchema(field.Type)
+		f, err := inferFieldSchema(field.Type, seen)
 		if err != nil {
 			return nil, err
 		}
diff --git a/bigquery/schema_test.go b/bigquery/schema_test.go
index a223a8d..c3eef55 100644
--- a/bigquery/schema_test.go
+++ b/bigquery/schema_test.go
@@ -496,6 +496,18 @@
 	}
 }
 
+func TestRecursiveInference(t *testing.T) {
+	type List struct {
+		Val  int
+		Next *List
+	}
+
+	_, err := InferSchema(List{})
+	if err == nil {
+		t.Fatal("got nil, want error")
+	}
+}
+
 type withTags struct {
 	NoTag         int
 	ExcludeTag    int `bigquery:"-"`
diff --git a/bigquery/uploader.go b/bigquery/uploader.go
index 5ebaeb0..a0ed0ae 100644
--- a/bigquery/uploader.go
+++ b/bigquery/uploader.go
@@ -123,7 +123,7 @@
 		return nil, false, nil
 	}
 	// TODO(jba): cache schema inference to speed this up.
-	schema, err := inferStruct(v.Type())
+	schema, err := inferSchemaReflect(v.Type())
 	if err != nil {
 		return nil, false, err
 	}