[feature] Refactor tokens, allow multiple app redirect_uris

This commit is contained in:
tobi 2025-03-03 11:45:45 +01:00
commit 3b1b842890
77 changed files with 860 additions and 554 deletions

View file

@ -22,6 +22,8 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"slices"
"strings"
"codeberg.org/superseriousbusiness/oauth2/v4"
@ -30,7 +32,10 @@ import (
"codeberg.org/superseriousbusiness/oauth2/v4/server"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
const (
@ -75,17 +80,58 @@ type s struct {
}
// New returns a new oauth server that implements the Server interface
func New(ctx context.Context, database db.DB) Server {
ts := newTokenStore(ctx, database)
cs := NewClientStore(database)
func New(
ctx context.Context,
state *state.State,
clientScopeHandler server.ClientScopeHandler,
) Server {
ts := newTokenStore(ctx, state)
cs := NewClientStore(state)
manager := manage.NewDefaultManager()
manager.MapTokenStorage(ts)
manager.MapClientStorage(cs)
manager.SetAuthorizeCodeTokenCfg(&manage.Config{
AccessTokenExp: 0, // access tokens don't expire -- they must be revoked
IsGenerateRefresh: false, // don't use refresh tokens
// Following the Mastodon API,
// access tokens don't expire.
AccessTokenExp: 0,
// Don't use refresh tokens.
IsGenerateRefresh: false,
})
manager.SetValidateURIHandler(func(hasRedirectList, wantsRedirect string) error {
wantsRedirectURI, err := url.Parse(wantsRedirect)
if err != nil {
return err
}
// Redirect URIs are given to us as
// a list of URIs, newline-separated.
//
// Ensure that one of them matches
// requested redirectURI.
hasRedirects := strings.Split(hasRedirectList, "\n")
if slices.ContainsFunc(
hasRedirects,
func(hasRedirect string) bool {
hasRedirectURI, err := url.Parse(hasRedirect)
if err != nil {
log.Errorf(nil, "error parsing hasRedirect: %v", err)
return false
}
// Want an exact match.
// See: https://www.oauth.com/oauth2-servers/redirect-uris/redirect-uri-validation/
return wantsRedirectURI.String() == hasRedirectURI.String()
},
) {
return nil
}
return oautherr.ErrInvalidRedirectURI
})
sc := &server.Config{
TokenType: "Bearer",
// Must follow the spec.
@ -106,6 +152,19 @@ func New(ctx context.Context, database db.DB) Server {
}
srv := server.NewServer(sc, manager)
srv.SetAuthorizeScopeHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
// Use provided scope or
// fall back to default "read".
scope := r.FormValue("scope")
if strings.TrimSpace(scope) == "" {
scope = "read"
}
return scope, nil
})
srv.SetClientScopeHandler(clientScopeHandler)
srv.SetInternalErrorHandler(func(err error) *oautherr.Response {
log.Errorf(nil, "internal oauth error: %s", err)
return nil
@ -122,10 +181,10 @@ func New(ctx context.Context, database db.DB) Server {
}
return userID, nil
})
srv.SetClientInfoHandler(server.ClientFormHandler)
return &s{
server: srv,
}
return &s{srv}
}
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
@ -143,31 +202,42 @@ func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserro
}
ti, err := s.server.GetAccessToken(ctx, gt, tgr)
if err != nil {
switch {
case err == nil:
// No problem.
break
case errors.Is(err, oautherr.ErrInvalidScope):
help := fmt.Sprintf("requested scope %s was not covered by client scope", tgr.Scope)
return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
case errors.Is(err, oautherr.ErrInvalidRedirectURI):
help := fmt.Sprintf("requested redirect URI %s was not covered by client redirect URIs", tgr.RedirectURI)
return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
default:
help := fmt.Sprintf("could not get access token: %s", err)
return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice)
}
// Wrangle data a bit.
data := s.server.GetTokenData(ti)
// Add created_at for Mastodon API compatibility.
data["created_at"] = ti.GetAccessCreateAt().Unix()
// If expires_in is 0 or less, omit it
// from serialization so that clients don't
// interpret the token as already expired.
if expiresInI, ok := data["expires_in"]; ok {
switch expiresIn := expiresInI.(type) {
case int64:
// remove this key from the returned map
// if the value is 0 or less, so that clients
// don't interpret the token as already expired
if expiresIn <= 0 {
delete(data, "expires_in")
}
default:
err := errors.New("expires_in was set on token response, but was not an int64")
return nil, gtserror.NewErrorInternalError(err, HelpfulAdvice)
expiresIn, ok := expiresInI.(int64)
if !ok {
log.Panicf(ctx, "could not cast expires_in %T as int64", expiresInI)
return nil, nil
}
if expiresIn <= 0 {
delete(data, "expires_in")
}
}
// add this for mastodon api compatibility
data["created_at"] = ti.GetAccessCreateAt().Unix()
return data, nil
}
@ -207,7 +277,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
}
req.UserID = userID
// specify the scope of authorization
// Specify the scope of authorization.
if fn := s.server.AuthorizeScopeHandler; fn != nil {
scope, err := fn(w, r)
if err != nil {
@ -217,7 +287,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
}
}
// specify the expiration time of access token
// Specify the expiration time of access token.
if fn := s.server.AccessTokenExpHandler; fn != nil {
exp, err := fn(w, r)
if err != nil {
@ -231,13 +301,28 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
return s.errorOrRedirect(err, w, req)
}
// If the redirect URI is empty, the default domain provided by the client is used.
// If the redirect URI is empty, use the
// first of the client's redirect URIs.
if req.RedirectURI == "" {
client, err := s.server.Manager.GetClient(ctx, req.ClientID)
if err != nil {
if err != nil && !errors.Is(err, db.ErrNoEntries) {
// Real error.
err := gtserror.Newf("db error getting application with client id %s: %w", req.ClientID, err)
return gtserror.NewErrorInternalError(err)
}
if util.IsNil(client) {
// Application just not found.
return gtserror.NewErrorUnauthorized(err, HelpfulAdvice)
}
req.RedirectURI = client.GetDomain()
app, ok := client.(*gtsmodel.Application)
if !ok {
log.Panicf(ctx, "could not cast %T to *gtsmodel.Application", client)
return nil
}
req.RedirectURI = app.RedirectURIs[0]
}
uri, err := s.server.GetRedirectURI(req, s.server.GetAuthorizeData(req.ResponseType, ti))