http2: add a more full-featured test net.Conn

Add a net.Conn implementation that plays nicely with testsyncGroup,
implements read/write timeouts, and gives control over buffering
to let us write tests that cause writes to a Conn to block at
specific points in time.

Change-Id: I9d870b211ac9d938a8c4a221277981cdb821a6e4
Reviewed-on: https://go-review.googlesource.com/c/net/+/586246
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go
index 855f44e..36f080b 100644
--- a/http2/clientconn_test.go
+++ b/http2/clientconn_test.go
@@ -10,11 +10,10 @@
 import (
 	"bytes"
 	"context"
-	"errors"
 	"fmt"
 	"io"
-	"net"
 	"net/http"
+	"os"
 	"reflect"
 	"runtime"
 	"slices"
@@ -104,7 +103,7 @@
 
 	roundtrips []*testRoundTrip
 
-	netconn testClientConnNetConn
+	netconn *synctestNetConn
 }
 
 func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn {
@@ -114,22 +113,21 @@
 		cc:    cc,
 		group: cc.t.transportTestHooks.group.(*synctestGroup),
 	}
+	cli, srv := synctestNetPipe(tc.group)
+	srv.SetReadDeadline(tc.group.Now())
+	tc.netconn = srv
 	tc.enc = hpack.NewEncoder(&tc.encbuf)
-	tc.netconn.gate = newGate()
 
 	// all writes and reads are finished.
 	//
 	// cli is the ClientConn's side, srv is the side controlled by the test.
-	cc.tconn = &tc.netconn
-	tc.fr = NewFramer(
-		(*testClientConnNetConnWriteToClient)(&tc.netconn),
-		(*testClientConnNetConnReadFromClient)(&tc.netconn),
-	)
+	cc.tconn = cli
+	tc.fr = NewFramer(srv, srv)
 
 	tc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
 	tc.fr.SetMaxReadFrameSize(10 << 20)
 	t.Cleanup(func() {
-		tc.closeWrite(io.EOF)
+		tc.closeWrite()
 	})
 	return tc
 }
@@ -138,8 +136,7 @@
 	tc.t.Helper()
 	// Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames.
 	buf := make([]byte, len(clientPreface))
-	r := (*testClientConnNetConnReadFromClient)(&tc.netconn)
-	if _, err := io.ReadFull(r, buf); err != nil {
+	if _, err := io.ReadFull(tc.netconn, buf); err != nil {
 		tc.t.Fatalf("reading preface: %v", err)
 	}
 	if !bytes.Equal(buf, clientPreface) {
@@ -174,26 +171,23 @@
 
 // hasFrame reports whether a frame is available to be read.
 func (tc *testClientConn) hasFrame() bool {
-	tc.netconn.lock()
-	defer tc.netconn.unlock()
-	return tc.netconn.fromConn.Len() > 0
+	return len(tc.netconn.Peek()) > 0
 }
 
 func (tc *testClientConn) isClosed() bool {
-	tc.netconn.lock()
-	defer tc.netconn.unlock()
-	return tc.netconn.fromConnClosed
+	return tc.netconn.IsClosedByPeer()
 }
 
 // readFrame reads the next frame from the conn.
 func (tc *testClientConn) readFrame() Frame {
+	tc.t.Helper()
 	tc.sync()
 	fr, err := tc.fr.ReadFrame()
-	if err == io.EOF {
+	if err == io.EOF || err == os.ErrDeadlineExceeded {
 		return nil
 	}
 	if err != nil {
-		return nil
+		tc.t.Fatalf("ReadFrame: %v", err)
 	}
 	return fr
 }
@@ -597,10 +591,8 @@
 
 // closeWrite causes the net.Conn used by the ClientConn to return a error
 // from Read calls.
-func (tc *testClientConn) closeWrite(err error) {
-	tc.netconn.lock()
-	tc.netconn.toConnErr = err
-	tc.netconn.unlock()
+func (tc *testClientConn) closeWrite() {
+	tc.netconn.Close()
 	tc.sync()
 }
 
@@ -746,80 +738,6 @@
 	return fmt.Sprintf("got:  %v\nwant: %v", got, want)
 }
 
-// testClientConnNetConn implements net.Conn,
-// and is the Conn used by a ClientConn under test.
-type testClientConnNetConn struct {
-	gate           gate
-	toConn         bytes.Buffer
-	toConnErr      error
-	fromConn       bytes.Buffer
-	fromConnClosed bool
-}
-
-func (c *testClientConnNetConn) lock() {
-	c.gate.lock()
-}
-
-func (c *testClientConnNetConn) unlock() {
-	c.gate.unlock(c.toConn.Len() > 0 || c.toConnErr != nil)
-}
-
-func (c *testClientConnNetConn) Read(b []byte) (n int, err error) {
-	if err := c.gate.waitAndLock(context.Background()); err != nil {
-		return 0, err
-	}
-	defer c.unlock()
-	if c.toConn.Len() == 0 && c.toConnErr != nil {
-		return 0, c.toConnErr
-	}
-	return c.toConn.Read(b)
-}
-
-func (c *testClientConnNetConn) Write(b []byte) (n int, err error) {
-	c.lock()
-	defer c.unlock()
-	return c.fromConn.Write(b)
-}
-
-func (c *testClientConnNetConn) Close() error {
-	c.lock()
-	defer c.unlock()
-	c.fromConnClosed = true
-	c.toConn.Reset()
-	if c.toConnErr == nil {
-		c.toConnErr = errors.New("connection closed")
-	}
-	return nil
-}
-
-func (*testClientConnNetConn) LocalAddr() (_ net.Addr)            { return }
-func (*testClientConnNetConn) RemoteAddr() (_ net.Addr)           { return }
-func (*testClientConnNetConn) SetDeadline(t time.Time) error      { return nil }
-func (*testClientConnNetConn) SetReadDeadline(t time.Time) error  { return nil }
-func (*testClientConnNetConn) SetWriteDeadline(t time.Time) error { return nil }
-
-// testClientConnNetConnWriteToClient is a view on a testClientConnNetConn
-// that implements an io.Writer that sends to the client conn under test.
-type testClientConnNetConnWriteToClient testClientConnNetConn
-
-func (w *testClientConnNetConnWriteToClient) Write(b []byte) (n int, err error) {
-	c := (*testClientConnNetConn)(w)
-	c.gate.lock()
-	defer c.unlock()
-	return c.toConn.Write(b)
-}
-
-// testClientConnNetConnReadFromClient is a view on a testClientConnNetConn
-// that implements an io.Reader that reads data sent by the client conn under test.
-type testClientConnNetConnReadFromClient testClientConnNetConn
-
-func (w *testClientConnNetConnReadFromClient) Read(b []byte) (n int, err error) {
-	c := (*testClientConnNetConn)(w)
-	c.gate.lock()
-	defer c.unlock()
-	return c.fromConn.Read(b)
-}
-
 // A testTransport allows testing Transport.RoundTrip against fake servers.
 // Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling
 // should use testClientConn instead.
@@ -861,7 +779,7 @@
 			buf := make([]byte, 16*1024)
 			n := runtime.Stack(buf, true)
 			t.Logf("stacks:\n%s", buf[:n])
-			t.Fatalf("%v goroutines still running after test completed, expect 1", count-1)
+			t.Fatalf("%v goroutines still running after test completed, expect 1", count)
 		}
 	})
 
diff --git a/http2/netconn_test.go b/http2/netconn_test.go
new file mode 100644
index 0000000..8a61fbe
--- /dev/null
+++ b/http2/netconn_test.go
@@ -0,0 +1,350 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http2
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"io"
+	"math"
+	"net"
+	"net/netip"
+	"os"
+	"sync"
+	"time"
+)
+
+// synctestNetPipe creates an in-memory, full duplex network connection.
+// Read and write timeouts are managed by the synctest group.
+//
+// Unlike net.Pipe, the connection is not synchronous.
+// Writes are made to a buffer, and return immediately.
+// By default, the buffer size is unlimited.
+func synctestNetPipe(group *synctestGroup) (r, w *synctestNetConn) {
+	s1addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000"))
+	s2addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8001"))
+	s1 := newSynctestNetConnHalf(s1addr)
+	s2 := newSynctestNetConnHalf(s2addr)
+	return &synctestNetConn{group: group, loc: s1, rem: s2},
+		&synctestNetConn{group: group, loc: s2, rem: s1}
+}
+
+// A synctestNetConn is one endpoint of the connection created by synctestNetPipe.
+type synctestNetConn struct {
+	group *synctestGroup
+
+	// local and remote connection halves.
+	// Each half contains a buffer.
+	// Reads pull from the local buffer, and writes push to the remote buffer.
+	loc, rem *synctestNetConnHalf
+
+	// When set, group.Wait is automatically called before reads and after writes.
+	autoWait bool
+}
+
+// Read reads data from the connection.
+func (c *synctestNetConn) Read(b []byte) (n int, err error) {
+	if c.autoWait {
+		c.group.Wait()
+	}
+	return c.loc.read(b)
+}
+
+// Peek returns the available unread read buffer,
+// without consuming its contents.
+func (c *synctestNetConn) Peek() []byte {
+	if c.autoWait {
+		c.group.Wait()
+	}
+	return c.loc.peek()
+}
+
+// Write writes data to the connection.
+func (c *synctestNetConn) Write(b []byte) (n int, err error) {
+	if c.autoWait {
+		defer c.group.Wait()
+	}
+	return c.rem.write(b)
+}
+
+// IsClosed reports whether the peer has closed its end of the connection.
+func (c *synctestNetConn) IsClosedByPeer() bool {
+	if c.autoWait {
+		c.group.Wait()
+	}
+	return c.loc.isClosedByPeer()
+}
+
+// Close closes the connection.
+func (c *synctestNetConn) Close() error {
+	c.loc.setWriteError(errors.New("connection closed by peer"))
+	c.rem.setReadError(io.EOF)
+	if c.autoWait {
+		c.group.Wait()
+	}
+	return nil
+}
+
+// LocalAddr returns the (fake) local network address.
+func (c *synctestNetConn) LocalAddr() net.Addr {
+	return c.loc.addr
+}
+
+// LocalAddr returns the (fake) remote network address.
+func (c *synctestNetConn) RemoteAddr() net.Addr {
+	return c.rem.addr
+}
+
+// SetDeadline sets the read and write deadlines for the connection.
+func (c *synctestNetConn) SetDeadline(t time.Time) error {
+	c.SetReadDeadline(t)
+	c.SetWriteDeadline(t)
+	return nil
+}
+
+// SetReadDeadline sets the read deadline for the connection.
+func (c *synctestNetConn) SetReadDeadline(t time.Time) error {
+	c.loc.rctx.setDeadline(c.group, t)
+	return nil
+}
+
+// SetWriteDeadline sets the write deadline for the connection.
+func (c *synctestNetConn) SetWriteDeadline(t time.Time) error {
+	c.rem.wctx.setDeadline(c.group, t)
+	return nil
+}
+
+// SetReadBufferSize sets the read buffer limit for the connection.
+// Writes by the peer will block so long as the buffer is full.
+func (c *synctestNetConn) SetReadBufferSize(size int) {
+	c.loc.setReadBufferSize(size)
+}
+
+// synctestNetConnHalf is one data flow in the connection created by synctestNetPipe.
+// Each half contains a buffer. Writes to the half push to the buffer, and reads pull from it.
+type synctestNetConnHalf struct {
+	addr net.Addr
+
+	// Read and write timeouts.
+	rctx, wctx deadlineContext
+
+	// A half can be readable and/or writable.
+	//
+	// These four channels act as a lock,
+	// and allow waiting for readability/writability.
+	// When the half is unlocked, exactly one channel contains a value.
+	// When the half is locked, all channels are empty.
+	lockr  chan struct{} // readable
+	lockw  chan struct{} // writable
+	lockrw chan struct{} // readable and writable
+	lockc  chan struct{} // neither readable nor writable
+
+	bufMax   int // maximum buffer size
+	buf      bytes.Buffer
+	readErr  error // error returned by reads
+	writeErr error // error returned by writes
+}
+
+func newSynctestNetConnHalf(addr net.Addr) *synctestNetConnHalf {
+	h := &synctestNetConnHalf{
+		addr:   addr,
+		lockw:  make(chan struct{}, 1),
+		lockr:  make(chan struct{}, 1),
+		lockrw: make(chan struct{}, 1),
+		lockc:  make(chan struct{}, 1),
+		bufMax: math.MaxInt, // unlimited
+	}
+	h.unlock()
+	return h
+}
+
+func (h *synctestNetConnHalf) lock() {
+	select {
+	case <-h.lockw:
+	case <-h.lockr:
+	case <-h.lockrw:
+	case <-h.lockc:
+	}
+}
+
+func (h *synctestNetConnHalf) unlock() {
+	canRead := h.readErr != nil || h.buf.Len() > 0
+	canWrite := h.writeErr != nil || h.bufMax > h.buf.Len()
+	switch {
+	case canRead && canWrite:
+		h.lockrw <- struct{}{}
+	case canRead:
+		h.lockr <- struct{}{}
+	case canWrite:
+		h.lockw <- struct{}{}
+	default:
+		h.lockc <- struct{}{}
+	}
+}
+
+func (h *synctestNetConnHalf) readWaitAndLock() error {
+	select {
+	case <-h.lockr:
+		return nil
+	case <-h.lockrw:
+		return nil
+	default:
+	}
+	ctx := h.rctx.context()
+	select {
+	case <-h.lockr:
+		return nil
+	case <-h.lockrw:
+		return nil
+	case <-ctx.Done():
+		return context.Cause(ctx)
+	}
+}
+
+func (h *synctestNetConnHalf) writeWaitAndLock() error {
+	select {
+	case <-h.lockw:
+		return nil
+	case <-h.lockrw:
+		return nil
+	default:
+	}
+	ctx := h.wctx.context()
+	select {
+	case <-h.lockw:
+		return nil
+	case <-h.lockrw:
+		return nil
+	case <-ctx.Done():
+		return context.Cause(ctx)
+	}
+}
+
+func (h *synctestNetConnHalf) peek() []byte {
+	h.lock()
+	defer h.unlock()
+	return h.buf.Bytes()
+}
+
+func (h *synctestNetConnHalf) isClosedByPeer() bool {
+	h.lock()
+	defer h.unlock()
+	return h.readErr != nil
+}
+
+func (h *synctestNetConnHalf) read(b []byte) (n int, err error) {
+	if err := h.readWaitAndLock(); err != nil {
+		return 0, err
+	}
+	defer h.unlock()
+	if h.buf.Len() == 0 && h.readErr != nil {
+		return 0, h.readErr
+	}
+	return h.buf.Read(b)
+}
+
+func (h *synctestNetConnHalf) setReadBufferSize(size int) {
+	h.lock()
+	defer h.unlock()
+	h.bufMax = size
+}
+
+func (h *synctestNetConnHalf) write(b []byte) (n int, err error) {
+	for n < len(b) {
+		nn, err := h.writePartial(b[n:])
+		n += nn
+		if err != nil {
+			return n, err
+		}
+	}
+	return n, nil
+}
+
+func (h *synctestNetConnHalf) writePartial(b []byte) (n int, err error) {
+	if err := h.writeWaitAndLock(); err != nil {
+		return 0, err
+	}
+	defer h.unlock()
+	if h.writeErr != nil {
+		return 0, h.writeErr
+	}
+	writeMax := h.bufMax - h.buf.Len()
+	if writeMax < len(b) {
+		b = b[:writeMax]
+	}
+	return h.buf.Write(b)
+}
+
+func (h *synctestNetConnHalf) setReadError(err error) {
+	h.lock()
+	defer h.unlock()
+	if h.readErr == nil {
+		h.readErr = err
+	}
+}
+
+func (h *synctestNetConnHalf) setWriteError(err error) {
+	h.lock()
+	defer h.unlock()
+	if h.writeErr == nil {
+		h.writeErr = err
+	}
+}
+
+// deadlineContext converts a changable deadline (as in net.Conn.SetDeadline) into a Context.
+type deadlineContext struct {
+	mu     sync.Mutex
+	ctx    context.Context
+	cancel context.CancelCauseFunc
+	timer  timer
+}
+
+// context returns a Context which expires when the deadline does.
+func (t *deadlineContext) context() context.Context {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	if t.ctx == nil {
+		t.ctx, t.cancel = context.WithCancelCause(context.Background())
+	}
+	return t.ctx
+}
+
+// setDeadline sets the current deadline.
+func (t *deadlineContext) setDeadline(group *synctestGroup, deadline time.Time) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	// If t.ctx is non-nil and t.cancel is nil, then t.ctx was canceled
+	// and we should create a new one.
+	if t.ctx == nil || t.cancel == nil {
+		t.ctx, t.cancel = context.WithCancelCause(context.Background())
+	}
+	// Stop any existing deadline from expiring.
+	if t.timer != nil {
+		t.timer.Stop()
+	}
+	if deadline.IsZero() {
+		// No deadline.
+		return
+	}
+	if !deadline.After(group.Now()) {
+		// Deadline has already expired.
+		t.cancel(os.ErrDeadlineExceeded)
+		t.cancel = nil
+		return
+	}
+	if t.timer != nil {
+		// Reuse existing deadline timer.
+		t.timer.Reset(deadline.Sub(group.Now()))
+		return
+	}
+	// Create a new timer to cancel the context at the deadline.
+	t.timer = group.AfterFunc(deadline.Sub(group.Now()), func() {
+		t.mu.Lock()
+		defer t.mu.Unlock()
+		t.cancel(os.ErrDeadlineExceeded)
+		t.cancel = nil
+	})
+}
diff --git a/http2/sync_test.go b/http2/sync_test.go
index 3f5cf31..bcbbe66 100644
--- a/http2/sync_test.go
+++ b/http2/sync_test.go
@@ -166,6 +166,13 @@
 	}
 }
 
+// Now returns the current synthetic time.
+func (g *synctestGroup) Now() time.Time {
+	g.mu.Lock()
+	defer g.mu.Unlock()
+	return g.now
+}
+
 // TimeUntilEvent returns the amount of time until the next scheduled timer.
 func (g *synctestGroup) TimeUntilEvent() (d time.Duration, scheduled bool) {
 	g.mu.Lock()
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 2171359..d62407b 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -2491,7 +2491,7 @@
 	// the interesting parts of both.
 	tc.writeGoAway(5, ErrCodeNo, []byte(goAwayDebugData))
 	tc.writeGoAway(5, goAwayErrCode, nil)
-	tc.closeWrite(io.EOF)
+	tc.closeWrite()
 
 	res, err := rt.result()
 	whence := "RoundTrip"
@@ -5151,7 +5151,7 @@
 		})
 	}
 
-	tc.closeWrite(io.EOF)
+	tc.closeWrite()
 	err := rt.err()
 	if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr {
 		t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr)