fix(spanner): use 'proto' as the decoder name
diff --git a/spanner/fast_decoder.go b/spanner/fast_decoder.go index 8ed8143..2f57d19 100644 --- a/spanner/fast_decoder.go +++ b/spanner/fast_decoder.go
@@ -37,7 +37,7 @@ } func (c customSpannerCodec) Name() string { - return "spanner-streaming-codec" + return "" } func (c customSpannerCodec) Marshal(v interface{}) ([]byte, error) { @@ -119,6 +119,7 @@ func (fastRowKind) isValue_Kind() {} func (p *partialResultSetDecoder) decodeFastPartialResultSet(data []byte, prs *sppb.PartialResultSet) error { + p.lastChunkSize = int32(len(data)) for len(data) > 0 { num, wire, n := protowire.ConsumeTag(data) if n < 0 { @@ -139,7 +140,7 @@ if err := proto.Unmarshal(b, prs.Metadata); err != nil { return err } - if p.row.fields == nil { + if p.row.fields == nil && prs.Metadata.RowType != nil { p.row.fields = prs.Metadata.RowType.Fields } case 4: // resume_token @@ -175,6 +176,16 @@ if err := proto.Unmarshal(b, prs.Stats); err != nil { return err } + case 9: // last + if wire != protowire.VarintType { + return spannerErrorf(codes.Internal, "invalid last wire type") + } + b, n := protowire.ConsumeVarint(data) + if n < 0 { + return spannerErrorf(codes.Internal, "corrupt last value") + } + data = data[n:] + prs.Last = b != 0 case 2: // values if wire != protowire.BytesType { return spannerErrorf(codes.Internal, "invalid values wire type") @@ -195,6 +206,22 @@ } } + if p.chunked { + p.chunked = false + lastIdx := len(p.curFastRow.cells) - 1 + if lastIdx < 0 { + return spannerErrorf(codes.FailedPrecondition, "got invalid chunked fast PartialResultSet with empty row") + } + temp := &SpannerValue{} + if err := decodeFastSpannerValueBytes(b, temp); err != nil { + return err + } + if err := mergeFast(&p.curFastRow.cells[lastIdx], temp); err != nil { + return err + } + continue + } + var cell *SpannerValue if len(p.curFastRow.cells) < cap(p.curFastRow.cells) { p.curFastRow.cells = p.curFastRow.cells[:len(p.curFastRow.cells)+1] @@ -213,10 +240,6 @@ if err := decodeFastSpannerValueBytes(b, cell); err != nil { return err } - if len(p.row.fields) > 0 && len(p.curFastRow.cells) == len(p.row.fields) { - p.completedFastRows = append(p.completedFastRows, p.curFastRow) - p.curFastRow = nil - } case 8: // precommit_token if wire != protowire.BytesType { return spannerErrorf(codes.Internal, "invalid precommit token wire type") @@ -238,9 +261,54 @@ data = data[vn:] } } + if prs.ChunkedValue { + p.chunked = true + } return nil } +func hasMoreValues(data []byte) bool { + if len(data) == 0 { + return false + } + num, _, n := protowire.ConsumeTag(data) + return n >= 0 && num == 2 +} + +func isMergeableFast(a *SpannerValue) bool { + return a.valType == 3 || a.valType == 6 +} + +func mergeFast(a, b *SpannerValue) error { + if a.valType != b.valType { + return spannerErrorf(codes.FailedPrecondition, "incompatible type in chunked fast decoding. expected valType %d, got %d", a.valType, b.valType) + } + switch a.valType { + case 3: + a.strVal += b.strVal + case 6: + if len(b.listVal) == 0 { + return nil + } + if len(a.listVal) == 0 { + a.listVal = b.listVal + return nil + } + la := len(a.listVal) - 1 + if isMergeableFast(a.listVal[la]) { + if err := mergeFast(a.listVal[la], b.listVal[0]); err != nil { + return err + } + b.listVal = b.listVal[1:] + } + a.listVal = append(a.listVal, b.listVal...) + default: + return spannerErrorf(codes.FailedPrecondition, "unsupported type merge in fast decoding (%d)", a.valType) + } + return nil +} + + func decodeFastSpannerValueBytes(valData []byte, cell *SpannerValue) error { for len(valData) > 0 { vnum, vwire, vn := protowire.ConsumeTag(valData) @@ -320,8 +388,8 @@ entryData = entryData[en:] _ = decodeFastSpannerValueBytes(vb, val) default: - en = protowire.ConsumeFieldValue(enum, ewire, entryData) - entryData = entryData[en:] + vn := protowire.ConsumeFieldValue(enum, ewire, entryData) + entryData = entryData[vn:] } } if key != "" { @@ -353,7 +421,30 @@ } func (p *partialResultSetDecoder) addFast(r *sppb.PartialResultSet) ([]*Row, *sppb.ResultSetMetadata, error) { + if r.Metadata != nil && p.row.fields == nil && r.Metadata.RowType != nil { + p.row.fields = r.Metadata.RowType.Fields + } var rows []*Row + lenFields := len(p.row.fields) + if lenFields > 0 && p.curFastRow != nil { + for len(p.curFastRow.cells) >= lenFields { + if len(p.curFastRow.cells) == lenFields && p.chunked { + // The last cell of this row is chunked, so the row is not complete yet. + break + } + // We have a complete row of lenFields cells! + completedRow := &fastRowData{cells: p.curFastRow.cells[:lenFields]} + p.completedFastRows = append(p.completedFastRows, completedRow) + // Slice curFastRow.cells for remaining cells + remainingCells := p.curFastRow.cells[lenFields:] + if len(remainingCells) == 0 { + p.curFastRow = nil + break + } + p.curFastRow = &fastRowData{cells: remainingCells} + } + } + for _, fast := range p.completedFastRows { var fresh *Row if len(p.rowPool) > 0 {
diff --git a/spanner/read.go b/spanner/read.go index 2955bda..7ad8c17 100644 --- a/spanner/read.go +++ b/spanner/read.go
@@ -781,7 +781,11 @@ return } if d.state == queueingRetryable && !d.isNewResumeToken(res.ResumeToken) { - d.bytesBetweenResumeTokens += int32(proto.Size(res)) + if d.rowd != nil && d.rowd.fastDecoding { + d.bytesBetweenResumeTokens += d.rowd.lastChunkSize + } else { + d.bytesBetweenResumeTokens += int32(proto.Size(res)) + } } d.changeState(d.state) return @@ -846,6 +850,7 @@ fastPool []*fastRowData curFastRow *fastRowData completedFastRows []*fastRowData + lastChunkSize int32 // Tracks true wire size for flow control }