[feature] Refactor tokens, allow multiple app redirect_uris (#3849)

* [feature] Refactor tokens, allow multiple app redirect_uris

* move + tweak handlers a bit

* return error for unset oauth2.ClientStore funcs

* wrap UpdateToken with cache

* panic handling

* cheeky little time optimization

* unlock on error
This commit is contained in:
tobi 2025-03-03 16:03:36 +01:00 committed by GitHub
commit 1b37944f8b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
77 changed files with 963 additions and 594 deletions

View file

@ -22,30 +22,32 @@ import (
"errors"
"time"
"codeberg.org/gruf/go-mutexes"
"codeberg.org/superseriousbusiness/oauth2/v4"
"codeberg.org/superseriousbusiness/oauth2/v4/models"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
)
// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
type tokenStore struct {
oauth2.TokenStore
db db.DB
state *state.State
lastUsedLocks mutexes.MutexMap
}
// newTokenStore returns a token store that satisfies the oauth2.TokenStore interface.
//
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
// the tokens in the DB once per minute and deletes any that have expired.
func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore {
ts := &tokenStore{
db: db,
}
func newTokenStore(ctx context.Context, state *state.State) oauth2.TokenStore {
ts := &tokenStore{state: state}
// set the token store to clean out expired tokens once per minute, or return if we're done
// Set the token store to clean out expired tokens
// once per minute, or return if we're done.
go func(ctx context.Context, ts *tokenStore) {
cleanloop:
for {
@ -64,25 +66,48 @@ func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore {
return ts
}
// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so.
// sweep clears out old tokens that have expired;
// it should be run on a loop about once per minute or so.
func (ts *tokenStore) sweep(ctx context.Context) error {
// select *all* tokens from the db
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
tokens, err := ts.db.GetAllTokens(ctx)
// Select *all* tokens from the db
//
// TODO: if this becomes expensive
// (ie., there are fucking LOADS of
// tokens) then figure out a better way.
tokens, err := ts.state.DB.GetAllTokens(ctx)
if err != nil {
return err
}
// iterate through and remove expired tokens
// Remove any expired tokens, bearing
// in mind that zero time = no expiry.
now := time.Now()
for _, dbt := range tokens {
// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
// we only want to check if a token expired before now if the expiry time is *not zero*;
// ie., if it's been explicity set.
if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) {
if err := ts.db.DeleteTokenByID(ctx, dbt.ID); err != nil {
return err
}
for _, token := range tokens {
var expired bool
switch {
case !token.CodeExpiresAt.IsZero() && token.CodeExpiresAt.Before(now):
log.Tracef(ctx, "code token %s is expired", token.ID)
expired = true
case !token.RefreshExpiresAt.IsZero() && token.RefreshExpiresAt.Before(now):
log.Tracef(ctx, "refresh token %s is expired", token.ID)
expired = true
case !token.AccessExpiresAt.IsZero() && token.AccessExpiresAt.Before(now):
log.Tracef(ctx, "access token %s is expired", token.ID)
expired = true
}
if !expired {
// Token's
// still good.
continue
}
if err := ts.state.DB.DeleteTokenByID(ctx, token.ID); err != nil {
err := gtserror.Newf("db error expiring token %s: %w", token.ID, err)
return err
}
}
@ -90,7 +115,6 @@ func (ts *tokenStore) sweep(ctx context.Context) error {
}
// Create creates and store the new token information.
// For the original implementation, see https://codeberg.org/superseriousbusiness/oauth2/blob/master/store/token.go#L34
func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
t, ok := info.(*models.Token)
if !ok {
@ -99,55 +123,99 @@ func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
dbt := TokenToDBToken(t)
if dbt.ID == "" {
dbtID, err := id.NewRandomULID()
if err != nil {
return err
}
dbt.ID = dbtID
dbt.ID = id.NewULID()
}
return ts.db.PutToken(ctx, dbt)
return ts.state.DB.PutToken(ctx, dbt)
}
// RemoveByCode deletes a token from the DB based on the Code field
func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
return ts.db.DeleteTokenByCode(ctx, code)
return ts.state.DB.DeleteTokenByCode(ctx, code)
}
// RemoveByAccess deletes a token from the DB based on the Access field
func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
return ts.db.DeleteTokenByAccess(ctx, access)
return ts.state.DB.DeleteTokenByAccess(ctx, access)
}
// RemoveByRefresh deletes a token from the DB based on the Refresh field
func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
return ts.db.DeleteTokenByRefresh(ctx, refresh)
return ts.state.DB.DeleteTokenByRefresh(ctx, refresh)
}
// GetByCode selects a token from the DB based on the Code field
func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
token, err := ts.db.GetTokenByCode(ctx, code)
if err != nil {
return nil, err
}
return DBTokenToToken(token), nil
// GetByCode selects a token from
// the DB based on the Code field
func (ts *tokenStore) GetByCode(
ctx context.Context,
code string,
) (oauth2.TokenInfo, error) {
return ts.getUpdateToken(
ctx,
ts.state.DB.GetTokenByCode,
code,
)
}
// GetByAccess selects a token from the DB based on the Access field
func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
token, err := ts.db.GetTokenByAccess(ctx, access)
if err != nil {
return nil, err
}
return DBTokenToToken(token), nil
// GetByAccess selects a token from
// the DB based on the Access field.
func (ts *tokenStore) GetByAccess(
ctx context.Context,
access string,
) (oauth2.TokenInfo, error) {
return ts.getUpdateToken(
ctx,
ts.state.DB.GetTokenByAccess,
access,
)
}
// GetByRefresh selects a token from the DB based on the Refresh field
func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
token, err := ts.db.GetTokenByRefresh(ctx, refresh)
// GetByRefresh selects a token from
// the DB based on the Refresh field
func (ts *tokenStore) GetByRefresh(
ctx context.Context,
refresh string,
) (oauth2.TokenInfo, error) {
return ts.getUpdateToken(
ctx,
ts.state.DB.GetTokenByRefresh,
refresh,
)
}
// package-internal function for getting a token
// and potentially updating its last_used value.
func (ts *tokenStore) getUpdateToken(
ctx context.Context,
getBy func(context.Context, string) (*gtsmodel.Token, error),
key string,
) (oauth2.TokenInfo, error) {
// Hold a lock to get the token based on
// whatever func + key we've been given.
unlock := ts.lastUsedLocks.Lock(key)
token, err := getBy(ctx, key)
if err != nil {
// Unlock on error.
unlock()
return nil, err
}
// If token was last used more than
// an hour ago, update this in the db.
wasLastUsed := token.LastUsed
if now := time.Now(); now.Sub(wasLastUsed) > 1*time.Hour {
token.LastUsed = now
if err := ts.state.DB.UpdateToken(ctx, token, "last_used"); err != nil {
// Unlock on error.
unlock()
err := gtserror.Newf("error updating last_used on token: %w", err)
return nil, err
}
}
// We're done, unlock.
unlock()
return DBTokenToToken(token), nil
}