diff --git a/handler.go b/handler.go index 68d6e60..f9b8d66 100644 --- a/handler.go +++ b/handler.go @@ -1,13 +1,44 @@ package ezhandler -import "net/http" +import ( + "io" + "net/http" +) +// Handler is similar to [http.Handler], but also may return an error. type Handler interface { ServeHTTP(w http.ResponseWriter, r *http.Request) error } +// HandlerFunc is similar to [http.HandlerFunc] but it can also return an error. type HandlerFunc func(w http.ResponseWriter, r *http.Request) error +var _ Handler = HandlerFunc(nil) + func (fn HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) error { return fn(w, r) } + +// ResponseHandler is similar to HandlerFunc but returns a [ResponseHelper], instead of passing an [http.ResponseWriter]. +type ResponseHandler func(r *http.Request) (resp ResponseHelper, err error) + +var _ Handler = ResponseHandler(nil) + +func (fn ResponseHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) error { + resp, err := fn(r) + if err != nil { + return err + } + w.WriteHeader(resp.Status()) + for key, values := range resp.Headers() { + for _, val := range values { + w.Header().Add(key, val) + } + } + body, err := resp.Body() + if err != nil { + return err + } + _, err = io.Copy(w, body) + return err +} diff --git a/helper.go b/helper.go index 9510f34..6595062 100644 --- a/helper.go +++ b/helper.go @@ -34,3 +34,9 @@ func (help *Helper) HandlerFunc(hnd HandlerFunc) http.Handler { } }) } + +// ResponderHandler returns an [http.Handler] for the provided [ResponderHandler]. +// If hnd returns an error, an appropriate error response is written using the ErrorHandler. +func (help *Helper) ResponderHandler(hnd ResponseHandler) http.Handler { + return help.Handler(hnd) +} diff --git a/helper_test.go b/helper_test.go index a2fee51..6c2a3ad 100644 --- a/helper_test.go +++ b/helper_test.go @@ -44,6 +44,14 @@ func runHelperTest(t *testing.T, name string, handlerErr, expectedErr error, exp return handlerErr }) wrappedHandler = helper.HandlerFunc(mockHandlerFunc) + case "ResponderHandler": + mockResponseHandler := ezhandler.ResponseHandler(func(r *http.Request) (ezhandler.ResponseHelper, error) { + if handlerErr != nil { + return nil, handlerErr + } + return ezhandler.JSONResponse(map[string]string{"status": "ok"}), nil + }) + wrappedHandler = helper.ResponderHandler(mockResponseHandler) } wrappedHandler.ServeHTTP(rec, req) @@ -61,3 +69,8 @@ func TestHelper_HandlerFunc(t *testing.T) { runHelperTest(t, "no error", nil, nil, false, "HandlerFunc") runHelperTest(t, "with error", errTest, errTest, true, "HandlerFunc") } + +func TestHelper_ResponderHandler(t *testing.T) { + runHelperTest(t, "no error", nil, nil, false, "ResponderHandler") + runHelperTest(t, "with error", errTest, errTest, true, "ResponderHandler") +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..b101163 --- /dev/null +++ b/response.go @@ -0,0 +1,56 @@ +package ezhandler + +import ( + "bytes" + "encoding/json" + "io" + "net/http" +) + +// ResponseHelper is a simpler way to return a response. +type ResponseHelper interface { + // Body returns the body of the response. It is written to an [http.ResponseWriter]. + Body() (io.Reader, error) + // Status should be a valid HTTP status. + Status() int + // Headers returns headers that should be added to the response. + Headers() http.Header +} + +type jsonResponse struct { + value any + status int +} + +var _ ResponseHelper = new(jsonResponse) + +func (j *jsonResponse) Body() (io.Reader, error) { + b, err := json.Marshal(j.value) + if err != nil { + return nil, err + } + return bytes.NewReader(b), nil +} + +func (j *jsonResponse) Status() int { + return j.status +} + +func (j *jsonResponse) Headers() http.Header { + header := make(http.Header) + header.Set("Content-Type", "application/json") + return header +} + +// JSONResponse returns a [ResponseHelper] that JSON encodes value with a 200 response. +func JSONResponse(value any) ResponseHelper { + return JSONResponseWithStatus(value, http.StatusOK) +} + +// JSONResponseWithStatus returns a [ResponseHelper] that JSON encodes value with the provided status. +func JSONResponseWithStatus(value any, status int) ResponseHelper { + return &jsonResponse{ + value: value, + status: status, + } +} diff --git a/response_test.go b/response_test.go new file mode 100644 index 0000000..53faccb --- /dev/null +++ b/response_test.go @@ -0,0 +1,163 @@ +package ezhandler_test + +import ( + "codeberg.org/danjones000/ezhandler" + "encoding/json" + "errors" + "github.com/stretchr/testify/assert" + + "net/http" + "net/http/httptest" + "testing" +) + +var errCircularMarshal = errors.New("json: unsupported value: encountered a cycle") + +func TestJSONResponse(t *testing.T) { + type testData struct { + Message string `json:"message"` + Code int `json:"code"` + } + + data := testData{Message: "Hello", Code: 200} + respHelper := ezhandler.JSONResponse(data) + + assert.Equal(t, http.StatusOK, respHelper.Status()) + + headers := respHelper.Headers() + assert.Equal(t, "application/json", headers.Get("Content-Type")) + + body, err := respHelper.Body() + assert.NoError(t, err) + assert.NotNil(t, body) + + var decodedData testData + err = json.NewDecoder(body).Decode(&decodedData) + assert.NoError(t, err) + assert.Equal(t, data, decodedData) +} + +func TestJSONResponseWithStatus(t *testing.T) { + type testData struct { + Key string `json:"key"` + } + + data := testData{Key: "value"} + customStatus := http.StatusCreated + respHelper := ezhandler.JSONResponseWithStatus(data, customStatus) + + assert.Equal(t, customStatus, respHelper.Status()) + + headers := respHelper.Headers() + assert.Equal(t, "application/json", headers.Get("Content-Type")) + + body, err := respHelper.Body() + assert.NoError(t, err) + assert.NotNil(t, body) + + var decodedData testData + err = json.NewDecoder(body).Decode(&decodedData) + assert.NoError(t, err) + assert.Equal(t, data, decodedData) +} + +func TestJSONResponse_BodyError(t *testing.T) { + // This type cannot be marshaled to JSON due to a circular reference + type Circular struct { + Self *Circular + } + data := Circular{} + data.Self = &data + + respHelper := ezhandler.JSONResponse(data) + + body, err := respHelper.Body() + assert.Error(t, err) + assert.Nil(t, body) +} + +func TestResponseHandler_ServeHTTP_JSONResponse(t *testing.T) { + type testData struct { + Message string `json:"message"` + } + + tests := []struct { + name string + handler ezhandler.ResponseHandler + expectedStatus int + expectedBody string + expectedError error + }{ + { + name: "successful JSON response", + handler: ezhandler.ResponseHandler(func(r *http.Request) (ezhandler.ResponseHelper, error) { + return ezhandler.JSONResponse(testData{Message: "success"}), nil + }), + expectedStatus: http.StatusOK, + expectedBody: "{\"message\":\"success\"}", + expectedError: nil, + }, + { + name: "JSON response with custom status", + handler: ezhandler.ResponseHandler(func(r *http.Request) (ezhandler.ResponseHelper, error) { + return ezhandler.JSONResponseWithStatus(testData{Message: "created"}, http.StatusCreated), nil + }), + expectedStatus: http.StatusCreated, + expectedBody: "{\"message\":\"created\"}", + expectedError: nil, + }, + { + name: "error from ResponseHandler", + handler: ezhandler.ResponseHandler(func(r *http.Request) (ezhandler.ResponseHelper, error) { + return nil, errTest // Using errTest from helper_test.go + }), + expectedStatus: http.StatusOK, // Status won't be set if handler returns error before writing header + expectedBody: "", + expectedError: errTest, + }, + { + name: "error from Body() method", + handler: ezhandler.ResponseHandler(func(r *http.Request) (ezhandler.ResponseHelper, error) { + type Circular struct { + Self *Circular + } + data := Circular{} + data.Self = &data + return ezhandler.JSONResponse(data), nil + }), + expectedStatus: http.StatusOK, // Status won't be set if Body() returns error + expectedBody: "", + expectedError: errCircularMarshal, // The error is returned by ServeHTTP, not the handler func + }, + } + + for _, tt := range tests { + runResponseHandlerTest(t, tt) + } +} + +func runResponseHandlerTest(t *testing.T, tt struct { + name string + handler ezhandler.ResponseHandler + expectedStatus int + expectedBody string + expectedError error +}) { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + rec := httptest.NewRecorder() + + err := tt.handler.ServeHTTP(rec, req) + + if tt.expectedError != nil { + assert.ErrorContains(t, err, tt.expectedError.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedStatus, rec.Code) + assert.Equal(t, tt.expectedBody, rec.Body.String()) + if tt.expectedBody != "" { + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + } + } + }) +}