http2: add a more full-featured test net.Conn
Add a net.Conn implementation that plays nicely with testsyncGroup,
implements read/write timeouts, and gives control over buffering
to let us write tests that cause writes to a Conn to block at
specific points in time.
Change-Id: I9d870b211ac9d938a8c4a221277981cdb821a6e4
Reviewed-on: https://go-review.googlesource.com/c/net/+/586246
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go
index 855f44e..36f080b 100644
--- a/http2/clientconn_test.go
+++ b/http2/clientconn_test.go
@@ -10,11 +10,10 @@
import (
"bytes"
"context"
- "errors"
"fmt"
"io"
- "net"
"net/http"
+ "os"
"reflect"
"runtime"
"slices"
@@ -104,7 +103,7 @@
roundtrips []*testRoundTrip
- netconn testClientConnNetConn
+ netconn *synctestNetConn
}
func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn {
@@ -114,22 +113,21 @@
cc: cc,
group: cc.t.transportTestHooks.group.(*synctestGroup),
}
+ cli, srv := synctestNetPipe(tc.group)
+ srv.SetReadDeadline(tc.group.Now())
+ tc.netconn = srv
tc.enc = hpack.NewEncoder(&tc.encbuf)
- tc.netconn.gate = newGate()
// all writes and reads are finished.
//
// cli is the ClientConn's side, srv is the side controlled by the test.
- cc.tconn = &tc.netconn
- tc.fr = NewFramer(
- (*testClientConnNetConnWriteToClient)(&tc.netconn),
- (*testClientConnNetConnReadFromClient)(&tc.netconn),
- )
+ cc.tconn = cli
+ tc.fr = NewFramer(srv, srv)
tc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
tc.fr.SetMaxReadFrameSize(10 << 20)
t.Cleanup(func() {
- tc.closeWrite(io.EOF)
+ tc.closeWrite()
})
return tc
}
@@ -138,8 +136,7 @@
tc.t.Helper()
// Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames.
buf := make([]byte, len(clientPreface))
- r := (*testClientConnNetConnReadFromClient)(&tc.netconn)
- if _, err := io.ReadFull(r, buf); err != nil {
+ if _, err := io.ReadFull(tc.netconn, buf); err != nil {
tc.t.Fatalf("reading preface: %v", err)
}
if !bytes.Equal(buf, clientPreface) {
@@ -174,26 +171,23 @@
// hasFrame reports whether a frame is available to be read.
func (tc *testClientConn) hasFrame() bool {
- tc.netconn.lock()
- defer tc.netconn.unlock()
- return tc.netconn.fromConn.Len() > 0
+ return len(tc.netconn.Peek()) > 0
}
func (tc *testClientConn) isClosed() bool {
- tc.netconn.lock()
- defer tc.netconn.unlock()
- return tc.netconn.fromConnClosed
+ return tc.netconn.IsClosedByPeer()
}
// readFrame reads the next frame from the conn.
func (tc *testClientConn) readFrame() Frame {
+ tc.t.Helper()
tc.sync()
fr, err := tc.fr.ReadFrame()
- if err == io.EOF {
+ if err == io.EOF || err == os.ErrDeadlineExceeded {
return nil
}
if err != nil {
- return nil
+ tc.t.Fatalf("ReadFrame: %v", err)
}
return fr
}
@@ -597,10 +591,8 @@
// closeWrite causes the net.Conn used by the ClientConn to return a error
// from Read calls.
-func (tc *testClientConn) closeWrite(err error) {
- tc.netconn.lock()
- tc.netconn.toConnErr = err
- tc.netconn.unlock()
+func (tc *testClientConn) closeWrite() {
+ tc.netconn.Close()
tc.sync()
}
@@ -746,80 +738,6 @@
return fmt.Sprintf("got: %v\nwant: %v", got, want)
}
-// testClientConnNetConn implements net.Conn,
-// and is the Conn used by a ClientConn under test.
-type testClientConnNetConn struct {
- gate gate
- toConn bytes.Buffer
- toConnErr error
- fromConn bytes.Buffer
- fromConnClosed bool
-}
-
-func (c *testClientConnNetConn) lock() {
- c.gate.lock()
-}
-
-func (c *testClientConnNetConn) unlock() {
- c.gate.unlock(c.toConn.Len() > 0 || c.toConnErr != nil)
-}
-
-func (c *testClientConnNetConn) Read(b []byte) (n int, err error) {
- if err := c.gate.waitAndLock(context.Background()); err != nil {
- return 0, err
- }
- defer c.unlock()
- if c.toConn.Len() == 0 && c.toConnErr != nil {
- return 0, c.toConnErr
- }
- return c.toConn.Read(b)
-}
-
-func (c *testClientConnNetConn) Write(b []byte) (n int, err error) {
- c.lock()
- defer c.unlock()
- return c.fromConn.Write(b)
-}
-
-func (c *testClientConnNetConn) Close() error {
- c.lock()
- defer c.unlock()
- c.fromConnClosed = true
- c.toConn.Reset()
- if c.toConnErr == nil {
- c.toConnErr = errors.New("connection closed")
- }
- return nil
-}
-
-func (*testClientConnNetConn) LocalAddr() (_ net.Addr) { return }
-func (*testClientConnNetConn) RemoteAddr() (_ net.Addr) { return }
-func (*testClientConnNetConn) SetDeadline(t time.Time) error { return nil }
-func (*testClientConnNetConn) SetReadDeadline(t time.Time) error { return nil }
-func (*testClientConnNetConn) SetWriteDeadline(t time.Time) error { return nil }
-
-// testClientConnNetConnWriteToClient is a view on a testClientConnNetConn
-// that implements an io.Writer that sends to the client conn under test.
-type testClientConnNetConnWriteToClient testClientConnNetConn
-
-func (w *testClientConnNetConnWriteToClient) Write(b []byte) (n int, err error) {
- c := (*testClientConnNetConn)(w)
- c.gate.lock()
- defer c.unlock()
- return c.toConn.Write(b)
-}
-
-// testClientConnNetConnReadFromClient is a view on a testClientConnNetConn
-// that implements an io.Reader that reads data sent by the client conn under test.
-type testClientConnNetConnReadFromClient testClientConnNetConn
-
-func (w *testClientConnNetConnReadFromClient) Read(b []byte) (n int, err error) {
- c := (*testClientConnNetConn)(w)
- c.gate.lock()
- defer c.unlock()
- return c.fromConn.Read(b)
-}
-
// A testTransport allows testing Transport.RoundTrip against fake servers.
// Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling
// should use testClientConn instead.
@@ -861,7 +779,7 @@
buf := make([]byte, 16*1024)
n := runtime.Stack(buf, true)
t.Logf("stacks:\n%s", buf[:n])
- t.Fatalf("%v goroutines still running after test completed, expect 1", count-1)
+ t.Fatalf("%v goroutines still running after test completed, expect 1", count)
}
})
diff --git a/http2/netconn_test.go b/http2/netconn_test.go
new file mode 100644
index 0000000..8a61fbe
--- /dev/null
+++ b/http2/netconn_test.go
@@ -0,0 +1,350 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http2
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "io"
+ "math"
+ "net"
+ "net/netip"
+ "os"
+ "sync"
+ "time"
+)
+
+// synctestNetPipe creates an in-memory, full duplex network connection.
+// Read and write timeouts are managed by the synctest group.
+//
+// Unlike net.Pipe, the connection is not synchronous.
+// Writes are made to a buffer, and return immediately.
+// By default, the buffer size is unlimited.
+func synctestNetPipe(group *synctestGroup) (r, w *synctestNetConn) {
+ s1addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000"))
+ s2addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8001"))
+ s1 := newSynctestNetConnHalf(s1addr)
+ s2 := newSynctestNetConnHalf(s2addr)
+ return &synctestNetConn{group: group, loc: s1, rem: s2},
+ &synctestNetConn{group: group, loc: s2, rem: s1}
+}
+
+// A synctestNetConn is one endpoint of the connection created by synctestNetPipe.
+type synctestNetConn struct {
+ group *synctestGroup
+
+ // local and remote connection halves.
+ // Each half contains a buffer.
+ // Reads pull from the local buffer, and writes push to the remote buffer.
+ loc, rem *synctestNetConnHalf
+
+ // When set, group.Wait is automatically called before reads and after writes.
+ autoWait bool
+}
+
+// Read reads data from the connection.
+func (c *synctestNetConn) Read(b []byte) (n int, err error) {
+ if c.autoWait {
+ c.group.Wait()
+ }
+ return c.loc.read(b)
+}
+
+// Peek returns the available unread read buffer,
+// without consuming its contents.
+func (c *synctestNetConn) Peek() []byte {
+ if c.autoWait {
+ c.group.Wait()
+ }
+ return c.loc.peek()
+}
+
+// Write writes data to the connection.
+func (c *synctestNetConn) Write(b []byte) (n int, err error) {
+ if c.autoWait {
+ defer c.group.Wait()
+ }
+ return c.rem.write(b)
+}
+
+// IsClosed reports whether the peer has closed its end of the connection.
+func (c *synctestNetConn) IsClosedByPeer() bool {
+ if c.autoWait {
+ c.group.Wait()
+ }
+ return c.loc.isClosedByPeer()
+}
+
+// Close closes the connection.
+func (c *synctestNetConn) Close() error {
+ c.loc.setWriteError(errors.New("connection closed by peer"))
+ c.rem.setReadError(io.EOF)
+ if c.autoWait {
+ c.group.Wait()
+ }
+ return nil
+}
+
+// LocalAddr returns the (fake) local network address.
+func (c *synctestNetConn) LocalAddr() net.Addr {
+ return c.loc.addr
+}
+
+// LocalAddr returns the (fake) remote network address.
+func (c *synctestNetConn) RemoteAddr() net.Addr {
+ return c.rem.addr
+}
+
+// SetDeadline sets the read and write deadlines for the connection.
+func (c *synctestNetConn) SetDeadline(t time.Time) error {
+ c.SetReadDeadline(t)
+ c.SetWriteDeadline(t)
+ return nil
+}
+
+// SetReadDeadline sets the read deadline for the connection.
+func (c *synctestNetConn) SetReadDeadline(t time.Time) error {
+ c.loc.rctx.setDeadline(c.group, t)
+ return nil
+}
+
+// SetWriteDeadline sets the write deadline for the connection.
+func (c *synctestNetConn) SetWriteDeadline(t time.Time) error {
+ c.rem.wctx.setDeadline(c.group, t)
+ return nil
+}
+
+// SetReadBufferSize sets the read buffer limit for the connection.
+// Writes by the peer will block so long as the buffer is full.
+func (c *synctestNetConn) SetReadBufferSize(size int) {
+ c.loc.setReadBufferSize(size)
+}
+
+// synctestNetConnHalf is one data flow in the connection created by synctestNetPipe.
+// Each half contains a buffer. Writes to the half push to the buffer, and reads pull from it.
+type synctestNetConnHalf struct {
+ addr net.Addr
+
+ // Read and write timeouts.
+ rctx, wctx deadlineContext
+
+ // A half can be readable and/or writable.
+ //
+ // These four channels act as a lock,
+ // and allow waiting for readability/writability.
+ // When the half is unlocked, exactly one channel contains a value.
+ // When the half is locked, all channels are empty.
+ lockr chan struct{} // readable
+ lockw chan struct{} // writable
+ lockrw chan struct{} // readable and writable
+ lockc chan struct{} // neither readable nor writable
+
+ bufMax int // maximum buffer size
+ buf bytes.Buffer
+ readErr error // error returned by reads
+ writeErr error // error returned by writes
+}
+
+func newSynctestNetConnHalf(addr net.Addr) *synctestNetConnHalf {
+ h := &synctestNetConnHalf{
+ addr: addr,
+ lockw: make(chan struct{}, 1),
+ lockr: make(chan struct{}, 1),
+ lockrw: make(chan struct{}, 1),
+ lockc: make(chan struct{}, 1),
+ bufMax: math.MaxInt, // unlimited
+ }
+ h.unlock()
+ return h
+}
+
+func (h *synctestNetConnHalf) lock() {
+ select {
+ case <-h.lockw:
+ case <-h.lockr:
+ case <-h.lockrw:
+ case <-h.lockc:
+ }
+}
+
+func (h *synctestNetConnHalf) unlock() {
+ canRead := h.readErr != nil || h.buf.Len() > 0
+ canWrite := h.writeErr != nil || h.bufMax > h.buf.Len()
+ switch {
+ case canRead && canWrite:
+ h.lockrw <- struct{}{}
+ case canRead:
+ h.lockr <- struct{}{}
+ case canWrite:
+ h.lockw <- struct{}{}
+ default:
+ h.lockc <- struct{}{}
+ }
+}
+
+func (h *synctestNetConnHalf) readWaitAndLock() error {
+ select {
+ case <-h.lockr:
+ return nil
+ case <-h.lockrw:
+ return nil
+ default:
+ }
+ ctx := h.rctx.context()
+ select {
+ case <-h.lockr:
+ return nil
+ case <-h.lockrw:
+ return nil
+ case <-ctx.Done():
+ return context.Cause(ctx)
+ }
+}
+
+func (h *synctestNetConnHalf) writeWaitAndLock() error {
+ select {
+ case <-h.lockw:
+ return nil
+ case <-h.lockrw:
+ return nil
+ default:
+ }
+ ctx := h.wctx.context()
+ select {
+ case <-h.lockw:
+ return nil
+ case <-h.lockrw:
+ return nil
+ case <-ctx.Done():
+ return context.Cause(ctx)
+ }
+}
+
+func (h *synctestNetConnHalf) peek() []byte {
+ h.lock()
+ defer h.unlock()
+ return h.buf.Bytes()
+}
+
+func (h *synctestNetConnHalf) isClosedByPeer() bool {
+ h.lock()
+ defer h.unlock()
+ return h.readErr != nil
+}
+
+func (h *synctestNetConnHalf) read(b []byte) (n int, err error) {
+ if err := h.readWaitAndLock(); err != nil {
+ return 0, err
+ }
+ defer h.unlock()
+ if h.buf.Len() == 0 && h.readErr != nil {
+ return 0, h.readErr
+ }
+ return h.buf.Read(b)
+}
+
+func (h *synctestNetConnHalf) setReadBufferSize(size int) {
+ h.lock()
+ defer h.unlock()
+ h.bufMax = size
+}
+
+func (h *synctestNetConnHalf) write(b []byte) (n int, err error) {
+ for n < len(b) {
+ nn, err := h.writePartial(b[n:])
+ n += nn
+ if err != nil {
+ return n, err
+ }
+ }
+ return n, nil
+}
+
+func (h *synctestNetConnHalf) writePartial(b []byte) (n int, err error) {
+ if err := h.writeWaitAndLock(); err != nil {
+ return 0, err
+ }
+ defer h.unlock()
+ if h.writeErr != nil {
+ return 0, h.writeErr
+ }
+ writeMax := h.bufMax - h.buf.Len()
+ if writeMax < len(b) {
+ b = b[:writeMax]
+ }
+ return h.buf.Write(b)
+}
+
+func (h *synctestNetConnHalf) setReadError(err error) {
+ h.lock()
+ defer h.unlock()
+ if h.readErr == nil {
+ h.readErr = err
+ }
+}
+
+func (h *synctestNetConnHalf) setWriteError(err error) {
+ h.lock()
+ defer h.unlock()
+ if h.writeErr == nil {
+ h.writeErr = err
+ }
+}
+
+// deadlineContext converts a changable deadline (as in net.Conn.SetDeadline) into a Context.
+type deadlineContext struct {
+ mu sync.Mutex
+ ctx context.Context
+ cancel context.CancelCauseFunc
+ timer timer
+}
+
+// context returns a Context which expires when the deadline does.
+func (t *deadlineContext) context() context.Context {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.ctx == nil {
+ t.ctx, t.cancel = context.WithCancelCause(context.Background())
+ }
+ return t.ctx
+}
+
+// setDeadline sets the current deadline.
+func (t *deadlineContext) setDeadline(group *synctestGroup, deadline time.Time) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ // If t.ctx is non-nil and t.cancel is nil, then t.ctx was canceled
+ // and we should create a new one.
+ if t.ctx == nil || t.cancel == nil {
+ t.ctx, t.cancel = context.WithCancelCause(context.Background())
+ }
+ // Stop any existing deadline from expiring.
+ if t.timer != nil {
+ t.timer.Stop()
+ }
+ if deadline.IsZero() {
+ // No deadline.
+ return
+ }
+ if !deadline.After(group.Now()) {
+ // Deadline has already expired.
+ t.cancel(os.ErrDeadlineExceeded)
+ t.cancel = nil
+ return
+ }
+ if t.timer != nil {
+ // Reuse existing deadline timer.
+ t.timer.Reset(deadline.Sub(group.Now()))
+ return
+ }
+ // Create a new timer to cancel the context at the deadline.
+ t.timer = group.AfterFunc(deadline.Sub(group.Now()), func() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.cancel(os.ErrDeadlineExceeded)
+ t.cancel = nil
+ })
+}
diff --git a/http2/sync_test.go b/http2/sync_test.go
index 3f5cf31..bcbbe66 100644
--- a/http2/sync_test.go
+++ b/http2/sync_test.go
@@ -166,6 +166,13 @@
}
}
+// Now returns the current synthetic time.
+func (g *synctestGroup) Now() time.Time {
+ g.mu.Lock()
+ defer g.mu.Unlock()
+ return g.now
+}
+
// TimeUntilEvent returns the amount of time until the next scheduled timer.
func (g *synctestGroup) TimeUntilEvent() (d time.Duration, scheduled bool) {
g.mu.Lock()
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 2171359..d62407b 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -2491,7 +2491,7 @@
// the interesting parts of both.
tc.writeGoAway(5, ErrCodeNo, []byte(goAwayDebugData))
tc.writeGoAway(5, goAwayErrCode, nil)
- tc.closeWrite(io.EOF)
+ tc.closeWrite()
res, err := rt.result()
whence := "RoundTrip"
@@ -5151,7 +5151,7 @@
})
}
- tc.closeWrite(io.EOF)
+ tc.closeWrite()
err := rt.err()
if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr {
t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr)