summaryrefslogtreecommitdiff
path: root/pkg/errorhandling
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/errorhandling')
-rw-r--r--pkg/errorhandling/errorhandling.go21
-rw-r--r--pkg/errorhandling/errorhandling_test.go53
2 files changed, 73 insertions, 1 deletions
diff --git a/pkg/errorhandling/errorhandling.go b/pkg/errorhandling/errorhandling.go
index fc6772c08..9b456c9c0 100644
--- a/pkg/errorhandling/errorhandling.go
+++ b/pkg/errorhandling/errorhandling.go
@@ -1,11 +1,11 @@
package errorhandling
import (
+ "errors"
"os"
"strings"
"github.com/hashicorp/go-multierror"
- "github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
@@ -121,3 +121,22 @@ func (e PodConflictErrorModel) Error() string {
func (e PodConflictErrorModel) Code() int {
return 409
}
+
+// Cause returns the most underlying error for the provided one. There is a
+// maximum error depth of 100 to avoid endless loops. An additional error log
+// message will be created if this maximum has reached.
+func Cause(err error) (cause error) {
+ cause = err
+
+ const maxDepth = 100
+ for i := 0; i <= maxDepth; i++ {
+ res := errors.Unwrap(cause)
+ if res == nil {
+ return cause
+ }
+ cause = res
+ }
+
+ logrus.Errorf("Max error depth of %d reached, cannot unwrap until root cause: %v", maxDepth, err)
+ return cause
+}
diff --git a/pkg/errorhandling/errorhandling_test.go b/pkg/errorhandling/errorhandling_test.go
new file mode 100644
index 000000000..ec720c5e7
--- /dev/null
+++ b/pkg/errorhandling/errorhandling_test.go
@@ -0,0 +1,53 @@
+package errorhandling
+
+import (
+ "errors"
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestCause(t *testing.T) {
+ t.Parallel()
+
+ for _, tc := range []struct {
+ name string
+ err func() error
+ expectedErr error
+ }{
+ {
+ name: "nil error",
+ err: func() error { return nil },
+ expectedErr: nil,
+ },
+ {
+ name: "equal errors",
+ err: func() error { return errors.New("foo") },
+ expectedErr: errors.New("foo"),
+ },
+ {
+ name: "wrapped error",
+ err: func() error { return fmt.Errorf("baz: %w", fmt.Errorf("bar: %w", errors.New("foo"))) },
+ expectedErr: errors.New("foo"),
+ },
+ {
+ name: "max depth reached",
+ err: func() error {
+ err := errors.New("error")
+ for i := 0; i <= 101; i++ {
+ err = fmt.Errorf("%d: %w", i, err)
+ }
+ return err
+ },
+ expectedErr: fmt.Errorf("0: %w", errors.New("error")),
+ },
+ } {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ err := Cause(tc.err())
+ assert.Equal(t, tc.expectedErr, err)
+ })
+ }
+}