[performance] refactoring + add fave / follow / request / visibility caching (#1607)

* refactor visibility checking, add caching for visibility

* invalidate visibility cache items on account / status deletes

* fix requester ID passed to visibility cache nil ptr

* de-interface caches, fix home / public timeline caching + visibility

* finish adding code comments for visibility filter

* fix angry goconst linter warnings

* actually finish adding filter visibility code comments for timeline functions

* move home timeline status author check to after visibility

* remove now-unused code

* add more code comments

* add TODO code comment, update printed cache start names

* update printed cache names on stop

* start adding separate follow(request) delete db functions, add specific visibility cache tests

* add relationship type caching

* fix getting local account follows / followed-bys, other small codebase improvements

* simplify invalidation using cache hooks, add more GetAccountBy___() functions

* fix boosting to return 404 if not boostable but no error (to not leak status ID)

* remove dead code

* improved placement of cache invalidation

* update license headers

* add example follow, follow-request config entries

* add example visibility cache configuration to config file

* use specific PutFollowRequest() instead of just Put()

* add tests for all GetAccountBy()

* add GetBlockBy() tests

* update block to check primitive fields

* update and finish adding Get{Account,Block,Follow,FollowRequest}By() tests

* fix copy-pasted code

* update envparsing test

* whitespace

* fix bun struct tag

* add license header to gtscontext

* fix old license header

* improved error creation to not use fmt.Errorf() when not needed

* fix various rebase conflicts, fix account test

* remove commented-out code, fix-up mention caching

* fix mention select bun statement

* ensure mention target account populated, pass in context to customrenderer logging

* remove more uncommented code, fix typeutil test

* add statusfave database model caching

* add status fave cache configuration

* add status fave cache example config

* woops, catch missed error. nice catch linter!

* add back testrig panic on nil db

* update example configuration to match defaults, slight tweak to cache configuration defaults

* update envparsing test with new defaults

* fetch followingget to use the follow target account

* use accounnt.IsLocal() instead of empty domain check

* use constants for the cache visibility type check

* use bun.In() for notification type restriction in db query

* include replies when fetching PublicTimeline() (to account for single-author threads in Visibility{}.StatusPublicTimelineable())

* use bun query building for nested select statements to ensure working with postgres

* update public timeline future status checks to match visibility filter

* same as previous, for home timeline

* update public timeline tests to dynamically check for appropriate statuses

* migrate accounts to allow unique constraint on public_key

* provide minimal account with publicKey

---------

Signed-off-by: kim <grufwub@gmail.com>
Co-authored-by: tsmethurst <tobi.smethurst@protonmail.com>
This commit is contained in:
kim 2023-03-28 14:03:14 +01:00 committed by GitHub
commit de6e3e5f2a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
100 changed files with 4423 additions and 2367 deletions

View file

@ -41,6 +41,21 @@ type Account interface {
// GetAccountByPubkeyID returns one account with the given public key URI (ID), or an error if something goes wrong.
GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error)
// GetAccountByInboxURI returns one account with the given inbox_uri, or an error if something goes wrong.
GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// GetAccountByOutboxURI returns one account with the given outbox_uri, or an error if something goes wrong.
GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// GetAccountByFollowingURI returns one account with the given following_uri, or an error if something goes wrong.
GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// GetAccountByFollowersURI returns one account with the given followers_uri, or an error if something goes wrong.
GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// PopulateAccount ensures that all sub-models of an account are populated (e.g. avatar, header etc).
PopulateAccount(ctx context.Context, account *gtsmodel.Account) error
// PutAccount puts one account in the database.
PutAccount(ctx context.Context, account *gtsmodel.Account) Error

View file

@ -20,11 +20,13 @@ package bundb
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@ -37,18 +39,15 @@ type accountDB struct {
state *state.State
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
return a.conn.
NewSelect().
Model(account)
}
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
"ID",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.id"), id).
Scan(ctx)
},
id,
)
@ -59,7 +58,10 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
ctx,
"URI",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.uri"), uri).
Scan(ctx)
},
uri,
)
@ -70,7 +72,10 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
ctx,
"URL",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.url"), url).
Scan(ctx)
},
url,
)
@ -81,7 +86,8 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
ctx,
"Username.Domain",
func(account *gtsmodel.Account) error {
q := a.newAccountQ(account)
q := a.conn.NewSelect().
Model(account)
if domain != "" {
q = q.
@ -105,12 +111,71 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
ctx,
"PublicKeyURI",
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.public_key_uri"), id).
Scan(ctx)
},
id,
)
}
func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
"InboxURI",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.inbox_uri"), uri).
Scan(ctx)
},
uri,
)
}
func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
"OutboxURI",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.outbox_uri"), uri).
Scan(ctx)
},
uri,
)
}
func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
"FollowersURI",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.followers_uri"), uri).
Scan(ctx)
},
uri,
)
}
func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
"FollowingURI",
func(account *gtsmodel.Account) error {
return a.conn.NewSelect().
Model(account).
Where("? = ?", bun.Ident("account.following_uri"), uri).
Scan(ctx)
},
uri,
)
}
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
var username string
@ -141,33 +206,58 @@ func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(
return nil, err
}
if account.AvatarMediaAttachmentID != "" {
// Set the account's related avatar
account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID)
if err != nil {
log.Errorf(ctx, "error getting account %s avatar: %v", account.ID, err)
}
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return account, nil
}
if account.HeaderMediaAttachmentID != "" {
// Set the account's related header
account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.HeaderMediaAttachmentID)
if err != nil {
log.Errorf(ctx, "error getting account %s header: %v", account.ID, err)
}
}
if len(account.EmojiIDs) > 0 {
// Set the account's related emojis
account.Emojis, err = a.state.DB.GetEmojisByIDs(ctx, account.EmojiIDs)
if err != nil {
log.Errorf(ctx, "error getting account %s emojis: %v", account.ID, err)
}
// Further populate the account fields where applicable.
if err := a.PopulateAccount(ctx, account); err != nil {
return nil, err
}
return account, nil
}
func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Account) error {
var err error
if account.AvatarMediaAttachment == nil && account.AvatarMediaAttachmentID != "" {
// Account avatar attachment is not set, fetch from database.
account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID(
ctx, // these are already barebones
account.AvatarMediaAttachmentID,
)
if err != nil {
return fmt.Errorf("error populating account avatar: %w", err)
}
}
if account.HeaderMediaAttachment == nil && account.HeaderMediaAttachmentID != "" {
// Account header attachment is not set, fetch from database.
account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID(
ctx, // these are already barebones
account.HeaderMediaAttachmentID,
)
if err != nil {
return fmt.Errorf("error populating account header: %w", err)
}
}
if !account.EmojisPopulated() {
// Account emojis are out-of-date with IDs, repopulate.
account.Emojis, err = a.state.DB.GetEmojisByIDs(
ctx, // these are already barebones
account.EmojiIDs,
)
if err != nil {
return fmt.Errorf("error populating account emojis: %w", err)
}
}
return nil
}
func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error {
return a.state.Caches.GTS.Account().Store(account, func() error {
// It is safe to run this database transaction within cache.Store
@ -198,7 +288,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
columns = append(columns, "updated_at")
}
return a.state.Caches.GTS.Account().Store(account, func() error {
err := a.state.Caches.GTS.Account().Store(account, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
@ -234,6 +324,11 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
return err
})
})
if err != nil {
return err
}
return nil
}
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
@ -258,7 +353,9 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
return err
}
// Invalidate account from database lookups.
a.state.Caches.GTS.Account().Invalidate("ID", id)
return nil
}

View file

@ -21,6 +21,8 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"reflect"
"strings"
"testing"
"time"
@ -61,44 +63,149 @@ func (suite *AccountTestSuite) TestGetAccountStatusesMediaOnly() {
suite.Len(statuses, 1)
}
func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID)
if err != nil {
suite.FailNow(err.Error())
func (suite *AccountTestSuite) TestGetAccountBy() {
t := suite.T()
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Sentinel error to mark avoiding a test case.
sentinelErr := errors.New("sentinel")
// isEqual checks if 2 account models are equal.
isEqual := func(a1, a2 gtsmodel.Account) bool {
// Clear populated sub-models.
a1.HeaderMediaAttachment = nil
a2.HeaderMediaAttachment = nil
a1.AvatarMediaAttachment = nil
a2.AvatarMediaAttachment = nil
a1.Emojis = nil
a2.Emojis = nil
// Clear database-set fields.
a1.CreatedAt = time.Time{}
a2.CreatedAt = time.Time{}
a1.UpdatedAt = time.Time{}
a2.UpdatedAt = time.Time{}
// Manually compare keys.
pk1 := a1.PublicKey
pv1 := a1.PrivateKey
pk2 := a2.PublicKey
pv2 := a2.PrivateKey
a1.PublicKey = nil
a1.PrivateKey = nil
a2.PublicKey = nil
a2.PrivateKey = nil
return reflect.DeepEqual(a1, a2) &&
((pk1 == nil && pk2 == nil) || pk1.Equal(pk2)) &&
((pv1 == nil && pv2 == nil) || pv1.Equal(pv2))
}
suite.NotNil(account)
suite.NotNil(account.AvatarMediaAttachment)
suite.NotEmpty(account.AvatarMediaAttachment.URL)
suite.NotNil(account.HeaderMediaAttachment)
suite.NotEmpty(account.HeaderMediaAttachment.URL)
}
func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() {
testAccount1 := suite.testAccounts["local_account_1"]
account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain)
suite.NoError(err)
suite.NotNil(account1)
for _, account := range suite.testAccounts {
for lookup, dbfunc := range map[string]func() (*gtsmodel.Account, error){
"id": func() (*gtsmodel.Account, error) {
return suite.db.GetAccountByID(ctx, account.ID)
},
testAccount2 := suite.testAccounts["remote_account_1"]
account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain)
suite.NoError(err)
suite.NotNil(account2)
}
"uri": func() (*gtsmodel.Account, error) {
return suite.db.GetAccountByURI(ctx, account.URI)
},
func (suite *AccountTestSuite) TestGetAccountByUsernameDomainMixedCase() {
testAccount := suite.testAccounts["remote_account_2"]
"url": func() (*gtsmodel.Account, error) {
if account.URL == "" {
return nil, sentinelErr
}
return suite.db.GetAccountByURL(ctx, account.URL)
},
account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount.Username, testAccount.Domain)
suite.NoError(err)
suite.NotNil(account1)
"username@domain": func() (*gtsmodel.Account, error) {
return suite.db.GetAccountByUsernameDomain(ctx, account.Username, account.Domain)
},
account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToUpper(testAccount.Username), testAccount.Domain)
suite.NoError(err)
suite.NotNil(account2)
"username_upper@domain": func() (*gtsmodel.Account, error) {
return suite.db.GetAccountByUsernameDomain(ctx, strings.ToUpper(account.Username), account.Domain)
},
account3, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToLower(testAccount.Username), testAccount.Domain)
suite.NoError(err)
suite.NotNil(account3)
"username_lower@domain": func() (*gtsmodel.Account, error) {
return suite.db.GetAccountByUsernameDomain(ctx, strings.ToLower(account.Username), account.Domain)
},
"public_key_uri": func() (*gtsmodel.Account, error) {
if account.PublicKeyURI == "" {
return nil, sentinelErr
}
return suite.db.GetAccountByPubkeyID(ctx, account.PublicKeyURI)
},
"inbox_uri": func() (*gtsmodel.Account, error) {
if account.InboxURI == "" {
return nil, sentinelErr
}
return suite.db.GetAccountByInboxURI(ctx, account.InboxURI)
},
"outbox_uri": func() (*gtsmodel.Account, error) {
if account.OutboxURI == "" {
return nil, sentinelErr
}
return suite.db.GetAccountByOutboxURI(ctx, account.OutboxURI)
},
"following_uri": func() (*gtsmodel.Account, error) {
if account.FollowingURI == "" {
return nil, sentinelErr
}
return suite.db.GetAccountByFollowingURI(ctx, account.FollowingURI)
},
"followers_uri": func() (*gtsmodel.Account, error) {
if account.FollowersURI == "" {
return nil, sentinelErr
}
return suite.db.GetAccountByFollowersURI(ctx, account.FollowersURI)
},
} {
// Clear database caches.
suite.state.Caches.Init()
t.Logf("checking database lookup %q", lookup)
// Perform database function.
checkAcc, err := dbfunc()
if err != nil {
if err == sentinelErr {
continue
}
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
continue
}
// Check received account data.
if !isEqual(*checkAcc, *account) {
t.Errorf("account does not contain expected data: %+v", checkAcc)
continue
}
// Check that avatar attachment populated.
if account.AvatarMediaAttachmentID != "" &&
(checkAcc.AvatarMediaAttachment == nil || checkAcc.AvatarMediaAttachment.ID != account.AvatarMediaAttachmentID) {
t.Errorf("account avatar media attachment not correctly populated for: %+v", account)
continue
}
// Check that header attachment populated.
if account.HeaderMediaAttachmentID != "" &&
(checkAcc.HeaderMediaAttachment == nil || checkAcc.HeaderMediaAttachment.ID != account.HeaderMediaAttachmentID) {
t.Errorf("account header media attachment not correctly populated for: %+v", account)
continue
}
}
}
}
func (suite *AccountTestSuite) TestUpdateAccount() {

View file

@ -19,6 +19,8 @@ package bundb_test
import (
"context"
"crypto/rand"
"crypto/rsa"
"testing"
"time"
@ -40,6 +42,12 @@ func (suite *BasicTestSuite) TestGetAccountByID() {
}
func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
suite.FailNow(err.Error())
}
// Create an account that only just matches constraints.
testAccount := &gtsmodel.Account{
ID: "01GADR1AH9VCKH8YYCM86XSZ00",
Username: "test",
@ -49,6 +57,7 @@ func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
OutboxURI: "https://example.org/users/test/outbox",
ActorType: "Person",
PublicKeyURI: "https://example.org/test#main-key",
PublicKey: &key.PublicKey,
}
if err := suite.db.Put(context.Background(), testAccount); err != nil {
@ -99,7 +108,7 @@ func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
suite.Empty(a.FeaturedCollectionURI)
suite.Equal(testAccount.ActorType, a.ActorType)
suite.Nil(a.PrivateKey)
suite.Nil(a.PublicKey)
suite.EqualValues(key.PublicKey, *a.PublicKey)
suite.Equal(testAccount.PublicKeyURI, a.PublicKeyURI)
suite.Zero(a.SensitizedAt)
suite.Zero(a.SilencedAt)

View file

@ -47,6 +47,24 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
)
}
func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) {
attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids))
for _, id := range ids {
// Attempt fetch from DB
attachment, err := m.GetAttachmentByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting attachment %q: %v", id, err)
continue
}
// Append attachment
attachments = append(attachments, attachment)
}
return attachments, nil
}
func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, db.Error) {
return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) {
var attachment gtsmodel.MediaAttachment
@ -118,7 +136,7 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l
return nil, m.conn.ProcessError(err)
}
return m.getAttachments(ctx, attachmentIDs)
return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) {
@ -163,7 +181,7 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit
return nil, m.conn.ProcessError(err)
}
return m.getAttachments(ctx, attachmentIDs)
return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, db.Error) {
@ -189,7 +207,7 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim
return nil, m.conn.ProcessError(err)
}
return m.getAttachments(ctx, attachmentIDs)
return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) {
@ -211,21 +229,3 @@ func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan t
return count, nil
}
func (m *mediaDB) getAttachments(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, db.Error) {
attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids))
for _, id := range ids {
// Attempt fetch from DB
attachment, err := m.GetAttachmentByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting attachment %q: %v", id, err)
continue
}
// Append attachment
attachments = append(attachments, attachment)
}
return attachments, nil
}

View file

@ -19,8 +19,10 @@ package bundb
import (
"context"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@ -32,20 +34,13 @@ type mentionDB struct {
state *state.State
}
func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
return m.conn.
NewSelect().
Model(i).
Relation("Status").
Relation("OriginAccount").
Relation("TargetAccount")
}
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
return m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) {
mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) {
var mention gtsmodel.Mention
q := m.newMentionQ(&mention).
q := m.conn.
NewSelect().
Model(&mention).
Where("? = ?", bun.Ident("mention.id"), id)
if err := q.Scan(ctx); err != nil {
@ -54,6 +49,38 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
return &mention, nil
}, id)
if err != nil {
return nil, err
}
// Set the mention originating status.
mention.Status, err = m.state.DB.GetStatusByID(
gtscontext.SetBarebones(ctx),
mention.StatusID,
)
if err != nil {
return nil, fmt.Errorf("error populating mention status: %w", err)
}
// Set the mention origin account model.
mention.OriginAccount, err = m.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
mention.OriginAccountID,
)
if err != nil {
return nil, fmt.Errorf("error populating mention origin account: %w", err)
}
// Set the mention target account model.
mention.TargetAccount, err = m.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
mention.TargetAccountID,
)
if err != nil {
return nil, fmt.Errorf("error populating mention target account: %w", err)
}
return mention, nil
}
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
@ -73,3 +100,25 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.
return mentions, nil
}
func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error {
return m.state.Caches.GTS.Mention().Store(mention, func() error {
_, err := m.conn.NewInsert().Model(mention).Exec(ctx)
return m.conn.ProcessError(err)
})
}
func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error {
if _, err := m.conn.
NewDelete().
Table("mentions").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return m.conn.ProcessError(err)
}
// Invalidate mention from the lookup cache.
m.state.Caches.GTS.Mention().Invalidate("ID", id)
return nil
}

View file

@ -0,0 +1,167 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package migrations
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
func init() {
up := func(ctx context.Context, db *bun.DB) error {
// To update unique constraint on public key, we need to migrate accounts into a new table.
// See section 7 here: https://www.sqlite.org/lang_altertable.html
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Create the new accounts table.
if _, err := tx.
NewCreateTable().
ModelTableExpr("new_accounts").
Model(&gtsmodel.Account{}).
Exec(ctx); err != nil {
return err
}
// If we don't specify columns explicitly,
// Postgres gives the following error when
// transferring accounts to new_accounts:
//
// ERROR: column "fetched_at" is of type timestamp with time zone but expression is of type character varying at character 35
// HINT: You will need to rewrite or cast the expression.
//
// Rather than do funky casting to fix this,
// it's simpler to just specify all columns.
columns := []string{
"id",
"created_at",
"updated_at",
"fetched_at",
"username",
"domain",
"avatar_media_attachment_id",
"avatar_remote_url",
"header_media_attachment_id",
"header_remote_url",
"display_name",
"emojis",
"fields",
"note",
"note_raw",
"memorial",
"also_known_as",
"moved_to_account_id",
"bot",
"reason",
"locked",
"discoverable",
"privacy",
"sensitive",
"language",
"status_content_type",
"custom_css",
"uri",
"url",
"inbox_uri",
"shared_inbox_uri",
"outbox_uri",
"following_uri",
"followers_uri",
"featured_collection_uri",
"actor_type",
"private_key",
"public_key",
"public_key_uri",
"sensitized_at",
"silenced_at",
"suspended_at",
"hide_collections",
"suspension_origin",
"enable_rss",
}
// Copy all accounts to the new table.
if _, err := tx.
NewInsert().
Table("new_accounts").
Table("accounts").
Column(columns...).
Exec(ctx); err != nil {
return err
}
// Drop the old table.
if _, err := tx.
NewDropTable().
Table("accounts").
Exec(ctx); err != nil {
return err
}
// Rename new table to old table.
if _, err := tx.
ExecContext(
ctx,
"ALTER TABLE ? RENAME TO ?",
bun.Ident("new_accounts"),
bun.Ident("accounts"),
); err != nil {
return err
}
// Add all account indexes to the new table.
for index, columns := range map[string][]string{
// Standard indices.
"accounts_id_idx": {"id"},
"accounts_suspended_at_idx": {"suspended_at"},
"accounts_domain_idx": {"domain"},
"accounts_username_domain_idx": {"username", "domain"},
// URI indices.
"accounts_uri_idx": {"uri"},
"accounts_url_idx": {"url"},
"accounts_inbox_uri_idx": {"inbox_uri"},
"accounts_outbox_uri_idx": {"outbox_uri"},
"accounts_followers_uri_idx": {"followers_uri"},
"accounts_following_uri_idx": {"following_uri"},
"accounts_public_key_uri_idx": {"public_key_uri"},
} {
if _, err := tx.
NewCreateIndex().
Table("accounts").
Index(index).
Column(columns...).
Exec(ctx); err != nil {
return err
}
}
return nil
})
}
down := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return nil
})
}
if err := Migrations.Register(up, down); err != nil {
panic(err)
}
}

View file

@ -33,7 +33,7 @@ type notificationDB struct {
state *state.State
}
func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) {
var notif gtsmodel.Notification
@ -48,7 +48,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo
}, id)
}
func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
func (n *notificationDB) GetAccountNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
// Ensure reasonable
if limit < 0 {
limit = 0
@ -92,7 +92,7 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
// reason for this is that for each notif, we can instead get it from our cache if it's cached
for _, id := range notifIDs {
// Attempt fetch from DB
notif, err := n.GetNotification(ctx, id)
notif, err := n.GetNotificationByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting notification %q: %v", id, err)
continue
@ -105,7 +105,14 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
return notifs, nil
}
func (n *notificationDB) DeleteNotification(ctx context.Context, id string) db.Error {
func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error {
return n.state.Caches.GTS.Notification().Store(notif, func() error {
_, err := n.conn.NewInsert().Model(notif).Exec(ctx)
return n.conn.ProcessError(err)
})
}
func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) db.Error {
if _, err := n.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
@ -118,19 +125,23 @@ func (n *notificationDB) DeleteNotification(ctx context.Context, id string) db.E
return nil
}
func (n *notificationDB) DeleteNotifications(ctx context.Context, targetAccountID string, originAccountID string) db.Error {
func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) db.Error {
if targetAccountID == "" && originAccountID == "" {
return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set")
}
// Capture notification IDs in a RETURNING statement.
ids := []string{}
var ids []string
q := n.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
Returning("?", bun.Ident("id"))
if len(types) > 0 {
q = q.Where("? IN (?)", bun.Ident("notification.notification_type"), bun.In(types))
}
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("notification.target_account_id"), targetAccountID)
}
@ -153,7 +164,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, targetAccountI
func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) db.Error {
// Capture notification IDs in a RETURNING statement.
ids := []string{}
var ids []string
q := n.conn.
NewDelete().

View file

@ -85,11 +85,11 @@ type NotificationTestSuite struct {
BunDBStandardTestSuite
}
func (suite *NotificationTestSuite) TestGetNotificationsWithSpam() {
func (suite *NotificationTestSuite) TestGetAccountNotificationsWithSpam() {
suite.spamNotifs()
testAccount := suite.testAccounts["local_account_1"]
before := time.Now()
notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
timeTaken := time.Since(before)
fmt.Printf("\n\n\n withSpam: got %d notifications in %s\n\n\n", len(notifications), timeTaken)
@ -100,10 +100,10 @@ func (suite *NotificationTestSuite) TestGetNotificationsWithSpam() {
}
}
func (suite *NotificationTestSuite) TestGetNotificationsWithoutSpam() {
func (suite *NotificationTestSuite) TestGetAccountNotificationsWithoutSpam() {
testAccount := suite.testAccounts["local_account_1"]
before := time.Now()
notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
timeTaken := time.Since(before)
fmt.Printf("\n\n\n withoutSpam: got %d notifications in %s\n\n\n", len(notifications), timeTaken)
@ -117,10 +117,10 @@ func (suite *NotificationTestSuite) TestGetNotificationsWithoutSpam() {
func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() {
suite.spamNotifs()
testAccount := suite.testAccounts["local_account_1"]
err := suite.db.DeleteNotifications(context.Background(), testAccount.ID, "")
err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, "")
suite.NoError(err)
notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
suite.NotNil(notifications)
suite.Empty(notifications)
@ -129,10 +129,10 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() {
func (suite *NotificationTestSuite) TestDeleteNotificationsWithTwoAccounts() {
suite.spamNotifs()
testAccount := suite.testAccounts["local_account_1"]
err := suite.db.DeleteNotifications(context.Background(), testAccount.ID, "")
err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, "")
suite.NoError(err)
notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
suite.NotNil(notifications)
suite.Empty(notifications)
@ -146,7 +146,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithTwoAccounts() {
func (suite *NotificationTestSuite) TestDeleteNotificationsOriginatingFromAccount() {
testAccount := suite.testAccounts["local_account_2"]
if err := suite.db.DeleteNotifications(context.Background(), "", testAccount.ID); err != nil {
if err := suite.db.DeleteNotifications(context.Background(), nil, "", testAccount.ID); err != nil {
suite.FailNow(err.Error())
}
@ -166,7 +166,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsOriginatingFromAndTar
originAccount := suite.testAccounts["local_account_2"]
targetAccount := suite.testAccounts["admin_account"]
if err := suite.db.DeleteNotifications(context.Background(), targetAccount.ID, originAccount.ID); err != nil {
if err := suite.db.DeleteNotifications(context.Background(), nil, targetAccount.ID, originAccount.ID); err != nil {
suite.FailNow(err.Error())
}

View file

@ -23,8 +23,8 @@ import (
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
)
@ -34,603 +34,212 @@ type relationshipDB struct {
state *state.State
}
func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
// Look for a block in direction of account1->account2
block1, err := r.getBlock(ctx, account1, account2)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return false, err
}
if block1 != nil {
// account1 blocks account2
return true, nil
} else if !eitherDirection {
// Don't check for mutli-directional
return false, nil
}
// Look for a block in direction of account2->account1
block2, err := r.getBlock(ctx, account2, account1)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return false, err
}
return (block2 != nil), nil
}
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
// Fetch block from database
block, err := r.getBlock(ctx, account1, account2)
if err != nil {
return nil, err
}
// Set the block originating account
block.Account, err = r.state.DB.GetAccountByID(ctx, block.AccountID)
if err != nil {
return nil, err
}
// Set the block target account
block.TargetAccount, err = r.state.DB.GetAccountByID(ctx, block.TargetAccountID)
if err != nil {
return nil, err
}
return block, nil
}
func (r *relationshipDB) getBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
return r.state.Caches.GTS.Block().Load("AccountID.TargetAccountID", func() (*gtsmodel.Block, error) {
var block gtsmodel.Block
q := r.conn.NewSelect().Model(&block).
Where("? = ?", bun.Ident("block.account_id"), account1).
Where("? = ?", bun.Ident("block.target_account_id"), account2)
if err := q.Scan(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
return &block, nil
}, account1, account2)
}
func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) db.Error {
return r.state.Caches.GTS.Block().Store(block, func() error {
_, err := r.conn.NewInsert().Model(block).Exec(ctx)
return r.conn.ProcessError(err)
})
}
func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) db.Error {
if _, err := r.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
Where("? = ?", bun.Ident("block.id"), id).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Drop any old value from cache by this ID
r.state.Caches.GTS.Block().Invalidate("ID", id)
return nil
}
func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) db.Error {
if _, err := r.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
Where("? = ?", bun.Ident("block.uri"), uri).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Drop any old value from cache by this URI
r.state.Caches.GTS.Block().Invalidate("URI", uri)
return nil
}
func (r *relationshipDB) DeleteBlocksByOriginAccountID(ctx context.Context, originAccountID string) db.Error {
blockIDs := []string{}
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
Column("block.id").
Where("? = ?", bun.Ident("block.account_id"), originAccountID)
if err := q.Scan(ctx, &blockIDs); err != nil {
return r.conn.ProcessError(err)
}
for _, blockID := range blockIDs {
if err := r.DeleteBlockByID(ctx, blockID); err != nil {
return err
}
}
return nil
}
func (r *relationshipDB) DeleteBlocksByTargetAccountID(ctx context.Context, targetAccountID string) db.Error {
blockIDs := []string{}
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
Column("block.id").
Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID)
if err := q.Scan(ctx, &blockIDs); err != nil {
return r.conn.ProcessError(err)
}
for _, blockID := range blockIDs {
if err := r.DeleteBlockByID(ctx, blockID); err != nil {
return err
}
}
return nil
}
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
rel := &gtsmodel.Relationship{
ID: targetAccount,
var rel gtsmodel.Relationship
rel.ID = targetAccount
// check if the requesting follows the target
follow, err := r.GetFollow(
gtscontext.SetBarebones(ctx),
requestingAccount,
targetAccount,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, fmt.Errorf("GetRelationship: error fetching follow: %w", err)
}
// check if the requesting account follows the target account
follow := &gtsmodel.Follow{}
if err := r.conn.
NewSelect().
Model(follow).
Column("follow.show_reblogs", "follow.notify").
Where("? = ?", bun.Ident("follow.account_id"), requestingAccount).
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount).
Limit(1).
Scan(ctx); err != nil {
if err := r.conn.ProcessError(err); err != db.ErrNoEntries {
return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err)
}
// no follow exists so these are all false
rel.Following = false
rel.ShowingReblogs = false
rel.Notifying = false
} else {
if follow != nil {
// follow exists so we can fill these fields out...
rel.Following = true
rel.ShowingReblogs = *follow.ShowReblogs
rel.Notifying = *follow.Notify
}
// check if the target account follows the requesting account
followedByQ := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.id").
Where("? = ?", bun.Ident("follow.account_id"), targetAccount).
Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount)
followedBy, err := r.conn.Exists(ctx, followedByQ)
// check if the target follows the requesting
rel.FollowedBy, err = r.IsFollowing(ctx,
targetAccount,
requestingAccount,
)
if err != nil {
return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err)
return nil, fmt.Errorf("GetRelationship: error checking followedBy: %w", err)
}
rel.FollowedBy = followedBy
// check if there's a pending following request from requesting account to target account
requestedQ := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Column("follow_request.id").
Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount).
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount)
requested, err := r.conn.Exists(ctx, requestedQ)
// check if requesting has follow requested target
rel.Requested, err = r.IsFollowRequested(ctx,
requestingAccount,
targetAccount,
)
if err != nil {
return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err)
return nil, fmt.Errorf("GetRelationship: error checking requested: %w", err)
}
rel.Requested = requested
// check if the requesting account is blocking the target account
blockA2T, err := r.getBlock(ctx, requestingAccount, targetAccount)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err)
rel.Blocking, err = r.IsBlocked(ctx, requestingAccount, targetAccount)
if err != nil {
return nil, fmt.Errorf("GetRelationship: error checking blocking: %w", err)
}
rel.Blocking = (blockA2T != nil)
// check if the requesting account is blocked by the target account
blockT2A, err := r.getBlock(ctx, targetAccount, requestingAccount)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err)
}
rel.BlockedBy = (blockT2A != nil)
return rel, nil
}
func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.id").
Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID).
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID)
return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Column("follow_request.id").
Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID).
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID)
return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
if account1 == nil || account2 == nil {
return false, nil
}
// make sure account 1 follows account 2
f1, err := r.IsFollowing(ctx, account1, account2)
rel.BlockedBy, err = r.IsBlocked(ctx, targetAccount, requestingAccount)
if err != nil {
return false, err
return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %w", err)
}
// make sure account 2 follows account 1
f2, err := r.IsFollowing(ctx, account2, account1)
if err != nil {
return false, err
}
return f1 && f2, nil
return &rel, nil
}
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
// Get original follow request.
var followRequestID string
if err := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Column("follow_request.id").
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
Scan(ctx, &followRequestID); err != nil {
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string
if err := newSelectFollows(r.conn, accountID).
Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
followRequest, err := r.getFollowRequest(ctx, followRequestID)
if err != nil {
return nil, r.conn.ProcessError(err)
}
// Create a new follow to 'replace'
// the original follow request with.
follow := &gtsmodel.Follow{
ID: followRequest.ID,
AccountID: originAccountID,
Account: followRequest.Account,
TargetAccountID: targetAccountID,
TargetAccount: followRequest.TargetAccount,
URI: followRequest.URI,
}
// If the follow already exists, just
// replace the URI with the new one.
if _, err := r.conn.
NewInsert().
Model(follow).
On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
// Delete original follow request.
if _, err := r.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
// Delete original follow request notification.
if err := r.deleteFollowRequestNotif(ctx, originAccountID, targetAccountID); err != nil {
return nil, err
}
// return the new follow
return follow, nil
return r.GetFollowsByIDs(ctx, followIDs)
}
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) {
// Get original follow request.
var followRequestID string
if err := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Column("follow_request.id").
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
Scan(ctx, &followRequestID); err != nil {
func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string
if err := newSelectLocalFollows(r.conn, accountID).
Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
followRequest, err := r.getFollowRequest(ctx, followRequestID)
if err != nil {
return nil, r.conn.ProcessError(err)
}
// Delete original follow request.
if _, err := r.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
// Delete original follow request notification.
if err := r.deleteFollowRequestNotif(ctx, originAccountID, targetAccountID); err != nil {
return nil, err
}
// Return the now deleted follow request.
return followRequest, nil
return r.GetFollowsByIDs(ctx, followIDs)
}
func (r *relationshipDB) deleteFollowRequestNotif(ctx context.Context, originAccountID string, targetAccountID string) db.Error {
var id string
if err := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
Column("notification.id").
Where("? = ?", bun.Ident("notification.origin_account_id"), originAccountID).
Where("? = ?", bun.Ident("notification.target_account_id"), targetAccountID).
Where("? = ?", bun.Ident("notification.notification_type"), gtsmodel.NotificationFollowRequest).
Limit(1). // There should only be one!
Scan(ctx, &id); err != nil {
err = r.conn.ProcessError(err)
if errors.Is(err, db.ErrNoEntries) {
// If no entries, the notif didn't
// exist anyway so nothing to do here.
return nil
}
// Return on real error.
return err
}
return r.state.DB.DeleteNotification(ctx, id)
}
func (r *relationshipDB) getFollow(ctx context.Context, id string) (*gtsmodel.Follow, db.Error) {
follow := &gtsmodel.Follow{}
err := r.conn.
NewSelect().
Model(follow).
Where("? = ?", bun.Ident("follow.id"), id).
Scan(ctx)
if err != nil {
func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string
if err := newSelectFollowers(r.conn, accountID).
Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
follow.Account, err = r.state.DB.GetAccountByID(ctx, follow.AccountID)
if err != nil {
log.Errorf(ctx, "error getting follow account %q: %v", follow.AccountID, err)
}
follow.TargetAccount, err = r.state.DB.GetAccountByID(ctx, follow.TargetAccountID)
if err != nil {
log.Errorf(ctx, "error getting follow target account %q: %v", follow.TargetAccountID, err)
}
return follow, nil
return r.GetFollowsByIDs(ctx, followIDs)
}
func (r *relationshipDB) GetLocalFollowersIDs(ctx context.Context, targetAccountID string) ([]string, db.Error) {
accountIDs := []string{}
func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
var followIDs []string
if err := newSelectLocalFollowers(r.conn, accountID).
Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
return r.GetFollowsByIDs(ctx, followIDs)
}
// Select only the account ID of each follow.
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
ColumnExpr("? AS ?", bun.Ident("follow.account_id"), bun.Ident("account_id")).
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollows(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
}
// Join on accounts table to select only
// those with NULL domain (local accounts).
q = q.
Join("JOIN ? AS ? ON ? = ?",
bun.Ident("accounts"),
bun.Ident("account"),
bun.Ident("follow.account_id"),
bun.Ident("account.id"),
func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) {
n, err := newSelectLocalFollows(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
}
func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollowers(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
}
func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) {
n, err := newSelectLocalFollowers(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
}
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
var followReqIDs []string
if err := newSelectFollowRequests(r.conn, accountID).
Scan(ctx, &followReqIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
return r.GetFollowRequestsByIDs(ctx, followReqIDs)
}
func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
var followReqIDs []string
if err := newSelectFollowRequesting(r.conn, accountID).
Scan(ctx, &followReqIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
return r.GetFollowRequestsByIDs(ctx, followReqIDs)
}
func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollowRequests(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
}
func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) {
n, err := newSelectFollowRequesting(r.conn, accountID).Count(ctx)
return n, r.conn.ProcessError(err)
}
// newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.
func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
TableExpr("?", bun.Ident("follow_requests")).
ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("target_account_id"), accountID).
OrderExpr("? DESC", bun.Ident("updated_at"))
}
// newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID.
func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
TableExpr("?", bun.Ident("follow_requests")).
ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("target_account_id"), accountID).
OrderExpr("? DESC", bun.Ident("updated_at"))
}
// newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID.
func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
Table("follows").
Column("id").
Where("? = ?", bun.Ident("account_id"), accountID).
OrderExpr("? DESC", bun.Ident("updated_at"))
}
// newSelectLocalFollows returns a new select query for all rows in the follows table with
// account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
Table("follows").
Column("id").
Where("? = ? AND ? IN (?)",
bun.Ident("account_id"),
accountID,
bun.Ident("target_account_id"),
conn.NewSelect().
Table("accounts").
Column("id").
Where("? IS NULL", bun.Ident("domain")),
).
Where("? IS NULL", bun.Ident("account.domain"))
// We don't *really* need to order these,
// but it makes it more consistent to do so.
q = q.Order("account_id DESC")
if err := q.Scan(ctx, &accountIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
return accountIDs, nil
OrderExpr("? DESC", bun.Ident("updated_at"))
}
func (r *relationshipDB) GetFollows(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.Follow, db.Error) {
ids := []string{}
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.id").
Order("follow.updated_at DESC")
if accountID != "" {
q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
}
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
}
if err := q.Scan(ctx, &ids); err != nil {
return nil, r.conn.ProcessError(err)
}
follows := make([]*gtsmodel.Follow, 0, len(ids))
for _, id := range ids {
follow, err := r.getFollow(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting follow %q: %v", id, err)
continue
}
follows = append(follows, follow)
}
return follows, nil
// newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID.
func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
Table("follows").
Column("id").
Where("? = ?", bun.Ident("target_account_id"), accountID).
OrderExpr("? DESC", bun.Ident("updated_at"))
}
func (r *relationshipDB) CountFollows(ctx context.Context, accountID string, targetAccountID string) (int, db.Error) {
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Column("follow.id")
if accountID != "" {
q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
}
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
}
return q.Count(ctx)
}
func (r *relationshipDB) getFollowRequest(ctx context.Context, id string) (*gtsmodel.FollowRequest, db.Error) {
followRequest := &gtsmodel.FollowRequest{}
err := r.conn.
NewSelect().
Model(followRequest).
Where("? = ?", bun.Ident("follow_request.id"), id).
Scan(ctx)
if err != nil {
return nil, r.conn.ProcessError(err)
}
followRequest.Account, err = r.state.DB.GetAccountByID(ctx, followRequest.AccountID)
if err != nil {
log.Errorf(ctx, "error getting follow request account %q: %v", followRequest.AccountID, err)
}
followRequest.TargetAccount, err = r.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
if err != nil {
log.Errorf(ctx, "error getting follow request target account %q: %v", followRequest.TargetAccountID, err)
}
return followRequest, nil
}
func (r *relationshipDB) GetFollowRequests(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.FollowRequest, db.Error) {
ids := []string{}
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Column("follow_request.id")
if accountID != "" {
q = q.Where("? = ?", bun.Ident("follow_request.account_id"), accountID)
}
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID)
}
if err := q.Scan(ctx, &ids); err != nil {
return nil, r.conn.ProcessError(err)
}
followRequests := make([]*gtsmodel.FollowRequest, 0, len(ids))
for _, id := range ids {
followRequest, err := r.getFollowRequest(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting follow request %q: %v", id, err)
continue
}
followRequests = append(followRequests, followRequest)
}
return followRequests, nil
}
func (r *relationshipDB) CountFollowRequests(ctx context.Context, accountID string, targetAccountID string) (int, db.Error) {
q := r.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Column("follow_request.id").
Order("follow_request.updated_at DESC")
if accountID != "" {
q = q.Where("? = ?", bun.Ident("follow_request.account_id"), accountID)
}
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID)
}
return q.Count(ctx)
}
func (r *relationshipDB) Unfollow(ctx context.Context, originAccountID string, targetAccountID string) (string, db.Error) {
uri := new(string)
_, err := r.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID).
Where("? = ?", bun.Ident("follow.account_id"), originAccountID).
Returning("?", bun.Ident("uri")).Exec(ctx, uri)
// Only return proper errors.
if err = r.conn.ProcessError(err); err != db.ErrNoEntries {
return *uri, err
}
return *uri, nil
}
func (r *relationshipDB) UnfollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (string, db.Error) {
uri := new(string)
_, err := r.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
Returning("?", bun.Ident("uri")).Exec(ctx, uri)
// Only return proper errors.
if err = r.conn.ProcessError(err); err != db.ErrNoEntries {
return *uri, err
}
return *uri, nil
// newSelectLocalFollowers returns a new select query for all rows in the follows table with
// target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
func newSelectLocalFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
return conn.NewSelect().
Table("follows").
Column("id").
Where("? = ? AND ? IN (?)",
bun.Ident("target_account_id"),
accountID,
bun.Ident("account_id"),
conn.NewSelect().
Table("accounts").
Column("id").
Where("? IS NULL", bun.Ident("domain")),
).
OrderExpr("? DESC", bun.Ident("updated_at"))
}

View file

@ -0,0 +1,218 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
)
func (r *relationshipDB) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
block, err := r.GetBlock(
gtscontext.SetBarebones(ctx),
sourceAccountID,
targetAccountID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return false, err
}
return (block != nil), nil
}
func (r *relationshipDB) IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error) {
// Look for a block in direction of account1->account2
b1, err := r.IsBlocked(ctx, accountID1, accountID2)
if err != nil || b1 {
return true, err
}
// Look for a block in direction of account2->account1
b2, err := r.IsBlocked(ctx, accountID2, accountID1)
if err != nil || b2 {
return true, err
}
return false, nil
}
func (r *relationshipDB) GetBlockByID(ctx context.Context, id string) (*gtsmodel.Block, error) {
return r.getBlock(
ctx,
"ID",
func(block *gtsmodel.Block) error {
return r.conn.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.id"), id).
Scan(ctx)
},
id,
)
}
func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmodel.Block, error) {
return r.getBlock(
ctx,
"URI",
func(block *gtsmodel.Block) error {
return r.conn.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.uri"), uri).
Scan(ctx)
},
uri,
)
}
func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) {
return r.getBlock(
ctx,
"AccountID.TargetAccountID",
func(block *gtsmodel.Block) error {
return r.conn.NewSelect().Model(block).
Where("? = ?", bun.Ident("block.account_id"), sourceAccountID).
Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID).
Scan(ctx)
},
sourceAccountID,
targetAccountID,
)
}
func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) {
// Fetch block from cache with loader callback
block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) {
var block gtsmodel.Block
// Not cached! Perform database query
if err := dbQuery(&block); err != nil {
return nil, r.conn.ProcessError(err)
}
return &block, nil
}, keyParts...)
if err != nil {
// already processe
return nil, err
}
if gtscontext.Barebones(ctx) {
// Only a barebones model was requested.
return block, nil
}
// Set the block source account
block.Account, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
block.AccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting block source account: %w", err)
}
// Set the block target account
block.TargetAccount, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
block.TargetAccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting block target account: %w", err)
}
return block, nil
}
func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error {
err := r.state.Caches.GTS.Block().Store(block, func() error {
_, err := r.conn.NewInsert().Model(block).Exec(ctx)
return r.conn.ProcessError(err)
})
if err != nil {
return err
}
// Invalidate block origin account ID cached visibility.
r.state.Caches.Visibility.Invalidate("ItemID", block.AccountID)
r.state.Caches.Visibility.Invalidate("RequesterID", block.AccountID)
// Invalidate block target account ID cached visibility.
r.state.Caches.Visibility.Invalidate("ItemID", block.TargetAccountID)
r.state.Caches.Visibility.Invalidate("RequesterID", block.TargetAccountID)
return nil
}
func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
block, err := r.GetBlockByID(gtscontext.SetBarebones(ctx), id)
if err != nil {
return err
}
return r.deleteBlock(ctx, block)
}
func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error {
block, err := r.GetBlockByURI(gtscontext.SetBarebones(ctx), uri)
if err != nil {
return err
}
return r.deleteBlock(ctx, block)
}
func (r *relationshipDB) deleteBlock(ctx context.Context, block *gtsmodel.Block) error {
if _, err := r.conn.
NewDelete().
Table("blocks").
Where("? = ?", bun.Ident("id"), block.ID).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Invalidate block from cache lookups.
r.state.Caches.GTS.Block().Invalidate("ID", block.ID)
return nil
}
func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID string) error {
var blockIDs []string
if err := r.conn.NewSelect().
Table("blocks").
ColumnExpr("?", bun.Ident("id")).
WhereOr("? = ? OR ? = ?",
bun.Ident("account_id"),
accountID,
bun.Ident("target_account_id"),
accountID,
).
Scan(ctx, &blockIDs); err != nil {
return r.conn.ProcessError(err)
}
for _, id := range blockIDs {
if err := r.DeleteBlockByID(ctx, id); err != nil {
log.Errorf(ctx, "error deleting block %q: %v", id, err)
}
}
return nil
}

View file

@ -0,0 +1,243 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
)
func (r *relationshipDB) GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error) {
return r.getFollow(
ctx,
"ID",
func(follow *gtsmodel.Follow) error {
return r.conn.NewSelect().
Model(follow).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
},
id,
)
}
func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmodel.Follow, error) {
return r.getFollow(
ctx,
"URI",
func(follow *gtsmodel.Follow) error {
return r.conn.NewSelect().
Model(follow).
Where("? = ?", bun.Ident("uri"), uri).
Scan(ctx)
},
uri,
)
}
func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
return r.getFollow(
ctx,
"AccountID.TargetAccountID",
func(follow *gtsmodel.Follow) error {
return r.conn.NewSelect().
Model(follow).
Where("? = ?", bun.Ident("account_id"), sourceAccountID).
Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
Scan(ctx)
},
sourceAccountID,
targetAccountID,
)
}
func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) {
// Preallocate slice of expected length.
follows := make([]*gtsmodel.Follow, 0, len(ids))
for _, id := range ids {
// Fetch follow model for this ID.
follow, err := r.GetFollowByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting follow %q: %v", id, err)
continue
}
// Append to return slice.
follows = append(follows, follow)
}
return follows, nil
}
func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
follow, err := r.GetFollow(
gtscontext.SetBarebones(ctx),
sourceAccountID,
targetAccountID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return false, err
}
return (follow != nil), nil
}
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 string, accountID2 string) (bool, db.Error) {
// make sure account 1 follows account 2
f1, err := r.IsFollowing(ctx,
accountID1,
accountID2,
)
if !f1 /* f1 = false when err != nil */ {
return false, err
}
// make sure account 2 follows account 1
f2, err := r.IsFollowing(ctx,
accountID2,
accountID1,
)
if !f2 /* f2 = false when err != nil */ {
return false, err
}
return true, nil
}
func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) {
// Fetch follow from database cache with loader callback
follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) {
var follow gtsmodel.Follow
// Not cached! Perform database query
if err := dbQuery(&follow); err != nil {
return nil, r.conn.ProcessError(err)
}
return &follow, nil
}, keyParts...)
if err != nil {
// error already processed
return nil, err
}
if gtscontext.Barebones(ctx) {
// Only a barebones model was requested.
return follow, nil
}
// Set the follow source account
follow.Account, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
follow.AccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting follow source account: %w", err)
}
// Set the follow target account
follow.TargetAccount, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
follow.TargetAccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting follow target account: %w", err)
}
return follow, nil
}
func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
err := r.state.Caches.GTS.Follow().Store(follow, func() error {
_, err := r.conn.NewInsert().Model(follow).Exec(ctx)
return r.conn.ProcessError(err)
})
if err != nil {
return err
}
// Invalidate follow origin account ID cached visibility.
r.state.Caches.Visibility.Invalidate("ItemID", follow.AccountID)
r.state.Caches.Visibility.Invalidate("RequesterID", follow.AccountID)
// Invalidate follow target account ID cached visibility.
r.state.Caches.Visibility.Invalidate("ItemID", follow.TargetAccountID)
r.state.Caches.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
return nil
}
func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error {
if _, err := r.conn.NewDelete().
Table("follows").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Invalidate follow from cache lookups.
r.state.Caches.GTS.Follow().Invalidate("ID", id)
return nil
}
func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error {
if _, err := r.conn.NewDelete().
Table("follows").
Where("? = ?", bun.Ident("uri"), uri).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Invalidate follow from cache lookups.
r.state.Caches.GTS.Follow().Invalidate("URI", uri)
return nil
}
func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID string) error {
var followIDs []string
if _, err := r.conn.
NewDelete().
Table("follows").
WhereOr("? = ? OR ? = ?",
bun.Ident("account_id"),
accountID,
bun.Ident("target_account_id"),
accountID,
).
Returning("?", bun.Ident("id")).
Exec(ctx, &followIDs); err != nil {
return r.conn.ProcessError(err)
}
// Invalidate each returned ID.
for _, id := range followIDs {
r.state.Caches.GTS.Follow().Invalidate("ID", id)
}
return nil
}

View file

@ -0,0 +1,293 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
)
func (r *relationshipDB) GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error) {
return r.getFollowRequest(
ctx,
"ID",
func(followReq *gtsmodel.FollowRequest) error {
return r.conn.NewSelect().
Model(followReq).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
},
id,
)
}
func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string) (*gtsmodel.FollowRequest, error) {
return r.getFollowRequest(
ctx,
"URI",
func(followReq *gtsmodel.FollowRequest) error {
return r.conn.NewSelect().
Model(followReq).
Where("? = ?", bun.Ident("uri"), uri).
Scan(ctx)
},
uri,
)
}
func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) {
return r.getFollowRequest(
ctx,
"AccountID.TargetAccountID",
func(followReq *gtsmodel.FollowRequest) error {
return r.conn.NewSelect().
Model(followReq).
Where("? = ?", bun.Ident("account_id"), sourceAccountID).
Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
Scan(ctx)
},
sourceAccountID,
targetAccountID,
)
}
func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) {
// Preallocate slice of expected length.
followReqs := make([]*gtsmodel.FollowRequest, 0, len(ids))
for _, id := range ids {
// Fetch follow request model for this ID.
followReq, err := r.GetFollowRequestByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting follow request %q: %v", id, err)
continue
}
// Append to return slice.
followReqs = append(followReqs, followReq)
}
return followReqs, nil
}
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
followReq, err := r.GetFollowRequest(
gtscontext.SetBarebones(ctx),
sourceAccountID,
targetAccountID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return false, err
}
return (followReq != nil), nil
}
func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) {
// Fetch follow request from database cache with loader callback
followReq, err := r.state.Caches.GTS.FollowRequest().Load(lookup, func() (*gtsmodel.FollowRequest, error) {
var followReq gtsmodel.FollowRequest
// Not cached! Perform database query
if err := dbQuery(&followReq); err != nil {
return nil, r.conn.ProcessError(err)
}
return &followReq, nil
}, keyParts...)
if err != nil {
// error already processed
return nil, err
}
if gtscontext.Barebones(ctx) {
// Only a barebones model was requested.
return followReq, nil
}
// Set the follow request source account
followReq.Account, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
followReq.AccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting follow request source account: %w", err)
}
// Set the follow request target account
followReq.TargetAccount, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
followReq.TargetAccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting follow request target account: %w", err)
}
return followReq, nil
}
func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error {
err := r.state.Caches.GTS.FollowRequest().Store(follow, func() error {
_, err := r.conn.NewInsert().Model(follow).Exec(ctx)
return r.conn.ProcessError(err)
})
if err != nil {
return err
}
// Invalidate follow request origin account ID cached visibility.
r.state.Caches.Visibility.Invalidate("ItemID", follow.AccountID)
r.state.Caches.Visibility.Invalidate("RequesterID", follow.AccountID)
// Invalidate follow request target account ID cached visibility.
r.state.Caches.Visibility.Invalidate("ItemID", follow.TargetAccountID)
r.state.Caches.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
return nil
}
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
// Get original follow request.
followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID)
if err != nil {
return nil, err
}
// Create a new follow to 'replace'
// the original follow request with.
follow := &gtsmodel.Follow{
ID: followReq.ID,
AccountID: sourceAccountID,
Account: followReq.Account,
TargetAccountID: targetAccountID,
TargetAccount: followReq.TargetAccount,
URI: followReq.URI,
}
// If the follow already exists, just
// replace the URI with the new one.
if _, err := r.conn.
NewInsert().
Model(follow).
On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
// Delete original follow request.
if _, err := r.conn.
NewDelete().
Table("follow_requests").
Where("? = ?", bun.Ident("id"), followReq.ID).
Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
// Invalidate follow request from cache lookups.
r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID)
// Delete original follow request notification
if err := r.state.DB.DeleteNotifications(ctx, []string{
string(gtsmodel.NotificationFollowRequest),
}, targetAccountID, sourceAccountID); err != nil {
return nil, err
}
return follow, nil
}
func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) db.Error {
// Get original follow request.
followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID)
if err != nil {
return err
}
// Delete original follow request.
if _, err := r.conn.
NewDelete().
Table("follow_requests").
Where("? = ?", bun.Ident("id"), followReq.ID).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Delete original follow request notification
return r.state.DB.DeleteNotifications(ctx, []string{
string(gtsmodel.NotificationFollowRequest),
}, targetAccountID, sourceAccountID)
}
func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error {
if _, err := r.conn.NewDelete().
Table("follow_requests").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Invalidate follow request from cache lookups.
r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
return nil
}
func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error {
if _, err := r.conn.NewDelete().
Table("follow_requests").
Where("? = ?", bun.Ident("uri"), uri).
Exec(ctx); err != nil {
return r.conn.ProcessError(err)
}
// Invalidate follow request from cache lookups.
r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri)
return nil
}
func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accountID string) error {
var followIDs []string
if _, err := r.conn.
NewDelete().
Table("follow_requests").
WhereOr("? = ? OR ? = ?",
bun.Ident("account_id"),
accountID,
bun.Ident("target_account_id"),
accountID,
).
Returning("?", bun.Ident("id")).
Exec(ctx, &followIDs); err != nil {
return r.conn.ProcessError(err)
}
// Invalidate each returned ID.
for _, id := range followIDs {
r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
}
return nil
}

View file

@ -19,17 +19,359 @@ package bundb_test
import (
"context"
"errors"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
)
type RelationshipTestSuite struct {
BunDBStandardTestSuite
}
func (suite *RelationshipTestSuite) TestGetBlockBy() {
t := suite.T()
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Sentinel error to mark avoiding a test case.
sentinelErr := errors.New("sentinel")
// isEqual checks if 2 block models are equal.
isEqual := func(b1, b2 gtsmodel.Block) bool {
// Clear populated sub-models.
b1.Account = nil
b2.Account = nil
b1.TargetAccount = nil
b2.TargetAccount = nil
// Clear database-set fields.
b1.CreatedAt = time.Time{}
b2.CreatedAt = time.Time{}
b1.UpdatedAt = time.Time{}
b2.UpdatedAt = time.Time{}
return reflect.DeepEqual(b1, b2)
}
var testBlocks []*gtsmodel.Block
for _, account1 := range suite.testAccounts {
for _, account2 := range suite.testAccounts {
if account1.ID == account2.ID {
// don't block *yourself* ...
continue
}
// Create new account block.
block := &gtsmodel.Block{
ID: id.NewULID(),
URI: "http://127.0.0.1:8080/" + id.NewULID(),
AccountID: account1.ID,
TargetAccountID: account2.ID,
}
// Attempt to place the block in database (if not already).
if err := suite.db.PutBlock(ctx, block); err != nil {
if err != db.ErrAlreadyExists {
// Unrecoverable database error.
t.Fatalf("error creating block: %v", err)
}
// Fetch existing block from database between accounts.
block, _ = suite.db.GetBlock(ctx, account1.ID, account2.ID)
continue
}
// Append generated block to test cases.
testBlocks = append(testBlocks, block)
}
}
for _, block := range testBlocks {
for lookup, dbfunc := range map[string]func() (*gtsmodel.Block, error){
"id": func() (*gtsmodel.Block, error) {
return suite.db.GetBlockByID(ctx, block.ID)
},
"uri": func() (*gtsmodel.Block, error) {
return suite.db.GetBlockByURI(ctx, block.URI)
},
"origin_target": func() (*gtsmodel.Block, error) {
return suite.db.GetBlock(ctx, block.AccountID, block.TargetAccountID)
},
} {
// Clear database caches.
suite.state.Caches.Init()
t.Logf("checking database lookup %q", lookup)
// Perform database function.
checkBlock, err := dbfunc()
if err != nil {
if err == sentinelErr {
continue
}
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
continue
}
// Check received block data.
if !isEqual(*checkBlock, *block) {
t.Errorf("block does not contain expected data: %+v", checkBlock)
continue
}
// Check that block origin account populated.
if checkBlock.Account == nil || checkBlock.Account.ID != block.AccountID {
t.Errorf("block origin account not correctly populated for: %+v", checkBlock)
continue
}
// Check that block target account populated.
if checkBlock.TargetAccount == nil || checkBlock.TargetAccount.ID != block.TargetAccountID {
t.Errorf("block target account not correctly populated for: %+v", checkBlock)
continue
}
}
}
}
func (suite *RelationshipTestSuite) TestGetFollowBy() {
t := suite.T()
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Sentinel error to mark avoiding a test case.
sentinelErr := errors.New("sentinel")
// isEqual checks if 2 follow models are equal.
isEqual := func(f1, f2 gtsmodel.Follow) bool {
// Clear populated sub-models.
f1.Account = nil
f2.Account = nil
f1.TargetAccount = nil
f2.TargetAccount = nil
// Clear database-set fields.
f1.CreatedAt = time.Time{}
f2.CreatedAt = time.Time{}
f1.UpdatedAt = time.Time{}
f2.UpdatedAt = time.Time{}
return reflect.DeepEqual(f1, f2)
}
var testFollows []*gtsmodel.Follow
for _, account1 := range suite.testAccounts {
for _, account2 := range suite.testAccounts {
if account1.ID == account2.ID {
// don't follow *yourself* ...
continue
}
// Create new account follow.
follow := &gtsmodel.Follow{
ID: id.NewULID(),
URI: "http://127.0.0.1:8080/" + id.NewULID(),
AccountID: account1.ID,
TargetAccountID: account2.ID,
}
// Attempt to place the follow in database (if not already).
if err := suite.db.PutFollow(ctx, follow); err != nil {
if err != db.ErrAlreadyExists {
// Unrecoverable database error.
t.Fatalf("error creating follow: %v", err)
}
// Fetch existing follow from database between accounts.
follow, _ = suite.db.GetFollow(ctx, account1.ID, account2.ID)
continue
}
// Append generated follow to test cases.
testFollows = append(testFollows, follow)
}
}
for _, follow := range testFollows {
for lookup, dbfunc := range map[string]func() (*gtsmodel.Follow, error){
"id": func() (*gtsmodel.Follow, error) {
return suite.db.GetFollowByID(ctx, follow.ID)
},
"uri": func() (*gtsmodel.Follow, error) {
return suite.db.GetFollowByURI(ctx, follow.URI)
},
"origin_target": func() (*gtsmodel.Follow, error) {
return suite.db.GetFollow(ctx, follow.AccountID, follow.TargetAccountID)
},
} {
// Clear database caches.
suite.state.Caches.Init()
t.Logf("checking database lookup %q", lookup)
// Perform database function.
checkFollow, err := dbfunc()
if err != nil {
if err == sentinelErr {
continue
}
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
continue
}
// Check received follow data.
if !isEqual(*checkFollow, *follow) {
t.Errorf("follow does not contain expected data: %+v", checkFollow)
continue
}
// Check that follow origin account populated.
if checkFollow.Account == nil || checkFollow.Account.ID != follow.AccountID {
t.Errorf("follow origin account not correctly populated for: %+v", checkFollow)
continue
}
// Check that follow target account populated.
if checkFollow.TargetAccount == nil || checkFollow.TargetAccount.ID != follow.TargetAccountID {
t.Errorf("follow target account not correctly populated for: %+v", checkFollow)
continue
}
}
}
}
func (suite *RelationshipTestSuite) TestGetFollowRequestBy() {
t := suite.T()
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Sentinel error to mark avoiding a test case.
sentinelErr := errors.New("sentinel")
// isEqual checks if 2 follow request models are equal.
isEqual := func(f1, f2 gtsmodel.FollowRequest) bool {
// Clear populated sub-models.
f1.Account = nil
f2.Account = nil
f1.TargetAccount = nil
f2.TargetAccount = nil
// Clear database-set fields.
f1.CreatedAt = time.Time{}
f2.CreatedAt = time.Time{}
f1.UpdatedAt = time.Time{}
f2.UpdatedAt = time.Time{}
return reflect.DeepEqual(f1, f2)
}
var testFollowReqs []*gtsmodel.FollowRequest
for _, account1 := range suite.testAccounts {
for _, account2 := range suite.testAccounts {
if account1.ID == account2.ID {
// don't follow *yourself* ...
continue
}
// Create new account follow request.
followReq := &gtsmodel.FollowRequest{
ID: id.NewULID(),
URI: "http://127.0.0.1:8080/" + id.NewULID(),
AccountID: account1.ID,
TargetAccountID: account2.ID,
}
// Attempt to place the follow in database (if not already).
if err := suite.db.PutFollowRequest(ctx, followReq); err != nil {
if err != db.ErrAlreadyExists {
// Unrecoverable database error.
t.Fatalf("error creating follow request: %v", err)
}
// Fetch existing follow request from database between accounts.
followReq, _ = suite.db.GetFollowRequest(ctx, account1.ID, account2.ID)
continue
}
// Append generated follow request to test cases.
testFollowReqs = append(testFollowReqs, followReq)
}
}
for _, followReq := range testFollowReqs {
for lookup, dbfunc := range map[string]func() (*gtsmodel.FollowRequest, error){
"id": func() (*gtsmodel.FollowRequest, error) {
return suite.db.GetFollowRequestByID(ctx, followReq.ID)
},
"uri": func() (*gtsmodel.FollowRequest, error) {
return suite.db.GetFollowRequestByURI(ctx, followReq.URI)
},
"origin_target": func() (*gtsmodel.FollowRequest, error) {
return suite.db.GetFollowRequest(ctx, followReq.AccountID, followReq.TargetAccountID)
},
} {
// Clear database caches.
suite.state.Caches.Init()
t.Logf("checking database lookup %q", lookup)
// Perform database function.
checkFollowReq, err := dbfunc()
if err != nil {
if err == sentinelErr {
continue
}
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
continue
}
// Check received follow request data.
if !isEqual(*checkFollowReq, *followReq) {
t.Errorf("follow request does not contain expected data: %+v", checkFollowReq)
continue
}
// Check that follow request origin account populated.
if checkFollowReq.Account == nil || checkFollowReq.Account.ID != followReq.AccountID {
t.Errorf("follow request origin account not correctly populated for: %+v", checkFollowReq)
continue
}
// Check that follow request target account populated.
if checkFollowReq.TargetAccount == nil || checkFollowReq.TargetAccount.ID != followReq.TargetAccountID {
t.Errorf("follow request target account not correctly populated for: %+v", checkFollowReq)
continue
}
}
}
}
func (suite *RelationshipTestSuite) TestIsBlocked() {
ctx := context.Background()
@ -37,11 +379,11 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
account2 := suite.testAccounts["local_account_2"].ID
// no blocks exist between account 1 and account 2
blocked, err := suite.db.IsBlocked(ctx, account1, account2, false)
blocked, err := suite.db.IsBlocked(ctx, account1, account2)
suite.NoError(err)
suite.False(blocked)
blocked, err = suite.db.IsBlocked(ctx, account2, account1, false)
blocked, err = suite.db.IsBlocked(ctx, account2, account1)
suite.NoError(err)
suite.False(blocked)
@ -56,45 +398,24 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
}
// account 1 now blocks account 2
blocked, err = suite.db.IsBlocked(ctx, account1, account2, false)
blocked, err = suite.db.IsBlocked(ctx, account1, account2)
suite.NoError(err)
suite.True(blocked)
// account 2 doesn't block account 1
blocked, err = suite.db.IsBlocked(ctx, account2, account1, false)
blocked, err = suite.db.IsBlocked(ctx, account2, account1)
suite.NoError(err)
suite.False(blocked)
// a block exists in either direction between the two
blocked, err = suite.db.IsBlocked(ctx, account1, account2, true)
blocked, err = suite.db.IsEitherBlocked(ctx, account1, account2)
suite.NoError(err)
suite.True(blocked)
blocked, err = suite.db.IsBlocked(ctx, account2, account1, true)
blocked, err = suite.db.IsEitherBlocked(ctx, account2, account1)
suite.NoError(err)
suite.True(blocked)
}
func (suite *RelationshipTestSuite) TestGetBlock() {
ctx := context.Background()
account1 := suite.testAccounts["local_account_1"].ID
account2 := suite.testAccounts["local_account_2"].ID
if err := suite.db.PutBlock(ctx, &gtsmodel.Block{
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
URI: "http://localhost:8080/some_block_uri_1",
AccountID: account1,
TargetAccountID: account2,
}); err != nil {
suite.FailNow(err.Error())
}
block, err := suite.db.GetBlock(ctx, account1, account2)
suite.NoError(err)
suite.NotNil(block)
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
}
func (suite *RelationshipTestSuite) TestDeleteBlockByID() {
ctx := context.Background()
@ -157,7 +478,7 @@ func (suite *RelationshipTestSuite) TestDeleteBlockByURI() {
suite.Nil(block)
}
func (suite *RelationshipTestSuite) TestDeleteBlocksByOriginAccountID() {
func (suite *RelationshipTestSuite) TestDeleteAccountBlocks() {
ctx := context.Background()
// put a block in first
@ -179,38 +500,7 @@ func (suite *RelationshipTestSuite) TestDeleteBlocksByOriginAccountID() {
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
// delete the block by originAccountID
err = suite.db.DeleteBlocksByOriginAccountID(ctx, account1)
suite.NoError(err)
// block should be gone
block, err = suite.db.GetBlock(ctx, account1, account2)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(block)
}
func (suite *RelationshipTestSuite) TestDeleteBlocksByTargetAccountID() {
ctx := context.Background()
// put a block in first
account1 := suite.testAccounts["local_account_1"].ID
account2 := suite.testAccounts["local_account_2"].ID
if err := suite.db.PutBlock(ctx, &gtsmodel.Block{
ID: "01G202BCSXXJZ70BHB5KCAHH8C",
URI: "http://localhost:8080/some_block_uri_1",
AccountID: account1,
TargetAccountID: account2,
}); err != nil {
suite.FailNow(err.Error())
}
// make sure the block is in the db
block, err := suite.db.GetBlock(ctx, account1, account2)
suite.NoError(err)
suite.NotNil(block)
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
// delete the block by targetAccountID
err = suite.db.DeleteBlocksByTargetAccountID(ctx, account2)
err = suite.db.DeleteAccountBlocks(ctx, account1)
suite.NoError(err)
// block should be gone
@ -244,7 +534,7 @@ func (suite *RelationshipTestSuite) TestGetRelationship() {
func (suite *RelationshipTestSuite) TestIsFollowingYes() {
requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.True(isFollowing)
}
@ -252,7 +542,7 @@ func (suite *RelationshipTestSuite) TestIsFollowingYes() {
func (suite *RelationshipTestSuite) TestIsFollowingNo() {
requestingAccount := suite.testAccounts["admin_account"]
targetAccount := suite.testAccounts["local_account_2"]
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.False(isFollowing)
}
@ -260,7 +550,7 @@ func (suite *RelationshipTestSuite) TestIsFollowingNo() {
func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.True(isMutualFollowing)
}
@ -268,7 +558,7 @@ func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() {
requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["local_account_2"]
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.True(isMutualFollowing)
}
@ -306,7 +596,7 @@ func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() {
suite.Equal(followRequest.URI, follow.URI)
// Ensure notification is deleted.
notification, err := suite.db.GetNotification(ctx, followRequestNotification.ID)
notification, err := suite.db.GetNotificationByID(ctx, followRequestNotification.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(notification)
}
@ -389,7 +679,7 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
TargetAccountID: targetAccount.ID,
}
if err := suite.db.Put(ctx, followRequest); err != nil {
if err := suite.db.PutFollowRequest(ctx, followRequest); err != nil {
suite.FailNow(err.Error())
}
@ -404,12 +694,11 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
suite.FailNow(err.Error())
}
rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
suite.NoError(err)
suite.NotNil(rejectedFollowRequest)
// Ensure notification is deleted.
notification, err := suite.db.GetNotification(ctx, followRequestNotification.ID)
notification, err := suite.db.GetNotificationByID(ctx, followRequestNotification.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(notification)
}
@ -419,9 +708,8 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() {
account := suite.testAccounts["admin_account"]
targetAccount := suite.testAccounts["local_account_2"]
rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(rejectedFollowRequest)
}
func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
@ -440,42 +728,49 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
suite.FailNow(err.Error())
}
followRequests, err := suite.db.GetFollowRequests(ctx, "", targetAccount.ID)
followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID)
suite.NoError(err)
suite.Len(followRequests, 1)
}
func (suite *RelationshipTestSuite) TestGetAccountFollows() {
account := suite.testAccounts["local_account_1"]
follows, err := suite.db.GetFollows(context.Background(), account.ID, "")
follows, err := suite.db.GetAccountFollows(context.Background(), account.ID)
suite.NoError(err)
suite.Len(follows, 2)
}
func (suite *RelationshipTestSuite) TestCountAccountFollows() {
func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() {
account := suite.testAccounts["local_account_1"]
followsCount, err := suite.db.CountFollows(context.Background(), account.ID, "")
followsCount, err := suite.db.CountAccountLocalFollows(context.Background(), account.ID)
suite.NoError(err)
suite.Equal(2, followsCount)
}
func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() {
func (suite *RelationshipTestSuite) TestCountAccountFollows() {
account := suite.testAccounts["local_account_1"]
follows, err := suite.db.GetFollows(context.Background(), "", account.ID)
followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID)
suite.NoError(err)
suite.Equal(2, followsCount)
}
func (suite *RelationshipTestSuite) TestGetAccountFollowers() {
account := suite.testAccounts["local_account_1"]
follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID)
suite.NoError(err)
suite.Len(follows, 2)
}
func (suite *RelationshipTestSuite) TestGetLocalFollowersIDs() {
func (suite *RelationshipTestSuite) TestCountAccountFollowers() {
account := suite.testAccounts["local_account_1"]
accountIDs, err := suite.db.GetLocalFollowersIDs(context.Background(), account.ID)
followsCount, err := suite.db.CountAccountFollowers(context.Background(), account.ID)
suite.NoError(err)
suite.EqualValues([]string{"01F8MH5NBDF2MV7CTC4Q5128HF", "01F8MH17FWEB39HZJ76B6VXSKF"}, accountIDs)
suite.Equal(2, followsCount)
}
func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() {
func (suite *RelationshipTestSuite) TestCountAccountFollowersLocalOnly() {
account := suite.testAccounts["local_account_1"]
followsCount, err := suite.db.CountFollows(context.Background(), "", account.ID)
followsCount, err := suite.db.CountAccountLocalFollowers(context.Background(), account.ID)
suite.NoError(err)
suite.Equal(2, followsCount)
}
@ -484,18 +779,25 @@ func (suite *RelationshipTestSuite) TestUnfollowExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
uri, err := suite.db.Unfollow(context.Background(), originAccount.ID, targetAccount.ID)
follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.Equal("http://localhost:8080/users/the_mighty_zork/follow/01F8PY8RHWRQZV038T4E8T9YK8", uri)
suite.NotNil(follow)
err = suite.db.DeleteFollowByID(context.Background(), follow.ID)
suite.NoError(err)
follow, err = suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(follow)
}
func (suite *RelationshipTestSuite) TestUnfollowNotExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ"
uri, err := suite.db.Unfollow(context.Background(), originAccount.ID, targetAccountID)
suite.NoError(err)
suite.Empty(uri)
follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccountID)
suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(follow)
}
func (suite *RelationshipTestSuite) TestUnfollowRequestExisting() {
@ -510,22 +812,29 @@ func (suite *RelationshipTestSuite) TestUnfollowRequestExisting() {
TargetAccountID: targetAccount.ID,
}
if err := suite.db.Put(ctx, followRequest); err != nil {
if err := suite.db.PutFollowRequest(ctx, followRequest); err != nil {
suite.FailNow(err.Error())
}
uri, err := suite.db.UnfollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
followRequest, err := suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.Equal("http://localhost:8080/weeeeeeeeeeeeeeeee", uri)
suite.NotNil(followRequest)
err = suite.db.DeleteFollowRequestByID(context.Background(), followRequest.ID)
suite.NoError(err)
followRequest, err = suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(followRequest)
}
func (suite *RelationshipTestSuite) TestUnfollowRequestNotExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ"
uri, err := suite.db.UnfollowRequest(context.Background(), originAccount.ID, targetAccountID)
suite.NoError(err)
suite.Empty(uri)
followRequest, err := suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccountID)
suite.EqualError(err, db.ErrNoEntries.Error())
suite.Nil(followRequest)
}
func TestRelationshipTestSuite(t *testing.T) {

View file

@ -26,6 +26,7 @@ import (
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@ -41,7 +42,6 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
return s.conn.
NewSelect().
Model(status).
Relation("Attachments").
Relation("Tags").
Relation("CreatedWithApplication")
}
@ -102,81 +102,143 @@ func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*g
status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) {
var status gtsmodel.Status
// Not cached! Perform database query
// Not cached! Perform database query.
if err := dbQuery(&status); err != nil {
return nil, s.conn.ProcessError(err)
}
if status.InReplyToID != "" {
// Also load in-reply-to status
status.InReplyTo = new(gtsmodel.Status)
err := s.conn.NewSelect().Model(status.InReplyTo).
Where("? = ?", bun.Ident("status.id"), status.InReplyToID).
Scan(ctx)
if err != nil {
return nil, s.conn.ProcessError(err)
}
}
if status.BoostOfID != "" {
// Also load original boosted status
status.BoostOf = new(gtsmodel.Status)
err := s.conn.NewSelect().Model(status.BoostOf).
Where("? = ?", bun.Ident("status.id"), status.BoostOfID).
Scan(ctx)
if err != nil {
return nil, s.conn.ProcessError(err)
}
}
return &status, nil
}, keyParts...)
if err != nil {
// error already processed
return nil, err
}
// Set the status author account
status.Account, err = s.state.DB.GetAccountByID(ctx, status.AccountID)
if err != nil {
return nil, fmt.Errorf("error getting status account: %w", err)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return status, nil
}
if id := status.BoostOfAccountID; id != "" {
// Set boost of status' author account
status.BoostOfAccount, err = s.state.DB.GetAccountByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("error getting boosted status account: %w", err)
}
}
if id := status.InReplyToAccountID; id != "" {
// Set in-reply-to status' author account
status.InReplyToAccount, err = s.state.DB.GetAccountByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("error getting in reply to status account: %w", err)
}
}
if len(status.EmojiIDs) > 0 {
// Fetch status emojis
status.Emojis, err = s.state.DB.GetEmojisByIDs(ctx, status.EmojiIDs)
if err != nil {
return nil, fmt.Errorf("error getting status emojis: %w", err)
}
}
if len(status.MentionIDs) > 0 {
// Fetch status mentions
status.Mentions, err = s.state.DB.GetMentions(ctx, status.MentionIDs)
if err != nil {
return nil, fmt.Errorf("error getting status mentions: %w", err)
}
// Further populate the status fields where applicable.
if err := s.PopulateStatus(ctx, status); err != nil {
return nil, err
}
return status, nil
}
func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) error {
var err error
if status.Account == nil {
// Status author is not set, fetch from database.
status.Account, err = s.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
status.AccountID,
)
if err != nil {
return fmt.Errorf("error populating status author: %w", err)
}
}
if status.InReplyToID != "" && status.InReplyTo == nil {
// Status parent is not set, fetch from database.
status.InReplyTo, err = s.GetStatusByID(
gtscontext.SetBarebones(ctx),
status.InReplyToID,
)
if err != nil {
return fmt.Errorf("error populating status parent: %w", err)
}
}
if status.InReplyToID != "" {
if status.InReplyTo == nil {
// Status parent is not set, fetch from database.
status.InReplyTo, err = s.GetStatusByID(
gtscontext.SetBarebones(ctx),
status.InReplyToID,
)
if err != nil {
return fmt.Errorf("error populating status parent: %w", err)
}
}
if status.InReplyToAccount == nil {
// Status parent author is not set, fetch from database.
status.InReplyToAccount, err = s.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
status.InReplyToAccountID,
)
if err != nil {
return fmt.Errorf("error populating status parent author: %w", err)
}
}
}
if status.BoostOfID != "" {
if status.BoostOf == nil {
// Status boost is not set, fetch from database.
status.BoostOf, err = s.GetStatusByID(
gtscontext.SetBarebones(ctx),
status.BoostOfID,
)
if err != nil {
return fmt.Errorf("error populating status boost: %w", err)
}
}
if status.BoostOfAccount == nil {
// Status boost author is not set, fetch from database.
status.BoostOfAccount, err = s.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
status.BoostOfAccountID,
)
if err != nil {
return fmt.Errorf("error populating status boost author: %w", err)
}
}
}
if !status.AttachmentsPopulated() {
// Status attachments are out-of-date with IDs, repopulate.
status.Attachments, err = s.state.DB.GetAttachmentsByIDs(
ctx, // these are already barebones
status.AttachmentIDs,
)
if err != nil {
return fmt.Errorf("error populating status attachments: %w", err)
}
}
// TODO: once we don't fetch using relations.
// if !status.TagsPopulated() {
// }
if !status.MentionsPopulated() {
// Status mentions are out-of-date with IDs, repopulate.
status.Mentions, err = s.state.DB.GetMentions(
ctx, // leave fully populated for now
status.MentionIDs,
)
if err != nil {
return fmt.Errorf("error populating status mentions: %w", err)
}
}
if !status.EmojisPopulated() {
// Status emojis are out-of-date with IDs, repopulate.
status.Emojis, err = s.state.DB.GetEmojisByIDs(
ctx, // these are already barebones
status.EmojiIDs,
)
if err != nil {
return fmt.Errorf("error populating status emojis: %w", err)
}
}
return nil
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
err := s.state.Caches.GTS.Status().Store(status, func() error {
// It is safe to run this database transaction within cache.Store
@ -239,12 +301,16 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
})
})
if err != nil {
// already processed
return err
}
for _, id := range status.AttachmentIDs {
// Clear updated media attachment IDs from cache
// Invalidate media attachments from cache.
//
// NOTE: this is needed due to the way in which
// we upload status attachments, and only after
// update them with a known status ID. This is
// not the case for header/avatar attachments.
s.state.Caches.GTS.Media().Invalidate("ID", id)
}
@ -322,14 +388,19 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
return err
}
// Invalidate status from database lookups.
s.state.Caches.GTS.Status().Invalidate("ID", status.ID)
for _, id := range status.AttachmentIDs {
// Clear updated media attachment IDs from cache
// Invalidate media attachments from cache.
//
// NOTE: this is needed due to the way in which
// we upload status attachments, and only after
// update them with a known status ID. This is
// not the case for header/avatar attachments.
s.state.Caches.GTS.Media().Invalidate("ID", id)
}
// Drop any old status value from cache by this ID
s.state.Caches.GTS.Status().Invalidate("ID", status.ID)
return nil
}
@ -367,8 +438,12 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
return err
}
// Drop any old value from cache by this ID
// Invalidate status from database lookups.
s.state.Caches.GTS.Status().Invalidate("ID", id)
// Invalidate status from all visibility lookups.
s.state.Caches.Visibility.Invalidate("ItemID", id)
return nil
}

View file

@ -23,6 +23,7 @@ import (
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@ -34,29 +35,82 @@ type statusFaveDB struct {
state *state.State
}
func (s *statusFaveDB) GetStatusFave(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) {
fave := new(gtsmodel.StatusFave)
func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) {
return s.getStatusFave(
ctx,
"AccountID.StatusID",
func(fave *gtsmodel.StatusFave) error {
return s.conn.
NewSelect().
Model(fave).
Where("? = ?", bun.Ident("account_id"), accountID).
Where("? = ?", bun.Ident("status_id"), statusID).
Scan(ctx)
},
accountID,
statusID,
)
}
err := s.conn.
NewSelect().
Model(fave).
Where("? = ?", bun.Ident("status_fave.ID"), id).
Scan(ctx)
func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) {
return s.getStatusFave(
ctx,
"ID",
func(fave *gtsmodel.StatusFave) error {
return s.conn.
NewSelect().
Model(fave).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
},
id,
)
}
func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery func(*gtsmodel.StatusFave) error, keyParts ...any) (*gtsmodel.StatusFave, error) {
// Fetch status fave from database cache with loader callback
fave, err := s.state.Caches.GTS.StatusFave().Load(lookup, func() (*gtsmodel.StatusFave, error) {
var fave gtsmodel.StatusFave
// Not cached! Perform database query.
if err := dbQuery(&fave); err != nil {
return nil, s.conn.ProcessError(err)
}
return &fave, nil
}, keyParts...)
if err != nil {
return nil, s.conn.ProcessError(err)
return nil, err
}
fave.Account, err = s.state.DB.GetAccountByID(ctx, fave.AccountID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return fave, nil
}
// Fetch the status fave author account.
fave.Account, err = s.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
fave.AccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting status fave account %q: %w", fave.AccountID, err)
}
fave.TargetAccount, err = s.state.DB.GetAccountByID(ctx, fave.TargetAccountID)
// Fetch the status fave target account.
fave.TargetAccount, err = s.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
fave.TargetAccountID,
)
if err != nil {
return nil, fmt.Errorf("error getting status fave target account %q: %w", fave.TargetAccountID, err)
}
fave.Status, err = s.state.DB.GetStatusByID(ctx, fave.StatusID)
// Fetch the status fave target status.
fave.Status, err = s.state.DB.GetStatusByID(
gtscontext.SetBarebones(ctx),
fave.StatusID,
)
if err != nil {
return nil, fmt.Errorf("error getting status fave status %q: %w", fave.StatusID, err)
}
@ -64,38 +118,22 @@ func (s *statusFaveDB) GetStatusFave(ctx context.Context, id string) (*gtsmodel.
return fave, nil
}
func (s *statusFaveDB) GetStatusFaveByAccountID(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) {
var id string
err := s.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
Column("status_fave.id").
Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
Where("? = ?", bun.Ident("status_fave.status_id"), statusID).
Scan(ctx, &id)
if err != nil {
return nil, s.conn.ProcessError(err)
}
return s.GetStatusFave(ctx, id)
}
func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) {
func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) {
ids := []string{}
if err := s.conn.
NewSelect().
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
Column("status_fave.id").
Where("? = ?", bun.Ident("status_fave.status_id"), statusID).
Table("status_faves").
Column("id").
Where("? = ?", bun.Ident("status_id"), statusID).
Scan(ctx, &ids); err != nil {
return nil, s.conn.ProcessError(err)
}
faves := make([]*gtsmodel.StatusFave, 0, len(ids))
for _, id := range ids {
fave, err := s.GetStatusFave(ctx, id)
fave, err := s.GetStatusFaveByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting status fave %q: %v", id, err)
continue
@ -107,23 +145,27 @@ func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*
return faves, nil
}
func (s *statusFaveDB) PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) db.Error {
_, err := s.conn.
NewInsert().
Model(statusFave).
Exec(ctx)
return s.conn.ProcessError(err)
func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) db.Error {
return s.state.Caches.GTS.StatusFave().Store(fave, func() error {
_, err := s.conn.
NewInsert().
Model(fave).
Exec(ctx)
return s.conn.ProcessError(err)
})
}
func (s *statusFaveDB) DeleteStatusFave(ctx context.Context, id string) db.Error {
_, err := s.conn.
func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) db.Error {
if _, err := s.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
Where("? = ?", bun.Ident("status_fave.id"), id).
Exec(ctx)
Table("status_faves").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return s.conn.ProcessError(err)
}
return s.conn.ProcessError(err)
s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
return nil
}
func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) db.Error {
@ -131,42 +173,52 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
return errors.New("DeleteStatusFaves: one of targetAccountID or originAccountID must be set")
}
// TODO: Capture fave IDs in a RETURNING
// statement (when faves have a cache),
// + use the IDs to invalidate cache entries.
// Capture fave IDs in a RETURNING statement.
var faveIDs []string
q := s.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave"))
Table("status_faves").
Returning("?", bun.Ident("id"))
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("status_fave.target_account_id"), targetAccountID)
q = q.Where("? = ?", bun.Ident("target_account_id"), targetAccountID)
}
if originAccountID != "" {
q = q.Where("? = ?", bun.Ident("status_fave.account_id"), originAccountID)
q = q.Where("? = ?", bun.Ident("account_id"), originAccountID)
}
if _, err := q.Exec(ctx); err != nil {
if _, err := q.Exec(ctx, &faveIDs); err != nil {
return s.conn.ProcessError(err)
}
for _, id := range faveIDs {
// Invalidate each of the returned status fave IDs.
s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
}
return nil
}
func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) db.Error {
// TODO: Capture fave IDs in a RETURNING
// statement (when faves have a cache),
// + use the IDs to invalidate cache entries.
// Capture fave IDs in a RETURNING statement.
var faveIDs []string
q := s.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
Where("? = ?", bun.Ident("status_fave.status_id"), statusID)
Table("status_faves").
Where("? = ?", bun.Ident("status_id"), statusID).
Returning("?", bun.Ident("id"))
if _, err := q.Exec(ctx); err != nil {
if _, err := q.Exec(ctx, &faveIDs); err != nil {
return s.conn.ProcessError(err)
}
for _, id := range faveIDs {
// Invalidate each of the returned status fave IDs.
s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
}
return nil
}

View file

@ -35,7 +35,7 @@ type StatusFaveTestSuite struct {
func (suite *StatusFaveTestSuite) TestGetStatusFaves() {
testStatus := suite.testStatuses["admin_account_status_1"]
faves, err := suite.db.GetStatusFaves(context.Background(), testStatus.ID)
faves, err := suite.db.GetStatusFavesForStatus(context.Background(), testStatus.ID)
if err != nil {
suite.FailNow(err.Error())
}
@ -51,7 +51,7 @@ func (suite *StatusFaveTestSuite) TestGetStatusFaves() {
func (suite *StatusFaveTestSuite) TestGetStatusFavesNone() {
testStatus := suite.testStatuses["admin_account_status_4"]
faves, err := suite.db.GetStatusFaves(context.Background(), testStatus.ID)
faves, err := suite.db.GetStatusFavesForStatus(context.Background(), testStatus.ID)
if err != nil {
suite.FailNow(err.Error())
}
@ -63,7 +63,7 @@ func (suite *StatusFaveTestSuite) TestGetStatusFaveByAccountID() {
testAccount := suite.testAccounts["local_account_1"]
testStatus := suite.testStatuses["admin_account_status_1"]
fave, err := suite.db.GetStatusFaveByAccountID(context.Background(), testAccount.ID, testStatus.ID)
fave, err := suite.db.GetStatusFave(context.Background(), testAccount.ID, testStatus.ID)
suite.NoError(err)
suite.NotNil(fave)
}
@ -129,17 +129,17 @@ func (suite *StatusFaveTestSuite) TestDeleteStatusFave() {
testFave := suite.testFaves["local_account_1_admin_account_status_1"]
ctx := context.Background()
if err := suite.db.DeleteStatusFave(ctx, testFave.ID); err != nil {
if err := suite.db.DeleteStatusFaveByID(ctx, testFave.ID); err != nil {
suite.FailNow(err.Error())
}
fave, err := suite.db.GetStatusFave(ctx, testFave.ID)
fave, err := suite.db.GetStatusFaveByID(ctx, testFave.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(fave)
}
func (suite *StatusFaveTestSuite) TestDeleteStatusFaveNonExisting() {
err := suite.db.DeleteStatusFave(context.Background(), "01GVAV715K6Y2SG9ZKS9ZA8G7G")
err := suite.db.DeleteStatusFaveByID(context.Background(), "01GVAV715K6Y2SG9ZKS9ZA8G7G")
suite.NoError(err)
}

View file

@ -61,9 +61,12 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
Order("status.id DESC")
if maxID == "" {
const future = 24 * time.Hour
var err error
// don't return statuses more than five minutes in the future
maxID, err = id.NewULIDFromTime(time.Now().Add(5 * time.Minute))
// don't return statuses more than 24hr in the future
maxID, err = id.NewULIDFromTime(time.Now().Add(future))
if err != nil {
return nil, err
}
@ -138,15 +141,16 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.id").
Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_id")).
WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")).
WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")).
Order("status.id DESC")
if maxID == "" {
const future = 24 * time.Hour
var err error
// don't return statuses more than five minutes in the future
maxID, err = id.NewULIDFromTime(time.Now().Add(5 * time.Minute))
// don't return statuses more than 24hr in the future
maxID, err = id.NewULIDFromTime(time.Now().Add(future))
if err != nil {
return nil, err
}

View file

@ -34,15 +34,32 @@ type TimelineTestSuite struct {
}
func (suite *TimelineTestSuite) TestGetPublicTimeline() {
ctx := context.Background()
var count int
for _, status := range suite.testStatuses {
if status.Visibility == gtsmodel.VisibilityPublic &&
status.BoostOfID == "" {
count++
}
}
ctx := context.Background()
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
suite.NoError(err)
suite.Len(s, 6)
suite.Len(s, count)
}
func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
var count int
for _, status := range suite.testStatuses {
if status.Visibility == gtsmodel.VisibilityPublic &&
status.BoostOfID == "" {
count++
}
}
ctx := context.Background()
futureStatus := getFutureStatus()
@ -53,7 +70,7 @@ func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
suite.NoError(err)
suite.NotContains(s, futureStatus)
suite.Len(s, 6)
suite.Len(s, count)
}
func (suite *TimelineTestSuite) TestGetHomeTimeline() {

View file

@ -29,6 +29,9 @@ type Media interface {
// GetAttachmentByID gets a single attachment by its ID.
GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error)
// GetAttachmentsByIDs fetches a list of media attachments for given IDs.
GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error)
// PutAttachment inserts the given attachment into the database.
PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error

View file

@ -30,4 +30,10 @@ type Mention interface {
// GetMentions gets multiple mentions.
GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error)
// PutMention will insert the given mention into the database.
PutMention(ctx context.Context, mention *gtsmodel.Mention) error
// DeleteMentionByID will delete mention with given ID from the database.
DeleteMentionByID(ctx context.Context, id string) error
}

View file

@ -28,14 +28,17 @@ type Notification interface {
// GetNotifications returns a slice of notifications that pertain to the given accountID.
//
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
GetAccountNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
// GetNotification returns one notification according to its id.
GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, Error)
GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, Error)
// DeleteNotification deletes one notification according to its id,
// PutNotification will insert the given notification into the database.
PutNotification(ctx context.Context, notif *gtsmodel.Notification) error
// DeleteNotificationByID deletes one notification according to its id,
// and removes that notification from the in-memory cache.
DeleteNotification(ctx context.Context, id string) Error
DeleteNotificationByID(ctx context.Context, id string) Error
// DeleteNotifications mass deletes notifications targeting targetAccountID
// and/or originating from originAccountID.
@ -50,7 +53,7 @@ type Notification interface {
// originate from originAccountID will be deleted.
//
// At least one parameter must not be an empty string.
DeleteNotifications(ctx context.Context, targetAccountID string, originAccountID string) Error
DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) Error
// DeleteNotificationsForStatus deletes all notifications that relate to
// the given statusID. This function is useful when a status has been deleted,

View file

@ -25,42 +25,86 @@ import (
// Relationship contains functions for getting or modifying the relationship between two accounts.
type Relationship interface {
// IsBlocked checks whether account 1 has a block in place against account2.
// If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1.
IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, Error)
// IsBlocked checks whether source account has a block in place against target.
IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
// IsEitherBlocked checks whether there is a block in place between either of account1 and account2.
IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error)
// GetBlockByID fetches block with given ID from the database.
GetBlockByID(ctx context.Context, id string) (*gtsmodel.Block, error)
// GetBlockByURI fetches block with given AP URI from the database.
GetBlockByURI(ctx context.Context, uri string) (*gtsmodel.Block, error)
// GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't.
//
// Because this is slower than Blocked, only use it if you need the actual Block struct for some reason,
// not if you're just checking for the existence of a block.
GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, Error)
GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, error)
// PutBlock attempts to place the given account block in the database.
PutBlock(ctx context.Context, block *gtsmodel.Block) Error
PutBlock(ctx context.Context, block *gtsmodel.Block) error
// DeleteBlockByID removes block with given ID from the database.
DeleteBlockByID(ctx context.Context, id string) Error
DeleteBlockByID(ctx context.Context, id string) error
// DeleteBlockByURI removes block with given AP URI from the database.
DeleteBlockByURI(ctx context.Context, uri string) Error
DeleteBlockByURI(ctx context.Context, uri string) error
// DeleteBlocksByOriginAccountID removes any blocks with accountID equal to originAccountID.
DeleteBlocksByOriginAccountID(ctx context.Context, originAccountID string) Error
// DeleteBlocksByTargetAccountID removes any blocks with given targetAccountID.
DeleteBlocksByTargetAccountID(ctx context.Context, targetAccountID string) Error
// DeleteAccountBlocks will delete all database blocks to / from the given account ID.
DeleteAccountBlocks(ctx context.Context, accountID string) error
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// GetFollowByID fetches follow with given ID from the database.
GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error)
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// GetFollowByURI fetches follow with given AP URI from the database.
GetFollowByURI(ctx context.Context, uri string) (*gtsmodel.Follow, error)
// GetFollow retrieves a follow if it exists between source and target accounts.
GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
// GetFollowRequestByID fetches follow request with given ID from the database.
GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error)
// GetFollowRequestByURI fetches follow request with given AP URI from the database.
GetFollowRequestByURI(ctx context.Context, uri string) (*gtsmodel.FollowRequest, error)
// GetFollowRequest retrieves a follow request if it exists between source and target accounts.
GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error)
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
// IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
IsMutualFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
// PutFollow attempts to place the given account follow in the database.
PutFollow(ctx context.Context, follow *gtsmodel.Follow) error
// PutFollowRequest attempts to place the given account follow request in the database.
PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error
// DeleteFollowByID deletes a follow from the database with the given ID.
DeleteFollowByID(ctx context.Context, id string) error
// DeleteFollowByURI deletes a follow from the database with the given URI.
DeleteFollowByURI(ctx context.Context, uri string) error
// DeleteFollowRequestByID deletes a follow request from the database with the given ID.
DeleteFollowRequestByID(ctx context.Context, id string) error
// DeleteFollowRequestByURI deletes a follow request from the database with the given URI.
DeleteFollowRequestByURI(ctx context.Context, uri string) error
// DeleteAccountFollows will delete all database follows to / from the given account ID.
DeleteAccountFollows(ctx context.Context, accountID string) error
// DeleteAccountFollowRequests will delete all database follow requests to / from the given account ID.
DeleteAccountFollowRequests(ctx context.Context, accountID string) error
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
@ -69,65 +113,41 @@ type Relationship interface {
AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
// RejectFollowRequest fetches a follow request from the database, and then deletes it.
//
// The deleted follow request will be returned so that further processing can be done on it.
RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, Error)
RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) Error
// GetFollows returns a slice of follows owned by the given accountID, and/or
// targeting the given account id.
//
// If accountID is set and targetAccountID isn't, then all follows created by
// accountID will be returned.
//
// If targetAccountID is set and accountID isn't, then all follows targeting
// targetAccountID will be returned.
//
// If both accountID and targetAccountID are set, then only 0 or 1 follows will
// be in the returned slice.
GetFollows(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.Follow, Error)
// GetAccountFollows returns a slice of follows owned by the given accountID.
GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
// GetLocalFollowersIDs returns a list of local account IDs which follow the
// targetAccountID. The returned IDs are not guaranteed to be ordered in any
// particular way, so take care.
GetLocalFollowersIDs(ctx context.Context, targetAccountID string) ([]string, Error)
// GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance.
GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
// CountFollows is like GetFollows, but just counts rather than returning.
CountFollows(ctx context.Context, accountID string, targetAccountID string) (int, Error)
// CountAccountFollows returns the amount of accounts that the given accountID is following.
CountAccountFollows(ctx context.Context, accountID string) (int, error)
// GetFollowRequests returns a slice of follows requests owned by the given
// accountID, and/or targeting the given account id.
//
// If accountID is set and targetAccountID isn't, then all requests created by
// accountID will be returned.
//
// If targetAccountID is set and accountID isn't, then all requests targeting
// targetAccountID will be returned.
//
// If both accountID and targetAccountID are set, then only 0 or 1 requests will
// be in the returned slice.
GetFollowRequests(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.FollowRequest, Error)
// CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance.
CountAccountLocalFollows(ctx context.Context, accountID string) (int, error)
// CountFollowRequests is like GetFollowRequests, but just counts rather than returning.
CountFollowRequests(ctx context.Context, accountID string, targetAccountID string) (int, Error)
// GetAccountFollowers fetches follows that target given accountID.
GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
// Unfollow removes a follow targeting targetAccountID and originating
// from originAccountID.
//
// If a follow was removed this way, the AP URI of the follow will be
// returned to the caller, so that further processing can take place
// if necessary.
//
// If no follow was removed this way, the returned string will be empty.
Unfollow(ctx context.Context, originAccountID string, targetAccountID string) (string, Error)
// GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance.
GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
// UnfollowRequest removes a follow request targeting targetAccountID
// and originating from originAccountID.
//
// If a follow request was removed this way, the AP URI of the follow
// request will be returned to the caller, so that further processing
// can take place if necessary.
//
// If no follow request was removed this way, the returned string will
// be empty.
UnfollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (string, Error)
// CountAccountFollowers returns the amounts that the given ID is followed by.
CountAccountFollowers(ctx context.Context, accountID string) (int, error)
// CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance.
CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error)
// GetAccountFollowRequests returns all follow requests targeting the given account.
GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error)
// GetAccountFollowRequesting returns all follow requests originating from the given account.
GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error)
// CountAccountFollowRequests returns number of follow requests targeting the given account.
CountAccountFollowRequests(ctx context.Context, accountID string) (int, error)
// CountAccountFollowerRequests returns number of follow requests originating from the given account.
CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error)
}

View file

@ -37,6 +37,9 @@ type Status interface {
// GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// PopulateStatus ensures that all sub-models of a status are populated (e.g. mentions, attachments, etc).
PopulateStatus(ctx context.Context, status *gtsmodel.Status) error
// PutStatus stores one status in the database.
PutStatus(ctx context.Context, status *gtsmodel.Status) Error

View file

@ -24,22 +24,22 @@ import (
)
type StatusFave interface {
// GetStatusFave returns one status fave with the given id.
GetStatusFave(ctx context.Context, id string) (*gtsmodel.StatusFave, Error)
// GetStatusFaveByAccountID gets one status fave created by the given
// accountID, targeting the given statusID.
GetStatusFaveByAccountID(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error)
GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error)
// GetStatusFave returns one status fave with the given id.
GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, Error)
// GetStatusFaves returns a slice of faves/likes of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error)
GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error)
// PutStatusFave inserts the given statusFave into the database.
PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) Error
// DeleteStatusFave deletes one status fave with the given id.
DeleteStatusFave(ctx context.Context, id string) Error
DeleteStatusFaveByID(ctx context.Context, id string) Error
// DeleteStatusFaves mass deletes status faves targeting targetAccountID
// and/or originating from originAccountID and/or faving statusID.