mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 14:42:26 -05:00 
			
		
		
		
	ensure OnInvalidateToken() hook is called during token delete
This commit is contained in:
		
					parent
					
						
							
								ab8cdd5df5
							
						
					
				
			
			
				commit
				
					
						10cf211c2c
					
				
			
		
					 1 changed files with 35 additions and 19 deletions
				
			
		|  | @ -19,8 +19,10 @@ package bundb | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"errors" | ||||||
| 	"slices" | 	"slices" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/paging" | 	"github.com/superseriousbusiness/gotosocial/internal/paging" | ||||||
|  | @ -409,8 +411,11 @@ func (a *applicationDB) UpdateToken(ctx context.Context, token *gtsmodel.Token, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error { | func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error { | ||||||
|  | 	var token gtsmodel.Token | ||||||
|  | 	token.ID = id | ||||||
|  | 
 | ||||||
| 	_, err := a.db.NewDelete(). | 	_, err := a.db.NewDelete(). | ||||||
| 		Table("tokens"). | 		Model(&token). | ||||||
| 		Where("? = ?", bun.Ident("id"), id). | 		Where("? = ?", bun.Ident("id"), id). | ||||||
| 		Exec(ctx) | 		Exec(ctx) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -418,68 +423,79 @@ func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	a.state.Caches.DB.Token.Invalidate("ID", id) | 	a.state.Caches.DB.Token.Invalidate("ID", id) | ||||||
|  | 	a.state.Caches.OnInvalidateToken(&token) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (a *applicationDB) DeleteTokenByCode(ctx context.Context, code string) error { | func (a *applicationDB) DeleteTokenByCode(ctx context.Context, code string) error { | ||||||
|  | 	var token gtsmodel.Token | ||||||
|  | 
 | ||||||
| 	_, err := a.db.NewDelete(). | 	_, err := a.db.NewDelete(). | ||||||
| 		Table("tokens"). | 		Model(&token). | ||||||
| 		Where("? = ?", bun.Ident("code"), code). | 		Where("? = ?", bun.Ident("code"), code). | ||||||
|  | 		Returning("?", bun.Ident("id")). | ||||||
| 		Exec(ctx) | 		Exec(ctx) | ||||||
| 	if err != nil { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	a.state.Caches.DB.Token.Invalidate("Code", code) | 	a.state.Caches.DB.Token.Invalidate("Code", code) | ||||||
|  | 	a.state.Caches.OnInvalidateToken(&token) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (a *applicationDB) DeleteTokenByAccess(ctx context.Context, access string) error { | func (a *applicationDB) DeleteTokenByAccess(ctx context.Context, access string) error { | ||||||
|  | 	var token gtsmodel.Token | ||||||
|  | 
 | ||||||
| 	_, err := a.db.NewDelete(). | 	_, err := a.db.NewDelete(). | ||||||
| 		Table("tokens"). | 		Model(&token). | ||||||
| 		Where("? = ?", bun.Ident("access"), access). | 		Where("? = ?", bun.Ident("access"), access). | ||||||
|  | 		Returning("?", bun.Ident("id")). | ||||||
| 		Exec(ctx) | 		Exec(ctx) | ||||||
| 	if err != nil { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	a.state.Caches.DB.Token.Invalidate("Access", access) | 	a.state.Caches.DB.Token.Invalidate("Access", access) | ||||||
|  | 	a.state.Caches.OnInvalidateToken(&token) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (a *applicationDB) DeleteTokenByRefresh(ctx context.Context, refresh string) error { | func (a *applicationDB) DeleteTokenByRefresh(ctx context.Context, refresh string) error { | ||||||
|  | 	var token gtsmodel.Token | ||||||
|  | 
 | ||||||
| 	_, err := a.db.NewDelete(). | 	_, err := a.db.NewDelete(). | ||||||
| 		Table("tokens"). | 		Model(&token). | ||||||
| 		Where("? = ?", bun.Ident("refresh"), refresh). | 		Where("? = ?", bun.Ident("refresh"), refresh). | ||||||
|  | 		Returning("?", bun.Ident("id")). | ||||||
| 		Exec(ctx) | 		Exec(ctx) | ||||||
| 	if err != nil { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	a.state.Caches.DB.Token.Invalidate("Refresh", refresh) | 	a.state.Caches.DB.Token.Invalidate("Refresh", refresh) | ||||||
|  | 	a.state.Caches.OnInvalidateToken(&token) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (a *applicationDB) DeleteTokensByClientID(ctx context.Context, clientID string) error { | func (a *applicationDB) DeleteTokensByClientID(ctx context.Context, clientID string) error { | ||||||
|  | 	var tokens []*gtsmodel.Token | ||||||
|  | 
 | ||||||
| 	// Delete tokens owned by | 	// Delete tokens owned by | ||||||
| 	// clientID and gather token IDs. | 	// clientID and gather token IDs. | ||||||
| 	var tokenIDs []string | 	if _, err := a.db.NewDelete(). | ||||||
| 	if _, err := a.db. | 		Model(&tokens). | ||||||
| 		NewDelete(). |  | ||||||
| 		Table("tokens"). |  | ||||||
| 		Where("? = ?", bun.Ident("client_id"), clientID). | 		Where("? = ?", bun.Ident("client_id"), clientID). | ||||||
| 		Returning("id"). | 		Returning("?", bun.Ident("id")). | ||||||
| 		Exec(ctx, &tokenIDs); err != nil { | 		Exec(ctx); err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if len(tokenIDs) == 0 { | 	// Invalidate all deleted tokens. | ||||||
| 		// Nothing was deleted, | 	for _, token := range tokens { | ||||||
| 		// nothing to invalidate. | 		a.state.Caches.DB.Token.Invalidate("ID", token.ID) | ||||||
| 		return nil | 		a.state.Caches.OnInvalidateToken(token) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Invalidate all deleted tokens. |  | ||||||
| 	a.state.Caches.DB.Token.InvalidateIDs("ID", tokenIDs) |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue