blob: 825256dec0b48929ae118d6c8a3bf936123907bd [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bigquery
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"io"
"math/big"
"cloud.google.com/go/civil"
"github.com/apache/arrow/go/v15/arrow"
"github.com/apache/arrow/go/v15/arrow/array"
"github.com/apache/arrow/go/v15/arrow/ipc"
"github.com/apache/arrow/go/v15/arrow/memory"
"google.golang.org/api/iterator"
)
// ArrowRecordBatch represents an Arrow RecordBatch with the source PartitionID
type ArrowRecordBatch struct {
reader io.Reader
// Serialized Arrow Record Batch.
Data []byte
// Serialized Arrow Schema.
Schema []byte
// Source partition ID. In the Storage API world, it represents the ReadStream.
PartitionID string
}
// Read makes ArrowRecordBatch implements io.Reader
func (r *ArrowRecordBatch) Read(p []byte) (int, error) {
if r.reader == nil {
buf := bytes.NewBuffer(r.Schema)
buf.Write(r.Data)
r.reader = buf
}
return r.reader.Read(p)
}
// ArrowIterator represents a way to iterate through a stream of arrow records.
// Experimental: this interface is experimental and may be modified or removed in future versions,
// regardless of any other documented package stability guarantees.
type ArrowIterator interface {
Next() (*ArrowRecordBatch, error)
Schema() Schema
SerializedArrowSchema() []byte
}
// NewArrowIteratorReader allows to consume an ArrowIterator as an io.Reader.
// Experimental: this interface is experimental and may be modified or removed in future versions,
// regardless of any other documented package stability guarantees.
func NewArrowIteratorReader(it ArrowIterator) io.Reader {
return &arrowIteratorReader{
it: it,
}
}
type arrowIteratorReader struct {
buf *bytes.Buffer
it ArrowIterator
}
// Read makes ArrowIteratorReader implement io.Reader
func (r *arrowIteratorReader) Read(p []byte) (int, error) {
if r.it == nil {
return -1, errors.New("bigquery: nil ArrowIterator")
}
if r.buf == nil { // init with schema
buf := bytes.NewBuffer(r.it.SerializedArrowSchema())
r.buf = buf
}
n, err := r.buf.Read(p)
if err == io.EOF {
batch, err := r.it.Next()
if err == iterator.Done {
return 0, io.EOF
}
r.buf.Write(batch.Data)
return r.Read(p)
}
return n, err
}
type arrowDecoder struct {
allocator memory.Allocator
tableSchema Schema
arrowSchema *arrow.Schema
}
func newArrowDecoder(arrowSerializedSchema []byte, schema Schema) (*arrowDecoder, error) {
buf := bytes.NewBuffer(arrowSerializedSchema)
r, err := ipc.NewReader(buf)
if err != nil {
return nil, err
}
defer r.Release()
p := &arrowDecoder{
tableSchema: schema,
arrowSchema: r.Schema(),
allocator: memory.DefaultAllocator,
}
return p, nil
}
func (ap *arrowDecoder) createIPCReaderForBatch(arrowRecordBatch *ArrowRecordBatch) (*ipc.Reader, error) {
return ipc.NewReader(
arrowRecordBatch,
ipc.WithSchema(ap.arrowSchema),
ipc.WithAllocator(ap.allocator),
)
}
// decodeArrowRecords decodes BQ ArrowRecordBatch into rows of []Value.
func (ap *arrowDecoder) decodeArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([][]Value, error) {
r, err := ap.createIPCReaderForBatch(arrowRecordBatch)
if err != nil {
return nil, err
}
defer r.Release()
rs := make([][]Value, 0)
for r.Next() {
rec := r.Record()
values, err := ap.convertArrowRecordValue(rec)
if err != nil {
return nil, err
}
rs = append(rs, values...)
}
return rs, nil
}
// decodeRetainedArrowRecords decodes BQ ArrowRecordBatch into a list of retained arrow.Record.
func (ap *arrowDecoder) decodeRetainedArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([]arrow.Record, error) {
r, err := ap.createIPCReaderForBatch(arrowRecordBatch)
if err != nil {
return nil, err
}
defer r.Release()
records := []arrow.Record{}
for r.Next() {
rec := r.Record()
rec.Retain()
records = append(records, rec)
}
return records, nil
}
// convertArrowRows converts an arrow.Record into a series of Value slices.
func (ap *arrowDecoder) convertArrowRecordValue(record arrow.Record) ([][]Value, error) {
rs := make([][]Value, record.NumRows())
for i := range rs {
rs[i] = make([]Value, record.NumCols())
}
for j, col := range record.Columns() {
fs := ap.tableSchema[j]
ft := ap.arrowSchema.Field(j).Type
for i := 0; i < col.Len(); i++ {
v, err := convertArrowValue(col, i, ft, fs)
if err != nil {
return nil, fmt.Errorf("found arrow type %s, but could not convert value: %v", ap.arrowSchema.Field(j).Type, err)
}
rs[i][j] = v
}
}
return rs, nil
}
// convertArrow gets row value in the given column and converts to a Value.
// Arrow is a colunar storage, so we navigate first by column and get the row value.
// More details on conversions can be seen here: https://cloud.google.com/bigquery/docs/reference/storage#arrow_schema_details
func convertArrowValue(col arrow.Array, i int, ft arrow.DataType, fs *FieldSchema) (Value, error) {
if !col.IsValid(i) {
return nil, nil
}
switch ft.(type) {
case *arrow.BooleanType:
v := col.(*array.Boolean).Value(i)
return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
case *arrow.Int8Type:
v := col.(*array.Int8).Value(i)
return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
case *arrow.Int16Type:
v := col.(*array.Int16).Value(i)
return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
case *arrow.Int32Type:
v := col.(*array.Int32).Value(i)
return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
case *arrow.Int64Type:
v := col.(*array.Int64).Value(i)
return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
case *arrow.Float16Type:
v := col.(*array.Float16).Value(i)
return convertBasicType(fmt.Sprintf("%v", v.Float32()), fs.Type)
case *arrow.Float32Type:
v := col.(*array.Float32).Value(i)
return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
case *arrow.Float64Type:
v := col.(*array.Float64).Value(i)
return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
case *arrow.BinaryType:
v := col.(*array.Binary).Value(i)
encoded := base64.StdEncoding.EncodeToString(v)
return convertBasicType(encoded, fs.Type)
case *arrow.StringType:
v := col.(*array.String).Value(i)
return convertBasicType(v, fs.Type)
case *arrow.Date32Type:
v := col.(*array.Date32).Value(i)
return convertBasicType(v.FormattedString(), fs.Type)
case *arrow.Date64Type:
v := col.(*array.Date64).Value(i)
return convertBasicType(v.FormattedString(), fs.Type)
case *arrow.TimestampType:
v := col.(*array.Timestamp).Value(i)
dft := ft.(*arrow.TimestampType)
t := v.ToTime(dft.Unit)
if dft.TimeZone == "" { // Datetime
return Value(civil.DateTimeOf(t)), nil
}
return Value(t.UTC()), nil // Timestamp
case *arrow.Time32Type:
v := col.(*array.Time32).Value(i)
return convertBasicType(v.FormattedString(arrow.Microsecond), fs.Type)
case *arrow.Time64Type:
v := col.(*array.Time64).Value(i)
return convertBasicType(v.FormattedString(arrow.Microsecond), fs.Type)
case *arrow.Decimal128Type:
dft := ft.(*arrow.Decimal128Type)
v := col.(*array.Decimal128).Value(i)
rat := big.NewRat(1, 1)
rat.Num().SetBytes(v.BigInt().Bytes())
d := rat.Denom()
d.Exp(big.NewInt(10), big.NewInt(int64(dft.Scale)), nil)
return Value(rat), nil
case *arrow.Decimal256Type:
dft := ft.(*arrow.Decimal256Type)
v := col.(*array.Decimal256).Value(i)
rat := big.NewRat(1, 1)
rat.Num().SetBytes(v.BigInt().Bytes())
d := rat.Denom()
d.Exp(big.NewInt(10), big.NewInt(int64(dft.Scale)), nil)
return Value(rat), nil
case *arrow.ListType:
arr := col.(*array.List)
dft := ft.(*arrow.ListType)
values := []Value{}
start, end := arr.ValueOffsets(i)
slice := array.NewSlice(arr.ListValues(), start, end)
for j := 0; j < slice.Len(); j++ {
v, err := convertArrowValue(slice, j, dft.Elem(), fs)
if err != nil {
return nil, err
}
values = append(values, v)
}
return values, nil
case *arrow.StructType:
arr := col.(*array.Struct)
nestedValues := []Value{}
fields := ft.(*arrow.StructType).Fields()
if fs.Type == RangeFieldType {
rangeFieldSchema := &FieldSchema{
Type: fs.RangeElementType.Type,
}
start, err := convertArrowValue(arr.Field(0), i, fields[0].Type, rangeFieldSchema)
if err != nil {
return nil, err
}
end, err := convertArrowValue(arr.Field(1), i, fields[1].Type, rangeFieldSchema)
if err != nil {
return nil, err
}
rangeValue := &RangeValue{Start: start, End: end}
return Value(rangeValue), nil
}
for fIndex, f := range fields {
v, err := convertArrowValue(arr.Field(fIndex), i, f.Type, fs.Schema[fIndex])
if err != nil {
return nil, err
}
nestedValues = append(nestedValues, v)
}
return nestedValues, nil
default:
return nil, fmt.Errorf("unknown arrow type: %v", ft)
}
}