[feature] Refactor tokens, allow multiple app redirect_uris (#3849)

* [feature] Refactor tokens, allow multiple app redirect_uris

* move + tweak handlers a bit

* return error for unset oauth2.ClientStore funcs

* wrap UpdateToken with cache

* panic handling

* cheeky little time optimization

* unlock on error
This commit is contained in:
tobi 2025-03-03 16:03:36 +01:00 committed by GitHub
commit 1b37944f8b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
77 changed files with 963 additions and 594 deletions

View file

@ -21,45 +21,30 @@ import (
"context"
"codeberg.org/superseriousbusiness/oauth2/v4"
"codeberg.org/superseriousbusiness/oauth2/v4/models"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"codeberg.org/superseriousbusiness/oauth2/v4/errors"
"github.com/superseriousbusiness/gotosocial/internal/state"
)
type clientStore struct {
db db.DB
state *state.State
}
// NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend.
func NewClientStore(db db.DB) oauth2.ClientStore {
pts := &clientStore{
db: db,
}
return pts
// NewClientStore returns a minimal implementation of
// oauth2.ClientStore interface, using state as storage.
//
// Only GetByID is implemented, Set and Delete are stubs.
func NewClientStore(state *state.State) oauth2.ClientStore {
return &clientStore{state: state}
}
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
client, err := cs.db.GetClientByID(ctx, clientID)
if err != nil {
return nil, err
}
return models.New(
client.ID,
client.Secret,
client.Domain,
client.UserID,
), nil
return cs.state.DB.GetApplicationByClientID(ctx, clientID)
}
func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
return cs.db.PutClient(ctx, &gtsmodel.Client{
ID: cli.GetID(),
Secret: cli.GetSecret(),
Domain: cli.GetDomain(),
UserID: cli.GetUserID(),
})
func (cs *clientStore) Set(_ context.Context, _ string, _ oauth2.ClientInfo) error {
return errors.New("func oauth2.ClientStore.Set not implemented")
}
func (cs *clientStore) Delete(ctx context.Context, id string) error {
return cs.db.DeleteClientByID(ctx, id)
func (cs *clientStore) Delete(_ context.Context, _ string) error {
return errors.New("func oauth2.ClientStore.Delete not implemented")
}

View file

@ -21,93 +21,58 @@ import (
"context"
"testing"
"codeberg.org/superseriousbusiness/oauth2/v4/models"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/admin"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type PgClientStoreTestSuite struct {
type ClientStoreTestSuite struct {
suite.Suite
db db.DB
state state.State
testClientID string
testClientSecret string
testClientDomain string
testClientUserID string
testApplications map[string]*gtsmodel.Application
}
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
func (suite *PgClientStoreTestSuite) SetupSuite() {
suite.testClientID = "01FCVB74EW6YBYAEY7QG9CQQF6"
suite.testClientSecret = "4cc87402-259b-4a35-9485-2c8bf54f3763"
suite.testClientDomain = "https://example.org"
suite.testClientUserID = "01FEGYXKVCDB731QF9MVFXA4F5"
func (suite *ClientStoreTestSuite) SetupSuite() {
suite.testApplications = testrig.NewTestApplications()
}
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
func (suite *PgClientStoreTestSuite) SetupTest() {
func (suite *ClientStoreTestSuite) SetupTest() {
suite.state.Caches.Init()
testrig.InitTestLog()
testrig.InitTestConfig()
testrig.InitTestLog()
suite.db = testrig.NewTestDB(&suite.state)
suite.state.DB = suite.db
suite.state.AdminActions = admin.New(suite.state.DB, &suite.state.Workers)
testrig.StandardDBSetup(suite.db, nil)
}
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
func (suite *PgClientStoreTestSuite) TearDownTest() {
func (suite *ClientStoreTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() {
// set a new client in the store
cs := oauth.NewClientStore(suite.db)
if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
suite.FailNow(err.Error())
}
func (suite *ClientStoreTestSuite) TestClientStoreGet() {
testApp := suite.testApplications["application_1"]
cs := oauth.NewClientStore(&suite.state)
// fetch that client from the store
client, err := cs.GetByID(context.Background(), suite.testClientID)
// Fetch clientInfo from the store.
clientInfo, err := cs.GetByID(context.Background(), testApp.ClientID)
if err != nil {
suite.FailNow(err.Error())
}
// check that the values are the same
suite.NotNil(client)
suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client)
// Check expected values.
suite.NotNil(clientInfo)
suite.Equal(testApp.ClientID, clientInfo.GetID())
suite.Equal(testApp.ClientSecret, clientInfo.GetSecret())
suite.Equal(testApp.RedirectURIs[0], clientInfo.GetDomain())
suite.Equal(testApp.ManagedByUserID, clientInfo.GetUserID())
}
func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
// set a new client in the store
cs := oauth.NewClientStore(suite.db)
if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
suite.FailNow(err.Error())
}
// fetch the client from the store
client, err := cs.GetByID(context.Background(), suite.testClientID)
if err != nil {
suite.FailNow(err.Error())
}
// check that the values are the same
suite.NotNil(client)
suite.EqualValues(models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID), client)
if err := cs.Delete(context.Background(), suite.testClientID); err != nil {
suite.FailNow(err.Error())
}
// try to get the deleted client; we should get an error
deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
suite.Assert().Nil(deletedClient)
suite.Assert().EqualValues(db.ErrNoEntries, err)
}
func TestPgClientStoreTestSuite(t *testing.T) {
suite.Run(t, new(PgClientStoreTestSuite))
func TestClientStoreTestSuite(t *testing.T) {
suite.Run(t, new(ClientStoreTestSuite))
}

View file

@ -0,0 +1,153 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package handlers
import (
"context"
"errors"
"net/http"
"net/url"
"slices"
"strings"
"codeberg.org/superseriousbusiness/oauth2/v4"
oautherr "codeberg.org/superseriousbusiness/oauth2/v4/errors"
"codeberg.org/superseriousbusiness/oauth2/v4/manage"
"codeberg.org/superseriousbusiness/oauth2/v4/server"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
)
// GetClientScopeHandler returns a handler for testing scope on a TokenGenerateRequest.
func GetClientScopeHandler(ctx context.Context, state *state.State) server.ClientScopeHandler {
return func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) {
application, err := state.DB.GetApplicationByClientID(
gtscontext.SetBarebones(ctx),
tgr.ClientID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
log.Errorf(ctx, "database error getting application: %v", err)
return false, err
}
if application == nil {
err := gtserror.Newf("no application found with client id %s", tgr.ClientID)
return false, err
}
// Normalize scope.
if strings.TrimSpace(tgr.Scope) == "" {
tgr.Scope = "read"
}
// Make sure requested scopes are all
// within scopes permitted by application.
hasScopes := strings.Split(application.Scopes, " ")
wantsScopes := strings.Split(tgr.Scope, " ")
for _, wantsScope := range wantsScopes {
thisOK := slices.ContainsFunc(
hasScopes,
func(hasScope string) bool {
has := apiutil.Scope(hasScope)
wants := apiutil.Scope(wantsScope)
return has.Permits(wants)
},
)
if !thisOK {
// Requested unpermitted
// scope for this app.
return false, nil
}
}
// All OK.
return true, nil
}
}
func GetValidateURIHandler(ctx context.Context) manage.ValidateURIHandler {
return func(hasRedirects string, wantsRedirect string) error {
// Normalize the wantsRedirect URI
// string by parsing + reserializing.
wantsRedirectURI, err := url.Parse(wantsRedirect)
if err != nil {
return err
}
wantsRedirect = wantsRedirectURI.String()
// Redirect URIs are given to us as
// a list of URIs, newline-separated.
//
// They're already normalized on input so
// we don't need to parse + reserialize them.
//
// Ensure that one of them matches.
if slices.ContainsFunc(
strings.Split(hasRedirects, "\n"),
func(hasRedirect string) bool {
// Want an exact match.
// See: https://www.oauth.com/oauth2-servers/redirect-uris/redirect-uri-validation/
return wantsRedirect == hasRedirect
},
) {
return nil
}
return oautherr.ErrInvalidRedirectURI
}
}
func GetAuthorizeScopeHandler() server.AuthorizeScopeHandler {
return func(_ http.ResponseWriter, r *http.Request) (string, error) {
// Use provided scope or
// fall back to default "read".
scope := r.FormValue("scope")
if strings.TrimSpace(scope) == "" {
scope = "read"
}
return scope, nil
}
}
func GetInternalErrorHandler(ctx context.Context) server.InternalErrorHandler {
return func(err error) *oautherr.Response {
log.Errorf(ctx, "internal oauth error: %v", err)
return nil
}
}
func GetResponseErrorHandler(ctx context.Context) server.ResponseErrorHandler {
return func(re *oautherr.Response) {
log.Errorf(ctx, "internal response error: %v", re.Error)
}
}
func GetUserAuthorizationHandler() server.UserAuthorizationHandler {
return func(w http.ResponseWriter, r *http.Request) (string, error) {
userID := r.FormValue("userid")
if userID == "" {
return "", errors.New("userid was empty")
}
return userID, nil
}
}

View file

@ -1,20 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package oauth_test
// TODO: write tests

View file

@ -30,7 +30,10 @@ import (
"codeberg.org/superseriousbusiness/oauth2/v4/server"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
const (
@ -60,7 +63,8 @@ const (
HelpfulAdviceGrant = "If you arrived at this error during a sign in/oauth flow, your client is trying to use an unsupported OAuth grant type. Supported grant types are: authorization_code, client_credentials; please reach out to developer of your client"
)
// Server wraps some oauth2 server functions in an interface, exposing only what is needed
// Server wraps some oauth2 server functions
// in an interface, exposing only what is needed.
type Server interface {
HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode)
HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtserror.WithCode
@ -69,66 +73,76 @@ type Server interface {
LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error)
}
// s fulfils the Server interface using the underlying oauth2 server
// s fulfils the Server interface
// using the underlying oauth2 server.
type s struct {
server *server.Server
}
// New returns a new oauth server that implements the Server interface
func New(ctx context.Context, database db.DB) Server {
ts := newTokenStore(ctx, database)
cs := NewClientStore(database)
func New(
ctx context.Context,
state *state.State,
validateURIHandler manage.ValidateURIHandler,
clientScopeHandler server.ClientScopeHandler,
authorizeScopeHandler server.AuthorizeScopeHandler,
internalErrorHandler server.InternalErrorHandler,
responseErrorHandler server.ResponseErrorHandler,
userAuthorizationHandler server.UserAuthorizationHandler,
) Server {
ts := newTokenStore(ctx, state)
cs := NewClientStore(state)
// Set up OAuth2 manager.
manager := manage.NewDefaultManager()
manager.SetValidateURIHandler(validateURIHandler)
manager.MapTokenStorage(ts)
manager.MapClientStorage(cs)
manager.SetAuthorizeCodeTokenCfg(&manage.Config{
AccessTokenExp: 0, // access tokens don't expire -- they must be revoked
IsGenerateRefresh: false, // don't use refresh tokens
})
sc := &server.Config{
TokenType: "Bearer",
// Must follow the spec.
AllowGetAccessRequest: false,
// Support only the non-implicit flow.
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
// Allow:
// - Authorization Code (for first & third parties)
// - Client Credentials (for applications)
AllowedGrantTypes: []oauth2.GrantType{
oauth2.AuthorizationCode,
oauth2.ClientCredentials,
manager.SetAuthorizeCodeTokenCfg(
&manage.Config{
// Following the Mastodon API,
// access tokens don't expire.
AccessTokenExp: 0,
// Don't use refresh tokens.
IsGenerateRefresh: false,
},
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
oauth2.CodeChallengePlain,
oauth2.CodeChallengeS256,
)
// Set up OAuth2 server.
srv := server.NewServer(
&server.Config{
TokenType: "Bearer",
// Must follow the spec.
AllowGetAccessRequest: false,
// Support only the non-implicit flow.
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
// Allow:
// - Authorization Code (for first & third parties)
// - Client Credentials (for applications)
AllowedGrantTypes: []oauth2.GrantType{
oauth2.AuthorizationCode,
oauth2.ClientCredentials,
},
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
oauth2.CodeChallengePlain,
oauth2.CodeChallengeS256,
},
},
}
srv := server.NewServer(sc, manager)
srv.SetInternalErrorHandler(func(err error) *oautherr.Response {
log.Errorf(nil, "internal oauth error: %s", err)
return nil
})
srv.SetResponseErrorHandler(func(re *oautherr.Response) {
log.Errorf(nil, "internal response error: %s", re.Error)
})
srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
userID := r.FormValue("userid")
if userID == "" {
return "", errors.New("userid was empty")
}
return userID, nil
})
manager,
)
srv.SetAuthorizeScopeHandler(authorizeScopeHandler)
srv.SetClientScopeHandler(clientScopeHandler)
srv.SetInternalErrorHandler(internalErrorHandler)
srv.SetResponseErrorHandler(responseErrorHandler)
srv.SetUserAuthorizationHandler(userAuthorizationHandler)
srv.SetClientInfoHandler(server.ClientFormHandler)
return &s{
server: srv,
}
return &s{srv}
}
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function,
// providing some custom error handling (with more informative messages),
// and a slightly different token serialization format.
func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode) {
ctx := r.Context()
@ -142,32 +156,43 @@ func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserro
return nil, gtserror.NewErrorBadRequest(err, help, adv)
}
// Get access token + do our own nicer error handling.
ti, err := s.server.GetAccessToken(ctx, gt, tgr)
if err != nil {
help := fmt.Sprintf("could not get access token: %s", err)
switch {
case err == nil:
// No problem.
break
case errors.Is(err, oautherr.ErrInvalidScope):
help := fmt.Sprintf("requested scope %s was not covered by client scope", tgr.Scope)
return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
case errors.Is(err, oautherr.ErrInvalidRedirectURI):
help := fmt.Sprintf("requested redirect URI %s was not covered by client redirect URIs", tgr.RedirectURI)
return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
default:
help := fmt.Sprintf("could not get access token: %v", err)
return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice)
}
// Wrangle data a bit.
data := s.server.GetTokenData(ti)
// Add created_at for Mastodon API compatibility.
data["created_at"] = ti.GetAccessCreateAt().Unix()
// If expires_in is 0 or less, omit it
// from serialization so that clients don't
// interpret the token as already expired.
if expiresInI, ok := data["expires_in"]; ok {
switch expiresIn := expiresInI.(type) {
case int64:
// remove this key from the returned map
// if the value is 0 or less, so that clients
// don't interpret the token as already expired
if expiresIn <= 0 {
delete(data, "expires_in")
}
default:
err := errors.New("expires_in was set on token response, but was not an int64")
return nil, gtserror.NewErrorInternalError(err, HelpfulAdvice)
// This will panic if expiresIn is
// not an int64, which is what we want.
if expiresInI.(int64) <= 0 {
delete(data, "expires_in")
}
}
// add this for mastodon api compatibility
data["created_at"] = ti.GetAccessCreateAt().Unix()
return data, nil
}
@ -207,7 +232,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
}
req.UserID = userID
// specify the scope of authorization
// Specify the scope of authorization.
if fn := s.server.AuthorizeScopeHandler; fn != nil {
scope, err := fn(w, r)
if err != nil {
@ -217,7 +242,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
}
}
// specify the expiration time of access token
// Specify the expiration time of access token.
if fn := s.server.AccessTokenExpHandler; fn != nil {
exp, err := fn(w, r)
if err != nil {
@ -231,13 +256,24 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
return s.errorOrRedirect(err, w, req)
}
// If the redirect URI is empty, the default domain provided by the client is used.
// If the redirect URI is empty, use the
// first of the client's redirect URIs.
if req.RedirectURI == "" {
client, err := s.server.Manager.GetClient(ctx, req.ClientID)
if err != nil {
if err != nil && !errors.Is(err, db.ErrNoEntries) {
// Real error.
err := gtserror.Newf("db error getting application with client id %s: %w", req.ClientID, err)
return gtserror.NewErrorInternalError(err)
}
if util.IsNil(client) {
// Application just not found.
return gtserror.NewErrorUnauthorized(err, HelpfulAdvice)
}
req.RedirectURI = client.GetDomain()
// This will panic if client is not a
// *gtsmodel.Application, which is what we want.
req.RedirectURI = client.(*gtsmodel.Application).RedirectURIs[0]
}
uri, err := s.server.GetRedirectURI(req, s.server.GetAuthorizeData(req.ResponseType, ti))

View file

@ -22,30 +22,32 @@ import (
"errors"
"time"
"codeberg.org/gruf/go-mutexes"
"codeberg.org/superseriousbusiness/oauth2/v4"
"codeberg.org/superseriousbusiness/oauth2/v4/models"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
)
// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
type tokenStore struct {
oauth2.TokenStore
db db.DB
state *state.State
lastUsedLocks mutexes.MutexMap
}
// newTokenStore returns a token store that satisfies the oauth2.TokenStore interface.
//
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
// the tokens in the DB once per minute and deletes any that have expired.
func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore {
ts := &tokenStore{
db: db,
}
func newTokenStore(ctx context.Context, state *state.State) oauth2.TokenStore {
ts := &tokenStore{state: state}
// set the token store to clean out expired tokens once per minute, or return if we're done
// Set the token store to clean out expired tokens
// once per minute, or return if we're done.
go func(ctx context.Context, ts *tokenStore) {
cleanloop:
for {
@ -64,25 +66,48 @@ func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore {
return ts
}
// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so.
// sweep clears out old tokens that have expired;
// it should be run on a loop about once per minute or so.
func (ts *tokenStore) sweep(ctx context.Context) error {
// select *all* tokens from the db
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
tokens, err := ts.db.GetAllTokens(ctx)
// Select *all* tokens from the db
//
// TODO: if this becomes expensive
// (ie., there are fucking LOADS of
// tokens) then figure out a better way.
tokens, err := ts.state.DB.GetAllTokens(ctx)
if err != nil {
return err
}
// iterate through and remove expired tokens
// Remove any expired tokens, bearing
// in mind that zero time = no expiry.
now := time.Now()
for _, dbt := range tokens {
// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
// we only want to check if a token expired before now if the expiry time is *not zero*;
// ie., if it's been explicity set.
if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) {
if err := ts.db.DeleteTokenByID(ctx, dbt.ID); err != nil {
return err
}
for _, token := range tokens {
var expired bool
switch {
case !token.CodeExpiresAt.IsZero() && token.CodeExpiresAt.Before(now):
log.Tracef(ctx, "code token %s is expired", token.ID)
expired = true
case !token.RefreshExpiresAt.IsZero() && token.RefreshExpiresAt.Before(now):
log.Tracef(ctx, "refresh token %s is expired", token.ID)
expired = true
case !token.AccessExpiresAt.IsZero() && token.AccessExpiresAt.Before(now):
log.Tracef(ctx, "access token %s is expired", token.ID)
expired = true
}
if !expired {
// Token's
// still good.
continue
}
if err := ts.state.DB.DeleteTokenByID(ctx, token.ID); err != nil {
err := gtserror.Newf("db error expiring token %s: %w", token.ID, err)
return err
}
}
@ -90,7 +115,6 @@ func (ts *tokenStore) sweep(ctx context.Context) error {
}
// Create creates and store the new token information.
// For the original implementation, see https://codeberg.org/superseriousbusiness/oauth2/blob/master/store/token.go#L34
func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
t, ok := info.(*models.Token)
if !ok {
@ -99,55 +123,99 @@ func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
dbt := TokenToDBToken(t)
if dbt.ID == "" {
dbtID, err := id.NewRandomULID()
if err != nil {
return err
}
dbt.ID = dbtID
dbt.ID = id.NewULID()
}
return ts.db.PutToken(ctx, dbt)
return ts.state.DB.PutToken(ctx, dbt)
}
// RemoveByCode deletes a token from the DB based on the Code field
func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
return ts.db.DeleteTokenByCode(ctx, code)
return ts.state.DB.DeleteTokenByCode(ctx, code)
}
// RemoveByAccess deletes a token from the DB based on the Access field
func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
return ts.db.DeleteTokenByAccess(ctx, access)
return ts.state.DB.DeleteTokenByAccess(ctx, access)
}
// RemoveByRefresh deletes a token from the DB based on the Refresh field
func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
return ts.db.DeleteTokenByRefresh(ctx, refresh)
return ts.state.DB.DeleteTokenByRefresh(ctx, refresh)
}
// GetByCode selects a token from the DB based on the Code field
func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
token, err := ts.db.GetTokenByCode(ctx, code)
if err != nil {
return nil, err
}
return DBTokenToToken(token), nil
// GetByCode selects a token from
// the DB based on the Code field
func (ts *tokenStore) GetByCode(
ctx context.Context,
code string,
) (oauth2.TokenInfo, error) {
return ts.getUpdateToken(
ctx,
ts.state.DB.GetTokenByCode,
code,
)
}
// GetByAccess selects a token from the DB based on the Access field
func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
token, err := ts.db.GetTokenByAccess(ctx, access)
if err != nil {
return nil, err
}
return DBTokenToToken(token), nil
// GetByAccess selects a token from
// the DB based on the Access field.
func (ts *tokenStore) GetByAccess(
ctx context.Context,
access string,
) (oauth2.TokenInfo, error) {
return ts.getUpdateToken(
ctx,
ts.state.DB.GetTokenByAccess,
access,
)
}
// GetByRefresh selects a token from the DB based on the Refresh field
func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
token, err := ts.db.GetTokenByRefresh(ctx, refresh)
// GetByRefresh selects a token from
// the DB based on the Refresh field
func (ts *tokenStore) GetByRefresh(
ctx context.Context,
refresh string,
) (oauth2.TokenInfo, error) {
return ts.getUpdateToken(
ctx,
ts.state.DB.GetTokenByRefresh,
refresh,
)
}
// package-internal function for getting a token
// and potentially updating its last_used value.
func (ts *tokenStore) getUpdateToken(
ctx context.Context,
getBy func(context.Context, string) (*gtsmodel.Token, error),
key string,
) (oauth2.TokenInfo, error) {
// Hold a lock to get the token based on
// whatever func + key we've been given.
unlock := ts.lastUsedLocks.Lock(key)
token, err := getBy(ctx, key)
if err != nil {
// Unlock on error.
unlock()
return nil, err
}
// If token was last used more than
// an hour ago, update this in the db.
wasLastUsed := token.LastUsed
if now := time.Now(); now.Sub(wasLastUsed) > 1*time.Hour {
token.LastUsed = now
if err := ts.state.DB.UpdateToken(ctx, token, "last_used"); err != nil {
// Unlock on error.
unlock()
err := gtserror.Newf("error updating last_used on token: %w", err)
return nil, err
}
}
// We're done, unlock.
unlock()
return DBTokenToToken(token), nil
}