Support Go 1.13 error chains in `Cause` (#215)
diff --git a/cause.go b/cause.go
new file mode 100644
index 0000000..566f88b
--- /dev/null
+++ b/cause.go
@@ -0,0 +1,29 @@
+// +build !go1.13
+
+package errors
+
+// Cause recursively unwraps an error chain and returns the underlying cause of
+// the error, if possible. An error value has a cause if it implements the
+// following interface:
+//
+// type causer interface {
+// Cause() error
+// }
+//
+// If the error does not implement Cause, the original error will
+// be returned. If the error is nil, nil will be returned without further
+// investigation.
+func Cause(err error) error {
+ type causer interface {
+ Cause() error
+ }
+
+ for err != nil {
+ cause, ok := err.(causer)
+ if !ok {
+ break
+ }
+ err = cause.Cause()
+ }
+ return err
+}
diff --git a/errors.go b/errors.go
index 161aea2..a9840ec 100644
--- a/errors.go
+++ b/errors.go
@@ -260,29 +260,3 @@
io.WriteString(s, w.Error())
}
}
-
-// Cause returns the underlying cause of the error, if possible.
-// An error value has a cause if it implements the following
-// interface:
-//
-// type causer interface {
-// Cause() error
-// }
-//
-// If the error does not implement Cause, the original error will
-// be returned. If the error is nil, nil will be returned without further
-// investigation.
-func Cause(err error) error {
- type causer interface {
- Cause() error
- }
-
- for err != nil {
- cause, ok := err.(causer)
- if !ok {
- break
- }
- err = cause.Cause()
- }
- return err
-}
diff --git a/go113.go b/go113.go
index be0d10d..ed0dc7a 100644
--- a/go113.go
+++ b/go113.go
@@ -36,3 +36,36 @@
func Unwrap(err error) error {
return stderrors.Unwrap(err)
}
+
+// Cause recursively unwraps an error chain and returns the underlying cause of
+// the error, if possible. There are two ways that an error value may provide a
+// cause. First, the error may implement the following interface:
+//
+// type causer interface {
+// Cause() error
+// }
+//
+// Second, the error may return a non-nil value when passed as an argument to
+// the Unwrap function. This makes Cause forwards-compatible with Go 1.13 error
+// chains.
+//
+// If an error value satisfies both methods of unwrapping, Cause will use the
+// causer interface.
+//
+// If the error is nil, nil will be returned without further investigation.
+func Cause(err error) error {
+ type causer interface {
+ Cause() error
+ }
+
+ for err != nil {
+ if cause, ok := err.(causer); ok {
+ err = cause.Cause()
+ } else if unwrapped := Unwrap(err); unwrapped != nil {
+ err = unwrapped
+ } else {
+ break
+ }
+ }
+ return err
+}
diff --git a/go113_test.go b/go113_test.go
index 4ea37e6..7da3788 100644
--- a/go113_test.go
+++ b/go113_test.go
@@ -9,7 +9,29 @@
"testing"
)
-func TestErrorChainCompat(t *testing.T) {
+func TestCauseErrorChainCompat(t *testing.T) {
+ err := stderrors.New("the cause!")
+
+ // Wrap error using the standard library
+ wrapped := fmt.Errorf("wrapped with stdlib: %w", err)
+ if Cause(wrapped) != err {
+ t.Errorf("Cause does not support Go 1.13 error chains")
+ }
+
+ // Wrap in another layer using pkg/errors
+ wrapped = WithMessage(wrapped, "wrapped with pkg/errors")
+ if Cause(wrapped) != err {
+ t.Errorf("Cause does not support Go 1.13 error chains")
+ }
+
+ // Wrap in another layer using the standard library
+ wrapped = fmt.Errorf("wrapped with stdlib: %w", wrapped)
+ if Cause(wrapped) != err {
+ t.Errorf("Cause does not support Go 1.13 error chains")
+ }
+}
+
+func TestWrapErrorChainCompat(t *testing.T) {
err := stderrors.New("error that gets wrapped")
wrapped := Wrap(err, "wrapped up")
if !stderrors.Is(wrapped, err) {