diff --git a/middleware.go b/middleware.go index 9492589..3f18098 100644 --- a/middleware.go +++ b/middleware.go @@ -7,7 +7,16 @@ import ( "github.com/gin-gonic/gin" ) -func ErrorMiddleware() gin.HandlerFunc { +func ErrorMiddleware(opts ...Option) gin.HandlerFunc { + conf := config{} + for _, opt := range opts { + conf = opt(conf) + } + + if len(conf.transformers) == 0 { + conf.transformers = []Transformer{ginTransformer} + } + return func(c *gin.Context) { c.Next() err := c.Errors.Last() @@ -16,24 +25,27 @@ func ErrorMiddleware() gin.HandlerFunc { } var re rErrors.ResponsableError + errors.As(err, &re) + // If we have at least one that's a ResponsableError, we should use it for _, err = range c.Errors { - errors.As(err, &re) if re != nil { break } + errors.As(err, &re) } - // @todo we need to add some way to do custom handling - - // @todo Refactor this with 👆 - if re == nil { - switch err.Type { - case gin.ErrorTypePrivate: - re = rErrors.NewInternalError("%w", err) - default: - re = rErrors.NewBadRequest("%w", err) + // Next, let's check our transformers + for _, trans := range conf.transformers { + if re != nil { + break } + re = trans(err) + } + + // Still couldn't find one, so it's a 500 + if re == nil { + re = rErrors.NewInternalError("%w", err) } c.JSON(re.Status(), re) diff --git a/middleware_test.go b/middleware_test.go index 798747f..6a99f4d 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -43,6 +43,10 @@ func (s *MiddlewareTestSuite) do(err ...error) { func (s *MiddlewareTestSuite) doParse(err ...error) map[string]any { s.do(err...) + return s.parse() +} + +func (s *MiddlewareTestSuite) parse() map[string]any { var out map[string]any jsonErr := json.Unmarshal(s.w.Body.Bytes(), &out) s.Assert().Nil(jsonErr) @@ -98,3 +102,18 @@ func (s *MiddlewareTestSuite) TestOtherError() { s.Assert().Equal("Unknown Error", outMsg) s.Assert().Equal(http.StatusInternalServerError, s.w.Code) } + +func (s *MiddlewareTestSuite) TestNoWorkingTransformer() { + var noop Transformer = func(err error) rErrors.ResponsableError { + return nil + } + err := errors.New("Foo") + s.ctx.Error(err) + ErrorMiddleware(WithTransformer(noop))(s.ctx) + + out := s.parse() + outMsg, ok := out["error"].(string) + s.Assert().True(ok) + s.Assert().Equal("Unknown Error", outMsg) + s.Assert().Equal(http.StatusInternalServerError, s.w.Code) +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..ba31f25 --- /dev/null +++ b/options.go @@ -0,0 +1,41 @@ +package handler + +import ( + "errors" + + rErrors "codeberg.org/danjones000/responsable-errors" + "github.com/gin-gonic/gin" +) + +type config struct { + transformers []Transformer +} + +type Transformer func(error) rErrors.ResponsableError + +type Option func(config) config + +func WithTransformer(tr Transformer) Option { + return func(c config) config { + c.transformers = append(c.transformers, tr) + return c + } +} + +func WithDefaultTransformer() Option { + return WithTransformer(ginTransformer) +} + +func ginTransformer(er error) rErrors.ResponsableError { + var err *gin.Error + if !errors.As(er, &err) { + return nil + } + + switch err.Type { + case gin.ErrorTypePrivate: + return rErrors.NewInternalError("%w", err) + default: + return rErrors.NewBadRequest("%w", err) + } +} diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..86995fd --- /dev/null +++ b/options_test.go @@ -0,0 +1,38 @@ +package handler + +import ( + "errors" + "fmt" + "testing" + + rErrors "codeberg.org/danjones000/responsable-errors" + "github.com/stretchr/testify/assert" +) + +var someErr error = errors.New("I am a teapot") + +var noop Transformer = func(err error) rErrors.ResponsableError { + return nil +} + +func TestWithTrans(t *testing.T) { + c := config{} + c = WithTransformer(noop)(c) + assert.Len(t, c.transformers, 1) + exp := fmt.Sprintf("%p", noop) + fd := fmt.Sprintf("%p", c.transformers[0]) + assert.Equal(t, exp, fd) +} + +func TestWithDef(t *testing.T) { + c := config{} + c = WithDefaultTransformer()(c) + assert.Len(t, c.transformers, 1) + exp := fmt.Sprintf("%p", ginTransformer) + fd := fmt.Sprintf("%p", c.transformers[0]) + assert.Equal(t, exp, fd) +} + +func TestGinTransNotGinError(t *testing.T) { + assert.Nil(t, ginTransformer(someErr)) +}