mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-12-23 01:46:16 -06:00
[feature] Refactor tokens, allow multiple app redirect_uris
This commit is contained in:
parent
67a2b3650c
commit
3b1b842890
77 changed files with 860 additions and 554 deletions
|
|
@ -22,6 +22,8 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"codeberg.org/superseriousbusiness/oauth2/v4"
|
||||
|
|
@ -30,7 +32,10 @@ import (
|
|||
"codeberg.org/superseriousbusiness/oauth2/v4/server"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -75,17 +80,58 @@ type s struct {
|
|||
}
|
||||
|
||||
// New returns a new oauth server that implements the Server interface
|
||||
func New(ctx context.Context, database db.DB) Server {
|
||||
ts := newTokenStore(ctx, database)
|
||||
cs := NewClientStore(database)
|
||||
func New(
|
||||
ctx context.Context,
|
||||
state *state.State,
|
||||
clientScopeHandler server.ClientScopeHandler,
|
||||
) Server {
|
||||
ts := newTokenStore(ctx, state)
|
||||
cs := NewClientStore(state)
|
||||
|
||||
manager := manage.NewDefaultManager()
|
||||
manager.MapTokenStorage(ts)
|
||||
manager.MapClientStorage(cs)
|
||||
manager.SetAuthorizeCodeTokenCfg(&manage.Config{
|
||||
AccessTokenExp: 0, // access tokens don't expire -- they must be revoked
|
||||
IsGenerateRefresh: false, // don't use refresh tokens
|
||||
// Following the Mastodon API,
|
||||
// access tokens don't expire.
|
||||
AccessTokenExp: 0,
|
||||
// Don't use refresh tokens.
|
||||
IsGenerateRefresh: false,
|
||||
})
|
||||
|
||||
manager.SetValidateURIHandler(func(hasRedirectList, wantsRedirect string) error {
|
||||
wantsRedirectURI, err := url.Parse(wantsRedirect)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Redirect URIs are given to us as
|
||||
// a list of URIs, newline-separated.
|
||||
//
|
||||
// Ensure that one of them matches
|
||||
// requested redirectURI.
|
||||
hasRedirects := strings.Split(hasRedirectList, "\n")
|
||||
|
||||
if slices.ContainsFunc(
|
||||
hasRedirects,
|
||||
func(hasRedirect string) bool {
|
||||
hasRedirectURI, err := url.Parse(hasRedirect)
|
||||
if err != nil {
|
||||
log.Errorf(nil, "error parsing hasRedirect: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Want an exact match.
|
||||
// See: https://www.oauth.com/oauth2-servers/redirect-uris/redirect-uri-validation/
|
||||
return wantsRedirectURI.String() == hasRedirectURI.String()
|
||||
},
|
||||
) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return oautherr.ErrInvalidRedirectURI
|
||||
})
|
||||
|
||||
sc := &server.Config{
|
||||
TokenType: "Bearer",
|
||||
// Must follow the spec.
|
||||
|
|
@ -106,6 +152,19 @@ func New(ctx context.Context, database db.DB) Server {
|
|||
}
|
||||
|
||||
srv := server.NewServer(sc, manager)
|
||||
|
||||
srv.SetAuthorizeScopeHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
// Use provided scope or
|
||||
// fall back to default "read".
|
||||
scope := r.FormValue("scope")
|
||||
if strings.TrimSpace(scope) == "" {
|
||||
scope = "read"
|
||||
}
|
||||
return scope, nil
|
||||
})
|
||||
|
||||
srv.SetClientScopeHandler(clientScopeHandler)
|
||||
|
||||
srv.SetInternalErrorHandler(func(err error) *oautherr.Response {
|
||||
log.Errorf(nil, "internal oauth error: %s", err)
|
||||
return nil
|
||||
|
|
@ -122,10 +181,10 @@ func New(ctx context.Context, database db.DB) Server {
|
|||
}
|
||||
return userID, nil
|
||||
})
|
||||
|
||||
srv.SetClientInfoHandler(server.ClientFormHandler)
|
||||
return &s{
|
||||
server: srv,
|
||||
}
|
||||
|
||||
return &s{srv}
|
||||
}
|
||||
|
||||
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
|
||||
|
|
@ -143,31 +202,42 @@ func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserro
|
|||
}
|
||||
|
||||
ti, err := s.server.GetAccessToken(ctx, gt, tgr)
|
||||
if err != nil {
|
||||
switch {
|
||||
case err == nil:
|
||||
// No problem.
|
||||
break
|
||||
case errors.Is(err, oautherr.ErrInvalidScope):
|
||||
help := fmt.Sprintf("requested scope %s was not covered by client scope", tgr.Scope)
|
||||
return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
|
||||
case errors.Is(err, oautherr.ErrInvalidRedirectURI):
|
||||
help := fmt.Sprintf("requested redirect URI %s was not covered by client redirect URIs", tgr.RedirectURI)
|
||||
return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
|
||||
default:
|
||||
help := fmt.Sprintf("could not get access token: %s", err)
|
||||
return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice)
|
||||
}
|
||||
|
||||
// Wrangle data a bit.
|
||||
data := s.server.GetTokenData(ti)
|
||||
|
||||
// Add created_at for Mastodon API compatibility.
|
||||
data["created_at"] = ti.GetAccessCreateAt().Unix()
|
||||
|
||||
// If expires_in is 0 or less, omit it
|
||||
// from serialization so that clients don't
|
||||
// interpret the token as already expired.
|
||||
if expiresInI, ok := data["expires_in"]; ok {
|
||||
switch expiresIn := expiresInI.(type) {
|
||||
case int64:
|
||||
// remove this key from the returned map
|
||||
// if the value is 0 or less, so that clients
|
||||
// don't interpret the token as already expired
|
||||
if expiresIn <= 0 {
|
||||
delete(data, "expires_in")
|
||||
}
|
||||
default:
|
||||
err := errors.New("expires_in was set on token response, but was not an int64")
|
||||
return nil, gtserror.NewErrorInternalError(err, HelpfulAdvice)
|
||||
expiresIn, ok := expiresInI.(int64)
|
||||
if !ok {
|
||||
log.Panicf(ctx, "could not cast expires_in %T as int64", expiresInI)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if expiresIn <= 0 {
|
||||
delete(data, "expires_in")
|
||||
}
|
||||
}
|
||||
|
||||
// add this for mastodon api compatibility
|
||||
data["created_at"] = ti.GetAccessCreateAt().Unix()
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
|
|
@ -207,7 +277,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
|
|||
}
|
||||
req.UserID = userID
|
||||
|
||||
// specify the scope of authorization
|
||||
// Specify the scope of authorization.
|
||||
if fn := s.server.AuthorizeScopeHandler; fn != nil {
|
||||
scope, err := fn(w, r)
|
||||
if err != nil {
|
||||
|
|
@ -217,7 +287,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
|
|||
}
|
||||
}
|
||||
|
||||
// specify the expiration time of access token
|
||||
// Specify the expiration time of access token.
|
||||
if fn := s.server.AccessTokenExpHandler; fn != nil {
|
||||
exp, err := fn(w, r)
|
||||
if err != nil {
|
||||
|
|
@ -231,13 +301,28 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
|
|||
return s.errorOrRedirect(err, w, req)
|
||||
}
|
||||
|
||||
// If the redirect URI is empty, the default domain provided by the client is used.
|
||||
// If the redirect URI is empty, use the
|
||||
// first of the client's redirect URIs.
|
||||
if req.RedirectURI == "" {
|
||||
client, err := s.server.Manager.GetClient(ctx, req.ClientID)
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
// Real error.
|
||||
err := gtserror.Newf("db error getting application with client id %s: %w", req.ClientID, err)
|
||||
return gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
if util.IsNil(client) {
|
||||
// Application just not found.
|
||||
return gtserror.NewErrorUnauthorized(err, HelpfulAdvice)
|
||||
}
|
||||
req.RedirectURI = client.GetDomain()
|
||||
|
||||
app, ok := client.(*gtsmodel.Application)
|
||||
if !ok {
|
||||
log.Panicf(ctx, "could not cast %T to *gtsmodel.Application", client)
|
||||
return nil
|
||||
}
|
||||
|
||||
req.RedirectURI = app.RedirectURIs[0]
|
||||
}
|
||||
|
||||
uri, err := s.server.GetRedirectURI(req, s.server.GetAuthorizeData(req.ResponseType, ti))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue