Detect cyclic references to prevent infinite recursion.
Prevents panics from stack overflow with a heuristic to detect cycles.
diff --git a/formatter.go b/formatter.go
index c834d46..8dacda2 100644
--- a/formatter.go
+++ b/formatter.go
@@ -57,7 +57,7 @@
func (fo formatter) Format(f fmt.State, c rune) {
if fo.force || c == 'v' && f.Flag('#') && f.Flag(' ') {
w := tabwriter.NewWriter(f, 4, 4, 1, ' ', 0)
- p := &printer{tw: w, Writer: w}
+ p := &printer{tw: w, Writer: w, visited: make(map[visit]int)}
p.printValue(reflect.ValueOf(fo.x), true, fo.quote)
w.Flush()
return
@@ -67,7 +67,9 @@
type printer struct {
io.Writer
- tw *tabwriter.Writer
+ tw *tabwriter.Writer
+ visited map[visit]int
+ depth int
}
func (p *printer) indent() *printer {
@@ -86,7 +88,19 @@
}
}
+// printValue must keep track of already-printed pointer values to avoid
+// infinite recursion.
+type visit struct {
+ v uintptr
+ typ reflect.Type
+}
+
func (p *printer) printValue(v reflect.Value, showType, quote bool) {
+ if p.depth > 10 {
+ io.WriteString(p, "!%v(DEPTH EXCEEDED)")
+ return
+ }
+
switch v.Kind() {
case reflect.Bool:
p.printInline(v, v.Bool(), showType)
@@ -138,6 +152,16 @@
writeByte(p, '}')
case reflect.Struct:
t := v.Type()
+ if v.CanAddr() {
+ addr := v.UnsafeAddr()
+ vis := visit{addr, t}
+ if vd, ok := p.visited[vis]; ok && vd < p.depth {
+ p.fmtString(t.String()+"{(CYCLIC REFERENCE)}", false)
+ break // don't print v again
+ }
+ p.visited[vis] = p.depth
+ }
+
if showType {
io.WriteString(p, t.String())
}
@@ -176,7 +200,9 @@
case e.Kind() == reflect.Invalid:
io.WriteString(p, "nil")
case e.IsValid():
- p.printValue(e, showType, true)
+ pp := *p
+ pp.depth++
+ pp.printValue(e, showType, true)
default:
io.WriteString(p, v.Type().String())
io.WriteString(p, "(nil)")
@@ -221,8 +247,10 @@
io.WriteString(p, v.Type().String())
io.WriteString(p, ")(nil)")
} else {
- writeByte(p, '&')
- p.printValue(e, true, true)
+ pp := *p
+ pp.depth++
+ writeByte(pp, '&')
+ pp.printValue(e, true, true)
}
case reflect.Chan:
x := v.Pointer()
diff --git a/formatter_test.go b/formatter_test.go
index ec9bbca..4342f1b 100644
--- a/formatter_test.go
+++ b/formatter_test.go
@@ -3,6 +3,7 @@
import (
"fmt"
"io"
+ "strings"
"testing"
"unsafe"
)
@@ -146,3 +147,113 @@
}
}
}
+
+type I struct {
+ i int
+ R interface{}
+}
+
+func (i *I) I() *I { return i.R.(*I) }
+
+func TestCycle(t *testing.T) {
+ type A struct{ *A }
+ v := &A{}
+ v.A = v
+
+ // panics from stack overflow without cycle detection
+ t.Logf("Example cycle:\n%# v", Formatter(v))
+
+ p := &A{}
+ s := fmt.Sprintf("%# v", Formatter([]*A{p, p}))
+ if strings.Contains(s, "CYCLIC") {
+ t.Errorf("Repeated address detected as cyclic reference:\n%s", s)
+ }
+
+ type R struct {
+ i int
+ *R
+ }
+ r := &R{
+ i: 1,
+ R: &R{
+ i: 2,
+ R: &R{
+ i: 3,
+ },
+ },
+ }
+ r.R.R.R = r
+ t.Logf("Example longer cycle:\n%# v", Formatter(r))
+
+ r = &R{
+ i: 1,
+ R: &R{
+ i: 2,
+ R: &R{
+ i: 3,
+ R: &R{
+ i: 4,
+ R: &R{
+ i: 5,
+ R: &R{
+ i: 6,
+ R: &R{
+ i: 7,
+ R: &R{
+ i: 8,
+ R: &R{
+ i: 9,
+ R: &R{
+ i: 10,
+ R: &R{
+ i: 11,
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+ // here be pirates
+ r.R.R.R.R.R.R.R.R.R.R.R = r
+ t.Logf("Example long interface cycle:\n%# v", Formatter(r))
+
+ i := &I{
+ i: 1,
+ R: &I{
+ i: 2,
+ R: &I{
+ i: 3,
+ R: &I{
+ i: 4,
+ R: &I{
+ i: 5,
+ R: &I{
+ i: 6,
+ R: &I{
+ i: 7,
+ R: &I{
+ i: 8,
+ R: &I{
+ i: 9,
+ R: &I{
+ i: 10,
+ R: &I{
+ i: 11,
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+ t.Logf("Example very long cycle:\n%# v", Formatter(i))
+}