mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-10-29 10:12:26 -05:00
[chore] Standardize database queries, use bun.Ident() properly (#886)
* use bun.Ident for user queries * use bun.Ident for account queries * use bun.Ident for media queries * add DeleteAccount func * remove CaseInsensitive in Where+use Ident ipv Safe * update admin db * update domain, use ident * update emoji, use ident * update instance queries, use bun.Ident * fix media * update mentions, use bun ident * update relationship + tests * use tableexpr * add test follows to bun db test suite * update notifications * updatebyprimarykey => updatebyid * fix session * prefer explicit ID to pk * fix little fucky wucky * remove workaround * use proper db func for attachment selection * update status db * add m2m entries in test rig * fix up timeline * go fmt * fix status put issue * update GetAccountStatuses
This commit is contained in:
parent
e58a6a2da3
commit
aa07750bdb
45 changed files with 1074 additions and 570 deletions
|
|
@ -21,7 +21,6 @@ package bundb
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
|
@ -56,7 +55,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac
|
|||
return a.cache.GetByID(id)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx)
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
|
@ -68,7 +67,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
|
|||
return a.cache.GetByURI(uri)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx)
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
|
@ -80,7 +79,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
|
|||
return a.cache.GetByURL(url)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx)
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
|
@ -95,11 +94,11 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
|
|||
q := a.newAccountQ(account)
|
||||
|
||||
if domain != "" {
|
||||
q = q.Where("account.username = ?", username)
|
||||
q = q.Where("account.domain = ?", domain)
|
||||
q = q.Where("? = ?", bun.Ident("account.username"), username)
|
||||
q = q.Where("? = ?", bun.Ident("account.domain"), domain)
|
||||
} else {
|
||||
q = q.Where("account.username = ?", strings.ToLower(username))
|
||||
q = q.Where("account.domain IS NULL")
|
||||
q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username))
|
||||
q = q.Where("? IS NULL", bun.Ident("account.domain"))
|
||||
}
|
||||
|
||||
return q.Scan(ctx)
|
||||
|
|
@ -114,7 +113,7 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
|
|||
return a.cache.GetByPubkeyID(id)
|
||||
},
|
||||
func(account *gtsmodel.Account) error {
|
||||
return a.newAccountQ(account).Where("account.public_key_uri = ?", id).Scan(ctx)
|
||||
return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
|
@ -169,26 +168,36 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
|
|||
if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||
// create links between this account and any emojis it uses
|
||||
// first clear out any old emoji links
|
||||
if _, err := tx.NewDelete().
|
||||
Model(&[]*gtsmodel.AccountToEmoji{}).
|
||||
Where("account_id = ?", account.ID).
|
||||
if _, err := tx.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
|
||||
Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// now populate new emoji links
|
||||
for _, i := range account.EmojiIDs {
|
||||
if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{
|
||||
AccountID: account.ID,
|
||||
EmojiID: i,
|
||||
}).Exec(ctx); err != nil {
|
||||
if _, err := tx.
|
||||
NewInsert().
|
||||
Model(>smodel.AccountToEmoji{
|
||||
AccountID: account.ID,
|
||||
EmojiID: i,
|
||||
}).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// update the account
|
||||
_, err := tx.NewUpdate().Model(account).WherePK().Exec(ctx)
|
||||
return err
|
||||
if _, err := tx.
|
||||
NewUpdate().
|
||||
Model(account).
|
||||
Where("? = ?", bun.Ident("account.id"), account.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
|
|
@ -197,6 +206,32 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
|
|||
return account, nil
|
||||
}
|
||||
|
||||
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
|
||||
if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error {
|
||||
// clear out any emoji links
|
||||
if _, err := tx.
|
||||
NewDelete().
|
||||
TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")).
|
||||
Where("? = ?", bun.Ident("account_to_emoji.account_id"), id).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete the account
|
||||
_, err := tx.
|
||||
NewUpdate().
|
||||
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||
Where("? = ?", bun.Ident("account.id"), id).
|
||||
Exec(ctx)
|
||||
return err
|
||||
}); err != nil {
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
a.cache.Invalidate(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
|
||||
account := new(gtsmodel.Account)
|
||||
|
||||
|
|
@ -204,11 +239,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
|
|||
|
||||
if domain != "" {
|
||||
q = q.
|
||||
Where("account.username = ?", domain).
|
||||
Where("account.domain = ?", domain)
|
||||
Where("? = ?", bun.Ident("account.username"), domain).
|
||||
Where("? = ?", bun.Ident("account.domain"), domain)
|
||||
} else {
|
||||
q = q.
|
||||
Where("account.username = ?", config.GetHost()).
|
||||
Where("? = ?", bun.Ident("account.username"), config.GetHost()).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||
}
|
||||
|
||||
|
|
@ -224,10 +259,10 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string)
|
|||
q := a.conn.
|
||||
NewSelect().
|
||||
Model(status).
|
||||
Order("id DESC").
|
||||
Limit(1).
|
||||
Where("account_id = ?", accountID).
|
||||
Column("created_at")
|
||||
Column("status.created_at").
|
||||
Where("? = ?", bun.Ident("status.account_id"), accountID).
|
||||
Order("status.id DESC").
|
||||
Limit(1)
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return time.Time{}, a.conn.ProcessError(err)
|
||||
|
|
@ -240,12 +275,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
|
|||
return errors.New("one media attachment cannot be both header and avatar")
|
||||
}
|
||||
|
||||
var headerOrAVI string
|
||||
var column bun.Ident
|
||||
switch {
|
||||
case *mediaAttachment.Avatar:
|
||||
headerOrAVI = "avatar"
|
||||
column = bun.Ident("account.avatar_media_attachment_id")
|
||||
case *mediaAttachment.Header:
|
||||
headerOrAVI = "header"
|
||||
column = bun.Ident("account.header_media_attachment_id")
|
||||
default:
|
||||
return errors.New("given media attachment was neither a header nor an avatar")
|
||||
}
|
||||
|
|
@ -257,11 +292,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
|
|||
Exec(ctx); err != nil {
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
if _, err := a.conn.
|
||||
NewUpdate().
|
||||
Model(>smodel.Account{}).
|
||||
Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
|
||||
Where("id = ?", accountID).
|
||||
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||
Set("? = ?", column, mediaAttachment.ID).
|
||||
Where("? = ?", bun.Ident("account.id"), accountID).
|
||||
Exec(ctx); err != nil {
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
|
@ -284,7 +320,7 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
|
|||
if err := a.conn.
|
||||
NewSelect().
|
||||
Model(faves).
|
||||
Where("account_id = ?", accountID).
|
||||
Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
|
|
@ -295,8 +331,8 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
|
|||
func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
|
||||
return a.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Status{}).
|
||||
Where("account_id = ?", accountID).
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
Where("? = ?", bun.Ident("status.account_id"), accountID).
|
||||
Count(ctx)
|
||||
}
|
||||
|
||||
|
|
@ -305,12 +341,12 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
|||
|
||||
q := a.conn.
|
||||
NewSelect().
|
||||
Table("statuses").
|
||||
Column("id").
|
||||
Order("id DESC")
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
Column("status.id").
|
||||
Order("status.id DESC")
|
||||
|
||||
if accountID != "" {
|
||||
q = q.Where("account_id = ?", accountID)
|
||||
q = q.Where("? = ?", bun.Ident("status.account_id"), accountID)
|
||||
}
|
||||
|
||||
if limit != 0 {
|
||||
|
|
@ -321,27 +357,27 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
|||
// include self-replies (threads)
|
||||
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
|
||||
return q.
|
||||
WhereOr("in_reply_to_account_id = ?", accountID).
|
||||
WhereGroup(" OR ", whereEmptyOrNull("in_reply_to_uri"))
|
||||
WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID).
|
||||
WhereGroup(" OR ", whereEmptyOrNull("status.in_reply_to_uri"))
|
||||
}
|
||||
|
||||
q = q.WhereGroup(" AND ", whereGroup)
|
||||
}
|
||||
|
||||
if excludeReblogs {
|
||||
q = q.WhereGroup(" AND ", whereEmptyOrNull("boost_of_id"))
|
||||
q = q.WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id"))
|
||||
}
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("status.id"), maxID)
|
||||
}
|
||||
|
||||
if minID != "" {
|
||||
q = q.Where("id > ?", minID)
|
||||
q = q.Where("? > ?", bun.Ident("status.id"), minID)
|
||||
}
|
||||
|
||||
if pinnedOnly {
|
||||
q = q.Where("pinned = ?", true)
|
||||
q = q.Where("? = ?", bun.Ident("status.pinned"), true)
|
||||
}
|
||||
|
||||
if mediaOnly {
|
||||
|
|
@ -352,15 +388,15 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
|||
switch a.conn.Dialect().Name() {
|
||||
case dialect.PG:
|
||||
return q.
|
||||
Where("? IS NOT NULL", bun.Ident("attachments")).
|
||||
Where("? != '{}'", bun.Ident("attachments"))
|
||||
Where("? IS NOT NULL", bun.Ident("status.attachments")).
|
||||
Where("? != '{}'", bun.Ident("status.attachments"))
|
||||
case dialect.SQLite:
|
||||
return q.
|
||||
Where("? IS NOT NULL", bun.Ident("attachments")).
|
||||
Where("? != ''", bun.Ident("attachments")).
|
||||
Where("? != 'null'", bun.Ident("attachments")).
|
||||
Where("? != '{}'", bun.Ident("attachments")).
|
||||
Where("? != '[]'", bun.Ident("attachments"))
|
||||
Where("? IS NOT NULL", bun.Ident("status.attachments")).
|
||||
Where("? != ''", bun.Ident("status.attachments")).
|
||||
Where("? != 'null'", bun.Ident("status.attachments")).
|
||||
Where("? != '{}'", bun.Ident("status.attachments")).
|
||||
Where("? != '[]'", bun.Ident("status.attachments"))
|
||||
default:
|
||||
log.Panic("db dialect was neither pg nor sqlite")
|
||||
return q
|
||||
|
|
@ -369,7 +405,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
|||
}
|
||||
|
||||
if publicOnly {
|
||||
q = q.Where("visibility = ?", gtsmodel.VisibilityPublic)
|
||||
q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic)
|
||||
}
|
||||
|
||||
if err := q.Scan(ctx, &statusIDs); err != nil {
|
||||
|
|
@ -384,19 +420,19 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string,
|
|||
|
||||
q := a.conn.
|
||||
NewSelect().
|
||||
Table("statuses").
|
||||
Column("id").
|
||||
Where("account_id = ?", accountID).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("in_reply_to_uri")).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")).
|
||||
Where("visibility = ?", gtsmodel.VisibilityPublic).
|
||||
Where("federated = ?", true)
|
||||
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
|
||||
Column("status.id").
|
||||
Where("? = ?", bun.Ident("status.account_id"), accountID).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")).
|
||||
Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
|
||||
Where("? = ?", bun.Ident("status.federated"), true)
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("id < ?", maxID)
|
||||
q = q.Where("? < ?", bun.Ident("status.id"), maxID)
|
||||
}
|
||||
|
||||
q = q.Limit(limit).Order("id DESC")
|
||||
q = q.Limit(limit).Order("status.id DESC")
|
||||
|
||||
if err := q.Scan(ctx, &statusIDs); err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
|
|
@ -411,16 +447,16 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
|
|||
fq := a.conn.
|
||||
NewSelect().
|
||||
Model(&blocks).
|
||||
Where("block.account_id = ?", accountID).
|
||||
Where("? = ?", bun.Ident("block.account_id"), accountID).
|
||||
Relation("TargetAccount").
|
||||
Order("block.id DESC")
|
||||
|
||||
if maxID != "" {
|
||||
fq = fq.Where("block.id < ?", maxID)
|
||||
fq = fq.Where("? < ?", bun.Ident("block.id"), maxID)
|
||||
}
|
||||
|
||||
if sinceID != "" {
|
||||
fq = fq.Where("block.id > ?", sinceID)
|
||||
fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID)
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue