From 469bd759995520e2bb7440dc7bbd8ffd7fc554b2 Mon Sep 17 00:00:00 2001 From: tobi Date: Mon, 3 Mar 2025 14:30:23 +0100 Subject: [PATCH] move + tweak handlers a bit --- cmd/gotosocial/action/server/server.go | 10 +- internal/api/util/auth.go | 54 --------- internal/oauth/handlers/handlers.go | 153 +++++++++++++++++++++++++ internal/oauth/server.go | 145 +++++++++-------------- testrig/oauthserver.go | 9 +- 5 files changed, 221 insertions(+), 150 deletions(-) create mode 100644 internal/oauth/handlers/handlers.go diff --git a/cmd/gotosocial/action/server/server.go b/cmd/gotosocial/action/server/server.go index 0c011510b..3c37c6ff6 100644 --- a/cmd/gotosocial/action/server/server.go +++ b/cmd/gotosocial/action/server/server.go @@ -52,6 +52,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/middleware" "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/oauth/handlers" "github.com/superseriousbusiness/gotosocial/internal/observability" "github.com/superseriousbusiness/gotosocial/internal/oidc" "github.com/superseriousbusiness/gotosocial/internal/processing" @@ -260,7 +261,14 @@ var Start action.GTSAction = func(ctx context.Context) error { // Build handlers used in later initializations. mediaManager := media.NewManager(state) - oauthServer := oauth.New(ctx, state, apiutil.GetClientScopeHandler(ctx, state)) + oauthServer := oauth.New(ctx, state, + handlers.GetValidateURIHandler(ctx), + handlers.GetClientScopeHandler(ctx, state), + handlers.GetAuthorizeScopeHandler(), + handlers.GetInternalErrorHandler(ctx), + handlers.GetResponseErrorHandler(ctx), + handlers.GetUserAuthorizationHandler(), + ) typeConverter := typeutils.NewConverter(state) visFilter := visibility.NewFilter(state) intFilter := interaction.NewFilter(state) diff --git a/internal/api/util/auth.go b/internal/api/util/auth.go index b56827998..fccdf38e1 100644 --- a/internal/api/util/auth.go +++ b/internal/api/util/auth.go @@ -18,21 +18,15 @@ package util import ( - "context" "errors" "slices" "strings" "codeberg.org/superseriousbusiness/oauth2/v4" - "codeberg.org/superseriousbusiness/oauth2/v4/server" "github.com/gin-gonic/gin" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/oauth" - "github.com/superseriousbusiness/gotosocial/internal/state" ) // Auth wraps an authorized token, application, user, and account. @@ -156,51 +150,3 @@ func TokenAuth( return a, nil } - -// GetClientScopeHandler returns a handler for testing scope on a TokenGenerateRequest. -func GetClientScopeHandler(ctx context.Context, state *state.State) server.ClientScopeHandler { - return func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) { - application, err := state.DB.GetApplicationByClientID( - gtscontext.SetBarebones(ctx), - tgr.ClientID, - ) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - log.Errorf(ctx, "database error getting application: %v", err) - return false, err - } - - if application == nil { - err := gtserror.Newf("no application found with client id %s", tgr.ClientID) - return false, err - } - - // Normalize scope. - if strings.TrimSpace(tgr.Scope) == "" { - tgr.Scope = "read" - } - - // Make sure requested scopes are all - // within scopes permitted by application. - hasScopes := strings.Split(application.Scopes, " ") - wantsScopes := strings.Split(tgr.Scope, " ") - for _, wantsScope := range wantsScopes { - thisOK := slices.ContainsFunc( - hasScopes, - func(hasScope string) bool { - has := Scope(hasScope) - wants := Scope(wantsScope) - return has.Permits(wants) - }, - ) - - if !thisOK { - // Requested unpermitted - // scope for this app. - return false, nil - } - } - - // All OK. - return true, nil - } -} diff --git a/internal/oauth/handlers/handlers.go b/internal/oauth/handlers/handlers.go new file mode 100644 index 000000000..f0af007f0 --- /dev/null +++ b/internal/oauth/handlers/handlers.go @@ -0,0 +1,153 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package handlers + +import ( + "context" + "errors" + "net/http" + "net/url" + "slices" + "strings" + + "codeberg.org/superseriousbusiness/oauth2/v4" + oautherr "codeberg.org/superseriousbusiness/oauth2/v4/errors" + "codeberg.org/superseriousbusiness/oauth2/v4/manage" + "codeberg.org/superseriousbusiness/oauth2/v4/server" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/state" +) + +// GetClientScopeHandler returns a handler for testing scope on a TokenGenerateRequest. +func GetClientScopeHandler(ctx context.Context, state *state.State) server.ClientScopeHandler { + return func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) { + application, err := state.DB.GetApplicationByClientID( + gtscontext.SetBarebones(ctx), + tgr.ClientID, + ) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + log.Errorf(ctx, "database error getting application: %v", err) + return false, err + } + + if application == nil { + err := gtserror.Newf("no application found with client id %s", tgr.ClientID) + return false, err + } + + // Normalize scope. + if strings.TrimSpace(tgr.Scope) == "" { + tgr.Scope = "read" + } + + // Make sure requested scopes are all + // within scopes permitted by application. + hasScopes := strings.Split(application.Scopes, " ") + wantsScopes := strings.Split(tgr.Scope, " ") + for _, wantsScope := range wantsScopes { + thisOK := slices.ContainsFunc( + hasScopes, + func(hasScope string) bool { + has := apiutil.Scope(hasScope) + wants := apiutil.Scope(wantsScope) + return has.Permits(wants) + }, + ) + + if !thisOK { + // Requested unpermitted + // scope for this app. + return false, nil + } + } + + // All OK. + return true, nil + } +} + +func GetValidateURIHandler(ctx context.Context) manage.ValidateURIHandler { + return func(hasRedirects string, wantsRedirect string) error { + // Normalize the wantsRedirect URI + // string by parsing + reserializing. + wantsRedirectURI, err := url.Parse(wantsRedirect) + if err != nil { + return err + } + wantsRedirect = wantsRedirectURI.String() + + // Redirect URIs are given to us as + // a list of URIs, newline-separated. + // + // They're already normalized on input so + // we don't need to parse + reserialize them. + // + // Ensure that one of them matches. + if slices.ContainsFunc( + strings.Split(hasRedirects, "\n"), + func(hasRedirect string) bool { + // Want an exact match. + // See: https://www.oauth.com/oauth2-servers/redirect-uris/redirect-uri-validation/ + return wantsRedirect == hasRedirect + }, + ) { + return nil + } + + return oautherr.ErrInvalidRedirectURI + } +} + +func GetAuthorizeScopeHandler() server.AuthorizeScopeHandler { + return func(_ http.ResponseWriter, r *http.Request) (string, error) { + // Use provided scope or + // fall back to default "read". + scope := r.FormValue("scope") + if strings.TrimSpace(scope) == "" { + scope = "read" + } + return scope, nil + } +} + +func GetInternalErrorHandler(ctx context.Context) server.InternalErrorHandler { + return func(err error) *oautherr.Response { + log.Errorf(ctx, "internal oauth error: %v", err) + return nil + } +} + +func GetResponseErrorHandler(ctx context.Context) server.ResponseErrorHandler { + return func(re *oautherr.Response) { + log.Errorf(ctx, "internal response error: %v", re.Error) + } +} + +func GetUserAuthorizationHandler() server.UserAuthorizationHandler { + return func(w http.ResponseWriter, r *http.Request) (string, error) { + userID := r.FormValue("userid") + if userID == "" { + return "", errors.New("userid was empty") + } + return userID, nil + } +} diff --git a/internal/oauth/server.go b/internal/oauth/server.go index 8475555ef..34fc4c60e 100644 --- a/internal/oauth/server.go +++ b/internal/oauth/server.go @@ -22,8 +22,6 @@ import ( "errors" "fmt" "net/http" - "net/url" - "slices" "strings" "codeberg.org/superseriousbusiness/oauth2/v4" @@ -65,7 +63,8 @@ const ( HelpfulAdviceGrant = "If you arrived at this error during a sign in/oauth flow, your client is trying to use an unsupported OAuth grant type. Supported grant types are: authorization_code, client_credentials; please reach out to developer of your client" ) -// Server wraps some oauth2 server functions in an interface, exposing only what is needed +// Server wraps some oauth2 server functions +// in an interface, exposing only what is needed. type Server interface { HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtserror.WithCode @@ -74,7 +73,8 @@ type Server interface { LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error) } -// s fulfils the Server interface using the underlying oauth2 server +// s fulfils the Server interface +// using the underlying oauth2 server. type s struct { server *server.Server } @@ -83,111 +83,66 @@ type s struct { func New( ctx context.Context, state *state.State, + validateURIHandler manage.ValidateURIHandler, clientScopeHandler server.ClientScopeHandler, + authorizeScopeHandler server.AuthorizeScopeHandler, + internalErrorHandler server.InternalErrorHandler, + responseErrorHandler server.ResponseErrorHandler, + userAuthorizationHandler server.UserAuthorizationHandler, ) Server { ts := newTokenStore(ctx, state) cs := NewClientStore(state) + // Set up OAuth2 manager. manager := manage.NewDefaultManager() + manager.SetValidateURIHandler(validateURIHandler) manager.MapTokenStorage(ts) manager.MapClientStorage(cs) - manager.SetAuthorizeCodeTokenCfg(&manage.Config{ - // Following the Mastodon API, - // access tokens don't expire. - AccessTokenExp: 0, - // Don't use refresh tokens. - IsGenerateRefresh: false, - }) + manager.SetAuthorizeCodeTokenCfg( + &manage.Config{ + // Following the Mastodon API, + // access tokens don't expire. + AccessTokenExp: 0, + // Don't use refresh tokens. + IsGenerateRefresh: false, + }, + ) - manager.SetValidateURIHandler(func(hasRedirectList, wantsRedirect string) error { - wantsRedirectURI, err := url.Parse(wantsRedirect) - if err != nil { - return err - } - - // Redirect URIs are given to us as - // a list of URIs, newline-separated. - // - // Ensure that one of them matches - // requested redirectURI. - hasRedirects := strings.Split(hasRedirectList, "\n") - - if slices.ContainsFunc( - hasRedirects, - func(hasRedirect string) bool { - hasRedirectURI, err := url.Parse(hasRedirect) - if err != nil { - log.Errorf(nil, "error parsing hasRedirect: %v", err) - return false - } - - // Want an exact match. - // See: https://www.oauth.com/oauth2-servers/redirect-uris/redirect-uri-validation/ - return wantsRedirectURI.String() == hasRedirectURI.String() + // Set up OAuth2 server. + srv := server.NewServer( + &server.Config{ + TokenType: "Bearer", + // Must follow the spec. + AllowGetAccessRequest: false, + // Support only the non-implicit flow. + AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, + // Allow: + // - Authorization Code (for first & third parties) + // - Client Credentials (for applications) + AllowedGrantTypes: []oauth2.GrantType{ + oauth2.AuthorizationCode, + oauth2.ClientCredentials, + }, + AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{ + oauth2.CodeChallengePlain, + oauth2.CodeChallengeS256, }, - ) { - return nil - } - - return oautherr.ErrInvalidRedirectURI - }) - - sc := &server.Config{ - TokenType: "Bearer", - // Must follow the spec. - AllowGetAccessRequest: false, - // Support only the non-implicit flow. - AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, - // Allow: - // - Authorization Code (for first & third parties) - // - Client Credentials (for applications) - AllowedGrantTypes: []oauth2.GrantType{ - oauth2.AuthorizationCode, - oauth2.ClientCredentials, }, - AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{ - oauth2.CodeChallengePlain, - oauth2.CodeChallengeS256, - }, - } - - srv := server.NewServer(sc, manager) - - srv.SetAuthorizeScopeHandler(func(w http.ResponseWriter, r *http.Request) (string, error) { - // Use provided scope or - // fall back to default "read". - scope := r.FormValue("scope") - if strings.TrimSpace(scope) == "" { - scope = "read" - } - return scope, nil - }) - + manager, + ) + srv.SetAuthorizeScopeHandler(authorizeScopeHandler) srv.SetClientScopeHandler(clientScopeHandler) - - srv.SetInternalErrorHandler(func(err error) *oautherr.Response { - log.Errorf(nil, "internal oauth error: %s", err) - return nil - }) - - srv.SetResponseErrorHandler(func(re *oautherr.Response) { - log.Errorf(nil, "internal response error: %s", re.Error) - }) - - srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) { - userID := r.FormValue("userid") - if userID == "" { - return "", errors.New("userid was empty") - } - return userID, nil - }) - + srv.SetInternalErrorHandler(internalErrorHandler) + srv.SetResponseErrorHandler(responseErrorHandler) + srv.SetUserAuthorizationHandler(userAuthorizationHandler) srv.SetClientInfoHandler(server.ClientFormHandler) return &s{srv} } -// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function +// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function, +// providing some custom error handling (with more informative messages), +// and a slightly different token serialization format. func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode) { ctx := r.Context() @@ -201,19 +156,23 @@ func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserro return nil, gtserror.NewErrorBadRequest(err, help, adv) } + // Get access token + do our own nicer error handling. ti, err := s.server.GetAccessToken(ctx, gt, tgr) switch { case err == nil: // No problem. break + case errors.Is(err, oautherr.ErrInvalidScope): help := fmt.Sprintf("requested scope %s was not covered by client scope", tgr.Scope) return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice) + case errors.Is(err, oautherr.ErrInvalidRedirectURI): help := fmt.Sprintf("requested redirect URI %s was not covered by client redirect URIs", tgr.RedirectURI) return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice) + default: - help := fmt.Sprintf("could not get access token: %s", err) + help := fmt.Sprintf("could not get access token: %v", err) return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice) } diff --git a/testrig/oauthserver.go b/testrig/oauthserver.go index df3caada3..9429e751b 100644 --- a/testrig/oauthserver.go +++ b/testrig/oauthserver.go @@ -20,8 +20,8 @@ package testrig import ( "context" - apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/oauth/handlers" "github.com/superseriousbusiness/gotosocial/internal/state" ) @@ -31,6 +31,11 @@ func NewTestOauthServer(state *state.State) oauth.Server { return oauth.New( ctx, state, - apiutil.GetClientScopeHandler(ctx, state), + handlers.GetValidateURIHandler(ctx), + handlers.GetClientScopeHandler(ctx, state), + handlers.GetAuthorizeScopeHandler(), + handlers.GetInternalErrorHandler(ctx), + handlers.GetResponseErrorHandler(ctx), + handlers.GetUserAuthorizationHandler(), ) }