chacha20: add SetCounter method

Fixes golang/go#35506

Change-Id: I5cfc6b4dc07ab368e370edaee11841c2c1377f82
GitHub-Last-Rev: 16147a1668a903532f2d3777b873ddad8f0f26f5
GitHub-Pull-Request: golang/crypto#108
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/206638
Run-TryBot: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
diff --git a/chacha20/chacha_generic.go b/chacha20/chacha_generic.go
index 098ec9f..7c498e9 100644
--- a/chacha20/chacha_generic.go
+++ b/chacha20/chacha_generic.go
@@ -136,6 +136,33 @@
 	return a, b, c, d
 }
 
+// SetCounter sets the Cipher counter. The next invocation of XORKeyStream will
+// behave as if (64 * counter) bytes had been encrypted so far.
+//
+// To prevent accidental counter reuse, SetCounter panics if counter is
+// less than the current value.
+func (s *Cipher) SetCounter(counter uint32) {
+	// Internally, s may buffer multiple blocks, which complicates this
+	// implementation slightly. When checking whether the counter has rolled
+	// back, we must use both s.counter and s.len to determine how many blocks
+	// we have already output.
+	outputCounter := s.counter - uint32(s.len)/blockSize
+	if counter < outputCounter {
+		panic("chacha20: SetCounter attempted to rollback counter")
+	}
+
+	// In the general case, we set the new counter value and reset s.len to 0,
+	// causing the next call to XORKeyStream to refill the buffer. However, if
+	// we're advancing within the existing buffer, we can save work by simply
+	// setting s.len.
+	if counter < s.counter {
+		s.len = int(s.counter-counter) * blockSize
+	} else {
+		s.counter = counter
+		s.len = 0
+	}
+}
+
 // XORKeyStream XORs each byte in the given slice with a byte from the
 // cipher's key stream. Dst and src must overlap entirely or not at all.
 //
diff --git a/chacha20/chacha_test.go b/chacha20/chacha_test.go
index 033867b..554afbf 100644
--- a/chacha20/chacha_test.go
+++ b/chacha20/chacha_test.go
@@ -110,6 +110,52 @@
 	}
 }
 
+func TestSetCounter(t *testing.T) {
+	newCipher := func() *Cipher {
+		s, _ := NewUnauthenticatedCipher(make([]byte, KeySize), make([]byte, NonceSize))
+		return s
+	}
+	s := newCipher()
+	src := bytes.Repeat([]byte("test"), 32) // two 64-byte blocks
+	dst1 := make([]byte, len(src))
+	s.XORKeyStream(dst1, src)
+	// advance counter to 1 and xor second block
+	s = newCipher()
+	s.SetCounter(1)
+	dst2 := make([]byte, len(src))
+	s.XORKeyStream(dst2[64:], src[64:])
+	if !bytes.Equal(dst1[64:], dst2[64:]) {
+		t.Error("failed to produce identical output using SetCounter")
+	}
+
+	// test again with unaligned blocks; SetCounter should reset the buffer
+	s = newCipher()
+	s.XORKeyStream(dst1[:70], src[:70])
+	s = newCipher()
+	s.XORKeyStream([]byte{0}, []byte{0})
+	s.SetCounter(1)
+	s.XORKeyStream(dst2[64:70], src[64:70])
+	if !bytes.Equal(dst1[64:70], dst2[64:70]) {
+		t.Error("SetCounter did not reset buffer")
+	}
+
+	// advancing to a lower counter value should cause a panic
+	panics := func(fn func()) (p bool) {
+		defer func() { p = recover() != nil }()
+		fn()
+		return
+	}
+	if !panics(func() { s.SetCounter(0) }) {
+		t.Error("counter decreasing should trigger a panic")
+	}
+	// advancing to ^uint32(0) and then calling XORKeyStream should cause a panic
+	s = newCipher()
+	s.SetCounter(^uint32(0))
+	if !panics(func() { s.XORKeyStream([]byte{0}, []byte{0}) }) {
+		t.Error("counter overflowing should trigger a panic")
+	}
+}
+
 func benchmarkChaCha20(b *testing.B, step, count int) {
 	tot := step * count
 	src := make([]byte, tot)