fixed tests, add coverage for Wrap and Fprintf
diff --git a/errors.go b/errors.go
index 4d0d8fb..dcc1456 100644
--- a/errors.go
+++ b/errors.go
@@ -1,9 +1,28 @@
// Package errors implements functions to manipulate errors.
package errors
+import (
+ "fmt"
+ "io"
+ "os"
+ "runtime"
+)
+
+// New returns an error that formats as the given text.
+func New(text string) error {
+ return struct {
+ error
+ pc uintptr
+ }{
+ fmt.Errorf(text),
+ pc(),
+ }
+}
+
type e struct {
cause error
message string
+ pc uintptr
}
func (e *e) Error() string {
@@ -20,7 +39,15 @@
if cause == nil {
return nil
}
- return &e{cause: cause, message: message}
+ return &e{
+ cause: cause,
+ message: message,
+ pc: pc(),
+ }
+}
+
+type causer interface {
+ Cause() error
}
// Cause returns the underlying cause of the error, if possible.
@@ -35,14 +62,51 @@
// be returned. If the error is nil, nil will be returned without further
// investigation.
func Cause(err error) error {
- if err == nil {
- return nil
- }
- type causer interface {
- Cause() error
- }
- if err, ok := err.(causer); ok {
- return err.Cause()
+ for err != nil {
+ cause, ok := err.(causer)
+ if !ok {
+ break
+ }
+ err = cause.Cause()
}
return err
}
+
+type locationer interface {
+ Location() (string, int)
+}
+
+// Print prints the error to Stderr.
+func Print(err error) {
+ Fprint(os.Stderr, err)
+}
+
+// Fprint prints the error to the supplied writer.
+// The format of the output is the same as Print.
+// If err is nil, nothing is printed.
+func Fprint(w io.Writer, err error) {
+ for err != nil {
+ location, ok := err.(locationer)
+ if ok {
+ file, line := location.Location()
+ fmt.Fprint(w, "%s:%d: ", file, line)
+ }
+ switch err := err.(type) {
+ case *e:
+ fmt.Fprintln(w, err.message)
+ default:
+ fmt.Fprintln(w, err.Error())
+ }
+
+ cause, ok := err.(causer)
+ if !ok {
+ break
+ }
+ err = cause.Cause()
+ }
+}
+
+func pc() uintptr {
+ pc, _, _, _ := runtime.Caller(2)
+ return pc
+}
diff --git a/errors_test.go b/errors_test.go
index 6b6321e..2eb3de7 100644
--- a/errors_test.go
+++ b/errors_test.go
@@ -1,11 +1,56 @@
package errors
import (
+ "bytes"
+ "fmt"
"io"
"reflect"
"testing"
)
+func TestNew(t *testing.T) {
+ tests := []struct {
+ err string
+ want error
+ }{
+ {"", fmt.Errorf("")},
+ {"foo", fmt.Errorf("foo")},
+ {"foo", New("foo")},
+ }
+
+ for _, tt := range tests {
+ got := New(tt.err)
+ if got.Error() != tt.want.Error() {
+ t.Errorf("New.Error(): got: %q, want %q", got, tt.want)
+ }
+ }
+}
+
+func TestWrapNil(t *testing.T) {
+ got := Wrap(nil, "no error")
+ if got != nil {
+ t.Errorf("Wrap(nil, \"no error\"): got %#v, expected nil", got)
+ }
+}
+
+func TestWrap(t *testing.T) {
+ tests := []struct {
+ err error
+ message string
+ want string
+ }{
+ {io.EOF, "read error", "read error: EOF"},
+ {Wrap(io.EOF, "read error"), "client error", "client error: read error: EOF"},
+ }
+
+ for _, tt := range tests {
+ got := Wrap(tt.err, tt.message).Error()
+ if got != tt.want {
+ t.Errorf("Wrap(%v, %q): got: %v, want %v", tt.err, tt.message, got, tt.want)
+ }
+ }
+}
+
type nilError struct{}
func (nilError) Error() string { return "nil error" }
@@ -18,6 +63,7 @@
func (e *causeError) Cause() error { return e.cause }
func TestCause(t *testing.T) {
+ x := New("error")
tests := []struct {
err error
want error
@@ -41,6 +87,9 @@
// caused error returns cause
err: &causeError{cause: io.EOF},
want: io.EOF,
+ }, {
+ err: x, // return from errors.New
+ want: x,
}}
for i, tt := range tests {
@@ -50,3 +99,43 @@
}
}
}
+
+func TestFprint(t *testing.T) {
+ x := New("error")
+ tests := []struct {
+ err error
+ want string
+ }{{
+ // nil error is nil
+ err: nil,
+ }, {
+ // explicit nil error is nil
+ err: (error)(nil),
+ }, {
+ // uncaused error is unaffected
+ err: io.EOF,
+ want: "EOF\n",
+ }, {
+ // caused error returns cause
+ err: &causeError{cause: io.EOF},
+ want: "cause error\nEOF\n",
+ }, {
+ err: x, // return from errors.New
+ want: "error\n",
+ }, {
+ err: Wrap(x, "message"),
+ want: "message\nerror\n",
+ }, {
+ err: Wrap(Wrap(x, "message"), "another message"),
+ want: "another message\nmessage\nerror\n",
+ }}
+
+ for i, tt := range tests {
+ var w bytes.Buffer
+ Fprint(&w, tt.err)
+ got := w.String()
+ if got != tt.want {
+ t.Errorf("test %d: Fprint(w, %q): got %q, want %q", i+1, tt.err, got, tt.want)
+ }
+ }
+}