mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2026-01-06 00:33:15 -06:00
more
This commit is contained in:
parent
0f7cfa75f3
commit
526a14a92d
486 changed files with 84353 additions and 23865 deletions
|
|
@ -250,7 +250,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
|
|||
return err
|
||||
}
|
||||
|
||||
a.log.Infof("instance account CREATED with id %s", username, acct.ID)
|
||||
a.log.Infof("instance account %s CREATED with id %s", username, acct.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
|
|||
}
|
||||
|
||||
func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
|
||||
_, err := b.conn.NewDropTable().Model(i).Exec(ctx)
|
||||
_, err := b.conn.NewDropTable().IfExists().Model(i).Exec(ctx)
|
||||
return processErrorResponse(err)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,21 +42,13 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db
|
|||
return false, nil
|
||||
}
|
||||
|
||||
count, err := d.conn.
|
||||
q := d.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.DomainBlock{}).
|
||||
Where("LOWER(domain) = LOWER(?)", domain).
|
||||
Limit(1).
|
||||
Count(ctx)
|
||||
Limit(1)
|
||||
|
||||
blocked := count != 0
|
||||
err = processErrorResponse(err)
|
||||
|
||||
if err != db.ErrNoEntries {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return blocked, nil
|
||||
return exists(ctx, q)
|
||||
}
|
||||
|
||||
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/stdlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
|
|
@ -37,7 +39,6 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/dialect/pgdialect"
|
||||
"github.com/uptrace/bun/driver/pgdriver"
|
||||
)
|
||||
|
||||
var registerTables []interface{} = []interface{}{
|
||||
|
|
@ -67,19 +68,13 @@ type postgresService struct {
|
|||
// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
|
||||
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
|
||||
func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
|
||||
for _, t := range registerTables {
|
||||
// https://pg.uptrace.dev/orm/many-to-many-relation/
|
||||
bun.RegisterModel(t)
|
||||
}
|
||||
|
||||
opts, err := derivePGOptions(c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create postgres service: %s", err)
|
||||
return nil, fmt.Errorf("could not create postgres options: %s", err)
|
||||
}
|
||||
log.Debugf("using pg options: %+v", opts)
|
||||
|
||||
sqldb := sql.OpenDB(pgdriver.NewConnector(opts...))
|
||||
|
||||
log.Debugf("using opts %+v", opts)
|
||||
sqldb := stdlib.OpenDB(*opts)
|
||||
conn := bun.NewDB(sqldb, pgdialect.New())
|
||||
|
||||
// actually *begin* the connection so that we can tell if the db is there and listening
|
||||
|
|
@ -88,6 +83,12 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
|
|||
}
|
||||
log.Info("connected to postgres")
|
||||
|
||||
for _, t := range registerTables {
|
||||
// https://bun.uptrace.dev/orm/many-to-many-relation/
|
||||
conn.RegisterModel(t)
|
||||
}
|
||||
log.Info("models registered")
|
||||
|
||||
ps := &postgresService{
|
||||
Account: &accountDB{
|
||||
config: c,
|
||||
|
|
@ -157,9 +158,9 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
|
|||
HANDY STUFF
|
||||
*/
|
||||
|
||||
// derivePGOptions takes an application config and returns either a ready-to-use *pg.Options
|
||||
// derivePGOptions takes an application config and returns either a ready-to-use set of options
|
||||
// with sensible defaults, or an error if it's not satisfied by the provided config.
|
||||
func derivePGOptions(c *config.Config) ([]pgdriver.DriverOption, error) {
|
||||
func derivePGOptions(c *config.Config) (*pgx.ConnConfig, error) {
|
||||
if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres {
|
||||
return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type)
|
||||
}
|
||||
|
|
@ -237,18 +238,15 @@ func derivePGOptions(c *config.Config) ([]pgdriver.DriverOption, error) {
|
|||
tlsConfig.RootCAs = certPool
|
||||
}
|
||||
|
||||
// We can rely on the pg library we're using to set
|
||||
// sensible defaults for everything we don't set here.
|
||||
options := []pgdriver.DriverOption{
|
||||
pgdriver.WithAddr(fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port)),
|
||||
pgdriver.WithUser(c.DBConfig.User),
|
||||
pgdriver.WithPassword(c.DBConfig.Password),
|
||||
pgdriver.WithDatabase(c.DBConfig.Database),
|
||||
pgdriver.WithApplicationName(c.ApplicationName),
|
||||
pgdriver.WithTLSConfig(tlsConfig),
|
||||
}
|
||||
opts, _ := pgx.ParseConfig("")
|
||||
opts.Host = c.DBConfig.Address
|
||||
opts.Port = uint16(c.DBConfig.Port)
|
||||
opts.User = c.DBConfig.User
|
||||
opts.Password = c.DBConfig.Password
|
||||
opts.TLSConfig = tlsConfig
|
||||
opts.PreferSimpleProtocol = true
|
||||
|
||||
return options, nil
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -257,9 +255,9 @@ func derivePGOptions(c *config.Config) ([]pgdriver.DriverOption, error) {
|
|||
|
||||
// TODO: move these to the type converter, it's bananas that they're here and not there
|
||||
|
||||
func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
|
||||
func (ps *postgresService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
|
||||
ogAccount := >smodel.Account{}
|
||||
if err := ps.conn.Model(ogAccount).Where("id = ?", originAccountID).Select(); err != nil {
|
||||
if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -304,14 +302,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori
|
|||
// match username + account, case insensitive
|
||||
if local {
|
||||
// local user -- should have a null domain
|
||||
err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("? IS NULL", pg.Ident("domain")).Select()
|
||||
err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("? IS NULL", bun.Ident("domain")).Scan(ctx)
|
||||
} else {
|
||||
// remote user -- should have domain defined
|
||||
err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("LOWER(?) = LOWER(?)", pg.Ident("domain"), domain).Select()
|
||||
err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("LOWER(?) = LOWER(?)", bun.Ident("domain"), domain).Scan(ctx)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
if err == sql.ErrNoRows {
|
||||
// no result found for this username/domain so just don't include it as a mencho and carry on about our business
|
||||
ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
|
||||
continue
|
||||
|
|
@ -335,14 +333,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori
|
|||
return menchies, nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
|
||||
func (ps *postgresService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
|
||||
newTags := []*gtsmodel.Tag{}
|
||||
for _, t := range tags {
|
||||
tag := >smodel.Tag{}
|
||||
// we can use selectorinsert here to create the new tag if it doesn't exist already
|
||||
// inserted will be true if this is a new tag we just created
|
||||
if err := ps.conn.Model(tag).Where("LOWER(?) = LOWER(?)", pg.Ident("name"), t).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
// tag doesn't exist yet so populate it
|
||||
newID, err := id.NewRandomULID()
|
||||
if err != nil {
|
||||
|
|
@ -371,13 +369,13 @@ func (ps *postgresService) TagStringsToTags(tags []string, originAccountID strin
|
|||
return newTags, nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
|
||||
func (ps *postgresService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
|
||||
newEmojis := []*gtsmodel.Emoji{}
|
||||
for _, e := range emojis {
|
||||
emoji := >smodel.Emoji{}
|
||||
err := ps.conn.Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Select()
|
||||
err := ps.conn.NewSelect().Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Scan(ctx)
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
if err == sql.ErrNoRows {
|
||||
// no result found for this username/domain so just don't include it as an emoji and carry on about our business
|
||||
ps.log.Debugf("no emoji found with shortcode %s, skipping it", e)
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -67,16 +67,7 @@ func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account
|
|||
Where("account_id = ?", account2)
|
||||
}
|
||||
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
blocked := count != 0
|
||||
err = processErrorResponse(err)
|
||||
|
||||
if err != db.ErrNoEntries {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return blocked, nil
|
||||
return exists(ctx, q)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
|
||||
|
|
@ -186,16 +177,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
|
|||
Where("target_account_id = ?", targetAccount.ID).
|
||||
Limit(1)
|
||||
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
following := count != 0
|
||||
err = processErrorResponse(err)
|
||||
|
||||
if err != db.ErrNoEntries {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return following, nil
|
||||
return exists(ctx, q)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
|
||||
|
|
@ -209,16 +191,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
|
|||
Where("account_id = ?", sourceAccount.ID).
|
||||
Where("target_account_id = ?", targetAccount.ID)
|
||||
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
followRequested := count != 0
|
||||
err = processErrorResponse(err)
|
||||
|
||||
if err != db.ErrNoEntries {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return followRequested, nil
|
||||
return exists(ctx, q)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
|
||||
|
|
|
|||
|
|
@ -24,8 +24,6 @@ import (
|
|||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
|
|
@ -71,8 +69,10 @@ func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
|
|||
return status, true
|
||||
}
|
||||
|
||||
func (s *statusDB) newStatusQ(status interface{}) *orm.Query {
|
||||
return s.conn.Model(status).
|
||||
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
|
||||
return s.conn.
|
||||
NewSelect().
|
||||
Model(status).
|
||||
Relation("Attachments").
|
||||
Relation("Tags").
|
||||
Relation("Mentions").
|
||||
|
|
@ -85,14 +85,16 @@ func (s *statusDB) newStatusQ(status interface{}) *orm.Query {
|
|||
Relation("CreatedWithApplication")
|
||||
}
|
||||
|
||||
func (s *statusDB) newFaveQ(faves interface{}) *orm.Query {
|
||||
return s.conn.Model(faves).
|
||||
func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
|
||||
return s.conn.
|
||||
NewSelect().
|
||||
Model(faves).
|
||||
Relation("Account").
|
||||
Relation("TargetAccount").
|
||||
Relation("Status")
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) {
|
||||
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
|
||||
if status, cached := s.statusCached(id); cached {
|
||||
return status, nil
|
||||
}
|
||||
|
|
@ -102,7 +104,7 @@ func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) {
|
|||
q := s.newStatusQ(status).
|
||||
Where("status.id = ?", id)
|
||||
|
||||
err := processErrorResponse(q.Select())
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
if err == nil && status != nil {
|
||||
s.cacheStatus(id, status)
|
||||
|
|
@ -111,7 +113,7 @@ func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) {
|
|||
return status, err
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) {
|
||||
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
|
||||
if status, cached := s.statusCached(uri); cached {
|
||||
return status, nil
|
||||
}
|
||||
|
|
@ -121,7 +123,7 @@ func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) {
|
|||
q := s.newStatusQ(status).
|
||||
Where("LOWER(status.uri) = LOWER(?)", uri)
|
||||
|
||||
err := processErrorResponse(q.Select())
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
if err == nil && status != nil {
|
||||
s.cacheStatus(uri, status)
|
||||
|
|
@ -130,7 +132,7 @@ func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) {
|
|||
return status, err
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) {
|
||||
func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
|
||||
if status, cached := s.statusCached(uri); cached {
|
||||
return status, nil
|
||||
}
|
||||
|
|
@ -140,7 +142,7 @@ func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) {
|
|||
q := s.newStatusQ(status).
|
||||
Where("LOWER(status.url) = LOWER(?)", uri)
|
||||
|
||||
err := processErrorResponse(q.Select())
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
if err == nil && status != nil {
|
||||
s.cacheStatus(uri, status)
|
||||
|
|
@ -149,24 +151,24 @@ func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) {
|
|||
return status, err
|
||||
}
|
||||
|
||||
func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error {
|
||||
transaction := func(tx *pg.Tx) error {
|
||||
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
|
||||
transaction := func(ctx context.Context, tx bun.Tx) error {
|
||||
// create links between this status and any emojis it uses
|
||||
for _, i := range status.EmojiIDs {
|
||||
if _, err := tx.Model(>smodel.StatusToEmoji{
|
||||
if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{
|
||||
StatusID: status.ID,
|
||||
EmojiID: i,
|
||||
}).Insert(); err != nil {
|
||||
}).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// create links between this status and any tags it uses
|
||||
for _, i := range status.TagIDs {
|
||||
if _, err := tx.Model(>smodel.StatusToTag{
|
||||
if _, err := tx.NewInsert().Model(>smodel.StatusToTag{
|
||||
StatusID: status.ID,
|
||||
TagID: i,
|
||||
}).Insert(); err != nil {
|
||||
}).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -175,33 +177,33 @@ func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error {
|
|||
for _, a := range status.Attachments {
|
||||
a.StatusID = status.ID
|
||||
a.UpdatedAt = time.Now()
|
||||
if _, err := s.conn.Model(a).
|
||||
if _, err := s.conn.NewUpdate().Model(a).
|
||||
Where("id = ?", a.ID).
|
||||
Update(); err != nil {
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
_, err := tx.Model(status).Insert()
|
||||
_, err := tx.NewInsert().Model(status).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
return processErrorResponse(s.conn.RunInTransaction(context.Background(), transaction))
|
||||
return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
|
||||
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
|
||||
parents := []*gtsmodel.Status{}
|
||||
s.statusParent(status, &parents, onlyDirect)
|
||||
s.statusParent(ctx, status, &parents, onlyDirect)
|
||||
|
||||
return parents, nil
|
||||
}
|
||||
|
||||
func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
|
||||
func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
|
||||
if status.InReplyToID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
parentStatus, err := s.GetStatusByID(status.InReplyToID)
|
||||
parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID)
|
||||
if err == nil {
|
||||
*foundStatuses = append(*foundStatuses, parentStatus)
|
||||
}
|
||||
|
|
@ -210,13 +212,13 @@ func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmo
|
|||
return
|
||||
}
|
||||
|
||||
s.statusParent(parentStatus, foundStatuses, false)
|
||||
s.statusParent(ctx, parentStatus, foundStatuses, false)
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
|
||||
func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
|
||||
foundStatuses := &list.List{}
|
||||
foundStatuses.PushFront(status)
|
||||
s.statusChildren(status, foundStatuses, onlyDirect, minID)
|
||||
s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID)
|
||||
|
||||
children := []*gtsmodel.Status{}
|
||||
for e := foundStatuses.Front(); e != nil; e = e.Next() {
|
||||
|
|
@ -234,15 +236,18 @@ func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, m
|
|||
return children, nil
|
||||
}
|
||||
|
||||
func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
|
||||
func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
|
||||
immediateChildren := []*gtsmodel.Status{}
|
||||
|
||||
q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
|
||||
q := s.conn.
|
||||
NewSelect().
|
||||
Model(&immediateChildren).
|
||||
Where("in_reply_to_id = ?", status.ID)
|
||||
if minID != "" {
|
||||
q = q.Where("status.id > ?", minID)
|
||||
}
|
||||
|
||||
if err := q.Select(); err != nil {
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -264,56 +269,78 @@ func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.L
|
|||
if onlyDirect {
|
||||
return
|
||||
}
|
||||
s.statusChildren(child, foundStatuses, false, minID)
|
||||
s.statusChildren(ctx, child, foundStatuses, false, minID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.Error) {
|
||||
return s.conn.Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
|
||||
func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
|
||||
return s.conn.NewSelect().Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx)
|
||||
}
|
||||
|
||||
func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.Error) {
|
||||
return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
|
||||
func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
|
||||
return s.conn.NewSelect().Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx)
|
||||
}
|
||||
|
||||
func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.Error) {
|
||||
return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
|
||||
func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
|
||||
return s.conn.NewSelect().Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx)
|
||||
}
|
||||
|
||||
func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
|
||||
func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
q := s.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.StatusFave{}).
|
||||
Where("status_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
|
||||
return exists(ctx, q)
|
||||
}
|
||||
|
||||
func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
|
||||
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
q := s.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.Status{}).
|
||||
Where("boost_of_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
|
||||
return exists(ctx, q)
|
||||
}
|
||||
|
||||
func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
return s.conn.Model(>smodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
|
||||
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
q := s.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.StatusMute{}).
|
||||
Where("status_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
|
||||
return exists(ctx, q)
|
||||
}
|
||||
|
||||
func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
return s.conn.Model(>smodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
|
||||
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
q := s.conn.
|
||||
NewSelect().
|
||||
Model(>smodel.StatusBookmark{}).
|
||||
Where("status_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
|
||||
return exists(ctx, q)
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
|
||||
func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
|
||||
faves := []*gtsmodel.StatusFave{}
|
||||
|
||||
q := s.newFaveQ(&faves).
|
||||
Where("status_id = ?", status.ID)
|
||||
|
||||
err := processErrorResponse(q.Select())
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
return faves, err
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
|
||||
func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
|
||||
reblogs := []*gtsmodel.Status{}
|
||||
|
||||
q := s.newStatusQ(&reblogs).
|
||||
Where("boost_of_id = ?", status.ID)
|
||||
|
||||
err := processErrorResponse(q.Select())
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
return reblogs, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
package pg_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -56,7 +57,7 @@ func (suite *StatusTestSuite) TearDownTest() {
|
|||
}
|
||||
|
||||
func (suite *StatusTestSuite) TestGetStatusByID() {
|
||||
status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID)
|
||||
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
|
@ -67,10 +68,11 @@ func (suite *StatusTestSuite) TestGetStatusByID() {
|
|||
suite.Nil(status.BoostOfAccount)
|
||||
suite.Nil(status.InReplyTo)
|
||||
suite.Nil(status.InReplyToAccount)
|
||||
suite.log.Debug("test finished")
|
||||
}
|
||||
|
||||
func (suite *StatusTestSuite) TestGetStatusByURI() {
|
||||
status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
|
||||
status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
|
@ -84,7 +86,7 @@ func (suite *StatusTestSuite) TestGetStatusByURI() {
|
|||
}
|
||||
|
||||
func (suite *StatusTestSuite) TestGetStatusWithExtras() {
|
||||
status, err := suite.db.GetStatusByID(suite.testStatuses["admin_account_status_1"].ID)
|
||||
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["admin_account_status_1"].ID)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
|
@ -97,7 +99,7 @@ func (suite *StatusTestSuite) TestGetStatusWithExtras() {
|
|||
}
|
||||
|
||||
func (suite *StatusTestSuite) TestGetStatusWithMention() {
|
||||
status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_2_status_5"].ID)
|
||||
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_2_status_5"].ID)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
|
@ -112,14 +114,14 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() {
|
|||
|
||||
func (suite *StatusTestSuite) TestGetStatusTwice() {
|
||||
before1 := time.Now()
|
||||
_, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
|
||||
_, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
|
||||
suite.NoError(err)
|
||||
after1 := time.Now()
|
||||
duration1 := after1.Sub(before1)
|
||||
fmt.Println(duration1.Nanoseconds())
|
||||
|
||||
before2 := time.Now()
|
||||
_, err = suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
|
||||
_, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
|
||||
suite.NoError(err)
|
||||
after2 := time.Now()
|
||||
duration2 := after2.Sub(before2)
|
||||
|
|
|
|||
|
|
@ -20,9 +20,9 @@ package pg
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sort"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
|
|
@ -37,23 +37,27 @@ type timelineDB struct {
|
|||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
|
||||
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
|
||||
statuses := []*gtsmodel.Status{}
|
||||
q := t.conn.Model(&statuses)
|
||||
q := t.conn.
|
||||
NewSelect().
|
||||
Model(&statuses)
|
||||
|
||||
// Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows,
|
||||
// OR statuses posted by accountID itself (since a user should be able to see their own statuses).
|
||||
//
|
||||
// This is equivalent to something like WHERE ... AND (... OR ...)
|
||||
// See: https://pg.uptrace.dev/queries/#select
|
||||
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
|
||||
q = q.Where("f.account_id = ?", accountID).
|
||||
WhereOr("status.account_id = ?", accountID)
|
||||
return q
|
||||
}
|
||||
|
||||
q = q.ColumnExpr("status.*").
|
||||
// Find out who accountID follows.
|
||||
Join("LEFT JOIN follows AS f ON f.target_account_id = status.account_id").
|
||||
// Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows,
|
||||
// OR statuses posted by accountID itself (since a user should be able to see their own statuses).
|
||||
//
|
||||
// This is equivalent to something like WHERE ... AND (... OR ...)
|
||||
// See: https://pg.uptrace.dev/queries/#select
|
||||
WhereGroup(func(q *pg.Query) (*pg.Query, error) {
|
||||
q = q.WhereOr("f.account_id = ?", accountID).
|
||||
WhereOr("status.account_id = ?", accountID)
|
||||
return q, nil
|
||||
}).
|
||||
WhereGroup(" AND ", whereGroup).
|
||||
// Sort by highest ID (newest) to lowest ID (oldest)
|
||||
Order("status.id DESC")
|
||||
|
||||
|
|
@ -82,29 +86,19 @@ func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID str
|
|||
q = q.Limit(limit)
|
||||
}
|
||||
|
||||
err := q.Select()
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(statuses) == 0 {
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
|
||||
return statuses, nil
|
||||
return statuses, processErrorResponse(q.Scan(ctx))
|
||||
}
|
||||
|
||||
func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
|
||||
func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
|
||||
statuses := []*gtsmodel.Status{}
|
||||
|
||||
q := t.conn.Model(&statuses).
|
||||
q := t.conn.
|
||||
NewSelect().
|
||||
Model(&statuses).
|
||||
Where("visibility = ?", gtsmodel.VisibilityPublic).
|
||||
Where("? IS NULL", pg.Ident("in_reply_to_id")).
|
||||
Where("? IS NULL", pg.Ident("in_reply_to_uri")).
|
||||
Where("? IS NULL", pg.Ident("boost_of_id")).
|
||||
Where("? IS NULL", bun.Ident("in_reply_to_id")).
|
||||
Where("? IS NULL", bun.Ident("in_reply_to_uri")).
|
||||
Where("? IS NULL", bun.Ident("boost_of_id")).
|
||||
Order("status.id DESC")
|
||||
|
||||
if maxID != "" {
|
||||
|
|
@ -127,28 +121,18 @@ func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID s
|
|||
q = q.Limit(limit)
|
||||
}
|
||||
|
||||
err := q.Select()
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(statuses) == 0 {
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
|
||||
return statuses, nil
|
||||
return statuses, processErrorResponse(q.Scan(ctx))
|
||||
}
|
||||
|
||||
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
|
||||
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
|
||||
func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
|
||||
func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
|
||||
|
||||
faves := []*gtsmodel.StatusFave{}
|
||||
|
||||
fq := t.conn.Model(&faves).
|
||||
fq := t.conn.
|
||||
NewSelect().
|
||||
Model(&faves).
|
||||
Where("account_id = ?", accountID).
|
||||
Order("id DESC")
|
||||
|
||||
|
|
@ -164,9 +148,9 @@ func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID stri
|
|||
fq = fq.Limit(limit)
|
||||
}
|
||||
|
||||
err := fq.Select()
|
||||
err := fq.Scan(ctx)
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, "", "", db.ErrNoEntries
|
||||
}
|
||||
return nil, "", "", err
|
||||
|
|
@ -186,9 +170,13 @@ func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID stri
|
|||
}
|
||||
|
||||
statuses := []*gtsmodel.Status{}
|
||||
err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
|
||||
err = t.conn.
|
||||
NewSelect().
|
||||
Model(&statuses).
|
||||
Where("id IN (?)", bun.In(in)).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, "", "", db.ErrNoEntries
|
||||
}
|
||||
return nil, "", "", err
|
||||
|
|
|
|||
|
|
@ -1,11 +1,30 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package pg
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"context"
|
||||
|
||||
"database/sql"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// processErrorResponse parses the given error and returns an appropriate DBError.
|
||||
|
|
@ -16,9 +35,40 @@ func processErrorResponse(err error) db.Error {
|
|||
case sql.ErrNoRows:
|
||||
return db.ErrNoEntries
|
||||
default:
|
||||
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
|
||||
return db.ErrAlreadyExists
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
exists := count != 0
|
||||
|
||||
err = processErrorResponse(err)
|
||||
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
notExists := count == 0
|
||||
|
||||
err = processErrorResponse(err)
|
||||
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
return notExists, nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue