Skip to content

Commit ee7a3ba

Browse files
authored
Merge pull request #9 from aserto-dev/wrapped_error
add WrappedError and logger extraction from ctx
2 parents 8e92a25 + 769fd0c commit ee7a3ba

4 files changed

Lines changed: 226 additions & 5 deletions

File tree

.gitignore

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
11
/cover.out
22
/tenant.db
3-
/.ext
3+
/.ext
4+
5+
# https://github.com/golang/go/issues/53502
6+
# go.work.sum is machine specific and should not be checked in
7+
# go.work.sum
8+
9+
.DS_Store

context.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package errors
2+
3+
import (
4+
"context"
5+
6+
"github.com/pkg/errors"
7+
)
8+
9+
// ContextError represents a standard error
10+
// that can also encapsulate a context.
11+
type ContextError struct {
12+
Err error
13+
Ctx context.Context
14+
}
15+
16+
func WithContext(err error, ctx context.Context) *ContextError {
17+
return &ContextError{
18+
Err: err,
19+
Ctx: ctx,
20+
}
21+
}
22+
23+
func WrapContext(err error, ctx context.Context, message string) *ContextError {
24+
return WithContext(errors.Wrap(err, message), ctx)
25+
}
26+
27+
func WrapfContext(err error, ctx context.Context, format string, args ...interface{}) *ContextError {
28+
return WithContext(errors.Wrapf(err, format, args...), ctx)
29+
}
30+
31+
func (ce *ContextError) Error() string {
32+
return ce.Err.Error()
33+
}
34+
35+
func (ce *ContextError) Cause() error {
36+
return errors.Cause(ce.Unwrap())
37+
}
38+
39+
func (ce *ContextError) Unwrap() error {
40+
return ce.Err
41+
}

errors.go

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package errors
22

33
import (
4+
"context"
45
"fmt"
56
"io"
67
"net/http"
@@ -204,7 +205,6 @@ func (e *AsertoError) Time(key string, value time.Time) *AsertoError {
204205
func (e *AsertoError) FromReader(key string, value io.Reader) *AsertoError {
205206
buf := &strings.Builder{}
206207
_, err := io.Copy(buf, value)
207-
208208
if err != nil {
209209
return e.Err(err)
210210
}
@@ -252,7 +252,6 @@ func (e *AsertoError) GRPCStatus() *status.Status {
252252
Metadata: e.Data(),
253253
Domain: e.Code,
254254
})
255-
256255
if err != nil {
257256
return status.New(codes.Internal, "internal failure setting up error details, please contact the administrator")
258257
}
@@ -272,6 +271,10 @@ func (e *AsertoError) WithHTTPStatus(httpStatus int) *AsertoError {
272271
return c
273272
}
274273

274+
func (e *AsertoError) Ctx(ctx context.Context) error {
275+
return WithContext(e, ctx)
276+
}
277+
275278
// Returns an Aserto error based on a given grpcStatus. The details that are not of type errdetails.ErrorInfo are dropped.
276279
// and if there are details from multiple errors, the aserto error will be constructed based on the first one.
277280
func FromGRPCStatus(grpcStatus status.Status) *AsertoError {
@@ -297,6 +300,33 @@ func FromGRPCStatus(grpcStatus status.Status) *AsertoError {
297300
return result
298301
}
299302

303+
/**
304+
* Retrieves the most inner logger associated with an error.
305+
*/
306+
func Logger(err error) *zerolog.Logger {
307+
var logger *zerolog.Logger
308+
var ce *ContextError
309+
310+
if err == nil {
311+
return logger
312+
}
313+
314+
for {
315+
if errors.As(err, &ce) {
316+
if ctxLogger := extractLogger(ce.Ctx); ctxLogger != nil {
317+
logger = ctxLogger
318+
}
319+
}
320+
321+
err = errors.Unwrap(err)
322+
if err == nil {
323+
break
324+
}
325+
}
326+
327+
return logger
328+
}
329+
300330
func UnwrapAsertoError(err error) *AsertoError {
301331
if err == nil {
302332
return nil
@@ -351,3 +381,19 @@ func Equals(err1, err2 error) bool {
351381
func CodeToAsertoError(code string) *AsertoError {
352382
return asertoErrors[code]
353383
}
384+
385+
/**
386+
* Retrieve the logger associated with the context using zerolog.Ctx(ctx).
387+
* If the retrieved logger is either the default context logger or has a disabled level, it returns nil.
388+
*/
389+
func extractLogger(ctx context.Context) *zerolog.Logger {
390+
if ctx == nil {
391+
return nil
392+
}
393+
logger := zerolog.Ctx(ctx)
394+
if logger == zerolog.DefaultContextLogger || logger.GetLevel() == zerolog.Disabled {
395+
logger = nil
396+
}
397+
398+
return logger
399+
}

errors_test.go

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package errors_test
22

33
import (
4+
"context"
45
"net/http"
6+
"os"
57
"testing"
68

79
"github.com/pkg/errors"
10+
"github.com/rs/zerolog"
811

912
cerr "github.com/aserto-dev/errors"
1013
"github.com/stretchr/testify/require"
@@ -98,7 +101,6 @@ func TestFromGRPCStatus(t *testing.T) {
98101
Metadata: initialErr.Data(),
99102
Domain: initialErr.Code,
100103
})
101-
102104
if err != nil {
103105
assert.Fail(err.Error())
104106
}
@@ -126,7 +128,6 @@ func TestEquals(t *testing.T) {
126128
err2 := ErrAlreadyExists.Msgf("error 2").Str("key2", "val2").Err(errors.New("zoom"))
127129

128130
assert.True(cerr.Equals(err1, err2))
129-
130131
}
131132

132133
func TestEqualsNil(t *testing.T) {
@@ -198,3 +199,130 @@ func TestWithHttpError(t *testing.T) {
198199
unAerr := cerr.UnwrapAsertoError(aerr)
199200
assert.Equal(http.StatusNotAcceptable, unAerr.HTTPCode)
200201
}
202+
203+
// returns nil logger if error is nil.
204+
func TestLoggerWithNilError(t *testing.T) {
205+
assert := require.New(t)
206+
207+
var err error
208+
logger := cerr.Logger(err)
209+
assert.Nil(logger)
210+
}
211+
212+
func TestLoggerWithWrappedNilError(t *testing.T) {
213+
assert := require.New(t)
214+
215+
var err error
216+
ctx := context.Background()
217+
218+
logger := cerr.Logger(cerr.WithContext(err, ctx))
219+
assert.Nil(logger)
220+
}
221+
222+
func TestLoggerWithWrappedErrorsWithEmptyContext(t *testing.T) {
223+
assert := require.New(t)
224+
225+
ctx := context.Background()
226+
err := cerr.WithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx)
227+
wrappedErr := errors.Wrap(err, "wrapped error")
228+
229+
logger := cerr.Logger(wrappedErr)
230+
assert.Nil(logger)
231+
}
232+
233+
func TestLoggerWithWrappedErrorsWithLoggerContext(t *testing.T) {
234+
assert := require.New(t)
235+
initialLogger := zerolog.New(os.Stderr)
236+
237+
ctx := context.Background()
238+
ctx = initialLogger.WithContext(ctx)
239+
err := cerr.WithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx)
240+
wrappedErr := errors.Wrap(err, "wrapped error")
241+
242+
logger := cerr.Logger(wrappedErr)
243+
assert.NotNil(logger)
244+
assert.Equal(logger, zerolog.Ctx(ctx))
245+
}
246+
247+
func TestLoggerWithWrappedMultipleWithoutErrorsWithContext(t *testing.T) {
248+
assert := require.New(t)
249+
initialLogger := zerolog.New(os.Stderr)
250+
251+
ctx := context.Background()
252+
ctx = initialLogger.WithContext(ctx)
253+
err := cerr.WithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx)
254+
errWithoutCtx := cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error")
255+
wrappedErr := errWithoutCtx.Err(errors.Wrap(err, "wrapped error"))
256+
257+
logger := cerr.Logger(wrappedErr)
258+
assert.NotNil(logger)
259+
assert.Equal(logger, zerolog.Ctx(ctx))
260+
}
261+
262+
func TestLoggerWithWrappedMultipleErrorsWithContext(t *testing.T) {
263+
assert := require.New(t)
264+
initialLogger := zerolog.New(os.Stderr)
265+
266+
ctx := context.Background()
267+
ctx = initialLogger.WithContext(ctx)
268+
err := cerr.WithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx)
269+
errWithoutCtx := cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error")
270+
wrappedErr := errors.Wrap(errWithoutCtx.Err(err), "wrapped error")
271+
272+
logger := cerr.Logger(wrappedErr)
273+
assert.NotNil(logger)
274+
assert.Equal(logger, zerolog.Ctx(ctx))
275+
}
276+
277+
func TestLoggerWithWrappedMultipleErrorsWithMultipleContexts(t *testing.T) {
278+
assert := require.New(t)
279+
initialLogger := zerolog.New(os.Stderr)
280+
ctx1 := context.Background()
281+
ctx2 := initialLogger.WithContext(ctx1)
282+
err := cerr.WithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx1)
283+
wrappedErr := cerr.WithContext(cerr.WithContext(err, ctx2), ctx1)
284+
285+
logger := cerr.Logger(wrappedErr)
286+
ctx1Logger := zerolog.Ctx(ctx1)
287+
ctx2Logger := zerolog.Ctx(ctx2)
288+
289+
assert.NotNil(logger)
290+
assert.NotEqual(logger, ctx1Logger)
291+
assert.Equal(logger, ctx2Logger)
292+
}
293+
294+
func TestLoggerWithWrappedMultipleErrorsWithMultipleContextsOuter(t *testing.T) {
295+
assert := require.New(t)
296+
initialLogger := zerolog.New(os.Stderr)
297+
ctx1 := context.Background()
298+
ctx2 := initialLogger.WithContext(ctx1)
299+
err := cerr.WithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx1)
300+
err2 := cerr.WithContext(cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error"), ctx2)
301+
wrappedErr := errors.Wrap(errors.Wrap(err2, err.Error()), "wrapped error")
302+
303+
logger := cerr.Logger(wrappedErr)
304+
ctx1Logger := zerolog.Ctx(ctx1)
305+
ctx2Logger := zerolog.Ctx(ctx2)
306+
307+
assert.NotNil(logger)
308+
assert.NotEqual(logger, ctx1Logger)
309+
assert.Equal(logger, ctx2Logger)
310+
}
311+
312+
func TestLoggerWithWrappedMultipleAsertoErrorsWithMultipleContextsOuter(t *testing.T) {
313+
assert := require.New(t)
314+
initialLogger := zerolog.New(os.Stderr)
315+
ctx1 := context.Background()
316+
ctx2 := initialLogger.WithContext(ctx1)
317+
err := cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error").Ctx(ctx1)
318+
err2 := cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error").Ctx(ctx2)
319+
wrappedErr := errors.Wrap(errors.Wrap(err2, err.Error()), "wrapped error")
320+
321+
logger := cerr.Logger(wrappedErr)
322+
ctx1Logger := zerolog.Ctx(ctx1)
323+
ctx2Logger := zerolog.Ctx(ctx2)
324+
325+
assert.NotNil(logger)
326+
assert.NotEqual(logger, ctx1Logger)
327+
assert.Equal(logger, ctx2Logger)
328+
}

0 commit comments

Comments
 (0)