[performance] overhaul struct (+ result) caching library for simplicity, performance and multiple-result lookups (#2535)

* rewrite cache library as codeberg.org/gruf/go-structr, implement in gotosocial

* use actual go-structr release version (not just commit hash)

* revert go toolchain changes (damn you go for auto changing this)

* fix go mod woes

* ensure %w is used in calls to errs.Appendf()

* fix error checking

* fix possible panic

* remove unnecessary start/stop functions, move to main Cache{} struct, add note regarding which caches require start/stop

* fix copy-paste artifact... 😇

* fix all comment copy-paste artifacts

* remove dropID() function, now we can just use slices.DeleteFunc()

* use util.Deduplicate() instead of collate(), move collate to util

* move orderByIDs() to util package and "generify"

* add a util.DeleteIf() function, use this to delete entries on failed population

* use slices.DeleteFunc() instead of util.DeleteIf() (i had the logic mixed up in my head somehow lol)

* add note about how collate differs from deduplicate
This commit is contained in:
kim 2024-01-19 12:57:29 +00:00 committed by GitHub
commit 7ec1e1332e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
66 changed files with 4038 additions and 2711 deletions

30
internal/cache/ap.go vendored
View file

@ -1,30 +0,0 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package cache
type APCaches struct{}
// Init will initialize all the ActivityPub caches in this collection.
// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
func (c *APCaches) Init() {}
// Start will attempt to start all of the ActivityPub caches, or panic.
func (c *APCaches) Start() {}
// Stop will attempt to stop all of the ActivityPub caches, or panic.
func (c *APCaches) Stop() {}

View file

@ -18,8 +18,9 @@
package cache
import (
"time"
"github.com/superseriousbusiness/gotosocial/internal/cache/headerfilter"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
@ -49,198 +50,59 @@ type Caches struct {
func (c *Caches) Init() {
log.Infof(nil, "init: %p", c)
c.GTS.Init()
c.Visibility.Init()
// Setup cache invalidate hooks.
// !! READ THE METHOD COMMENT
c.setuphooks()
c.initAccount()
c.initAccountNote()
c.initApplication()
c.initBlock()
c.initBlockIDs()
c.initBoostOfIDs()
c.initDomainAllow()
c.initDomainBlock()
c.initEmoji()
c.initEmojiCategory()
c.initFollow()
c.initFollowIDs()
c.initFollowRequest()
c.initFollowRequestIDs()
c.initInReplyToIDs()
c.initInstance()
c.initList()
c.initListEntry()
c.initMarker()
c.initMedia()
c.initMention()
c.initNotification()
c.initPoll()
c.initPollVote()
c.initPollVoteIDs()
c.initReport()
c.initStatus()
c.initStatusFave()
c.initTag()
c.initThreadMute()
c.initStatusFaveIDs()
c.initTombstone()
c.initUser()
c.initWebfinger()
c.initVisibility()
}
// Start will start both the GTS and AP cache collections.
// Start will start any caches that require a background
// routine, which usually means any kind of TTL caches.
func (c *Caches) Start() {
log.Infof(nil, "start: %p", c)
c.GTS.Start()
c.Visibility.Start()
tryUntil("starting *gtsmodel.Webfinger cache", 5, func() bool {
return c.GTS.Webfinger.Start(5 * time.Minute)
})
}
// Stop will stop both the GTS and AP cache collections.
// Stop will stop any caches that require a background
// routine, which usually means any kind of TTL caches.
func (c *Caches) Stop() {
log.Infof(nil, "stop: %p", c)
c.GTS.Stop()
c.Visibility.Stop()
}
// setuphooks sets necessary cache invalidation hooks between caches,
// as an invalidation indicates a database INSERT / UPDATE / DELETE.
// NOTE THEY ARE ONLY CALLED WHEN THE ITEM IS IN THE CACHE, SO FOR
// HOOKS TO BE CALLED ON DELETE YOU MUST FIRST POPULATE IT IN THE CACHE.
func (c *Caches) setuphooks() {
c.GTS.Account().SetInvalidateCallback(func(account *gtsmodel.Account) {
// Invalidate account ID cached visibility.
c.Visibility.Invalidate("ItemID", account.ID)
c.Visibility.Invalidate("RequesterID", account.ID)
// Invalidate this account's
// following / follower lists.
// (see FollowIDs() comment for details).
c.GTS.FollowIDs().InvalidateAll(
">"+account.ID,
"l>"+account.ID,
"<"+account.ID,
"l<"+account.ID,
)
// Invalidate this account's
// follow requesting / request lists.
// (see FollowRequestIDs() comment for details).
c.GTS.FollowRequestIDs().InvalidateAll(
">"+account.ID,
"<"+account.ID,
)
// Invalidate this account's block lists.
c.GTS.BlockIDs().Invalidate(account.ID)
})
c.GTS.Block().SetInvalidateCallback(func(block *gtsmodel.Block) {
// Invalidate block origin account ID cached visibility.
c.Visibility.Invalidate("ItemID", block.AccountID)
c.Visibility.Invalidate("RequesterID", block.AccountID)
// Invalidate block target account ID cached visibility.
c.Visibility.Invalidate("ItemID", block.TargetAccountID)
c.Visibility.Invalidate("RequesterID", block.TargetAccountID)
// Invalidate source account's block lists.
c.GTS.BlockIDs().Invalidate(block.AccountID)
})
c.GTS.EmojiCategory().SetInvalidateCallback(func(category *gtsmodel.EmojiCategory) {
// Invalidate any emoji in this category.
c.GTS.Emoji().Invalidate("CategoryID", category.ID)
})
c.GTS.Follow().SetInvalidateCallback(func(follow *gtsmodel.Follow) {
// Invalidate follow request with this same ID.
c.GTS.FollowRequest().Invalidate("ID", follow.ID)
// Invalidate any related list entries.
c.GTS.ListEntry().Invalidate("FollowID", follow.ID)
// Invalidate follow origin account ID cached visibility.
c.Visibility.Invalidate("ItemID", follow.AccountID)
c.Visibility.Invalidate("RequesterID", follow.AccountID)
// Invalidate follow target account ID cached visibility.
c.Visibility.Invalidate("ItemID", follow.TargetAccountID)
c.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
// Invalidate source account's following
// lists, and destination's follwer lists.
// (see FollowIDs() comment for details).
c.GTS.FollowIDs().InvalidateAll(
">"+follow.AccountID,
"l>"+follow.AccountID,
"<"+follow.AccountID,
"l<"+follow.AccountID,
"<"+follow.TargetAccountID,
"l<"+follow.TargetAccountID,
">"+follow.TargetAccountID,
"l>"+follow.TargetAccountID,
)
})
c.GTS.FollowRequest().SetInvalidateCallback(func(followReq *gtsmodel.FollowRequest) {
// Invalidate follow with this same ID.
c.GTS.Follow().Invalidate("ID", followReq.ID)
// Invalidate source account's followreq
// lists, and destinations follow req lists.
// (see FollowRequestIDs() comment for details).
c.GTS.FollowRequestIDs().InvalidateAll(
">"+followReq.AccountID,
"<"+followReq.AccountID,
">"+followReq.TargetAccountID,
"<"+followReq.TargetAccountID,
)
})
c.GTS.List().SetInvalidateCallback(func(list *gtsmodel.List) {
// Invalidate all cached entries of this list.
c.GTS.ListEntry().Invalidate("ListID", list.ID)
})
c.GTS.Media().SetInvalidateCallback(func(media *gtsmodel.MediaAttachment) {
if *media.Avatar || *media.Header {
// Invalidate cache of attaching account.
c.GTS.Account().Invalidate("ID", media.AccountID)
}
if media.StatusID != "" {
// Invalidate cache of attaching status.
c.GTS.Status().Invalidate("ID", media.StatusID)
}
})
c.GTS.Poll().SetInvalidateCallback(func(poll *gtsmodel.Poll) {
// Invalidate all cached votes of this poll.
c.GTS.PollVote().Invalidate("PollID", poll.ID)
// Invalidate cache of poll vote IDs.
c.GTS.PollVoteIDs().Invalidate(poll.ID)
})
c.GTS.PollVote().SetInvalidateCallback(func(vote *gtsmodel.PollVote) {
// Invalidate cached poll (contains no. votes).
c.GTS.Poll().Invalidate("ID", vote.PollID)
// Invalidate cache of poll vote IDs.
c.GTS.PollVoteIDs().Invalidate(vote.PollID)
})
c.GTS.Status().SetInvalidateCallback(func(status *gtsmodel.Status) {
// Invalidate status ID cached visibility.
c.Visibility.Invalidate("ItemID", status.ID)
for _, id := range status.AttachmentIDs {
// Invalidate each media by the IDs we're aware of.
// This must be done as the status table is aware of
// the media IDs in use before the media table is
// aware of the status ID they are linked to.
//
// c.GTS.Media().Invalidate("StatusID") will not work.
c.GTS.Media().Invalidate("ID", id)
}
if status.BoostOfID != "" {
// Invalidate boost ID list of the original status.
c.GTS.BoostOfIDs().Invalidate(status.BoostOfID)
}
if status.InReplyToID != "" {
// Invalidate in reply to ID list of original status.
c.GTS.InReplyToIDs().Invalidate(status.InReplyToID)
}
if status.PollID != "" {
// Invalidate cache of attached poll ID.
c.GTS.Poll().Invalidate("ID", status.PollID)
}
})
c.GTS.StatusFave().SetInvalidateCallback(func(fave *gtsmodel.StatusFave) {
// Invalidate status fave ID list for this status.
c.GTS.StatusFaveIDs().Invalidate(fave.StatusID)
})
c.GTS.User().SetInvalidateCallback(func(user *gtsmodel.User) {
// Invalidate local account ID cached visibility.
c.Visibility.Invalidate("ItemID", user.AccountID)
c.Visibility.Invalidate("RequesterID", user.AccountID)
})
tryUntil("stopping *gtsmodel.Webfinger cache", 5, c.GTS.Webfinger.Stop)
}
// Sweep will sweep all the available caches to ensure none
@ -250,30 +112,30 @@ func (c *Caches) setuphooks() {
// require an eviction on every single write, which adds
// significant overhead to all cache writes.
func (c *Caches) Sweep(threshold float64) {
c.GTS.Account().Trim(threshold)
c.GTS.AccountNote().Trim(threshold)
c.GTS.Block().Trim(threshold)
c.GTS.BlockIDs().Trim(threshold)
c.GTS.Emoji().Trim(threshold)
c.GTS.EmojiCategory().Trim(threshold)
c.GTS.Follow().Trim(threshold)
c.GTS.FollowIDs().Trim(threshold)
c.GTS.FollowRequest().Trim(threshold)
c.GTS.FollowRequestIDs().Trim(threshold)
c.GTS.Instance().Trim(threshold)
c.GTS.List().Trim(threshold)
c.GTS.ListEntry().Trim(threshold)
c.GTS.Marker().Trim(threshold)
c.GTS.Media().Trim(threshold)
c.GTS.Mention().Trim(threshold)
c.GTS.Notification().Trim(threshold)
c.GTS.Poll().Trim(threshold)
c.GTS.Report().Trim(threshold)
c.GTS.Status().Trim(threshold)
c.GTS.StatusFave().Trim(threshold)
c.GTS.Tag().Trim(threshold)
c.GTS.ThreadMute().Trim(threshold)
c.GTS.Tombstone().Trim(threshold)
c.GTS.User().Trim(threshold)
c.GTS.Account.Trim(threshold)
c.GTS.AccountNote.Trim(threshold)
c.GTS.Block.Trim(threshold)
c.GTS.BlockIDs.Trim(threshold)
c.GTS.Emoji.Trim(threshold)
c.GTS.EmojiCategory.Trim(threshold)
c.GTS.Follow.Trim(threshold)
c.GTS.FollowIDs.Trim(threshold)
c.GTS.FollowRequest.Trim(threshold)
c.GTS.FollowRequestIDs.Trim(threshold)
c.GTS.Instance.Trim(threshold)
c.GTS.List.Trim(threshold)
c.GTS.ListEntry.Trim(threshold)
c.GTS.Marker.Trim(threshold)
c.GTS.Media.Trim(threshold)
c.GTS.Mention.Trim(threshold)
c.GTS.Notification.Trim(threshold)
c.GTS.Poll.Trim(threshold)
c.GTS.Report.Trim(threshold)
c.GTS.Status.Trim(threshold)
c.GTS.StatusFave.Trim(threshold)
c.GTS.Tag.Trim(threshold)
c.GTS.ThreadMute.Trim(threshold)
c.GTS.Tombstone.Trim(threshold)
c.GTS.User.Trim(threshold)
c.Visibility.Trim(threshold)
}

1071
internal/cache/db.go vendored Normal file

File diff suppressed because it is too large Load diff

1119
internal/cache/gts.go vendored

File diff suppressed because it is too large Load diff

192
internal/cache/invalidate.go vendored Normal file
View file

@ -0,0 +1,192 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package cache
import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Below are cache invalidation hooks between other caches,
// as an invalidation indicates a database INSERT / UPDATE / DELETE.
// NOTE THEY ARE ONLY CALLED WHEN THE ITEM IS IN THE CACHE, SO FOR
// HOOKS TO BE CALLED ON DELETE YOU MUST FIRST POPULATE IT IN THE CACHE.
func (c *Caches) OnInvalidateAccount(account *gtsmodel.Account) {
// Invalidate account ID cached visibility.
c.Visibility.Invalidate("ItemID", account.ID)
c.Visibility.Invalidate("RequesterID", account.ID)
// Invalidate this account's
// following / follower lists.
// (see FollowIDs() comment for details).
c.GTS.FollowIDs.InvalidateAll(
">"+account.ID,
"l>"+account.ID,
"<"+account.ID,
"l<"+account.ID,
)
// Invalidate this account's
// follow requesting / request lists.
// (see FollowRequestIDs() comment for details).
c.GTS.FollowRequestIDs.InvalidateAll(
">"+account.ID,
"<"+account.ID,
)
// Invalidate this account's block lists.
c.GTS.BlockIDs.Invalidate(account.ID)
}
func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) {
// Invalidate block origin account ID cached visibility.
c.Visibility.Invalidate("ItemID", block.AccountID)
c.Visibility.Invalidate("RequesterID", block.AccountID)
// Invalidate block target account ID cached visibility.
c.Visibility.Invalidate("ItemID", block.TargetAccountID)
c.Visibility.Invalidate("RequesterID", block.TargetAccountID)
// Invalidate source account's block lists.
c.GTS.BlockIDs.Invalidate(block.AccountID)
}
func (c *Caches) OnInvalidateEmojiCategory(category *gtsmodel.EmojiCategory) {
// Invalidate any emoji in this category.
c.GTS.Emoji.Invalidate("CategoryID", category.ID)
}
func (c *Caches) OnInvalidateFollow(follow *gtsmodel.Follow) {
// Invalidate follow request with this same ID.
c.GTS.FollowRequest.Invalidate("ID", follow.ID)
// Invalidate any related list entries.
c.GTS.ListEntry.Invalidate("FollowID", follow.ID)
// Invalidate follow origin account ID cached visibility.
c.Visibility.Invalidate("ItemID", follow.AccountID)
c.Visibility.Invalidate("RequesterID", follow.AccountID)
// Invalidate follow target account ID cached visibility.
c.Visibility.Invalidate("ItemID", follow.TargetAccountID)
c.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
// Invalidate source account's following
// lists, and destination's follwer lists.
// (see FollowIDs() comment for details).
c.GTS.FollowIDs.InvalidateAll(
">"+follow.AccountID,
"l>"+follow.AccountID,
"<"+follow.AccountID,
"l<"+follow.AccountID,
"<"+follow.TargetAccountID,
"l<"+follow.TargetAccountID,
">"+follow.TargetAccountID,
"l>"+follow.TargetAccountID,
)
}
func (c *Caches) OnInvalidateFollowRequest(followReq *gtsmodel.FollowRequest) {
// Invalidate follow with this same ID.
c.GTS.Follow.Invalidate("ID", followReq.ID)
// Invalidate source account's followreq
// lists, and destinations follow req lists.
// (see FollowRequestIDs() comment for details).
c.GTS.FollowRequestIDs.InvalidateAll(
">"+followReq.AccountID,
"<"+followReq.AccountID,
">"+followReq.TargetAccountID,
"<"+followReq.TargetAccountID,
)
}
func (c *Caches) OnInvalidateList(list *gtsmodel.List) {
// Invalidate all cached entries of this list.
c.GTS.ListEntry.Invalidate("ListID", list.ID)
}
func (c *Caches) OnInvalidateMedia(media *gtsmodel.MediaAttachment) {
if (media.Avatar != nil && *media.Avatar) ||
(media.Header != nil && *media.Header) {
// Invalidate cache of attaching account.
c.GTS.Account.Invalidate("ID", media.AccountID)
}
if media.StatusID != "" {
// Invalidate cache of attaching status.
c.GTS.Status.Invalidate("ID", media.StatusID)
}
}
func (c *Caches) OnInvalidatePoll(poll *gtsmodel.Poll) {
// Invalidate all cached votes of this poll.
c.GTS.PollVote.Invalidate("PollID", poll.ID)
// Invalidate cache of poll vote IDs.
c.GTS.PollVoteIDs.Invalidate(poll.ID)
}
func (c *Caches) OnInvalidatePollVote(vote *gtsmodel.PollVote) {
// Invalidate cached poll (contains no. votes).
c.GTS.Poll.Invalidate("ID", vote.PollID)
// Invalidate cache of poll vote IDs.
c.GTS.PollVoteIDs.Invalidate(vote.PollID)
}
func (c *Caches) OnInvalidateStatus(status *gtsmodel.Status) {
// Invalidate status ID cached visibility.
c.Visibility.Invalidate("ItemID", status.ID)
for _, id := range status.AttachmentIDs {
// Invalidate each media by the IDs we're aware of.
// This must be done as the status table is aware of
// the media IDs in use before the media table is
// aware of the status ID they are linked to.
//
// c.GTS.Media().Invalidate("StatusID") will not work.
c.GTS.Media.Invalidate("ID", id)
}
if status.BoostOfID != "" {
// Invalidate boost ID list of the original status.
c.GTS.BoostOfIDs.Invalidate(status.BoostOfID)
}
if status.InReplyToID != "" {
// Invalidate in reply to ID list of original status.
c.GTS.InReplyToIDs.Invalidate(status.InReplyToID)
}
if status.PollID != "" {
// Invalidate cache of attached poll ID.
c.GTS.Poll.Invalidate("ID", status.PollID)
}
}
func (c *Caches) OnInvalidateStatusFave(fave *gtsmodel.StatusFave) {
// Invalidate status fave ID list for this status.
c.GTS.StatusFaveIDs.Invalidate(fave.StatusID)
}
func (c *Caches) OnInvalidateUser(user *gtsmodel.User) {
// Invalidate local account ID cached visibility.
c.Visibility.Invalidate("ItemID", user.AccountID)
c.Visibility.Invalidate("RequesterID", user.AccountID)
}

View file

@ -18,18 +18,16 @@
package cache
import (
"codeberg.org/gruf/go-cache/v3/result"
"codeberg.org/gruf/go-structr"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
type VisibilityCache struct {
*result.Cache[*CachedVisibility]
structr.Cache[*CachedVisibility]
}
// Init will initialize the visibility cache in this collection.
// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
func (c *VisibilityCache) Init() {
func (c *Caches) initVisibility() {
// Calculate maximum cache size.
cap := calculateResultCacheMax(
sizeofVisibility(), // model in-mem size.
@ -38,25 +36,22 @@ func (c *VisibilityCache) Init() {
log.Infof(nil, "Visibility cache size = %d", cap)
c.Cache = result.New([]result.Lookup{
{Name: "ItemID", Multi: true},
{Name: "RequesterID", Multi: true},
{Name: "Type.RequesterID.ItemID"},
}, func(v1 *CachedVisibility) *CachedVisibility {
copyF := func(v1 *CachedVisibility) *CachedVisibility {
v2 := new(CachedVisibility)
*v2 = *v1
return v2
}, cap)
}
c.Cache.IgnoreErrors(ignoreErrors)
}
// Start will attempt to start the visibility cache, or panic.
func (c *VisibilityCache) Start() {
}
// Stop will attempt to stop the visibility cache, or panic.
func (c *VisibilityCache) Stop() {
c.Visibility.Init(structr.Config[*CachedVisibility]{
Indices: []structr.IndexConfig{
{Fields: "ItemID", Multiple: true},
{Fields: "RequesterID", Multiple: true},
{Fields: "Type,RequesterID,ItemID"},
},
MaxSize: cap,
IgnoreErr: ignoreErrors,
CopyValue: copyF,
})
}
// VisibilityType represents a visibility lookup type.

View file

@ -116,7 +116,7 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
return a.getAccount(
ctx,
"Username.Domain",
"Username,Domain",
func(account *gtsmodel.Account) error {
q := a.db.NewSelect().
Model(account)
@ -224,7 +224,7 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Account) error, keyParts ...any) (*gtsmodel.Account, error) {
// Fetch account from database cache with loader callback
account, err := a.state.Caches.GTS.Account().Load(lookup, func() (*gtsmodel.Account, error) {
account, err := a.state.Caches.GTS.Account.LoadOne(lookup, func() (*gtsmodel.Account, error) {
var account gtsmodel.Account
// Not cached! Perform database query
@ -325,7 +325,7 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou
}
func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) error {
return a.state.Caches.GTS.Account().Store(account, func() error {
return a.state.Caches.GTS.Account.Store(account, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
@ -354,7 +354,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
columns = append(columns, "updated_at")
}
return a.state.Caches.GTS.Account().Store(account, func() error {
return a.state.Caches.GTS.Account.Store(account, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
@ -393,7 +393,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
}
func (a *accountDB) DeleteAccount(ctx context.Context, id string) error {
defer a.state.Caches.GTS.Account().Invalidate("ID", id)
defer a.state.Caches.GTS.Account.Invalidate("ID", id)
// Load account into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
@ -635,6 +635,10 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
return nil, err
}
if len(statusIDs) == 0 {
return nil, db.ErrNoEntries
}
// If we're paging up, we still want statuses
// to be sorted by ID desc, so reverse ids slice.
// https://zchee.github.io/golang-wiki/SliceTricks/#reversing
@ -644,7 +648,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
}
}
return a.statusesFromIDs(ctx, statusIDs)
return a.state.DB.GetStatusesByIDs(ctx, statusIDs)
}
func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID string) ([]*gtsmodel.Status, error) {
@ -662,7 +666,11 @@ func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID stri
return nil, err
}
return a.statusesFromIDs(ctx, statusIDs)
if len(statusIDs) == 0 {
return nil, db.ErrNoEntries
}
return a.state.DB.GetStatusesByIDs(ctx, statusIDs)
}
func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) {
@ -710,29 +718,9 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
return nil, err
}
return a.statusesFromIDs(ctx, statusIDs)
}
func (a *accountDB) statusesFromIDs(ctx context.Context, statusIDs []string) ([]*gtsmodel.Status, error) {
// Catch case of no statuses early
if len(statusIDs) == 0 {
return nil, db.ErrNoEntries
}
// Allocate return slice (will be at most len statusIDS)
statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
for _, id := range statusIDs {
// Fetch from status from database by ID
status, err := a.state.DB.GetStatusByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting status %q: %v", id, err)
continue
}
// Append to return slice
statuses = append(statuses, status)
}
return statuses, nil
return a.state.DB.GetStatusesByIDs(ctx, statusIDs)
}

View file

@ -53,7 +53,7 @@ func (a *applicationDB) GetApplicationByClientID(ctx context.Context, clientID s
}
func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Application) error, keyParts ...any) (*gtsmodel.Application, error) {
return a.state.Caches.GTS.Application().Load(lookup, func() (*gtsmodel.Application, error) {
return a.state.Caches.GTS.Application.LoadOne(lookup, func() (*gtsmodel.Application, error) {
var app gtsmodel.Application
// Not cached! Perform database query.
@ -66,7 +66,7 @@ func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQue
}
func (a *applicationDB) PutApplication(ctx context.Context, app *gtsmodel.Application) error {
return a.state.Caches.GTS.Application().Store(app, func() error {
return a.state.Caches.GTS.Application.Store(app, func() error {
_, err := a.db.NewInsert().Model(app).Exec(ctx)
return err
})
@ -91,7 +91,7 @@ func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientI
//
// Clear application from the cache.
a.state.Caches.GTS.Application().Invalidate("ClientID", clientID)
a.state.Caches.GTS.Application.Invalidate("ClientID", clientID)
return nil
}

View file

@ -258,7 +258,7 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
state: state,
},
Tag: &tagDB{
conn: db,
db: db,
state: state,
},
Thread: &threadDB{

View file

@ -51,7 +51,7 @@ func (d *domainDB) CreateDomainAllow(ctx context.Context, allow *gtsmodel.Domain
}
// Clear the domain allow cache (for later reload)
d.state.Caches.GTS.DomainAllow().Clear()
d.state.Caches.GTS.DomainAllow.Clear()
return nil
}
@ -126,7 +126,7 @@ func (d *domainDB) DeleteDomainAllow(ctx context.Context, domain string) error {
}
// Clear the domain allow cache (for later reload)
d.state.Caches.GTS.DomainAllow().Clear()
d.state.Caches.GTS.DomainAllow.Clear()
return nil
}
@ -147,7 +147,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain
}
// Clear the domain block cache (for later reload)
d.state.Caches.GTS.DomainBlock().Clear()
d.state.Caches.GTS.DomainBlock.Clear()
return nil
}
@ -222,7 +222,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error {
}
// Clear the domain block cache (for later reload)
d.state.Caches.GTS.DomainBlock().Clear()
d.state.Caches.GTS.DomainBlock.Clear()
return nil
}
@ -241,7 +241,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er
}
// Check the cache for an explicit domain allow (hydrating the cache with callback if necessary).
explicitAllow, err := d.state.Caches.GTS.DomainAllow().Matches(domain, func() ([]string, error) {
explicitAllow, err := d.state.Caches.GTS.DomainAllow.Matches(domain, func() ([]string, error) {
var domains []string
// Scan list of all explicitly allowed domains from DB
@ -259,7 +259,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er
}
// Check the cache for a domain block (hydrating the cache with callback if necessary)
explicitBlock, err := d.state.Caches.GTS.DomainBlock().Matches(domain, func() ([]string, error) {
explicitBlock, err := d.state.Caches.GTS.DomainBlock.Matches(domain, func() ([]string, error) {
var domains []string
// Scan list of all blocked domains from DB

View file

@ -21,6 +21,7 @@ import (
"context"
"database/sql"
"errors"
"slices"
"strings"
"time"
@ -30,6 +31,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
)
@ -40,7 +42,7 @@ type emojiDB struct {
}
func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) error {
return e.state.Caches.GTS.Emoji().Store(emoji, func() error {
return e.state.Caches.GTS.Emoji.Store(emoji, func() error {
_, err := e.db.NewInsert().Model(emoji).Exec(ctx)
return err
})
@ -54,7 +56,7 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column
}
// Update the emoji model in the database.
return e.state.Caches.GTS.Emoji().Store(emoji, func() error {
return e.state.Caches.GTS.Emoji.Store(emoji, func() error {
_, err := e.db.
NewUpdate().
Model(emoji).
@ -74,21 +76,21 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {
defer func() {
// Invalidate cached emoji.
e.state.Caches.GTS.
Emoji().
Emoji.
Invalidate("ID", id)
for _, id := range accountIDs {
for _, accountID := range accountIDs {
// Invalidate cached account.
e.state.Caches.GTS.
Account().
Invalidate("ID", id)
Account.
Invalidate("ID", accountID)
}
for _, id := range statusIDs {
for _, statusID := range statusIDs {
// Invalidate cached account.
e.state.Caches.GTS.
Status().
Invalidate("ID", id)
Status.
Invalidate("ID", statusID)
}
}()
@ -129,26 +131,28 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {
return err
}
for _, id := range statusIDs {
for _, statusID := range statusIDs {
var emojiIDs []string
// Select statuses with ID.
if _, err := tx.NewSelect().
Table("statuses").
Column("emojis").
Where("? = ?", bun.Ident("id"), id).
Where("? = ?", bun.Ident("id"), statusID).
Exec(ctx); err != nil &&
err != sql.ErrNoRows {
return err
}
// Drop ID from account emojis.
emojiIDs = dropID(emojiIDs, id)
// Delete all instances of this emoji ID from status emojis.
emojiIDs = slices.DeleteFunc(emojiIDs, func(emojiID string) bool {
return emojiID == id
})
// Update status emoji IDs.
if _, err := tx.NewUpdate().
Table("statuses").
Where("? = ?", bun.Ident("id"), id).
Where("? = ?", bun.Ident("id"), statusID).
Set("emojis = ?", emojiIDs).
Exec(ctx); err != nil &&
err != sql.ErrNoRows {
@ -156,26 +160,28 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {
}
}
for _, id := range accountIDs {
for _, accountID := range accountIDs {
var emojiIDs []string
// Select account with ID.
if _, err := tx.NewSelect().
Table("accounts").
Column("emojis").
Where("? = ?", bun.Ident("id"), id).
Where("? = ?", bun.Ident("id"), accountID).
Exec(ctx); err != nil &&
err != sql.ErrNoRows {
return err
}
// Drop ID from account emojis.
emojiIDs = dropID(emojiIDs, id)
// Delete all instances of this emoji ID from account emojis.
emojiIDs = slices.DeleteFunc(emojiIDs, func(emojiID string) bool {
return emojiID == id
})
// Update account emoji IDs.
if _, err := tx.NewUpdate().
Table("accounts").
Where("? = ?", bun.Ident("id"), id).
Where("? = ?", bun.Ident("id"), accountID).
Set("emojis = ?", emojiIDs).
Exec(ctx); err != nil &&
err != sql.ErrNoRows {
@ -431,7 +437,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj
func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, error) {
return e.getEmoji(
ctx,
"Shortcode.Domain",
"Shortcode,Domain",
func(emoji *gtsmodel.Emoji) error {
q := e.db.
NewSelect().
@ -468,7 +474,7 @@ func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string
}
func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) error {
return e.state.Caches.GTS.EmojiCategory().Store(emojiCategory, func() error {
return e.state.Caches.GTS.EmojiCategory.Store(emojiCategory, func() error {
_, err := e.db.NewInsert().Model(emojiCategory).Exec(ctx)
return err
})
@ -520,7 +526,7 @@ func (e *emojiDB) GetEmojiCategoryByName(ctx context.Context, name string) (*gts
func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Emoji) error, keyParts ...any) (*gtsmodel.Emoji, error) {
// Fetch emoji from database cache with loader callback
emoji, err := e.state.Caches.GTS.Emoji().Load(lookup, func() (*gtsmodel.Emoji, error) {
emoji, err := e.state.Caches.GTS.Emoji.LoadOne(lookup, func() (*gtsmodel.Emoji, error) {
var emoji gtsmodel.Emoji
// Not cached! Perform database query
@ -568,28 +574,72 @@ func (e *emojiDB) PopulateEmoji(ctx context.Context, emoji *gtsmodel.Emoji) erro
return errs.Combine()
}
func (e *emojiDB) GetEmojisByIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, error) {
if len(emojiIDs) == 0 {
func (e *emojiDB) GetEmojisByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Emoji, error) {
if len(ids) == 0 {
return nil, db.ErrNoEntries
}
emojis := make([]*gtsmodel.Emoji, 0, len(emojiIDs))
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
for _, id := range emojiIDs {
emoji, err := e.GetEmojiByID(ctx, id)
if err != nil {
log.Errorf(ctx, "emojisFromIDs: error getting emoji %q: %v", id, err)
continue
}
// Load all emoji IDs via cache loader callbacks.
emojis, err := e.state.Caches.GTS.Emoji.Load("ID",
emojis = append(emojis, emoji)
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached emoji loader function.
func() ([]*gtsmodel.Emoji, error) {
// Preallocate expected length of uncached emojis.
emojis := make([]*gtsmodel.Emoji, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := e.db.NewSelect().
Model(&emojis).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return emojis, nil
},
)
if err != nil {
return nil, err
}
// Reorder the emojis by their
// IDs to ensure in correct order.
getID := func(e *gtsmodel.Emoji) string { return e.ID }
util.OrderBy(emojis, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return emojis, nil
}
// Populate all loaded emojis, removing those we fail to
// populate (removes needing so many nil checks everywhere).
emojis = slices.DeleteFunc(emojis, func(emoji *gtsmodel.Emoji) bool {
if err := e.PopulateEmoji(ctx, emoji); err != nil {
log.Errorf(ctx, "error populating emoji %s: %v", emoji.ID, err)
return true
}
return false
})
return emojis, nil
}
func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery func(*gtsmodel.EmojiCategory) error, keyParts ...any) (*gtsmodel.EmojiCategory, error) {
return e.state.Caches.GTS.EmojiCategory().Load(lookup, func() (*gtsmodel.EmojiCategory, error) {
return e.state.Caches.GTS.EmojiCategory.LoadOne(lookup, func() (*gtsmodel.EmojiCategory, error) {
var category gtsmodel.EmojiCategory
// Not cached! Perform database query
@ -601,36 +651,51 @@ func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery f
}, keyParts...)
}
func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, emojiCategoryIDs []string) ([]*gtsmodel.EmojiCategory, error) {
if len(emojiCategoryIDs) == 0 {
func (e *emojiDB) GetEmojiCategoriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.EmojiCategory, error) {
if len(ids) == 0 {
return nil, db.ErrNoEntries
}
emojiCategories := make([]*gtsmodel.EmojiCategory, 0, len(emojiCategoryIDs))
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
for _, id := range emojiCategoryIDs {
emojiCategory, err := e.GetEmojiCategory(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting emoji category %q: %v", id, err)
continue
}
// Load all category IDs via cache loader callbacks.
categories, err := e.state.Caches.GTS.EmojiCategory.Load("ID",
emojiCategories = append(emojiCategories, emojiCategory)
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached emoji loader function.
func() ([]*gtsmodel.EmojiCategory, error) {
// Preallocate expected length of uncached categories.
categories := make([]*gtsmodel.EmojiCategory, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := e.db.NewSelect().
Model(&categories).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return categories, nil
},
)
if err != nil {
return nil, err
}
return emojiCategories, nil
}
// Reorder the categories by their
// IDs to ensure in correct order.
getID := func(c *gtsmodel.EmojiCategory) string { return c.ID }
util.OrderBy(categories, ids, getID)
// dropIDs drops given ID string from IDs slice.
func dropID(ids []string, id string) []string {
for i := 0; i < len(ids); {
if ids[i] == id {
// Remove this reference.
copy(ids[i:], ids[i+1:])
ids = ids[:len(ids)-1]
continue
}
i++
}
return ids
return categories, nil
}

View file

@ -143,7 +143,7 @@ func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel.
func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, error) {
// Fetch instance from database cache with loader callback
instance, err := i.state.Caches.GTS.Instance().Load(lookup, func() (*gtsmodel.Instance, error) {
instance, err := i.state.Caches.GTS.Instance.LoadOne(lookup, func() (*gtsmodel.Instance, error) {
var instance gtsmodel.Instance
// Not cached! Perform database query.
@ -219,7 +219,7 @@ func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instanc
return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err)
}
return i.state.Caches.GTS.Instance().Store(instance, func() error {
return i.state.Caches.GTS.Instance.Store(instance, func() error {
_, err := i.db.NewInsert().Model(instance).Exec(ctx)
return err
})
@ -239,7 +239,7 @@ func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Inst
columns = append(columns, "updated_at")
}
return i.state.Caches.GTS.Instance().Store(instance, func() error {
return i.state.Caches.GTS.Instance.Store(instance, func() error {
_, err := i.db.
NewUpdate().
Model(instance).

View file

@ -21,6 +21,7 @@ import (
"context"
"errors"
"fmt"
"slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@ -29,6 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@ -56,7 +58,7 @@ func (l *listDB) GetListByID(ctx context.Context, id string) (*gtsmodel.List, er
}
func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmodel.List) error, keyParts ...any) (*gtsmodel.List, error) {
list, err := l.state.Caches.GTS.List().Load(lookup, func() (*gtsmodel.List, error) {
list, err := l.state.Caches.GTS.List.LoadOne(lookup, func() (*gtsmodel.List, error) {
var list gtsmodel.List
// Not cached! Perform database query.
@ -100,18 +102,8 @@ func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([]
return nil, nil
}
// Select each list using its ID to ensure cache used.
lists := make([]*gtsmodel.List, 0, len(listIDs))
for _, id := range listIDs {
list, err := l.state.DB.GetListByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching list %q: %v", id, err)
continue
}
lists = append(lists, list)
}
return lists, nil
// Return lists by their IDs.
return l.GetListsByIDs(ctx, listIDs)
}
func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
@ -147,7 +139,7 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error {
}
func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error {
return l.state.Caches.GTS.List().Store(list, func() error {
return l.state.Caches.GTS.List.Store(list, func() error {
_, err := l.db.NewInsert().Model(list).Exec(ctx)
return err
})
@ -162,7 +154,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ..
defer func() {
// Invalidate all entries for this list ID.
l.state.Caches.GTS.ListEntry().Invalidate("ListID", list.ID)
l.state.Caches.GTS.ListEntry.Invalidate("ListID", list.ID)
// Invalidate this entire list's timeline.
if err := l.state.Timelines.List.RemoveTimeline(ctx, list.ID); err != nil {
@ -170,7 +162,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns ..
}
}()
return l.state.Caches.GTS.List().Store(list, func() error {
return l.state.Caches.GTS.List.Store(list, func() error {
_, err := l.db.NewUpdate().
Model(list).
Where("? = ?", bun.Ident("list.id"), list.ID).
@ -198,7 +190,7 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error {
defer func() {
// Invalidate this list from cache.
l.state.Caches.GTS.List().Invalidate("ID", id)
l.state.Caches.GTS.List.Invalidate("ID", id)
// Invalidate this entire list's timeline.
if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil {
@ -243,7 +235,7 @@ func (l *listDB) GetListEntryByID(ctx context.Context, id string) (*gtsmodel.Lis
}
func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(*gtsmodel.ListEntry) error, keyParts ...any) (*gtsmodel.ListEntry, error) {
listEntry, err := l.state.Caches.GTS.ListEntry().Load(lookup, func() (*gtsmodel.ListEntry, error) {
listEntry, err := l.state.Caches.GTS.ListEntry.LoadOne(lookup, func() (*gtsmodel.ListEntry, error) {
var listEntry gtsmodel.ListEntry
// Not cached! Perform database query.
@ -344,18 +336,128 @@ func (l *listDB) GetListEntries(ctx context.Context,
}
}
// Select each list entry using its ID to ensure cache used.
listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs))
for _, id := range entryIDs {
listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
continue
}
listEntries = append(listEntries, listEntry)
// Return list entries by their IDs.
return l.GetListEntriesByIDs(ctx, entryIDs)
}
func (l *listDB) GetListsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.List, error) {
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
// Load all list IDs via cache loader callbacks.
lists, err := l.state.Caches.GTS.List.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached list loader function.
func() ([]*gtsmodel.List, error) {
// Preallocate expected length of uncached lists.
lists := make([]*gtsmodel.List, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := l.db.NewSelect().
Model(&lists).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return lists, nil
},
)
if err != nil {
return nil, err
}
return listEntries, nil
// Reorder the lists by their
// IDs to ensure in correct order.
getID := func(l *gtsmodel.List) string { return l.ID }
util.OrderBy(lists, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return lists, nil
}
// Populate all loaded lists, removing those we fail to
// populate (removes needing so many nil checks everywhere).
lists = slices.DeleteFunc(lists, func(list *gtsmodel.List) bool {
if err := l.PopulateList(ctx, list); err != nil {
log.Errorf(ctx, "error populating list %s: %v", list.ID, err)
return true
}
return false
})
return lists, nil
}
func (l *listDB) GetListEntriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.ListEntry, error) {
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
// Load all entry IDs via cache loader callbacks.
entries, err := l.state.Caches.GTS.ListEntry.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached entry loader function.
func() ([]*gtsmodel.ListEntry, error) {
// Preallocate expected length of uncached entries.
entries := make([]*gtsmodel.ListEntry, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := l.db.NewSelect().
Model(&entries).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return entries, nil
},
)
if err != nil {
return nil, err
}
// Reorder the entries by their
// IDs to ensure in correct order.
getID := func(e *gtsmodel.ListEntry) string { return e.ID }
util.OrderBy(entries, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return entries, nil
}
// Populate all loaded entries, removing those we fail to
// populate (removes needing so many nil checks everywhere).
entries = slices.DeleteFunc(entries, func(entry *gtsmodel.ListEntry) bool {
if err := l.PopulateListEntry(ctx, entry); err != nil {
log.Errorf(ctx, "error populating entry %s: %v", entry.ID, err)
return true
}
return false
})
return entries, nil
}
func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) ([]*gtsmodel.ListEntry, error) {
@ -376,18 +478,8 @@ func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string)
return nil, nil
}
// Select each list entry using its ID to ensure cache used.
listEntries := make([]*gtsmodel.ListEntry, 0, len(entryIDs))
for _, id := range entryIDs {
listEntry, err := l.state.DB.GetListEntryByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching list entry %q: %v", id, err)
continue
}
listEntries = append(listEntries, listEntry)
}
return listEntries, nil
// Return list entries by their IDs.
return l.GetListEntriesByIDs(ctx, entryIDs)
}
func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.ListEntry) error {
@ -409,10 +501,10 @@ func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.List
func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEntry) error {
defer func() {
// Collect unique list IDs from the entries.
listIDs := collate(func(i int) string {
return entries[i].ListID
}, len(entries))
// Collect unique list IDs from the provided entries.
listIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string {
return e.ListID
})
for _, id := range listIDs {
// Invalidate the timeline for the list this entry belongs to.
@ -426,7 +518,7 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt
return l.db.RunInTx(ctx, func(tx Tx) error {
for _, entry := range entries {
entry := entry // rescope
if err := l.state.Caches.GTS.ListEntry().Store(entry, func() error {
if err := l.state.Caches.GTS.ListEntry.Store(entry, func() error {
_, err := tx.
NewInsert().
Model(entry).
@ -459,7 +551,7 @@ func (l *listDB) DeleteListEntry(ctx context.Context, id string) error {
defer func() {
// Invalidate this list entry upon delete.
l.state.Caches.GTS.ListEntry().Invalidate("ID", id)
l.state.Caches.GTS.ListEntry.Invalidate("ID", id)
// Invalidate the timeline for the list this entry belongs to.
if err := l.state.Timelines.List.RemoveTimeline(ctx, entry.ListID); err != nil {
@ -514,24 +606,3 @@ func (l *listDB) ListIncludesAccount(ctx context.Context, listID string, account
return exists, err
}
// collate will collect the values of type T from an expected slice of length 'len',
// passing the expected index to each call of 'get' and deduplicating the end result.
func collate[T comparable](get func(int) T, len int) []T {
ts := make([]T, 0, len)
tm := make(map[T]struct{}, len)
for i := 0; i < len; i++ {
// Get next.
t := get(i)
if _, ok := tm[t]; !ok {
// New value, add
// to map + slice.
ts = append(ts, t)
tm[t] = struct{}{}
}
}
return ts
}

View file

@ -39,8 +39,8 @@ type markerDB struct {
*/
func (m *markerDB) GetMarker(ctx context.Context, accountID string, name gtsmodel.MarkerName) (*gtsmodel.Marker, error) {
marker, err := m.state.Caches.GTS.Marker().Load(
"AccountID.Name",
marker, err := m.state.Caches.GTS.Marker.LoadOne(
"AccountID,Name",
func() (*gtsmodel.Marker, error) {
var marker gtsmodel.Marker
@ -52,9 +52,7 @@ func (m *markerDB) GetMarker(ctx context.Context, accountID string, name gtsmode
}
return &marker, nil
},
accountID,
name,
}, accountID, name,
)
if err != nil {
return nil, err // already processed
@ -74,7 +72,7 @@ func (m *markerDB) UpdateMarker(ctx context.Context, marker *gtsmodel.Marker) er
marker.Version = prevMarker.Version + 1
}
return m.state.Caches.GTS.Marker().Store(marker, func() error {
return m.state.Caches.GTS.Marker.Store(marker, func() error {
if prevMarker == nil {
if _, err := m.db.NewInsert().
Model(marker).

View file

@ -20,14 +20,15 @@ package bundb
import (
"context"
"errors"
"slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@ -51,25 +52,52 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
}
func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) {
attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids))
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
for _, id := range ids {
// Attempt fetch from DB
attachment, err := m.GetAttachmentByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting attachment %q: %v", id, err)
continue
}
// Load all media IDs via cache loader callbacks.
media, err := m.state.Caches.GTS.Media.Load("ID",
// Append attachment
attachments = append(attachments, attachment)
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached media loader function.
func() ([]*gtsmodel.MediaAttachment, error) {
// Preallocate expected length of uncached media attachments.
media := make([]*gtsmodel.MediaAttachment, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := m.db.NewSelect().
Model(&media).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return media, nil
},
)
if err != nil {
return nil, err
}
return attachments, nil
// Reorder the media by their
// IDs to ensure in correct order.
getID := func(m *gtsmodel.MediaAttachment) string { return m.ID }
util.OrderBy(media, ids, getID)
return media, nil
}
func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, error) {
return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) {
return m.state.Caches.GTS.Media.LoadOne(lookup, func() (*gtsmodel.MediaAttachment, error) {
var attachment gtsmodel.MediaAttachment
// Not cached! Perform database query
@ -82,7 +110,7 @@ func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func
}
func (m *mediaDB) PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error {
return m.state.Caches.GTS.Media().Store(media, func() error {
return m.state.Caches.GTS.Media.Store(media, func() error {
_, err := m.db.NewInsert().Model(media).Exec(ctx)
return err
})
@ -95,7 +123,7 @@ func (m *mediaDB) UpdateAttachment(ctx context.Context, media *gtsmodel.MediaAtt
columns = append(columns, "updated_at")
}
return m.state.Caches.GTS.Media().Store(media, func() error {
return m.state.Caches.GTS.Media.Store(media, func() error {
_, err := m.db.NewUpdate().
Model(media).
Where("? = ?", bun.Ident("media_attachment.id"), media.ID).
@ -119,7 +147,7 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
}
// On return, ensure that media with ID is invalidated.
defer m.state.Caches.GTS.Media().Invalidate("ID", id)
defer m.state.Caches.GTS.Media.Invalidate("ID", id)
// Delete media attachment in new transaction.
err = m.db.RunInTx(ctx, func(tx Tx) error {
@ -171,8 +199,12 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
return gtserror.Newf("error selecting status: %w", err)
}
if updatedIDs := dropID(status.AttachmentIDs, id); // nocollapse
len(updatedIDs) != len(status.AttachmentIDs) {
// Delete all instances of this deleted media ID from status attachments.
updatedIDs := slices.DeleteFunc(status.AttachmentIDs, func(s string) bool {
return s == id
})
if len(updatedIDs) != len(status.AttachmentIDs) {
// Note: this handles not found.
//
// Attachments changed, update the status.

View file

@ -20,6 +20,7 @@ package bundb
import (
"context"
"errors"
"slices"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
@ -27,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@ -36,7 +38,7 @@ type mentionDB struct {
}
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, error) {
mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) {
mention, err := m.state.Caches.GTS.Mention.LoadOne("ID", func() (*gtsmodel.Mention, error) {
var mention gtsmodel.Mention
q := m.db.
@ -63,21 +65,64 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
}
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error) {
mentions := make([]*gtsmodel.Mention, 0, len(ids))
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
for _, id := range ids {
// Attempt fetch from DB
mention, err := m.GetMention(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting mention %q: %v", id, err)
continue
}
// Load all mention IDs via cache loader callbacks.
mentions, err := m.state.Caches.GTS.Mention.Load("ID",
// Append mention
mentions = append(mentions, mention)
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached mention loader function.
func() ([]*gtsmodel.Mention, error) {
// Preallocate expected length of uncached mentions.
mentions := make([]*gtsmodel.Mention, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := m.db.NewSelect().
Model(&mentions).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return mentions, nil
},
)
if err != nil {
return nil, err
}
// Reorder the mentions by their
// IDs to ensure in correct order.
getID := func(m *gtsmodel.Mention) string { return m.ID }
util.OrderBy(mentions, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return mentions, nil
}
// Populate all loaded mentions, removing those we fail to
// populate (removes needing so many nil checks everywhere).
mentions = slices.DeleteFunc(mentions, func(mention *gtsmodel.Mention) bool {
if err := m.PopulateMention(ctx, mention); err != nil {
log.Errorf(ctx, "error populating mention %s: %v", mention.ID, err)
return true
}
return false
})
return mentions, nil
}
func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Mention) (err error) {
@ -120,14 +165,14 @@ func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Menti
}
func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error {
return m.state.Caches.GTS.Mention().Store(mention, func() error {
return m.state.Caches.GTS.Mention.Store(mention, func() error {
_, err := m.db.NewInsert().Model(mention).Exec(ctx)
return err
})
}
func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error {
defer m.state.Caches.GTS.Mention().Invalidate("ID", id)
defer m.state.Caches.GTS.Mention.Invalidate("ID", id)
// Load mention into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate

View file

@ -20,6 +20,7 @@ package bundb
import (
"context"
"errors"
"slices"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
@ -28,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@ -37,18 +39,17 @@ type notificationDB struct {
}
func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) {
return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) {
var notif gtsmodel.Notification
q := n.db.NewSelect().
Model(&notif).
Where("? = ?", bun.Ident("notification.id"), id)
if err := q.Scan(ctx); err != nil {
return nil, err
}
return &notif, nil
}, id)
return n.getNotification(
ctx,
"ID",
func(notif *gtsmodel.Notification) error {
return n.db.NewSelect().
Model(notif).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
},
id,
)
}
func (n *notificationDB) GetNotification(
@ -58,42 +59,113 @@ func (n *notificationDB) GetNotification(
originAccountID string,
statusID string,
) (*gtsmodel.Notification, error) {
notif, err := n.state.Caches.GTS.Notification().Load("NotificationType.TargetAccountID.OriginAccountID.StatusID", func() (*gtsmodel.Notification, error) {
return n.getNotification(
ctx,
"NotificationType,TargetAccountID,OriginAccountID,StatusID",
func(notif *gtsmodel.Notification) error {
return n.db.NewSelect().
Model(notif).
Where("? = ?", bun.Ident("notification_type"), notificationType).
Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
Where("? = ?", bun.Ident("origin_account_id"), originAccountID).
Where("? = ?", bun.Ident("status_id"), statusID).
Scan(ctx)
},
notificationType, targetAccountID, originAccountID, statusID,
)
}
func (n *notificationDB) getNotification(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Notification) error, keyParts ...any) (*gtsmodel.Notification, error) {
// Fetch notification from cache with loader callback
notif, err := n.state.Caches.GTS.Notification.LoadOne(lookup, func() (*gtsmodel.Notification, error) {
var notif gtsmodel.Notification
q := n.db.NewSelect().
Model(&notif).
Where("? = ?", bun.Ident("notification_type"), notificationType).
Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
Where("? = ?", bun.Ident("origin_account_id"), originAccountID).
Where("? = ?", bun.Ident("status_id"), statusID)
if err := q.Scan(ctx); err != nil {
// Not cached! Perform database query
if err := dbQuery(&notif); err != nil {
return nil, err
}
return &notif, nil
}, notificationType, targetAccountID, originAccountID, statusID)
}, keyParts...)
if err != nil {
return nil, err
}
if gtscontext.Barebones(ctx) {
// no need to fully populate.
// Only a barebones model was requested.
return notif, nil
}
// Further populate the notif fields where applicable.
if err := n.PopulateNotification(ctx, notif); err != nil {
if err := n.state.DB.PopulateNotification(ctx, notif); err != nil {
return nil, err
}
return notif, nil
}
func (n *notificationDB) GetNotificationsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Notification, error) {
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
// Load all notif IDs via cache loader callbacks.
notifs, err := n.state.Caches.GTS.Notification.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached notification loader function.
func() ([]*gtsmodel.Notification, error) {
// Preallocate expected length of uncached notifications.
notifs := make([]*gtsmodel.Notification, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := n.db.NewSelect().
Model(&notifs).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return notifs, nil
},
)
if err != nil {
return nil, err
}
// Reorder the notifs by their
// IDs to ensure in correct order.
getID := func(n *gtsmodel.Notification) string { return n.ID }
util.OrderBy(notifs, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return notifs, nil
}
// Populate all loaded notifs, removing those we fail to
// populate (removes needing so many nil checks everywhere).
notifs = slices.DeleteFunc(notifs, func(notif *gtsmodel.Notification) bool {
if err := n.PopulateNotification(ctx, notif); err != nil {
log.Errorf(ctx, "error populating notif %s: %v", notif.ID, err)
return true
}
return false
})
return notifs, nil
}
func (n *notificationDB) PopulateNotification(ctx context.Context, notif *gtsmodel.Notification) error {
var (
errs = gtserror.NewMultiError(2)
errs gtserror.MultiError
err error
)
@ -211,31 +283,19 @@ func (n *notificationDB) GetAccountNotifications(
}
}
notifs := make([]*gtsmodel.Notification, 0, len(notifIDs))
for _, id := range notifIDs {
// Attempt fetch from DB
notif, err := n.GetNotificationByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching notification %q: %v", id, err)
continue
}
// Append notification
notifs = append(notifs, notif)
}
return notifs, nil
// Fetch notification models by their IDs.
return n.GetNotificationsByIDs(ctx, notifIDs)
}
func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error {
return n.state.Caches.GTS.Notification().Store(notif, func() error {
return n.state.Caches.GTS.Notification.Store(notif, func() error {
_, err := n.db.NewInsert().Model(notif).Exec(ctx)
return err
})
}
func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) error {
defer n.state.Caches.GTS.Notification().Invalidate("ID", id)
defer n.state.Caches.GTS.Notification.Invalidate("ID", id)
// Load notif into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
@ -288,7 +348,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string
defer func() {
// Invalidate all IDs on return.
for _, id := range notifIDs {
n.state.Caches.GTS.Notification().Invalidate("ID", id)
n.state.Caches.GTS.Notification.Invalidate("ID", id)
}
}()
@ -326,7 +386,7 @@ func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statu
defer func() {
// Invalidate all IDs on return.
for _, id := range notifIDs {
n.state.Caches.GTS.Notification().Invalidate("ID", id)
n.state.Caches.GTS.Notification.Invalidate("ID", id)
}
}()

View file

@ -20,6 +20,7 @@ package bundb
import (
"context"
"errors"
"slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@ -28,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@ -52,7 +54,7 @@ func (p *pollDB) GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, er
func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) {
// Fetch poll from database cache with loader callback
poll, err := p.state.Caches.GTS.Poll().Load(lookup, func() (*gtsmodel.Poll, error) {
poll, err := p.state.Caches.GTS.Poll.LoadOne(lookup, func() (*gtsmodel.Poll, error) {
var poll gtsmodel.Poll
// Not cached! Perform database query.
@ -140,7 +142,7 @@ func (p *pollDB) PutPoll(ctx context.Context, poll *gtsmodel.Poll) error {
// is non nil and set.
poll.CheckVotes()
return p.state.Caches.GTS.Poll().Store(poll, func() error {
return p.state.Caches.GTS.Poll.Store(poll, func() error {
_, err := p.db.NewInsert().Model(poll).Exec(ctx)
return err
})
@ -151,7 +153,7 @@ func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...st
// is non nil and set.
poll.CheckVotes()
return p.state.Caches.GTS.Poll().Store(poll, func() error {
return p.state.Caches.GTS.Poll.Store(poll, func() error {
return p.db.RunInTx(ctx, func(tx Tx) error {
// Update the status' "updated_at" field.
if _, err := tx.NewUpdate().
@ -184,8 +186,8 @@ func (p *pollDB) DeletePollByID(ctx context.Context, id string) error {
}
// Invalidate poll by ID from cache.
p.state.Caches.GTS.Poll().Invalidate("ID", id)
p.state.Caches.GTS.PollVoteIDs().Invalidate(id)
p.state.Caches.GTS.Poll.Invalidate("ID", id)
p.state.Caches.GTS.PollVoteIDs.Invalidate(id)
return nil
}
@ -207,7 +209,7 @@ func (p *pollDB) GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.Poll
func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) {
return p.getPollVote(
ctx,
"PollID.AccountID",
"PollID,AccountID",
func(vote *gtsmodel.PollVote) error {
return p.db.NewSelect().
Model(vote).
@ -222,7 +224,7 @@ func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID str
func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.PollVote) error, keyParts ...any) (*gtsmodel.PollVote, error) {
// Fetch vote from database cache with loader callback
vote, err := p.state.Caches.GTS.PollVote().Load(lookup, func() (*gtsmodel.PollVote, error) {
vote, err := p.state.Caches.GTS.PollVote.LoadOne(lookup, func() (*gtsmodel.PollVote, error) {
var vote gtsmodel.PollVote
// Not cached! Perform database query.
@ -250,7 +252,9 @@ func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*g
}
func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) {
voteIDs, err := p.state.Caches.GTS.PollVoteIDs().Load(pollID, func() ([]string, error) {
// Load vote IDs known for given poll ID using loader callback.
voteIDs, err := p.state.Caches.GTS.PollVoteIDs.Load(pollID, func() ([]string, error) {
var voteIDs []string
// Vote IDs not in cache, perform DB query!
@ -266,21 +270,62 @@ func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.P
return nil, err
}
// Preallocate slice of expected length.
votes := make([]*gtsmodel.PollVote, 0, len(voteIDs))
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(voteIDs))
for _, id := range voteIDs {
// Fetch poll vote model for this ID.
vote, err := p.GetPollVoteByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting poll vote %s: %v", id, err)
continue
}
// Load all votes from IDs via cache loader callbacks.
votes, err := p.state.Caches.GTS.PollVote.Load("ID",
// Append to return slice.
votes = append(votes, vote)
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range voteIDs {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached poll vote loader function.
func() ([]*gtsmodel.PollVote, error) {
// Preallocate expected length of uncached votes.
votes := make([]*gtsmodel.PollVote, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := p.db.NewSelect().
Model(&votes).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return votes, nil
},
)
if err != nil {
return nil, err
}
// Reorder the poll votes by their
// IDs to ensure in correct order.
getID := func(v *gtsmodel.PollVote) string { return v.ID }
util.OrderBy(votes, voteIDs, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return votes, nil
}
// Populate all loaded votes, removing those we fail to
// populate (removes needing so many nil checks everywhere).
votes = slices.DeleteFunc(votes, func(vote *gtsmodel.PollVote) bool {
if err := p.PopulatePollVote(ctx, vote); err != nil {
log.Errorf(ctx, "error populating vote %s: %v", vote.ID, err)
return true
}
return false
})
return votes, nil
}
@ -316,7 +361,7 @@ func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote)
}
func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error {
return p.state.Caches.GTS.PollVote().Store(vote, func() error {
return p.state.Caches.GTS.PollVote.Store(vote, func() error {
return p.db.RunInTx(ctx, func(tx Tx) error {
// Try insert vote into database.
if _, err := tx.NewInsert().
@ -416,9 +461,9 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
}
// Invalidate poll vote and poll entry from caches.
p.state.Caches.GTS.Poll().Invalidate("ID", pollID)
p.state.Caches.GTS.PollVote().Invalidate("PollID", pollID)
p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID)
p.state.Caches.GTS.Poll.Invalidate("ID", pollID)
p.state.Caches.GTS.PollVote.Invalidate("PollID", pollID)
p.state.Caches.GTS.PollVoteIDs.Invalidate(pollID)
return nil
}
@ -428,7 +473,7 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
// Slice should only ever be of length
// 0 or 1; it's a slice of slices only
// because we can't LIMIT deletes to 1.
var choicesSl [][]int
var choicesSlice [][]int
// Delete vote in poll by account,
// returning the ID + choices of the vote.
@ -437,17 +482,19 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
Where("? = ?", bun.Ident("poll_id"), pollID).
Where("? = ?", bun.Ident("account_id"), accountID).
Returning("?", bun.Ident("choices")).
Scan(ctx, &choicesSl); err != nil {
Scan(ctx, &choicesSlice); err != nil {
// irrecoverable.
return err
}
if len(choicesSl) != 1 {
if len(choicesSlice) != 1 {
// No poll votes by this
// acct on this poll.
return nil
}
choices := choicesSl[0]
// Extract the *actual* choices.
choices := choicesSlice[0]
// Select current poll counts from DB,
// taking minimal columns needed to
@ -489,9 +536,9 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID
}
// Invalidate poll vote and poll entry from caches.
p.state.Caches.GTS.Poll().Invalidate("ID", pollID)
p.state.Caches.GTS.PollVote().Invalidate("PollID.AccountID", pollID, accountID)
p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID)
p.state.Caches.GTS.Poll.Invalidate("ID", pollID)
p.state.Caches.GTS.PollVote.Invalidate("PollID,AccountID", pollID, accountID)
p.state.Caches.GTS.PollVoteIDs.Invalidate(pollID)
return nil
}

View file

@ -194,7 +194,7 @@ func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID strin
}
func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), ">"+accountID, page, func() ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.FollowIDs, ">"+accountID, page, func() ([]string, error) {
var followIDs []string
// Follow IDs not in cache, perform DB query!
@ -209,7 +209,7 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri
}
func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) {
return r.state.Caches.GTS.FollowIDs().Load("l>"+accountID, func() ([]string, error) {
return r.state.Caches.GTS.FollowIDs.Load("l>"+accountID, func() ([]string, error) {
var followIDs []string
// Follow IDs not in cache, perform DB query!
@ -224,7 +224,7 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID
}
func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), "<"+accountID, page, func() ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.FollowIDs, "<"+accountID, page, func() ([]string, error) {
var followIDs []string
// Follow IDs not in cache, perform DB query!
@ -239,7 +239,7 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st
}
func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) {
return r.state.Caches.GTS.FollowIDs().Load("l<"+accountID, func() ([]string, error) {
return r.state.Caches.GTS.FollowIDs.Load("l<"+accountID, func() ([]string, error) {
var followIDs []string
// Follow IDs not in cache, perform DB query!
@ -254,7 +254,7 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account
}
func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), ">"+accountID, page, func() ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs, ">"+accountID, page, func() ([]string, error) {
var followReqIDs []string
// Follow request IDs not in cache, perform DB query!
@ -269,7 +269,7 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account
}
func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), "<"+accountID, page, func() ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs, "<"+accountID, page, func() ([]string, error) {
var followReqIDs []string
// Follow request IDs not in cache, perform DB query!
@ -284,7 +284,7 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco
}
func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.BlockIDs(), accountID, page, func() ([]string, error) {
return loadPagedIDs(r.state.Caches.GTS.BlockIDs, accountID, page, func() ([]string, error) {
var blockIDs []string
// Block IDs not in cache, perform DB query!

View file

@ -20,12 +20,14 @@ package bundb
import (
"context"
"errors"
"slices"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@ -86,7 +88,7 @@ func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmod
func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) {
return r.getBlock(
ctx,
"AccountID.TargetAccountID",
"AccountID,TargetAccountID",
func(block *gtsmodel.Block) error {
return r.db.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.account_id"), sourceAccountID).
@ -99,27 +101,68 @@ func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, t
}
func (r *relationshipDB) GetBlocksByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Block, error) {
// Preallocate slice of expected length.
blocks := make([]*gtsmodel.Block, 0, len(ids))
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
for _, id := range ids {
// Fetch block model for this ID.
block, err := r.GetBlockByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting block %q: %v", id, err)
continue
}
// Load all blocks IDs via cache loader callbacks.
blocks, err := r.state.Caches.GTS.Block.Load("ID",
// Append to return slice.
blocks = append(blocks, block)
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached block loader function.
func() ([]*gtsmodel.Block, error) {
// Preallocate expected length of uncached blocks.
blocks := make([]*gtsmodel.Block, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := r.db.NewSelect().
Model(&blocks).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return blocks, nil
},
)
if err != nil {
return nil, err
}
// Reorder the blocks by their
// IDs to ensure in correct order.
getID := func(b *gtsmodel.Block) string { return b.ID }
util.OrderBy(blocks, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return blocks, nil
}
// Populate all loaded blocks, removing those we fail to
// populate (removes needing so many nil checks everywhere).
blocks = slices.DeleteFunc(blocks, func(block *gtsmodel.Block) bool {
if err := r.PopulateBlock(ctx, block); err != nil {
log.Errorf(ctx, "error populating block %s: %v", block.ID, err)
return true
}
return false
})
return blocks, nil
}
func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) {
// Fetch block from cache with loader callback
block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) {
block, err := r.state.Caches.GTS.Block.LoadOne(lookup, func() (*gtsmodel.Block, error) {
var block gtsmodel.Block
// Not cached! Perform database query
@ -148,8 +191,8 @@ func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery fu
func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Block) error {
var (
errs gtserror.MultiError
err error
errs = gtserror.NewMultiError(2)
)
if block.Account == nil {
@ -178,7 +221,7 @@ func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Bloc
}
func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error {
return r.state.Caches.GTS.Block().Store(block, func() error {
return r.state.Caches.GTS.Block.Store(block, func() error {
_, err := r.db.NewInsert().Model(block).Exec(ctx)
return err
})
@ -198,7 +241,7 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
}
// Drop this now-cached block on return after delete.
defer r.state.Caches.GTS.Block().Invalidate("ID", id)
defer r.state.Caches.GTS.Block.Invalidate("ID", id)
// Finally delete block from DB.
_, err = r.db.NewDelete().
@ -222,7 +265,7 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error
}
// Drop this now-cached block on return after delete.
defer r.state.Caches.GTS.Block().Invalidate("URI", uri)
defer r.state.Caches.GTS.Block.Invalidate("URI", uri)
// Finally delete block from DB.
_, err = r.db.NewDelete().
@ -251,22 +294,20 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri
defer func() {
// Invalidate all account's incoming / outoing blocks on return.
r.state.Caches.GTS.Block().Invalidate("AccountID", accountID)
r.state.Caches.GTS.Block().Invalidate("TargetAccountID", accountID)
r.state.Caches.GTS.Block.Invalidate("AccountID", accountID)
r.state.Caches.GTS.Block.Invalidate("TargetAccountID", accountID)
}()
// Load all blocks into cache, this *really* isn't great
// but it is the only way we can ensure we invalidate all
// related caches correctly (e.g. visibility).
for _, id := range blockIDs {
_, err := r.GetBlockByID(ctx, id)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
_, err := r.GetAccountBlocks(ctx, accountID, nil)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
// Finally delete all from DB.
_, err := r.db.NewDelete().
_, err = r.db.NewDelete().
Table("blocks").
Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)).
Exec(ctx)

View file

@ -21,6 +21,7 @@ import (
"context"
"errors"
"fmt"
"slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@ -28,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@ -62,7 +64,7 @@ func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmo
func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
return r.getFollow(
ctx,
"AccountID.TargetAccountID",
"AccountID,TargetAccountID",
func(follow *gtsmodel.Follow) error {
return r.db.NewSelect().
Model(follow).
@ -76,21 +78,62 @@ func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string,
}
func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) {
// Preallocate slice of expected length.
follows := make([]*gtsmodel.Follow, 0, len(ids))
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
for _, id := range ids {
// Fetch follow model for this ID.
follow, err := r.GetFollowByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting follow %q: %v", id, err)
continue
}
// Load all follow IDs via cache loader callbacks.
follows, err := r.state.Caches.GTS.Follow.Load("ID",
// Append to return slice.
follows = append(follows, follow)
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached follow loader function.
func() ([]*gtsmodel.Follow, error) {
// Preallocate expected length of uncached follows.
follows := make([]*gtsmodel.Follow, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := r.db.NewSelect().
Model(&follows).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return follows, nil
},
)
if err != nil {
return nil, err
}
// Reorder the follows by their
// IDs to ensure in correct order.
getID := func(f *gtsmodel.Follow) string { return f.ID }
util.OrderBy(follows, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return follows, nil
}
// Populate all loaded follows, removing those we fail to
// populate (removes needing so many nil checks everywhere).
follows = slices.DeleteFunc(follows, func(follow *gtsmodel.Follow) bool {
if err := r.PopulateFollow(ctx, follow); err != nil {
log.Errorf(ctx, "error populating follow %s: %v", follow.ID, err)
return true
}
return false
})
return follows, nil
}
@ -130,7 +173,7 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 strin
func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) {
// Fetch follow from database cache with loader callback
follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) {
follow, err := r.state.Caches.GTS.Follow.LoadOne(lookup, func() (*gtsmodel.Follow, error) {
var follow gtsmodel.Follow
// Not cached! Perform database query
@ -189,7 +232,7 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo
}
func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
return r.state.Caches.GTS.Follow().Store(follow, func() error {
return r.state.Caches.GTS.Follow.Store(follow, func() error {
_, err := r.db.NewInsert().Model(follow).Exec(ctx)
return err
})
@ -202,7 +245,7 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll
columns = append(columns, "updated_at")
}
return r.state.Caches.GTS.Follow().Store(follow, func() error {
return r.state.Caches.GTS.Follow.Store(follow, func() error {
if _, err := r.db.NewUpdate().
Model(follow).
Where("? = ?", bun.Ident("follow.id"), follow.ID).
@ -250,7 +293,7 @@ func (r *relationshipDB) DeleteFollow(ctx context.Context, sourceAccountID strin
}
// Drop this now-cached follow on return after delete.
defer r.state.Caches.GTS.Follow().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID)
defer r.state.Caches.GTS.Follow.Invalidate("AccountID,TargetAccountID", sourceAccountID, targetAccountID)
// Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID)
@ -270,7 +313,7 @@ func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error
}
// Drop this now-cached follow on return after delete.
defer r.state.Caches.GTS.Follow().Invalidate("ID", id)
defer r.state.Caches.GTS.Follow.Invalidate("ID", id)
// Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID)
@ -290,7 +333,7 @@ func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) erro
}
// Drop this now-cached follow on return after delete.
defer r.state.Caches.GTS.Follow().Invalidate("URI", uri)
defer r.state.Caches.GTS.Follow.Invalidate("URI", uri)
// Finally delete follow from DB.
return r.deleteFollow(ctx, follow.ID)
@ -316,22 +359,30 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str
defer func() {
// Invalidate all account's incoming / outoing follows on return.
r.state.Caches.GTS.Follow().Invalidate("AccountID", accountID)
r.state.Caches.GTS.Follow().Invalidate("TargetAccountID", accountID)
r.state.Caches.GTS.Follow.Invalidate("AccountID", accountID)
r.state.Caches.GTS.Follow.Invalidate("TargetAccountID", accountID)
}()
// Load all follows into cache, this *really* isn't great
// but it is the only way we can ensure we invalidate all
// related caches correctly (e.g. visibility).
for _, id := range followIDs {
follow, err := r.GetFollowByID(ctx, id)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
_, err := r.GetAccountFollows(ctx, accountID, nil)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
// Delete each follow from DB.
if err := r.deleteFollow(ctx, follow.ID); err != nil &&
!errors.Is(err, db.ErrNoEntries) {
// Delete all follows from DB.
_, err = r.db.NewDelete().
Table("follows").
Where("? IN (?)", bun.Ident("id"), bun.In(followIDs)).
Exec(ctx)
if err != nil {
return err
}
for _, id := range followIDs {
// Finally, delete all list entries associated with each follow ID.
if err := r.state.DB.DeleteListEntriesForFollowID(ctx, id); err != nil {
return err
}
}

View file

@ -20,6 +20,7 @@ package bundb
import (
"context"
"errors"
"slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@ -27,6 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@ -61,7 +63,7 @@ func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string)
func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) {
return r.getFollowRequest(
ctx,
"AccountID.TargetAccountID",
"AccountID,TargetAccountID",
func(followReq *gtsmodel.FollowRequest) error {
return r.db.NewSelect().
Model(followReq).
@ -75,22 +77,63 @@ func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID s
}
func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) {
// Preallocate slice of expected length.
followReqs := make([]*gtsmodel.FollowRequest, 0, len(ids))
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
for _, id := range ids {
// Fetch follow request model for this ID.
followReq, err := r.GetFollowRequestByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting follow request %q: %v", id, err)
continue
}
// Load all follow IDs via cache loader callbacks.
follows, err := r.state.Caches.GTS.FollowRequest.Load("ID",
// Append to return slice.
followReqs = append(followReqs, followReq)
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached follow req loader function.
func() ([]*gtsmodel.FollowRequest, error) {
// Preallocate expected length of uncached followReqs.
follows := make([]*gtsmodel.FollowRequest, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := r.db.NewSelect().
Model(&follows).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return follows, nil
},
)
if err != nil {
return nil, err
}
return followReqs, nil
// Reorder the requests by their
// IDs to ensure in correct order.
getID := func(f *gtsmodel.FollowRequest) string { return f.ID }
util.OrderBy(follows, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return follows, nil
}
// Populate all loaded followreqs, removing those we fail to
// populate (removes needing so many nil checks everywhere).
follows = slices.DeleteFunc(follows, func(follow *gtsmodel.FollowRequest) bool {
if err := r.PopulateFollowRequest(ctx, follow); err != nil {
log.Errorf(ctx, "error populating follow request %s: %v", follow.ID, err)
return true
}
return false
})
return follows, nil
}
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) {
@ -107,7 +150,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID
func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) {
// Fetch follow request from database cache with loader callback
followReq, err := r.state.Caches.GTS.FollowRequest().Load(lookup, func() (*gtsmodel.FollowRequest, error) {
followReq, err := r.state.Caches.GTS.FollowRequest.LoadOne(lookup, func() (*gtsmodel.FollowRequest, error) {
var followReq gtsmodel.FollowRequest
// Not cached! Perform database query
@ -166,7 +209,7 @@ func (r *relationshipDB) PopulateFollowRequest(ctx context.Context, follow *gtsm
}
func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error {
return r.state.Caches.GTS.FollowRequest().Store(follow, func() error {
return r.state.Caches.GTS.FollowRequest.Store(follow, func() error {
_, err := r.db.NewInsert().Model(follow).Exec(ctx)
return err
})
@ -179,7 +222,7 @@ func (r *relationshipDB) UpdateFollowRequest(ctx context.Context, followRequest
columns = append(columns, "updated_at")
}
return r.state.Caches.GTS.FollowRequest().Store(followRequest, func() error {
return r.state.Caches.GTS.FollowRequest.Store(followRequest, func() error {
if _, err := r.db.NewUpdate().
Model(followRequest).
Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
@ -212,7 +255,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI
Notify: followReq.Notify,
}
if err := r.state.Caches.GTS.Follow().Store(follow, func() error {
if err := r.state.Caches.GTS.Follow.Store(follow, func() error {
// If the follow already exists, just
// replace the URI with the new one.
_, err := r.db.
@ -274,7 +317,7 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI
}
// Drop this now-cached follow request on return after delete.
defer r.state.Caches.GTS.FollowRequest().Invalidate("AccountID.TargetAccountID", sourceAccountID, targetAccountID)
defer r.state.Caches.GTS.FollowRequest.Invalidate("AccountID,TargetAccountID", sourceAccountID, targetAccountID)
// Finally delete followreq from DB.
_, err = r.db.NewDelete().
@ -298,7 +341,7 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string)
}
// Drop this now-cached follow request on return after delete.
defer r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
defer r.state.Caches.GTS.FollowRequest.Invalidate("ID", id)
// Finally delete followreq from DB.
_, err = r.db.NewDelete().
@ -322,7 +365,7 @@ func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri strin
}
// Drop this now-cached follow request on return after delete.
defer r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri)
defer r.state.Caches.GTS.FollowRequest.Invalidate("URI", uri)
// Finally delete followreq from DB.
_, err = r.db.NewDelete().
@ -352,22 +395,20 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun
defer func() {
// Invalidate all account's incoming / outoing follow requests on return.
r.state.Caches.GTS.FollowRequest().Invalidate("AccountID", accountID)
r.state.Caches.GTS.FollowRequest().Invalidate("TargetAccountID", accountID)
r.state.Caches.GTS.FollowRequest.Invalidate("AccountID", accountID)
r.state.Caches.GTS.FollowRequest.Invalidate("TargetAccountID", accountID)
}()
// Load all followreqs into cache, this *really* isn't
// great but it is the only way we can ensure we invalidate
// all related caches correctly (e.g. visibility).
for _, id := range followReqIDs {
_, err := r.GetFollowRequestByID(ctx, id)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
_, err := r.GetAccountFollowRequests(ctx, accountID, nil)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
// Finally delete all from DB.
_, err := r.db.NewDelete().
_, err = r.db.NewDelete().
Table("follow_requests").
Where("? IN (?)", bun.Ident("id"), bun.In(followReqIDs)).
Exec(ctx)

View file

@ -30,7 +30,7 @@ import (
func (r *relationshipDB) GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) {
return r.getNote(
ctx,
"AccountID.TargetAccountID",
"AccountID,TargetAccountID",
func(note *gtsmodel.AccountNote) error {
return r.db.NewSelect().Model(note).
Where("? = ?", bun.Ident("account_id"), sourceAccountID).
@ -44,7 +44,7 @@ func (r *relationshipDB) GetNote(ctx context.Context, sourceAccountID string, ta
func (r *relationshipDB) getNote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.AccountNote) error, keyParts ...any) (*gtsmodel.AccountNote, error) {
// Fetch note from cache with loader callback
note, err := r.state.Caches.GTS.AccountNote().Load(lookup, func() (*gtsmodel.AccountNote, error) {
note, err := r.state.Caches.GTS.AccountNote.LoadOne(lookup, func() (*gtsmodel.AccountNote, error) {
var note gtsmodel.AccountNote
// Not cached! Perform database query
@ -105,7 +105,7 @@ func (r *relationshipDB) PopulateNote(ctx context.Context, note *gtsmodel.Accoun
func (r *relationshipDB) PutNote(ctx context.Context, note *gtsmodel.AccountNote) error {
note.UpdatedAt = time.Now()
return r.state.Caches.GTS.AccountNote().Store(note, func() error {
return r.state.Caches.GTS.AccountNote.Store(note, func() error {
_, err := r.db.
NewInsert().
Model(note).

View file

@ -120,7 +120,7 @@ func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID str
func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, error) {
// Fetch report from database cache with loader callback
report, err := r.state.Caches.GTS.Report().Load(lookup, func() (*gtsmodel.Report, error) {
report, err := r.state.Caches.GTS.Report.LoadOne(lookup, func() (*gtsmodel.Report, error) {
var report gtsmodel.Report
// Not cached! Perform database query
@ -215,7 +215,7 @@ func (r *reportDB) PopulateReport(ctx context.Context, report *gtsmodel.Report)
}
func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) error {
return r.state.Caches.GTS.Report().Store(report, func() error {
return r.state.Caches.GTS.Report.Store(report, func() error {
_, err := r.db.NewInsert().Model(report).Exec(ctx)
return err
})
@ -237,12 +237,12 @@ func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, co
return nil, err
}
r.state.Caches.GTS.Report().Invalidate("ID", report.ID)
r.state.Caches.GTS.Report.Invalidate("ID", report.ID)
return report, nil
}
func (r *reportDB) DeleteReportByID(ctx context.Context, id string) error {
defer r.state.Caches.GTS.Report().Invalidate("ID", id)
defer r.state.Caches.GTS.Report.Invalidate("ID", id)
// Load status into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate

View file

@ -125,7 +125,7 @@ func (r *ruleDB) PutRule(ctx context.Context, rule *gtsmodel.Rule) error {
}
// invalidate cached local instance response, so it gets updated with the new rules
r.state.Caches.GTS.Instance().Invalidate("Domain", config.GetHost())
r.state.Caches.GTS.Instance.Invalidate("Domain", config.GetHost())
return nil
}
@ -143,7 +143,7 @@ func (r *ruleDB) UpdateRule(ctx context.Context, rule *gtsmodel.Rule) (*gtsmodel
}
// invalidate cached local instance response, so it gets updated with the new rules
r.state.Caches.GTS.Instance().Invalidate("Domain", config.GetHost())
r.state.Caches.GTS.Instance.Invalidate("Domain", config.GetHost())
return rule, nil
}

View file

@ -20,6 +20,7 @@ package bundb
import (
"context"
"errors"
"slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
@ -28,6 +29,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@ -48,20 +50,62 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat
}
func (s *statusDB) GetStatusesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Status, error) {
statuses := make([]*gtsmodel.Status, 0, len(ids))
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
for _, id := range ids {
// Attempt to fetch status from DB.
status, err := s.GetStatusByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting status %q: %v", id, err)
continue
}
// Load all status IDs via cache loader callbacks.
statuses, err := s.state.Caches.GTS.Status.Load("ID",
// Append status to return slice.
statuses = append(statuses, status)
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached statuses loader function.
func() ([]*gtsmodel.Status, error) {
// Preallocate expected length of uncached statuses.
statuses := make([]*gtsmodel.Status, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) status IDs.
if err := s.db.NewSelect().
Model(&statuses).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return statuses, nil
},
)
if err != nil {
return nil, err
}
// Reorder the statuses by their
// IDs to ensure in correct order.
getID := func(s *gtsmodel.Status) string { return s.ID }
util.OrderBy(statuses, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return statuses, nil
}
// Populate all loaded statuses, removing those we fail to
// populate (removes needing so many nil checks everywhere).
statuses = slices.DeleteFunc(statuses, func(status *gtsmodel.Status) bool {
if err := s.PopulateStatus(ctx, status); err != nil {
log.Errorf(ctx, "error populating status %s: %v", status.ID, err)
return true
}
return false
})
return statuses, nil
}
@ -101,7 +145,7 @@ func (s *statusDB) GetStatusByPollID(ctx context.Context, pollID string) (*gtsmo
func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccountID string) (*gtsmodel.Status, error) {
return s.getStatus(
ctx,
"BoostOfID.AccountID",
"BoostOfID,AccountID",
func(status *gtsmodel.Status) error {
return s.db.NewSelect().Model(status).
Where("status.boost_of_id = ?", boostOfID).
@ -120,7 +164,7 @@ func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccou
func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Status) error, keyParts ...any) (*gtsmodel.Status, error) {
// Fetch status from database cache with loader callback
status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) {
status, err := s.state.Caches.GTS.Status.LoadOne(lookup, func() (*gtsmodel.Status, error) {
var status gtsmodel.Status
// Not cached! Perform database query.
@ -282,7 +326,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error {
return s.state.Caches.GTS.Status().Store(status, func() error {
return s.state.Caches.GTS.Status.Store(status, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
@ -366,7 +410,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
columns = append(columns, "updated_at")
}
return s.state.Caches.GTS.Status().Store(status, func() error {
return s.state.Caches.GTS.Status.Store(status, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
@ -463,7 +507,7 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error {
}
// On return ensure status invalidated from cache.
defer s.state.Caches.GTS.Status().Invalidate("ID", id)
defer s.state.Caches.GTS.Status.Invalidate("ID", id)
return s.db.RunInTx(ctx, func(tx Tx) error {
// delete links between this status and any emojis it uses
@ -585,7 +629,7 @@ func (s *statusDB) CountStatusReplies(ctx context.Context, statusID string) (int
}
func (s *statusDB) getStatusReplyIDs(ctx context.Context, statusID string) ([]string, error) {
return s.state.Caches.GTS.InReplyToIDs().Load(statusID, func() ([]string, error) {
return s.state.Caches.GTS.InReplyToIDs.Load(statusID, func() ([]string, error) {
var statusIDs []string
// Status reply IDs not in cache, perform DB query!
@ -629,7 +673,7 @@ func (s *statusDB) CountStatusBoosts(ctx context.Context, statusID string) (int,
}
func (s *statusDB) getStatusBoostIDs(ctx context.Context, statusID string) ([]string, error) {
return s.state.Caches.GTS.BoostOfIDs().Load(statusID, func() ([]string, error) {
return s.state.Caches.GTS.BoostOfIDs.Load(statusID, func() ([]string, error) {
var statusIDs []string
// Status boost IDs not in cache, perform DB query!

View file

@ -22,6 +22,7 @@ import (
"database/sql"
"errors"
"fmt"
"slices"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
@ -29,6 +30,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
@ -40,7 +42,7 @@ type statusFaveDB struct {
func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, error) {
return s.getStatusFave(
ctx,
"AccountID.StatusID",
"AccountID,StatusID",
func(fave *gtsmodel.StatusFave) error {
return s.db.
NewSelect().
@ -77,7 +79,7 @@ func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmo
func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery func(*gtsmodel.StatusFave) error, keyParts ...any) (*gtsmodel.StatusFave, error) {
// Fetch status fave from database cache with loader callback
fave, err := s.state.Caches.GTS.StatusFave().Load(lookup, func() (*gtsmodel.StatusFave, error) {
fave, err := s.state.Caches.GTS.StatusFave.LoadOne(lookup, func() (*gtsmodel.StatusFave, error) {
var fave gtsmodel.StatusFave
// Not cached! Perform database query.
@ -111,19 +113,62 @@ func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*
return nil, err
}
// Preallocate a slice of expected status fave capacity.
faves := make([]*gtsmodel.StatusFave, 0, len(faveIDs))
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(faveIDs))
for _, id := range faveIDs {
// Fetch status fave model for each ID.
fave, err := s.GetStatusFaveByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting status fave %q: %v", id, err)
continue
}
faves = append(faves, fave)
// Load all fave IDs via cache loader callbacks.
faves, err := s.state.Caches.GTS.StatusFave.Load("ID",
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range faveIDs {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached status faves loader function.
func() ([]*gtsmodel.StatusFave, error) {
// Preallocate expected length of uncached faves.
faves := make([]*gtsmodel.StatusFave, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) fave IDs.
if err := s.db.NewSelect().
Model(&faves).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return faves, nil
},
)
if err != nil {
return nil, err
}
// Reorder the statuses by their
// IDs to ensure in correct order.
getID := func(f *gtsmodel.StatusFave) string { return f.ID }
util.OrderBy(faves, faveIDs, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return faves, nil
}
// Populate all loaded faves, removing those we fail to
// populate (removes needing so many nil checks everywhere).
faves = slices.DeleteFunc(faves, func(fave *gtsmodel.StatusFave) bool {
if err := s.PopulateStatusFave(ctx, fave); err != nil {
log.Errorf(ctx, "error populating fave %s: %v", fave.ID, err)
return true
}
return false
})
return faves, nil
}
@ -141,7 +186,7 @@ func (s *statusFaveDB) CountStatusFaves(ctx context.Context, statusID string) (i
}
func (s *statusFaveDB) getStatusFaveIDs(ctx context.Context, statusID string) ([]string, error) {
return s.state.Caches.GTS.StatusFaveIDs().Load(statusID, func() ([]string, error) {
return s.state.Caches.GTS.StatusFaveIDs.Load(statusID, func() ([]string, error) {
var faveIDs []string
// Status fave IDs not in cache, perform DB query!
@ -201,7 +246,7 @@ func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmo
}
func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) error {
return s.state.Caches.GTS.StatusFave().Store(fave, func() error {
return s.state.Caches.GTS.StatusFave.Store(fave, func() error {
_, err := s.db.
NewInsert().
Model(fave).
@ -230,10 +275,10 @@ func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) erro
if statusID != "" {
// Invalidate any cached status faves for this status.
s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
s.state.Caches.GTS.StatusFave.Invalidate("ID", id)
// Invalidate any cached status fave IDs for this status.
s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID)
s.state.Caches.GTS.StatusFaveIDs.Invalidate(statusID)
}
return nil
@ -270,17 +315,15 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
return err
}
// Collate (deduplicating) status IDs.
statusIDs = collate(func(i int) string {
return statusIDs[i]
}, len(statusIDs))
// Deduplicate determined status IDs.
statusIDs = util.Deduplicate(statusIDs)
for _, id := range statusIDs {
// Invalidate any cached status faves for this status.
s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
s.state.Caches.GTS.StatusFave.Invalidate("ID", id)
// Invalidate any cached status fave IDs for this status.
s.state.Caches.GTS.StatusFaveIDs().Invalidate(id)
s.state.Caches.GTS.StatusFaveIDs.Invalidate(id)
}
return nil
@ -296,10 +339,10 @@ func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID
}
// Invalidate any cached status faves for this status.
s.state.Caches.GTS.StatusFave().Invalidate("ID", statusID)
s.state.Caches.GTS.StatusFave.Invalidate("ID", statusID)
// Invalidate any cached status fave IDs for this status.
s.state.Caches.GTS.StatusFaveIDs().Invalidate(statusID)
s.state.Caches.GTS.StatusFaveIDs.Invalidate(statusID)
return nil
}

View file

@ -22,21 +22,21 @@ import (
"strings"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
type tagDB struct {
conn *DB
db *DB
state *state.State
}
func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) {
return m.state.Caches.GTS.Tag().Load("ID", func() (*gtsmodel.Tag, error) {
func (t *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) {
return t.state.Caches.GTS.Tag.LoadOne("ID", func() (*gtsmodel.Tag, error) {
var tag gtsmodel.Tag
q := m.conn.
q := t.db.
NewSelect().
Model(&tag).
Where("? = ?", bun.Ident("tag.id"), id)
@ -49,15 +49,15 @@ func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) {
}, id)
}
func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) {
func (t *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, error) {
// Normalize 'name' string.
name = strings.TrimSpace(name)
name = strings.ToLower(name)
return m.state.Caches.GTS.Tag().Load("Name", func() (*gtsmodel.Tag, error) {
return t.state.Caches.GTS.Tag.LoadOne("Name", func() (*gtsmodel.Tag, error) {
var tag gtsmodel.Tag
q := m.conn.
q := t.db.
NewSelect().
Model(&tag).
Where("? = ?", bun.Ident("tag.name"), name)
@ -70,25 +70,52 @@ func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, e
}, name)
}
func (m *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) {
tags := make([]*gtsmodel.Tag, 0, len(ids))
func (t *tagDB) GetTags(ctx context.Context, ids []string) ([]*gtsmodel.Tag, error) {
// Preallocate at-worst possible length.
uncached := make([]string, 0, len(ids))
for _, id := range ids {
// Attempt fetch from DB
tag, err := m.GetTag(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting tag %q: %v", id, err)
continue
}
// Load all tag IDs via cache loader callbacks.
tags, err := t.state.Caches.GTS.Tag.Load("ID",
// Append tag
tags = append(tags, tag)
// Load cached + check for uncached.
func(load func(keyParts ...any) bool) {
for _, id := range ids {
if !load(id) {
uncached = append(uncached, id)
}
}
},
// Uncached tag loader function.
func() ([]*gtsmodel.Tag, error) {
// Preallocate expected length of uncached tags.
tags := make([]*gtsmodel.Tag, 0, len(uncached))
// Perform database query scanning
// the remaining (uncached) IDs.
if err := t.db.NewSelect().
Model(&tags).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return tags, nil
},
)
if err != nil {
return nil, err
}
// Reorder the tags by their
// IDs to ensure in correct order.
getID := func(t *gtsmodel.Tag) string { return t.ID }
util.OrderBy(tags, ids, getID)
return tags, nil
}
func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error {
func (t *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error {
// Normalize 'name' string before it enters
// the db, without changing tag we were given.
//
@ -101,8 +128,8 @@ func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error {
t2.Name = strings.ToLower(t2.Name)
// Insert the copy.
if err := m.state.Caches.GTS.Tag().Store(t2, func() error {
_, err := m.conn.NewInsert().Model(t2).Exec(ctx)
if err := t.state.Caches.GTS.Tag.Store(t2, func() error {
_, err := t.db.NewInsert().Model(t2).Exec(ctx)
return err
}); err != nil {
return err // err already processed

View file

@ -42,7 +42,7 @@ func (t *threadDB) PutThread(ctx context.Context, thread *gtsmodel.Thread) error
}
func (t *threadDB) GetThreadMute(ctx context.Context, id string) (*gtsmodel.ThreadMute, error) {
return t.state.Caches.GTS.ThreadMute().Load("ID", func() (*gtsmodel.ThreadMute, error) {
return t.state.Caches.GTS.ThreadMute.LoadOne("ID", func() (*gtsmodel.ThreadMute, error) {
var threadMute gtsmodel.ThreadMute
q := t.db.
@ -63,7 +63,7 @@ func (t *threadDB) GetThreadMutedByAccount(
threadID string,
accountID string,
) (*gtsmodel.ThreadMute, error) {
return t.state.Caches.GTS.ThreadMute().Load("ThreadID.AccountID", func() (*gtsmodel.ThreadMute, error) {
return t.state.Caches.GTS.ThreadMute.LoadOne("ThreadID,AccountID", func() (*gtsmodel.ThreadMute, error) {
var threadMute gtsmodel.ThreadMute
q := t.db.
@ -98,7 +98,7 @@ func (t *threadDB) IsThreadMutedByAccount(
}
func (t *threadDB) PutThreadMute(ctx context.Context, threadMute *gtsmodel.ThreadMute) error {
return t.state.Caches.GTS.ThreadMute().Store(threadMute, func() error {
return t.state.Caches.GTS.ThreadMute.Store(threadMute, func() error {
_, err := t.db.NewInsert().Model(threadMute).Exec(ctx)
return err
})
@ -112,6 +112,6 @@ func (t *threadDB) DeleteThreadMute(ctx context.Context, id string) error {
return err
}
t.state.Caches.GTS.ThreadMute().Invalidate("ID", id)
t.state.Caches.GTS.ThreadMute.Invalidate("ID", id)
return nil
}

View file

@ -29,7 +29,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
)
@ -155,20 +154,8 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
}
}
statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
for _, id := range statusIDs {
// Fetch status from db for ID
status, err := t.state.DB.GetStatusByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching status %q: %v", id, err)
continue
}
// Append status to slice
statuses = append(statuses, status)
}
return statuses, nil
// Return status IDs loaded from cache + db.
return t.state.DB.GetStatusesByIDs(ctx, statusIDs)
}
func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
@ -256,20 +243,8 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI
}
}
statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
for _, id := range statusIDs {
// Fetch status from db for ID
status, err := t.state.DB.GetStatusByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching status %q: %v", id, err)
continue
}
// Append status to slice
statuses = append(statuses, status)
}
return statuses, nil
// Return status IDs loaded from cache + db.
return t.state.DB.GetStatusesByIDs(ctx, statusIDs)
}
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
@ -323,18 +298,15 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
}
})
statuses := make([]*gtsmodel.Status, 0, len(faves))
// Convert fave IDs to status IDs.
statusIDs := make([]string, len(faves))
for i, fave := range faves {
statusIDs[i] = fave.StatusID
}
for _, fave := range faves {
// Fetch status from db for corresponding favourite
status, err := t.state.DB.GetStatusByID(ctx, fave.StatusID)
if err != nil {
log.Errorf(ctx, "error fetching status for fave %q: %v", fave.ID, err)
continue
}
// Append status to slice
statuses = append(statuses, status)
statuses, err := t.state.DB.GetStatusesByIDs(ctx, statusIDs)
if err != nil {
return nil, "", "", err
}
nextMaxID := faves[len(faves)-1].ID
@ -453,20 +425,8 @@ func (t *timelineDB) GetListTimeline(
}
}
statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
for _, id := range statusIDs {
// Fetch status from db for ID
status, err := t.state.DB.GetStatusByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching status %q: %v", id, err)
continue
}
// Append status to slice
statuses = append(statuses, status)
}
return statuses, nil
// Return status IDs loaded from cache + db.
return t.state.DB.GetStatusesByIDs(ctx, statusIDs)
}
func (t *timelineDB) GetTagTimeline(
@ -561,18 +521,6 @@ func (t *timelineDB) GetTagTimeline(
}
}
statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
for _, id := range statusIDs {
// Fetch status from db for ID
status, err := t.state.DB.GetStatusByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error fetching status %q: %v", id, err)
continue
}
// Append status to slice
statuses = append(statuses, status)
}
return statuses, nil
// Return status IDs loaded from cache + db.
return t.state.DB.GetStatusesByIDs(ctx, statusIDs)
}

View file

@ -32,7 +32,7 @@ type tombstoneDB struct {
}
func (t *tombstoneDB) GetTombstoneByURI(ctx context.Context, uri string) (*gtsmodel.Tombstone, error) {
return t.state.Caches.GTS.Tombstone().Load("URI", func() (*gtsmodel.Tombstone, error) {
return t.state.Caches.GTS.Tombstone.LoadOne("URI", func() (*gtsmodel.Tombstone, error) {
var tomb gtsmodel.Tombstone
q := t.db.
@ -57,7 +57,7 @@ func (t *tombstoneDB) TombstoneExistsWithURI(ctx context.Context, uri string) (b
}
func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tombstone) error {
return t.state.Caches.GTS.Tombstone().Store(tombstone, func() error {
return t.state.Caches.GTS.Tombstone.Store(tombstone, func() error {
_, err := t.db.
NewInsert().
Model(tombstone).
@ -67,7 +67,7 @@ func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tomb
}
func (t *tombstoneDB) DeleteTombstone(ctx context.Context, id string) error {
defer t.state.Caches.GTS.Tombstone().Invalidate("ID", id)
defer t.state.Caches.GTS.Tombstone.Invalidate("ID", id)
// Delete tombstone from DB.
_, err := t.db.NewDelete().

View file

@ -116,7 +116,7 @@ func (u *userDB) GetUserByConfirmationToken(ctx context.Context, token string) (
func (u *userDB) getUser(ctx context.Context, lookup string, dbQuery func(*gtsmodel.User) error, keyParts ...any) (*gtsmodel.User, error) {
// Fetch user from database cache with loader callback.
user, err := u.state.Caches.GTS.User().Load(lookup, func() (*gtsmodel.User, error) {
user, err := u.state.Caches.GTS.User.LoadOne(lookup, func() (*gtsmodel.User, error) {
var user gtsmodel.User
// Not cached! perform database query.
@ -179,7 +179,7 @@ func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) {
}
func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error {
return u.state.Caches.GTS.User().Store(user, func() error {
return u.state.Caches.GTS.User.Store(user, func() error {
_, err := u.db.
NewInsert().
Model(user).
@ -197,7 +197,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..
columns = append(columns, "updated_at")
}
return u.state.Caches.GTS.User().Store(user, func() error {
return u.state.Caches.GTS.User.Store(user, func() error {
_, err := u.db.
NewUpdate().
Model(user).
@ -209,7 +209,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ..
}
func (u *userDB) DeleteUserByID(ctx context.Context, userID string) error {
defer u.state.Caches.GTS.User().Invalidate("ID", userID)
defer u.state.Caches.GTS.User.Invalidate("ID", userID)
// Load user into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate

View file

@ -27,6 +27,9 @@ type List interface {
// GetListByID gets one list with the given id.
GetListByID(ctx context.Context, id string) (*gtsmodel.List, error)
// GetListsByIDs fetches all lists with the provided IDs.
GetListsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.List, error)
// GetListsForAccountID gets all lists owned by the given accountID.
GetListsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.List, error)
@ -46,6 +49,9 @@ type List interface {
// GetListEntryByID gets one list entry with the given ID.
GetListEntryByID(ctx context.Context, id string) (*gtsmodel.ListEntry, error)
// GetListEntriesyIDs fetches all list entries with the provided IDs.
GetListEntriesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.ListEntry, error)
// GetListEntries gets list entries from the given listID, using the given parameters.
GetListEntries(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.ListEntry, error)

View file

@ -33,6 +33,9 @@ type Notification interface {
// GetNotification returns one notification according to its id.
GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error)
// GetNotificationsByIDs returns a slice of notifications of the the provided IDs.
GetNotificationsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Notification, error)
// GetNotification gets one notification according to the provided parameters, if it exists.
// Since not all notifications are about a status, statusID can be an empty string.
GetNotification(ctx context.Context, notificationType gtsmodel.NotificationType, targetAccountID string, originAccountID string, statusID string) (*gtsmodel.Notification, error)

View file

@ -107,19 +107,21 @@ func (d *Dereferencer) EnrichAnnounce(
// All good baby.
case errors.Is(err, db.ErrAlreadyExists):
uri := boost.URI
// DATA RACE! We likely lost out to another goroutine
// in a call to db.Put(Status). Look again in DB by URI.
boost, err = d.state.DB.GetStatusByURI(ctx, boost.URI)
boost, err = d.state.DB.GetStatusByURI(ctx, uri)
if err != nil {
err = gtserror.Newf(
return nil, gtserror.Newf(
"error getting boost wrapper status %s from database after race: %w",
boost.URI, err,
uri, err,
)
}
default:
// Proper database error.
err = gtserror.Newf("db error inserting status: %w", err)
return nil, gtserror.Newf("db error inserting status: %w", err)
}
return boost, err

View file

@ -79,9 +79,7 @@ func (suite *AnnounceTestSuite) TestAnnounceTwice() {
// Insert the boost-of status into the
// DB cache to emulate processor handling
boost.ID, _ = id.NewULIDFromTime(boost.CreatedAt)
suite.state.Caches.GTS.Status().Store(boost, func() error {
return nil
})
suite.state.Caches.GTS.Status.Put(boost)
// only the URI will be set for the boosted status
// because it still needs to be dereferenced

View file

@ -55,7 +55,6 @@ func (m *Manager) RefetchEmojis(ctx context.Context, domain string, dereferenceM
emojis, err := m.state.DB.GetEmojisBy(ctx, domain, false, true, "", maxShortcodeDomain, "", 20)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
// an actual error has occurred
log.Errorf(ctx, "error fetching emojis from database: %s", err)
}
break

View file

@ -229,6 +229,7 @@ func (p *Processor) processMediaIDs(ctx context.Context, form *apimodel.Advanced
attachments := []*gtsmodel.MediaAttachment{}
attachmentIDs := []string{}
for _, mediaID := range form.MediaIDs {
attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {

View file

@ -82,7 +82,7 @@ func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*ur
// Attempt to deliver data to recipient.
if err := t.deliver(ctx, b, to); err != nil {
mutex.Lock() // safely append err to accumulator.
errs.Appendf("error delivering to %s: %v", to, err)
errs.Appendf("error delivering to %s: %w", to, err)
mutex.Unlock()
}
}

View file

@ -36,7 +36,8 @@ import (
func (t *transport) webfingerURLFor(targetDomain string) (string, bool) {
url := "https://" + targetDomain + "/.well-known/webfinger"
wc := t.controller.state.Caches.GTS.Webfinger()
wc := t.controller.state.Caches.GTS.Webfinger
// We're doing the manual locking/unlocking here to be able to
// safely call Cache.Get instead of Get, as the latter updates the
// item expiry which we don't want to do here
@ -95,7 +96,7 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom
// If we got a response we consider successful on a cached URL, i.e one set
// by us later on when a host-meta based webfinger request succeeded, set it
// again here to renew the TTL
t.controller.state.Caches.GTS.Webfinger().Set(targetDomain, url)
t.controller.state.Caches.GTS.Webfinger.Set(targetDomain, url)
}
if rsp.StatusCode == http.StatusGone {
return nil, fmt.Errorf("account has been deleted/is gone")
@ -151,7 +152,7 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom
// we asked for is gone. This means the endpoint itself is valid and we should
// cache it for future queries to the same domain
if rsp.StatusCode == http.StatusGone {
t.controller.state.Caches.GTS.Webfinger().Set(targetDomain, host)
t.controller.state.Caches.GTS.Webfinger.Set(targetDomain, host)
return nil, fmt.Errorf("account has been deleted/is gone")
}
// We've reached the end of the line here, both the original request
@ -162,7 +163,7 @@ func (t *transport) Finger(ctx context.Context, targetUsername string, targetDom
// Set the URL in cache here, since host-meta told us this should be the
// valid one, it's different from the default and our request to it did
// not fail in any manner
t.controller.state.Caches.GTS.Webfinger().Set(targetDomain, host)
t.controller.state.Caches.GTS.Webfinger.Set(targetDomain, host)
return io.ReadAll(rsp.Body)
}

View file

@ -31,7 +31,7 @@ type FingerTestSuite struct {
}
func (suite *FingerTestSuite) TestFinger() {
wc := suite.state.Caches.GTS.Webfinger()
wc := suite.state.Caches.GTS.Webfinger
suite.Equal(0, wc.Len(), "expect webfinger cache to be empty")
_, err := suite.transport.Finger(context.TODO(), "brand_new_person", "unknown-instance.com")
@ -43,7 +43,7 @@ func (suite *FingerTestSuite) TestFinger() {
}
func (suite *FingerTestSuite) TestFingerWithHostMeta() {
wc := suite.state.Caches.GTS.Webfinger()
wc := suite.state.Caches.GTS.Webfinger
suite.Equal(0, wc.Len(), "expect webfinger cache to be empty")
_, err := suite.transport.Finger(context.TODO(), "someone", "misconfigured-instance.com")
@ -60,7 +60,7 @@ func (suite *FingerTestSuite) TestFingerWithHostMetaCacheStrategy() {
suite.T().Skip("this test is flaky on CI for as of yet unknown reasons")
}
wc := suite.state.Caches.GTS.Webfinger()
wc := suite.state.Caches.GTS.Webfinger
// Reset the sweep frequency so nothing interferes with the test
wc.Stop()

View file

@ -794,7 +794,6 @@ func (c *Converter) getASAttributedToAccount(ctx context.Context, id string, wit
}
return account, nil
}
func (c *Converter) getASObjectAccount(ctx context.Context, id string, with ap.WithObject) (*gtsmodel.Account, error) {

View file

@ -491,7 +491,7 @@ func (c *Converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (ap.Stat
// tag -- mentions
mentions := s.Mentions
if len(s.MentionIDs) > len(mentions) {
if len(s.MentionIDs) != len(mentions) {
mentions, err = c.state.DB.GetMentions(ctx, s.MentionIDs)
if err != nil {
return nil, gtserror.Newf("error getting mentions: %w", err)
@ -507,14 +507,10 @@ func (c *Converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (ap.Stat
// tag -- emojis
emojis := s.Emojis
if len(s.EmojiIDs) > len(emojis) {
emojis = []*gtsmodel.Emoji{}
for _, emojiID := range s.EmojiIDs {
emoji, err := c.state.DB.GetEmojiByID(ctx, emojiID)
if err != nil {
return nil, gtserror.Newf("error getting emoji %s from database: %w", emojiID, err)
}
emojis = append(emojis, emoji)
if len(s.EmojiIDs) != len(emojis) {
emojis, err = c.state.DB.GetEmojisByIDs(ctx, s.EmojiIDs)
if err != nil {
return nil, gtserror.Newf("error getting emojis from database: %w", err)
}
}
for _, emoji := range emojis {
@ -527,7 +523,7 @@ func (c *Converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (ap.Stat
// tag -- hashtags
hashtags := s.Tags
if len(s.TagIDs) > len(hashtags) {
if len(s.TagIDs) != len(hashtags) {
hashtags, err = c.state.DB.GetTags(ctx, s.TagIDs)
if err != nil {
return nil, gtserror.Newf("error getting tags: %w", err)
@ -623,14 +619,10 @@ func (c *Converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (ap.Stat
// attachments
attachmentProp := streams.NewActivityStreamsAttachmentProperty()
attachments := s.Attachments
if len(s.AttachmentIDs) > len(attachments) {
attachments = []*gtsmodel.MediaAttachment{}
for _, attachmentID := range s.AttachmentIDs {
attachment, err := c.state.DB.GetAttachmentByID(ctx, attachmentID)
if err != nil {
return nil, gtserror.Newf("error getting attachment %s from database: %w", attachmentID, err)
}
attachments = append(attachments, attachment)
if len(s.AttachmentIDs) != len(attachments) {
attachments, err = c.state.DB.GetAttachmentsByIDs(ctx, s.AttachmentIDs)
if err != nil {
return nil, gtserror.Newf("error getting attachments from database: %w", err)
}
}
for _, a := range attachments {

View file

@ -1563,20 +1563,15 @@ func (c *Converter) PollToAPIPoll(ctx context.Context, requester *gtsmodel.Accou
func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, attachments []*gtsmodel.MediaAttachment, attachmentIDs []string) ([]*apimodel.Attachment, error) {
var errs gtserror.MultiError
if len(attachments) == 0 {
if len(attachments) == 0 && len(attachmentIDs) > 0 {
// GTS model attachments were not populated
// Preallocate expected GTS slice
attachments = make([]*gtsmodel.MediaAttachment, 0, len(attachmentIDs))
var err error
// Fetch GTS models for attachment IDs
for _, id := range attachmentIDs {
attachment, err := c.state.DB.GetAttachmentByID(ctx, id)
if err != nil {
errs.Appendf("error fetching attachment %s from database: %v", id, err)
continue
}
attachments = append(attachments, attachment)
attachments, err = c.state.DB.GetAttachmentsByIDs(ctx, attachmentIDs)
if err != nil {
errs.Appendf("error fetching attachments from database: %w", err)
}
}
@ -1587,7 +1582,7 @@ func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, atta
for _, attachment := range attachments {
apiAttachment, err := c.AttachmentToAPIAttachment(ctx, attachment)
if err != nil {
errs.Appendf("error converting attchment %s to api attachment: %v", attachment.ID, err)
errs.Appendf("error converting attchment %s to api attachment: %w", attachment.ID, err)
continue
}
apiAttachments = append(apiAttachments, &apiAttachment)
@ -1600,20 +1595,15 @@ func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, atta
func (c *Converter) convertEmojisToAPIEmojis(ctx context.Context, emojis []*gtsmodel.Emoji, emojiIDs []string) ([]apimodel.Emoji, error) {
var errs gtserror.MultiError
if len(emojis) == 0 {
if len(emojis) == 0 && len(emojiIDs) > 0 {
// GTS model attachments were not populated
// Preallocate expected GTS slice
emojis = make([]*gtsmodel.Emoji, 0, len(emojiIDs))
var err error
// Fetch GTS models for emoji IDs
for _, id := range emojiIDs {
emoji, err := c.state.DB.GetEmojiByID(ctx, id)
if err != nil {
errs.Appendf("error fetching emoji %s from database: %v", id, err)
continue
}
emojis = append(emojis, emoji)
emojis, err = c.state.DB.GetEmojisByIDs(ctx, emojiIDs)
if err != nil {
errs.Appendf("error fetching emojis from database: %w", err)
}
}
@ -1624,7 +1614,7 @@ func (c *Converter) convertEmojisToAPIEmojis(ctx context.Context, emojis []*gtsm
for _, emoji := range emojis {
apiEmoji, err := c.EmojiToAPIEmoji(ctx, emoji)
if err != nil {
errs.Appendf("error converting emoji %s to api emoji: %v", emoji.ID, err)
errs.Appendf("error converting emoji %s to api emoji: %w", emoji.ID, err)
continue
}
apiEmojis = append(apiEmojis, apiEmoji)
@ -1637,7 +1627,7 @@ func (c *Converter) convertEmojisToAPIEmojis(ctx context.Context, emojis []*gtsm
func (c *Converter) convertMentionsToAPIMentions(ctx context.Context, mentions []*gtsmodel.Mention, mentionIDs []string) ([]apimodel.Mention, error) {
var errs gtserror.MultiError
if len(mentions) == 0 {
if len(mentions) == 0 && len(mentionIDs) > 0 {
var err error
// GTS model mentions were not populated
@ -1645,7 +1635,7 @@ func (c *Converter) convertMentionsToAPIMentions(ctx context.Context, mentions [
// Fetch GTS models for mention IDs
mentions, err = c.state.DB.GetMentions(ctx, mentionIDs)
if err != nil {
errs.Appendf("error fetching mentions from database: %v", err)
errs.Appendf("error fetching mentions from database: %w", err)
}
}
@ -1656,7 +1646,7 @@ func (c *Converter) convertMentionsToAPIMentions(ctx context.Context, mentions [
for _, mention := range mentions {
apiMention, err := c.MentionToAPIMention(ctx, mention)
if err != nil {
errs.Appendf("error converting mention %s to api mention: %v", mention.ID, err)
errs.Appendf("error converting mention %s to api mention: %w", mention.ID, err)
continue
}
apiMentions = append(apiMentions, apiMention)
@ -1669,12 +1659,12 @@ func (c *Converter) convertMentionsToAPIMentions(ctx context.Context, mentions [
func (c *Converter) convertTagsToAPITags(ctx context.Context, tags []*gtsmodel.Tag, tagIDs []string) ([]apimodel.Tag, error) {
var errs gtserror.MultiError
if len(tags) == 0 {
if len(tags) == 0 && len(tagIDs) > 0 {
var err error
tags, err = c.state.DB.GetTags(ctx, tagIDs)
if err != nil {
errs.Appendf("error fetching tags from database: %v", err)
errs.Appendf("error fetching tags from database: %w", err)
}
}
@ -1685,7 +1675,7 @@ func (c *Converter) convertTagsToAPITags(ctx context.Context, tags []*gtsmodel.T
for _, tag := range tags {
apiTag, err := c.TagToAPITag(ctx, tag, false)
if err != nil {
errs.Appendf("error converting tag %s to api tag: %v", tag.ID, err)
errs.Appendf("error converting tag %s to api tag: %w", tag.ID, err)
continue
}
apiTags = append(apiTags, apiTag)

View file

@ -61,3 +61,75 @@ func DeduplicateFunc[T any, C comparable](in []T, key func(v T) C) []T {
return deduped
}
// Collate will collect the values of type K from input type []T,
// passing each item to 'get' and deduplicating the end result.
// Compared to Deduplicate() this returns []K, NOT input type []T.
func Collate[T any, K comparable](in []T, get func(T) K) []K {
ks := make([]K, 0, len(in))
km := make(map[K]struct{}, len(in))
for i := 0; i < len(in); i++ {
// Get next k.
k := get(in[i])
if _, ok := km[k]; !ok {
// New value, add
// to map + slice.
ks = append(ks, k)
km[k] = struct{}{}
}
}
return ks
}
// OrderBy orders a slice of given type by the provided alternative slice of comparable type.
func OrderBy[T any, K comparable](in []T, keys []K, key func(T) K) {
var (
start int
offset int
)
for i := 0; i < len(keys); i++ {
var (
// key at index.
k = keys[i]
// sentinel
// idx value.
idx = -1
)
// Look for model with key in slice.
for j := start; j < len(in); j++ {
if key(in[j]) == k {
idx = j
break
}
}
if idx == -1 {
// model with key
// was not found.
offset++
continue
}
// Update
// start
start++
// Expected ID index.
exp := i - offset
if idx == exp {
// Model is in expected
// location, keep going.
continue
}
// Swap models at current and expected.
in[idx], in[exp] = in[exp], in[idx]
}
}

View file

@ -39,7 +39,7 @@ func (f *Filter) AccountVisible(ctx context.Context, requester *gtsmodel.Account
requesterID = requester.ID
}
visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) {
visibility, err := f.state.Caches.Visibility.LoadOne("Type,RequesterID,ItemID", func() (*cache.CachedVisibility, error) {
// Visibility not yet cached, perform visibility lookup.
visible, err := f.isAccountVisibleTo(ctx, requester, account)
if err != nil {

View file

@ -42,7 +42,7 @@ func (f *Filter) StatusHomeTimelineable(ctx context.Context, owner *gtsmodel.Acc
requesterID = owner.ID
}
visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) {
visibility, err := f.state.Caches.Visibility.LoadOne("Type,RequesterID,ItemID", func() (*cache.CachedVisibility, error) {
// Visibility not yet cached, perform timeline visibility lookup.
visible, err := f.isStatusHomeTimelineable(ctx, owner, status)
if err != nil {

View file

@ -40,7 +40,7 @@ func (f *Filter) StatusPublicTimelineable(ctx context.Context, requester *gtsmod
requesterID = requester.ID
}
visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) {
visibility, err := f.state.Caches.Visibility.LoadOne("Type,RequesterID,ItemID", func() (*cache.CachedVisibility, error) {
// Visibility not yet cached, perform timeline visibility lookup.
visible, err := f.isStatusPublicTimelineable(ctx, requester, status)
if err != nil {

View file

@ -53,7 +53,7 @@ func (f *Filter) StatusVisible(ctx context.Context, requester *gtsmodel.Account,
requesterID = requester.ID
}
visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) {
visibility, err := f.state.Caches.Visibility.LoadOne("Type,RequesterID,ItemID", func() (*cache.CachedVisibility, error) {
// Visibility not yet cached, perform visibility lookup.
visible, err := f.isStatusVisible(ctx, requester, status)
if err != nil {