fixed tests, add coverage for Wrap and Fprintf
diff --git a/errors.go b/errors.go
index 8250584..dcc1456 100644
--- a/errors.go
+++ b/errors.go
@@ -3,107 +3,110 @@
import (
"fmt"
+ "io"
+ "os"
"runtime"
)
// New returns an error that formats as the given text.
-func New(message string) error {
- pc, _, _, _ := runtime.Caller(1) // the caller of New
+func New(text string) error {
return struct {
error
pc uintptr
}{
- fmt.Errorf(message),
- pc,
+ fmt.Errorf(text),
+ pc(),
}
}
-// Errorf returns a formatted error.
-func Errorf(format string, args ...interface{}) error {
- pc, _, _, _ := runtime.Caller(1) // the caller of Errorf
- return struct {
- error
- pc uintptr
- }{
- fmt.Errorf(format, args...),
- pc,
+type e struct {
+ cause error
+ message string
+ pc uintptr
+}
+
+func (e *e) Error() string {
+ return e.message + ": " + e.cause.Error()
+}
+
+func (e *e) Cause() error {
+ return e.cause
+}
+
+// Wrap returns an error annotating the cause with message.
+// If cause is nil, Wrap returns nil.
+func Wrap(cause error, message string) error {
+ if cause == nil {
+ return nil
}
+ return &e{
+ cause: cause,
+ message: message,
+ pc: pc(),
+ }
+}
+
+type causer interface {
+ Cause() error
}
// Cause returns the underlying cause of the error, if possible.
// An error value has a cause if it implements the following
-// method:
+// interface:
//
-// Cause() error
-//
+// type Causer interface {
+// Cause() error
+// }
//
// If the error does not implement Cause, the original error will
// be returned. If the error is nil, nil will be returned without further
// investigation.
func Cause(err error) error {
- if err == nil {
- return nil
- }
- type causer interface {
- Cause() error
- }
- if err, ok := err.(causer); ok {
- return err.Cause()
+ for err != nil {
+ cause, ok := err.(causer)
+ if !ok {
+ break
+ }
+ err = cause.Cause()
}
return err
}
-func underlying(err error) (error, bool) {
- if err == nil {
- return nil, false
- }
- type underlying interface {
- underlying() error
- }
- if err, ok := err.(underlying); ok {
- return err.underlying(), true
- }
- return nil, false
+type locationer interface {
+ Location() (string, int)
}
-type traced struct {
- error // underlying error
- pc uintptr
+// Print prints the error to Stderr.
+func Print(err error) {
+ Fprint(os.Stderr, err)
}
-func (t *traced) underlying() error { return t.error }
+// Fprint prints the error to the supplied writer.
+// The format of the output is the same as Print.
+// If err is nil, nothing is printed.
+func Fprint(w io.Writer, err error) {
+ for err != nil {
+ location, ok := err.(locationer)
+ if ok {
+ file, line := location.Location()
+ fmt.Fprint(w, "%s:%d: ", file, line)
+ }
+ switch err := err.(type) {
+ case *e:
+ fmt.Fprintln(w, err.message)
+ default:
+ fmt.Fprintln(w, err.Error())
+ }
-// Trace adds caller information to the error.
-// If error is nil, nil will be returned.
-func Trace(err error) error {
- if err == nil {
- return nil
- }
- pc, _, _, _ := runtime.Caller(1) // the caller of Trace
- return traced{
- error: err,
- pc: pc,
+ cause, ok := err.(causer)
+ if !ok {
+ break
+ }
+ err = cause.Cause()
}
}
-type annotated struct {
- error // underlying error
- pc uintptr
-}
-
-func (a *annotated) Cause() error { return a.error }
-
-// Annotate returns a new error annotating the error provided
-// with the message, and the location of the caller of Annotate.
-// The underlying error can be recovered by calling Cause.
-// If err is nil, nil will be returned.
-func Annotate(err error, message string) error {
- if err == nil {
- return nil
- }
- pc, _, _, _ := runtime.Caller(1) // the caller of Annotate
- return annotated{
- error: err,
- pc: pc,
- }
+func pc() uintptr {
+ pc, _, _, _ := runtime.Caller(2)
+ return pc
}
diff --git a/errors_test.go b/errors_test.go
index c284106..2eb3de7 100644
--- a/errors_test.go
+++ b/errors_test.go
@@ -1,13 +1,14 @@
package errors
import (
+ "bytes"
"fmt"
"io"
"reflect"
"testing"
)
-func TestNewError(t *testing.T) {
+func TestNew(t *testing.T) {
tests := []struct {
err string
want error
@@ -25,23 +26,28 @@
}
}
-func TestNewEqualNew(t *testing.T) {
- // test that two calls to New return the same error when called from the same location
- var errs []error
- for i := 0; i < 2; i++ {
- errs = append(errs, New("error"))
- }
- a, b := errs[0], errs[1]
- if !reflect.DeepEqual(a, b) {
- t.Errorf("Expected two calls to New from the same location to give the same error: %#v, %#v", a, b)
+func TestWrapNil(t *testing.T) {
+ got := Wrap(nil, "no error")
+ if got != nil {
+ t.Errorf("Wrap(nil, \"no error\"): got %#v, expected nil", got)
}
}
-func TestNewNotEqualNew(t *testing.T) {
- // test that two calls to New return different errors when called from different locations
- a, b := New("error"), New("error")
- if reflect.DeepEqual(a, b) {
- t.Errorf("Expected two calls to New from the different locations give the same error: %#v, %#v", a, b)
+func TestWrap(t *testing.T) {
+ tests := []struct {
+ err error
+ message string
+ want string
+ }{
+ {io.EOF, "read error", "read error: EOF"},
+ {Wrap(io.EOF, "read error"), "client error", "client error: read error: EOF"},
+ }
+
+ for _, tt := range tests {
+ got := Wrap(tt.err, tt.message).Error()
+ if got != tt.want {
+ t.Errorf("Wrap(%v, %q): got: %v, want %v", tt.err, tt.message, got, tt.want)
+ }
}
}
@@ -94,23 +100,42 @@
}
}
-func TestTraceNotEqual(t *testing.T) {
- // test that two calls to trace do not return identical errors
- err := New("error")
- a := err
- var errs []error
- for i := 0; i < 2; i++ {
- err = Trace(err)
- errs = append(errs, err)
- }
- b, c := errs[0], errs[1]
- if reflect.DeepEqual(a, b) {
- t.Errorf("a and b equal: %#v, %#v", a, b)
- }
- if reflect.DeepEqual(b, c) {
- t.Errorf("b and c equal: %#v, %#v", b, c)
- }
- if reflect.DeepEqual(a, c) {
- t.Errorf("a and c equal: %#v, %#v", a, c)
+func TestFprint(t *testing.T) {
+ x := New("error")
+ tests := []struct {
+ err error
+ want string
+ }{{
+ // nil error is nil
+ err: nil,
+ }, {
+ // explicit nil error is nil
+ err: (error)(nil),
+ }, {
+ // uncaused error is unaffected
+ err: io.EOF,
+ want: "EOF\n",
+ }, {
+ // caused error returns cause
+ err: &causeError{cause: io.EOF},
+ want: "cause error\nEOF\n",
+ }, {
+ err: x, // return from errors.New
+ want: "error\n",
+ }, {
+ err: Wrap(x, "message"),
+ want: "message\nerror\n",
+ }, {
+ err: Wrap(Wrap(x, "message"), "another message"),
+ want: "another message\nmessage\nerror\n",
+ }}
+
+ for i, tt := range tests {
+ var w bytes.Buffer
+ Fprint(&w, tt.err)
+ got := w.String()
+ if got != tt.want {
+ t.Errorf("test %d: Fprint(w, %q): got %q, want %q", i+1, tt.err, got, tt.want)
+ }
}
}