mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-11-26 03:23:32 -06:00
[feature] Refactor tokens, allow multiple app redirect_uris (#3849)
* [feature] Refactor tokens, allow multiple app redirect_uris * move + tweak handlers a bit * return error for unset oauth2.ClientStore funcs * wrap UpdateToken with cache * panic handling * cheeky little time optimization * unlock on error
This commit is contained in:
parent
c80810eae8
commit
1b37944f8b
77 changed files with 963 additions and 594 deletions
|
|
@ -21,45 +21,30 @@ import (
|
|||
"context"
|
||||
|
||||
"codeberg.org/superseriousbusiness/oauth2/v4"
|
||||
"codeberg.org/superseriousbusiness/oauth2/v4/models"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"codeberg.org/superseriousbusiness/oauth2/v4/errors"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
)
|
||||
|
||||
type clientStore struct {
|
||||
db db.DB
|
||||
state *state.State
|
||||
}
|
||||
|
||||
// NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend.
|
||||
func NewClientStore(db db.DB) oauth2.ClientStore {
|
||||
pts := &clientStore{
|
||||
db: db,
|
||||
}
|
||||
return pts
|
||||
// NewClientStore returns a minimal implementation of
|
||||
// oauth2.ClientStore interface, using state as storage.
|
||||
//
|
||||
// Only GetByID is implemented, Set and Delete are stubs.
|
||||
func NewClientStore(state *state.State) oauth2.ClientStore {
|
||||
return &clientStore{state: state}
|
||||
}
|
||||
|
||||
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
|
||||
client, err := cs.db.GetClientByID(ctx, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return models.New(
|
||||
client.ID,
|
||||
client.Secret,
|
||||
client.Domain,
|
||||
client.UserID,
|
||||
), nil
|
||||
return cs.state.DB.GetApplicationByClientID(ctx, clientID)
|
||||
}
|
||||
|
||||
func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
|
||||
return cs.db.PutClient(ctx, >smodel.Client{
|
||||
ID: cli.GetID(),
|
||||
Secret: cli.GetSecret(),
|
||||
Domain: cli.GetDomain(),
|
||||
UserID: cli.GetUserID(),
|
||||
})
|
||||
func (cs *clientStore) Set(_ context.Context, _ string, _ oauth2.ClientInfo) error {
|
||||
return errors.New("func oauth2.ClientStore.Set not implemented")
|
||||
}
|
||||
|
||||
func (cs *clientStore) Delete(ctx context.Context, id string) error {
|
||||
return cs.db.DeleteClientByID(ctx, id)
|
||||
func (cs *clientStore) Delete(_ context.Context, _ string) error {
|
||||
return errors.New("func oauth2.ClientStore.Delete not implemented")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,93 +21,58 @@ import (
|
|||
"context"
|
||||
"testing"
|
||||
|
||||
"codeberg.org/superseriousbusiness/oauth2/v4/models"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/admin"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
type PgClientStoreTestSuite struct {
|
||||
type ClientStoreTestSuite struct {
|
||||
suite.Suite
|
||||
db db.DB
|
||||
state state.State
|
||||
testClientID string
|
||||
testClientSecret string
|
||||
testClientDomain string
|
||||
testClientUserID string
|
||||
testApplications map[string]*gtsmodel.Application
|
||||
}
|
||||
|
||||
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
||||
func (suite *PgClientStoreTestSuite) SetupSuite() {
|
||||
suite.testClientID = "01FCVB74EW6YBYAEY7QG9CQQF6"
|
||||
suite.testClientSecret = "4cc87402-259b-4a35-9485-2c8bf54f3763"
|
||||
suite.testClientDomain = "https://example.org"
|
||||
suite.testClientUserID = "01FEGYXKVCDB731QF9MVFXA4F5"
|
||||
func (suite *ClientStoreTestSuite) SetupSuite() {
|
||||
suite.testApplications = testrig.NewTestApplications()
|
||||
}
|
||||
|
||||
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
|
||||
func (suite *PgClientStoreTestSuite) SetupTest() {
|
||||
func (suite *ClientStoreTestSuite) SetupTest() {
|
||||
suite.state.Caches.Init()
|
||||
testrig.InitTestLog()
|
||||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
suite.db = testrig.NewTestDB(&suite.state)
|
||||
suite.state.DB = suite.db
|
||||
suite.state.AdminActions = admin.New(suite.state.DB, &suite.state.Workers)
|
||||
testrig.StandardDBSetup(suite.db, nil)
|
||||
}
|
||||
|
||||
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
||||
func (suite *PgClientStoreTestSuite) TearDownTest() {
|
||||
func (suite *ClientStoreTestSuite) TearDownTest() {
|
||||
testrig.StandardDBTeardown(suite.db)
|
||||
}
|
||||
|
||||
func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() {
|
||||
// set a new client in the store
|
||||
cs := oauth.NewClientStore(suite.db)
|
||||
if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
func (suite *ClientStoreTestSuite) TestClientStoreGet() {
|
||||
testApp := suite.testApplications["application_1"]
|
||||
cs := oauth.NewClientStore(&suite.state)
|
||||
|
||||
// fetch that client from the store
|
||||
client, err := cs.GetByID(context.Background(), suite.testClientID)
|
||||
// Fetch clientInfo from the store.
|
||||
clientInfo, err := cs.GetByID(context.Background(), testApp.ClientID)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
// check that the values are the same
|
||||
suite.NotNil(client)
|
||||
suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client)
|
||||
// Check expected values.
|
||||
suite.NotNil(clientInfo)
|
||||
suite.Equal(testApp.ClientID, clientInfo.GetID())
|
||||
suite.Equal(testApp.ClientSecret, clientInfo.GetSecret())
|
||||
suite.Equal(testApp.RedirectURIs[0], clientInfo.GetDomain())
|
||||
suite.Equal(testApp.ManagedByUserID, clientInfo.GetUserID())
|
||||
}
|
||||
|
||||
func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
|
||||
// set a new client in the store
|
||||
cs := oauth.NewClientStore(suite.db)
|
||||
if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
// fetch the client from the store
|
||||
client, err := cs.GetByID(context.Background(), suite.testClientID)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
// check that the values are the same
|
||||
suite.NotNil(client)
|
||||
suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client)
|
||||
if err := cs.Delete(context.Background(), suite.testClientID); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
// try to get the deleted client; we should get an error
|
||||
deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
|
||||
suite.Assert().Nil(deletedClient)
|
||||
suite.Assert().EqualValues(db.ErrNoEntries, err)
|
||||
}
|
||||
|
||||
func TestPgClientStoreTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(PgClientStoreTestSuite))
|
||||
func TestClientStoreTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClientStoreTestSuite))
|
||||
}
|
||||
|
|
|
|||
153
internal/oauth/handlers/handlers.go
Normal file
153
internal/oauth/handlers/handlers.go
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"codeberg.org/superseriousbusiness/oauth2/v4"
|
||||
oautherr "codeberg.org/superseriousbusiness/oauth2/v4/errors"
|
||||
"codeberg.org/superseriousbusiness/oauth2/v4/manage"
|
||||
"codeberg.org/superseriousbusiness/oauth2/v4/server"
|
||||
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
)
|
||||
|
||||
// GetClientScopeHandler returns a handler for testing scope on a TokenGenerateRequest.
|
||||
func GetClientScopeHandler(ctx context.Context, state *state.State) server.ClientScopeHandler {
|
||||
return func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) {
|
||||
application, err := state.DB.GetApplicationByClientID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
tgr.ClientID,
|
||||
)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
log.Errorf(ctx, "database error getting application: %v", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
if application == nil {
|
||||
err := gtserror.Newf("no application found with client id %s", tgr.ClientID)
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Normalize scope.
|
||||
if strings.TrimSpace(tgr.Scope) == "" {
|
||||
tgr.Scope = "read"
|
||||
}
|
||||
|
||||
// Make sure requested scopes are all
|
||||
// within scopes permitted by application.
|
||||
hasScopes := strings.Split(application.Scopes, " ")
|
||||
wantsScopes := strings.Split(tgr.Scope, " ")
|
||||
for _, wantsScope := range wantsScopes {
|
||||
thisOK := slices.ContainsFunc(
|
||||
hasScopes,
|
||||
func(hasScope string) bool {
|
||||
has := apiutil.Scope(hasScope)
|
||||
wants := apiutil.Scope(wantsScope)
|
||||
return has.Permits(wants)
|
||||
},
|
||||
)
|
||||
|
||||
if !thisOK {
|
||||
// Requested unpermitted
|
||||
// scope for this app.
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// All OK.
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetValidateURIHandler(ctx context.Context) manage.ValidateURIHandler {
|
||||
return func(hasRedirects string, wantsRedirect string) error {
|
||||
// Normalize the wantsRedirect URI
|
||||
// string by parsing + reserializing.
|
||||
wantsRedirectURI, err := url.Parse(wantsRedirect)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
wantsRedirect = wantsRedirectURI.String()
|
||||
|
||||
// Redirect URIs are given to us as
|
||||
// a list of URIs, newline-separated.
|
||||
//
|
||||
// They're already normalized on input so
|
||||
// we don't need to parse + reserialize them.
|
||||
//
|
||||
// Ensure that one of them matches.
|
||||
if slices.ContainsFunc(
|
||||
strings.Split(hasRedirects, "\n"),
|
||||
func(hasRedirect string) bool {
|
||||
// Want an exact match.
|
||||
// See: https://www.oauth.com/oauth2-servers/redirect-uris/redirect-uri-validation/
|
||||
return wantsRedirect == hasRedirect
|
||||
},
|
||||
) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return oautherr.ErrInvalidRedirectURI
|
||||
}
|
||||
}
|
||||
|
||||
func GetAuthorizeScopeHandler() server.AuthorizeScopeHandler {
|
||||
return func(_ http.ResponseWriter, r *http.Request) (string, error) {
|
||||
// Use provided scope or
|
||||
// fall back to default "read".
|
||||
scope := r.FormValue("scope")
|
||||
if strings.TrimSpace(scope) == "" {
|
||||
scope = "read"
|
||||
}
|
||||
return scope, nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetInternalErrorHandler(ctx context.Context) server.InternalErrorHandler {
|
||||
return func(err error) *oautherr.Response {
|
||||
log.Errorf(ctx, "internal oauth error: %v", err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetResponseErrorHandler(ctx context.Context) server.ResponseErrorHandler {
|
||||
return func(re *oautherr.Response) {
|
||||
log.Errorf(ctx, "internal response error: %v", re.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func GetUserAuthorizationHandler() server.UserAuthorizationHandler {
|
||||
return func(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
userID := r.FormValue("userid")
|
||||
if userID == "" {
|
||||
return "", errors.New("userid was empty")
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
}
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package oauth_test
|
||||
|
||||
// TODO: write tests
|
||||
|
|
@ -30,7 +30,10 @@ import (
|
|||
"codeberg.org/superseriousbusiness/oauth2/v4/server"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -60,7 +63,8 @@ const (
|
|||
HelpfulAdviceGrant = "If you arrived at this error during a sign in/oauth flow, your client is trying to use an unsupported OAuth grant type. Supported grant types are: authorization_code, client_credentials; please reach out to developer of your client"
|
||||
)
|
||||
|
||||
// Server wraps some oauth2 server functions in an interface, exposing only what is needed
|
||||
// Server wraps some oauth2 server functions
|
||||
// in an interface, exposing only what is needed.
|
||||
type Server interface {
|
||||
HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode)
|
||||
HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtserror.WithCode
|
||||
|
|
@ -69,66 +73,76 @@ type Server interface {
|
|||
LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error)
|
||||
}
|
||||
|
||||
// s fulfils the Server interface using the underlying oauth2 server
|
||||
// s fulfils the Server interface
|
||||
// using the underlying oauth2 server.
|
||||
type s struct {
|
||||
server *server.Server
|
||||
}
|
||||
|
||||
// New returns a new oauth server that implements the Server interface
|
||||
func New(ctx context.Context, database db.DB) Server {
|
||||
ts := newTokenStore(ctx, database)
|
||||
cs := NewClientStore(database)
|
||||
func New(
|
||||
ctx context.Context,
|
||||
state *state.State,
|
||||
validateURIHandler manage.ValidateURIHandler,
|
||||
clientScopeHandler server.ClientScopeHandler,
|
||||
authorizeScopeHandler server.AuthorizeScopeHandler,
|
||||
internalErrorHandler server.InternalErrorHandler,
|
||||
responseErrorHandler server.ResponseErrorHandler,
|
||||
userAuthorizationHandler server.UserAuthorizationHandler,
|
||||
) Server {
|
||||
ts := newTokenStore(ctx, state)
|
||||
cs := NewClientStore(state)
|
||||
|
||||
// Set up OAuth2 manager.
|
||||
manager := manage.NewDefaultManager()
|
||||
manager.SetValidateURIHandler(validateURIHandler)
|
||||
manager.MapTokenStorage(ts)
|
||||
manager.MapClientStorage(cs)
|
||||
manager.SetAuthorizeCodeTokenCfg(&manage.Config{
|
||||
AccessTokenExp: 0, // access tokens don't expire -- they must be revoked
|
||||
IsGenerateRefresh: false, // don't use refresh tokens
|
||||
})
|
||||
sc := &server.Config{
|
||||
TokenType: "Bearer",
|
||||
// Must follow the spec.
|
||||
AllowGetAccessRequest: false,
|
||||
// Support only the non-implicit flow.
|
||||
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
|
||||
// Allow:
|
||||
// - Authorization Code (for first & third parties)
|
||||
// - Client Credentials (for applications)
|
||||
AllowedGrantTypes: []oauth2.GrantType{
|
||||
oauth2.AuthorizationCode,
|
||||
oauth2.ClientCredentials,
|
||||
manager.SetAuthorizeCodeTokenCfg(
|
||||
&manage.Config{
|
||||
// Following the Mastodon API,
|
||||
// access tokens don't expire.
|
||||
AccessTokenExp: 0,
|
||||
// Don't use refresh tokens.
|
||||
IsGenerateRefresh: false,
|
||||
},
|
||||
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
|
||||
oauth2.CodeChallengePlain,
|
||||
oauth2.CodeChallengeS256,
|
||||
)
|
||||
|
||||
// Set up OAuth2 server.
|
||||
srv := server.NewServer(
|
||||
&server.Config{
|
||||
TokenType: "Bearer",
|
||||
// Must follow the spec.
|
||||
AllowGetAccessRequest: false,
|
||||
// Support only the non-implicit flow.
|
||||
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
|
||||
// Allow:
|
||||
// - Authorization Code (for first & third parties)
|
||||
// - Client Credentials (for applications)
|
||||
AllowedGrantTypes: []oauth2.GrantType{
|
||||
oauth2.AuthorizationCode,
|
||||
oauth2.ClientCredentials,
|
||||
},
|
||||
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
|
||||
oauth2.CodeChallengePlain,
|
||||
oauth2.CodeChallengeS256,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srv := server.NewServer(sc, manager)
|
||||
srv.SetInternalErrorHandler(func(err error) *oautherr.Response {
|
||||
log.Errorf(nil, "internal oauth error: %s", err)
|
||||
return nil
|
||||
})
|
||||
|
||||
srv.SetResponseErrorHandler(func(re *oautherr.Response) {
|
||||
log.Errorf(nil, "internal response error: %s", re.Error)
|
||||
})
|
||||
|
||||
srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
userID := r.FormValue("userid")
|
||||
if userID == "" {
|
||||
return "", errors.New("userid was empty")
|
||||
}
|
||||
return userID, nil
|
||||
})
|
||||
manager,
|
||||
)
|
||||
srv.SetAuthorizeScopeHandler(authorizeScopeHandler)
|
||||
srv.SetClientScopeHandler(clientScopeHandler)
|
||||
srv.SetInternalErrorHandler(internalErrorHandler)
|
||||
srv.SetResponseErrorHandler(responseErrorHandler)
|
||||
srv.SetUserAuthorizationHandler(userAuthorizationHandler)
|
||||
srv.SetClientInfoHandler(server.ClientFormHandler)
|
||||
return &s{
|
||||
server: srv,
|
||||
}
|
||||
|
||||
return &s{srv}
|
||||
}
|
||||
|
||||
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
|
||||
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function,
|
||||
// providing some custom error handling (with more informative messages),
|
||||
// and a slightly different token serialization format.
|
||||
func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode) {
|
||||
ctx := r.Context()
|
||||
|
||||
|
|
@ -142,32 +156,43 @@ func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserro
|
|||
return nil, gtserror.NewErrorBadRequest(err, help, adv)
|
||||
}
|
||||
|
||||
// Get access token + do our own nicer error handling.
|
||||
ti, err := s.server.GetAccessToken(ctx, gt, tgr)
|
||||
if err != nil {
|
||||
help := fmt.Sprintf("could not get access token: %s", err)
|
||||
switch {
|
||||
case err == nil:
|
||||
// No problem.
|
||||
break
|
||||
|
||||
case errors.Is(err, oautherr.ErrInvalidScope):
|
||||
help := fmt.Sprintf("requested scope %s was not covered by client scope", tgr.Scope)
|
||||
return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
|
||||
|
||||
case errors.Is(err, oautherr.ErrInvalidRedirectURI):
|
||||
help := fmt.Sprintf("requested redirect URI %s was not covered by client redirect URIs", tgr.RedirectURI)
|
||||
return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
|
||||
|
||||
default:
|
||||
help := fmt.Sprintf("could not get access token: %v", err)
|
||||
return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice)
|
||||
}
|
||||
|
||||
// Wrangle data a bit.
|
||||
data := s.server.GetTokenData(ti)
|
||||
|
||||
// Add created_at for Mastodon API compatibility.
|
||||
data["created_at"] = ti.GetAccessCreateAt().Unix()
|
||||
|
||||
// If expires_in is 0 or less, omit it
|
||||
// from serialization so that clients don't
|
||||
// interpret the token as already expired.
|
||||
if expiresInI, ok := data["expires_in"]; ok {
|
||||
switch expiresIn := expiresInI.(type) {
|
||||
case int64:
|
||||
// remove this key from the returned map
|
||||
// if the value is 0 or less, so that clients
|
||||
// don't interpret the token as already expired
|
||||
if expiresIn <= 0 {
|
||||
delete(data, "expires_in")
|
||||
}
|
||||
default:
|
||||
err := errors.New("expires_in was set on token response, but was not an int64")
|
||||
return nil, gtserror.NewErrorInternalError(err, HelpfulAdvice)
|
||||
// This will panic if expiresIn is
|
||||
// not an int64, which is what we want.
|
||||
if expiresInI.(int64) <= 0 {
|
||||
delete(data, "expires_in")
|
||||
}
|
||||
}
|
||||
|
||||
// add this for mastodon api compatibility
|
||||
data["created_at"] = ti.GetAccessCreateAt().Unix()
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
|
|
@ -207,7 +232,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
|
|||
}
|
||||
req.UserID = userID
|
||||
|
||||
// specify the scope of authorization
|
||||
// Specify the scope of authorization.
|
||||
if fn := s.server.AuthorizeScopeHandler; fn != nil {
|
||||
scope, err := fn(w, r)
|
||||
if err != nil {
|
||||
|
|
@ -217,7 +242,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
|
|||
}
|
||||
}
|
||||
|
||||
// specify the expiration time of access token
|
||||
// Specify the expiration time of access token.
|
||||
if fn := s.server.AccessTokenExpHandler; fn != nil {
|
||||
exp, err := fn(w, r)
|
||||
if err != nil {
|
||||
|
|
@ -231,13 +256,24 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
|
|||
return s.errorOrRedirect(err, w, req)
|
||||
}
|
||||
|
||||
// If the redirect URI is empty, the default domain provided by the client is used.
|
||||
// If the redirect URI is empty, use the
|
||||
// first of the client's redirect URIs.
|
||||
if req.RedirectURI == "" {
|
||||
client, err := s.server.Manager.GetClient(ctx, req.ClientID)
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
// Real error.
|
||||
err := gtserror.Newf("db error getting application with client id %s: %w", req.ClientID, err)
|
||||
return gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
if util.IsNil(client) {
|
||||
// Application just not found.
|
||||
return gtserror.NewErrorUnauthorized(err, HelpfulAdvice)
|
||||
}
|
||||
req.RedirectURI = client.GetDomain()
|
||||
|
||||
// This will panic if client is not a
|
||||
// *gtsmodel.Application, which is what we want.
|
||||
req.RedirectURI = client.(*gtsmodel.Application).RedirectURIs[0]
|
||||
}
|
||||
|
||||
uri, err := s.server.GetRedirectURI(req, s.server.GetAuthorizeData(req.ResponseType, ti))
|
||||
|
|
|
|||
|
|
@ -22,30 +22,32 @@ import (
|
|||
"errors"
|
||||
"time"
|
||||
|
||||
"codeberg.org/gruf/go-mutexes"
|
||||
"codeberg.org/superseriousbusiness/oauth2/v4"
|
||||
"codeberg.org/superseriousbusiness/oauth2/v4/models"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
)
|
||||
|
||||
// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
|
||||
type tokenStore struct {
|
||||
oauth2.TokenStore
|
||||
db db.DB
|
||||
state *state.State
|
||||
lastUsedLocks mutexes.MutexMap
|
||||
}
|
||||
|
||||
// newTokenStore returns a token store that satisfies the oauth2.TokenStore interface.
|
||||
//
|
||||
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
|
||||
// the tokens in the DB once per minute and deletes any that have expired.
|
||||
func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore {
|
||||
ts := &tokenStore{
|
||||
db: db,
|
||||
}
|
||||
func newTokenStore(ctx context.Context, state *state.State) oauth2.TokenStore {
|
||||
ts := &tokenStore{state: state}
|
||||
|
||||
// set the token store to clean out expired tokens once per minute, or return if we're done
|
||||
// Set the token store to clean out expired tokens
|
||||
// once per minute, or return if we're done.
|
||||
go func(ctx context.Context, ts *tokenStore) {
|
||||
cleanloop:
|
||||
for {
|
||||
|
|
@ -64,25 +66,48 @@ func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore {
|
|||
return ts
|
||||
}
|
||||
|
||||
// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so.
|
||||
// sweep clears out old tokens that have expired;
|
||||
// it should be run on a loop about once per minute or so.
|
||||
func (ts *tokenStore) sweep(ctx context.Context) error {
|
||||
// select *all* tokens from the db
|
||||
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
|
||||
tokens, err := ts.db.GetAllTokens(ctx)
|
||||
// Select *all* tokens from the db
|
||||
//
|
||||
// TODO: if this becomes expensive
|
||||
// (ie., there are fucking LOADS of
|
||||
// tokens) then figure out a better way.
|
||||
tokens, err := ts.state.DB.GetAllTokens(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// iterate through and remove expired tokens
|
||||
// Remove any expired tokens, bearing
|
||||
// in mind that zero time = no expiry.
|
||||
now := time.Now()
|
||||
for _, dbt := range tokens {
|
||||
// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
|
||||
// we only want to check if a token expired before now if the expiry time is *not zero*;
|
||||
// ie., if it's been explicity set.
|
||||
if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) {
|
||||
if err := ts.db.DeleteTokenByID(ctx, dbt.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, token := range tokens {
|
||||
var expired bool
|
||||
|
||||
switch {
|
||||
case !token.CodeExpiresAt.IsZero() && token.CodeExpiresAt.Before(now):
|
||||
log.Tracef(ctx, "code token %s is expired", token.ID)
|
||||
expired = true
|
||||
|
||||
case !token.RefreshExpiresAt.IsZero() && token.RefreshExpiresAt.Before(now):
|
||||
log.Tracef(ctx, "refresh token %s is expired", token.ID)
|
||||
expired = true
|
||||
|
||||
case !token.AccessExpiresAt.IsZero() && token.AccessExpiresAt.Before(now):
|
||||
log.Tracef(ctx, "access token %s is expired", token.ID)
|
||||
expired = true
|
||||
}
|
||||
|
||||
if !expired {
|
||||
// Token's
|
||||
// still good.
|
||||
continue
|
||||
}
|
||||
|
||||
if err := ts.state.DB.DeleteTokenByID(ctx, token.ID); err != nil {
|
||||
err := gtserror.Newf("db error expiring token %s: %w", token.ID, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -90,7 +115,6 @@ func (ts *tokenStore) sweep(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// Create creates and store the new token information.
|
||||
// For the original implementation, see https://codeberg.org/superseriousbusiness/oauth2/blob/master/store/token.go#L34
|
||||
func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
|
||||
t, ok := info.(*models.Token)
|
||||
if !ok {
|
||||
|
|
@ -99,55 +123,99 @@ func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
|
|||
|
||||
dbt := TokenToDBToken(t)
|
||||
if dbt.ID == "" {
|
||||
dbtID, err := id.NewRandomULID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dbt.ID = dbtID
|
||||
dbt.ID = id.NewULID()
|
||||
}
|
||||
|
||||
return ts.db.PutToken(ctx, dbt)
|
||||
return ts.state.DB.PutToken(ctx, dbt)
|
||||
}
|
||||
|
||||
// RemoveByCode deletes a token from the DB based on the Code field
|
||||
func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
|
||||
return ts.db.DeleteTokenByCode(ctx, code)
|
||||
return ts.state.DB.DeleteTokenByCode(ctx, code)
|
||||
}
|
||||
|
||||
// RemoveByAccess deletes a token from the DB based on the Access field
|
||||
func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
|
||||
return ts.db.DeleteTokenByAccess(ctx, access)
|
||||
return ts.state.DB.DeleteTokenByAccess(ctx, access)
|
||||
}
|
||||
|
||||
// RemoveByRefresh deletes a token from the DB based on the Refresh field
|
||||
func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
|
||||
return ts.db.DeleteTokenByRefresh(ctx, refresh)
|
||||
return ts.state.DB.DeleteTokenByRefresh(ctx, refresh)
|
||||
}
|
||||
|
||||
// GetByCode selects a token from the DB based on the Code field
|
||||
func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
|
||||
token, err := ts.db.GetTokenByCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DBTokenToToken(token), nil
|
||||
// GetByCode selects a token from
|
||||
// the DB based on the Code field
|
||||
func (ts *tokenStore) GetByCode(
|
||||
ctx context.Context,
|
||||
code string,
|
||||
) (oauth2.TokenInfo, error) {
|
||||
return ts.getUpdateToken(
|
||||
ctx,
|
||||
ts.state.DB.GetTokenByCode,
|
||||
code,
|
||||
)
|
||||
}
|
||||
|
||||
// GetByAccess selects a token from the DB based on the Access field
|
||||
func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
|
||||
token, err := ts.db.GetTokenByAccess(ctx, access)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DBTokenToToken(token), nil
|
||||
// GetByAccess selects a token from
|
||||
// the DB based on the Access field.
|
||||
func (ts *tokenStore) GetByAccess(
|
||||
ctx context.Context,
|
||||
access string,
|
||||
) (oauth2.TokenInfo, error) {
|
||||
return ts.getUpdateToken(
|
||||
ctx,
|
||||
ts.state.DB.GetTokenByAccess,
|
||||
access,
|
||||
)
|
||||
}
|
||||
|
||||
// GetByRefresh selects a token from the DB based on the Refresh field
|
||||
func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
|
||||
token, err := ts.db.GetTokenByRefresh(ctx, refresh)
|
||||
// GetByRefresh selects a token from
|
||||
// the DB based on the Refresh field
|
||||
func (ts *tokenStore) GetByRefresh(
|
||||
ctx context.Context,
|
||||
refresh string,
|
||||
) (oauth2.TokenInfo, error) {
|
||||
return ts.getUpdateToken(
|
||||
ctx,
|
||||
ts.state.DB.GetTokenByRefresh,
|
||||
refresh,
|
||||
)
|
||||
}
|
||||
|
||||
// package-internal function for getting a token
|
||||
// and potentially updating its last_used value.
|
||||
func (ts *tokenStore) getUpdateToken(
|
||||
ctx context.Context,
|
||||
getBy func(context.Context, string) (*gtsmodel.Token, error),
|
||||
key string,
|
||||
) (oauth2.TokenInfo, error) {
|
||||
// Hold a lock to get the token based on
|
||||
// whatever func + key we've been given.
|
||||
unlock := ts.lastUsedLocks.Lock(key)
|
||||
|
||||
token, err := getBy(ctx, key)
|
||||
if err != nil {
|
||||
// Unlock on error.
|
||||
unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If token was last used more than
|
||||
// an hour ago, update this in the db.
|
||||
wasLastUsed := token.LastUsed
|
||||
if now := time.Now(); now.Sub(wasLastUsed) > 1*time.Hour {
|
||||
token.LastUsed = now
|
||||
if err := ts.state.DB.UpdateToken(ctx, token, "last_used"); err != nil {
|
||||
// Unlock on error.
|
||||
unlock()
|
||||
err := gtserror.Newf("error updating last_used on token: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// We're done, unlock.
|
||||
unlock()
|
||||
return DBTokenToToken(token), nil
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue