diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go index 1a600d620..562614e5e 100644 --- a/internal/db/bundb/application.go +++ b/internal/db/bundb/application.go @@ -19,8 +19,10 @@ package bundb import ( "context" + "errors" "slices" + "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/paging" @@ -409,8 +411,11 @@ func (a *applicationDB) UpdateToken(ctx context.Context, token *gtsmodel.Token, } func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error { + var token gtsmodel.Token + token.ID = id + _, err := a.db.NewDelete(). - Table("tokens"). + Model(&token). Where("? = ?", bun.Ident("id"), id). Exec(ctx) if err != nil { @@ -418,68 +423,79 @@ func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error { } a.state.Caches.DB.Token.Invalidate("ID", id) + a.state.Caches.OnInvalidateToken(&token) return nil } func (a *applicationDB) DeleteTokenByCode(ctx context.Context, code string) error { + var token gtsmodel.Token + _, err := a.db.NewDelete(). - Table("tokens"). + Model(&token). Where("? = ?", bun.Ident("code"), code). + Returning("?", bun.Ident("id")). Exec(ctx) - if err != nil { + if err != nil && !errors.Is(err, db.ErrNoEntries) { return err } a.state.Caches.DB.Token.Invalidate("Code", code) + a.state.Caches.OnInvalidateToken(&token) return nil } func (a *applicationDB) DeleteTokenByAccess(ctx context.Context, access string) error { + var token gtsmodel.Token + _, err := a.db.NewDelete(). - Table("tokens"). + Model(&token). Where("? = ?", bun.Ident("access"), access). + Returning("?", bun.Ident("id")). Exec(ctx) - if err != nil { + if err != nil && !errors.Is(err, db.ErrNoEntries) { return err } a.state.Caches.DB.Token.Invalidate("Access", access) + a.state.Caches.OnInvalidateToken(&token) return nil } func (a *applicationDB) DeleteTokenByRefresh(ctx context.Context, refresh string) error { + var token gtsmodel.Token + _, err := a.db.NewDelete(). - Table("tokens"). + Model(&token). Where("? = ?", bun.Ident("refresh"), refresh). + Returning("?", bun.Ident("id")). Exec(ctx) - if err != nil { + if err != nil && !errors.Is(err, db.ErrNoEntries) { return err } a.state.Caches.DB.Token.Invalidate("Refresh", refresh) + a.state.Caches.OnInvalidateToken(&token) return nil } func (a *applicationDB) DeleteTokensByClientID(ctx context.Context, clientID string) error { + var tokens []*gtsmodel.Token + // Delete tokens owned by // clientID and gather token IDs. - var tokenIDs []string - if _, err := a.db. - NewDelete(). - Table("tokens"). + if _, err := a.db.NewDelete(). + Model(&tokens). Where("? = ?", bun.Ident("client_id"), clientID). - Returning("id"). - Exec(ctx, &tokenIDs); err != nil { + Returning("?", bun.Ident("id")). + Exec(ctx); err != nil && !errors.Is(err, db.ErrNoEntries) { return err } - if len(tokenIDs) == 0 { - // Nothing was deleted, - // nothing to invalidate. - return nil + // Invalidate all deleted tokens. + for _, token := range tokens { + a.state.Caches.DB.Token.Invalidate("ID", token.ID) + a.state.Caches.OnInvalidateToken(token) } - // Invalidate all deleted tokens. - a.state.Caches.DB.Token.InvalidateIDs("ID", tokenIDs) return nil }