http2: track reused connections
nextStreamID was used as a means to determine if the connection was
being reused. Multiple requests can see a new connection because the
nextStreamID is updated after a ClientTrace reports it is being reused.
Updates golang/go#31982
Change-Id: Iaa4b62b217f015423cddb99fd86de75a352f8320
Reviewed-on: https://go-review.googlesource.com/c/net/+/176720
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 4ec0792..c0c80d8 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -28,6 +28,7 @@
"strconv"
"strings"
"sync"
+ "sync/atomic"
"time"
"golang.org/x/net/http/httpguts"
@@ -199,6 +200,7 @@
t *Transport
tconn net.Conn // usually *tls.Conn, except specialized impls
tlsState *tls.ConnectionState // nil only for specialized impls
+ reused uint32 // whether conn is being reused; atomic
singleUse bool // whether being used for a single http.Request
// readLoop goroutine fields:
@@ -440,7 +442,8 @@
t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err)
return nil, err
}
- traceGotConn(req, cc)
+ reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1)
+ traceGotConn(req, cc, reused)
res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req)
if err != nil && retry <= 6 {
if req, err = shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil {
@@ -2559,15 +2562,15 @@
trace.GetConn(hostPort)
}
-func traceGotConn(req *http.Request, cc *ClientConn) {
+func traceGotConn(req *http.Request, cc *ClientConn, reused bool) {
trace := httptrace.ContextClientTrace(req.Context())
if trace == nil || trace.GotConn == nil {
return
}
ci := httptrace.GotConnInfo{Conn: cc.tconn}
+ ci.Reused = reused
cc.mu.Lock()
- ci.Reused = cc.nextStreamID > 1
- ci.WasIdle = len(cc.streams) == 0 && ci.Reused
+ ci.WasIdle = len(cc.streams) == 0 && reused
if ci.WasIdle && !cc.lastActive.IsZero() {
ci.IdleTime = time.Now().Sub(cc.lastActive)
}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 6c4bed3..567eb74 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -19,6 +19,7 @@
"net"
"net/http"
"net/http/httptest"
+ "net/http/httptrace"
"net/textproto"
"net/url"
"os"
@@ -98,6 +99,15 @@
if err != nil {
t.Fatal(err)
}
+ var gotConnCnt int32
+ trace := &httptrace.ClientTrace{
+ GotConn: func(connInfo httptrace.GotConnInfo) {
+ if !connInfo.Reused {
+ atomic.AddInt32(&gotConnCnt, 1)
+ }
+ },
+ }
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
tr := &Transport{
AllowHTTP: true,
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
@@ -118,6 +128,9 @@
if got, want := string(body), "Hello, /foobar, http: true"; got != want {
t.Fatalf("response got %v, want %v", got, want)
}
+ if got, want := gotConnCnt, int32(1); got != want {
+ t.Errorf("Too many got connections: %d", gotConnCnt)
+ }
}
func TestTransport(t *testing.T) {
@@ -244,6 +257,14 @@
mu sync.Mutex
dials = map[string]int{}
)
+ var gotConnCnt int32
+ trace := &httptrace.ClientTrace{
+ GotConn: func(connInfo httptrace.GotConnInfo) {
+ if !connInfo.Reused {
+ atomic.AddInt32(&gotConnCnt, 1)
+ }
+ },
+ }
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
@@ -254,6 +275,7 @@
t.Error(err)
return
}
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
res, err := tr.RoundTrip(req)
if err != nil {
t.Error(err)
@@ -298,6 +320,9 @@
}); err != nil {
t.Errorf("State of pool after CloseIdleConnections: %v", err)
}
+ if got, want := gotConnCnt, int32(1); got != want {
+ t.Errorf("Too many got connections: %d", gotConnCnt)
+ }
}
func retry(tries int, delay time.Duration, fn func() error) error {