feat(spanner): add DCP DirectPath fallback
diff --git a/spanner/client.go b/spanner/client.go index 4e391e4..59fb9f5 100644 --- a/spanner/client.go +++ b/spanner/client.go
@@ -82,6 +82,12 @@ // MinSessions for Experimental Host connection experimentalHostMinSessions = 0 + + // DirectPath fallback policy used by both non-DCP grpc-gcp fallback and the + // DCP DirectPath/CloudPath wrapper. + directPathFallbackErrorRateThreshold = float32(1) + directPathFallbackMinFailedCalls = 1 + directPathFallbackPeriod = time.Minute * 3 ) const ( @@ -579,7 +585,13 @@ dial := func(dialCtx context.Context) (gtransport.ConnPool, error) { return gtransport.DialPool(dialCtx, allClientOpts(1, config.Compression, config.EnableDirectAccess, dcpOpts...)...) } - dcp, err := newDynamicChannelPool(ctx, sc, config.DynamicChannelPoolConfig, 0, dial) + var fallbackDial func(context.Context) (gtransport.ConnPool, error) + if isFallbackEnabled && isDirectPathEnabled { + fallbackDial = func(dialCtx context.Context) (gtransport.ConnPool, error) { + return gtransport.DialPool(dialCtx, append(allClientOpts(1, config.Compression, config.EnableDirectAccess, dcpOpts...), internaloption.EnableDirectPath(false))...) + } + } + dcp, err := newDynamicChannelPool(ctx, sc, config.DynamicChannelPoolConfig, 0, dial, fallbackDial) if err != nil { return nil, err } @@ -620,9 +632,9 @@ fbOpts := grpcgcp.NewGCPFallbackOptions() fbOpts.EnableFallback = true - fbOpts.ErrorRateThreshold = 1 - fbOpts.MinFailedCalls = 1 - fbOpts.Period = time.Minute * 3 + fbOpts.ErrorRateThreshold = directPathFallbackErrorRateThreshold + fbOpts.MinFailedCalls = directPathFallbackMinFailedCalls + fbOpts.Period = directPathFallbackPeriod if metricsTracerFactory != nil && metricsTracerFactory.meterProvider != nil { fbOpts.MeterProvider = metricsTracerFactory.meterProvider
diff --git a/spanner/dynamic_channel_pool.go b/spanner/dynamic_channel_pool.go index db481b9..05e8cee 100644 --- a/spanner/dynamic_channel_pool.go +++ b/spanner/dynamic_channel_pool.go
@@ -32,6 +32,7 @@ gtransport "google.golang.org/api/transport/grpc" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) const ( @@ -183,7 +184,10 @@ database string disableRouteToLeader bool - dial func(context.Context) (gtransport.ConnPool, error) + dial func(context.Context) (gtransport.ConnPool, error) + fallbackDial func(context.Context) (gtransport.ConnPool, error) + fallbackState *dcpFallbackState + rrIndex atomic.Uint64 nextID atomic.Uint64 totalRPCLoad atomic.Int32 @@ -202,7 +206,9 @@ drainingCount atomic.Int64 } -// dcpEntry represents one logical DCP slot. +// dcpEntry represents one logical DCP slot. In DirectPath fallback mode the +// entry pool is a wrapper containing one DirectPath channel and one CloudPath +// fallback channel. type dcpEntry struct { id uint64 metricSlot int64 // bounded slot id used for metric cardinality @@ -219,7 +225,7 @@ } // newDynamicChannelPool creates the initial channel set and starts scale workers. -func newDynamicChannelPool(ctx context.Context, sc *sessionClient, cfg DynamicChannelPoolConfig, initial int, dial func(context.Context) (gtransport.ConnPool, error)) (*dynamicChannelPool, error) { +func newDynamicChannelPool(ctx context.Context, sc *sessionClient, cfg DynamicChannelPoolConfig, initial int, dial func(context.Context) (gtransport.ConnPool, error), fallbackDial func(context.Context) (gtransport.ConnPool, error)) (*dynamicChannelPool, error) { cfg, err := normalizeDCPConfig(cfg) if err != nil { return nil, err @@ -237,6 +243,8 @@ database: sc.database, disableRouteToLeader: sc.disableRouteToLeader, dial: dial, + fallbackDial: fallbackDial, + fallbackState: &dcpFallbackState{}, scaleUpSignal: make(chan struct{}, 1), done: make(chan struct{}), } @@ -381,18 +389,29 @@ p.metricSlotMu.Unlock() } -// newEntry dials one DCP entry. +// newEntry dials one DCP entry. When fallbackDial is set, the entry uses a +// DirectPath/CloudPath wrapper but still appears as one logical DCP slot. func (p *dynamicChannelPool) newEntry(ctx context.Context, prime bool) (*dcpEntry, error) { id := p.nextID.Add(1) metricSlot, err := p.allocateMetricSlot() if err != nil { return nil, err } - entryPool, err := p.dial(ctx) + primary, err := p.dial(ctx) if err != nil { p.releaseMetricSlot(metricSlot) return nil, err } + var entryPool gtransport.ConnPool = primary + if p.fallbackDial != nil { + fallback, err := p.fallbackDial(ctx) + if err != nil { + primary.Close() + p.releaseMetricSlot(metricSlot) + return nil, err + } + entryPool = &dcpFallbackSlot{id: id, direct: primary, cloud: fallback, state: p.fallbackState} + } e := &dcpEntry{id: id, metricSlot: metricSlot, pool: entryPool, parent: p} now := time.Now().UnixNano() e.createdAt.Store(now) @@ -872,6 +891,148 @@ func (e *dcpEntry) applyPenalty(ctx context.Context, err error) {} +// dcpFallbackState is shared by all DirectPath fallback slots in the pool so a +// primary DirectPath outage can move the whole DCP wrapper pool to CloudPath. +type dcpFallbackState struct { + fallbackActive atomic.Bool + primarySuccesses atomic.Uint64 + primaryFailures atomic.Uint64 + lastPrimaryReset atomic.Int64 +} + +// dcpFallbackSlot is one logical DCP slot backed by two physical channels: +// DirectPath for the primary path and CloudPath for fallback. +type dcpFallbackSlot struct { + id uint64 + direct gtransport.ConnPool + cloud gtransport.ConnPool + state *dcpFallbackState +} + +func (s *dcpFallbackSlot) Conn() *grpc.ClientConn { + if s.state.fallbackActive.Load() { + return s.cloud.Conn() + } + return s.direct.Conn() +} + +func (s *dcpFallbackSlot) Num() int { return 1 } + +func (s *dcpFallbackSlot) Close() error { + e1 := s.direct.Close() + e2 := s.cloud.Close() + return errors.Join(e1, e2) +} + +func (s *dcpFallbackSlot) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error { + if s.state.fallbackActive.Load() { + err := s.cloud.Invoke(ctx, method, args, reply, opts...) + s.recordFallback(err) + return err + } + err := s.direct.Invoke(ctx, method, args, reply, opts...) + s.recordPrimary(err) + return err +} + +func (s *dcpFallbackSlot) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if s.state.fallbackActive.Load() { + st, err := s.cloud.NewStream(ctx, desc, method, opts...) + if err != nil { + s.recordFallback(err) + return st, err + } + return &dcpFallbackMonitoredStream{ClientStream: st, record: s.recordFallback}, nil + } + st, err := s.direct.NewStream(ctx, desc, method, opts...) + if err != nil { + s.recordPrimary(err) + return st, err + } + return &dcpFallbackMonitoredStream{ClientStream: st, record: s.recordPrimary}, nil +} + +// recordPrimary updates shared DirectPath health counters. +func (s *dcpFallbackSlot) recordPrimary(err error) { + s.resetPrimaryFallbackWindowIfNeeded() + if isDCPFallbackFailure(err) { + s.state.primaryFailures.Add(1) + s.maybeActivateFallback() + } else { + s.state.primarySuccesses.Add(1) + } +} + +// resetPrimaryFallbackWindowIfNeeded keeps DCP fallback activation counters on +// the same time window as the non-DCP grpc-gcp fallback Period. +func (s *dcpFallbackSlot) resetPrimaryFallbackWindowIfNeeded() { + now := time.Now().UnixNano() + last := s.state.lastPrimaryReset.Load() + if last == 0 { + s.state.lastPrimaryReset.CompareAndSwap(0, now) + return + } + if time.Duration(now-last) < directPathFallbackPeriod { + return + } + if s.state.lastPrimaryReset.CompareAndSwap(last, now) { + s.state.primaryFailures.Store(0) + s.state.primarySuccesses.Store(0) + } +} + +// recordFallback is intentionally a no-op. DCP DirectPath fallback is sticky +// once activated, matching non-DCP grpc-gcp fallback behavior. +func (s *dcpFallbackSlot) recordFallback(err error) { +} + +// maybeActivateFallback enables CloudPath after enough DirectPath samples show a +// sustained Unavailable rate. The activation threshold, minimum failed calls, +// and counter window intentionally mirror the non-DCP grpc-gcp fallback config. +func (s *dcpFallbackSlot) maybeActivateFallback() { + failures := s.state.primaryFailures.Load() + successes := s.state.primarySuccesses.Load() + total := failures + successes + if total == 0 || failures < uint64(directPathFallbackMinFailedCalls) { + return + } + if float32(failures)/float32(total) < directPathFallbackErrorRateThreshold { + return + } + s.state.fallbackActive.Store(true) +} + +// isDCPFallbackFailure returns true for errors that should move DirectPath +// traffic to CloudPath fallback. +func isDCPFallbackFailure(err error) bool { + c := status.Code(err) + return c == codes.Unavailable +} + +type dcpFallbackMonitoredStream struct { + grpc.ClientStream + once sync.Once + record func(error) +} + +func (s *dcpFallbackMonitoredStream) RecvMsg(m interface{}) error { + err := s.ClientStream.RecvMsg(m) + if err != nil { + s.once.Do(func() { + if errors.Is(err, io.EOF) { + s.record(nil) + } else { + s.record(err) + } + }) + } + return err +} + +func (s *dcpFallbackMonitoredStream) CloseSend() error { + return s.ClientStream.CloseSend() +} + func (p *dynamicChannelPool) recordScaleUp(added int) {} func (p *dynamicChannelPool) recordScaleDown(draining int) {}
diff --git a/spanner/dynamic_channel_pool_test.go b/spanner/dynamic_channel_pool_test.go index d3c6ce0..6418e83 100644 --- a/spanner/dynamic_channel_pool_test.go +++ b/spanner/dynamic_channel_pool_test.go
@@ -259,6 +259,30 @@ return nil, f.invokeErr } +func TestDynamicChannelPoolDirectPathFallbackUsesSharedState(t *testing.T) { + state := &dcpFallbackState{} + primary1 := &fakeDCPConnPool{invokeErr: status.Error(codes.Unavailable, "directpath unavailable")} + cloud1 := &fakeDCPConnPool{} + primary2 := &fakeDCPConnPool{} + cloud2 := &fakeDCPConnPool{} + slot1 := &dcpFallbackSlot{id: 1, direct: primary1, cloud: cloud1, state: state} + slot2 := &dcpFallbackSlot{id: 2, direct: primary2, cloud: cloud2, state: state} + + _ = slot1.Invoke(context.Background(), "/test", nil, nil) + if !state.fallbackActive.Load() { + t.Fatal("shared fallback state inactive after DirectPath failure threshold") + } + if err := slot2.Invoke(context.Background(), "/test", nil, nil); err != nil { + t.Fatalf("fallback slot invoke failed: %v", err) + } + if got := primary2.invokeCount; got != 0 { + t.Fatalf("slot2 primary invoke count = %d, want 0 after shared fallback", got) + } + if got := cloud2.invokeCount; got != 1 { + t.Fatalf("slot2 cloud invoke count = %d, want 1 after shared fallback", got) + } +} + func TestDynamicChannelPoolScaleUpPrimeFailureDoesNotPublishEntry(t *testing.T) { server, client, teardown := setupDCPMockedTestServer(t, testDCPConfig(1, 1, 2)) defer teardown() @@ -576,6 +600,53 @@ } } +func TestDynamicChannelPoolDirectPathFallbackSlotStaysPinnedAcrossFallback(t *testing.T) { + state := &dcpFallbackState{} + direct1 := &fakeDCPConnPool{} + cloud1 := &fakeDCPConnPool{} + direct2 := &fakeDCPConnPool{} + cloud2 := &fakeDCPConnPool{} + slot1 := &dcpFallbackSlot{id: 7, direct: direct1, cloud: cloud1, state: state} + slot2 := &dcpFallbackSlot{id: 8, direct: direct2, cloud: cloud2, state: state} + p := &dynamicChannelPool{cfg: testDCPConfig(2, 1, 2)} + entry1 := &dcpEntry{id: slot1.id, pool: slot1, parent: p} + entry2 := &dcpEntry{id: slot2.id, pool: slot2, parent: p} + entry1.state.Store(dcpStateActive) + entry2.state.Store(dcpStateActive) + entries := []*dcpEntry{entry1, entry2} + p.entries.Store(&entries) + + picked, err := p.pick(context.Background()) + if err != nil { + t.Fatalf("pick failed: %v", err) + } + if err := picked.pool.Invoke(context.Background(), "/test", nil, nil); err != nil { + t.Fatalf("direct invoke failed: %v", err) + } + state.fallbackActive.Store(true) + if err := picked.pool.Invoke(context.Background(), "/test", nil, nil); err != nil { + t.Fatalf("fallback invoke failed: %v", err) + } + + var pickedDirect, pickedCloud, otherDirect, otherCloud *fakeDCPConnPool + if picked.id == slot1.id { + pickedDirect, pickedCloud, otherDirect, otherCloud = direct1, cloud1, direct2, cloud2 + } else if picked.id == slot2.id { + pickedDirect, pickedCloud, otherDirect, otherCloud = direct2, cloud2, direct1, cloud1 + } else { + t.Fatalf("picked unexpected slot id %d", picked.id) + } + if got, want := pickedDirect.invokeCount, 1; got != want { + t.Fatalf("picked direct invoke count = %d, want %d", got, want) + } + if got, want := pickedCloud.invokeCount, 1; got != want { + t.Fatalf("picked cloud invoke count = %d, want %d", got, want) + } + if got := otherDirect.invokeCount + otherCloud.invokeCount; got != 0 { + t.Fatalf("other slot invoke count = %d, want 0", got) + } +} + func TestDynamicChannelPoolConfigDefaultsInitialChannelsToMinWhenInitialUnset(t *testing.T) { cfg, err := normalizeDCPConfig(DynamicChannelPoolConfig{DCPEnabled: true, DCPMinChannels: 8, DCPMaxChannels: 10}) if err != nil {