cryptobyte: add (*Builder).Unwrite and (*Builder).SetError

Unwrite allows programs to rollback builders more reliably and
efficiently than by copying a Builder (which might waste an allocation
and depends on internal behavior). This is useful for example to remove
a length-prefixed field if it ends up being empty.

SetError allows simple Builder extensions to set errors without making
MarshalingValue wrappers.

Based on the experience of CL 144115.

Change-Id: I9a785b81b51b15af49418b5bdb71c4ef222ccc46
Reviewed-on: https://go-review.googlesource.com/c/145317
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/cryptobyte/builder.go b/cryptobyte/builder.go
index 765bf8e..d0c6c3f 100644
--- a/cryptobyte/builder.go
+++ b/cryptobyte/builder.go
@@ -50,6 +50,12 @@
 	}
 }
 
+// SetError sets the value to be returned as the error from Bytes. Writes
+// performed after calling SetError are ignored.
+func (b *Builder) SetError(err error) {
+	b.err = err
+}
+
 // Bytes returns the bytes written by the builder or an error if one has
 // occurred during building.
 func (b *Builder) Bytes() ([]byte, error) {
@@ -94,7 +100,7 @@
 	b.add(v...)
 }
 
-// BuilderContinuation is continuation-passing interface for building
+// BuilderContinuation is a continuation-passing interface for building
 // length-prefixed byte sequences. Builder methods for length-prefixed
 // sequences (AddUint8LengthPrefixed etc) will invoke the BuilderContinuation
 // supplied to them. The child builder passed to the continuation can be used
@@ -290,6 +296,26 @@
 	b.result = append(b.result, bytes...)
 }
 
+// Unwrite rolls back n bytes written directly to the Builder. An attempt by a
+// child builder passed to a continuation to unwrite bytes from its parent will
+// panic.
+func (b *Builder) Unwrite(n int) {
+	if b.err != nil {
+		return
+	}
+	if b.child != nil {
+		panic("attempted unwrite while child is pending")
+	}
+	length := len(b.result) - b.pendingLenLen - b.offset
+	if length < 0 {
+		panic("cryptobyte: internal error")
+	}
+	if n > length {
+		panic("cryptobyte: attempted to unwrite more than was written")
+	}
+	b.result = b.result[:len(b.result)-n]
+}
+
 // A MarshalingValue marshals itself into a Builder.
 type MarshalingValue interface {
 	// Marshal is called by Builder.AddValue. It receives a pointer to a builder
diff --git a/cryptobyte/cryptobyte_test.go b/cryptobyte/cryptobyte_test.go
index f294dd5..b859cc9 100644
--- a/cryptobyte/cryptobyte_test.go
+++ b/cryptobyte/cryptobyte_test.go
@@ -327,12 +327,14 @@
 	var b Builder
 	b.AddUint8LengthPrefixed(func(c *Builder) {
 		c.AddUint8LengthPrefixed(func(d *Builder) {
-			defer func() {
-				if recover() == nil {
-					t.Errorf("recover() = nil, want error; c.AddUint8() did not panic")
-				}
+			func() {
+				defer func() {
+					if recover() == nil {
+						t.Errorf("recover() = nil, want error; c.AddUint8() did not panic")
+					}
+				}()
+				c.AddUint8(2) // panics
 			}()
-			c.AddUint8(2) // panics
 
 			defer func() {
 				if recover() == nil {
@@ -351,6 +353,65 @@
 	})
 }
 
+func TestSetError(t *testing.T) {
+	const errorStr = "TestSetError"
+	var b Builder
+	b.SetError(errors.New(errorStr))
+
+	ret, err := b.Bytes()
+	if ret != nil {
+		t.Error("expected nil result")
+	}
+	if err == nil {
+		t.Fatal("unexpected nil error")
+	}
+	if s := err.Error(); s != errorStr {
+		t.Errorf("expected error %q, got %v", errorStr, s)
+	}
+}
+
+func TestUnwrite(t *testing.T) {
+	var b Builder
+	b.AddBytes([]byte{1, 2, 3, 4, 5})
+	b.Unwrite(2)
+	if err := builderBytesEq(&b, 1, 2, 3); err != nil {
+		t.Error(err)
+	}
+
+	func() {
+		defer func() {
+			if recover() == nil {
+				t.Errorf("recover() = nil, want error; b.Unwrite() did not panic")
+			}
+		}()
+		b.Unwrite(4) // panics
+	}()
+
+	b = Builder{}
+	b.AddBytes([]byte{1, 2, 3, 4, 5})
+	b.AddUint8LengthPrefixed(func(b *Builder) {
+		b.AddBytes([]byte{1, 2, 3, 4, 5})
+
+		defer func() {
+			if recover() == nil {
+				t.Errorf("recover() = nil, want error; b.Unwrite() did not panic")
+			}
+		}()
+		b.Unwrite(6) // panics
+	})
+
+	b = Builder{}
+	b.AddBytes([]byte{1, 2, 3, 4, 5})
+	b.AddUint8LengthPrefixed(func(c *Builder) {
+		defer func() {
+			if recover() == nil {
+				t.Errorf("recover() = nil, want error; b.Unwrite() did not panic")
+			}
+		}()
+		b.Unwrite(2) // panics (attempted unwrite while child is pending)
+	})
+}
+
 // ASN.1
 
 func TestASN1Int64(t *testing.T) {