websocket: limit incoming payload size

Codec's Receive method calls io.ReadAll of the whole frame payload,
which can be abused by user sending large payloads in order to exhaust
server memory.

Introduce limit on received payload size defined by
Conn.MaxPayloadBytes. If payload size of the message read with
Codec.Receive exceeds limit, ErrFrameTooLarge error is returned; the
connection can still be recovered if required: the next call to Receive
would at first discard leftovers of previous oversized message before
processing the next one.

Fixes golang/go#5082.

Change-Id: Ib04acd7038474fee39a1719324daaec1c0c496b1
Reviewed-on: https://go-review.googlesource.com/23590
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/websocket/websocket.go b/websocket/websocket.go
index c4c7991..a7731d9 100644
--- a/websocket/websocket.go
+++ b/websocket/websocket.go
@@ -32,6 +32,8 @@
 	PingFrame         = 9
 	PongFrame         = 10
 	UnknownFrame      = 255
+
+	DefaultMaxPayloadBytes = 32 << 20 // 32MB
 )
 
 // ProtocolError represents WebSocket protocol errors.
@@ -58,6 +60,10 @@
 	ErrNotSupported         = &ProtocolError{"not supported"}
 )
 
+// ErrFrameTooLarge is returned by Codec's Receive method if payload size
+// exceeds limit set by Conn.MaxPayloadBytes
+var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit")
+
 // Addr is an implementation of net.Addr for WebSocket.
 type Addr struct {
 	*url.URL
@@ -166,6 +172,10 @@
 	frameHandler
 	PayloadType        byte
 	defaultCloseStatus int
+
+	// MaxPayloadBytes limits the size of frame payload received over Conn
+	// by Codec's Receive method. If zero, DefaultMaxPayloadBytes is used.
+	MaxPayloadBytes int
 }
 
 // Read implements the io.Reader interface:
@@ -302,7 +312,12 @@
 	return err
 }
 
-// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores in v.
+// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores
+// in v. The whole frame payload is read to an in-memory buffer; max size of
+// payload is defined by ws.MaxPayloadBytes. If frame payload size exceeds
+// limit, ErrFrameTooLarge is returned; in this case frame is not read off wire
+// completely. The next call to Receive would read and discard leftover data of
+// previous oversized frame before processing next frame.
 func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
 	ws.rio.Lock()
 	defer ws.rio.Unlock()
@@ -325,6 +340,19 @@
 	if frame == nil {
 		goto again
 	}
+	maxPayloadBytes := ws.MaxPayloadBytes
+	if maxPayloadBytes == 0 {
+		maxPayloadBytes = DefaultMaxPayloadBytes
+	}
+	if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) {
+		// payload size exceeds limit, no need to call Unmarshal
+		//
+		// set frameReader to current oversized frame so that
+		// the next call to this function can drain leftover
+		// data before processing the next frame
+		ws.frameReader = frame
+		return ErrFrameTooLarge
+	}
 	payloadType := frame.PayloadType()
 	data, err := ioutil.ReadAll(frame)
 	if err != nil {
diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go
index 4a76a7e..4cd674b 100644
--- a/websocket/websocket_test.go
+++ b/websocket/websocket_test.go
@@ -9,6 +9,7 @@
 	"fmt"
 	"io"
 	"log"
+	"math/rand"
 	"net"
 	"net/http"
 	"net/http/httptest"
@@ -605,3 +606,60 @@
 		}
 	}
 }
+
+func TestCodec_ReceiveLimited(t *testing.T) {
+	const limit = 2048
+	var payloads [][]byte
+	for _, size := range []int{
+		1024,
+		2048,
+		4096, // receive of this message would be interrupted due to limit
+		2048, // this one is to make sure next receive recovers discarding leftovers
+	} {
+		b := make([]byte, size)
+		rand.Read(b)
+		payloads = append(payloads, b)
+	}
+	handlerDone := make(chan struct{})
+	limitedHandler := func(ws *Conn) {
+		defer close(handlerDone)
+		ws.MaxPayloadBytes = limit
+		defer ws.Close()
+		for i, p := range payloads {
+			t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit)
+			var recv []byte
+			err := Message.Receive(ws, &recv)
+			switch err {
+			case nil:
+			case ErrFrameTooLarge:
+				if len(p) <= limit {
+					t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit)
+				}
+				continue
+			default:
+				t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err)
+			}
+			if len(recv) > limit {
+				t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit)
+			}
+			if !bytes.Equal(p, recv) {
+				t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p)
+			}
+		}
+	}
+	server := httptest.NewServer(Handler(limitedHandler))
+	defer server.CloseClientConnections()
+	defer server.Close()
+	addr := server.Listener.Addr().String()
+	ws, err := Dial("ws://"+addr+"/", "", "http://localhost/")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer ws.Close()
+	for i, p := range payloads {
+		if err := Message.Send(ws, p); err != nil {
+			t.Fatalf("payload #%d (size %d): %v", i, len(p), err)
+		}
+	}
+	<-handlerDone
+}