mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-30 22:32:25 -05:00 
			
		
		
		
	[performance] cached oauth database types (#2838)
* update token + client code to use struct caches * add code comments * slight tweak to default mem ratios * fix envparsing * add appropriate invalidate hooks * update the tokenstore sweeping function to rely on caches * update to use PutClient() * add ClientID to list of token struct indices
This commit is contained in:
		
					parent
					
						
							
								8b30709791
							
						
					
				
			
			
				commit
				
					
						f79d50b9b2
					
				
			
		
					 18 changed files with 428 additions and 67 deletions
				
			
		|  | @ -98,8 +98,8 @@ var Start action.GTSAction = func(ctx context.Context) error { | ||||||
| 	testrig.StandardStorageSetup(state.Storage, "./testrig/media") | 	testrig.StandardStorageSetup(state.Storage, "./testrig/media") | ||||||
| 
 | 
 | ||||||
| 	// Initialize workers. | 	// Initialize workers. | ||||||
| 	state.Workers.Start() | 	testrig.StartNoopWorkers(&state) | ||||||
| 	defer state.Workers.Stop() | 	defer testrig.StopWorkers(&state) | ||||||
| 
 | 
 | ||||||
| 	// build backend handlers | 	// build backend handlers | ||||||
| 	transportController := testrig.NewTestTransportController(&state, testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) { | 	transportController := testrig.NewTestTransportController(&state, testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) { | ||||||
|  |  | ||||||
|  | @ -49,7 +49,7 @@ func (m *Module) TokenPOSTHandler(c *gin.Context) { | ||||||
| 
 | 
 | ||||||
| 	form := &tokenRequestForm{} | 	form := &tokenRequestForm{} | ||||||
| 	if err := c.ShouldBind(form); err != nil { | 	if err := c.ShouldBind(form); err != nil { | ||||||
| 		apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), err.Error())) | 		apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.ErrInvalidRequest, err.Error())) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -98,7 +98,7 @@ func (m *Module) TokenPOSTHandler(c *gin.Context) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if len(help) != 0 { | 	if len(help) != 0 { | ||||||
| 		apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), help...)) | 		apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.ErrInvalidRequest, help...)) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										4
									
								
								internal/cache/cache.go
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								internal/cache/cache.go
									
										
									
									
										vendored
									
									
								
							|  | @ -59,6 +59,7 @@ func (c *Caches) Init() { | ||||||
| 	c.initBlock() | 	c.initBlock() | ||||||
| 	c.initBlockIDs() | 	c.initBlockIDs() | ||||||
| 	c.initBoostOfIDs() | 	c.initBoostOfIDs() | ||||||
|  | 	c.initClient() | ||||||
| 	c.initDomainAllow() | 	c.initDomainAllow() | ||||||
| 	c.initDomainBlock() | 	c.initDomainBlock() | ||||||
| 	c.initEmoji() | 	c.initEmoji() | ||||||
|  | @ -85,9 +86,10 @@ func (c *Caches) Init() { | ||||||
| 	c.initReport() | 	c.initReport() | ||||||
| 	c.initStatus() | 	c.initStatus() | ||||||
| 	c.initStatusFave() | 	c.initStatusFave() | ||||||
|  | 	c.initStatusFaveIDs() | ||||||
| 	c.initTag() | 	c.initTag() | ||||||
| 	c.initThreadMute() | 	c.initThreadMute() | ||||||
| 	c.initStatusFaveIDs() | 	c.initToken() | ||||||
| 	c.initTombstone() | 	c.initTombstone() | ||||||
| 	c.initUser() | 	c.initUser() | ||||||
| 	c.initWebfinger() | 	c.initWebfinger() | ||||||
|  |  | ||||||
							
								
								
									
										64
									
								
								internal/cache/db.go
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										64
									
								
								internal/cache/db.go
									
										
									
									
										vendored
									
									
								
							|  | @ -58,6 +58,9 @@ type GTSCaches struct { | ||||||
| 	// BoostOfIDs provides access to the boost of IDs list database cache. | 	// BoostOfIDs provides access to the boost of IDs list database cache. | ||||||
| 	BoostOfIDs SliceCache[string] | 	BoostOfIDs SliceCache[string] | ||||||
| 
 | 
 | ||||||
|  | 	// Client provides access to the gtsmodel Client database cache. | ||||||
|  | 	Client StructCache[*gtsmodel.Client] | ||||||
|  | 
 | ||||||
| 	// DomainAllow provides access to the domain allow database cache. | 	// DomainAllow provides access to the domain allow database cache. | ||||||
| 	DomainAllow *domain.Cache | 	DomainAllow *domain.Cache | ||||||
| 
 | 
 | ||||||
|  | @ -150,6 +153,9 @@ type GTSCaches struct { | ||||||
| 	// Tag provides access to the gtsmodel Tag database cache. | 	// Tag provides access to the gtsmodel Tag database cache. | ||||||
| 	Tag StructCache[*gtsmodel.Tag] | 	Tag StructCache[*gtsmodel.Tag] | ||||||
| 
 | 
 | ||||||
|  | 	// Token provides access to the gtsmodel Token database cache. | ||||||
|  | 	Token StructCache[*gtsmodel.Token] | ||||||
|  | 
 | ||||||
| 	// Tombstone provides access to the gtsmodel Tombstone database cache. | 	// Tombstone provides access to the gtsmodel Tombstone database cache. | ||||||
| 	Tombstone StructCache[*gtsmodel.Tombstone] | 	Tombstone StructCache[*gtsmodel.Tombstone] | ||||||
| 
 | 
 | ||||||
|  | @ -312,6 +318,7 @@ func (c *Caches) initApplication() { | ||||||
| 		MaxSize:    cap, | 		MaxSize:    cap, | ||||||
| 		IgnoreErr:  ignoreErrors, | 		IgnoreErr:  ignoreErrors, | ||||||
| 		Copy:       copyF, | 		Copy:       copyF, | ||||||
|  | 		Invalidate: c.OnInvalidateApplication, | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -374,6 +381,32 @@ func (c *Caches) initBoostOfIDs() { | ||||||
| 	c.GTS.BoostOfIDs.Init(0, cap) | 	c.GTS.BoostOfIDs.Init(0, cap) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (c *Caches) initClient() { | ||||||
|  | 	// Calculate maximum cache size. | ||||||
|  | 	cap := calculateResultCacheMax( | ||||||
|  | 		sizeofClient(), // model in-mem size. | ||||||
|  | 		config.GetCacheClientMemRatio(), | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	log.Infof(nil, "cache size = %d", cap) | ||||||
|  | 
 | ||||||
|  | 	copyF := func(c1 *gtsmodel.Client) *gtsmodel.Client { | ||||||
|  | 		c2 := new(gtsmodel.Client) | ||||||
|  | 		*c2 = *c1 | ||||||
|  | 		return c2 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	c.GTS.Client.Init(structr.CacheConfig[*gtsmodel.Client]{ | ||||||
|  | 		Indices: []structr.IndexConfig{ | ||||||
|  | 			{Fields: "ID"}, | ||||||
|  | 		}, | ||||||
|  | 		MaxSize:    cap, | ||||||
|  | 		IgnoreErr:  ignoreErrors, | ||||||
|  | 		Copy:       copyF, | ||||||
|  | 		Invalidate: c.OnInvalidateClient, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (c *Caches) initDomainAllow() { | func (c *Caches) initDomainAllow() { | ||||||
| 	c.GTS.DomainAllow = new(domain.Cache) | 	c.GTS.DomainAllow = new(domain.Cache) | ||||||
| } | } | ||||||
|  | @ -1135,7 +1168,7 @@ func (c *Caches) initTag() { | ||||||
| 
 | 
 | ||||||
| func (c *Caches) initThreadMute() { | func (c *Caches) initThreadMute() { | ||||||
| 	cap := calculateResultCacheMax( | 	cap := calculateResultCacheMax( | ||||||
| 		sizeOfThreadMute(), // model in-mem size. | 		sizeofThreadMute(), // model in-mem size. | ||||||
| 		config.GetCacheThreadMuteMemRatio(), | 		config.GetCacheThreadMuteMemRatio(), | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
|  | @ -1160,6 +1193,35 @@ func (c *Caches) initThreadMute() { | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (c *Caches) initToken() { | ||||||
|  | 	// Calculate maximum cache size. | ||||||
|  | 	cap := calculateResultCacheMax( | ||||||
|  | 		sizeofToken(), // model in-mem size. | ||||||
|  | 		config.GetCacheTokenMemRatio(), | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	log.Infof(nil, "cache size = %d", cap) | ||||||
|  | 
 | ||||||
|  | 	copyF := func(t1 *gtsmodel.Token) *gtsmodel.Token { | ||||||
|  | 		t2 := new(gtsmodel.Token) | ||||||
|  | 		*t2 = *t1 | ||||||
|  | 		return t2 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	c.GTS.Token.Init(structr.CacheConfig[*gtsmodel.Token]{ | ||||||
|  | 		Indices: []structr.IndexConfig{ | ||||||
|  | 			{Fields: "ID"}, | ||||||
|  | 			{Fields: "Code"}, | ||||||
|  | 			{Fields: "Access"}, | ||||||
|  | 			{Fields: "Refresh"}, | ||||||
|  | 			{Fields: "ClientID", Multiple: true}, | ||||||
|  | 		}, | ||||||
|  | 		MaxSize:   cap, | ||||||
|  | 		IgnoreErr: ignoreErrors, | ||||||
|  | 		Copy:      copyF, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (c *Caches) initTombstone() { | func (c *Caches) initTombstone() { | ||||||
| 	// Calculate maximum cache size. | 	// Calculate maximum cache size. | ||||||
| 	cap := calculateResultCacheMax( | 	cap := calculateResultCacheMax( | ||||||
|  |  | ||||||
							
								
								
									
										10
									
								
								internal/cache/invalidate.go
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								internal/cache/invalidate.go
									
										
									
									
										vendored
									
									
								
							|  | @ -60,6 +60,11 @@ func (c *Caches) OnInvalidateAccount(account *gtsmodel.Account) { | ||||||
| 	c.GTS.Move.Invalidate("TargetURI", account.URI) | 	c.GTS.Move.Invalidate("TargetURI", account.URI) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (c *Caches) OnInvalidateApplication(app *gtsmodel.Application) { | ||||||
|  | 	// Invalidate cached client of this application. | ||||||
|  | 	c.GTS.Client.Invalidate("ID", app.ClientID) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) { | func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) { | ||||||
| 	// Invalidate block origin account ID cached visibility. | 	// Invalidate block origin account ID cached visibility. | ||||||
| 	c.Visibility.Invalidate("ItemID", block.AccountID) | 	c.Visibility.Invalidate("ItemID", block.AccountID) | ||||||
|  | @ -73,6 +78,11 @@ func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) { | ||||||
| 	c.GTS.BlockIDs.Invalidate(block.AccountID) | 	c.GTS.BlockIDs.Invalidate(block.AccountID) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (c *Caches) OnInvalidateClient(client *gtsmodel.Client) { | ||||||
|  | 	// Invalidate any tokens under this client. | ||||||
|  | 	c.GTS.Token.Invalidate("ClientID", client.ID) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (c *Caches) OnInvalidateEmojiCategory(category *gtsmodel.EmojiCategory) { | func (c *Caches) OnInvalidateEmojiCategory(category *gtsmodel.EmojiCategory) { | ||||||
| 	// Invalidate any emoji in this category. | 	// Invalidate any emoji in this category. | ||||||
| 	c.GTS.Emoji.Invalidate("CategoryID", category.ID) | 	c.GTS.Emoji.Invalidate("CategoryID", category.ID) | ||||||
|  |  | ||||||
							
								
								
									
										38
									
								
								internal/cache/size.go
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										38
									
								
								internal/cache/size.go
									
										
									
									
										vendored
									
									
								
							|  | @ -176,6 +176,7 @@ func totalOfRatios() float64 { | ||||||
| 		config.GetCacheBlockMemRatio() + | 		config.GetCacheBlockMemRatio() + | ||||||
| 		config.GetCacheBlockIDsMemRatio() + | 		config.GetCacheBlockIDsMemRatio() + | ||||||
| 		config.GetCacheBoostOfIDsMemRatio() + | 		config.GetCacheBoostOfIDsMemRatio() + | ||||||
|  | 		config.GetCacheClientMemRatio() + | ||||||
| 		config.GetCacheEmojiMemRatio() + | 		config.GetCacheEmojiMemRatio() + | ||||||
| 		config.GetCacheEmojiCategoryMemRatio() + | 		config.GetCacheEmojiCategoryMemRatio() + | ||||||
| 		config.GetCacheFollowMemRatio() + | 		config.GetCacheFollowMemRatio() + | ||||||
|  | @ -198,6 +199,7 @@ func totalOfRatios() float64 { | ||||||
| 		config.GetCacheStatusFaveIDsMemRatio() + | 		config.GetCacheStatusFaveIDsMemRatio() + | ||||||
| 		config.GetCacheTagMemRatio() + | 		config.GetCacheTagMemRatio() + | ||||||
| 		config.GetCacheThreadMuteMemRatio() + | 		config.GetCacheThreadMuteMemRatio() + | ||||||
|  | 		config.GetCacheTokenMemRatio() + | ||||||
| 		config.GetCacheTombstoneMemRatio() + | 		config.GetCacheTombstoneMemRatio() + | ||||||
| 		config.GetCacheUserMemRatio() + | 		config.GetCacheUserMemRatio() + | ||||||
| 		config.GetCacheWebfingerMemRatio() + | 		config.GetCacheWebfingerMemRatio() + | ||||||
|  | @ -287,6 +289,17 @@ func sizeofBlock() uintptr { | ||||||
| 	})) | 	})) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func sizeofClient() uintptr { | ||||||
|  | 	return uintptr(size.Of(>smodel.Client{ | ||||||
|  | 		ID:        exampleID, | ||||||
|  | 		CreatedAt: exampleTime, | ||||||
|  | 		UpdatedAt: exampleTime, | ||||||
|  | 		Secret:    exampleID, | ||||||
|  | 		Domain:    exampleURI, | ||||||
|  | 		UserID:    exampleID, | ||||||
|  | 	})) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func sizeofEmoji() uintptr { | func sizeofEmoji() uintptr { | ||||||
| 	return uintptr(size.Of(>smodel.Emoji{ | 	return uintptr(size.Of(>smodel.Emoji{ | ||||||
| 		ID:                     exampleID, | 		ID:                     exampleID, | ||||||
|  | @ -591,7 +604,7 @@ func sizeofTag() uintptr { | ||||||
| 	})) | 	})) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func sizeOfThreadMute() uintptr { | func sizeofThreadMute() uintptr { | ||||||
| 	return uintptr(size.Of(>smodel.ThreadMute{ | 	return uintptr(size.Of(>smodel.ThreadMute{ | ||||||
| 		ID:        exampleID, | 		ID:        exampleID, | ||||||
| 		CreatedAt: exampleTime, | 		CreatedAt: exampleTime, | ||||||
|  | @ -601,6 +614,29 @@ func sizeOfThreadMute() uintptr { | ||||||
| 	})) | 	})) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func sizeofToken() uintptr { | ||||||
|  | 	return uintptr(size.Of(>smodel.Token{ | ||||||
|  | 		ID:                  exampleID, | ||||||
|  | 		CreatedAt:           exampleTime, | ||||||
|  | 		UpdatedAt:           exampleTime, | ||||||
|  | 		ClientID:            exampleID, | ||||||
|  | 		UserID:              exampleID, | ||||||
|  | 		RedirectURI:         exampleURI, | ||||||
|  | 		Scope:               "r:w", | ||||||
|  | 		Code:                "", // TODO | ||||||
|  | 		CodeChallenge:       "", // TODO | ||||||
|  | 		CodeChallengeMethod: "", // TODO | ||||||
|  | 		CodeCreateAt:        exampleTime, | ||||||
|  | 		CodeExpiresAt:       exampleTime, | ||||||
|  | 		Access:              exampleID + exampleID, | ||||||
|  | 		AccessCreateAt:      exampleTime, | ||||||
|  | 		AccessExpiresAt:     exampleTime, | ||||||
|  | 		Refresh:             "", // TODO: clients don't really support this very well yet | ||||||
|  | 		RefreshCreateAt:     exampleTime, | ||||||
|  | 		RefreshExpiresAt:    exampleTime, | ||||||
|  | 	})) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func sizeofTombstone() uintptr { | func sizeofTombstone() uintptr { | ||||||
| 	return uintptr(size.Of(>smodel.Tombstone{ | 	return uintptr(size.Of(>smodel.Tombstone{ | ||||||
| 		ID:        exampleID, | 		ID:        exampleID, | ||||||
|  |  | ||||||
|  | @ -199,6 +199,7 @@ type CacheConfiguration struct { | ||||||
| 	BlockMemRatio            float64       `name:"block-mem-ratio"` | 	BlockMemRatio            float64       `name:"block-mem-ratio"` | ||||||
| 	BlockIDsMemRatio         float64       `name:"block-mem-ratio"` | 	BlockIDsMemRatio         float64       `name:"block-mem-ratio"` | ||||||
| 	BoostOfIDsMemRatio       float64       `name:"boost-of-ids-mem-ratio"` | 	BoostOfIDsMemRatio       float64       `name:"boost-of-ids-mem-ratio"` | ||||||
|  | 	ClientMemRatio           float64       `name:"client-mem-ratio"` | ||||||
| 	EmojiMemRatio            float64       `name:"emoji-mem-ratio"` | 	EmojiMemRatio            float64       `name:"emoji-mem-ratio"` | ||||||
| 	EmojiCategoryMemRatio    float64       `name:"emoji-category-mem-ratio"` | 	EmojiCategoryMemRatio    float64       `name:"emoji-category-mem-ratio"` | ||||||
| 	FilterMemRatio           float64       `name:"filter-mem-ratio"` | 	FilterMemRatio           float64       `name:"filter-mem-ratio"` | ||||||
|  | @ -226,6 +227,7 @@ type CacheConfiguration struct { | ||||||
| 	StatusFaveIDsMemRatio    float64       `name:"status-fave-ids-mem-ratio"` | 	StatusFaveIDsMemRatio    float64       `name:"status-fave-ids-mem-ratio"` | ||||||
| 	TagMemRatio              float64       `name:"tag-mem-ratio"` | 	TagMemRatio              float64       `name:"tag-mem-ratio"` | ||||||
| 	ThreadMuteMemRatio       float64       `name:"thread-mute-mem-ratio"` | 	ThreadMuteMemRatio       float64       `name:"thread-mute-mem-ratio"` | ||||||
|  | 	TokenMemRatio            float64       `name:"token-mem-ratio"` | ||||||
| 	TombstoneMemRatio        float64       `name:"tombstone-mem-ratio"` | 	TombstoneMemRatio        float64       `name:"tombstone-mem-ratio"` | ||||||
| 	UserMemRatio             float64       `name:"user-mem-ratio"` | 	UserMemRatio             float64       `name:"user-mem-ratio"` | ||||||
| 	WebfingerMemRatio        float64       `name:"webfinger-mem-ratio"` | 	WebfingerMemRatio        float64       `name:"webfinger-mem-ratio"` | ||||||
|  |  | ||||||
|  | @ -163,6 +163,7 @@ var Defaults = Configuration{ | ||||||
| 		BlockMemRatio:            2, | 		BlockMemRatio:            2, | ||||||
| 		BlockIDsMemRatio:         3, | 		BlockIDsMemRatio:         3, | ||||||
| 		BoostOfIDsMemRatio:       3, | 		BoostOfIDsMemRatio:       3, | ||||||
|  | 		ClientMemRatio:           0.1, | ||||||
| 		EmojiMemRatio:            3, | 		EmojiMemRatio:            3, | ||||||
| 		EmojiCategoryMemRatio:    0.1, | 		EmojiCategoryMemRatio:    0.1, | ||||||
| 		FilterMemRatio:           0.5, | 		FilterMemRatio:           0.5, | ||||||
|  | @ -190,6 +191,7 @@ var Defaults = Configuration{ | ||||||
| 		StatusFaveIDsMemRatio:    3, | 		StatusFaveIDsMemRatio:    3, | ||||||
| 		TagMemRatio:              2, | 		TagMemRatio:              2, | ||||||
| 		ThreadMuteMemRatio:       0.2, | 		ThreadMuteMemRatio:       0.2, | ||||||
|  | 		TokenMemRatio:            0.75, | ||||||
| 		TombstoneMemRatio:        0.5, | 		TombstoneMemRatio:        0.5, | ||||||
| 		UserMemRatio:             0.25, | 		UserMemRatio:             0.25, | ||||||
| 		WebfingerMemRatio:        0.1, | 		WebfingerMemRatio:        0.1, | ||||||
|  |  | ||||||
|  | @ -2925,6 +2925,31 @@ func GetCacheBoostOfIDsMemRatio() float64 { return global.GetCacheBoostOfIDsMemR | ||||||
| // SetCacheBoostOfIDsMemRatio safely sets the value for global configuration 'Cache.BoostOfIDsMemRatio' field | // SetCacheBoostOfIDsMemRatio safely sets the value for global configuration 'Cache.BoostOfIDsMemRatio' field | ||||||
| func SetCacheBoostOfIDsMemRatio(v float64) { global.SetCacheBoostOfIDsMemRatio(v) } | func SetCacheBoostOfIDsMemRatio(v float64) { global.SetCacheBoostOfIDsMemRatio(v) } | ||||||
| 
 | 
 | ||||||
|  | // GetCacheClientMemRatio safely fetches the Configuration value for state's 'Cache.ClientMemRatio' field | ||||||
|  | func (st *ConfigState) GetCacheClientMemRatio() (v float64) { | ||||||
|  | 	st.mutex.RLock() | ||||||
|  | 	v = st.config.Cache.ClientMemRatio | ||||||
|  | 	st.mutex.RUnlock() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SetCacheClientMemRatio safely sets the Configuration value for state's 'Cache.ClientMemRatio' field | ||||||
|  | func (st *ConfigState) SetCacheClientMemRatio(v float64) { | ||||||
|  | 	st.mutex.Lock() | ||||||
|  | 	defer st.mutex.Unlock() | ||||||
|  | 	st.config.Cache.ClientMemRatio = v | ||||||
|  | 	st.reloadToViper() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CacheClientMemRatioFlag returns the flag name for the 'Cache.ClientMemRatio' field | ||||||
|  | func CacheClientMemRatioFlag() string { return "cache-client-mem-ratio" } | ||||||
|  | 
 | ||||||
|  | // GetCacheClientMemRatio safely fetches the value for global configuration 'Cache.ClientMemRatio' field | ||||||
|  | func GetCacheClientMemRatio() float64 { return global.GetCacheClientMemRatio() } | ||||||
|  | 
 | ||||||
|  | // SetCacheClientMemRatio safely sets the value for global configuration 'Cache.ClientMemRatio' field | ||||||
|  | func SetCacheClientMemRatio(v float64) { global.SetCacheClientMemRatio(v) } | ||||||
|  | 
 | ||||||
| // GetCacheEmojiMemRatio safely fetches the Configuration value for state's 'Cache.EmojiMemRatio' field | // GetCacheEmojiMemRatio safely fetches the Configuration value for state's 'Cache.EmojiMemRatio' field | ||||||
| func (st *ConfigState) GetCacheEmojiMemRatio() (v float64) { | func (st *ConfigState) GetCacheEmojiMemRatio() (v float64) { | ||||||
| 	st.mutex.RLock() | 	st.mutex.RLock() | ||||||
|  | @ -3600,6 +3625,31 @@ func GetCacheThreadMuteMemRatio() float64 { return global.GetCacheThreadMuteMemR | ||||||
| // SetCacheThreadMuteMemRatio safely sets the value for global configuration 'Cache.ThreadMuteMemRatio' field | // SetCacheThreadMuteMemRatio safely sets the value for global configuration 'Cache.ThreadMuteMemRatio' field | ||||||
| func SetCacheThreadMuteMemRatio(v float64) { global.SetCacheThreadMuteMemRatio(v) } | func SetCacheThreadMuteMemRatio(v float64) { global.SetCacheThreadMuteMemRatio(v) } | ||||||
| 
 | 
 | ||||||
|  | // GetCacheTokenMemRatio safely fetches the Configuration value for state's 'Cache.TokenMemRatio' field | ||||||
|  | func (st *ConfigState) GetCacheTokenMemRatio() (v float64) { | ||||||
|  | 	st.mutex.RLock() | ||||||
|  | 	v = st.config.Cache.TokenMemRatio | ||||||
|  | 	st.mutex.RUnlock() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SetCacheTokenMemRatio safely sets the Configuration value for state's 'Cache.TokenMemRatio' field | ||||||
|  | func (st *ConfigState) SetCacheTokenMemRatio(v float64) { | ||||||
|  | 	st.mutex.Lock() | ||||||
|  | 	defer st.mutex.Unlock() | ||||||
|  | 	st.config.Cache.TokenMemRatio = v | ||||||
|  | 	st.reloadToViper() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CacheTokenMemRatioFlag returns the flag name for the 'Cache.TokenMemRatio' field | ||||||
|  | func CacheTokenMemRatioFlag() string { return "cache-token-mem-ratio" } | ||||||
|  | 
 | ||||||
|  | // GetCacheTokenMemRatio safely fetches the value for global configuration 'Cache.TokenMemRatio' field | ||||||
|  | func GetCacheTokenMemRatio() float64 { return global.GetCacheTokenMemRatio() } | ||||||
|  | 
 | ||||||
|  | // SetCacheTokenMemRatio safely sets the value for global configuration 'Cache.TokenMemRatio' field | ||||||
|  | func SetCacheTokenMemRatio(v float64) { global.SetCacheTokenMemRatio(v) } | ||||||
|  | 
 | ||||||
| // GetCacheTombstoneMemRatio safely fetches the Configuration value for state's 'Cache.TombstoneMemRatio' field | // GetCacheTombstoneMemRatio safely fetches the Configuration value for state's 'Cache.TombstoneMemRatio' field | ||||||
| func (st *ConfigState) GetCacheTombstoneMemRatio() (v float64) { | func (st *ConfigState) GetCacheTombstoneMemRatio() (v float64) { | ||||||
| 	st.mutex.RLock() | 	st.mutex.RLock() | ||||||
|  |  | ||||||
|  | @ -35,4 +35,40 @@ type Application interface { | ||||||
| 
 | 
 | ||||||
| 	// DeleteApplicationByClientID deletes the application with corresponding client_id value from the database. | 	// DeleteApplicationByClientID deletes the application with corresponding client_id value from the database. | ||||||
| 	DeleteApplicationByClientID(ctx context.Context, clientID string) error | 	DeleteApplicationByClientID(ctx context.Context, clientID string) error | ||||||
|  | 
 | ||||||
|  | 	// GetClientByID ... | ||||||
|  | 	GetClientByID(ctx context.Context, id string) (*gtsmodel.Client, error) | ||||||
|  | 
 | ||||||
|  | 	// PutClient ... | ||||||
|  | 	PutClient(ctx context.Context, client *gtsmodel.Client) error | ||||||
|  | 
 | ||||||
|  | 	// DeleteClientByID ... | ||||||
|  | 	DeleteClientByID(ctx context.Context, id string) error | ||||||
|  | 
 | ||||||
|  | 	// GetAllTokens ... | ||||||
|  | 	GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) | ||||||
|  | 
 | ||||||
|  | 	// GetTokenByCode ... | ||||||
|  | 	GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) | ||||||
|  | 
 | ||||||
|  | 	// GetTokenByAccess ... | ||||||
|  | 	GetTokenByAccess(ctx context.Context, access string) (*gtsmodel.Token, error) | ||||||
|  | 
 | ||||||
|  | 	// GetTokenByRefresh ... | ||||||
|  | 	GetTokenByRefresh(ctx context.Context, refresh string) (*gtsmodel.Token, error) | ||||||
|  | 
 | ||||||
|  | 	// PutToken ... | ||||||
|  | 	PutToken(ctx context.Context, token *gtsmodel.Token) error | ||||||
|  | 
 | ||||||
|  | 	// DeleteTokenByID ... | ||||||
|  | 	DeleteTokenByID(ctx context.Context, id string) error | ||||||
|  | 
 | ||||||
|  | 	// DeleteTokenByCode ... | ||||||
|  | 	DeleteTokenByCode(ctx context.Context, code string) error | ||||||
|  | 
 | ||||||
|  | 	// DeleteTokenByAccess ... | ||||||
|  | 	DeleteTokenByAccess(ctx context.Context, access string) error | ||||||
|  | 
 | ||||||
|  | 	// DeleteTokenByRefresh ... | ||||||
|  | 	DeleteTokenByRefresh(ctx context.Context, refresh string) error | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -397,7 +397,7 @@ func (a *adminDB) CreateInstanceApplication(ctx context.Context) error { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Store it. | 	// Store it. | ||||||
| 	return a.state.DB.Put(ctx, oc) | 	return a.state.DB.PutClient(ctx, oc) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (a *adminDB) GetInstanceApplication(ctx context.Context) (*gtsmodel.Application, error) { | func (a *adminDB) GetInstanceApplication(ctx context.Context) (*gtsmodel.Application, error) { | ||||||
|  |  | ||||||
|  | @ -22,6 +22,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/state" | 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||||||
| 	"github.com/uptrace/bun" | 	"github.com/uptrace/bun" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -95,3 +96,181 @@ func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientI | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func (a *applicationDB) GetClientByID(ctx context.Context, id string) (*gtsmodel.Client, error) { | ||||||
|  | 	return a.state.Caches.GTS.Client.LoadOne("ID", func() (*gtsmodel.Client, error) { | ||||||
|  | 		var client gtsmodel.Client | ||||||
|  | 
 | ||||||
|  | 		if err := a.db.NewSelect(). | ||||||
|  | 			Model(&client). | ||||||
|  | 			Where("? = ?", bun.Ident("id"), id). | ||||||
|  | 			Scan(ctx); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return &client, nil | ||||||
|  | 	}, id) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *applicationDB) PutClient(ctx context.Context, client *gtsmodel.Client) error { | ||||||
|  | 	return a.state.Caches.GTS.Client.Store(client, func() error { | ||||||
|  | 		_, err := a.db.NewInsert().Model(client).Exec(ctx) | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *applicationDB) DeleteClientByID(ctx context.Context, id string) error { | ||||||
|  | 	_, err := a.db.NewDelete(). | ||||||
|  | 		Table("clients"). | ||||||
|  | 		Where("? = ?", bun.Ident("id"), id). | ||||||
|  | 		Exec(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	a.state.Caches.GTS.Client.Invalidate("ID", id) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *applicationDB) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) { | ||||||
|  | 	var tokenIDs []string | ||||||
|  | 
 | ||||||
|  | 	// Select ALL token IDs. | ||||||
|  | 	if err := a.db.NewSelect(). | ||||||
|  | 		Table("tokens"). | ||||||
|  | 		Column("id"). | ||||||
|  | 		Scan(ctx, &tokenIDs); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Load all input token IDs via cache loader callback. | ||||||
|  | 	tokens, err := a.state.Caches.GTS.Token.LoadIDs("ID", | ||||||
|  | 		tokenIDs, | ||||||
|  | 		func(uncached []string) ([]*gtsmodel.Token, error) { | ||||||
|  | 			// Preallocate expected length of uncached tokens. | ||||||
|  | 			tokens := make([]*gtsmodel.Token, 0, len(uncached)) | ||||||
|  | 
 | ||||||
|  | 			// Perform database query scanning | ||||||
|  | 			// the remaining (uncached) token IDs. | ||||||
|  | 			if err := a.db.NewSelect(). | ||||||
|  | 				Model(tokens). | ||||||
|  | 				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). | ||||||
|  | 				Scan(ctx); err != nil { | ||||||
|  | 				return nil, err | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			return tokens, nil | ||||||
|  | 		}, | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Reoroder the tokens by their | ||||||
|  | 	// IDs to ensure in correct order. | ||||||
|  | 	getID := func(t *gtsmodel.Token) string { return t.ID } | ||||||
|  | 	util.OrderBy(tokens, tokenIDs, getID) | ||||||
|  | 
 | ||||||
|  | 	return tokens, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *applicationDB) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) { | ||||||
|  | 	return a.getTokenBy( | ||||||
|  | 		"Code", | ||||||
|  | 		func(t *gtsmodel.Token) error { | ||||||
|  | 			return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("code"), code).Scan(ctx) | ||||||
|  | 		}, | ||||||
|  | 		code, | ||||||
|  | 	) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *applicationDB) GetTokenByAccess(ctx context.Context, access string) (*gtsmodel.Token, error) { | ||||||
|  | 	return a.getTokenBy( | ||||||
|  | 		"Access", | ||||||
|  | 		func(t *gtsmodel.Token) error { | ||||||
|  | 			return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("access"), access).Scan(ctx) | ||||||
|  | 		}, | ||||||
|  | 		access, | ||||||
|  | 	) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *applicationDB) GetTokenByRefresh(ctx context.Context, refresh string) (*gtsmodel.Token, error) { | ||||||
|  | 	return a.getTokenBy( | ||||||
|  | 		"Refresh", | ||||||
|  | 		func(t *gtsmodel.Token) error { | ||||||
|  | 			return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("refresh"), refresh).Scan(ctx) | ||||||
|  | 		}, | ||||||
|  | 		refresh, | ||||||
|  | 	) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *applicationDB) getTokenBy(lookup string, dbQuery func(*gtsmodel.Token) error, keyParts ...any) (*gtsmodel.Token, error) { | ||||||
|  | 	return a.state.Caches.GTS.Token.LoadOne(lookup, func() (*gtsmodel.Token, error) { | ||||||
|  | 		var token gtsmodel.Token | ||||||
|  | 
 | ||||||
|  | 		if err := dbQuery(&token); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return &token, nil | ||||||
|  | 	}, keyParts...) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *applicationDB) PutToken(ctx context.Context, token *gtsmodel.Token) error { | ||||||
|  | 	return a.state.Caches.GTS.Token.Store(token, func() error { | ||||||
|  | 		_, err := a.db.NewInsert().Model(token).Exec(ctx) | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error { | ||||||
|  | 	_, err := a.db.NewDelete(). | ||||||
|  | 		Table("tokens"). | ||||||
|  | 		Where("? = ?", bun.Ident("id"), id). | ||||||
|  | 		Exec(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	a.state.Caches.GTS.Token.Invalidate("ID", id) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *applicationDB) DeleteTokenByCode(ctx context.Context, code string) error { | ||||||
|  | 	_, err := a.db.NewDelete(). | ||||||
|  | 		Table("tokens"). | ||||||
|  | 		Where("? = ?", bun.Ident("code"), code). | ||||||
|  | 		Exec(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	a.state.Caches.GTS.Token.Invalidate("Code", code) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *applicationDB) DeleteTokenByAccess(ctx context.Context, access string) error { | ||||||
|  | 	_, err := a.db.NewDelete(). | ||||||
|  | 		Table("tokens"). | ||||||
|  | 		Where("? = ?", bun.Ident("access"), access). | ||||||
|  | 		Exec(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	a.state.Caches.GTS.Token.Invalidate("Access", access) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *applicationDB) DeleteTokenByRefresh(ctx context.Context, refresh string) error { | ||||||
|  | 	_, err := a.db.NewDelete(). | ||||||
|  | 		Table("tokens"). | ||||||
|  | 		Where("? = ?", bun.Ident("refresh"), refresh). | ||||||
|  | 		Exec(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	a.state.Caches.GTS.Token.Invalidate("Refresh", refresh) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -27,11 +27,11 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type clientStore struct { | type clientStore struct { | ||||||
| 	db db.Basic | 	db db.DB | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend. | // NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend. | ||||||
| func NewClientStore(db db.Basic) oauth2.ClientStore { | func NewClientStore(db db.DB) oauth2.ClientStore { | ||||||
| 	pts := &clientStore{ | 	pts := &clientStore{ | ||||||
| 		db: db, | 		db: db, | ||||||
| 	} | 	} | ||||||
|  | @ -39,26 +39,27 @@ func NewClientStore(db db.Basic) oauth2.ClientStore { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { | func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { | ||||||
| 	poc := >smodel.Client{} | 	client, err := cs.db.GetClientByID(ctx, clientID) | ||||||
| 	if err := cs.db.GetByID(ctx, clientID, poc); err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil | 	return models.New( | ||||||
|  | 		client.ID, | ||||||
|  | 		client.Secret, | ||||||
|  | 		client.Domain, | ||||||
|  | 		client.UserID, | ||||||
|  | 	), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { | func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { | ||||||
| 	poc := >smodel.Client{ | 	return cs.db.PutClient(ctx, >smodel.Client{ | ||||||
| 		ID:     cli.GetID(), | 		ID:     cli.GetID(), | ||||||
| 		Secret: cli.GetSecret(), | 		Secret: cli.GetSecret(), | ||||||
| 		Domain: cli.GetDomain(), | 		Domain: cli.GetDomain(), | ||||||
| 		UserID: cli.GetUserID(), | 		UserID: cli.GetUserID(), | ||||||
| 	} | 	}) | ||||||
| 	return cs.db.Put(ctx, poc) |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (cs *clientStore) Delete(ctx context.Context, id string) error { | func (cs *clientStore) Delete(ctx context.Context, id string) error { | ||||||
| 	poc := >smodel.Client{ | 	return cs.db.DeleteClientByID(ctx, id) | ||||||
| 		ID: id, |  | ||||||
| 	} |  | ||||||
| 	return cs.db.DeleteByID(ctx, id, poc) |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -19,7 +19,5 @@ package oauth | ||||||
| 
 | 
 | ||||||
| import "github.com/superseriousbusiness/oauth2/v4/errors" | import "github.com/superseriousbusiness/oauth2/v4/errors" | ||||||
| 
 | 
 | ||||||
| // InvalidRequest returns an oauth spec compliant 'invalid_request' error. | // ErrInvalidRequest is an oauth spec compliant 'invalid_request' error. | ||||||
| func InvalidRequest() error { | var ErrInvalidRequest = errors.New("invalid_request") | ||||||
| 	return errors.New("invalid_request") |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -75,7 +75,7 @@ type s struct { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // New returns a new oauth server that implements the Server interface | // New returns a new oauth server that implements the Server interface | ||||||
| func New(ctx context.Context, database db.Basic) Server { | func New(ctx context.Context, database db.DB) Server { | ||||||
| 	ts := newTokenStore(ctx, database) | 	ts := newTokenStore(ctx, database) | ||||||
| 	cs := NewClientStore(database) | 	cs := NewClientStore(database) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -20,7 +20,6 @@ package oauth | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | @ -34,14 +33,14 @@ import ( | ||||||
| // tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend. | // tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend. | ||||||
| type tokenStore struct { | type tokenStore struct { | ||||||
| 	oauth2.TokenStore | 	oauth2.TokenStore | ||||||
| 	db db.Basic | 	db db.DB | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newTokenStore returns a token store that satisfies the oauth2.TokenStore interface. | // newTokenStore returns a token store that satisfies the oauth2.TokenStore interface. | ||||||
| // | // | ||||||
| // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through | // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through | ||||||
| // the tokens in the DB once per minute and deletes any that have expired. | // the tokens in the DB once per minute and deletes any that have expired. | ||||||
| func newTokenStore(ctx context.Context, db db.Basic) oauth2.TokenStore { | func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore { | ||||||
| 	ts := &tokenStore{ | 	ts := &tokenStore{ | ||||||
| 		db: db, | 		db: db, | ||||||
| 	} | 	} | ||||||
|  | @ -69,19 +68,19 @@ func newTokenStore(ctx context.Context, db db.Basic) oauth2.TokenStore { | ||||||
| func (ts *tokenStore) sweep(ctx context.Context) error { | func (ts *tokenStore) sweep(ctx context.Context) error { | ||||||
| 	// select *all* tokens from the db | 	// select *all* tokens from the db | ||||||
| 	// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. | 	// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. | ||||||
| 	tokens := new([]*gtsmodel.Token) | 	tokens, err := ts.db.GetAllTokens(ctx) | ||||||
| 	if err := ts.db.GetAll(ctx, tokens); err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// iterate through and remove expired tokens | 	// iterate through and remove expired tokens | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 	for _, dbt := range *tokens { | 	for _, dbt := range tokens { | ||||||
| 		// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: | 		// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: | ||||||
| 		// we only want to check if a token expired before now if the expiry time is *not zero*; | 		// we only want to check if a token expired before now if the expiry time is *not zero*; | ||||||
| 		// ie., if it's been explicity set. | 		// ie., if it's been explicity set. | ||||||
| 		if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) { | 		if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) { | ||||||
| 			if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil { | 			if err := ts.db.DeleteTokenByID(ctx, dbt.ID); err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | @ -107,67 +106,49 @@ func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { | ||||||
| 		dbt.ID = dbtID | 		dbt.ID = dbtID | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := ts.db.Put(ctx, dbt); err != nil { | 	return ts.db.PutToken(ctx, dbt) | ||||||
| 		return fmt.Errorf("error in tokenstore create: %s", err) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RemoveByCode deletes a token from the DB based on the Code field | // RemoveByCode deletes a token from the DB based on the Code field | ||||||
| func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error { | func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error { | ||||||
| 	return ts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, >smodel.Token{}) | 	return ts.db.DeleteTokenByCode(ctx, code) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RemoveByAccess deletes a token from the DB based on the Access field | // RemoveByAccess deletes a token from the DB based on the Access field | ||||||
| func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { | func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { | ||||||
| 	return ts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, >smodel.Token{}) | 	return ts.db.DeleteTokenByAccess(ctx, access) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RemoveByRefresh deletes a token from the DB based on the Refresh field | // RemoveByRefresh deletes a token from the DB based on the Refresh field | ||||||
| func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { | func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { | ||||||
| 	return ts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, >smodel.Token{}) | 	return ts.db.DeleteTokenByRefresh(ctx, refresh) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetByCode selects a token from the DB based on the Code field | // GetByCode selects a token from the DB based on the Code field | ||||||
| func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { | func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { | ||||||
| 	if code == "" { | 	token, err := ts.db.GetTokenByCode(ctx, code) | ||||||
| 		return nil, nil | 	if err != nil { | ||||||
| 	} |  | ||||||
| 	dbt := >smodel.Token{ |  | ||||||
| 		Code: code, |  | ||||||
| 	} |  | ||||||
| 	if err := ts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, dbt); err != nil { |  | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return DBTokenToToken(dbt), nil | 	return DBTokenToToken(token), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetByAccess selects a token from the DB based on the Access field | // GetByAccess selects a token from the DB based on the Access field | ||||||
| func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { | func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { | ||||||
| 	if access == "" { | 	token, err := ts.db.GetTokenByAccess(ctx, access) | ||||||
| 		return nil, nil | 	if err != nil { | ||||||
| 	} |  | ||||||
| 	dbt := >smodel.Token{ |  | ||||||
| 		Access: access, |  | ||||||
| 	} |  | ||||||
| 	if err := ts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, dbt); err != nil { |  | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return DBTokenToToken(dbt), nil | 	return DBTokenToToken(token), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetByRefresh selects a token from the DB based on the Refresh field | // GetByRefresh selects a token from the DB based on the Refresh field | ||||||
| func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { | func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { | ||||||
| 	if refresh == "" { | 	token, err := ts.db.GetTokenByRefresh(ctx, refresh) | ||||||
| 		return nil, nil | 	if err != nil { | ||||||
| 	} |  | ||||||
| 	dbt := >smodel.Token{ |  | ||||||
| 		Refresh: refresh, |  | ||||||
| 	} |  | ||||||
| 	if err := ts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, dbt); err != nil { |  | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return DBTokenToToken(dbt), nil | 	return DBTokenToToken(token), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* | /* | ||||||
|  |  | ||||||
|  | @ -75,7 +75,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// chuck it in the db | 	// chuck it in the db | ||||||
| 	if err := p.state.DB.Put(ctx, oc); err != nil { | 	if err := p.state.DB.PutClient(ctx, oc); err != nil { | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -29,6 +29,7 @@ EXPECT=$(cat << "EOF" | ||||||
|         "application-mem-ratio": 0.1, |         "application-mem-ratio": 0.1, | ||||||
|         "block-mem-ratio": 3, |         "block-mem-ratio": 3, | ||||||
|         "boost-of-ids-mem-ratio": 3, |         "boost-of-ids-mem-ratio": 3, | ||||||
|  |         "client-mem-ratio": 0.1, | ||||||
|         "emoji-category-mem-ratio": 0.1, |         "emoji-category-mem-ratio": 0.1, | ||||||
|         "emoji-mem-ratio": 3, |         "emoji-mem-ratio": 3, | ||||||
|         "filter-keyword-mem-ratio": 0.5, |         "filter-keyword-mem-ratio": 0.5, | ||||||
|  | @ -57,6 +58,7 @@ EXPECT=$(cat << "EOF" | ||||||
|         "status-mem-ratio": 5, |         "status-mem-ratio": 5, | ||||||
|         "tag-mem-ratio": 2, |         "tag-mem-ratio": 2, | ||||||
|         "thread-mute-mem-ratio": 0.2, |         "thread-mute-mem-ratio": 0.2, | ||||||
|  |         "token-mem-ratio": 0.75, | ||||||
|         "tombstone-mem-ratio": 0.5, |         "tombstone-mem-ratio": 0.5, | ||||||
|         "user-mem-ratio": 0.25, |         "user-mem-ratio": 0.25, | ||||||
|         "visibility-mem-ratio": 2, |         "visibility-mem-ratio": 2, | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue