diff --git a/internal/cache/invalidate.go b/internal/cache/invalidate.go index ab7a385c0..b80c164f8 100644 --- a/internal/cache/invalidate.go +++ b/internal/cache/invalidate.go @@ -145,7 +145,12 @@ func (c *Caches) OnInvalidateFollowRequest(followReq *gtsmodel.FollowRequest) { } func (c *Caches) OnInvalidateList(list *gtsmodel.List) { - // Invalidate list ID entries. + // Invalidate list IDs cache. + c.DB.ListIDs.Invalidate( + "a" + list.AccountID, + ) + + // Invalidate listed IDs cache. c.DB.ListedIDs.Invalidate( "a"+list.ID, "f"+list.ID, diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index 4c75776d2..2a7c58772 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -149,6 +149,8 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { } func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error { + // note that inserting list will call OnInvalidateList() + // which will handle clearing caches other than List cache. return l.state.Caches.DB.List.Store(list, func() error { _, err := l.db.NewInsert().Model(list).Exec(ctx) return err @@ -162,50 +164,78 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns .. columns = append(columns, "updated_at") } - defer func() { - // Invalidate this entire list's timeline. - if err := l.state.Timelines.List.RemoveTimeline(ctx, list.ID); err != nil { - log.Errorf(ctx, "error invalidating list timeline: %q", err) - } - }() - - return l.state.Caches.DB.List.Store(list, func() error { + // Update list in the database, invalidating main list cache. + if err := l.state.Caches.DB.List.Store(list, func() error { _, err := l.db.NewUpdate(). Model(list). Where("? = ?", bun.Ident("list.id"), list.ID). Column(columns...). Exec(ctx) return err - }) + }); err != nil { + return err + } + + // Invalidate this entire list's timeline. + if err := l.state.Timelines.List.RemoveTimeline(ctx, list.ID); err != nil { + log.Errorf(ctx, "error invalidating list timeline: %q", err) + } + + return nil } func (l *listDB) DeleteListByID(ctx context.Context, id string) error { - defer func() { - // Invalidate this list from cache. - l.state.Caches.DB.List.Invalidate("ID", id) + // Acquire list owner ID. + var accountID string - // Invalidate this entire list's timeline. - if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil { - log.Errorf(ctx, "error invalidating list timeline: %q", err) - } - }() + // Gather follow IDs of all + // entries contained in list. + var followIDs []string - return l.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - // Delete all entries attached to list. + // Delete all list entries associated with list, and list itself in transaction. + if err := l.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if _, err := tx.NewDelete(). Table("list_entries"). Where("? = ?", bun.Ident("list_id"), id). - Exec(ctx); err != nil { + Returning("?", bun.Ident("follow_id")). + Exec(ctx, &followIDs); err != nil { return err } - // Delete the list itself. _, err := tx.NewDelete(). Table("lists"). Where("? = ?", bun.Ident("id"), id). - Exec(ctx) + Returning("?", bun.Ident("account_id")). + Exec(ctx, &accountID) return err - }) + }); err != nil { + return err + } + + // Invalidate the main list database cache. + l.state.Caches.DB.List.Invalidate("ID", id) + + // Invalidate account / follow IDs in list. + l.state.Caches.DB.ListedIDs.Invalidate( + "a"+id, + "f"+id, + ) + + // Generate ListID keys to invalidate. + keys := followIDs // just reuse slice. + for i, followID := range keys { + + // List IDs containing follow. + keys[i] = "f" + followID + } + + // ListIDs owned by account with ID. + keys = append(keys, "a"+accountID) + + // Invalidate ListID slice cache entries. + l.state.Caches.DB.ListIDs.Invalidate(keys...) + + return nil } func (l *listDB) getListIDsByAccountID(ctx context.Context, accountID string) ([]string, error) { @@ -358,22 +388,8 @@ func (l *listDB) PopulateListEntry(ctx context.Context, listEntry *gtsmodel.List } func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEntry) error { - defer func() { - // Collect unique list IDs from the provided entries. - listIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string { - return e.ListID - }) - - for _, id := range listIDs { - // Invalidate the timeline for the list this entry belongs to. - if err := l.state.Timelines.List.RemoveTimeline(ctx, id); err != nil { - log.Errorf(ctx, "error invalidating list timeline: %q", err) - } - } - }() - - // Finally, insert each list entry into the database. - return l.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // Insert all entries into the database in a single transaction (all or nothing!). + if err := l.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { for _, entry := range entries { if _, err := tx. NewInsert(). @@ -383,7 +399,49 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt } } return nil + }); err != nil { + return err + } + + // Collect unique list IDs from the provided list entries. + listIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string { + return e.ListID }) + + // Collect unique follow IDs from the provided list entries. + followIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string { + return e.FollowID + }) + + var keys []string + + // Generate ListedID keys to invalidate. + keys = slices.Grow(keys[:0], 2*len(listIDs)) + for _, listID := range listIDs { + keys = append(keys, + "a"+listID, + "f"+listID, + ) + + // Invalidate the timeline for the list this entry belongs to. + if err := l.state.Timelines.List.RemoveTimeline(ctx, listID); err != nil { + log.Errorf(ctx, "error invalidating list timeline: %q", err) + } + } + + // Invalidate ListedID slice cache entries. + l.state.Caches.DB.ListedIDs.Invalidate(keys...) + + // Generate ListID keys to invalidate. + keys = slices.Grow(keys[:0], len(followIDs)) + for _, followID := range followIDs { + keys = append(keys, "f"+followID) + } + + // Invalidate ListID slice cache entries. + l.state.Caches.DB.ListIDs.Invalidate(keys...) + + return nil } func (l *listDB) DeleteListEntry(ctx context.Context, listID string, followID string) error {