cryptobyte: add (*Builder).Unwrite and (*Builder).SetError
Unwrite allows programs to rollback builders more reliably and
efficiently than by copying a Builder (which might waste an allocation
and depends on internal behavior). This is useful for example to remove
a length-prefixed field if it ends up being empty.
SetError allows simple Builder extensions to set errors without making
MarshalingValue wrappers.
Based on the experience of CL 144115.
Change-Id: I9a785b81b51b15af49418b5bdb71c4ef222ccc46
Reviewed-on: https://go-review.googlesource.com/c/145317
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/cryptobyte/builder.go b/cryptobyte/builder.go
index 765bf8e..d0c6c3f 100644
--- a/cryptobyte/builder.go
+++ b/cryptobyte/builder.go
@@ -50,6 +50,12 @@
}
}
+// SetError sets the value to be returned as the error from Bytes. Writes
+// performed after calling SetError are ignored.
+func (b *Builder) SetError(err error) {
+ b.err = err
+}
+
// Bytes returns the bytes written by the builder or an error if one has
// occurred during building.
func (b *Builder) Bytes() ([]byte, error) {
@@ -94,7 +100,7 @@
b.add(v...)
}
-// BuilderContinuation is continuation-passing interface for building
+// BuilderContinuation is a continuation-passing interface for building
// length-prefixed byte sequences. Builder methods for length-prefixed
// sequences (AddUint8LengthPrefixed etc) will invoke the BuilderContinuation
// supplied to them. The child builder passed to the continuation can be used
@@ -290,6 +296,26 @@
b.result = append(b.result, bytes...)
}
+// Unwrite rolls back n bytes written directly to the Builder. An attempt by a
+// child builder passed to a continuation to unwrite bytes from its parent will
+// panic.
+func (b *Builder) Unwrite(n int) {
+ if b.err != nil {
+ return
+ }
+ if b.child != nil {
+ panic("attempted unwrite while child is pending")
+ }
+ length := len(b.result) - b.pendingLenLen - b.offset
+ if length < 0 {
+ panic("cryptobyte: internal error")
+ }
+ if n > length {
+ panic("cryptobyte: attempted to unwrite more than was written")
+ }
+ b.result = b.result[:len(b.result)-n]
+}
+
// A MarshalingValue marshals itself into a Builder.
type MarshalingValue interface {
// Marshal is called by Builder.AddValue. It receives a pointer to a builder
diff --git a/cryptobyte/cryptobyte_test.go b/cryptobyte/cryptobyte_test.go
index f294dd5..b859cc9 100644
--- a/cryptobyte/cryptobyte_test.go
+++ b/cryptobyte/cryptobyte_test.go
@@ -327,12 +327,14 @@
var b Builder
b.AddUint8LengthPrefixed(func(c *Builder) {
c.AddUint8LengthPrefixed(func(d *Builder) {
- defer func() {
- if recover() == nil {
- t.Errorf("recover() = nil, want error; c.AddUint8() did not panic")
- }
+ func() {
+ defer func() {
+ if recover() == nil {
+ t.Errorf("recover() = nil, want error; c.AddUint8() did not panic")
+ }
+ }()
+ c.AddUint8(2) // panics
}()
- c.AddUint8(2) // panics
defer func() {
if recover() == nil {
@@ -351,6 +353,65 @@
})
}
+func TestSetError(t *testing.T) {
+ const errorStr = "TestSetError"
+ var b Builder
+ b.SetError(errors.New(errorStr))
+
+ ret, err := b.Bytes()
+ if ret != nil {
+ t.Error("expected nil result")
+ }
+ if err == nil {
+ t.Fatal("unexpected nil error")
+ }
+ if s := err.Error(); s != errorStr {
+ t.Errorf("expected error %q, got %v", errorStr, s)
+ }
+}
+
+func TestUnwrite(t *testing.T) {
+ var b Builder
+ b.AddBytes([]byte{1, 2, 3, 4, 5})
+ b.Unwrite(2)
+ if err := builderBytesEq(&b, 1, 2, 3); err != nil {
+ t.Error(err)
+ }
+
+ func() {
+ defer func() {
+ if recover() == nil {
+ t.Errorf("recover() = nil, want error; b.Unwrite() did not panic")
+ }
+ }()
+ b.Unwrite(4) // panics
+ }()
+
+ b = Builder{}
+ b.AddBytes([]byte{1, 2, 3, 4, 5})
+ b.AddUint8LengthPrefixed(func(b *Builder) {
+ b.AddBytes([]byte{1, 2, 3, 4, 5})
+
+ defer func() {
+ if recover() == nil {
+ t.Errorf("recover() = nil, want error; b.Unwrite() did not panic")
+ }
+ }()
+ b.Unwrite(6) // panics
+ })
+
+ b = Builder{}
+ b.AddBytes([]byte{1, 2, 3, 4, 5})
+ b.AddUint8LengthPrefixed(func(c *Builder) {
+ defer func() {
+ if recover() == nil {
+ t.Errorf("recover() = nil, want error; b.Unwrite() did not panic")
+ }
+ }()
+ b.Unwrite(2) // panics (attempted unwrite while child is pending)
+ })
+}
+
// ASN.1
func TestASN1Int64(t *testing.T) {