WIP Prevent "Slow Retrieval" attacks
Signed-off-by: Lewis Marshall <lewis@lmars.net>
diff --git a/client/client.go b/client/client.go
index 0be0c1a..5787606 100644
--- a/client/client.go
+++ b/client/client.go
@@ -309,6 +309,9 @@
}
defer r.Close()
+ // wrap in a timeoutReader to prevent slow retrieval attacks
+ stream := newTimeoutReader(r)
+
// return ErrMetaTooLarge if the reported size is greater than maxMetaSize
if size > maxMetaSize {
return nil, ErrMetaTooLarge{name, size}
@@ -317,7 +320,7 @@
// although the size has been checked above, use a LimitReader in case
// the reported size is inaccurate, or size is -1 which indicates an
// unknown length
- return ioutil.ReadAll(io.LimitReader(r, maxMetaSize))
+ return ioutil.ReadAll(io.LimitReader(stream, maxMetaSize))
}
// downloadMeta downloads top-level metadata from remote storage and verifies
@@ -337,8 +340,9 @@
return nil, ErrWrongSize{name, size, m.Length}
}
- // wrap the data in a LimitReader so we download at most m.Length bytes
- stream := io.LimitReader(r, m.Length)
+ // wrap the data in a timeoutReader to prevent slow retrieval attacks,
+ // and a LimitReader so we download at most m.Length bytes
+ stream := newTimeoutReader(io.LimitReader(r, m.Length))
// read the data, simultaneously writing it to buf and generating metadata
var buf bytes.Buffer
@@ -465,8 +469,9 @@
return ErrWrongSize{name, size, localMeta.Length}
}
- // wrap the data in a LimitReader so we download at most localMeta.Length bytes
- stream := io.LimitReader(r, localMeta.Length)
+ // wrap the data in a timeoutReader to prevent slow retrieval attacks,
+ // and a LimitReader so we download at most localMeta.Length bytes
+ stream := newTimeoutReader(io.LimitReader(r, localMeta.Length))
// read the data, simultaneously writing it to dest and generating metadata
actual, err := util.GenerateFileMeta(io.TeeReader(stream, dest))
diff --git a/client/client_test.go b/client/client_test.go
index 54fe718..aa55b57 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -45,13 +45,28 @@
return &fakeFile{buf: bytes.NewReader(b), size: int64(len(b))}
}
+func newBlockingFakeFile(b []byte) *fakeFile {
+ return &fakeFile{typ: "blocking", buf: bytes.NewReader(b), size: int64(len(b))}
+}
+
+func newSlowFakeFile(b []byte) *fakeFile {
+ return &fakeFile{typ: "slow", buf: bytes.NewReader(b), size: int64(len(b))}
+}
+
type fakeFile struct {
+ typ string
buf *bytes.Reader
bytesRead int
size int64
}
func (f *fakeFile) Read(p []byte) (int, error) {
+ switch f.typ {
+ case "blocking":
+ <-make(chan struct{})
+ case "slow":
+ time.Sleep(2 * time.Second)
+ }
n, err := f.buf.Read(p)
f.bytesRead += n
return n, err
@@ -645,6 +660,26 @@
c.Assert(err, DeepEquals, ErrWrongSize{"targets.json", int64(len(tamperedJSON)), int64(len(targetsJSON))})
}
+func (s *ClientSuite) TestUpdateSlowRetrievalAttack(c *C) {
+ meta, err := s.store.GetMeta()
+ c.Assert(err, IsNil)
+ snapshot, ok := meta["snapshot.json"]
+ if !ok {
+ c.Fatal("missing snapshot.json")
+ }
+ client := s.newClient(c)
+
+ s.remote["snapshot.json"] = newBlockingFakeFile(snapshot)
+ _, err = client.Update()
+ // c.Assert(err, DeepEquals, ErrWrongSize{"snapshot.json", 0, int64(len(snapshot))})
+ c.Assert(err, DeepEquals, ErrDownloadFailed{"snapshot.json", util.ErrWrongLength})
+
+ s.remote["snapshot.json"] = newSlowFakeFile(snapshot)
+ _, err = client.Update()
+ // c.Assert(err, DeepEquals, ErrWrongSize{"snapshot.json", 16384, int64(len(snapshot))})
+ c.Assert(err, DeepEquals, ErrDownloadFailed{"snapshot.json", util.ErrWrongLength})
+}
+
type testDestination struct {
bytes.Buffer
deleted bool
diff --git a/client/timeout_reader.go b/client/timeout_reader.go
new file mode 100644
index 0000000..a1d7612
--- /dev/null
+++ b/client/timeout_reader.go
@@ -0,0 +1,69 @@
+package client
+
+import (
+ "errors"
+ "io"
+ "time"
+)
+
+var ErrTimeout = errors.New("timeout")
+
+// timeoutReader wraps an io.Reader and times out if the read rate is lower
+// than chunkSize per second
+// TODO: use gracePeriod
+type timeoutReader struct {
+ r io.Reader
+ gracePeriod time.Duration
+ chunkSize int
+}
+
+const (
+ defaultGracePeriod = 4 * time.Second
+ defaultChunkSize = 8 * 1024
+)
+
+// newTimeoutReader returns a timeoutReader with default gracePeriod and chunkSize
+func newTimeoutReader(r io.Reader) *timeoutReader {
+ return &timeoutReader{r, defaultGracePeriod, defaultChunkSize}
+}
+
+// readResult represents the return value of a read
+type readResult struct {
+ n int
+ err error
+}
+
+// Read reads from t.r, timing out if the read rate is lower than t.chunkSize per second
+func (t *timeoutReader) Read(p []byte) (int, error) {
+ if len(p) < t.chunkSize {
+ timeout := (time.Duration(len(p)) * time.Second) / time.Duration(t.chunkSize)
+ return t.readWithTimeout(p, timeout)
+ }
+ var pos int
+ for {
+ size := t.chunkSize
+ if size > len(p)-pos {
+ size = len(p) - pos
+ }
+ m, err := t.readWithTimeout(p[pos:size], time.Second)
+ pos += m
+ if pos == len(p) || err != nil {
+ return pos, err
+ }
+ }
+}
+
+func (t *timeoutReader) readWithTimeout(p []byte, timeout time.Duration) (int, error) {
+ done := make(chan *readResult)
+ go func() {
+ res := &readResult{}
+ res.n, res.err = t.r.Read(p)
+ done <- res
+ }()
+ select {
+ case res := <-done:
+ return res.n, res.err
+ case <-time.After(timeout):
+ return 0, ErrTimeout
+ }
+}