ensure OnInvalidateToken() hook is called during token delete

This commit is contained in:
kim 2025-04-18 15:53:28 +01:00
commit 10cf211c2c

View file

@ -19,8 +19,10 @@ package bundb
import ( import (
"context" "context"
"errors"
"slices" "slices"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/paging"
@ -409,8 +411,11 @@ func (a *applicationDB) UpdateToken(ctx context.Context, token *gtsmodel.Token,
} }
func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error { func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error {
var token gtsmodel.Token
token.ID = id
_, err := a.db.NewDelete(). _, err := a.db.NewDelete().
Table("tokens"). Model(&token).
Where("? = ?", bun.Ident("id"), id). Where("? = ?", bun.Ident("id"), id).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil {
@ -418,68 +423,79 @@ func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error {
} }
a.state.Caches.DB.Token.Invalidate("ID", id) a.state.Caches.DB.Token.Invalidate("ID", id)
a.state.Caches.OnInvalidateToken(&token)
return nil return nil
} }
func (a *applicationDB) DeleteTokenByCode(ctx context.Context, code string) error { func (a *applicationDB) DeleteTokenByCode(ctx context.Context, code string) error {
var token gtsmodel.Token
_, err := a.db.NewDelete(). _, err := a.db.NewDelete().
Table("tokens"). Model(&token).
Where("? = ?", bun.Ident("code"), code). Where("? = ?", bun.Ident("code"), code).
Returning("?", bun.Ident("id")).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err return err
} }
a.state.Caches.DB.Token.Invalidate("Code", code) a.state.Caches.DB.Token.Invalidate("Code", code)
a.state.Caches.OnInvalidateToken(&token)
return nil return nil
} }
func (a *applicationDB) DeleteTokenByAccess(ctx context.Context, access string) error { func (a *applicationDB) DeleteTokenByAccess(ctx context.Context, access string) error {
var token gtsmodel.Token
_, err := a.db.NewDelete(). _, err := a.db.NewDelete().
Table("tokens"). Model(&token).
Where("? = ?", bun.Ident("access"), access). Where("? = ?", bun.Ident("access"), access).
Returning("?", bun.Ident("id")).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err return err
} }
a.state.Caches.DB.Token.Invalidate("Access", access) a.state.Caches.DB.Token.Invalidate("Access", access)
a.state.Caches.OnInvalidateToken(&token)
return nil return nil
} }
func (a *applicationDB) DeleteTokenByRefresh(ctx context.Context, refresh string) error { func (a *applicationDB) DeleteTokenByRefresh(ctx context.Context, refresh string) error {
var token gtsmodel.Token
_, err := a.db.NewDelete(). _, err := a.db.NewDelete().
Table("tokens"). Model(&token).
Where("? = ?", bun.Ident("refresh"), refresh). Where("? = ?", bun.Ident("refresh"), refresh).
Returning("?", bun.Ident("id")).
Exec(ctx) Exec(ctx)
if err != nil { if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err return err
} }
a.state.Caches.DB.Token.Invalidate("Refresh", refresh) a.state.Caches.DB.Token.Invalidate("Refresh", refresh)
a.state.Caches.OnInvalidateToken(&token)
return nil return nil
} }
func (a *applicationDB) DeleteTokensByClientID(ctx context.Context, clientID string) error { func (a *applicationDB) DeleteTokensByClientID(ctx context.Context, clientID string) error {
var tokens []*gtsmodel.Token
// Delete tokens owned by // Delete tokens owned by
// clientID and gather token IDs. // clientID and gather token IDs.
var tokenIDs []string if _, err := a.db.NewDelete().
if _, err := a.db. Model(&tokens).
NewDelete().
Table("tokens").
Where("? = ?", bun.Ident("client_id"), clientID). Where("? = ?", bun.Ident("client_id"), clientID).
Returning("id"). Returning("?", bun.Ident("id")).
Exec(ctx, &tokenIDs); err != nil { Exec(ctx); err != nil && !errors.Is(err, db.ErrNoEntries) {
return err return err
} }
if len(tokenIDs) == 0 { // Invalidate all deleted tokens.
// Nothing was deleted, for _, token := range tokens {
// nothing to invalidate. a.state.Caches.DB.Token.Invalidate("ID", token.ID)
return nil a.state.Caches.OnInvalidateToken(token)
} }
// Invalidate all deleted tokens.
a.state.Caches.DB.Token.InvalidateIDs("ID", tokenIDs)
return nil return nil
} }