WIP Prevent "Slow Retrieval" attacks

Signed-off-by: Lewis Marshall <lewis@lmars.net>
diff --git a/client/client.go b/client/client.go
index 0be0c1a..5787606 100644
--- a/client/client.go
+++ b/client/client.go
@@ -309,6 +309,9 @@
 	}
 	defer r.Close()
 
+	// wrap in a timeoutReader to prevent slow retrieval attacks
+	stream := newTimeoutReader(r)
+
 	// return ErrMetaTooLarge if the reported size is greater than maxMetaSize
 	if size > maxMetaSize {
 		return nil, ErrMetaTooLarge{name, size}
@@ -317,7 +320,7 @@
 	// although the size has been checked above, use a LimitReader in case
 	// the reported size is inaccurate, or size is -1 which indicates an
 	// unknown length
-	return ioutil.ReadAll(io.LimitReader(r, maxMetaSize))
+	return ioutil.ReadAll(io.LimitReader(stream, maxMetaSize))
 }
 
 // downloadMeta downloads top-level metadata from remote storage and verifies
@@ -337,8 +340,9 @@
 		return nil, ErrWrongSize{name, size, m.Length}
 	}
 
-	// wrap the data in a LimitReader so we download at most m.Length bytes
-	stream := io.LimitReader(r, m.Length)
+	// wrap the data in a timeoutReader to prevent slow retrieval attacks,
+	// and a LimitReader so we download at most m.Length bytes
+	stream := newTimeoutReader(io.LimitReader(r, m.Length))
 
 	// read the data, simultaneously writing it to buf and generating metadata
 	var buf bytes.Buffer
@@ -465,8 +469,9 @@
 		return ErrWrongSize{name, size, localMeta.Length}
 	}
 
-	// wrap the data in a LimitReader so we download at most localMeta.Length bytes
-	stream := io.LimitReader(r, localMeta.Length)
+	// wrap the data in a timeoutReader to prevent slow retrieval attacks,
+	// and a LimitReader so we download at most localMeta.Length bytes
+	stream := newTimeoutReader(io.LimitReader(r, localMeta.Length))
 
 	// read the data, simultaneously writing it to dest and generating metadata
 	actual, err := util.GenerateFileMeta(io.TeeReader(stream, dest))
diff --git a/client/client_test.go b/client/client_test.go
index 54fe718..aa55b57 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -45,13 +45,28 @@
 	return &fakeFile{buf: bytes.NewReader(b), size: int64(len(b))}
 }
 
+func newBlockingFakeFile(b []byte) *fakeFile {
+	return &fakeFile{typ: "blocking", buf: bytes.NewReader(b), size: int64(len(b))}
+}
+
+func newSlowFakeFile(b []byte) *fakeFile {
+	return &fakeFile{typ: "slow", buf: bytes.NewReader(b), size: int64(len(b))}
+}
+
 type fakeFile struct {
+	typ       string
 	buf       *bytes.Reader
 	bytesRead int
 	size      int64
 }
 
 func (f *fakeFile) Read(p []byte) (int, error) {
+	switch f.typ {
+	case "blocking":
+		<-make(chan struct{})
+	case "slow":
+		time.Sleep(2 * time.Second)
+	}
 	n, err := f.buf.Read(p)
 	f.bytesRead += n
 	return n, err
@@ -645,6 +660,26 @@
 	c.Assert(err, DeepEquals, ErrWrongSize{"targets.json", int64(len(tamperedJSON)), int64(len(targetsJSON))})
 }
 
+func (s *ClientSuite) TestUpdateSlowRetrievalAttack(c *C) {
+	meta, err := s.store.GetMeta()
+	c.Assert(err, IsNil)
+	snapshot, ok := meta["snapshot.json"]
+	if !ok {
+		c.Fatal("missing snapshot.json")
+	}
+	client := s.newClient(c)
+
+	s.remote["snapshot.json"] = newBlockingFakeFile(snapshot)
+	_, err = client.Update()
+	// c.Assert(err, DeepEquals, ErrWrongSize{"snapshot.json", 0, int64(len(snapshot))})
+	c.Assert(err, DeepEquals, ErrDownloadFailed{"snapshot.json", util.ErrWrongLength})
+
+	s.remote["snapshot.json"] = newSlowFakeFile(snapshot)
+	_, err = client.Update()
+	// c.Assert(err, DeepEquals, ErrWrongSize{"snapshot.json", 16384, int64(len(snapshot))})
+	c.Assert(err, DeepEquals, ErrDownloadFailed{"snapshot.json", util.ErrWrongLength})
+}
+
 type testDestination struct {
 	bytes.Buffer
 	deleted bool
diff --git a/client/timeout_reader.go b/client/timeout_reader.go
new file mode 100644
index 0000000..a1d7612
--- /dev/null
+++ b/client/timeout_reader.go
@@ -0,0 +1,69 @@
+package client
+
+import (
+	"errors"
+	"io"
+	"time"
+)
+
+var ErrTimeout = errors.New("timeout")
+
+// timeoutReader wraps an io.Reader and times out if the read rate is lower
+// than chunkSize per second
+// TODO: use gracePeriod
+type timeoutReader struct {
+	r           io.Reader
+	gracePeriod time.Duration
+	chunkSize   int
+}
+
+const (
+	defaultGracePeriod = 4 * time.Second
+	defaultChunkSize   = 8 * 1024
+)
+
+// newTimeoutReader returns a timeoutReader with default gracePeriod and chunkSize
+func newTimeoutReader(r io.Reader) *timeoutReader {
+	return &timeoutReader{r, defaultGracePeriod, defaultChunkSize}
+}
+
+// readResult represents the return value of a read
+type readResult struct {
+	n   int
+	err error
+}
+
+// Read reads from t.r, timing out if the read rate is lower than t.chunkSize per second
+func (t *timeoutReader) Read(p []byte) (int, error) {
+	if len(p) < t.chunkSize {
+		timeout := (time.Duration(len(p)) * time.Second) / time.Duration(t.chunkSize)
+		return t.readWithTimeout(p, timeout)
+	}
+	var pos int
+	for {
+		size := t.chunkSize
+		if size > len(p)-pos {
+			size = len(p) - pos
+		}
+		m, err := t.readWithTimeout(p[pos:size], time.Second)
+		pos += m
+		if pos == len(p) || err != nil {
+			return pos, err
+		}
+	}
+}
+
+func (t *timeoutReader) readWithTimeout(p []byte, timeout time.Duration) (int, error) {
+	done := make(chan *readResult)
+	go func() {
+		res := &readResult{}
+		res.n, res.err = t.r.Read(p)
+		done <- res
+	}()
+	select {
+	case res := <-done:
+		return res.n, res.err
+	case <-time.After(timeout):
+		return 0, ErrTimeout
+	}
+}