This commit is contained in:
tsmethurst 2021-08-24 16:54:54 +02:00
commit 526a14a92d
486 changed files with 84353 additions and 23865 deletions

View file

@ -250,7 +250,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
return err
}
a.log.Infof("instance account CREATED with id %s", username, acct.ID)
a.log.Infof("instance account %s CREATED with id %s", username, acct.ID)
return nil
}

View file

@ -175,7 +175,7 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
}
func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewDropTable().Model(i).Exec(ctx)
_, err := b.conn.NewDropTable().IfExists().Model(i).Exec(ctx)
return processErrorResponse(err)
}

View file

@ -42,21 +42,13 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db
return false, nil
}
count, err := d.conn.
q := d.conn.
NewSelect().
Model(&gtsmodel.DomainBlock{}).
Where("LOWER(domain) = LOWER(?)", domain).
Limit(1).
Count(ctx)
Limit(1)
blocked := count != 0
err = processErrorResponse(err)
if err != db.ErrNoEntries {
return false, err
}
return blocked, nil
return exists(ctx, q)
}
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {

View file

@ -30,6 +30,8 @@ import (
"strings"
"time"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
@ -37,7 +39,6 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
)
var registerTables []interface{} = []interface{}{
@ -67,19 +68,13 @@ type postgresService struct {
// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
for _, t := range registerTables {
// https://pg.uptrace.dev/orm/many-to-many-relation/
bun.RegisterModel(t)
}
opts, err := derivePGOptions(c)
if err != nil {
return nil, fmt.Errorf("could not create postgres service: %s", err)
return nil, fmt.Errorf("could not create postgres options: %s", err)
}
log.Debugf("using pg options: %+v", opts)
sqldb := sql.OpenDB(pgdriver.NewConnector(opts...))
log.Debugf("using opts %+v", opts)
sqldb := stdlib.OpenDB(*opts)
conn := bun.NewDB(sqldb, pgdialect.New())
// actually *begin* the connection so that we can tell if the db is there and listening
@ -88,6 +83,12 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
}
log.Info("connected to postgres")
for _, t := range registerTables {
// https://bun.uptrace.dev/orm/many-to-many-relation/
conn.RegisterModel(t)
}
log.Info("models registered")
ps := &postgresService{
Account: &accountDB{
config: c,
@ -157,9 +158,9 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
HANDY STUFF
*/
// derivePGOptions takes an application config and returns either a ready-to-use *pg.Options
// derivePGOptions takes an application config and returns either a ready-to-use set of options
// with sensible defaults, or an error if it's not satisfied by the provided config.
func derivePGOptions(c *config.Config) ([]pgdriver.DriverOption, error) {
func derivePGOptions(c *config.Config) (*pgx.ConnConfig, error) {
if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres {
return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type)
}
@ -237,18 +238,15 @@ func derivePGOptions(c *config.Config) ([]pgdriver.DriverOption, error) {
tlsConfig.RootCAs = certPool
}
// We can rely on the pg library we're using to set
// sensible defaults for everything we don't set here.
options := []pgdriver.DriverOption{
pgdriver.WithAddr(fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port)),
pgdriver.WithUser(c.DBConfig.User),
pgdriver.WithPassword(c.DBConfig.Password),
pgdriver.WithDatabase(c.DBConfig.Database),
pgdriver.WithApplicationName(c.ApplicationName),
pgdriver.WithTLSConfig(tlsConfig),
}
opts, _ := pgx.ParseConfig("")
opts.Host = c.DBConfig.Address
opts.Port = uint16(c.DBConfig.Port)
opts.User = c.DBConfig.User
opts.Password = c.DBConfig.Password
opts.TLSConfig = tlsConfig
opts.PreferSimpleProtocol = true
return options, nil
return opts, nil
}
/*
@ -257,9 +255,9 @@ func derivePGOptions(c *config.Config) ([]pgdriver.DriverOption, error) {
// TODO: move these to the type converter, it's bananas that they're here and not there
func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
func (ps *postgresService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
ogAccount := &gtsmodel.Account{}
if err := ps.conn.Model(ogAccount).Where("id = ?", originAccountID).Select(); err != nil {
if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil {
return nil, err
}
@ -304,14 +302,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori
// match username + account, case insensitive
if local {
// local user -- should have a null domain
err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("? IS NULL", pg.Ident("domain")).Select()
err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("? IS NULL", bun.Ident("domain")).Scan(ctx)
} else {
// remote user -- should have domain defined
err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("LOWER(?) = LOWER(?)", pg.Ident("domain"), domain).Select()
err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("LOWER(?) = LOWER(?)", bun.Ident("domain"), domain).Scan(ctx)
}
if err != nil {
if err == pg.ErrNoRows {
if err == sql.ErrNoRows {
// no result found for this username/domain so just don't include it as a mencho and carry on about our business
ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
continue
@ -335,14 +333,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori
return menchies, nil
}
func (ps *postgresService) TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
func (ps *postgresService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
newTags := []*gtsmodel.Tag{}
for _, t := range tags {
tag := &gtsmodel.Tag{}
// we can use selectorinsert here to create the new tag if it doesn't exist already
// inserted will be true if this is a new tag we just created
if err := ps.conn.Model(tag).Where("LOWER(?) = LOWER(?)", pg.Ident("name"), t).Select(); err != nil {
if err == pg.ErrNoRows {
if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil {
if err == sql.ErrNoRows {
// tag doesn't exist yet so populate it
newID, err := id.NewRandomULID()
if err != nil {
@ -371,13 +369,13 @@ func (ps *postgresService) TagStringsToTags(tags []string, originAccountID strin
return newTags, nil
}
func (ps *postgresService) EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
func (ps *postgresService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
newEmojis := []*gtsmodel.Emoji{}
for _, e := range emojis {
emoji := &gtsmodel.Emoji{}
err := ps.conn.Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Select()
err := ps.conn.NewSelect().Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Scan(ctx)
if err != nil {
if err == pg.ErrNoRows {
if err == sql.ErrNoRows {
// no result found for this username/domain so just don't include it as an emoji and carry on about our business
ps.log.Debugf("no emoji found with shortcode %s, skipping it", e)
continue

View file

@ -67,16 +67,7 @@ func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account
Where("account_id = ?", account2)
}
count, err := q.Count(ctx)
blocked := count != 0
err = processErrorResponse(err)
if err != db.ErrNoEntries {
return false, err
}
return blocked, nil
return exists(ctx, q)
}
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
@ -186,16 +177,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
Where("target_account_id = ?", targetAccount.ID).
Limit(1)
count, err := q.Count(ctx)
following := count != 0
err = processErrorResponse(err)
if err != db.ErrNoEntries {
return false, err
}
return following, nil
return exists(ctx, q)
}
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
@ -209,16 +191,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID)
count, err := q.Count(ctx)
followRequested := count != 0
err = processErrorResponse(err)
if err != db.ErrNoEntries {
return false, err
}
return followRequested, nil
return exists(ctx, q)
}
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {

View file

@ -24,8 +24,6 @@ import (
"errors"
"time"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
@ -71,8 +69,10 @@ func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
return status, true
}
func (s *statusDB) newStatusQ(status interface{}) *orm.Query {
return s.conn.Model(status).
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
return s.conn.
NewSelect().
Model(status).
Relation("Attachments").
Relation("Tags").
Relation("Mentions").
@ -85,14 +85,16 @@ func (s *statusDB) newStatusQ(status interface{}) *orm.Query {
Relation("CreatedWithApplication")
}
func (s *statusDB) newFaveQ(faves interface{}) *orm.Query {
return s.conn.Model(faves).
func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
return s.conn.
NewSelect().
Model(faves).
Relation("Account").
Relation("TargetAccount").
Relation("Status")
}
func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) {
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(id); cached {
return status, nil
}
@ -102,7 +104,7 @@ func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) {
q := s.newStatusQ(status).
Where("status.id = ?", id)
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
if err == nil && status != nil {
s.cacheStatus(id, status)
@ -111,7 +113,7 @@ func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) {
return status, err
}
func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) {
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
return status, nil
}
@ -121,7 +123,7 @@ func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) {
q := s.newStatusQ(status).
Where("LOWER(status.uri) = LOWER(?)", uri)
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
if err == nil && status != nil {
s.cacheStatus(uri, status)
@ -130,7 +132,7 @@ func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) {
return status, err
}
func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) {
func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
return status, nil
}
@ -140,7 +142,7 @@ func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) {
q := s.newStatusQ(status).
Where("LOWER(status.url) = LOWER(?)", uri)
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
if err == nil && status != nil {
s.cacheStatus(uri, status)
@ -149,24 +151,24 @@ func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) {
return status, err
}
func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error {
transaction := func(tx *pg.Tx) error {
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
transaction := func(ctx context.Context, tx bun.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
if _, err := tx.Model(&gtsmodel.StatusToEmoji{
if _, err := tx.NewInsert().Model(&gtsmodel.StatusToEmoji{
StatusID: status.ID,
EmojiID: i,
}).Insert(); err != nil {
}).Exec(ctx); err != nil {
return err
}
}
// create links between this status and any tags it uses
for _, i := range status.TagIDs {
if _, err := tx.Model(&gtsmodel.StatusToTag{
if _, err := tx.NewInsert().Model(&gtsmodel.StatusToTag{
StatusID: status.ID,
TagID: i,
}).Insert(); err != nil {
}).Exec(ctx); err != nil {
return err
}
}
@ -175,33 +177,33 @@ func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error {
for _, a := range status.Attachments {
a.StatusID = status.ID
a.UpdatedAt = time.Now()
if _, err := s.conn.Model(a).
if _, err := s.conn.NewUpdate().Model(a).
Where("id = ?", a.ID).
Update(); err != nil {
Exec(ctx); err != nil {
return err
}
}
_, err := tx.Model(status).Insert()
_, err := tx.NewInsert().Model(status).Exec(ctx)
return err
}
return processErrorResponse(s.conn.RunInTransaction(context.Background(), transaction))
return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
}
func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
parents := []*gtsmodel.Status{}
s.statusParent(status, &parents, onlyDirect)
s.statusParent(ctx, status, &parents, onlyDirect)
return parents, nil
}
func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
if status.InReplyToID == "" {
return
}
parentStatus, err := s.GetStatusByID(status.InReplyToID)
parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID)
if err == nil {
*foundStatuses = append(*foundStatuses, parentStatus)
}
@ -210,13 +212,13 @@ func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmo
return
}
s.statusParent(parentStatus, foundStatuses, false)
s.statusParent(ctx, parentStatus, foundStatuses, false)
}
func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
foundStatuses := &list.List{}
foundStatuses.PushFront(status)
s.statusChildren(status, foundStatuses, onlyDirect, minID)
s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID)
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
@ -234,15 +236,18 @@ func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, m
return children, nil
}
func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
immediateChildren := []*gtsmodel.Status{}
q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
q := s.conn.
NewSelect().
Model(&immediateChildren).
Where("in_reply_to_id = ?", status.ID)
if minID != "" {
q = q.Where("status.id > ?", minID)
}
if err := q.Select(); err != nil {
if err := q.Scan(ctx); err != nil {
return
}
@ -264,56 +269,78 @@ func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.L
if onlyDirect {
return
}
s.statusChildren(child, foundStatuses, false, minID)
s.statusChildren(ctx, child, foundStatuses, false, minID)
}
}
func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
return s.conn.NewSelect().Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx)
}
func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
return s.conn.NewSelect().Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx)
}
func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
return s.conn.NewSelect().Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx)
}
func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
Model(&gtsmodel.StatusFave{}).
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
}
func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
Model(&gtsmodel.Status{}).
Where("boost_of_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
}
func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
Model(&gtsmodel.StatusMute{}).
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
}
func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
Model(&gtsmodel.StatusBookmark{}).
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
}
func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
faves := []*gtsmodel.StatusFave{}
q := s.newFaveQ(&faves).
Where("status_id = ?", status.ID)
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
return faves, err
}
func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
reblogs := []*gtsmodel.Status{}
q := s.newStatusQ(&reblogs).
Where("boost_of_id = ?", status.ID)
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
return reblogs, err
}

View file

@ -19,6 +19,7 @@
package pg_test
import (
"context"
"fmt"
"testing"
"time"
@ -56,7 +57,7 @@ func (suite *StatusTestSuite) TearDownTest() {
}
func (suite *StatusTestSuite) TestGetStatusByID() {
status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID)
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
@ -67,10 +68,11 @@ func (suite *StatusTestSuite) TestGetStatusByID() {
suite.Nil(status.BoostOfAccount)
suite.Nil(status.InReplyTo)
suite.Nil(status.InReplyToAccount)
suite.log.Debug("test finished")
}
func (suite *StatusTestSuite) TestGetStatusByURI() {
status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
if err != nil {
suite.FailNow(err.Error())
}
@ -84,7 +86,7 @@ func (suite *StatusTestSuite) TestGetStatusByURI() {
}
func (suite *StatusTestSuite) TestGetStatusWithExtras() {
status, err := suite.db.GetStatusByID(suite.testStatuses["admin_account_status_1"].ID)
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["admin_account_status_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
@ -97,7 +99,7 @@ func (suite *StatusTestSuite) TestGetStatusWithExtras() {
}
func (suite *StatusTestSuite) TestGetStatusWithMention() {
status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_2_status_5"].ID)
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_2_status_5"].ID)
if err != nil {
suite.FailNow(err.Error())
}
@ -112,14 +114,14 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() {
func (suite *StatusTestSuite) TestGetStatusTwice() {
before1 := time.Now()
_, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
_, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after1 := time.Now()
duration1 := after1.Sub(before1)
fmt.Println(duration1.Nanoseconds())
before2 := time.Now()
_, err = suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
_, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after2 := time.Now()
duration2 := after2.Sub(before2)

View file

@ -20,9 +20,9 @@ package pg
import (
"context"
"database/sql"
"sort"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
@ -37,23 +37,27 @@ type timelineDB struct {
cancel context.CancelFunc
}
func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
q := t.conn.Model(&statuses)
q := t.conn.
NewSelect().
Model(&statuses)
// Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows,
// OR statuses posted by accountID itself (since a user should be able to see their own statuses).
//
// This is equivalent to something like WHERE ... AND (... OR ...)
// See: https://pg.uptrace.dev/queries/#select
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
q = q.Where("f.account_id = ?", accountID).
WhereOr("status.account_id = ?", accountID)
return q
}
q = q.ColumnExpr("status.*").
// Find out who accountID follows.
Join("LEFT JOIN follows AS f ON f.target_account_id = status.account_id").
// Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows,
// OR statuses posted by accountID itself (since a user should be able to see their own statuses).
//
// This is equivalent to something like WHERE ... AND (... OR ...)
// See: https://pg.uptrace.dev/queries/#select
WhereGroup(func(q *pg.Query) (*pg.Query, error) {
q = q.WhereOr("f.account_id = ?", accountID).
WhereOr("status.account_id = ?", accountID)
return q, nil
}).
WhereGroup(" AND ", whereGroup).
// Sort by highest ID (newest) to lowest ID (oldest)
Order("status.id DESC")
@ -82,29 +86,19 @@ func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID str
q = q.Limit(limit)
}
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries
}
return statuses, nil
return statuses, processErrorResponse(q.Scan(ctx))
}
func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
q := t.conn.Model(&statuses).
q := t.conn.
NewSelect().
Model(&statuses).
Where("visibility = ?", gtsmodel.VisibilityPublic).
Where("? IS NULL", pg.Ident("in_reply_to_id")).
Where("? IS NULL", pg.Ident("in_reply_to_uri")).
Where("? IS NULL", pg.Ident("boost_of_id")).
Where("? IS NULL", bun.Ident("in_reply_to_id")).
Where("? IS NULL", bun.Ident("in_reply_to_uri")).
Where("? IS NULL", bun.Ident("boost_of_id")).
Order("status.id DESC")
if maxID != "" {
@ -127,28 +121,18 @@ func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID s
q = q.Limit(limit)
}
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries
}
return statuses, nil
return statuses, processErrorResponse(q.Scan(ctx))
}
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
faves := []*gtsmodel.StatusFave{}
fq := t.conn.Model(&faves).
fq := t.conn.
NewSelect().
Model(&faves).
Where("account_id = ?", accountID).
Order("id DESC")
@ -164,9 +148,9 @@ func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID stri
fq = fq.Limit(limit)
}
err := fq.Select()
err := fq.Scan(ctx)
if err != nil {
if err == pg.ErrNoRows {
if err == sql.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
@ -186,9 +170,13 @@ func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID stri
}
statuses := []*gtsmodel.Status{}
err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
err = t.conn.
NewSelect().
Model(&statuses).
Where("id IN (?)", bun.In(in)).
Scan(ctx)
if err != nil {
if err == pg.ErrNoRows {
if err == sql.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err

View file

@ -1,11 +1,30 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"strings"
"context"
"database/sql"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
)
// processErrorResponse parses the given error and returns an appropriate DBError.
@ -16,9 +35,40 @@ func processErrorResponse(err error) db.Error {
case sql.ErrNoRows:
return db.ErrNoEntries
default:
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
}
func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
count, err := q.Count(ctx)
exists := count != 0
err = processErrorResponse(err)
if err != nil {
if err == db.ErrNoEntries {
return false, nil
}
return false, err
}
return exists, nil
}
func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
count, err := q.Count(ctx)
notExists := count == 0
err = processErrorResponse(err)
if err != nil {
if err == db.ErrNoEntries {
return true, nil
}
return false, err
}
return notExists, nil
}