mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-10-29 14:42:24 -05:00
[performance] refactoring + add fave / follow / request / visibility caching (#1607)
* refactor visibility checking, add caching for visibility
* invalidate visibility cache items on account / status deletes
* fix requester ID passed to visibility cache nil ptr
* de-interface caches, fix home / public timeline caching + visibility
* finish adding code comments for visibility filter
* fix angry goconst linter warnings
* actually finish adding filter visibility code comments for timeline functions
* move home timeline status author check to after visibility
* remove now-unused code
* add more code comments
* add TODO code comment, update printed cache start names
* update printed cache names on stop
* start adding separate follow(request) delete db functions, add specific visibility cache tests
* add relationship type caching
* fix getting local account follows / followed-bys, other small codebase improvements
* simplify invalidation using cache hooks, add more GetAccountBy___() functions
* fix boosting to return 404 if not boostable but no error (to not leak status ID)
* remove dead code
* improved placement of cache invalidation
* update license headers
* add example follow, follow-request config entries
* add example visibility cache configuration to config file
* use specific PutFollowRequest() instead of just Put()
* add tests for all GetAccountBy()
* add GetBlockBy() tests
* update block to check primitive fields
* update and finish adding Get{Account,Block,Follow,FollowRequest}By() tests
* fix copy-pasted code
* update envparsing test
* whitespace
* fix bun struct tag
* add license header to gtscontext
* fix old license header
* improved error creation to not use fmt.Errorf() when not needed
* fix various rebase conflicts, fix account test
* remove commented-out code, fix-up mention caching
* fix mention select bun statement
* ensure mention target account populated, pass in context to customrenderer logging
* remove more uncommented code, fix typeutil test
* add statusfave database model caching
* add status fave cache configuration
* add status fave cache example config
* woops, catch missed error. nice catch linter!
* add back testrig panic on nil db
* update example configuration to match defaults, slight tweak to cache configuration defaults
* update envparsing test with new defaults
* fetch followingget to use the follow target account
* use accounnt.IsLocal() instead of empty domain check
* use constants for the cache visibility type check
* use bun.In() for notification type restriction in db query
* include replies when fetching PublicTimeline() (to account for single-author threads in Visibility{}.StatusPublicTimelineable())
* use bun query building for nested select statements to ensure working with postgres
* update public timeline future status checks to match visibility filter
* same as previous, for home timeline
* update public timeline tests to dynamically check for appropriate statuses
* migrate accounts to allow unique constraint on public_key
* provide minimal account with publicKey
---------
Signed-off-by: kim <grufwub@gmail.com>
Co-authored-by: tsmethurst <tobi.smethurst@protonmail.com>
This commit is contained in:
parent
7d09863393
commit
de6e3e5f2a
100 changed files with 4423 additions and 2367 deletions
|
|
@ -41,6 +41,21 @@ type Account interface {
|
|||
// GetAccountByPubkeyID returns one account with the given public key URI (ID), or an error if something goes wrong.
|
||||
GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error)
|
||||
|
||||
// GetAccountByInboxURI returns one account with the given inbox_uri, or an error if something goes wrong.
|
||||
GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
|
||||
|
||||
// GetAccountByOutboxURI returns one account with the given outbox_uri, or an error if something goes wrong.
|
||||
GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
|
||||
|
||||
// GetAccountByFollowingURI returns one account with the given following_uri, or an error if something goes wrong.
|
||||
GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
|
||||
|
||||
// GetAccountByFollowersURI returns one account with the given followers_uri, or an error if something goes wrong.
|
||||
GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
|
||||
|
||||
// PopulateAccount ensures that all sub-models of an account are populated (e.g. avatar, header etc).
|
||||
PopulateAccount(ctx context.Context, account *gtsmodel.Account) error
|
||||
|
||||
// PutAccount puts one account in the database.
|
||||
PutAccount(ctx context.Context, account *gtsmodel.Account) Error
|
||||
|
||||
|
|
|
|||
|
|
@ -20,11 +20,13 @@ package bundb
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
|
|
@ -37,18 +39,15 @@ type accountDB struct {
|
|||
state *state.State
|
||||
}
|
||||
|
||||
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
|
||||
return a.conn.
|
||||
NewSelect().
|
||||
Model(account)
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
|
||||
return a.getAccount(
|
||||
ctx,
|
||||
"ID",
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
|
||||
return a.conn.NewSelect().
|
||||
Model(account).
|
||||
Where("? = ?", bun.Ident("account.id"), id).
|
||||
Scan(ctx)
|
||||
},
|
||||
id,
|
||||
)
|
||||
|
|
@ -59,7 +58,10 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
|
|||
ctx,
|
||||
"URI",
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
|
||||
return a.conn.NewSelect().
|
||||
Model(account).
|
||||
Where("? = ?", bun.Ident("account.uri"), uri).
|
||||
Scan(ctx)
|
||||
},
|
||||
uri,
|
||||
)
|
||||
|
|
@ -70,7 +72,10 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
|
|||
ctx,
|
||||
"URL",
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
|
||||
return a.conn.NewSelect().
|
||||
Model(account).
|
||||
Where("? = ?", bun.Ident("account.url"), url).
|
||||
Scan(ctx)
|
||||
},
|
||||
url,
|
||||
)
|
||||
|
|
@ -81,7 +86,8 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
|
|||
ctx,
|
||||
"Username.Domain",
|
||||
func(account *gtsmodel.Account) error {
|
||||
q := a.newAccountQ(account)
|
||||
q := a.conn.NewSelect().
|
||||
Model(account)
|
||||
|
||||
if domain != "" {
|
||||
q = q.
|
||||
|
|
@ -105,12 +111,71 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
|
|||
ctx,
|
||||
"PublicKeyURI",
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
|
||||
return a.conn.NewSelect().
|
||||
Model(account).
|
||||
Where("? = ?", bun.Ident("account.public_key_uri"), id).
|
||||
Scan(ctx)
|
||||
},
|
||||
id,
|
||||
)
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
|
||||
return a.getAccount(
|
||||
ctx,
|
||||
"InboxURI",
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.conn.NewSelect().
|
||||
Model(account).
|
||||
Where("? = ?", bun.Ident("account.inbox_uri"), uri).
|
||||
Scan(ctx)
|
||||
},
|
||||
uri,
|
||||
)
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
|
||||
return a.getAccount(
|
||||
ctx,
|
||||
"OutboxURI",
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.conn.NewSelect().
|
||||
Model(account).
|
||||
Where("? = ?", bun.Ident("account.outbox_uri"), uri).
|
||||
Scan(ctx)
|
||||
},
|
||||
uri,
|
||||
)
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
|
||||
return a.getAccount(
|
||||
ctx,
|
||||
"FollowersURI",
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.conn.NewSelect().
|
||||
Model(account).
|
||||
Where("? = ?", bun.Ident("account.followers_uri"), uri).
|
||||
Scan(ctx)
|
||||
},
|
||||
uri,
|
||||
)
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
|
||||
return a.getAccount(
|
||||
ctx,
|
||||
"FollowingURI",
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.conn.NewSelect().
|
||||
Model(account).
|
||||
Where("? = ?", bun.Ident("account.following_uri"), uri).
|
||||
Scan(ctx)
|
||||
},
|
||||
uri,
|
||||
)
|
||||
}
|
||||
|
||||
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
|
||||
var username string
|
||||
|
||||
|
|
@ -141,33 +206,58 @@ func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if account.AvatarMediaAttachmentID != "" {
|
||||
// Set the account's related avatar
|
||||
account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting account %s avatar: %v", account.ID, err)
|
||||
}
|
||||
if gtscontext.Barebones(ctx) {
|
||||
// no need to fully populate.
|
||||
return account, nil
|
||||
}
|
||||
|
||||
if account.HeaderMediaAttachmentID != "" {
|
||||
// Set the account's related header
|
||||
account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.HeaderMediaAttachmentID)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting account %s header: %v", account.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(account.EmojiIDs) > 0 {
|
||||
// Set the account's related emojis
|
||||
account.Emojis, err = a.state.DB.GetEmojisByIDs(ctx, account.EmojiIDs)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting account %s emojis: %v", account.ID, err)
|
||||
}
|
||||
// Further populate the account fields where applicable.
|
||||
if err := a.PopulateAccount(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Account) error {
|
||||
var err error
|
||||
|
||||
if account.AvatarMediaAttachment == nil && account.AvatarMediaAttachmentID != "" {
|
||||
// Account avatar attachment is not set, fetch from database.
|
||||
account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID(
|
||||
ctx, // these are already barebones
|
||||
account.AvatarMediaAttachmentID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error populating account avatar: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if account.HeaderMediaAttachment == nil && account.HeaderMediaAttachmentID != "" {
|
||||
// Account header attachment is not set, fetch from database.
|
||||
account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID(
|
||||
ctx, // these are already barebones
|
||||
account.HeaderMediaAttachmentID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error populating account header: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !account.EmojisPopulated() {
|
||||
// Account emojis are out-of-date with IDs, repopulate.
|
||||
account.Emojis, err = a.state.DB.GetEmojisByIDs(
|
||||
ctx, // these are already barebones
|
||||
account.EmojiIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error populating account emojis: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error {
|
||||
return a.state.Caches.GTS.Account().Store(account, func() error {
|
||||
// It is safe to run this database transaction within cache.Store
|
||||
|
|
@ -198,7 +288,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
|
|||
columns = append(columns, "updated_at")
|
||||
}
|
||||
|
||||
return a.state.Caches.GTS.Account().Store(account, func() error {
|
||||
err := a.state.Caches.GTS.Account().Store(account, func() error {
|
||||
// It is safe to run this database transaction within cache.Store
|
||||
// as the cache does not attempt a mutex lock until AFTER hook.
|
||||
//
|
||||
|
|
@ -234,6 +324,11 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
|
|||
return err
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
|
||||
|
|
@ -258,7 +353,9 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Invalidate account from database lookups.
|
||||
a.state.Caches.GTS.Account().Invalidate("ID", id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ import (
|
|||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -61,44 +63,149 @@ func (suite *AccountTestSuite) TestGetAccountStatusesMediaOnly() {
|
|||
suite.Len(statuses, 1)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
|
||||
account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
func (suite *AccountTestSuite) TestGetAccountBy() {
|
||||
t := suite.T()
|
||||
|
||||
// Create a new context for this test.
|
||||
ctx, cncl := context.WithCancel(context.Background())
|
||||
defer cncl()
|
||||
|
||||
// Sentinel error to mark avoiding a test case.
|
||||
sentinelErr := errors.New("sentinel")
|
||||
|
||||
// isEqual checks if 2 account models are equal.
|
||||
isEqual := func(a1, a2 gtsmodel.Account) bool {
|
||||
// Clear populated sub-models.
|
||||
a1.HeaderMediaAttachment = nil
|
||||
a2.HeaderMediaAttachment = nil
|
||||
a1.AvatarMediaAttachment = nil
|
||||
a2.AvatarMediaAttachment = nil
|
||||
a1.Emojis = nil
|
||||
a2.Emojis = nil
|
||||
|
||||
// Clear database-set fields.
|
||||
a1.CreatedAt = time.Time{}
|
||||
a2.CreatedAt = time.Time{}
|
||||
a1.UpdatedAt = time.Time{}
|
||||
a2.UpdatedAt = time.Time{}
|
||||
|
||||
// Manually compare keys.
|
||||
pk1 := a1.PublicKey
|
||||
pv1 := a1.PrivateKey
|
||||
pk2 := a2.PublicKey
|
||||
pv2 := a2.PrivateKey
|
||||
a1.PublicKey = nil
|
||||
a1.PrivateKey = nil
|
||||
a2.PublicKey = nil
|
||||
a2.PrivateKey = nil
|
||||
|
||||
return reflect.DeepEqual(a1, a2) &&
|
||||
((pk1 == nil && pk2 == nil) || pk1.Equal(pk2)) &&
|
||||
((pv1 == nil && pv2 == nil) || pv1.Equal(pv2))
|
||||
}
|
||||
suite.NotNil(account)
|
||||
suite.NotNil(account.AvatarMediaAttachment)
|
||||
suite.NotEmpty(account.AvatarMediaAttachment.URL)
|
||||
suite.NotNil(account.HeaderMediaAttachment)
|
||||
suite.NotEmpty(account.HeaderMediaAttachment.URL)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() {
|
||||
testAccount1 := suite.testAccounts["local_account_1"]
|
||||
account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(account1)
|
||||
for _, account := range suite.testAccounts {
|
||||
for lookup, dbfunc := range map[string]func() (*gtsmodel.Account, error){
|
||||
"id": func() (*gtsmodel.Account, error) {
|
||||
return suite.db.GetAccountByID(ctx, account.ID)
|
||||
},
|
||||
|
||||
testAccount2 := suite.testAccounts["remote_account_1"]
|
||||
account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(account2)
|
||||
}
|
||||
"uri": func() (*gtsmodel.Account, error) {
|
||||
return suite.db.GetAccountByURI(ctx, account.URI)
|
||||
},
|
||||
|
||||
func (suite *AccountTestSuite) TestGetAccountByUsernameDomainMixedCase() {
|
||||
testAccount := suite.testAccounts["remote_account_2"]
|
||||
"url": func() (*gtsmodel.Account, error) {
|
||||
if account.URL == "" {
|
||||
return nil, sentinelErr
|
||||
}
|
||||
return suite.db.GetAccountByURL(ctx, account.URL)
|
||||
},
|
||||
|
||||
account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount.Username, testAccount.Domain)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(account1)
|
||||
"username@domain": func() (*gtsmodel.Account, error) {
|
||||
return suite.db.GetAccountByUsernameDomain(ctx, account.Username, account.Domain)
|
||||
},
|
||||
|
||||
account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToUpper(testAccount.Username), testAccount.Domain)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(account2)
|
||||
"username_upper@domain": func() (*gtsmodel.Account, error) {
|
||||
return suite.db.GetAccountByUsernameDomain(ctx, strings.ToUpper(account.Username), account.Domain)
|
||||
},
|
||||
|
||||
account3, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToLower(testAccount.Username), testAccount.Domain)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(account3)
|
||||
"username_lower@domain": func() (*gtsmodel.Account, error) {
|
||||
return suite.db.GetAccountByUsernameDomain(ctx, strings.ToLower(account.Username), account.Domain)
|
||||
},
|
||||
|
||||
"public_key_uri": func() (*gtsmodel.Account, error) {
|
||||
if account.PublicKeyURI == "" {
|
||||
return nil, sentinelErr
|
||||
}
|
||||
return suite.db.GetAccountByPubkeyID(ctx, account.PublicKeyURI)
|
||||
},
|
||||
|
||||
"inbox_uri": func() (*gtsmodel.Account, error) {
|
||||
if account.InboxURI == "" {
|
||||
return nil, sentinelErr
|
||||
}
|
||||
return suite.db.GetAccountByInboxURI(ctx, account.InboxURI)
|
||||
},
|
||||
|
||||
"outbox_uri": func() (*gtsmodel.Account, error) {
|
||||
if account.OutboxURI == "" {
|
||||
return nil, sentinelErr
|
||||
}
|
||||
return suite.db.GetAccountByOutboxURI(ctx, account.OutboxURI)
|
||||
},
|
||||
|
||||
"following_uri": func() (*gtsmodel.Account, error) {
|
||||
if account.FollowingURI == "" {
|
||||
return nil, sentinelErr
|
||||
}
|
||||
return suite.db.GetAccountByFollowingURI(ctx, account.FollowingURI)
|
||||
},
|
||||
|
||||
"followers_uri": func() (*gtsmodel.Account, error) {
|
||||
if account.FollowersURI == "" {
|
||||
return nil, sentinelErr
|
||||
}
|
||||
return suite.db.GetAccountByFollowersURI(ctx, account.FollowersURI)
|
||||
},
|
||||
} {
|
||||
|
||||
// Clear database caches.
|
||||
suite.state.Caches.Init()
|
||||
|
||||
t.Logf("checking database lookup %q", lookup)
|
||||
|
||||
// Perform database function.
|
||||
checkAcc, err := dbfunc()
|
||||
if err != nil {
|
||||
if err == sentinelErr {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check received account data.
|
||||
if !isEqual(*checkAcc, *account) {
|
||||
t.Errorf("account does not contain expected data: %+v", checkAcc)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check that avatar attachment populated.
|
||||
if account.AvatarMediaAttachmentID != "" &&
|
||||
(checkAcc.AvatarMediaAttachment == nil || checkAcc.AvatarMediaAttachment.ID != account.AvatarMediaAttachmentID) {
|
||||
t.Errorf("account avatar media attachment not correctly populated for: %+v", account)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check that header attachment populated.
|
||||
if account.HeaderMediaAttachmentID != "" &&
|
||||
(checkAcc.HeaderMediaAttachment == nil || checkAcc.HeaderMediaAttachment.ID != account.HeaderMediaAttachmentID) {
|
||||
t.Errorf("account header media attachment not correctly populated for: %+v", account)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestUpdateAccount() {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ package bundb_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -40,6 +42,12 @@ func (suite *BasicTestSuite) TestGetAccountByID() {
|
|||
}
|
||||
|
||||
func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
// Create an account that only just matches constraints.
|
||||
testAccount := >smodel.Account{
|
||||
ID: "01GADR1AH9VCKH8YYCM86XSZ00",
|
||||
Username: "test",
|
||||
|
|
@ -49,6 +57,7 @@ func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
|
|||
OutboxURI: "https://example.org/users/test/outbox",
|
||||
ActorType: "Person",
|
||||
PublicKeyURI: "https://example.org/test#main-key",
|
||||
PublicKey: &key.PublicKey,
|
||||
}
|
||||
|
||||
if err := suite.db.Put(context.Background(), testAccount); err != nil {
|
||||
|
|
@ -99,7 +108,7 @@ func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
|
|||
suite.Empty(a.FeaturedCollectionURI)
|
||||
suite.Equal(testAccount.ActorType, a.ActorType)
|
||||
suite.Nil(a.PrivateKey)
|
||||
suite.Nil(a.PublicKey)
|
||||
suite.EqualValues(key.PublicKey, *a.PublicKey)
|
||||
suite.Equal(testAccount.PublicKeyURI, a.PublicKeyURI)
|
||||
suite.Zero(a.SensitizedAt)
|
||||
suite.Zero(a.SilencedAt)
|
||||
|
|
|
|||
|
|
@ -47,6 +47,24 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
|
|||
)
|
||||
}
|
||||
|
||||
func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) {
|
||||
attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids))
|
||||
|
||||
for _, id := range ids {
|
||||
// Attempt fetch from DB
|
||||
attachment, err := m.GetAttachmentByID(ctx, id)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting attachment %q: %v", id, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Append attachment
|
||||
attachments = append(attachments, attachment)
|
||||
}
|
||||
|
||||
return attachments, nil
|
||||
}
|
||||
|
||||
func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, db.Error) {
|
||||
return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) {
|
||||
var attachment gtsmodel.MediaAttachment
|
||||
|
|
@ -118,7 +136,7 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l
|
|||
return nil, m.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return m.getAttachments(ctx, attachmentIDs)
|
||||
return m.GetAttachmentsByIDs(ctx, attachmentIDs)
|
||||
}
|
||||
|
||||
func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) {
|
||||
|
|
@ -163,7 +181,7 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit
|
|||
return nil, m.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return m.getAttachments(ctx, attachmentIDs)
|
||||
return m.GetAttachmentsByIDs(ctx, attachmentIDs)
|
||||
}
|
||||
|
||||
func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, db.Error) {
|
||||
|
|
@ -189,7 +207,7 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim
|
|||
return nil, m.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return m.getAttachments(ctx, attachmentIDs)
|
||||
return m.GetAttachmentsByIDs(ctx, attachmentIDs)
|
||||
}
|
||||
|
||||
func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) {
|
||||
|
|
@ -211,21 +229,3 @@ func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan t
|
|||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *mediaDB) getAttachments(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, db.Error) {
|
||||
attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids))
|
||||
|
||||
for _, id := range ids {
|
||||
// Attempt fetch from DB
|
||||
attachment, err := m.GetAttachmentByID(ctx, id)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting attachment %q: %v", id, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Append attachment
|
||||
attachments = append(attachments, attachment)
|
||||
}
|
||||
|
||||
return attachments, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,8 +19,10 @@ package bundb
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
|
|
@ -32,20 +34,13 @@ type mentionDB struct {
|
|||
state *state.State
|
||||
}
|
||||
|
||||
func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
|
||||
return m.conn.
|
||||
NewSelect().
|
||||
Model(i).
|
||||
Relation("Status").
|
||||
Relation("OriginAccount").
|
||||
Relation("TargetAccount")
|
||||
}
|
||||
|
||||
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
|
||||
return m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) {
|
||||
mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) {
|
||||
var mention gtsmodel.Mention
|
||||
|
||||
q := m.newMentionQ(&mention).
|
||||
q := m.conn.
|
||||
NewSelect().
|
||||
Model(&mention).
|
||||
Where("? = ?", bun.Ident("mention.id"), id)
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
|
|
@ -54,6 +49,38 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
|
|||
|
||||
return &mention, nil
|
||||
}, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set the mention originating status.
|
||||
mention.Status, err = m.state.DB.GetStatusByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
mention.StatusID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating mention status: %w", err)
|
||||
}
|
||||
|
||||
// Set the mention origin account model.
|
||||
mention.OriginAccount, err = m.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
mention.OriginAccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating mention origin account: %w", err)
|
||||
}
|
||||
|
||||
// Set the mention target account model.
|
||||
mention.TargetAccount, err = m.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
mention.TargetAccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating mention target account: %w", err)
|
||||
}
|
||||
|
||||
return mention, nil
|
||||
}
|
||||
|
||||
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
|
||||
|
|
@ -73,3 +100,25 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.
|
|||
|
||||
return mentions, nil
|
||||
}
|
||||
|
||||
func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error {
|
||||
return m.state.Caches.GTS.Mention().Store(mention, func() error {
|
||||
_, err := m.conn.NewInsert().Model(mention).Exec(ctx)
|
||||
return m.conn.ProcessError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error {
|
||||
if _, err := m.conn.
|
||||
NewDelete().
|
||||
Table("mentions").
|
||||
Where("? = ?", bun.Ident("id"), id).
|
||||
Exec(ctx); err != nil {
|
||||
return m.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Invalidate mention from the lookup cache.
|
||||
m.state.Caches.GTS.Mention().Invalidate("ID", id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
167
internal/db/bundb/migrations/20230328105630_chore_refactoring.go
Normal file
167
internal/db/bundb/migrations/20230328105630_chore_refactoring.go
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func init() {
|
||||
up := func(ctx context.Context, db *bun.DB) error {
|
||||
// To update unique constraint on public key, we need to migrate accounts into a new table.
|
||||
// See section 7 here: https://www.sqlite.org/lang_altertable.html
|
||||
|
||||
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
|
||||
// Create the new accounts table.
|
||||
if _, err := tx.
|
||||
NewCreateTable().
|
||||
ModelTableExpr("new_accounts").
|
||||
Model(>smodel.Account{}).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If we don't specify columns explicitly,
|
||||
// Postgres gives the following error when
|
||||
// transferring accounts to new_accounts:
|
||||
//
|
||||
// ERROR: column "fetched_at" is of type timestamp with time zone but expression is of type character varying at character 35
|
||||
// HINT: You will need to rewrite or cast the expression.
|
||||
//
|
||||
// Rather than do funky casting to fix this,
|
||||
// it's simpler to just specify all columns.
|
||||
columns := []string{
|
||||
"id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"fetched_at",
|
||||
"username",
|
||||
"domain",
|
||||
"avatar_media_attachment_id",
|
||||
"avatar_remote_url",
|
||||
"header_media_attachment_id",
|
||||
"header_remote_url",
|
||||
"display_name",
|
||||
"emojis",
|
||||
"fields",
|
||||
"note",
|
||||
"note_raw",
|
||||
"memorial",
|
||||
"also_known_as",
|
||||
"moved_to_account_id",
|
||||
"bot",
|
||||
"reason",
|
||||
"locked",
|
||||
"discoverable",
|
||||
"privacy",
|
||||
"sensitive",
|
||||
"language",
|
||||
"status_content_type",
|
||||
"custom_css",
|
||||
"uri",
|
||||
"url",
|
||||
"inbox_uri",
|
||||
"shared_inbox_uri",
|
||||
"outbox_uri",
|
||||
"following_uri",
|
||||
"followers_uri",
|
||||
"featured_collection_uri",
|
||||
"actor_type",
|
||||
"private_key",
|
||||
"public_key",
|
||||
"public_key_uri",
|
||||
"sensitized_at",
|
||||
"silenced_at",
|
||||
"suspended_at",
|
||||
"hide_collections",
|
||||
"suspension_origin",
|
||||
"enable_rss",
|
||||
}
|
||||
|
||||
// Copy all accounts to the new table.
|
||||
if _, err := tx.
|
||||
NewInsert().
|
||||
Table("new_accounts").
|
||||
Table("accounts").
|
||||
Column(columns...).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Drop the old table.
|
||||
if _, err := tx.
|
||||
NewDropTable().
|
||||
Table("accounts").
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rename new table to old table.
|
||||
if _, err := tx.
|
||||
ExecContext(
|
||||
ctx,
|
||||
"ALTER TABLE ? RENAME TO ?",
|
||||
bun.Ident("new_accounts"),
|
||||
bun.Ident("accounts"),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add all account indexes to the new table.
|
||||
for index, columns := range map[string][]string{
|
||||
// Standard indices.
|
||||
"accounts_id_idx": {"id"},
|
||||
"accounts_suspended_at_idx": {"suspended_at"},
|
||||
"accounts_domain_idx": {"domain"},
|
||||
"accounts_username_domain_idx": {"username", "domain"},
|
||||
// URI indices.
|
||||
"accounts_uri_idx": {"uri"},
|
||||
"accounts_url_idx": {"url"},
|
||||
"accounts_inbox_uri_idx": {"inbox_uri"},
|
||||
"accounts_outbox_uri_idx": {"outbox_uri"},
|
||||
"accounts_followers_uri_idx": {"followers_uri"},
|
||||
"accounts_following_uri_idx": {"following_uri"},
|
||||
"accounts_public_key_uri_idx": {"public_key_uri"},
|
||||
} {
|
||||
if _, err := tx.
|
||||
NewCreateIndex().
|
||||
Table("accounts").
|
||||
Index(index).
|
||||
Column(columns...).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
down := func(ctx context.Context, db *bun.DB) error {
|
||||
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := Migrations.Register(up, down); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
|
@ -33,7 +33,7 @@ type notificationDB struct {
|
|||
state *state.State
|
||||
}
|
||||
|
||||
func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
|
||||
func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
|
||||
return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) {
|
||||
var notif gtsmodel.Notification
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo
|
|||
}, id)
|
||||
}
|
||||
|
||||
func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
|
||||
func (n *notificationDB) GetAccountNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
|
||||
// Ensure reasonable
|
||||
if limit < 0 {
|
||||
limit = 0
|
||||
|
|
@ -92,7 +92,7 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
|
|||
// reason for this is that for each notif, we can instead get it from our cache if it's cached
|
||||
for _, id := range notifIDs {
|
||||
// Attempt fetch from DB
|
||||
notif, err := n.GetNotification(ctx, id)
|
||||
notif, err := n.GetNotificationByID(ctx, id)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting notification %q: %v", id, err)
|
||||
continue
|
||||
|
|
@ -105,7 +105,14 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
|
|||
return notifs, nil
|
||||
}
|
||||
|
||||
func (n *notificationDB) DeleteNotification(ctx context.Context, id string) db.Error {
|
||||
func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error {
|
||||
return n.state.Caches.GTS.Notification().Store(notif, func() error {
|
||||
_, err := n.conn.NewInsert().Model(notif).Exec(ctx)
|
||||
return n.conn.ProcessError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) db.Error {
|
||||
if _, err := n.conn.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
|
||||
|
|
@ -118,19 +125,23 @@ func (n *notificationDB) DeleteNotification(ctx context.Context, id string) db.E
|
|||
return nil
|
||||
}
|
||||
|
||||
func (n *notificationDB) DeleteNotifications(ctx context.Context, targetAccountID string, originAccountID string) db.Error {
|
||||
func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) db.Error {
|
||||
if targetAccountID == "" && originAccountID == "" {
|
||||
return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set")
|
||||
}
|
||||
|
||||
// Capture notification IDs in a RETURNING statement.
|
||||
ids := []string{}
|
||||
var ids []string
|
||||
|
||||
q := n.conn.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
|
||||
Returning("?", bun.Ident("id"))
|
||||
|
||||
if len(types) > 0 {
|
||||
q = q.Where("? IN (?)", bun.Ident("notification.notification_type"), bun.In(types))
|
||||
}
|
||||
|
||||
if targetAccountID != "" {
|
||||
q = q.Where("? = ?", bun.Ident("notification.target_account_id"), targetAccountID)
|
||||
}
|
||||
|
|
@ -153,7 +164,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, targetAccountI
|
|||
|
||||
func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) db.Error {
|
||||
// Capture notification IDs in a RETURNING statement.
|
||||
ids := []string{}
|
||||
var ids []string
|
||||
|
||||
q := n.conn.
|
||||
NewDelete().
|
||||
|
|
|
|||
|
|
@ -85,11 +85,11 @@ type NotificationTestSuite struct {
|
|||
BunDBStandardTestSuite
|
||||
}
|
||||
|
||||
func (suite *NotificationTestSuite) TestGetNotificationsWithSpam() {
|
||||
func (suite *NotificationTestSuite) TestGetAccountNotificationsWithSpam() {
|
||||
suite.spamNotifs()
|
||||
testAccount := suite.testAccounts["local_account_1"]
|
||||
before := time.Now()
|
||||
notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
|
||||
notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
|
||||
suite.NoError(err)
|
||||
timeTaken := time.Since(before)
|
||||
fmt.Printf("\n\n\n withSpam: got %d notifications in %s\n\n\n", len(notifications), timeTaken)
|
||||
|
|
@ -100,10 +100,10 @@ func (suite *NotificationTestSuite) TestGetNotificationsWithSpam() {
|
|||
}
|
||||
}
|
||||
|
||||
func (suite *NotificationTestSuite) TestGetNotificationsWithoutSpam() {
|
||||
func (suite *NotificationTestSuite) TestGetAccountNotificationsWithoutSpam() {
|
||||
testAccount := suite.testAccounts["local_account_1"]
|
||||
before := time.Now()
|
||||
notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
|
||||
notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
|
||||
suite.NoError(err)
|
||||
timeTaken := time.Since(before)
|
||||
fmt.Printf("\n\n\n withoutSpam: got %d notifications in %s\n\n\n", len(notifications), timeTaken)
|
||||
|
|
@ -117,10 +117,10 @@ func (suite *NotificationTestSuite) TestGetNotificationsWithoutSpam() {
|
|||
func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() {
|
||||
suite.spamNotifs()
|
||||
testAccount := suite.testAccounts["local_account_1"]
|
||||
err := suite.db.DeleteNotifications(context.Background(), testAccount.ID, "")
|
||||
err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, "")
|
||||
suite.NoError(err)
|
||||
|
||||
notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
|
||||
notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(notifications)
|
||||
suite.Empty(notifications)
|
||||
|
|
@ -129,10 +129,10 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() {
|
|||
func (suite *NotificationTestSuite) TestDeleteNotificationsWithTwoAccounts() {
|
||||
suite.spamNotifs()
|
||||
testAccount := suite.testAccounts["local_account_1"]
|
||||
err := suite.db.DeleteNotifications(context.Background(), testAccount.ID, "")
|
||||
err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, "")
|
||||
suite.NoError(err)
|
||||
|
||||
notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
|
||||
notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(notifications)
|
||||
suite.Empty(notifications)
|
||||
|
|
@ -146,7 +146,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithTwoAccounts() {
|
|||
func (suite *NotificationTestSuite) TestDeleteNotificationsOriginatingFromAccount() {
|
||||
testAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
if err := suite.db.DeleteNotifications(context.Background(), "", testAccount.ID); err != nil {
|
||||
if err := suite.db.DeleteNotifications(context.Background(), nil, "", testAccount.ID); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
|
|
@ -166,7 +166,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsOriginatingFromAndTar
|
|||
originAccount := suite.testAccounts["local_account_2"]
|
||||
targetAccount := suite.testAccounts["admin_account"]
|
||||
|
||||
if err := suite.db.DeleteNotifications(context.Background(), targetAccount.ID, originAccount.ID); err != nil {
|
||||
if err := suite.db.DeleteNotifications(context.Background(), nil, targetAccount.ID, originAccount.ID); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
|
@ -34,603 +34,212 @@ type relationshipDB struct {
|
|||
state *state.State
|
||||
}
|
||||
|
||||
func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
|
||||
// Look for a block in direction of account1->account2
|
||||
block1, err := r.getBlock(ctx, account1, account2)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if block1 != nil {
|
||||
// account1 blocks account2
|
||||
return true, nil
|
||||
} else if !eitherDirection {
|
||||
// Don't check for mutli-directional
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Look for a block in direction of account2->account1
|
||||
block2, err := r.getBlock(ctx, account2, account1)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return (block2 != nil), nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
|
||||
// Fetch block from database
|
||||
block, err := r.getBlock(ctx, account1, account2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set the block originating account
|
||||
block.Account, err = r.state.DB.GetAccountByID(ctx, block.AccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set the block target account
|
||||
block.TargetAccount, err = r.state.DB.GetAccountByID(ctx, block.TargetAccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return block, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) getBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
|
||||
return r.state.Caches.GTS.Block().Load("AccountID.TargetAccountID", func() (*gtsmodel.Block, error) {
|
||||
var block gtsmodel.Block
|
||||
|
||||
q := r.conn.NewSelect().Model(&block).
|
||||
Where("? = ?", bun.Ident("block.account_id"), account1).
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), account2)
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return &block, nil
|
||||
}, account1, account2)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) db.Error {
|
||||
return r.state.Caches.GTS.Block().Store(block, func() error {
|
||||
_, err := r.conn.NewInsert().Model(block).Exec(ctx)
|
||||
return r.conn.ProcessError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) db.Error {
|
||||
if _, err := r.conn.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
|
||||
Where("? = ?", bun.Ident("block.id"), id).
|
||||
Exec(ctx); err != nil {
|
||||
return r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Drop any old value from cache by this ID
|
||||
r.state.Caches.GTS.Block().Invalidate("ID", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) db.Error {
|
||||
if _, err := r.conn.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
|
||||
Where("? = ?", bun.Ident("block.uri"), uri).
|
||||
Exec(ctx); err != nil {
|
||||
return r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Drop any old value from cache by this URI
|
||||
r.state.Caches.GTS.Block().Invalidate("URI", uri)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) DeleteBlocksByOriginAccountID(ctx context.Context, originAccountID string) db.Error {
|
||||
blockIDs := []string{}
|
||||
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
|
||||
Column("block.id").
|
||||
Where("? = ?", bun.Ident("block.account_id"), originAccountID)
|
||||
|
||||
if err := q.Scan(ctx, &blockIDs); err != nil {
|
||||
return r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
for _, blockID := range blockIDs {
|
||||
if err := r.DeleteBlockByID(ctx, blockID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) DeleteBlocksByTargetAccountID(ctx context.Context, targetAccountID string) db.Error {
|
||||
blockIDs := []string{}
|
||||
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
|
||||
Column("block.id").
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID)
|
||||
|
||||
if err := q.Scan(ctx, &blockIDs); err != nil {
|
||||
return r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
for _, blockID := range blockIDs {
|
||||
if err := r.DeleteBlockByID(ctx, blockID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
|
||||
rel := >smodel.Relationship{
|
||||
ID: targetAccount,
|
||||
var rel gtsmodel.Relationship
|
||||
rel.ID = targetAccount
|
||||
|
||||
// check if the requesting follows the target
|
||||
follow, err := r.GetFollow(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
requestingAccount,
|
||||
targetAccount,
|
||||
)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, fmt.Errorf("GetRelationship: error fetching follow: %w", err)
|
||||
}
|
||||
|
||||
// check if the requesting account follows the target account
|
||||
follow := >smodel.Follow{}
|
||||
if err := r.conn.
|
||||
NewSelect().
|
||||
Model(follow).
|
||||
Column("follow.show_reblogs", "follow.notify").
|
||||
Where("? = ?", bun.Ident("follow.account_id"), requestingAccount).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount).
|
||||
Limit(1).
|
||||
Scan(ctx); err != nil {
|
||||
if err := r.conn.ProcessError(err); err != db.ErrNoEntries {
|
||||
return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err)
|
||||
}
|
||||
// no follow exists so these are all false
|
||||
rel.Following = false
|
||||
rel.ShowingReblogs = false
|
||||
rel.Notifying = false
|
||||
} else {
|
||||
if follow != nil {
|
||||
// follow exists so we can fill these fields out...
|
||||
rel.Following = true
|
||||
rel.ShowingReblogs = *follow.ShowReblogs
|
||||
rel.Notifying = *follow.Notify
|
||||
}
|
||||
|
||||
// check if the target account follows the requesting account
|
||||
followedByQ := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
|
||||
Column("follow.id").
|
||||
Where("? = ?", bun.Ident("follow.account_id"), targetAccount).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount)
|
||||
followedBy, err := r.conn.Exists(ctx, followedByQ)
|
||||
// check if the target follows the requesting
|
||||
rel.FollowedBy, err = r.IsFollowing(ctx,
|
||||
targetAccount,
|
||||
requestingAccount,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err)
|
||||
return nil, fmt.Errorf("GetRelationship: error checking followedBy: %w", err)
|
||||
}
|
||||
rel.FollowedBy = followedBy
|
||||
|
||||
// check if there's a pending following request from requesting account to target account
|
||||
requestedQ := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Column("follow_request.id").
|
||||
Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount)
|
||||
requested, err := r.conn.Exists(ctx, requestedQ)
|
||||
// check if requesting has follow requested target
|
||||
rel.Requested, err = r.IsFollowRequested(ctx,
|
||||
requestingAccount,
|
||||
targetAccount,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err)
|
||||
return nil, fmt.Errorf("GetRelationship: error checking requested: %w", err)
|
||||
}
|
||||
rel.Requested = requested
|
||||
|
||||
// check if the requesting account is blocking the target account
|
||||
blockA2T, err := r.getBlock(ctx, requestingAccount, targetAccount)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err)
|
||||
rel.Blocking, err = r.IsBlocked(ctx, requestingAccount, targetAccount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetRelationship: error checking blocking: %w", err)
|
||||
}
|
||||
rel.Blocking = (blockA2T != nil)
|
||||
|
||||
// check if the requesting account is blocked by the target account
|
||||
blockT2A, err := r.getBlock(ctx, targetAccount, requestingAccount)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err)
|
||||
}
|
||||
rel.BlockedBy = (blockT2A != nil)
|
||||
|
||||
return rel, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
|
||||
if sourceAccount == nil || targetAccount == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
|
||||
Column("follow.id").
|
||||
Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID)
|
||||
|
||||
return r.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
|
||||
if sourceAccount == nil || targetAccount == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Column("follow_request.id").
|
||||
Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID)
|
||||
|
||||
return r.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
|
||||
if account1 == nil || account2 == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// make sure account 1 follows account 2
|
||||
f1, err := r.IsFollowing(ctx, account1, account2)
|
||||
rel.BlockedBy, err = r.IsBlocked(ctx, targetAccount, requestingAccount)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %w", err)
|
||||
}
|
||||
|
||||
// make sure account 2 follows account 1
|
||||
f2, err := r.IsFollowing(ctx, account2, account1)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return f1 && f2, nil
|
||||
return &rel, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
|
||||
// Get original follow request.
|
||||
var followRequestID string
|
||||
if err := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Column("follow_request.id").
|
||||
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
|
||||
Scan(ctx, &followRequestID); err != nil {
|
||||
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
|
||||
var followIDs []string
|
||||
if err := newSelectFollows(r.conn, accountID).
|
||||
Scan(ctx, &followIDs); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
followRequest, err := r.getFollowRequest(ctx, followRequestID)
|
||||
if err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Create a new follow to 'replace'
|
||||
// the original follow request with.
|
||||
follow := >smodel.Follow{
|
||||
ID: followRequest.ID,
|
||||
AccountID: originAccountID,
|
||||
Account: followRequest.Account,
|
||||
TargetAccountID: targetAccountID,
|
||||
TargetAccount: followRequest.TargetAccount,
|
||||
URI: followRequest.URI,
|
||||
}
|
||||
|
||||
// If the follow already exists, just
|
||||
// replace the URI with the new one.
|
||||
if _, err := r.conn.
|
||||
NewInsert().
|
||||
Model(follow).
|
||||
On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Delete original follow request.
|
||||
if _, err := r.conn.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Delete original follow request notification.
|
||||
if err := r.deleteFollowRequestNotif(ctx, originAccountID, targetAccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// return the new follow
|
||||
return follow, nil
|
||||
return r.GetFollowsByIDs(ctx, followIDs)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) {
|
||||
// Get original follow request.
|
||||
var followRequestID string
|
||||
if err := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Column("follow_request.id").
|
||||
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
|
||||
Scan(ctx, &followRequestID); err != nil {
|
||||
func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
|
||||
var followIDs []string
|
||||
if err := newSelectLocalFollows(r.conn, accountID).
|
||||
Scan(ctx, &followIDs); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
followRequest, err := r.getFollowRequest(ctx, followRequestID)
|
||||
if err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Delete original follow request.
|
||||
if _, err := r.conn.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Delete original follow request notification.
|
||||
if err := r.deleteFollowRequestNotif(ctx, originAccountID, targetAccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the now deleted follow request.
|
||||
return followRequest, nil
|
||||
return r.GetFollowsByIDs(ctx, followIDs)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) deleteFollowRequestNotif(ctx context.Context, originAccountID string, targetAccountID string) db.Error {
|
||||
var id string
|
||||
if err := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
|
||||
Column("notification.id").
|
||||
Where("? = ?", bun.Ident("notification.origin_account_id"), originAccountID).
|
||||
Where("? = ?", bun.Ident("notification.target_account_id"), targetAccountID).
|
||||
Where("? = ?", bun.Ident("notification.notification_type"), gtsmodel.NotificationFollowRequest).
|
||||
Limit(1). // There should only be one!
|
||||
Scan(ctx, &id); err != nil {
|
||||
err = r.conn.ProcessError(err)
|
||||
if errors.Is(err, db.ErrNoEntries) {
|
||||
// If no entries, the notif didn't
|
||||
// exist anyway so nothing to do here.
|
||||
return nil
|
||||
}
|
||||
// Return on real error.
|
||||
return err
|
||||
}
|
||||
|
||||
return r.state.DB.DeleteNotification(ctx, id)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) getFollow(ctx context.Context, id string) (*gtsmodel.Follow, db.Error) {
|
||||
follow := >smodel.Follow{}
|
||||
|
||||
err := r.conn.
|
||||
NewSelect().
|
||||
Model(follow).
|
||||
Where("? = ?", bun.Ident("follow.id"), id).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
|
||||
var followIDs []string
|
||||
if err := newSelectFollowers(r.conn, accountID).
|
||||
Scan(ctx, &followIDs); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
follow.Account, err = r.state.DB.GetAccountByID(ctx, follow.AccountID)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting follow account %q: %v", follow.AccountID, err)
|
||||
}
|
||||
|
||||
follow.TargetAccount, err = r.state.DB.GetAccountByID(ctx, follow.TargetAccountID)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting follow target account %q: %v", follow.TargetAccountID, err)
|
||||
}
|
||||
|
||||
return follow, nil
|
||||
return r.GetFollowsByIDs(ctx, followIDs)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetLocalFollowersIDs(ctx context.Context, targetAccountID string) ([]string, db.Error) {
|
||||
accountIDs := []string{}
|
||||
func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
|
||||
var followIDs []string
|
||||
if err := newSelectLocalFollowers(r.conn, accountID).
|
||||
Scan(ctx, &followIDs); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
return r.GetFollowsByIDs(ctx, followIDs)
|
||||
}
|
||||
|
||||
// Select only the account ID of each follow.
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
|
||||
ColumnExpr("? AS ?", bun.Ident("follow.account_id"), bun.Ident("account_id")).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
|
||||
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) {
|
||||
n, err := newSelectFollows(r.conn, accountID).Count(ctx)
|
||||
return n, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Join on accounts table to select only
|
||||
// those with NULL domain (local accounts).
|
||||
q = q.
|
||||
Join("JOIN ? AS ? ON ? = ?",
|
||||
bun.Ident("accounts"),
|
||||
bun.Ident("account"),
|
||||
bun.Ident("follow.account_id"),
|
||||
bun.Ident("account.id"),
|
||||
func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) {
|
||||
n, err := newSelectLocalFollows(r.conn, accountID).Count(ctx)
|
||||
return n, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
|
||||
n, err := newSelectFollowers(r.conn, accountID).Count(ctx)
|
||||
return n, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) {
|
||||
n, err := newSelectLocalFollowers(r.conn, accountID).Count(ctx)
|
||||
return n, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
|
||||
var followReqIDs []string
|
||||
if err := newSelectFollowRequests(r.conn, accountID).
|
||||
Scan(ctx, &followReqIDs); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
return r.GetFollowRequestsByIDs(ctx, followReqIDs)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
|
||||
var followReqIDs []string
|
||||
if err := newSelectFollowRequesting(r.conn, accountID).
|
||||
Scan(ctx, &followReqIDs); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
return r.GetFollowRequestsByIDs(ctx, followReqIDs)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) {
|
||||
n, err := newSelectFollowRequests(r.conn, accountID).Count(ctx)
|
||||
return n, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) {
|
||||
n, err := newSelectFollowRequesting(r.conn, accountID).Count(ctx)
|
||||
return n, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.
|
||||
func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery {
|
||||
return conn.NewSelect().
|
||||
TableExpr("?", bun.Ident("follow_requests")).
|
||||
ColumnExpr("?", bun.Ident("id")).
|
||||
Where("? = ?", bun.Ident("target_account_id"), accountID).
|
||||
OrderExpr("? DESC", bun.Ident("updated_at"))
|
||||
}
|
||||
|
||||
// newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID.
|
||||
func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery {
|
||||
return conn.NewSelect().
|
||||
TableExpr("?", bun.Ident("follow_requests")).
|
||||
ColumnExpr("?", bun.Ident("id")).
|
||||
Where("? = ?", bun.Ident("target_account_id"), accountID).
|
||||
OrderExpr("? DESC", bun.Ident("updated_at"))
|
||||
}
|
||||
|
||||
// newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID.
|
||||
func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery {
|
||||
return conn.NewSelect().
|
||||
Table("follows").
|
||||
Column("id").
|
||||
Where("? = ?", bun.Ident("account_id"), accountID).
|
||||
OrderExpr("? DESC", bun.Ident("updated_at"))
|
||||
}
|
||||
|
||||
// newSelectLocalFollows returns a new select query for all rows in the follows table with
|
||||
// account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
|
||||
func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery {
|
||||
return conn.NewSelect().
|
||||
Table("follows").
|
||||
Column("id").
|
||||
Where("? = ? AND ? IN (?)",
|
||||
bun.Ident("account_id"),
|
||||
accountID,
|
||||
bun.Ident("target_account_id"),
|
||||
conn.NewSelect().
|
||||
Table("accounts").
|
||||
Column("id").
|
||||
Where("? IS NULL", bun.Ident("domain")),
|
||||
).
|
||||
Where("? IS NULL", bun.Ident("account.domain"))
|
||||
|
||||
// We don't *really* need to order these,
|
||||
// but it makes it more consistent to do so.
|
||||
q = q.Order("account_id DESC")
|
||||
|
||||
if err := q.Scan(ctx, &accountIDs); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return accountIDs, nil
|
||||
OrderExpr("? DESC", bun.Ident("updated_at"))
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetFollows(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.Follow, db.Error) {
|
||||
ids := []string{}
|
||||
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
|
||||
Column("follow.id").
|
||||
Order("follow.updated_at DESC")
|
||||
|
||||
if accountID != "" {
|
||||
q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
|
||||
}
|
||||
|
||||
if targetAccountID != "" {
|
||||
q = q.Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
|
||||
}
|
||||
|
||||
if err := q.Scan(ctx, &ids); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
follows := make([]*gtsmodel.Follow, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
follow, err := r.getFollow(ctx, id)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting follow %q: %v", id, err)
|
||||
continue
|
||||
}
|
||||
|
||||
follows = append(follows, follow)
|
||||
}
|
||||
|
||||
return follows, nil
|
||||
// newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID.
|
||||
func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
|
||||
return conn.NewSelect().
|
||||
Table("follows").
|
||||
Column("id").
|
||||
Where("? = ?", bun.Ident("target_account_id"), accountID).
|
||||
OrderExpr("? DESC", bun.Ident("updated_at"))
|
||||
}
|
||||
|
||||
func (r *relationshipDB) CountFollows(ctx context.Context, accountID string, targetAccountID string) (int, db.Error) {
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
|
||||
Column("follow.id")
|
||||
|
||||
if accountID != "" {
|
||||
q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
|
||||
}
|
||||
|
||||
if targetAccountID != "" {
|
||||
q = q.Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
|
||||
}
|
||||
|
||||
return q.Count(ctx)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) getFollowRequest(ctx context.Context, id string) (*gtsmodel.FollowRequest, db.Error) {
|
||||
followRequest := >smodel.FollowRequest{}
|
||||
|
||||
err := r.conn.
|
||||
NewSelect().
|
||||
Model(followRequest).
|
||||
Where("? = ?", bun.Ident("follow_request.id"), id).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
followRequest.Account, err = r.state.DB.GetAccountByID(ctx, followRequest.AccountID)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting follow request account %q: %v", followRequest.AccountID, err)
|
||||
}
|
||||
|
||||
followRequest.TargetAccount, err = r.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting follow request target account %q: %v", followRequest.TargetAccountID, err)
|
||||
}
|
||||
|
||||
return followRequest, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetFollowRequests(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.FollowRequest, db.Error) {
|
||||
ids := []string{}
|
||||
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Column("follow_request.id")
|
||||
|
||||
if accountID != "" {
|
||||
q = q.Where("? = ?", bun.Ident("follow_request.account_id"), accountID)
|
||||
}
|
||||
|
||||
if targetAccountID != "" {
|
||||
q = q.Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID)
|
||||
}
|
||||
|
||||
if err := q.Scan(ctx, &ids); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
followRequests := make([]*gtsmodel.FollowRequest, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
followRequest, err := r.getFollowRequest(ctx, id)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting follow request %q: %v", id, err)
|
||||
continue
|
||||
}
|
||||
|
||||
followRequests = append(followRequests, followRequest)
|
||||
}
|
||||
|
||||
return followRequests, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) CountFollowRequests(ctx context.Context, accountID string, targetAccountID string) (int, db.Error) {
|
||||
q := r.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Column("follow_request.id").
|
||||
Order("follow_request.updated_at DESC")
|
||||
|
||||
if accountID != "" {
|
||||
q = q.Where("? = ?", bun.Ident("follow_request.account_id"), accountID)
|
||||
}
|
||||
|
||||
if targetAccountID != "" {
|
||||
q = q.Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID)
|
||||
}
|
||||
|
||||
return q.Count(ctx)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) Unfollow(ctx context.Context, originAccountID string, targetAccountID string) (string, db.Error) {
|
||||
uri := new(string)
|
||||
|
||||
_, err := r.conn.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
|
||||
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID).
|
||||
Where("? = ?", bun.Ident("follow.account_id"), originAccountID).
|
||||
Returning("?", bun.Ident("uri")).Exec(ctx, uri)
|
||||
|
||||
// Only return proper errors.
|
||||
if err = r.conn.ProcessError(err); err != db.ErrNoEntries {
|
||||
return *uri, err
|
||||
}
|
||||
|
||||
return *uri, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) UnfollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (string, db.Error) {
|
||||
uri := new(string)
|
||||
|
||||
_, err := r.conn.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
|
||||
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
|
||||
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
|
||||
Returning("?", bun.Ident("uri")).Exec(ctx, uri)
|
||||
|
||||
// Only return proper errors.
|
||||
if err = r.conn.ProcessError(err); err != db.ErrNoEntries {
|
||||
return *uri, err
|
||||
}
|
||||
|
||||
return *uri, nil
|
||||
// newSelectLocalFollowers returns a new select query for all rows in the follows table with
|
||||
// target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
|
||||
func newSelectLocalFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
|
||||
return conn.NewSelect().
|
||||
Table("follows").
|
||||
Column("id").
|
||||
Where("? = ? AND ? IN (?)",
|
||||
bun.Ident("target_account_id"),
|
||||
accountID,
|
||||
bun.Ident("account_id"),
|
||||
conn.NewSelect().
|
||||
Table("accounts").
|
||||
Column("id").
|
||||
Where("? IS NULL", bun.Ident("domain")),
|
||||
).
|
||||
OrderExpr("? DESC", bun.Ident("updated_at"))
|
||||
}
|
||||
|
|
|
|||
218
internal/db/bundb/relationship_block.go
Normal file
218
internal/db/bundb/relationship_block.go
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package bundb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func (r *relationshipDB) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
|
||||
block, err := r.GetBlock(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
sourceAccountID,
|
||||
targetAccountID,
|
||||
)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return false, err
|
||||
}
|
||||
return (block != nil), nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error) {
|
||||
// Look for a block in direction of account1->account2
|
||||
b1, err := r.IsBlocked(ctx, accountID1, accountID2)
|
||||
if err != nil || b1 {
|
||||
return true, err
|
||||
}
|
||||
|
||||
// Look for a block in direction of account2->account1
|
||||
b2, err := r.IsBlocked(ctx, accountID2, accountID1)
|
||||
if err != nil || b2 {
|
||||
return true, err
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetBlockByID(ctx context.Context, id string) (*gtsmodel.Block, error) {
|
||||
return r.getBlock(
|
||||
ctx,
|
||||
"ID",
|
||||
func(block *gtsmodel.Block) error {
|
||||
return r.conn.NewSelect().Model(block).
|
||||
Where("? = ?", bun.Ident("block.id"), id).
|
||||
Scan(ctx)
|
||||
},
|
||||
id,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmodel.Block, error) {
|
||||
return r.getBlock(
|
||||
ctx,
|
||||
"URI",
|
||||
func(block *gtsmodel.Block) error {
|
||||
return r.conn.NewSelect().Model(block).
|
||||
Where("? = ?", bun.Ident("block.uri"), uri).
|
||||
Scan(ctx)
|
||||
},
|
||||
uri,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) {
|
||||
return r.getBlock(
|
||||
ctx,
|
||||
"AccountID.TargetAccountID",
|
||||
func(block *gtsmodel.Block) error {
|
||||
return r.conn.NewSelect().Model(block).
|
||||
Where("? = ?", bun.Ident("block.account_id"), sourceAccountID).
|
||||
Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID).
|
||||
Scan(ctx)
|
||||
},
|
||||
sourceAccountID,
|
||||
targetAccountID,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) {
|
||||
// Fetch block from cache with loader callback
|
||||
block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) {
|
||||
var block gtsmodel.Block
|
||||
|
||||
// Not cached! Perform database query
|
||||
if err := dbQuery(&block); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return &block, nil
|
||||
}, keyParts...)
|
||||
if err != nil {
|
||||
// already processe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gtscontext.Barebones(ctx) {
|
||||
// Only a barebones model was requested.
|
||||
return block, nil
|
||||
}
|
||||
|
||||
// Set the block source account
|
||||
block.Account, err = r.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
block.AccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting block source account: %w", err)
|
||||
}
|
||||
|
||||
// Set the block target account
|
||||
block.TargetAccount, err = r.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
block.TargetAccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting block target account: %w", err)
|
||||
}
|
||||
|
||||
return block, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error {
|
||||
err := r.state.Caches.GTS.Block().Store(block, func() error {
|
||||
_, err := r.conn.NewInsert().Model(block).Exec(ctx)
|
||||
return r.conn.ProcessError(err)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate block origin account ID cached visibility.
|
||||
r.state.Caches.Visibility.Invalidate("ItemID", block.AccountID)
|
||||
r.state.Caches.Visibility.Invalidate("RequesterID", block.AccountID)
|
||||
|
||||
// Invalidate block target account ID cached visibility.
|
||||
r.state.Caches.Visibility.Invalidate("ItemID", block.TargetAccountID)
|
||||
r.state.Caches.Visibility.Invalidate("RequesterID", block.TargetAccountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
|
||||
block, err := r.GetBlockByID(gtscontext.SetBarebones(ctx), id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.deleteBlock(ctx, block)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error {
|
||||
block, err := r.GetBlockByURI(gtscontext.SetBarebones(ctx), uri)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.deleteBlock(ctx, block)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) deleteBlock(ctx context.Context, block *gtsmodel.Block) error {
|
||||
if _, err := r.conn.
|
||||
NewDelete().
|
||||
Table("blocks").
|
||||
Where("? = ?", bun.Ident("id"), block.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Invalidate block from cache lookups.
|
||||
r.state.Caches.GTS.Block().Invalidate("ID", block.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID string) error {
|
||||
var blockIDs []string
|
||||
|
||||
if err := r.conn.NewSelect().
|
||||
Table("blocks").
|
||||
ColumnExpr("?", bun.Ident("id")).
|
||||
WhereOr("? = ? OR ? = ?",
|
||||
bun.Ident("account_id"),
|
||||
accountID,
|
||||
bun.Ident("target_account_id"),
|
||||
accountID,
|
||||
).
|
||||
Scan(ctx, &blockIDs); err != nil {
|
||||
return r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
for _, id := range blockIDs {
|
||||
if err := r.DeleteBlockByID(ctx, id); err != nil {
|
||||
log.Errorf(ctx, "error deleting block %q: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
243
internal/db/bundb/relationship_follow.go
Normal file
243
internal/db/bundb/relationship_follow.go
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package bundb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func (r *relationshipDB) GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error) {
|
||||
return r.getFollow(
|
||||
ctx,
|
||||
"ID",
|
||||
func(follow *gtsmodel.Follow) error {
|
||||
return r.conn.NewSelect().
|
||||
Model(follow).
|
||||
Where("? = ?", bun.Ident("id"), id).
|
||||
Scan(ctx)
|
||||
},
|
||||
id,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmodel.Follow, error) {
|
||||
return r.getFollow(
|
||||
ctx,
|
||||
"URI",
|
||||
func(follow *gtsmodel.Follow) error {
|
||||
return r.conn.NewSelect().
|
||||
Model(follow).
|
||||
Where("? = ?", bun.Ident("uri"), uri).
|
||||
Scan(ctx)
|
||||
},
|
||||
uri,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
|
||||
return r.getFollow(
|
||||
ctx,
|
||||
"AccountID.TargetAccountID",
|
||||
func(follow *gtsmodel.Follow) error {
|
||||
return r.conn.NewSelect().
|
||||
Model(follow).
|
||||
Where("? = ?", bun.Ident("account_id"), sourceAccountID).
|
||||
Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
|
||||
Scan(ctx)
|
||||
},
|
||||
sourceAccountID,
|
||||
targetAccountID,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) {
|
||||
// Preallocate slice of expected length.
|
||||
follows := make([]*gtsmodel.Follow, 0, len(ids))
|
||||
|
||||
for _, id := range ids {
|
||||
// Fetch follow model for this ID.
|
||||
follow, err := r.GetFollowByID(ctx, id)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting follow %q: %v", id, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Append to return slice.
|
||||
follows = append(follows, follow)
|
||||
}
|
||||
|
||||
return follows, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
|
||||
follow, err := r.GetFollow(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
sourceAccountID,
|
||||
targetAccountID,
|
||||
)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return false, err
|
||||
}
|
||||
return (follow != nil), nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 string, accountID2 string) (bool, db.Error) {
|
||||
// make sure account 1 follows account 2
|
||||
f1, err := r.IsFollowing(ctx,
|
||||
accountID1,
|
||||
accountID2,
|
||||
)
|
||||
if !f1 /* f1 = false when err != nil */ {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// make sure account 2 follows account 1
|
||||
f2, err := r.IsFollowing(ctx,
|
||||
accountID2,
|
||||
accountID1,
|
||||
)
|
||||
if !f2 /* f2 = false when err != nil */ {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) {
|
||||
// Fetch follow from database cache with loader callback
|
||||
follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) {
|
||||
var follow gtsmodel.Follow
|
||||
|
||||
// Not cached! Perform database query
|
||||
if err := dbQuery(&follow); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return &follow, nil
|
||||
}, keyParts...)
|
||||
if err != nil {
|
||||
// error already processed
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gtscontext.Barebones(ctx) {
|
||||
// Only a barebones model was requested.
|
||||
return follow, nil
|
||||
}
|
||||
|
||||
// Set the follow source account
|
||||
follow.Account, err = r.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
follow.AccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting follow source account: %w", err)
|
||||
}
|
||||
|
||||
// Set the follow target account
|
||||
follow.TargetAccount, err = r.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
follow.TargetAccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting follow target account: %w", err)
|
||||
}
|
||||
|
||||
return follow, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
|
||||
err := r.state.Caches.GTS.Follow().Store(follow, func() error {
|
||||
_, err := r.conn.NewInsert().Model(follow).Exec(ctx)
|
||||
return r.conn.ProcessError(err)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate follow origin account ID cached visibility.
|
||||
r.state.Caches.Visibility.Invalidate("ItemID", follow.AccountID)
|
||||
r.state.Caches.Visibility.Invalidate("RequesterID", follow.AccountID)
|
||||
|
||||
// Invalidate follow target account ID cached visibility.
|
||||
r.state.Caches.Visibility.Invalidate("ItemID", follow.TargetAccountID)
|
||||
r.state.Caches.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error {
|
||||
if _, err := r.conn.NewDelete().
|
||||
Table("follows").
|
||||
Where("? = ?", bun.Ident("id"), id).
|
||||
Exec(ctx); err != nil {
|
||||
return r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Invalidate follow from cache lookups.
|
||||
r.state.Caches.GTS.Follow().Invalidate("ID", id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error {
|
||||
if _, err := r.conn.NewDelete().
|
||||
Table("follows").
|
||||
Where("? = ?", bun.Ident("uri"), uri).
|
||||
Exec(ctx); err != nil {
|
||||
return r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Invalidate follow from cache lookups.
|
||||
r.state.Caches.GTS.Follow().Invalidate("URI", uri)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID string) error {
|
||||
var followIDs []string
|
||||
|
||||
if _, err := r.conn.
|
||||
NewDelete().
|
||||
Table("follows").
|
||||
WhereOr("? = ? OR ? = ?",
|
||||
bun.Ident("account_id"),
|
||||
accountID,
|
||||
bun.Ident("target_account_id"),
|
||||
accountID,
|
||||
).
|
||||
Returning("?", bun.Ident("id")).
|
||||
Exec(ctx, &followIDs); err != nil {
|
||||
return r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Invalidate each returned ID.
|
||||
for _, id := range followIDs {
|
||||
r.state.Caches.GTS.Follow().Invalidate("ID", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
293
internal/db/bundb/relationship_follow_req.go
Normal file
293
internal/db/bundb/relationship_follow_req.go
Normal file
|
|
@ -0,0 +1,293 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package bundb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func (r *relationshipDB) GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error) {
|
||||
return r.getFollowRequest(
|
||||
ctx,
|
||||
"ID",
|
||||
func(followReq *gtsmodel.FollowRequest) error {
|
||||
return r.conn.NewSelect().
|
||||
Model(followReq).
|
||||
Where("? = ?", bun.Ident("id"), id).
|
||||
Scan(ctx)
|
||||
},
|
||||
id,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string) (*gtsmodel.FollowRequest, error) {
|
||||
return r.getFollowRequest(
|
||||
ctx,
|
||||
"URI",
|
||||
func(followReq *gtsmodel.FollowRequest) error {
|
||||
return r.conn.NewSelect().
|
||||
Model(followReq).
|
||||
Where("? = ?", bun.Ident("uri"), uri).
|
||||
Scan(ctx)
|
||||
},
|
||||
uri,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) {
|
||||
return r.getFollowRequest(
|
||||
ctx,
|
||||
"AccountID.TargetAccountID",
|
||||
func(followReq *gtsmodel.FollowRequest) error {
|
||||
return r.conn.NewSelect().
|
||||
Model(followReq).
|
||||
Where("? = ?", bun.Ident("account_id"), sourceAccountID).
|
||||
Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
|
||||
Scan(ctx)
|
||||
},
|
||||
sourceAccountID,
|
||||
targetAccountID,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) {
|
||||
// Preallocate slice of expected length.
|
||||
followReqs := make([]*gtsmodel.FollowRequest, 0, len(ids))
|
||||
|
||||
for _, id := range ids {
|
||||
// Fetch follow request model for this ID.
|
||||
followReq, err := r.GetFollowRequestByID(ctx, id)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting follow request %q: %v", id, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Append to return slice.
|
||||
followReqs = append(followReqs, followReq)
|
||||
}
|
||||
|
||||
return followReqs, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
|
||||
followReq, err := r.GetFollowRequest(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
sourceAccountID,
|
||||
targetAccountID,
|
||||
)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return false, err
|
||||
}
|
||||
return (followReq != nil), nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) {
|
||||
// Fetch follow request from database cache with loader callback
|
||||
followReq, err := r.state.Caches.GTS.FollowRequest().Load(lookup, func() (*gtsmodel.FollowRequest, error) {
|
||||
var followReq gtsmodel.FollowRequest
|
||||
|
||||
// Not cached! Perform database query
|
||||
if err := dbQuery(&followReq); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return &followReq, nil
|
||||
}, keyParts...)
|
||||
if err != nil {
|
||||
// error already processed
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gtscontext.Barebones(ctx) {
|
||||
// Only a barebones model was requested.
|
||||
return followReq, nil
|
||||
}
|
||||
|
||||
// Set the follow request source account
|
||||
followReq.Account, err = r.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
followReq.AccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting follow request source account: %w", err)
|
||||
}
|
||||
|
||||
// Set the follow request target account
|
||||
followReq.TargetAccount, err = r.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
followReq.TargetAccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting follow request target account: %w", err)
|
||||
}
|
||||
|
||||
return followReq, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error {
|
||||
err := r.state.Caches.GTS.FollowRequest().Store(follow, func() error {
|
||||
_, err := r.conn.NewInsert().Model(follow).Exec(ctx)
|
||||
return r.conn.ProcessError(err)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate follow request origin account ID cached visibility.
|
||||
r.state.Caches.Visibility.Invalidate("ItemID", follow.AccountID)
|
||||
r.state.Caches.Visibility.Invalidate("RequesterID", follow.AccountID)
|
||||
|
||||
// Invalidate follow request target account ID cached visibility.
|
||||
r.state.Caches.Visibility.Invalidate("ItemID", follow.TargetAccountID)
|
||||
r.state.Caches.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
|
||||
// Get original follow request.
|
||||
followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a new follow to 'replace'
|
||||
// the original follow request with.
|
||||
follow := >smodel.Follow{
|
||||
ID: followReq.ID,
|
||||
AccountID: sourceAccountID,
|
||||
Account: followReq.Account,
|
||||
TargetAccountID: targetAccountID,
|
||||
TargetAccount: followReq.TargetAccount,
|
||||
URI: followReq.URI,
|
||||
}
|
||||
|
||||
// If the follow already exists, just
|
||||
// replace the URI with the new one.
|
||||
if _, err := r.conn.
|
||||
NewInsert().
|
||||
Model(follow).
|
||||
On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Delete original follow request.
|
||||
if _, err := r.conn.
|
||||
NewDelete().
|
||||
Table("follow_requests").
|
||||
Where("? = ?", bun.Ident("id"), followReq.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Invalidate follow request from cache lookups.
|
||||
r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID)
|
||||
|
||||
// Delete original follow request notification
|
||||
if err := r.state.DB.DeleteNotifications(ctx, []string{
|
||||
string(gtsmodel.NotificationFollowRequest),
|
||||
}, targetAccountID, sourceAccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return follow, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) db.Error {
|
||||
// Get original follow request.
|
||||
followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete original follow request.
|
||||
if _, err := r.conn.
|
||||
NewDelete().
|
||||
Table("follow_requests").
|
||||
Where("? = ?", bun.Ident("id"), followReq.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Delete original follow request notification
|
||||
return r.state.DB.DeleteNotifications(ctx, []string{
|
||||
string(gtsmodel.NotificationFollowRequest),
|
||||
}, targetAccountID, sourceAccountID)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error {
|
||||
if _, err := r.conn.NewDelete().
|
||||
Table("follow_requests").
|
||||
Where("? = ?", bun.Ident("id"), id).
|
||||
Exec(ctx); err != nil {
|
||||
return r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Invalidate follow request from cache lookups.
|
||||
r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error {
|
||||
if _, err := r.conn.NewDelete().
|
||||
Table("follow_requests").
|
||||
Where("? = ?", bun.Ident("uri"), uri).
|
||||
Exec(ctx); err != nil {
|
||||
return r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Invalidate follow request from cache lookups.
|
||||
r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accountID string) error {
|
||||
var followIDs []string
|
||||
|
||||
if _, err := r.conn.
|
||||
NewDelete().
|
||||
Table("follow_requests").
|
||||
WhereOr("? = ? OR ? = ?",
|
||||
bun.Ident("account_id"),
|
||||
accountID,
|
||||
bun.Ident("target_account_id"),
|
||||
accountID,
|
||||
).
|
||||
Returning("?", bun.Ident("id")).
|
||||
Exec(ctx, &followIDs); err != nil {
|
||||
return r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Invalidate each returned ID.
|
||||
for _, id := range followIDs {
|
||||
r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -19,17 +19,359 @@ package bundb_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
)
|
||||
|
||||
type RelationshipTestSuite struct {
|
||||
BunDBStandardTestSuite
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetBlockBy() {
|
||||
t := suite.T()
|
||||
|
||||
// Create a new context for this test.
|
||||
ctx, cncl := context.WithCancel(context.Background())
|
||||
defer cncl()
|
||||
|
||||
// Sentinel error to mark avoiding a test case.
|
||||
sentinelErr := errors.New("sentinel")
|
||||
|
||||
// isEqual checks if 2 block models are equal.
|
||||
isEqual := func(b1, b2 gtsmodel.Block) bool {
|
||||
// Clear populated sub-models.
|
||||
b1.Account = nil
|
||||
b2.Account = nil
|
||||
b1.TargetAccount = nil
|
||||
b2.TargetAccount = nil
|
||||
|
||||
// Clear database-set fields.
|
||||
b1.CreatedAt = time.Time{}
|
||||
b2.CreatedAt = time.Time{}
|
||||
b1.UpdatedAt = time.Time{}
|
||||
b2.UpdatedAt = time.Time{}
|
||||
|
||||
return reflect.DeepEqual(b1, b2)
|
||||
}
|
||||
|
||||
var testBlocks []*gtsmodel.Block
|
||||
|
||||
for _, account1 := range suite.testAccounts {
|
||||
for _, account2 := range suite.testAccounts {
|
||||
if account1.ID == account2.ID {
|
||||
// don't block *yourself* ...
|
||||
continue
|
||||
}
|
||||
|
||||
// Create new account block.
|
||||
block := >smodel.Block{
|
||||
ID: id.NewULID(),
|
||||
URI: "http://127.0.0.1:8080/" + id.NewULID(),
|
||||
AccountID: account1.ID,
|
||||
TargetAccountID: account2.ID,
|
||||
}
|
||||
|
||||
// Attempt to place the block in database (if not already).
|
||||
if err := suite.db.PutBlock(ctx, block); err != nil {
|
||||
if err != db.ErrAlreadyExists {
|
||||
// Unrecoverable database error.
|
||||
t.Fatalf("error creating block: %v", err)
|
||||
}
|
||||
|
||||
// Fetch existing block from database between accounts.
|
||||
block, _ = suite.db.GetBlock(ctx, account1.ID, account2.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
// Append generated block to test cases.
|
||||
testBlocks = append(testBlocks, block)
|
||||
}
|
||||
}
|
||||
|
||||
for _, block := range testBlocks {
|
||||
for lookup, dbfunc := range map[string]func() (*gtsmodel.Block, error){
|
||||
"id": func() (*gtsmodel.Block, error) {
|
||||
return suite.db.GetBlockByID(ctx, block.ID)
|
||||
},
|
||||
|
||||
"uri": func() (*gtsmodel.Block, error) {
|
||||
return suite.db.GetBlockByURI(ctx, block.URI)
|
||||
},
|
||||
|
||||
"origin_target": func() (*gtsmodel.Block, error) {
|
||||
return suite.db.GetBlock(ctx, block.AccountID, block.TargetAccountID)
|
||||
},
|
||||
} {
|
||||
|
||||
// Clear database caches.
|
||||
suite.state.Caches.Init()
|
||||
|
||||
t.Logf("checking database lookup %q", lookup)
|
||||
|
||||
// Perform database function.
|
||||
checkBlock, err := dbfunc()
|
||||
if err != nil {
|
||||
if err == sentinelErr {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check received block data.
|
||||
if !isEqual(*checkBlock, *block) {
|
||||
t.Errorf("block does not contain expected data: %+v", checkBlock)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check that block origin account populated.
|
||||
if checkBlock.Account == nil || checkBlock.Account.ID != block.AccountID {
|
||||
t.Errorf("block origin account not correctly populated for: %+v", checkBlock)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check that block target account populated.
|
||||
if checkBlock.TargetAccount == nil || checkBlock.TargetAccount.ID != block.TargetAccountID {
|
||||
t.Errorf("block target account not correctly populated for: %+v", checkBlock)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetFollowBy() {
|
||||
t := suite.T()
|
||||
|
||||
// Create a new context for this test.
|
||||
ctx, cncl := context.WithCancel(context.Background())
|
||||
defer cncl()
|
||||
|
||||
// Sentinel error to mark avoiding a test case.
|
||||
sentinelErr := errors.New("sentinel")
|
||||
|
||||
// isEqual checks if 2 follow models are equal.
|
||||
isEqual := func(f1, f2 gtsmodel.Follow) bool {
|
||||
// Clear populated sub-models.
|
||||
f1.Account = nil
|
||||
f2.Account = nil
|
||||
f1.TargetAccount = nil
|
||||
f2.TargetAccount = nil
|
||||
|
||||
// Clear database-set fields.
|
||||
f1.CreatedAt = time.Time{}
|
||||
f2.CreatedAt = time.Time{}
|
||||
f1.UpdatedAt = time.Time{}
|
||||
f2.UpdatedAt = time.Time{}
|
||||
|
||||
return reflect.DeepEqual(f1, f2)
|
||||
}
|
||||
|
||||
var testFollows []*gtsmodel.Follow
|
||||
|
||||
for _, account1 := range suite.testAccounts {
|
||||
for _, account2 := range suite.testAccounts {
|
||||
if account1.ID == account2.ID {
|
||||
// don't follow *yourself* ...
|
||||
continue
|
||||
}
|
||||
|
||||
// Create new account follow.
|
||||
follow := >smodel.Follow{
|
||||
ID: id.NewULID(),
|
||||
URI: "http://127.0.0.1:8080/" + id.NewULID(),
|
||||
AccountID: account1.ID,
|
||||
TargetAccountID: account2.ID,
|
||||
}
|
||||
|
||||
// Attempt to place the follow in database (if not already).
|
||||
if err := suite.db.PutFollow(ctx, follow); err != nil {
|
||||
if err != db.ErrAlreadyExists {
|
||||
// Unrecoverable database error.
|
||||
t.Fatalf("error creating follow: %v", err)
|
||||
}
|
||||
|
||||
// Fetch existing follow from database between accounts.
|
||||
follow, _ = suite.db.GetFollow(ctx, account1.ID, account2.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
// Append generated follow to test cases.
|
||||
testFollows = append(testFollows, follow)
|
||||
}
|
||||
}
|
||||
|
||||
for _, follow := range testFollows {
|
||||
for lookup, dbfunc := range map[string]func() (*gtsmodel.Follow, error){
|
||||
"id": func() (*gtsmodel.Follow, error) {
|
||||
return suite.db.GetFollowByID(ctx, follow.ID)
|
||||
},
|
||||
|
||||
"uri": func() (*gtsmodel.Follow, error) {
|
||||
return suite.db.GetFollowByURI(ctx, follow.URI)
|
||||
},
|
||||
|
||||
"origin_target": func() (*gtsmodel.Follow, error) {
|
||||
return suite.db.GetFollow(ctx, follow.AccountID, follow.TargetAccountID)
|
||||
},
|
||||
} {
|
||||
// Clear database caches.
|
||||
suite.state.Caches.Init()
|
||||
|
||||
t.Logf("checking database lookup %q", lookup)
|
||||
|
||||
// Perform database function.
|
||||
checkFollow, err := dbfunc()
|
||||
if err != nil {
|
||||
if err == sentinelErr {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check received follow data.
|
||||
if !isEqual(*checkFollow, *follow) {
|
||||
t.Errorf("follow does not contain expected data: %+v", checkFollow)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check that follow origin account populated.
|
||||
if checkFollow.Account == nil || checkFollow.Account.ID != follow.AccountID {
|
||||
t.Errorf("follow origin account not correctly populated for: %+v", checkFollow)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check that follow target account populated.
|
||||
if checkFollow.TargetAccount == nil || checkFollow.TargetAccount.ID != follow.TargetAccountID {
|
||||
t.Errorf("follow target account not correctly populated for: %+v", checkFollow)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetFollowRequestBy() {
|
||||
t := suite.T()
|
||||
|
||||
// Create a new context for this test.
|
||||
ctx, cncl := context.WithCancel(context.Background())
|
||||
defer cncl()
|
||||
|
||||
// Sentinel error to mark avoiding a test case.
|
||||
sentinelErr := errors.New("sentinel")
|
||||
|
||||
// isEqual checks if 2 follow request models are equal.
|
||||
isEqual := func(f1, f2 gtsmodel.FollowRequest) bool {
|
||||
// Clear populated sub-models.
|
||||
f1.Account = nil
|
||||
f2.Account = nil
|
||||
f1.TargetAccount = nil
|
||||
f2.TargetAccount = nil
|
||||
|
||||
// Clear database-set fields.
|
||||
f1.CreatedAt = time.Time{}
|
||||
f2.CreatedAt = time.Time{}
|
||||
f1.UpdatedAt = time.Time{}
|
||||
f2.UpdatedAt = time.Time{}
|
||||
|
||||
return reflect.DeepEqual(f1, f2)
|
||||
}
|
||||
|
||||
var testFollowReqs []*gtsmodel.FollowRequest
|
||||
|
||||
for _, account1 := range suite.testAccounts {
|
||||
for _, account2 := range suite.testAccounts {
|
||||
if account1.ID == account2.ID {
|
||||
// don't follow *yourself* ...
|
||||
continue
|
||||
}
|
||||
|
||||
// Create new account follow request.
|
||||
followReq := >smodel.FollowRequest{
|
||||
ID: id.NewULID(),
|
||||
URI: "http://127.0.0.1:8080/" + id.NewULID(),
|
||||
AccountID: account1.ID,
|
||||
TargetAccountID: account2.ID,
|
||||
}
|
||||
|
||||
// Attempt to place the follow in database (if not already).
|
||||
if err := suite.db.PutFollowRequest(ctx, followReq); err != nil {
|
||||
if err != db.ErrAlreadyExists {
|
||||
// Unrecoverable database error.
|
||||
t.Fatalf("error creating follow request: %v", err)
|
||||
}
|
||||
|
||||
// Fetch existing follow request from database between accounts.
|
||||
followReq, _ = suite.db.GetFollowRequest(ctx, account1.ID, account2.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
// Append generated follow request to test cases.
|
||||
testFollowReqs = append(testFollowReqs, followReq)
|
||||
}
|
||||
}
|
||||
|
||||
for _, followReq := range testFollowReqs {
|
||||
for lookup, dbfunc := range map[string]func() (*gtsmodel.FollowRequest, error){
|
||||
"id": func() (*gtsmodel.FollowRequest, error) {
|
||||
return suite.db.GetFollowRequestByID(ctx, followReq.ID)
|
||||
},
|
||||
|
||||
"uri": func() (*gtsmodel.FollowRequest, error) {
|
||||
return suite.db.GetFollowRequestByURI(ctx, followReq.URI)
|
||||
},
|
||||
|
||||
"origin_target": func() (*gtsmodel.FollowRequest, error) {
|
||||
return suite.db.GetFollowRequest(ctx, followReq.AccountID, followReq.TargetAccountID)
|
||||
},
|
||||
} {
|
||||
|
||||
// Clear database caches.
|
||||
suite.state.Caches.Init()
|
||||
|
||||
t.Logf("checking database lookup %q", lookup)
|
||||
|
||||
// Perform database function.
|
||||
checkFollowReq, err := dbfunc()
|
||||
if err != nil {
|
||||
if err == sentinelErr {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check received follow request data.
|
||||
if !isEqual(*checkFollowReq, *followReq) {
|
||||
t.Errorf("follow request does not contain expected data: %+v", checkFollowReq)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check that follow request origin account populated.
|
||||
if checkFollowReq.Account == nil || checkFollowReq.Account.ID != followReq.AccountID {
|
||||
t.Errorf("follow request origin account not correctly populated for: %+v", checkFollowReq)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check that follow request target account populated.
|
||||
if checkFollowReq.TargetAccount == nil || checkFollowReq.TargetAccount.ID != followReq.TargetAccountID {
|
||||
t.Errorf("follow request target account not correctly populated for: %+v", checkFollowReq)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestIsBlocked() {
|
||||
ctx := context.Background()
|
||||
|
||||
|
|
@ -37,11 +379,11 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
|
|||
account2 := suite.testAccounts["local_account_2"].ID
|
||||
|
||||
// no blocks exist between account 1 and account 2
|
||||
blocked, err := suite.db.IsBlocked(ctx, account1, account2, false)
|
||||
blocked, err := suite.db.IsBlocked(ctx, account1, account2)
|
||||
suite.NoError(err)
|
||||
suite.False(blocked)
|
||||
|
||||
blocked, err = suite.db.IsBlocked(ctx, account2, account1, false)
|
||||
blocked, err = suite.db.IsBlocked(ctx, account2, account1)
|
||||
suite.NoError(err)
|
||||
suite.False(blocked)
|
||||
|
||||
|
|
@ -56,45 +398,24 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
|
|||
}
|
||||
|
||||
// account 1 now blocks account 2
|
||||
blocked, err = suite.db.IsBlocked(ctx, account1, account2, false)
|
||||
blocked, err = suite.db.IsBlocked(ctx, account1, account2)
|
||||
suite.NoError(err)
|
||||
suite.True(blocked)
|
||||
|
||||
// account 2 doesn't block account 1
|
||||
blocked, err = suite.db.IsBlocked(ctx, account2, account1, false)
|
||||
blocked, err = suite.db.IsBlocked(ctx, account2, account1)
|
||||
suite.NoError(err)
|
||||
suite.False(blocked)
|
||||
|
||||
// a block exists in either direction between the two
|
||||
blocked, err = suite.db.IsBlocked(ctx, account1, account2, true)
|
||||
blocked, err = suite.db.IsEitherBlocked(ctx, account1, account2)
|
||||
suite.NoError(err)
|
||||
suite.True(blocked)
|
||||
blocked, err = suite.db.IsBlocked(ctx, account2, account1, true)
|
||||
blocked, err = suite.db.IsEitherBlocked(ctx, account2, account1)
|
||||
suite.NoError(err)
|
||||
suite.True(blocked)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetBlock() {
|
||||
ctx := context.Background()
|
||||
|
||||
account1 := suite.testAccounts["local_account_1"].ID
|
||||
account2 := suite.testAccounts["local_account_2"].ID
|
||||
|
||||
if err := suite.db.PutBlock(ctx, >smodel.Block{
|
||||
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
|
||||
URI: "http://localhost:8080/some_block_uri_1",
|
||||
AccountID: account1,
|
||||
TargetAccountID: account2,
|
||||
}); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
block, err := suite.db.GetBlock(ctx, account1, account2)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(block)
|
||||
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestDeleteBlockByID() {
|
||||
ctx := context.Background()
|
||||
|
||||
|
|
@ -157,7 +478,7 @@ func (suite *RelationshipTestSuite) TestDeleteBlockByURI() {
|
|||
suite.Nil(block)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestDeleteBlocksByOriginAccountID() {
|
||||
func (suite *RelationshipTestSuite) TestDeleteAccountBlocks() {
|
||||
ctx := context.Background()
|
||||
|
||||
// put a block in first
|
||||
|
|
@ -179,38 +500,7 @@ func (suite *RelationshipTestSuite) TestDeleteBlocksByOriginAccountID() {
|
|||
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
|
||||
|
||||
// delete the block by originAccountID
|
||||
err = suite.db.DeleteBlocksByOriginAccountID(ctx, account1)
|
||||
suite.NoError(err)
|
||||
|
||||
// block should be gone
|
||||
block, err = suite.db.GetBlock(ctx, account1, account2)
|
||||
suite.ErrorIs(err, db.ErrNoEntries)
|
||||
suite.Nil(block)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestDeleteBlocksByTargetAccountID() {
|
||||
ctx := context.Background()
|
||||
|
||||
// put a block in first
|
||||
account1 := suite.testAccounts["local_account_1"].ID
|
||||
account2 := suite.testAccounts["local_account_2"].ID
|
||||
if err := suite.db.PutBlock(ctx, >smodel.Block{
|
||||
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
|
||||
URI: "http://localhost:8080/some_block_uri_1",
|
||||
AccountID: account1,
|
||||
TargetAccountID: account2,
|
||||
}); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
// make sure the block is in the db
|
||||
block, err := suite.db.GetBlock(ctx, account1, account2)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(block)
|
||||
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
|
||||
|
||||
// delete the block by targetAccountID
|
||||
err = suite.db.DeleteBlocksByTargetAccountID(ctx, account2)
|
||||
err = suite.db.DeleteAccountBlocks(ctx, account1)
|
||||
suite.NoError(err)
|
||||
|
||||
// block should be gone
|
||||
|
|
@ -244,7 +534,7 @@ func (suite *RelationshipTestSuite) TestGetRelationship() {
|
|||
func (suite *RelationshipTestSuite) TestIsFollowingYes() {
|
||||
requestingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["admin_account"]
|
||||
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
|
||||
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.True(isFollowing)
|
||||
}
|
||||
|
|
@ -252,7 +542,7 @@ func (suite *RelationshipTestSuite) TestIsFollowingYes() {
|
|||
func (suite *RelationshipTestSuite) TestIsFollowingNo() {
|
||||
requestingAccount := suite.testAccounts["admin_account"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
|
||||
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.False(isFollowing)
|
||||
}
|
||||
|
|
@ -260,7 +550,7 @@ func (suite *RelationshipTestSuite) TestIsFollowingNo() {
|
|||
func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
|
||||
requestingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["admin_account"]
|
||||
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
|
||||
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.True(isMutualFollowing)
|
||||
}
|
||||
|
|
@ -268,7 +558,7 @@ func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
|
|||
func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() {
|
||||
requestingAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
|
||||
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.True(isMutualFollowing)
|
||||
}
|
||||
|
|
@ -306,7 +596,7 @@ func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() {
|
|||
suite.Equal(followRequest.URI, follow.URI)
|
||||
|
||||
// Ensure notification is deleted.
|
||||
notification, err := suite.db.GetNotification(ctx, followRequestNotification.ID)
|
||||
notification, err := suite.db.GetNotificationByID(ctx, followRequestNotification.ID)
|
||||
suite.ErrorIs(err, db.ErrNoEntries)
|
||||
suite.Nil(notification)
|
||||
}
|
||||
|
|
@ -389,7 +679,7 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
|
|||
TargetAccountID: targetAccount.ID,
|
||||
}
|
||||
|
||||
if err := suite.db.Put(ctx, followRequest); err != nil {
|
||||
if err := suite.db.PutFollowRequest(ctx, followRequest); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
|
|
@ -404,12 +694,11 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
|
|||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||
err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(rejectedFollowRequest)
|
||||
|
||||
// Ensure notification is deleted.
|
||||
notification, err := suite.db.GetNotification(ctx, followRequestNotification.ID)
|
||||
notification, err := suite.db.GetNotificationByID(ctx, followRequestNotification.ID)
|
||||
suite.ErrorIs(err, db.ErrNoEntries)
|
||||
suite.Nil(notification)
|
||||
}
|
||||
|
|
@ -419,9 +708,8 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() {
|
|||
account := suite.testAccounts["admin_account"]
|
||||
targetAccount := suite.testAccounts["local_account_2"]
|
||||
|
||||
rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||
err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
|
||||
suite.ErrorIs(err, db.ErrNoEntries)
|
||||
suite.Nil(rejectedFollowRequest)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
|
||||
|
|
@ -440,42 +728,49 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
|
|||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
followRequests, err := suite.db.GetFollowRequests(ctx, "", targetAccount.ID)
|
||||
followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.Len(followRequests, 1)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetAccountFollows() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
follows, err := suite.db.GetFollows(context.Background(), account.ID, "")
|
||||
follows, err := suite.db.GetAccountFollows(context.Background(), account.ID)
|
||||
suite.NoError(err)
|
||||
suite.Len(follows, 2)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestCountAccountFollows() {
|
||||
func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
followsCount, err := suite.db.CountFollows(context.Background(), account.ID, "")
|
||||
followsCount, err := suite.db.CountAccountLocalFollows(context.Background(), account.ID)
|
||||
suite.NoError(err)
|
||||
suite.Equal(2, followsCount)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() {
|
||||
func (suite *RelationshipTestSuite) TestCountAccountFollows() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
follows, err := suite.db.GetFollows(context.Background(), "", account.ID)
|
||||
followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID)
|
||||
suite.NoError(err)
|
||||
suite.Equal(2, followsCount)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetAccountFollowers() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID)
|
||||
suite.NoError(err)
|
||||
suite.Len(follows, 2)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestGetLocalFollowersIDs() {
|
||||
func (suite *RelationshipTestSuite) TestCountAccountFollowers() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
accountIDs, err := suite.db.GetLocalFollowersIDs(context.Background(), account.ID)
|
||||
followsCount, err := suite.db.CountAccountFollowers(context.Background(), account.ID)
|
||||
suite.NoError(err)
|
||||
suite.EqualValues([]string{"01F8MH5NBDF2MV7CTC4Q5128HF", "01F8MH17FWEB39HZJ76B6VXSKF"}, accountIDs)
|
||||
suite.Equal(2, followsCount)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() {
|
||||
func (suite *RelationshipTestSuite) TestCountAccountFollowersLocalOnly() {
|
||||
account := suite.testAccounts["local_account_1"]
|
||||
followsCount, err := suite.db.CountFollows(context.Background(), "", account.ID)
|
||||
followsCount, err := suite.db.CountAccountLocalFollowers(context.Background(), account.ID)
|
||||
suite.NoError(err)
|
||||
suite.Equal(2, followsCount)
|
||||
}
|
||||
|
|
@ -484,18 +779,25 @@ func (suite *RelationshipTestSuite) TestUnfollowExisting() {
|
|||
originAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccount := suite.testAccounts["admin_account"]
|
||||
|
||||
uri, err := suite.db.Unfollow(context.Background(), originAccount.ID, targetAccount.ID)
|
||||
follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.Equal("http://localhost:8080/users/the_mighty_zork/follow/01F8PY8RHWRQZV038T4E8T9YK8", uri)
|
||||
suite.NotNil(follow)
|
||||
|
||||
err = suite.db.DeleteFollowByID(context.Background(), follow.ID)
|
||||
suite.NoError(err)
|
||||
|
||||
follow, err = suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
|
||||
suite.EqualError(err, db.ErrNoEntries.Error())
|
||||
suite.Nil(follow)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestUnfollowNotExisting() {
|
||||
originAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ"
|
||||
|
||||
uri, err := suite.db.Unfollow(context.Background(), originAccount.ID, targetAccountID)
|
||||
suite.NoError(err)
|
||||
suite.Empty(uri)
|
||||
follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccountID)
|
||||
suite.EqualError(err, db.ErrNoEntries.Error())
|
||||
suite.Nil(follow)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestUnfollowRequestExisting() {
|
||||
|
|
@ -510,22 +812,29 @@ func (suite *RelationshipTestSuite) TestUnfollowRequestExisting() {
|
|||
TargetAccountID: targetAccount.ID,
|
||||
}
|
||||
|
||||
if err := suite.db.Put(ctx, followRequest); err != nil {
|
||||
if err := suite.db.PutFollowRequest(ctx, followRequest); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
uri, err := suite.db.UnfollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
|
||||
followRequest, err := suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
|
||||
suite.NoError(err)
|
||||
suite.Equal("http://localhost:8080/weeeeeeeeeeeeeeeee", uri)
|
||||
suite.NotNil(followRequest)
|
||||
|
||||
err = suite.db.DeleteFollowRequestByID(context.Background(), followRequest.ID)
|
||||
suite.NoError(err)
|
||||
|
||||
followRequest, err = suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
|
||||
suite.EqualError(err, db.ErrNoEntries.Error())
|
||||
suite.Nil(followRequest)
|
||||
}
|
||||
|
||||
func (suite *RelationshipTestSuite) TestUnfollowRequestNotExisting() {
|
||||
originAccount := suite.testAccounts["local_account_1"]
|
||||
targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ"
|
||||
|
||||
uri, err := suite.db.UnfollowRequest(context.Background(), originAccount.ID, targetAccountID)
|
||||
suite.NoError(err)
|
||||
suite.Empty(uri)
|
||||
followRequest, err := suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccountID)
|
||||
suite.EqualError(err, db.ErrNoEntries.Error())
|
||||
suite.Nil(followRequest)
|
||||
}
|
||||
|
||||
func TestRelationshipTestSuite(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
|
|
@ -41,7 +42,6 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
|
|||
return s.conn.
|
||||
NewSelect().
|
||||
Model(status).
|
||||
Relation("Attachments").
|
||||
Relation("Tags").
|
||||
Relation("CreatedWithApplication")
|
||||
}
|
||||
|
|
@ -102,81 +102,143 @@ func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*g
|
|||
status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) {
|
||||
var status gtsmodel.Status
|
||||
|
||||
// Not cached! Perform database query
|
||||
// Not cached! Perform database query.
|
||||
if err := dbQuery(&status); err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
if status.InReplyToID != "" {
|
||||
// Also load in-reply-to status
|
||||
status.InReplyTo = new(gtsmodel.Status)
|
||||
err := s.conn.NewSelect().Model(status.InReplyTo).
|
||||
Where("? = ?", bun.Ident("status.id"), status.InReplyToID).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
}
|
||||
|
||||
if status.BoostOfID != "" {
|
||||
// Also load original boosted status
|
||||
status.BoostOf = new(gtsmodel.Status)
|
||||
err := s.conn.NewSelect().Model(status.BoostOf).
|
||||
Where("? = ?", bun.Ident("status.id"), status.BoostOfID).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &status, nil
|
||||
}, keyParts...)
|
||||
if err != nil {
|
||||
// error already processed
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set the status author account
|
||||
status.Account, err = s.state.DB.GetAccountByID(ctx, status.AccountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting status account: %w", err)
|
||||
if gtscontext.Barebones(ctx) {
|
||||
// no need to fully populate.
|
||||
return status, nil
|
||||
}
|
||||
|
||||
if id := status.BoostOfAccountID; id != "" {
|
||||
// Set boost of status' author account
|
||||
status.BoostOfAccount, err = s.state.DB.GetAccountByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting boosted status account: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if id := status.InReplyToAccountID; id != "" {
|
||||
// Set in-reply-to status' author account
|
||||
status.InReplyToAccount, err = s.state.DB.GetAccountByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting in reply to status account: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(status.EmojiIDs) > 0 {
|
||||
// Fetch status emojis
|
||||
status.Emojis, err = s.state.DB.GetEmojisByIDs(ctx, status.EmojiIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting status emojis: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(status.MentionIDs) > 0 {
|
||||
// Fetch status mentions
|
||||
status.Mentions, err = s.state.DB.GetMentions(ctx, status.MentionIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting status mentions: %w", err)
|
||||
}
|
||||
// Further populate the status fields where applicable.
|
||||
if err := s.PopulateStatus(ctx, status); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) error {
|
||||
var err error
|
||||
|
||||
if status.Account == nil {
|
||||
// Status author is not set, fetch from database.
|
||||
status.Account, err = s.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
status.AccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error populating status author: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if status.InReplyToID != "" && status.InReplyTo == nil {
|
||||
// Status parent is not set, fetch from database.
|
||||
status.InReplyTo, err = s.GetStatusByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
status.InReplyToID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error populating status parent: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if status.InReplyToID != "" {
|
||||
if status.InReplyTo == nil {
|
||||
// Status parent is not set, fetch from database.
|
||||
status.InReplyTo, err = s.GetStatusByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
status.InReplyToID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error populating status parent: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if status.InReplyToAccount == nil {
|
||||
// Status parent author is not set, fetch from database.
|
||||
status.InReplyToAccount, err = s.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
status.InReplyToAccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error populating status parent author: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if status.BoostOfID != "" {
|
||||
if status.BoostOf == nil {
|
||||
// Status boost is not set, fetch from database.
|
||||
status.BoostOf, err = s.GetStatusByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
status.BoostOfID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error populating status boost: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if status.BoostOfAccount == nil {
|
||||
// Status boost author is not set, fetch from database.
|
||||
status.BoostOfAccount, err = s.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
status.BoostOfAccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error populating status boost author: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !status.AttachmentsPopulated() {
|
||||
// Status attachments are out-of-date with IDs, repopulate.
|
||||
status.Attachments, err = s.state.DB.GetAttachmentsByIDs(
|
||||
ctx, // these are already barebones
|
||||
status.AttachmentIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error populating status attachments: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: once we don't fetch using relations.
|
||||
// if !status.TagsPopulated() {
|
||||
// }
|
||||
|
||||
if !status.MentionsPopulated() {
|
||||
// Status mentions are out-of-date with IDs, repopulate.
|
||||
status.Mentions, err = s.state.DB.GetMentions(
|
||||
ctx, // leave fully populated for now
|
||||
status.MentionIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error populating status mentions: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !status.EmojisPopulated() {
|
||||
// Status emojis are out-of-date with IDs, repopulate.
|
||||
status.Emojis, err = s.state.DB.GetEmojisByIDs(
|
||||
ctx, // these are already barebones
|
||||
status.EmojiIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error populating status emojis: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
|
||||
err := s.state.Caches.GTS.Status().Store(status, func() error {
|
||||
// It is safe to run this database transaction within cache.Store
|
||||
|
|
@ -239,12 +301,16 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
|
|||
})
|
||||
})
|
||||
if err != nil {
|
||||
// already processed
|
||||
return err
|
||||
}
|
||||
|
||||
for _, id := range status.AttachmentIDs {
|
||||
// Clear updated media attachment IDs from cache
|
||||
// Invalidate media attachments from cache.
|
||||
//
|
||||
// NOTE: this is needed due to the way in which
|
||||
// we upload status attachments, and only after
|
||||
// update them with a known status ID. This is
|
||||
// not the case for header/avatar attachments.
|
||||
s.state.Caches.GTS.Media().Invalidate("ID", id)
|
||||
}
|
||||
|
||||
|
|
@ -322,14 +388,19 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
|
|||
return err
|
||||
}
|
||||
|
||||
// Invalidate status from database lookups.
|
||||
s.state.Caches.GTS.Status().Invalidate("ID", status.ID)
|
||||
|
||||
for _, id := range status.AttachmentIDs {
|
||||
// Clear updated media attachment IDs from cache
|
||||
// Invalidate media attachments from cache.
|
||||
//
|
||||
// NOTE: this is needed due to the way in which
|
||||
// we upload status attachments, and only after
|
||||
// update them with a known status ID. This is
|
||||
// not the case for header/avatar attachments.
|
||||
s.state.Caches.GTS.Media().Invalidate("ID", id)
|
||||
}
|
||||
|
||||
// Drop any old status value from cache by this ID
|
||||
s.state.Caches.GTS.Status().Invalidate("ID", status.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -367,8 +438,12 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Drop any old value from cache by this ID
|
||||
// Invalidate status from database lookups.
|
||||
s.state.Caches.GTS.Status().Invalidate("ID", id)
|
||||
|
||||
// Invalidate status from all visibility lookups.
|
||||
s.state.Caches.Visibility.Invalidate("ItemID", id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
|
|
@ -34,29 +35,82 @@ type statusFaveDB struct {
|
|||
state *state.State
|
||||
}
|
||||
|
||||
func (s *statusFaveDB) GetStatusFave(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) {
|
||||
fave := new(gtsmodel.StatusFave)
|
||||
func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) {
|
||||
return s.getStatusFave(
|
||||
ctx,
|
||||
"AccountID.StatusID",
|
||||
func(fave *gtsmodel.StatusFave) error {
|
||||
return s.conn.
|
||||
NewSelect().
|
||||
Model(fave).
|
||||
Where("? = ?", bun.Ident("account_id"), accountID).
|
||||
Where("? = ?", bun.Ident("status_id"), statusID).
|
||||
Scan(ctx)
|
||||
},
|
||||
accountID,
|
||||
statusID,
|
||||
)
|
||||
}
|
||||
|
||||
err := s.conn.
|
||||
NewSelect().
|
||||
Model(fave).
|
||||
Where("? = ?", bun.Ident("status_fave.ID"), id).
|
||||
Scan(ctx)
|
||||
func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) {
|
||||
return s.getStatusFave(
|
||||
ctx,
|
||||
"ID",
|
||||
func(fave *gtsmodel.StatusFave) error {
|
||||
return s.conn.
|
||||
NewSelect().
|
||||
Model(fave).
|
||||
Where("? = ?", bun.Ident("id"), id).
|
||||
Scan(ctx)
|
||||
},
|
||||
id,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery func(*gtsmodel.StatusFave) error, keyParts ...any) (*gtsmodel.StatusFave, error) {
|
||||
// Fetch status fave from database cache with loader callback
|
||||
fave, err := s.state.Caches.GTS.StatusFave().Load(lookup, func() (*gtsmodel.StatusFave, error) {
|
||||
var fave gtsmodel.StatusFave
|
||||
|
||||
// Not cached! Perform database query.
|
||||
if err := dbQuery(&fave); err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return &fave, nil
|
||||
}, keyParts...)
|
||||
if err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fave.Account, err = s.state.DB.GetAccountByID(ctx, fave.AccountID)
|
||||
if gtscontext.Barebones(ctx) {
|
||||
// no need to fully populate.
|
||||
return fave, nil
|
||||
}
|
||||
|
||||
// Fetch the status fave author account.
|
||||
fave.Account, err = s.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
fave.AccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting status fave account %q: %w", fave.AccountID, err)
|
||||
}
|
||||
|
||||
fave.TargetAccount, err = s.state.DB.GetAccountByID(ctx, fave.TargetAccountID)
|
||||
// Fetch the status fave target account.
|
||||
fave.TargetAccount, err = s.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
fave.TargetAccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting status fave target account %q: %w", fave.TargetAccountID, err)
|
||||
}
|
||||
|
||||
fave.Status, err = s.state.DB.GetStatusByID(ctx, fave.StatusID)
|
||||
// Fetch the status fave target status.
|
||||
fave.Status, err = s.state.DB.GetStatusByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
fave.StatusID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting status fave status %q: %w", fave.StatusID, err)
|
||||
}
|
||||
|
|
@ -64,38 +118,22 @@ func (s *statusFaveDB) GetStatusFave(ctx context.Context, id string) (*gtsmodel.
|
|||
return fave, nil
|
||||
}
|
||||
|
||||
func (s *statusFaveDB) GetStatusFaveByAccountID(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) {
|
||||
var id string
|
||||
|
||||
err := s.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
|
||||
Column("status_fave.id").
|
||||
Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
|
||||
Where("? = ?", bun.Ident("status_fave.status_id"), statusID).
|
||||
Scan(ctx, &id)
|
||||
if err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return s.GetStatusFave(ctx, id)
|
||||
}
|
||||
|
||||
func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) {
|
||||
func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) {
|
||||
ids := []string{}
|
||||
|
||||
if err := s.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
|
||||
Column("status_fave.id").
|
||||
Where("? = ?", bun.Ident("status_fave.status_id"), statusID).
|
||||
Table("status_faves").
|
||||
Column("id").
|
||||
Where("? = ?", bun.Ident("status_id"), statusID).
|
||||
Scan(ctx, &ids); err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
faves := make([]*gtsmodel.StatusFave, 0, len(ids))
|
||||
|
||||
for _, id := range ids {
|
||||
fave, err := s.GetStatusFave(ctx, id)
|
||||
fave, err := s.GetStatusFaveByID(ctx, id)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting status fave %q: %v", id, err)
|
||||
continue
|
||||
|
|
@ -107,23 +145,27 @@ func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*
|
|||
return faves, nil
|
||||
}
|
||||
|
||||
func (s *statusFaveDB) PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) db.Error {
|
||||
_, err := s.conn.
|
||||
NewInsert().
|
||||
Model(statusFave).
|
||||
Exec(ctx)
|
||||
|
||||
return s.conn.ProcessError(err)
|
||||
func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) db.Error {
|
||||
return s.state.Caches.GTS.StatusFave().Store(fave, func() error {
|
||||
_, err := s.conn.
|
||||
NewInsert().
|
||||
Model(fave).
|
||||
Exec(ctx)
|
||||
return s.conn.ProcessError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *statusFaveDB) DeleteStatusFave(ctx context.Context, id string) db.Error {
|
||||
_, err := s.conn.
|
||||
func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) db.Error {
|
||||
if _, err := s.conn.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
|
||||
Where("? = ?", bun.Ident("status_fave.id"), id).
|
||||
Exec(ctx)
|
||||
Table("status_faves").
|
||||
Where("? = ?", bun.Ident("id"), id).
|
||||
Exec(ctx); err != nil {
|
||||
return s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return s.conn.ProcessError(err)
|
||||
s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) db.Error {
|
||||
|
|
@ -131,42 +173,52 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
|
|||
return errors.New("DeleteStatusFaves: one of targetAccountID or originAccountID must be set")
|
||||
}
|
||||
|
||||
// TODO: Capture fave IDs in a RETURNING
|
||||
// statement (when faves have a cache),
|
||||
// + use the IDs to invalidate cache entries.
|
||||
// Capture fave IDs in a RETURNING statement.
|
||||
var faveIDs []string
|
||||
|
||||
q := s.conn.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave"))
|
||||
Table("status_faves").
|
||||
Returning("?", bun.Ident("id"))
|
||||
|
||||
if targetAccountID != "" {
|
||||
q = q.Where("? = ?", bun.Ident("status_fave.target_account_id"), targetAccountID)
|
||||
q = q.Where("? = ?", bun.Ident("target_account_id"), targetAccountID)
|
||||
}
|
||||
|
||||
if originAccountID != "" {
|
||||
q = q.Where("? = ?", bun.Ident("status_fave.account_id"), originAccountID)
|
||||
q = q.Where("? = ?", bun.Ident("account_id"), originAccountID)
|
||||
}
|
||||
|
||||
if _, err := q.Exec(ctx); err != nil {
|
||||
if _, err := q.Exec(ctx, &faveIDs); err != nil {
|
||||
return s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
for _, id := range faveIDs {
|
||||
// Invalidate each of the returned status fave IDs.
|
||||
s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) db.Error {
|
||||
// TODO: Capture fave IDs in a RETURNING
|
||||
// statement (when faves have a cache),
|
||||
// + use the IDs to invalidate cache entries.
|
||||
// Capture fave IDs in a RETURNING statement.
|
||||
var faveIDs []string
|
||||
|
||||
q := s.conn.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
|
||||
Where("? = ?", bun.Ident("status_fave.status_id"), statusID)
|
||||
Table("status_faves").
|
||||
Where("? = ?", bun.Ident("status_id"), statusID).
|
||||
Returning("?", bun.Ident("id"))
|
||||
|
||||
if _, err := q.Exec(ctx); err != nil {
|
||||
if _, err := q.Exec(ctx, &faveIDs); err != nil {
|
||||
return s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
for _, id := range faveIDs {
|
||||
// Invalidate each of the returned status fave IDs.
|
||||
s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ type StatusFaveTestSuite struct {
|
|||
func (suite *StatusFaveTestSuite) TestGetStatusFaves() {
|
||||
testStatus := suite.testStatuses["admin_account_status_1"]
|
||||
|
||||
faves, err := suite.db.GetStatusFaves(context.Background(), testStatus.ID)
|
||||
faves, err := suite.db.GetStatusFavesForStatus(context.Background(), testStatus.ID)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
|
@ -51,7 +51,7 @@ func (suite *StatusFaveTestSuite) TestGetStatusFaves() {
|
|||
func (suite *StatusFaveTestSuite) TestGetStatusFavesNone() {
|
||||
testStatus := suite.testStatuses["admin_account_status_4"]
|
||||
|
||||
faves, err := suite.db.GetStatusFaves(context.Background(), testStatus.ID)
|
||||
faves, err := suite.db.GetStatusFavesForStatus(context.Background(), testStatus.ID)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
|
@ -63,7 +63,7 @@ func (suite *StatusFaveTestSuite) TestGetStatusFaveByAccountID() {
|
|||
testAccount := suite.testAccounts["local_account_1"]
|
||||
testStatus := suite.testStatuses["admin_account_status_1"]
|
||||
|
||||
fave, err := suite.db.GetStatusFaveByAccountID(context.Background(), testAccount.ID, testStatus.ID)
|
||||
fave, err := suite.db.GetStatusFave(context.Background(), testAccount.ID, testStatus.ID)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(fave)
|
||||
}
|
||||
|
|
@ -129,17 +129,17 @@ func (suite *StatusFaveTestSuite) TestDeleteStatusFave() {
|
|||
testFave := suite.testFaves["local_account_1_admin_account_status_1"]
|
||||
ctx := context.Background()
|
||||
|
||||
if err := suite.db.DeleteStatusFave(ctx, testFave.ID); err != nil {
|
||||
if err := suite.db.DeleteStatusFaveByID(ctx, testFave.ID); err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
fave, err := suite.db.GetStatusFave(ctx, testFave.ID)
|
||||
fave, err := suite.db.GetStatusFaveByID(ctx, testFave.ID)
|
||||
suite.ErrorIs(err, db.ErrNoEntries)
|
||||
suite.Nil(fave)
|
||||
}
|
||||
|
||||
func (suite *StatusFaveTestSuite) TestDeleteStatusFaveNonExisting() {
|
||||
err := suite.db.DeleteStatusFave(context.Background(), "01GVAV715K6Y2SG9ZKS9ZA8G7G")
|
||||
err := suite.db.DeleteStatusFaveByID(context.Background(), "01GVAV715K6Y2SG9ZKS9ZA8G7G")
|
||||
suite.NoError(err)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -61,9 +61,12 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
|
|||
Order("status.id DESC")
|
||||
|
||||
if maxID == "" {
|
||||
const future = 24 * time.Hour
|
||||
|
||||
var err error
|
||||
// don't return statuses more than five minutes in the future
|
||||
maxID, err = id.NewULIDFromTime(time.Now().Add(5 * time.Minute))
|
||||
|
||||
// don't return statuses more than 24hr in the future
|
||||
maxID, err = id.NewULIDFromTime(time.Now().Add(future))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -138,15 +141,16 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI
|
|||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
Column("status.id").
|
||||
Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_id")).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")).
|
||||
Order("status.id DESC")
|
||||
|
||||
if maxID == "" {
|
||||
const future = 24 * time.Hour
|
||||
|
||||
var err error
|
||||
// don't return statuses more than five minutes in the future
|
||||
maxID, err = id.NewULIDFromTime(time.Now().Add(5 * time.Minute))
|
||||
|
||||
// don't return statuses more than 24hr in the future
|
||||
maxID, err = id.NewULIDFromTime(time.Now().Add(future))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,15 +34,32 @@ type TimelineTestSuite struct {
|
|||
}
|
||||
|
||||
func (suite *TimelineTestSuite) TestGetPublicTimeline() {
|
||||
ctx := context.Background()
|
||||
var count int
|
||||
|
||||
for _, status := range suite.testStatuses {
|
||||
if status.Visibility == gtsmodel.VisibilityPublic &&
|
||||
status.BoostOfID == "" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
|
||||
suite.NoError(err)
|
||||
|
||||
suite.Len(s, 6)
|
||||
suite.Len(s, count)
|
||||
}
|
||||
|
||||
func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
|
||||
var count int
|
||||
|
||||
for _, status := range suite.testStatuses {
|
||||
if status.Visibility == gtsmodel.VisibilityPublic &&
|
||||
status.BoostOfID == "" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
futureStatus := getFutureStatus()
|
||||
|
|
@ -53,7 +70,7 @@ func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
|
|||
suite.NoError(err)
|
||||
|
||||
suite.NotContains(s, futureStatus)
|
||||
suite.Len(s, 6)
|
||||
suite.Len(s, count)
|
||||
}
|
||||
|
||||
func (suite *TimelineTestSuite) TestGetHomeTimeline() {
|
||||
|
|
|
|||
|
|
@ -29,6 +29,9 @@ type Media interface {
|
|||
// GetAttachmentByID gets a single attachment by its ID.
|
||||
GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error)
|
||||
|
||||
// GetAttachmentsByIDs fetches a list of media attachments for given IDs.
|
||||
GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error)
|
||||
|
||||
// PutAttachment inserts the given attachment into the database.
|
||||
PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error
|
||||
|
||||
|
|
|
|||
|
|
@ -30,4 +30,10 @@ type Mention interface {
|
|||
|
||||
// GetMentions gets multiple mentions.
|
||||
GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error)
|
||||
|
||||
// PutMention will insert the given mention into the database.
|
||||
PutMention(ctx context.Context, mention *gtsmodel.Mention) error
|
||||
|
||||
// DeleteMentionByID will delete mention with given ID from the database.
|
||||
DeleteMentionByID(ctx context.Context, id string) error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,14 +28,17 @@ type Notification interface {
|
|||
// GetNotifications returns a slice of notifications that pertain to the given accountID.
|
||||
//
|
||||
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
|
||||
GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
|
||||
GetAccountNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
|
||||
|
||||
// GetNotification returns one notification according to its id.
|
||||
GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, Error)
|
||||
GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, Error)
|
||||
|
||||
// DeleteNotification deletes one notification according to its id,
|
||||
// PutNotification will insert the given notification into the database.
|
||||
PutNotification(ctx context.Context, notif *gtsmodel.Notification) error
|
||||
|
||||
// DeleteNotificationByID deletes one notification according to its id,
|
||||
// and removes that notification from the in-memory cache.
|
||||
DeleteNotification(ctx context.Context, id string) Error
|
||||
DeleteNotificationByID(ctx context.Context, id string) Error
|
||||
|
||||
// DeleteNotifications mass deletes notifications targeting targetAccountID
|
||||
// and/or originating from originAccountID.
|
||||
|
|
@ -50,7 +53,7 @@ type Notification interface {
|
|||
// originate from originAccountID will be deleted.
|
||||
//
|
||||
// At least one parameter must not be an empty string.
|
||||
DeleteNotifications(ctx context.Context, targetAccountID string, originAccountID string) Error
|
||||
DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) Error
|
||||
|
||||
// DeleteNotificationsForStatus deletes all notifications that relate to
|
||||
// the given statusID. This function is useful when a status has been deleted,
|
||||
|
|
|
|||
|
|
@ -25,42 +25,86 @@ import (
|
|||
|
||||
// Relationship contains functions for getting or modifying the relationship between two accounts.
|
||||
type Relationship interface {
|
||||
// IsBlocked checks whether account 1 has a block in place against account2.
|
||||
// If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1.
|
||||
IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, Error)
|
||||
// IsBlocked checks whether source account has a block in place against target.
|
||||
IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
|
||||
|
||||
// IsEitherBlocked checks whether there is a block in place between either of account1 and account2.
|
||||
IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error)
|
||||
|
||||
// GetBlockByID fetches block with given ID from the database.
|
||||
GetBlockByID(ctx context.Context, id string) (*gtsmodel.Block, error)
|
||||
|
||||
// GetBlockByURI fetches block with given AP URI from the database.
|
||||
GetBlockByURI(ctx context.Context, uri string) (*gtsmodel.Block, error)
|
||||
|
||||
// GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't.
|
||||
//
|
||||
// Because this is slower than Blocked, only use it if you need the actual Block struct for some reason,
|
||||
// not if you're just checking for the existence of a block.
|
||||
GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, Error)
|
||||
GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, error)
|
||||
|
||||
// PutBlock attempts to place the given account block in the database.
|
||||
PutBlock(ctx context.Context, block *gtsmodel.Block) Error
|
||||
PutBlock(ctx context.Context, block *gtsmodel.Block) error
|
||||
|
||||
// DeleteBlockByID removes block with given ID from the database.
|
||||
DeleteBlockByID(ctx context.Context, id string) Error
|
||||
DeleteBlockByID(ctx context.Context, id string) error
|
||||
|
||||
// DeleteBlockByURI removes block with given AP URI from the database.
|
||||
DeleteBlockByURI(ctx context.Context, uri string) Error
|
||||
DeleteBlockByURI(ctx context.Context, uri string) error
|
||||
|
||||
// DeleteBlocksByOriginAccountID removes any blocks with accountID equal to originAccountID.
|
||||
DeleteBlocksByOriginAccountID(ctx context.Context, originAccountID string) Error
|
||||
|
||||
// DeleteBlocksByTargetAccountID removes any blocks with given targetAccountID.
|
||||
DeleteBlocksByTargetAccountID(ctx context.Context, targetAccountID string) Error
|
||||
// DeleteAccountBlocks will delete all database blocks to / from the given account ID.
|
||||
DeleteAccountBlocks(ctx context.Context, accountID string) error
|
||||
|
||||
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
|
||||
GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
|
||||
|
||||
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
|
||||
IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
|
||||
// GetFollowByID fetches follow with given ID from the database.
|
||||
GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error)
|
||||
|
||||
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
|
||||
IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
|
||||
// GetFollowByURI fetches follow with given AP URI from the database.
|
||||
GetFollowByURI(ctx context.Context, uri string) (*gtsmodel.Follow, error)
|
||||
|
||||
// GetFollow retrieves a follow if it exists between source and target accounts.
|
||||
GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
|
||||
|
||||
// GetFollowRequestByID fetches follow request with given ID from the database.
|
||||
GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error)
|
||||
|
||||
// GetFollowRequestByURI fetches follow request with given AP URI from the database.
|
||||
GetFollowRequestByURI(ctx context.Context, uri string) (*gtsmodel.FollowRequest, error)
|
||||
|
||||
// GetFollowRequest retrieves a follow request if it exists between source and target accounts.
|
||||
GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error)
|
||||
|
||||
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
|
||||
IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
|
||||
|
||||
// IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
|
||||
IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
|
||||
IsMutualFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
|
||||
|
||||
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
|
||||
IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
|
||||
|
||||
// PutFollow attempts to place the given account follow in the database.
|
||||
PutFollow(ctx context.Context, follow *gtsmodel.Follow) error
|
||||
|
||||
// PutFollowRequest attempts to place the given account follow request in the database.
|
||||
PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error
|
||||
|
||||
// DeleteFollowByID deletes a follow from the database with the given ID.
|
||||
DeleteFollowByID(ctx context.Context, id string) error
|
||||
|
||||
// DeleteFollowByURI deletes a follow from the database with the given URI.
|
||||
DeleteFollowByURI(ctx context.Context, uri string) error
|
||||
|
||||
// DeleteFollowRequestByID deletes a follow request from the database with the given ID.
|
||||
DeleteFollowRequestByID(ctx context.Context, id string) error
|
||||
|
||||
// DeleteFollowRequestByURI deletes a follow request from the database with the given URI.
|
||||
DeleteFollowRequestByURI(ctx context.Context, uri string) error
|
||||
|
||||
// DeleteAccountFollows will delete all database follows to / from the given account ID.
|
||||
DeleteAccountFollows(ctx context.Context, accountID string) error
|
||||
|
||||
// DeleteAccountFollowRequests will delete all database follow requests to / from the given account ID.
|
||||
DeleteAccountFollowRequests(ctx context.Context, accountID string) error
|
||||
|
||||
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
|
||||
// In other words, it should create the follow, and delete the existing follow request.
|
||||
|
|
@ -69,65 +113,41 @@ type Relationship interface {
|
|||
AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
|
||||
|
||||
// RejectFollowRequest fetches a follow request from the database, and then deletes it.
|
||||
//
|
||||
// The deleted follow request will be returned so that further processing can be done on it.
|
||||
RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, Error)
|
||||
RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) Error
|
||||
|
||||
// GetFollows returns a slice of follows owned by the given accountID, and/or
|
||||
// targeting the given account id.
|
||||
//
|
||||
// If accountID is set and targetAccountID isn't, then all follows created by
|
||||
// accountID will be returned.
|
||||
//
|
||||
// If targetAccountID is set and accountID isn't, then all follows targeting
|
||||
// targetAccountID will be returned.
|
||||
//
|
||||
// If both accountID and targetAccountID are set, then only 0 or 1 follows will
|
||||
// be in the returned slice.
|
||||
GetFollows(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.Follow, Error)
|
||||
// GetAccountFollows returns a slice of follows owned by the given accountID.
|
||||
GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
|
||||
|
||||
// GetLocalFollowersIDs returns a list of local account IDs which follow the
|
||||
// targetAccountID. The returned IDs are not guaranteed to be ordered in any
|
||||
// particular way, so take care.
|
||||
GetLocalFollowersIDs(ctx context.Context, targetAccountID string) ([]string, Error)
|
||||
// GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance.
|
||||
GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
|
||||
|
||||
// CountFollows is like GetFollows, but just counts rather than returning.
|
||||
CountFollows(ctx context.Context, accountID string, targetAccountID string) (int, Error)
|
||||
// CountAccountFollows returns the amount of accounts that the given accountID is following.
|
||||
CountAccountFollows(ctx context.Context, accountID string) (int, error)
|
||||
|
||||
// GetFollowRequests returns a slice of follows requests owned by the given
|
||||
// accountID, and/or targeting the given account id.
|
||||
//
|
||||
// If accountID is set and targetAccountID isn't, then all requests created by
|
||||
// accountID will be returned.
|
||||
//
|
||||
// If targetAccountID is set and accountID isn't, then all requests targeting
|
||||
// targetAccountID will be returned.
|
||||
//
|
||||
// If both accountID and targetAccountID are set, then only 0 or 1 requests will
|
||||
// be in the returned slice.
|
||||
GetFollowRequests(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.FollowRequest, Error)
|
||||
// CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance.
|
||||
CountAccountLocalFollows(ctx context.Context, accountID string) (int, error)
|
||||
|
||||
// CountFollowRequests is like GetFollowRequests, but just counts rather than returning.
|
||||
CountFollowRequests(ctx context.Context, accountID string, targetAccountID string) (int, Error)
|
||||
// GetAccountFollowers fetches follows that target given accountID.
|
||||
GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
|
||||
|
||||
// Unfollow removes a follow targeting targetAccountID and originating
|
||||
// from originAccountID.
|
||||
//
|
||||
// If a follow was removed this way, the AP URI of the follow will be
|
||||
// returned to the caller, so that further processing can take place
|
||||
// if necessary.
|
||||
//
|
||||
// If no follow was removed this way, the returned string will be empty.
|
||||
Unfollow(ctx context.Context, originAccountID string, targetAccountID string) (string, Error)
|
||||
// GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance.
|
||||
GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
|
||||
|
||||
// UnfollowRequest removes a follow request targeting targetAccountID
|
||||
// and originating from originAccountID.
|
||||
//
|
||||
// If a follow request was removed this way, the AP URI of the follow
|
||||
// request will be returned to the caller, so that further processing
|
||||
// can take place if necessary.
|
||||
//
|
||||
// If no follow request was removed this way, the returned string will
|
||||
// be empty.
|
||||
UnfollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (string, Error)
|
||||
// CountAccountFollowers returns the amounts that the given ID is followed by.
|
||||
CountAccountFollowers(ctx context.Context, accountID string) (int, error)
|
||||
|
||||
// CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance.
|
||||
CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error)
|
||||
|
||||
// GetAccountFollowRequests returns all follow requests targeting the given account.
|
||||
GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error)
|
||||
|
||||
// GetAccountFollowRequesting returns all follow requests originating from the given account.
|
||||
GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error)
|
||||
|
||||
// CountAccountFollowRequests returns number of follow requests targeting the given account.
|
||||
CountAccountFollowRequests(ctx context.Context, accountID string) (int, error)
|
||||
|
||||
// CountAccountFollowerRequests returns number of follow requests originating from the given account.
|
||||
CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,6 +37,9 @@ type Status interface {
|
|||
// GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs
|
||||
GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)
|
||||
|
||||
// PopulateStatus ensures that all sub-models of a status are populated (e.g. mentions, attachments, etc).
|
||||
PopulateStatus(ctx context.Context, status *gtsmodel.Status) error
|
||||
|
||||
// PutStatus stores one status in the database.
|
||||
PutStatus(ctx context.Context, status *gtsmodel.Status) Error
|
||||
|
||||
|
|
|
|||
|
|
@ -24,22 +24,22 @@ import (
|
|||
)
|
||||
|
||||
type StatusFave interface {
|
||||
// GetStatusFave returns one status fave with the given id.
|
||||
GetStatusFave(ctx context.Context, id string) (*gtsmodel.StatusFave, Error)
|
||||
|
||||
// GetStatusFaveByAccountID gets one status fave created by the given
|
||||
// accountID, targeting the given statusID.
|
||||
GetStatusFaveByAccountID(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error)
|
||||
GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error)
|
||||
|
||||
// GetStatusFave returns one status fave with the given id.
|
||||
GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, Error)
|
||||
|
||||
// GetStatusFaves returns a slice of faves/likes of the given status.
|
||||
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
|
||||
GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error)
|
||||
GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error)
|
||||
|
||||
// PutStatusFave inserts the given statusFave into the database.
|
||||
PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) Error
|
||||
|
||||
// DeleteStatusFave deletes one status fave with the given id.
|
||||
DeleteStatusFave(ctx context.Context, id string) Error
|
||||
DeleteStatusFaveByID(ctx context.Context, id string) Error
|
||||
|
||||
// DeleteStatusFaves mass deletes status faves targeting targetAccountID
|
||||
// and/or originating from originAccountID and/or faving statusID.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue