From d03116fea387aab6a59c0870d607eed78555ceaa Mon Sep 17 00:00:00 2001 From: kim Date: Thu, 12 Sep 2024 19:15:33 +0100 Subject: [PATCH] rename some funcs, allow deleting list entries for multiple follow IDs at a time, fix up more tests --- internal/db/bundb/list.go | 149 +++-- internal/db/bundb/list_test.go | 507 +++++++++--------- internal/db/bundb/relationship_follow.go | 15 +- internal/db/list.go | 6 +- .../processing/workers/surfacetimeline.go | 2 +- 5 files changed, 313 insertions(+), 366 deletions(-) diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index c7e0d1caf..3f609f387 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -123,7 +123,7 @@ func (l *listDB) GetAccountsInList(ctx context.Context, listID string, page *pag return l.state.DB.GetAccountsByIDs(ctx, accountIDs) } -func (l *listDB) IsAccountInListID(ctx context.Context, listID string, accountID string) (bool, error) { +func (l *listDB) IsAccountInList(ctx context.Context, listID string, accountID string) (bool, error) { accountIDs, err := l.GetAccountIDsInList(ctx, listID, nil) return slices.Contains(accountIDs, accountID), err } @@ -215,25 +215,11 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error { // Invalidate the main list database cache. l.state.Caches.DB.List.Invalidate("ID", id) - // Invalidate account / follow IDs in list. - l.state.Caches.DB.ListedIDs.Invalidate( - "a"+id, - "f"+id, - ) + // Invalidate cache of list IDs owned by account. + l.state.Caches.DB.ListIDs.Invalidate("a" + accountID) - // Generate ListID keys to invalidate. - keys := followIDs // just reuse slice. - for i, followID := range keys { - - // List IDs containing follow. - keys[i] = "f" + followID - } - - // ListIDs owned by account with ID. - keys = append(keys, "a"+accountID) - - // Invalidate ListID slice cache entries. - l.state.Caches.DB.ListIDs.Invalidate(keys...) + // Invalidate all related entry caches for this list. + l.invalidateEntryCaches(ctx, []string{id}, followIDs) return nil } @@ -410,6 +396,61 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt return e.FollowID }) + // Invalidate all related list entry caches. + l.invalidateEntryCaches(ctx, listIDs, followIDs) + + return nil +} + +func (l *listDB) DeleteListEntry(ctx context.Context, listID string, followID string) error { + // Delete list entry with given + // ID, returning its list ID. + if _, err := l.db.NewDelete(). + Table("list_entries"). + Where("? = ?", bun.Ident("list_id"), listID). + Where("? = ?", bun.Ident("follow_id"), followID). + Exec(ctx, &listID); err != nil && + !errors.Is(err, db.ErrNoEntries) { + return err + } + + // Invalidate all related list entry caches. + l.invalidateEntryCaches(ctx, []string{listID}, + []string{followID}) + + return nil +} + +func (l *listDB) DeleteAllListEntriesByFollowIDs(ctx context.Context, followIDs ...string) error { + var listIDs []string + + // Check for empty list. + if len(followIDs) == 0 { + return nil + } + + // Delete all entries with follow + // ID, returning IDs and list IDs. + if _, err := l.db.NewDelete(). + Table("list_entries"). + Where("? IN (?)", bun.Ident("follow_id"), followIDs). + Returning("?", bun.Ident("list_id")). + Exec(ctx, &listIDs); err != nil && + !errors.Is(err, db.ErrNoEntries) { + return err + } + + // Deduplicate IDs before invalidate. + listIDs = util.Deduplicate(listIDs) + + // Invalidate all related list entry caches. + l.invalidateEntryCaches(ctx, listIDs, followIDs) + + return nil +} + +// invalidateEntryCaches will invalidate all related ListEntry caches for given list IDs and follow IDs, including timelines. +func (l *listDB) invalidateEntryCaches(ctx context.Context, listIDs, followIDs []string) { var keys []string // Generate ListedID keys to invalidate. @@ -437,74 +478,4 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt // Invalidate ListID slice cache entries. l.state.Caches.DB.ListIDs.Invalidate(keys...) - - return nil -} - -func (l *listDB) DeleteListEntry(ctx context.Context, listID string, followID string) error { - // Delete list entry with given - // ID, returning its list ID. - if _, err := l.db.NewDelete(). - Table("list_entries"). - Where("? = ?", bun.Ident("list_id"), listID). - Where("? = ?", bun.Ident("follow_id"), followID). - Exec(ctx, &listID); err != nil && - !errors.Is(err, db.ErrNoEntries) { - return err - } - - // Invalidate list IDs containing follow. - l.state.Caches.DB.ListIDs.Invalidate( - "f" + followID, - ) - - // Invalidate account / follow IDs in list. - l.state.Caches.DB.ListedIDs.Invalidate( - "a"+listID, - "f"+listID, - ) - - // Invalidate the timeline for the list this entry belongs to. - if err := l.state.Timelines.List.RemoveTimeline(ctx, listID); err != nil { - log.Errorf(ctx, "error invalidating list timeline: %q", err) - } - - return nil -} - -func (l *listDB) DeleteListEntriesTargettingFollowID(ctx context.Context, followID string) error { - var listIDs []string - - // Delete all entries with follow - // ID, returning IDs and list IDs. - if _, err := l.db.NewDelete(). - Table("list_entries"). - Where("? = ?", bun.Ident("follow_id"), followID). - Returning("?", bun.Ident("list_id")). - Exec(ctx, &listIDs); err != nil && - !errors.Is(err, db.ErrNoEntries) { - return err - } - - // Invalidate list IDs containing follow. - l.state.Caches.DB.ListIDs.Invalidate( - "f" + followID, - ) - - // Iterate through list IDs of deleted entries. - for _, listID := range util.Deduplicate(listIDs) { - - // Invalidate account / follow IDs in list. - l.state.Caches.DB.ListedIDs.Invalidate( - "a"+listID, - "f"+listID, - ) - - // Invalidate the timeline for the list this entry belongs to. - if err := l.state.Timelines.List.RemoveTimeline(ctx, listID); err != nil { - log.Errorf(ctx, "error invalidating list timeline: %q", err) - } - } - - return nil } diff --git a/internal/db/bundb/list_test.go b/internal/db/bundb/list_test.go index 9e94d1cf7..d5e5f315e 100644 --- a/internal/db/bundb/list_test.go +++ b/internal/db/bundb/list_test.go @@ -18,329 +18,312 @@ package bundb_test import ( + "context" + "slices" "testing" "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) type ListTestSuite struct { BunDBStandardTestSuite } -// func (suite *ListTestSuite) testStructs() (*gtsmodel.List, *gtsmodel.Account) { -// testList := >smodel.List{} -// *testList = *suite.testLists["local_account_1_list_1"] +func (suite *ListTestSuite) testStructs() (*gtsmodel.List, []*gtsmodel.ListEntry, *gtsmodel.Account) { + testList := >smodel.List{} + *testList = *suite.testLists["local_account_1_list_1"] -// // Populate entries on this list as we'd expect them back from the db. -// entries := make([]*gtsmodel.ListEntry, 0, len(suite.testListEntries)) -// for _, entry := range suite.testListEntries { -// entries = append(entries, entry) -// } + // Populate entries on this list as we'd expect them back from the db. + entries := make([]*gtsmodel.ListEntry, 0, len(suite.testListEntries)) + for _, entry := range suite.testListEntries { + entries = append(entries, entry) + } -// // Sort by ID descending (again, as we'd expect from the db). -// slices.SortFunc(entries, func(a, b *gtsmodel.ListEntry) int { -// const k = -1 -// switch { -// case a.ID > b.ID: -// return +k -// case a.ID < b.ID: -// return -k -// default: -// return 0 -// } -// }) + // Sort by ID descending (again, as we'd expect from the db). + slices.SortFunc(entries, func(a, b *gtsmodel.ListEntry) int { + const k = -1 + switch { + case a.ID > b.ID: + return +k + case a.ID < b.ID: + return -k + default: + return 0 + } + }) -// testList.ListEntries = entries + testAccount := >smodel.Account{} + *testAccount = *suite.testAccounts["local_account_1"] -// testAccount := >smodel.Account{} -// *testAccount = *suite.testAccounts["local_account_1"] + return testList, entries, testAccount +} -// return testList, testAccount -// } +func (suite *ListTestSuite) checkList(expected *gtsmodel.List, actual *gtsmodel.List) { + suite.Equal(expected.ID, actual.ID) + suite.Equal(expected.Title, actual.Title) + suite.Equal(expected.AccountID, actual.AccountID) + suite.Equal(expected.RepliesPolicy, actual.RepliesPolicy) + suite.NotNil(actual.Account) +} -// func (suite *ListTestSuite) checkList(expected *gtsmodel.List, actual *gtsmodel.List) { -// suite.Equal(expected.ID, actual.ID) -// suite.Equal(expected.Title, actual.Title) -// suite.Equal(expected.AccountID, actual.AccountID) -// suite.Equal(expected.RepliesPolicy, actual.RepliesPolicy) -// suite.NotNil(actual.Account) -// } +func (suite *ListTestSuite) checkListEntry(expected *gtsmodel.ListEntry, actual *gtsmodel.ListEntry) { + suite.Equal(expected.ID, actual.ID) + suite.Equal(expected.ListID, actual.ListID) + suite.Equal(expected.FollowID, actual.FollowID) +} -// func (suite *ListTestSuite) checkListEntry(expected *gtsmodel.ListEntry, actual *gtsmodel.ListEntry) { -// suite.Equal(expected.ID, actual.ID) -// suite.Equal(expected.ListID, actual.ListID) -// suite.Equal(expected.FollowID, actual.FollowID) -// } +func (suite *ListTestSuite) checkListEntries(expected []*gtsmodel.ListEntry, actual []*gtsmodel.ListEntry) { + var ( + lExpected = len(expected) + lActual = len(actual) + ) -// func (suite *ListTestSuite) checkListEntries(expected []*gtsmodel.ListEntry, actual []*gtsmodel.ListEntry) { -// var ( -// lExpected = len(expected) -// lActual = len(actual) -// ) + if lExpected != lActual { + suite.FailNow("", "expected %d list entries, got %d", lExpected, lActual) + } -// if lExpected != lActual { -// suite.FailNow("", "expected %d list entries, got %d", lExpected, lActual) -// } + var topID string + for i, expectedEntry := range expected { + actualEntry := actual[i] -// var topID string -// for i, expectedEntry := range expected { -// actualEntry := actual[i] + // Ensure ID descending. + if topID == "" { + topID = actualEntry.ID + } else { + suite.Less(actualEntry.ID, topID) + } -// // Ensure ID descending. -// if topID == "" { -// topID = actualEntry.ID -// } else { -// suite.Less(actualEntry.ID, topID) -// } + suite.checkListEntry(expectedEntry, actualEntry) + } +} -// suite.checkListEntry(expectedEntry, actualEntry) -// } -// } +func (suite *ListTestSuite) TestGetListByID() { + testList, _, _ := suite.testStructs() -// func (suite *ListTestSuite) TestGetListByID() { -// testList, _ := suite.testStructs() + dbList, err := suite.db.GetListByID(context.Background(), testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } -// dbList, err := suite.db.GetListByID(context.Background(), testList.ID) -// if err != nil { -// suite.FailNow(err.Error()) -// } + suite.checkList(testList, dbList) +} -// suite.checkList(testList, dbList) -// suite.checkListEntries(testList.ListEntries, dbList.ListEntries) -// } +func (suite *ListTestSuite) TestGetListsForAccountID() { + testList, _, testAccount := suite.testStructs() -// func (suite *ListTestSuite) TestGetListsForAccountID() { -// testList, testAccount := suite.testStructs() + dbLists, err := suite.db.GetListsByAccountID(context.Background(), testAccount.ID) + if err != nil { + suite.FailNow(err.Error()) + } -// dbLists, err := suite.db.GetListsForAccountID(context.Background(), testAccount.ID) -// if err != nil { -// suite.FailNow(err.Error()) -// } + if l := len(dbLists); l != 1 { + suite.FailNow("", "expected %d lists, got %d", 1, l) + } -// if l := len(dbLists); l != 1 { -// suite.FailNow("", "expected %d lists, got %d", 1, l) -// } + suite.checkList(testList, dbLists[0]) +} -// suite.checkList(testList, dbLists[0]) -// } +func (suite *ListTestSuite) TestPutList() { + ctx := context.Background() + _, _, testAccount := suite.testStructs() -// func (suite *ListTestSuite) TestGetListEntries() { -// testList, _ := suite.testStructs() + testList := >smodel.List{ + ID: "01H0J2PMYM54618VCV8Y8QYAT4", + Title: "Test List!", + AccountID: testAccount.ID, + } -// dbListEntries, err := suite.db.GetListEntries(context.Background(), testList.ID, "", "", "", 0) -// if err != nil { -// suite.FailNow(err.Error()) -// } + if err := suite.db.PutList(ctx, testList); err != nil { + suite.FailNow(err.Error()) + } -// suite.checkListEntries(testList.ListEntries, dbListEntries) -// } + dbList, err := suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } -// func (suite *ListTestSuite) TestPutList() { -// ctx := context.Background() -// _, testAccount := suite.testStructs() + // Bodge testlist as though default had been set. + testList.RepliesPolicy = gtsmodel.RepliesPolicyFollowed + suite.checkList(testList, dbList) +} -// testList := >smodel.List{ -// ID: "01H0J2PMYM54618VCV8Y8QYAT4", -// Title: "Test List!", -// AccountID: testAccount.ID, -// } +func (suite *ListTestSuite) TestUpdateList() { + ctx := context.Background() + testList, _, _ := suite.testStructs() -// if err := suite.db.PutList(ctx, testList); err != nil { -// suite.FailNow(err.Error()) -// } + // Get List in the cache first. + dbList, err := suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } -// dbList, err := suite.db.GetListByID(ctx, testList.ID) -// if err != nil { -// suite.FailNow(err.Error()) -// } + // Now do the update. + testList.Title = "New Title!" + if err := suite.db.UpdateList(ctx, testList, "title"); err != nil { + suite.FailNow(err.Error()) + } -// // Bodge testlist as though default had been set. -// testList.RepliesPolicy = gtsmodel.RepliesPolicyFollowed -// suite.checkList(testList, dbList) -// } + // Cache should be invalidated + // + we should have updated list. + dbList, err = suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } -// func (suite *ListTestSuite) TestUpdateList() { -// ctx := context.Background() -// testList, _ := suite.testStructs() + suite.checkList(testList, dbList) +} -// // Get List in the cache first. -// dbList, err := suite.db.GetListByID(ctx, testList.ID) -// if err != nil { -// suite.FailNow(err.Error()) -// } +func (suite *ListTestSuite) TestDeleteList() { + ctx := context.Background() + testList, _, _ := suite.testStructs() -// // Now do the update. -// testList.Title = "New Title!" -// if err := suite.db.UpdateList(ctx, testList, "title"); err != nil { -// suite.FailNow(err.Error()) -// } + // Get List in the cache first. + if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { + suite.FailNow(err.Error()) + } -// // Cache should be invalidated -// // + we should have updated list. -// dbList, err = suite.db.GetListByID(ctx, testList.ID) -// if err != nil { -// suite.FailNow(err.Error()) -// } + // Now do the delete. + if err := suite.db.DeleteListByID(ctx, testList.ID); err != nil { + suite.FailNow(err.Error()) + } -// suite.checkList(testList, dbList) -// } + // Cache should be invalidated + // + we should have no list. + _, err := suite.db.GetListByID(ctx, testList.ID) + suite.ErrorIs(err, db.ErrNoEntries) -// func (suite *ListTestSuite) TestDeleteList() { -// ctx := context.Background() -// testList, _ := suite.testStructs() + // All accounts / follows attached to this + // list should now be return empty values. + listAccounts, err1 := suite.db.GetAccountsInList(ctx, testList.ID, nil) + listFollows, err2 := suite.db.GetFollowsInList(ctx, testList.ID, nil) + suite.NoError(err1) + suite.NoError(err2) + suite.Empty(listAccounts) + suite.Empty(listFollows) +} -// // Get List in the cache first. -// if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { -// suite.FailNow(err.Error()) -// } +func (suite *ListTestSuite) TestPutListEntries() { + ctx := context.Background() + testList, testEntries, _ := suite.testStructs() -// // Now do the delete. -// if err := suite.db.DeleteListByID(ctx, testList.ID); err != nil { -// suite.FailNow(err.Error()) -// } + listEntries := []*gtsmodel.ListEntry{ + { + ID: "01H0MKMQY69HWDSDR2SWGA17R4", + ListID: testList.ID, + FollowID: "01H0MKNFRFZS8R9WV6DBX31Y03", // random id, doesn't exist + }, + { + ID: "01H0MKPGQF0E7QAVW5BKTHZ630", + ListID: testList.ID, + FollowID: "01H0MKP6RR8VEHN3GVWFBP2H30", // random id, doesn't exist + }, + { + ID: "01H0MKPPP2DT68FRBMR1FJM32T", + ListID: testList.ID, + FollowID: "01H0MKQ0KA29C6NFJ27GTZD16J", // random id, doesn't exist + }, + } -// // Cache should be invalidated -// // + we should have no list. -// _, err := suite.db.GetListByID(ctx, testList.ID) -// suite.ErrorIs(err, db.ErrNoEntries) + if err := suite.db.PutListEntries(ctx, listEntries); err != nil { + suite.FailNow(err.Error()) + } -// // All entries belonging to this -// // list should now be deleted. -// listEntries, err := suite.db.GetListEntries(ctx, testList.ID, "", "", "", 0) -// if err != nil { -// suite.FailNow(err.Error()) -// } -// suite.Empty(listEntries) -// } + // Add these entries to the test list. + testEntries = append(testEntries, listEntries...) -// func (suite *ListTestSuite) TestPutListEntries() { -// ctx := context.Background() -// testList, _ := suite.testStructs() + // Now get all list entries from the db. + // Use barebones for this because the ones + // we just added will fail if we try to get + // the nonexistent follows. + dbListEntries, err := suite.db.GetListEntries( + gtscontext.SetBarebones(ctx), + testList.ID, + "", "", "", 0) + if err != nil { + suite.FailNow(err.Error()) + } -// listEntries := []*gtsmodel.ListEntry{ -// { -// ID: "01H0MKMQY69HWDSDR2SWGA17R4", -// ListID: testList.ID, -// FollowID: "01H0MKNFRFZS8R9WV6DBX31Y03", // random id, doesn't exist -// }, -// { -// ID: "01H0MKPGQF0E7QAVW5BKTHZ630", -// ListID: testList.ID, -// FollowID: "01H0MKP6RR8VEHN3GVWFBP2H30", // random id, doesn't exist -// }, -// { -// ID: "01H0MKPPP2DT68FRBMR1FJM32T", -// ListID: testList.ID, -// FollowID: "01H0MKQ0KA29C6NFJ27GTZD16J", // random id, doesn't exist -// }, -// } + suite.checkListEntries(testList.ListEntries, dbListEntries) +} -// if err := suite.db.PutListEntries(ctx, listEntries); err != nil { -// suite.FailNow(err.Error()) -// } +func (suite *ListTestSuite) TestDeleteListEntry() { + ctx := context.Background() + testList, testEntries, _ := suite.testStructs() -// // Add these entries to the test list, sort it again -// // to reflect what we'd expect to get from the db. -// testList.ListEntries = append(testList.ListEntries, listEntries...) -// slices.SortFunc(testList.ListEntries, func(a, b *gtsmodel.ListEntry) int { -// const k = -1 -// switch { -// case a.ID > b.ID: -// return +k -// case a.ID < b.ID: -// return -k -// default: -// return 0 -// } -// }) + // Get List in the cache first. + if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { + suite.FailNow(err.Error()) + } -// // Now get all list entries from the db. -// // Use barebones for this because the ones -// // we just added will fail if we try to get -// // the nonexistent follows. -// dbListEntries, err := suite.db.GetListEntries( -// gtscontext.SetBarebones(ctx), -// testList.ID, -// "", "", "", 0) -// if err != nil { -// suite.FailNow(err.Error()) -// } + // Delete the first entry. + if err := suite.db.DeleteListEntry(ctx, + testEntries[0].ListID, + testEntries[0].FollowID, + ); err != nil { + suite.FailNow(err.Error()) + } -// suite.checkListEntries(testList.ListEntries, dbListEntries) -// } + // Get list from the db again. + dbList, err := suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } -// func (suite *ListTestSuite) TestDeleteListEntry() { -// ctx := context.Background() -// testList, _ := suite.testStructs() + // Bodge the testlist as though + // we'd removed the first entry. + testList.ListEntries = testList.ListEntries[1:] + suite.checkList(testList, dbList) +} -// // Get List in the cache first. -// if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { -// suite.FailNow(err.Error()) -// } +func (suite *ListTestSuite) TestDeleteAllListEntriesByFollowID() { + ctx := context.Background() + testList, testEntries, _ := suite.testStructs() -// // Delete the first entry. -// if err := suite.db.DeleteListEntry(ctx, testList.ListEntries[0].ID); err != nil { -// suite.FailNow(err.Error()) -// } + // Get List in the cache first. + if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { + suite.FailNow(err.Error()) + } -// // Get list from the db again. -// dbList, err := suite.db.GetListByID(ctx, testList.ID) -// if err != nil { -// suite.FailNow(err.Error()) -// } + // Delete the first entry. + if err := suite.db.DeleteAllListEntriesByFollowIDs(ctx, testEntries[0].FollowID); err != nil { + suite.FailNow(err.Error()) + } -// // Bodge the testlist as though -// // we'd removed the first entry. -// testList.ListEntries = testList.ListEntries[1:] -// suite.checkList(testList, dbList) -// } + // Get list from the db again. + dbList, err := suite.db.GetListByID(ctx, testList.ID) + if err != nil { + suite.FailNow(err.Error()) + } -// func (suite *ListTestSuite) TestDeleteListEntriesForFollowID() { -// ctx := context.Background() -// testList, _ := suite.testStructs() + // Bodge the testlist as though + // we'd removed the first entry. + testList.ListEntries = testList.ListEntries[1:] + suite.checkList(testList, dbList) +} -// // Get List in the cache first. -// if _, err := suite.db.GetListByID(ctx, testList.ID); err != nil { -// suite.FailNow(err.Error()) -// } +func (suite *ListTestSuite) TestListIncludesAccount() { + ctx := context.Background() + testList, _, _ := suite.testStructs() -// // Delete the first entry. -// if err := suite.db.DeleteListEntriesTargettingFollowID(ctx, testList.ListEntries[0].FollowID); err != nil { -// suite.FailNow(err.Error()) -// } + for accountID, expected := range map[string]bool{ + suite.testAccounts["admin_account"].ID: true, + suite.testAccounts["local_account_1"].ID: false, + suite.testAccounts["local_account_2"].ID: true, + "01H7074GEZJ56J5C86PFB0V2CT": false, + } { + includes, err := suite.db.IsAccountInList(ctx, testList.ID, accountID) + if err != nil { + suite.FailNow(err.Error()) + } -// // Get list from the db again. -// dbList, err := suite.db.GetListByID(ctx, testList.ID) -// if err != nil { -// suite.FailNow(err.Error()) -// } - -// // Bodge the testlist as though -// // we'd removed the first entry. -// testList.ListEntries = testList.ListEntries[1:] -// suite.checkList(testList, dbList) -// } - -// func (suite *ListTestSuite) TestListIncludesAccount() { -// ctx := context.Background() -// testList, _ := suite.testStructs() - -// for accountID, expected := range map[string]bool{ -// suite.testAccounts["admin_account"].ID: true, -// suite.testAccounts["local_account_1"].ID: false, -// suite.testAccounts["local_account_2"].ID: true, -// "01H7074GEZJ56J5C86PFB0V2CT": false, -// } { -// includes, err := suite.db.ListIncludesAccount(ctx, testList.ID, accountID) -// if err != nil { -// suite.FailNow(err.Error()) -// } - -// if includes != expected { -// suite.FailNow("", "expected %t for accountID %s got %t", expected, accountID, includes) -// } -// } -// } + if includes != expected { + suite.FailNow("", "expected %t for accountID %s got %t", expected, accountID, includes) + } + } +} func TestListTestSuite(t *testing.T) { suite.Run(t, new(ListTestSuite)) diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go index f7ed19bf9..c440962ba 100644 --- a/internal/db/bundb/relationship_follow.go +++ b/internal/db/bundb/relationship_follow.go @@ -20,7 +20,6 @@ package bundb import ( "context" "errors" - "fmt" "slices" "time" @@ -255,8 +254,8 @@ func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error { } // Delete every list entry that used this followID. - if err := r.state.DB.DeleteListEntriesTargettingFollowID(ctx, id); err != nil { - return fmt.Errorf("deleteFollow: error deleting list entries: %w", err) + if err := r.state.DB.DeleteAllListEntriesByFollowIDs(ctx, id); err != nil { + return gtserror.Newf("error deleting list entries: %w", err) } return nil @@ -373,12 +372,6 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str return err } - for _, id := range followIDs { - // Finally, delete all list entries associated with each follow ID. - if err := r.state.DB.DeleteListEntriesTargettingFollowID(ctx, id); err != nil { - return err - } - } - - return nil + // Finally, delete all list entries associated with the follow IDs. + return r.state.DB.DeleteAllListEntriesByFollowIDs(ctx, followIDs...) } diff --git a/internal/db/list.go b/internal/db/list.go index 222f87808..71d9d09b2 100644 --- a/internal/db/list.go +++ b/internal/db/list.go @@ -53,7 +53,7 @@ type List interface { GetAccountsInList(ctx context.Context, listID string, page *paging.Page) ([]*gtsmodel.Account, error) // IsAccountInListID returns whether given account with ID is in the list with ID. - IsAccountInListID(ctx context.Context, listID string, accountID string) (bool, error) + IsAccountInList(ctx context.Context, listID string, accountID string) (bool, error) // PopulateList ensures that the list's struct fields are populated. PopulateList(ctx context.Context, list *gtsmodel.List) error @@ -75,6 +75,6 @@ type List interface { // DeleteListEntry deletes the list entry with given list ID and follow ID. DeleteListEntry(ctx context.Context, listID string, followID string) error - // DeleteListEntryForFollowID deletes all list entries with the given followID. - DeleteListEntriesTargettingFollowID(ctx context.Context, followID string) error + // DeleteAllListEntryByFollowID deletes all list entries with the given followIDs. + DeleteAllListEntriesByFollowIDs(ctx context.Context, followIDs ...string) error } diff --git a/internal/processing/workers/surfacetimeline.go b/internal/processing/workers/surfacetimeline.go index be1ffdd23..f9e380dcb 100644 --- a/internal/processing/workers/surfacetimeline.go +++ b/internal/processing/workers/surfacetimeline.go @@ -333,7 +333,7 @@ func (s *Surface) listEligible( // // Check if replied-to account is // also included in this list. - in, err := s.State.DB.IsAccountInListID(ctx, + in, err := s.State.DB.IsAccountInList(ctx, list.ID, status.InReplyToAccountID, )