Pg to bun (#148)

* start moving to bun

* changing more stuff

* more

* and yet more

* tests passing

* seems stable now

* more big changes

* small fix

* little fixes
This commit is contained in:
tobi 2021-08-25 15:34:33 +02:00 committed by GitHub
commit 2dc9fc1626
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
713 changed files with 98694 additions and 22704 deletions

View file

@ -19,6 +19,7 @@
package db
import (
"context"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -27,40 +28,43 @@ import (
// Account contains functions related to account getting/setting/creation.
type Account interface {
// GetAccountByID returns one account with the given ID, or an error if something goes wrong.
GetAccountByID(id string) (*gtsmodel.Account, Error)
GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, Error)
// GetAccountByURI returns one account with the given URI, or an error if something goes wrong.
GetAccountByURI(uri string) (*gtsmodel.Account, Error)
GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
GetAccountByURL(uri string) (*gtsmodel.Account, Error)
GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// UpdateAccount updates one account by ID.
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
// GetLocalAccountByUsername returns an account on this instance by its username.
GetLocalAccountByUsername(username string) (*gtsmodel.Account, Error)
GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, Error)
// GetAccountFaves fetches faves/likes created by the target accountID.
GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, Error)
GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, Error)
// GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
CountAccountStatuses(accountID string) (int, Error)
CountAccountStatuses(ctx context.Context, accountID string) (int, Error)
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
// In case of no entries, a 'no entries' error will be returned
GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error)
GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error)
GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
// GetAccountLastPosted simply gets the timestamp of the most recent post by the account.
//
// The returned time will be zero if account has never posted anything.
GetAccountLastPosted(accountID string) (time.Time, Error)
GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, Error)
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
// GetInstanceAccount returns the instance account for the given domain.
// If domain is empty, this instance account will be returned.
GetInstanceAccount(domain string) (*gtsmodel.Account, Error)
GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, Error)
}

View file

@ -19,6 +19,7 @@
package db
import (
"context"
"net"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -28,26 +29,26 @@ import (
type Admin interface {
// IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db.
IsUsernameAvailable(username string) Error
IsUsernameAvailable(ctx context.Context, username string) (bool, Error)
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if:
// A) the email is already associated with an account
// B) we block signups from this email domain
// C) something went wrong in the db
IsEmailAvailable(email string) Error
IsEmailAvailable(ctx context.Context, email string) (bool, Error)
// NewSignup creates a new user in the database with the given parameters.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error)
NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
// This is needed for things like serving files that belong to the instance and not an individual user/account.
CreateInstanceAccount() Error
CreateInstanceAccount(ctx context.Context) Error
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
// This is needed for things like serving instance information through /api/v1/instance
CreateInstanceInstance() Error
CreateInstanceInstance(ctx context.Context) Error
}

View file

@ -24,15 +24,11 @@ import "context"
type Basic interface {
// CreateTable creates a table for the given interface.
// For implementations that don't use tables, this can just return nil.
CreateTable(i interface{}) Error
CreateTable(ctx context.Context, i interface{}) Error
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
DropTable(i interface{}) Error
// RegisterTable registers a table for use in many2many relations.
// For implementations that don't use tables, or many2many relations, this can just return nil.
RegisterTable(i interface{}) Error
DropTable(ctx context.Context, i interface{}) Error
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil.
@ -45,43 +41,38 @@ type Basic interface {
// for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetByID(id string, i interface{}) Error
GetByID(ctx context.Context, id string, i interface{}) Error
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetWhere(where []Where, i interface{}) Error
GetWhere(ctx context.Context, where []Where, i interface{}) Error
// GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetAll(i interface{}) Error
GetAll(ctx context.Context, i interface{}) Error
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Put(i interface{}) Error
// Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
// It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Upsert(i interface{}, conflictColumn string) Error
Put(ctx context.Context, i interface{}) Error
// UpdateByID updates i with id id.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
UpdateByID(id string, i interface{}) Error
UpdateByID(ctx context.Context, id string, i interface{}) Error
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
UpdateOneByID(id string, key string, value interface{}, i interface{}) Error
UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) Error
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
UpdateWhere(where []Where, key string, value interface{}, i interface{}) Error
UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
DeleteByID(id string, i interface{}) Error
DeleteByID(ctx context.Context, id string, i interface{}) Error
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
DeleteWhere(where []Where, i interface{}) Error
DeleteWhere(ctx context.Context, where []Where, i interface{}) Error
}

View file

@ -16,70 +16,90 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type accountDB struct {
config *config.Config
conn *pg.DB
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query {
return a.conn.Model(account).
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
return a.conn.
NewSelect().
Model(account).
Relation("AvatarMediaAttachment").
Relation("HeaderMediaAttachment")
}
func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("account.id = ?", id)
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
return account, err
}
func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("account.uri = ?", uri)
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
return account, err
}
func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("account.url = ?", uri)
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
return account, err
}
func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
if strings.TrimSpace(account.ID) == "" {
return nil, errors.New("account had no ID")
}
account.UpdatedAt = time.Now()
q := a.conn.
NewUpdate().
Model(account).
WherePK()
_, err := q.Exec(ctx)
err = processErrorResponse(err)
return account, err
}
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
q := a.newAccountQ(account)
@ -90,29 +110,31 @@ func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Err
} else {
q = q.
Where("account.username = ?", domain).
Where("? IS NULL", pg.Ident("domain"))
Where("? IS NULL", bun.Ident("domain"))
}
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
return account, err
}
func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.Error) {
status := &gtsmodel.Status{}
func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) {
status := new(gtsmodel.Status)
q := a.conn.Model(status).
q := a.conn.
NewSelect().
Model(status).
Order("id DESC").
Limit(1).
Where("account_id = ?", accountID).
Column("created_at")
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
return status.CreatedAt, err
}
func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
if mediaAttachment.Avatar && mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
@ -127,51 +149,66 @@ func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAtta
}
// TODO: there are probably more side effects here that need to be handled
if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if _, err := a.conn.
NewInsert().
Model(mediaAttachment).
Exec(ctx); err != nil {
return err
}
if _, err := a.conn.Model(&gtsmodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
if _, err := a.conn.
NewUpdate().
Model(&gtsmodel.Account{}).
Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
Where("id = ?", accountID).
Exec(ctx); err != nil {
return err
}
return nil
}
func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("username = ?", username).
Where("? IS NULL", pg.Ident("domain"))
Where("? IS NULL", bun.Ident("domain"))
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
return account, err
}
func (a *accountDB) GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, db.Error) {
faves := []*gtsmodel.StatusFave{}
func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) {
faves := new([]*gtsmodel.StatusFave)
if err := a.conn.Model(&faves).
if err := a.conn.
NewSelect().
Model(faves).
Where("account_id = ?", accountID).
Select(); err != nil {
if err == pg.ErrNoRows {
return faves, nil
}
Scan(ctx); err != nil {
return nil, err
}
return faves, nil
return *faves, nil
}
func (a *accountDB) CountAccountStatuses(accountID string) (int, db.Error) {
return a.conn.Model(&gtsmodel.Status{}).Where("account_id = ?", accountID).Count()
func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
return a.conn.
NewSelect().
Model(&gtsmodel.Status{}).
Where("account_id = ?", accountID).
Count(ctx)
}
func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
a.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
q := a.conn.Model(&statuses).Order("id DESC")
q := a.conn.
NewSelect().
Model(&statuses).
Order("id DESC")
if accountID != "" {
q = q.Where("account_id = ?", accountID)
}
@ -181,27 +218,26 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli
}
if excludeReplies {
q = q.Where("? IS NULL", pg.Ident("in_reply_to_id"))
q = q.Where("? IS NULL", bun.Ident("in_reply_to_id"))
}
if pinnedOnly {
q = q.Where("pinned = ?", true)
}
if mediaOnly {
q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) {
return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil
})
}
if maxID != "" {
q = q.Where("id < ?", maxID)
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries
}
if mediaOnly {
q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
return q.
WhereOr("? IS NOT NULL", bun.Ident("attachments")).
WhereOr("attachments != '{}'")
})
}
if err := q.Scan(ctx); err != nil {
return nil, err
}
@ -213,10 +249,12 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli
return statuses, nil
}
func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
blocks := []*gtsmodel.Block{}
fq := a.conn.Model(&blocks).
fq := a.conn.
NewSelect().
Model(&blocks).
Where("block.account_id = ?", accountID).
Relation("TargetAccount").
Order("block.id DESC")
@ -233,11 +271,8 @@ func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID str
fq = fq.Limit(limit)
}
err := fq.Select()
err := fq.Scan(ctx)
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}

View file

@ -16,17 +16,19 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
package bundb_test
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type AccountTestSuite struct {
PGStandardTestSuite
BunDBStandardTestSuite
}
func (suite *AccountTestSuite) SetupSuite() {
@ -54,7 +56,7 @@ func (suite *AccountTestSuite) TearDownTest() {
}
func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
account, err := suite.db.GetAccountByID(suite.testAccounts["local_account_1"].ID)
account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
@ -65,6 +67,20 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
suite.NotEmpty(account.HeaderMediaAttachment.URL)
}
func (suite *AccountTestSuite) TestUpdateAccount() {
testAccount := suite.testAccounts["local_account_1"]
testAccount.DisplayName = "new display name!"
_, err := suite.db.UpdateAccount(context.Background(), testAccount)
suite.NoError(err)
updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID)
suite.NoError(err)
suite.Equal("new display name!", updated.DisplayName)
suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second)
}
func TestAccountTestSuite(t *testing.T) {
suite.Run(t, new(AccountTestSuite))
}

View file

@ -16,76 +16,76 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
"crypto/rand"
"crypto/rsa"
"database/sql"
"fmt"
"net"
"net/mail"
"strings"
"time"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
"golang.org/x/crypto/bcrypt"
)
type adminDB struct {
config *config.Config
conn *pg.DB
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (a *adminDB) IsUsernameAvailable(username string) db.Error {
// if no error we fail because it means we found something
// if error but it's not pg.ErrNoRows then we fail
// if err is pg.ErrNoRows we're good, we found nothing so continue
if err := a.conn.Model(&gtsmodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
return fmt.Errorf("username %s already in use", username)
} else if err != pg.ErrNoRows {
return fmt.Errorf("db error: %s", err)
}
return nil
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
q := a.conn.
NewSelect().
Model(&gtsmodel.Account{}).
Where("username = ?", username).
Where("domain = ?", nil)
return notExists(ctx, q)
}
func (a *adminDB) IsEmailAvailable(email string) db.Error {
func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
return fmt.Errorf("error parsing email address %s: %s", email, err)
return false, fmt.Errorf("error parsing email address %s: %s", email, err)
}
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
if err := a.conn.Model(&gtsmodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
if err := a.conn.
NewSelect().
Model(&gtsmodel.EmailDomainBlock{}).
Where("domain = ?", domain).
Scan(ctx); err == nil {
// fail because we found something
return fmt.Errorf("email domain %s is blocked", domain)
} else if err != pg.ErrNoRows {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
return false, fmt.Errorf("email domain %s is blocked", domain)
} else if err != sql.ErrNoRows {
return false, processErrorResponse(err)
}
// check if this email is associated with a user already
if err := a.conn.Model(&gtsmodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email %s already in use", email)
} else if err != pg.ErrNoRows {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
}
return nil
q := a.conn.
NewSelect().
Model(&gtsmodel.User{}).
Where("email = ?", email).
WhereOr("unconfirmed_email = ?", email)
return notExists(ctx, q)
}
func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
@ -94,13 +94,12 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool
// if something went wrong while creating a user, we might already have an account, so check here first...
acct := &gtsmodel.Account{}
err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
err = a.conn.NewSelect().
Model(acct).
Where("username = ?", username).
Where("? IS NULL", bun.Ident("domain")).
Scan(ctx)
if err != nil {
// there's been an actual error
if err != pg.ErrNoRows {
return nil, fmt.Errorf("db error checking existence of account: %s", err)
}
// we just don't have an account yet create one
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
newAccountID, err := id.NewRandomULID()
@ -125,7 +124,10 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
if _, err = a.conn.Model(acct).Insert(); err != nil {
if _, err = a.conn.
NewInsert().
Model(acct).
Exec(ctx); err != nil {
return nil, err
}
}
@ -161,15 +163,33 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool
u.Moderator = true
}
if _, err = a.conn.Model(u).Insert(); err != nil {
if _, err = a.conn.
NewInsert().
Model(u).
Exec(ctx); err != nil {
return nil, err
}
return u, nil
}
func (a *adminDB) CreateInstanceAccount() db.Error {
func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
username := a.config.Host
// check if instance account already exists
existsQ := a.conn.
NewSelect().
Model(&gtsmodel.Account{}).
Where("username = ?", username).
Where("? IS NULL", bun.Ident("domain"))
count, err := existsQ.Count(ctx)
if err != nil && count == 1 {
a.log.Infof("instance account %s already exists", username)
return nil
} else if err != sql.ErrNoRows {
return processErrorResponse(err)
}
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
@ -198,19 +218,36 @@ func (a *adminDB) CreateInstanceAccount() db.Error {
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert()
if err != nil {
insertQ := a.conn.
NewInsert().
Model(acct)
if _, err := insertQ.Exec(ctx); err != nil {
return err
}
if inserted {
a.log.Infof("created instance account %s with id %s", username, acct.ID)
} else {
a.log.Infof("instance account %s already exists with id %s", username, acct.ID)
}
a.log.Infof("instance account %s CREATED with id %s", username, acct.ID)
return nil
}
func (a *adminDB) CreateInstanceInstance() db.Error {
func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
domain := a.config.Host
// check if instance entry already exists
existsQ := a.conn.
NewSelect().
Model(&gtsmodel.Instance{}).
Where("domain = ?", domain)
count, err := existsQ.Count(ctx)
if err != nil && count == 1 {
a.log.Infof("instance instance %s already exists", domain)
return nil
} else if err != sql.ErrNoRows {
return processErrorResponse(err)
}
iID, err := id.NewRandomULID()
if err != nil {
return err
@ -218,18 +255,18 @@ func (a *adminDB) CreateInstanceInstance() db.Error {
i := &gtsmodel.Instance{
ID: iID,
Domain: a.config.Host,
Title: a.config.Host,
Domain: domain,
Title: domain,
URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host),
}
inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert()
if err != nil {
insertQ := a.conn.
NewInsert().
Model(i)
if _, err := insertQ.Exec(ctx); err != nil {
return err
}
if inserted {
a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID)
} else {
a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID)
}
a.log.Infof("created instance instance %s with id %s", domain, i.ID)
return nil
}

179
internal/db/bundb/basic.go Normal file
View file

@ -0,0 +1,179 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"errors"
"strings"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
)
type basicDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
}
func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewInsert().Model(i).Exec(ctx)
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error {
q := b.conn.
NewSelect().
Model(i).
Where("id = ?", id)
return processErrorResponse(q.Scan(ctx))
}
func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := b.conn.NewSelect().Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", bun.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
}
}
}
return processErrorResponse(q.Scan(ctx))
}
func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
q := b.conn.
NewSelect().
Model(i)
return processErrorResponse(q.Scan(ctx))
}
func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error {
q := b.conn.
NewDelete().
Model(i).
Where("id = ?", id)
_, err := q.Exec(ctx)
return processErrorResponse(err)
}
func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := b.conn.
NewDelete().
Model(i)
for _, w := range where {
q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
}
_, err := q.Exec(ctx)
return processErrorResponse(err)
}
func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error {
q := b.conn.
NewUpdate().
Model(i).
WherePK()
_, err := q.Exec(ctx)
return processErrorResponse(err)
}
func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error {
q := b.conn.NewUpdate().
Model(i).
Set("? = ?", bun.Safe(key), value).
WherePK()
_, err := q.Exec(ctx)
return processErrorResponse(err)
}
func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
q := b.conn.NewUpdate().Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", bun.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
}
}
}
q = q.Set("? = ?", bun.Safe(key), value)
_, err := q.Exec(ctx)
return processErrorResponse(err)
}
func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewCreateTable().Model(i).IfNotExists().Exec(ctx)
return err
}
func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
return processErrorResponse(err)
}
func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
return b.conn.Ping()
}
func (b *basicDB) Stop(ctx context.Context) db.Error {
b.log.Info("closing db connection")
if err := b.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
return err
}
return nil
}

View file

@ -0,0 +1,68 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb_test
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type BasicTestSuite struct {
BunDBStandardTestSuite
}
func (suite *BasicTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
suite.testAttachments = testrig.NewTestAttachments()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
}
func (suite *BasicTestSuite) SetupTest() {
suite.config = testrig.NewTestConfig()
suite.db = testrig.NewTestDB()
suite.log = testrig.NewTestLog()
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
func (suite *BasicTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *BasicTestSuite) TestGetAccountByID() {
testAccount := suite.testAccounts["local_account_1"]
a := &gtsmodel.Account{}
err := suite.db.GetByID(context.Background(), testAccount.ID, a)
suite.NoError(err)
}
func TestBasicTestSuite(t *testing.T) {
suite.Run(t, new(BasicTestSuite))
}

View file

@ -16,12 +16,13 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"encoding/pem"
"errors"
"fmt"
@ -29,14 +30,20 @@ import (
"strings"
"time"
"github.com/go-pg/pg/extra/pgdebug"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
)
const (
dbTypePostgres = "postgres"
dbTypeSqlite = "sqlite"
)
var registerTables []interface{} = []interface{}{
@ -44,8 +51,8 @@ var registerTables []interface{} = []interface{}{
&gtsmodel.StatusToTag{},
}
// postgresService satisfies the DB interface
type postgresService struct {
// bunDBService satisfies the DB interface
type bunDBService struct {
db.Account
db.Admin
db.Basic
@ -55,130 +62,115 @@ type postgresService struct {
db.Mention
db.Notification
db.Relationship
db.Session
db.Status
db.Timeline
config *config.Config
conn *pg.DB
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
for _, t := range registerTables {
// https://pg.uptrace.dev/orm/many-to-many-relation/
orm.RegisterTable(t)
}
// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
var sqldb *sql.DB
var conn *bun.DB
opts, err := derivePGOptions(c)
if err != nil {
return nil, fmt.Errorf("could not create postgres service: %s", err)
}
log.Debugf("using pg options: %+v", opts)
// create a connection
pgCtx, cancel := context.WithCancel(ctx)
conn := pg.Connect(opts).WithContext(pgCtx)
// this will break the logfmt format we normally log in,
// since we can't choose where pg outputs to and it defaults to
// stdout. So use this option with care!
if log.GetLevel() >= logrus.TraceLevel {
conn.AddQueryHook(pgdebug.DebugHook{
// Print all queries.
Verbose: true,
})
// depending on the database type we're trying to create, we need to use a different driver...
switch strings.ToLower(c.DBConfig.Type) {
case dbTypePostgres:
// POSTGRES
opts, err := deriveBunDBPGOptions(c)
if err != nil {
return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
}
sqldb = stdlib.OpenDB(*opts)
conn = bun.NewDB(sqldb, pgdialect.New())
case dbTypeSqlite:
// SQLITE
// TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite
default:
return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
}
// actually *begin* the connection so that we can tell if the db is there and listening
if err := conn.Ping(ctx); err != nil {
cancel()
if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("db connection error: %s", err)
}
log.Info("connected to database")
// print out discovered postgres version
var version string
if _, err = conn.QueryOneContext(ctx, pg.Scan(&version), "SELECT version()"); err != nil {
cancel()
return nil, fmt.Errorf("db connection error: %s", err)
for _, t := range registerTables {
// https://bun.uptrace.dev/orm/many-to-many-relation/
conn.RegisterModel(t)
}
log.Infof("connected to postgres version: %s", version)
ps := &postgresService{
ps := &bunDBService{
Account: &accountDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Admin: &adminDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Basic: &basicDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Domain: &domainDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Instance: &instanceDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Media: &mediaDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Mention: &mentionDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Notification: &notificationDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Relationship: &relationshipDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Session: &sessionDB{
config: c,
conn: conn,
log: log,
},
Status: &statusDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Timeline: &timelineDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
config: c,
conn: conn,
log: log,
cancel: cancel,
}
// we can confidently return this useable postgres service now
// we can confidently return this useable service now
return ps, nil
}
@ -186,9 +178,9 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
HANDY STUFF
*/
// derivePGOptions takes an application config and returns either a ready-to-use *pg.Options
// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options
// with sensible defaults, or an error if it's not satisfied by the provided config.
func derivePGOptions(c *config.Config) (*pg.Options, error) {
func deriveBunDBPGOptions(c *config.Config) (*pgx.ConnConfig, error) {
if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres {
return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type)
}
@ -266,18 +258,16 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
tlsConfig.RootCAs = certPool
}
// We can rely on the pg library we're using to set
// sensible defaults for everything we don't set here.
options := &pg.Options{
Addr: fmt.Sprintf("%s:%d", c.DBConfig.Address, c.DBConfig.Port),
User: c.DBConfig.User,
Password: c.DBConfig.Password,
Database: c.DBConfig.Database,
ApplicationName: c.ApplicationName,
TLSConfig: tlsConfig,
}
cfg, _ := pgx.ParseConfig("")
cfg.Host = c.DBConfig.Address
cfg.Port = uint16(c.DBConfig.Port)
cfg.User = c.DBConfig.User
cfg.Password = c.DBConfig.Password
cfg.TLSConfig = tlsConfig
cfg.Database = c.DBConfig.Database
cfg.PreferSimpleProtocol = true
return options, nil
return cfg, nil
}
/*
@ -286,9 +276,9 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
// TODO: move these to the type converter, it's bananas that they're here and not there
func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
ogAccount := &gtsmodel.Account{}
if err := ps.conn.Model(ogAccount).Where("id = ?", originAccountID).Select(); err != nil {
if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil {
return nil, err
}
@ -333,14 +323,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori
// match username + account, case insensitive
if local {
// local user -- should have a null domain
err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("? IS NULL", pg.Ident("domain")).Select()
err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("? IS NULL", bun.Ident("domain")).Scan(ctx)
} else {
// remote user -- should have domain defined
err = ps.conn.Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", pg.Ident("username"), username).Where("LOWER(?) = LOWER(?)", pg.Ident("domain"), domain).Select()
err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("LOWER(?) = LOWER(?)", bun.Ident("domain"), domain).Scan(ctx)
}
if err != nil {
if err == pg.ErrNoRows {
if err == sql.ErrNoRows {
// no result found for this username/domain so just don't include it as a mencho and carry on about our business
ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
continue
@ -364,14 +354,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori
return menchies, nil
}
func (ps *postgresService) TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
newTags := []*gtsmodel.Tag{}
for _, t := range tags {
tag := &gtsmodel.Tag{}
// we can use selectorinsert here to create the new tag if it doesn't exist already
// inserted will be true if this is a new tag we just created
if err := ps.conn.Model(tag).Where("LOWER(?) = LOWER(?)", pg.Ident("name"), t).Select(); err != nil {
if err == pg.ErrNoRows {
if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil {
if err == sql.ErrNoRows {
// tag doesn't exist yet so populate it
newID, err := id.NewRandomULID()
if err != nil {
@ -400,13 +390,13 @@ func (ps *postgresService) TagStringsToTags(tags []string, originAccountID strin
return newTags, nil
}
func (ps *postgresService) EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
newEmojis := []*gtsmodel.Emoji{}
for _, e := range emojis {
emoji := &gtsmodel.Emoji{}
err := ps.conn.Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Select()
err := ps.conn.NewSelect().Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Scan(ctx)
if err != nil {
if err == pg.ErrNoRows {
if err == sql.ErrNoRows {
// no result found for this username/domain so just don't include it as an emoji and carry on about our business
ps.log.Debugf("no emoji found with shortcode %s, skipping it", e)
continue

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
package bundb_test
import (
"github.com/sirupsen/logrus"
@ -27,7 +27,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
type PGStandardTestSuite struct {
type BunDBStandardTestSuite struct {
// standard suite interfaces
suite.Suite
config *config.Config

View file

@ -16,48 +16,46 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
"net/url"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
type domainDB struct {
config *config.Config
conn *pg.DB
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (d *domainDB) IsDomainBlocked(domain string) (bool, db.Error) {
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
if domain == "" {
return false, nil
}
blocked, err := d.conn.
q := d.conn.
NewSelect().
Model(&gtsmodel.DomainBlock{}).
Where("LOWER(domain) = LOWER(?)", domain).
Exists()
Limit(1)
err = processErrorResponse(err)
return blocked, err
return exists(ctx, q)
}
func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) {
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
// filter out any doubles
uniqueDomains := util.UniqueStrings(domains)
for _, domain := range uniqueDomains {
if blocked, err := d.IsDomainBlocked(domain); err != nil {
if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil {
return false, err
} else if blocked {
return blocked, nil
@ -68,16 +66,16 @@ func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) {
return false, nil
}
func (d *domainDB) IsURIBlocked(uri *url.URL) (bool, db.Error) {
func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, db.Error) {
domain := uri.Hostname()
return d.IsDomainBlocked(domain)
return d.IsDomainBlocked(ctx, domain)
}
func (d *domainDB) AreURIsBlocked(uris []*url.URL) (bool, db.Error) {
func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, db.Error) {
domains := []string{}
for _, uri := range uris {
domains = append(domains, uri.Hostname())
}
return d.AreDomainsBlocked(domains)
return d.AreDomainsBlocked(ctx, domains)
}

View file

@ -16,43 +16,50 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type instanceDB struct {
config *config.Config
conn *pg.DB
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (i *instanceDB) CountInstanceUsers(domain string) (int, db.Error) {
q := i.conn.Model(&[]*gtsmodel.Account{})
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
q := i.conn.
NewSelect().
Model(&[]*gtsmodel.Account{})
if domain == i.config.Host {
// if the domain is *this* domain, just count where the domain field is null
q = q.Where("? IS NULL", pg.Ident("domain"))
q = q.Where("? IS NULL", bun.Ident("domain"))
} else {
q = q.Where("domain = ?", domain)
}
// don't count the instance account or suspended users
q = q.Where("username != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
q = q.
Where("username != ?", domain).
Where("? IS NULL", bun.Ident("suspended_at"))
return q.Count()
count, err := q.Count(ctx)
return count, processErrorResponse(err)
}
func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) {
q := i.conn.Model(&[]*gtsmodel.Status{})
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
q := i.conn.
NewSelect().
Model(&[]*gtsmodel.Status{})
if domain == i.config.Host {
// if the domain is *this* domain, just count where local is true
@ -63,30 +70,39 @@ func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) {
Where("account.domain = ?", domain)
}
return q.Count()
count, err := q.Count(ctx)
return count, processErrorResponse(err)
}
func (i *instanceDB) CountInstanceDomains(domain string) (int, db.Error) {
q := i.conn.Model(&[]*gtsmodel.Instance{})
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
q := i.conn.
NewSelect().
Model(&[]*gtsmodel.Instance{})
if domain == i.config.Host {
// if the domain is *this* domain, just count other instances it knows about
// exclude domains that are blocked
q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
q = q.Where("domain != ?", domain).Where("? IS NULL", bun.Ident("suspended_at"))
} else {
// TODO: implement federated domain counting properly for remote domains
return 0, nil
}
return q.Count()
count, err := q.Count(ctx)
return count, processErrorResponse(err)
}
func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
i.log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{}
q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
q := i.conn.NewSelect().
Model(&accounts).
Where("domain = ?", domain).
Order("id DESC")
if maxID != "" {
q = q.Where("id < ?", maxID)
@ -96,17 +112,7 @@ func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int)
q = q.Limit(limit)
}
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries
}
return nil, err
}
err := processErrorResponse(q.Scan(ctx))
if len(accounts) == 0 {
return nil, db.ErrNoEntries
}
return accounts, nil
return accounts, err
}

View file

@ -16,38 +16,38 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type mediaDB struct {
config *config.Config
conn *pg.DB
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (m *mediaDB) newMediaQ(i interface{}) *orm.Query {
return m.conn.Model(i).
func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery {
return m.conn.
NewSelect().
Model(i).
Relation("Account")
}
func (m *mediaDB) GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, db.Error) {
func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, db.Error) {
attachment := &gtsmodel.MediaAttachment{}
q := m.newMediaQ(attachment).
Where("media_attachment.id = ?", id)
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
return attachment, err
}

View file

@ -16,25 +16,23 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type mentionDB struct {
config *config.Config
conn *pg.DB
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}
@ -67,14 +65,16 @@ func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) {
return mention, true
}
func (m *mentionDB) newMentionQ(i interface{}) *orm.Query {
return m.conn.Model(i).
func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
return m.conn.
NewSelect().
Model(i).
Relation("Status").
Relation("OriginAccount").
Relation("TargetAccount")
}
func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) {
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
if mention, cached := m.mentionCached(id); cached {
return mention, nil
}
@ -84,7 +84,7 @@ func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) {
q := m.newMentionQ(mention).
Where("mention.id = ?", id)
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
if err == nil && mention != nil {
m.cacheMention(id, mention)
@ -93,11 +93,11 @@ func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) {
return mention, err
}
func (m *mentionDB) GetMentions(ids []string) ([]*gtsmodel.Mention, db.Error) {
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
mentions := []*gtsmodel.Mention{}
for _, i := range ids {
mention, err := m.GetMention(i)
mention, err := m.GetMention(ctx, i)
if err != nil {
return nil, processErrorResponse(err)
}

View file

@ -16,25 +16,23 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type notificationDB struct {
config *config.Config
conn *pg.DB
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}
@ -67,14 +65,16 @@ func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification,
return notification, true
}
func (n *notificationDB) newNotificationQ(i interface{}) *orm.Query {
return n.conn.Model(i).
func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery {
return n.conn.
NewSelect().
Model(i).
Relation("OriginAccount").
Relation("TargetAccount").
Relation("Status")
}
func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.Error) {
func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
if notification, cached := n.notificationCached(id); cached {
return notification, nil
}
@ -84,7 +84,7 @@ func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.
q := n.newNotificationQ(notification).
Where("notification.id = ?", id)
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
if err == nil && notification != nil {
n.cacheNotification(id, notification)
@ -93,10 +93,11 @@ func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.
return notification, err
}
func (n *notificationDB) GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
// begin by selecting just the IDs
notifIDs := []*gtsmodel.Notification{}
q := n.conn.
NewSelect().
Model(&notifIDs).
Column("id").
Where("target_account_id = ?", accountID).
@ -114,7 +115,7 @@ func (n *notificationDB) GetNotifications(accountID string, limit int, maxID str
q = q.Limit(limit)
}
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
if err != nil {
return nil, err
}
@ -123,7 +124,7 @@ func (n *notificationDB) GetNotifications(accountID string, limit int, maxID str
// reason for this is that for each notif, we can instead get it from our cache if it's cached
notifications := []*gtsmodel.Notification{}
for _, notifID := range notifIDs {
notif, err := n.GetNotification(notifID.ID)
notif, err := n.GetNotification(ctx, notifID.ID)
errP := processErrorResponse(err)
if errP != nil {
return nil, errP

View file

@ -16,44 +16,49 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
"database/sql"
"fmt"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type relationshipDB struct {
config *config.Config
conn *pg.DB
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *orm.Query {
return r.conn.Model(block).
func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery {
return r.conn.
NewSelect().
Model(block).
Relation("Account").
Relation("TargetAccount")
}
func (r *relationshipDB) newFollowQ(follow interface{}) *orm.Query {
return r.conn.Model(follow).
func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery {
return r.conn.
NewSelect().
Model(follow).
Relation("Account").
Relation("TargetAccount")
}
func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
q := r.conn.
NewSelect().
Model(&gtsmodel.Block{}).
Where("account_id = ?", account1).
Where("target_account_id = ?", account2)
Where("target_account_id = ?", account2).
Limit(1)
if eitherDirection {
q = q.
@ -61,30 +66,36 @@ func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirec
Where("account_id = ?", account2)
}
return q.Exists()
return exists(ctx, q)
}
func (r *relationshipDB) GetBlock(account1 string, account2 string) (*gtsmodel.Block, db.Error) {
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
block := &gtsmodel.Block{}
q := r.newBlockQ(block).
Where("block.account_id = ?", account1).
Where("block.target_account_id = ?", account2)
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
return block, err
}
func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
rel := &gtsmodel.Relationship{
ID: targetAccount,
}
// check if the requesting account follows the target account
follow := &gtsmodel.Follow{}
if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
if err != pg.ErrNoRows {
if err := r.conn.
NewSelect().
Model(follow).
Where("account_id = ?", requestingAccount).
Where("target_account_id = ?", targetAccount).
Limit(1).
Scan(ctx); err != nil {
if err != sql.ErrNoRows {
// a proper error
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
}
@ -100,75 +111,101 @@ func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount
}
// check if the target account follows the requesting account
followedBy, err := r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
count, err := r.conn.
NewSelect().
Model(&gtsmodel.Follow{}).
Where("account_id = ?", targetAccount).
Where("target_account_id = ?", requestingAccount).
Limit(1).
Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
}
rel.FollowedBy = followedBy
rel.FollowedBy = count > 0
// check if the requesting account blocks the target account
blocking, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
count, err = r.conn.NewSelect().
Model(&gtsmodel.Block{}).
Where("account_id = ?", requestingAccount).
Where("target_account_id = ?", targetAccount).
Limit(1).
Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
}
rel.Blocking = blocking
rel.Blocking = count > 0
// check if the target account blocks the requesting account
blockedBy, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
count, err = r.conn.
NewSelect().
Model(&gtsmodel.Block{}).
Where("account_id = ?", targetAccount).
Where("target_account_id = ?", requestingAccount).
Limit(1).
Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
rel.BlockedBy = blockedBy
rel.BlockedBy = count > 0
// check if there's a pending following request from requesting account to target account
requested, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
count, err = r.conn.
NewSelect().
Model(&gtsmodel.FollowRequest{}).
Where("account_id = ?", requestingAccount).
Where("target_account_id = ?", targetAccount).
Limit(1).
Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
rel.Requested = requested
rel.Requested = count > 0
return rel, nil
}
func (r *relationshipDB) IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
NewSelect().
Model(&gtsmodel.Follow{}).
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID)
Where("target_account_id = ?", targetAccount.ID).
Limit(1)
return q.Exists()
return exists(ctx, q)
}
func (r *relationshipDB) IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
NewSelect().
Model(&gtsmodel.FollowRequest{}).
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID)
return q.Exists()
return exists(ctx, q)
}
func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
if account1 == nil || account2 == nil {
return false, nil
}
// make sure account 1 follows account 2
f1, err := r.IsFollowing(account1, account2)
f1, err := r.IsFollowing(ctx, account1, account2)
if err != nil {
return false, processErrorResponse(err)
}
// make sure account 2 follows account 1
f2, err := r.IsFollowing(account2, account1)
f2, err := r.IsFollowing(ctx, account2, account1)
if err != nil {
return false, processErrorResponse(err)
}
@ -176,14 +213,16 @@ func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2
return f1 && f2, nil
}
func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
// make sure the original follow request exists
fr := &gtsmodel.FollowRequest{}
if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
if err == pg.ErrMultiRows {
return nil, db.ErrNoEntries
}
return nil, err
if err := r.conn.
NewSelect().
Model(fr).
Where("account_id = ?", originAccountID).
Where("target_account_id = ?", targetAccountID).
Scan(ctx); err != nil {
return nil, processErrorResponse(err)
}
// create a new follow to 'replace' the request with
@ -195,82 +234,95 @@ func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccou
}
// if the follow already exists, just update the URI -- we don't need to do anything else
if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
return nil, err
if _, err := r.conn.
NewInsert().
Model(follow).
On("CONFLICT CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
Exec(ctx); err != nil {
return nil, processErrorResponse(err)
}
// now remove the follow request
if _, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
return nil, err
if _, err := r.conn.
NewDelete().
Model(&gtsmodel.FollowRequest{}).
Where("account_id = ?", originAccountID).
Where("target_account_id = ?", targetAccountID).
Exec(ctx); err != nil {
return nil, processErrorResponse(err)
}
return follow, nil
}
func (r *relationshipDB) GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
followRequests := []*gtsmodel.FollowRequest{}
q := r.newFollowQ(&followRequests).
Where("target_account_id = ?", accountID)
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
return followRequests, err
}
func (r *relationshipDB) GetAccountFollows(accountID string) ([]*gtsmodel.Follow, db.Error) {
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
q := r.newFollowQ(&follows).
Where("account_id = ?", accountID)
err := processErrorResponse(q.Select())
err := processErrorResponse(q.Scan(ctx))
return follows, err
}
func (r *relationshipDB) CountAccountFollows(accountID string, localOnly bool) (int, db.Error) {
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
return r.conn.
NewSelect().
Model(&[]*gtsmodel.Follow{}).
Where("account_id = ?", accountID).
Count()
Count(ctx)
}
func (r *relationshipDB) GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
q := r.conn.Model(&follows)
q := r.conn.
NewSelect().
Model(&follows)
if localOnly {
// for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
whereGroup := func(q *pg.Query) (*pg.Query, error) {
whereGroup := func(q *bun.SelectQuery) *bun.SelectQuery {
q = q.
WhereOr("? IS NULL", pg.Ident("a.domain")).
WhereOr("? IS NULL", bun.Ident("a.domain")).
WhereOr("a.domain = ?", "")
return q, nil
return q
}
q = q.ColumnExpr("follow.*").
Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
Where("follow.target_account_id = ?", accountID).
WhereGroup(whereGroup)
WhereGroup(" AND ", whereGroup)
} else {
q = q.Where("target_account_id = ?", accountID)
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
if err := q.Scan(ctx); err != nil {
if err == sql.ErrNoRows {
return follows, nil
}
return nil, err
return nil, processErrorResponse(err)
}
return follows, nil
}
func (r *relationshipDB) CountAccountFollowedBy(accountID string, localOnly bool) (int, db.Error) {
func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
return r.conn.
NewSelect().
Model(&[]*gtsmodel.Follow{}).
Where("target_account_id = ?", accountID).
Count()
Count(ctx)
}

View file

@ -0,0 +1,85 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"crypto/rand"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/uptrace/bun"
)
type sessionDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
}
func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
rs := new(gtsmodel.RouterSession)
q := s.conn.
NewSelect().
Model(rs).
Limit(1)
_, err := q.Exec(ctx)
err = processErrorResponse(err)
return rs, err
}
func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
auth := make([]byte, 32)
crypt := make([]byte, 32)
if _, err := rand.Read(auth); err != nil {
return nil, err
}
if _, err := rand.Read(crypt); err != nil {
return nil, err
}
rid, err := id.NewULID()
if err != nil {
return nil, err
}
rs := &gtsmodel.RouterSession{
ID: rid,
Auth: auth,
Crypt: crypt,
}
q := s.conn.
NewInsert().
Model(rs)
_, err = q.Exec(ctx)
err = processErrorResponse(err)
return rs, err
}

375
internal/db/bundb/status.go Normal file
View file

@ -0,0 +1,375 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"container/list"
"context"
"errors"
"time"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type statusDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cache cache.Cache
}
func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
if s.cache == nil {
s.cache = cache.New()
}
if err := s.cache.Store(id, status); err != nil {
s.log.Panicf("statusDB: error storing in cache: %s", err)
}
}
func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
if s.cache == nil {
s.cache = cache.New()
return nil, false
}
sI, err := s.cache.Fetch(id)
if err != nil || sI == nil {
return nil, false
}
status, ok := sI.(*gtsmodel.Status)
if !ok {
s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
}
return status, true
}
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
return s.conn.
NewSelect().
Model(status).
Relation("Attachments").
Relation("Tags").
Relation("Mentions").
Relation("Emojis").
Relation("Account").
Relation("InReplyToAccount").
Relation("BoostOfAccount").
Relation("CreatedWithApplication")
}
func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status {
if status.InReplyToID != "" && status.InReplyTo == nil {
if inReplyTo, cached := s.statusCached(status.InReplyToID); cached {
status.InReplyTo = inReplyTo
} else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil {
status.InReplyTo = inReplyTo
}
}
if status.BoostOfID != "" && status.BoostOf == nil {
if boostOf, cached := s.statusCached(status.BoostOfID); cached {
status.BoostOf = boostOf
} else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil {
status.BoostOf = boostOf
}
}
return status
}
func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
return s.conn.
NewSelect().
Model(faves).
Relation("Account").
Relation("TargetAccount").
Relation("Status")
}
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(id); cached {
return status, nil
}
status := new(gtsmodel.Status)
q := s.newStatusQ(status).
Where("status.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
if err != nil {
return nil, err
}
if status != nil {
s.cacheStatus(id, status)
}
return s.getAttachedStatuses(ctx, status), err
}
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.uri) = LOWER(?)", uri)
err := processErrorResponse(q.Scan(ctx))
if err != nil {
return nil, err
}
if status != nil {
s.cacheStatus(uri, status)
}
return s.getAttachedStatuses(ctx, status), err
}
func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.url) = LOWER(?)", uri)
err := processErrorResponse(q.Scan(ctx))
if err != nil {
return nil, err
}
if status != nil {
s.cacheStatus(uri, status)
}
return s.getAttachedStatuses(ctx, status), err
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
transaction := func(ctx context.Context, tx bun.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
if _, err := tx.NewInsert().Model(&gtsmodel.StatusToEmoji{
StatusID: status.ID,
EmojiID: i,
}).Exec(ctx); err != nil {
return err
}
}
// create links between this status and any tags it uses
for _, i := range status.TagIDs {
if _, err := tx.NewInsert().Model(&gtsmodel.StatusToTag{
StatusID: status.ID,
TagID: i,
}).Exec(ctx); err != nil {
return err
}
}
// change the status ID of the media attachments to the new status
for _, a := range status.Attachments {
a.StatusID = status.ID
a.UpdatedAt = time.Now()
if _, err := s.conn.NewUpdate().Model(a).
Where("id = ?", a.ID).
Exec(ctx); err != nil {
return err
}
}
_, err := tx.NewInsert().Model(status).Exec(ctx)
return err
}
return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
}
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
parents := []*gtsmodel.Status{}
s.statusParent(ctx, status, &parents, onlyDirect)
return parents, nil
}
func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
if status.InReplyToID == "" {
return
}
parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID)
if err == nil {
*foundStatuses = append(*foundStatuses, parentStatus)
}
if onlyDirect {
return
}
s.statusParent(ctx, parentStatus, foundStatuses, false)
}
func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
foundStatuses := &list.List{}
foundStatuses.PushFront(status)
s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID)
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
// only append children, not the overall parent status
if entry.ID != status.ID {
children = append(children, entry)
}
}
return children, nil
}
func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
immediateChildren := []*gtsmodel.Status{}
q := s.conn.
NewSelect().
Model(&immediateChildren).
Where("in_reply_to_id = ?", status.ID)
if minID != "" {
q = q.Where("status.id > ?", minID)
}
if err := q.Scan(ctx); err != nil {
return
}
for _, child := range immediateChildren {
insertLoop:
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
foundStatuses.InsertAfter(child, e)
break insertLoop
}
}
// only do one loop if we only want direct children
if onlyDirect {
return
}
s.statusChildren(ctx, child, foundStatuses, false, minID)
}
}
func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
return s.conn.NewSelect().Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx)
}
func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
return s.conn.NewSelect().Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx)
}
func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
return s.conn.NewSelect().Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx)
}
func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
Model(&gtsmodel.StatusFave{}).
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
}
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
Model(&gtsmodel.Status{}).
Where("boost_of_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
}
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
Model(&gtsmodel.StatusMute{}).
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
}
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
Model(&gtsmodel.StatusBookmark{}).
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
}
func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
faves := []*gtsmodel.StatusFave{}
q := s.newFaveQ(&faves).
Where("status_id = ?", status.ID)
err := processErrorResponse(q.Scan(ctx))
return faves, err
}
func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
reblogs := []*gtsmodel.Status{}
q := s.newStatusQ(&reblogs).
Where("boost_of_id = ?", status.ID)
err := processErrorResponse(q.Scan(ctx))
return reblogs, err
}

View file

@ -16,9 +16,10 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
package bundb_test
import (
"context"
"fmt"
"testing"
"time"
@ -28,7 +29,7 @@ import (
)
type StatusTestSuite struct {
PGStandardTestSuite
BunDBStandardTestSuite
}
func (suite *StatusTestSuite) SetupSuite() {
@ -56,8 +57,9 @@ func (suite *StatusTestSuite) TearDownTest() {
}
func (suite *StatusTestSuite) TestGetStatusByID() {
status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID)
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID)
if err != nil {
fmt.Println(err.Error())
suite.FailNow(err.Error())
}
suite.NotNil(status)
@ -70,7 +72,7 @@ func (suite *StatusTestSuite) TestGetStatusByID() {
}
func (suite *StatusTestSuite) TestGetStatusByURI() {
status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
if err != nil {
suite.FailNow(err.Error())
}
@ -84,7 +86,7 @@ func (suite *StatusTestSuite) TestGetStatusByURI() {
}
func (suite *StatusTestSuite) TestGetStatusWithExtras() {
status, err := suite.db.GetStatusByID(suite.testStatuses["admin_account_status_1"].ID)
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["admin_account_status_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
@ -97,7 +99,7 @@ func (suite *StatusTestSuite) TestGetStatusWithExtras() {
}
func (suite *StatusTestSuite) TestGetStatusWithMention() {
status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_2_status_5"].ID)
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_2_status_5"].ID)
if err != nil {
suite.FailNow(err.Error())
}
@ -112,18 +114,18 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() {
func (suite *StatusTestSuite) TestGetStatusTwice() {
before1 := time.Now()
_, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
_, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after1 := time.Now()
duration1 := after1.Sub(before1)
fmt.Println(duration1.Nanoseconds())
fmt.Println(duration1.Milliseconds())
before2 := time.Now()
_, err = suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
_, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after2 := time.Now()
duration2 := after2.Sub(before2)
fmt.Println(duration2.Nanoseconds())
fmt.Println(duration2.Milliseconds())
// second retrieval should be several orders faster since it will be cached now
suite.Less(duration2, duration1)

View file

@ -16,43 +16,35 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
"database/sql"
"sort"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type timelineDB struct {
config *config.Config
conn *pg.DB
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
q := t.conn.Model(&statuses)
q := t.conn.
NewSelect().
Model(&statuses)
q = q.ColumnExpr("status.*").
// Find out who accountID follows.
Join("LEFT JOIN follows AS f ON f.target_account_id = status.account_id").
// Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows,
// OR statuses posted by accountID itself (since a user should be able to see their own statuses).
//
// This is equivalent to something like WHERE ... AND (... OR ...)
// See: https://pg.uptrace.dev/queries/#select
WhereGroup(func(q *pg.Query) (*pg.Query, error) {
q = q.WhereOr("f.account_id = ?", accountID).
WhereOr("status.account_id = ?", accountID)
return q, nil
}).
// Sort by highest ID (newest) to lowest ID (oldest)
Order("status.id DESC")
@ -81,29 +73,32 @@ func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID str
q = q.Limit(limit)
}
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries
}
return nil, err
// Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows,
// OR statuses posted by accountID itself (since a user should be able to see their own statuses).
//
// This is equivalent to something like WHERE ... AND (... OR ...)
// See: https://bun.uptrace.dev/guide/queries.html#select
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
return q.
WhereOr("f.account_id = ?", accountID).
WhereOr("status.account_id = ?", accountID)
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries
}
q = q.WhereGroup(" AND ", whereGroup)
return statuses, nil
return statuses, processErrorResponse(q.Scan(ctx))
}
func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
q := t.conn.Model(&statuses).
q := t.conn.
NewSelect().
Model(&statuses).
Where("visibility = ?", gtsmodel.VisibilityPublic).
Where("? IS NULL", pg.Ident("in_reply_to_id")).
Where("? IS NULL", pg.Ident("in_reply_to_uri")).
Where("? IS NULL", pg.Ident("boost_of_id")).
Where("? IS NULL", bun.Ident("in_reply_to_id")).
Where("? IS NULL", bun.Ident("in_reply_to_uri")).
Where("? IS NULL", bun.Ident("boost_of_id")).
Order("status.id DESC")
if maxID != "" {
@ -126,28 +121,18 @@ func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID s
q = q.Limit(limit)
}
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries
}
return statuses, nil
return statuses, processErrorResponse(q.Scan(ctx))
}
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
faves := []*gtsmodel.StatusFave{}
fq := t.conn.Model(&faves).
fq := t.conn.
NewSelect().
Model(&faves).
Where("account_id = ?", accountID).
Order("id DESC")
@ -163,9 +148,9 @@ func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID stri
fq = fq.Limit(limit)
}
err := fq.Select()
err := fq.Scan(ctx)
if err != nil {
if err == pg.ErrNoRows {
if err == sql.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
@ -185,9 +170,13 @@ func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID stri
}
statuses := []*gtsmodel.Status{}
err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
err = t.conn.
NewSelect().
Model(&statuses).
Where("id IN (?)", bun.In(in)).
Scan(ctx)
if err != nil {
if err == pg.ErrNoRows {
if err == sql.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err

78
internal/db/bundb/util.go Normal file
View file

@ -0,0 +1,78 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"strings"
"database/sql"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
)
// processErrorResponse parses the given error and returns an appropriate DBError.
func processErrorResponse(err error) db.Error {
switch err {
case nil:
return nil
case sql.ErrNoRows:
return db.ErrNoEntries
default:
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
}
func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
count, err := q.Count(ctx)
exists := count != 0
err = processErrorResponse(err)
if err != nil {
if err == db.ErrNoEntries {
return false, nil
}
return false, err
}
return exists, nil
}
func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
count, err := q.Count(ctx)
notExists := count == 0
err = processErrorResponse(err)
if err != nil {
if err == db.ErrNoEntries {
return true, nil
}
return false, err
}
return notExists, nil
}

View file

@ -19,6 +19,8 @@
package db
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@ -38,6 +40,7 @@ type DB interface {
Mention
Notification
Relationship
Session
Status
Timeline
@ -52,7 +55,7 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the accounts in the DB, it's just for checking
// if they exist in the db and conveniently returning them if they do.
MentionStringsToMentions(targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error)
MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error)
// TagStringsToTags takes a slice of deduplicated, lowercase tags in the form "somehashtag", which have been
// used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then
@ -61,7 +64,7 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the tags in the DB, it's just for checking
// if they exist in the db already, and conveniently returning them, or creating new tag structs.
TagStringsToTags(tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error)
TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error)
// EmojiStringsToEmojis takes a slice of deduplicated, lowercase emojis in the form ":emojiname:", which have been
// used in a status. It takes the id of the account that wrote the status, and the id of the status itself, and then
@ -69,5 +72,5 @@ type DB interface {
//
// Note: this func doesn't/shouldn't do any manipulation of the emoji in the DB, it's just for checking
// if they exist in the db and conveniently returning them if they do.
EmojiStringsToEmojis(emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error)
EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error)
}

View file

@ -18,19 +18,22 @@
package db
import "net/url"
import (
"context"
"net/url"
)
// Domain contains DB functions related to domains and domain blocks.
type Domain interface {
// IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).
IsDomainBlocked(domain string) (bool, Error)
IsDomainBlocked(ctx context.Context, domain string) (bool, Error)
// AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found.
AreDomainsBlocked(domains []string) (bool, Error)
AreDomainsBlocked(ctx context.Context, domains []string) (bool, Error)
// IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`).
IsURIBlocked(uri *url.URL) (bool, Error)
IsURIBlocked(ctx context.Context, uri *url.URL) (bool, Error)
// AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found.
AreURIsBlocked(uris []*url.URL) (bool, Error)
AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, Error)
}

View file

@ -18,19 +18,23 @@
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Instance contains functions for instance-level actions (counting instance users etc.).
type Instance interface {
// CountInstanceUsers returns the number of known accounts registered with the given domain.
CountInstanceUsers(domain string) (int, Error)
CountInstanceUsers(ctx context.Context, domain string) (int, Error)
// CountInstanceStatuses returns the number of known statuses posted from the given domain.
CountInstanceStatuses(domain string) (int, Error)
CountInstanceStatuses(ctx context.Context, domain string) (int, Error)
// CountInstanceDomains returns the number of known instances known that the given domain federates with.
CountInstanceDomains(domain string) (int, Error)
CountInstanceDomains(ctx context.Context, domain string) (int, Error)
// GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID.
GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
}

View file

@ -18,10 +18,14 @@
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Media contains functions related to creating/getting/removing media attachments.
type Media interface {
// GetAttachmentByID gets a single attachment by its ID
GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, Error)
GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error)
}

View file

@ -18,13 +18,17 @@
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Mention contains functions for getting/creating mentions in the database.
type Mention interface {
// GetMention gets a single mention by ID
GetMention(id string) (*gtsmodel.Mention, Error)
GetMention(ctx context.Context, id string) (*gtsmodel.Mention, Error)
// GetMentions gets multiple mentions.
GetMentions(ids []string) ([]*gtsmodel.Mention, Error)
GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error)
}

View file

@ -18,14 +18,18 @@
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Notification contains functions for creating and getting notifications.
type Notification interface {
// GetNotifications returns a slice of notifications that pertain to the given accountID.
//
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
// GetNotification returns one notification according to its id.
GetNotification(id string) (*gtsmodel.Notification, Error)
GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, Error)
}

View file

@ -1,205 +0,0 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"errors"
"fmt"
"strings"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
type basicDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (b *basicDB) Put(i interface{}) db.Error {
_, err := b.conn.Model(i).Insert(i)
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
func (b *basicDB) GetByID(id string, i interface{}) db.Error {
if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.Error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := b.conn.Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", pg.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
}
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) GetAll(i interface{}) db.Error {
if err := b.conn.Model(i).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) DeleteByID(id string, i interface{}) db.Error {
if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
// if there are no rows *anyway* then that's fine
// just return err if there's an actual error
if err != pg.ErrNoRows {
return err
}
}
return nil
}
func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.Error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := b.conn.Model(i)
for _, w := range where {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
if _, err := q.Delete(); err != nil {
// if there are no rows *anyway* then that's fine
// just return err if there's an actual error
if err != pg.ErrNoRows {
return err
}
}
return nil
}
func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.Error {
if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) UpdateByID(id string, i interface{}) db.Error {
if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.Error {
_, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
return err
}
func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.Error {
q := b.conn.Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", pg.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
}
}
q = q.Set("? = ?", pg.Safe(key), value)
_, err := q.Update()
return err
}
func (b *basicDB) CreateTable(i interface{}) db.Error {
return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
}
func (b *basicDB) DropTable(i interface{}) db.Error {
return b.conn.Model(i).DropTable(&orm.DropTableOptions{
IfExists: true,
})
}
func (b *basicDB) RegisterTable(i interface{}) db.Error {
orm.RegisterTable(i)
return nil
}
func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
return b.conn.Ping(ctx)
}
func (b *basicDB) Stop(ctx context.Context) db.Error {
b.log.Info("closing db connection")
if err := b.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
b.cancel()
return err
}
return nil
}

View file

@ -1,318 +0,0 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"container/list"
"context"
"errors"
"time"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type statusDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}
func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
if s.cache == nil {
s.cache = cache.New()
}
if err := s.cache.Store(id, status); err != nil {
s.log.Panicf("statusDB: error storing in cache: %s", err)
}
}
func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
if s.cache == nil {
s.cache = cache.New()
return nil, false
}
sI, err := s.cache.Fetch(id)
if err != nil || sI == nil {
return nil, false
}
status, ok := sI.(*gtsmodel.Status)
if !ok {
s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
}
return status, true
}
func (s *statusDB) newStatusQ(status interface{}) *orm.Query {
return s.conn.Model(status).
Relation("Attachments").
Relation("Tags").
Relation("Mentions").
Relation("Emojis").
Relation("Account").
Relation("InReplyTo").
Relation("InReplyToAccount").
Relation("BoostOf").
Relation("BoostOfAccount").
Relation("CreatedWithApplication")
}
func (s *statusDB) newFaveQ(faves interface{}) *orm.Query {
return s.conn.Model(faves).
Relation("Account").
Relation("TargetAccount").
Relation("Status")
}
func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(id); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("status.id = ?", id)
err := processErrorResponse(q.Select())
if err == nil && status != nil {
s.cacheStatus(id, status)
}
return status, err
}
func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.uri) = LOWER(?)", uri)
err := processErrorResponse(q.Select())
if err == nil && status != nil {
s.cacheStatus(uri, status)
}
return status, err
}
func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.url) = LOWER(?)", uri)
err := processErrorResponse(q.Select())
if err == nil && status != nil {
s.cacheStatus(uri, status)
}
return status, err
}
func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error {
transaction := func(tx *pg.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
if _, err := tx.Model(&gtsmodel.StatusToEmoji{
StatusID: status.ID,
EmojiID: i,
}).Insert(); err != nil {
return err
}
}
// create links between this status and any tags it uses
for _, i := range status.TagIDs {
if _, err := tx.Model(&gtsmodel.StatusToTag{
StatusID: status.ID,
TagID: i,
}).Insert(); err != nil {
return err
}
}
// change the status ID of the media attachments to the new status
for _, a := range status.Attachments {
a.StatusID = status.ID
a.UpdatedAt = time.Now()
if _, err := s.conn.Model(a).
Where("id = ?", a.ID).
Update(); err != nil {
return err
}
}
_, err := tx.Model(status).Insert()
return err
}
return processErrorResponse(s.conn.RunInTransaction(context.Background(), transaction))
}
func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
parents := []*gtsmodel.Status{}
s.statusParent(status, &parents, onlyDirect)
return parents, nil
}
func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
if status.InReplyToID == "" {
return
}
parentStatus, err := s.GetStatusByID(status.InReplyToID)
if err == nil {
*foundStatuses = append(*foundStatuses, parentStatus)
}
if onlyDirect {
return
}
s.statusParent(parentStatus, foundStatuses, false)
}
func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
foundStatuses := &list.List{}
foundStatuses.PushFront(status)
s.statusChildren(status, foundStatuses, onlyDirect, minID)
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
// only append children, not the overall parent status
if entry.ID != status.ID {
children = append(children, entry)
}
}
return children, nil
}
func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
immediateChildren := []*gtsmodel.Status{}
q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
if minID != "" {
q = q.Where("status.id > ?", minID)
}
if err := q.Select(); err != nil {
return
}
for _, child := range immediateChildren {
insertLoop:
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
foundStatuses.InsertAfter(child, e)
break insertLoop
}
}
// only do one loop if we only want direct children
if onlyDirect {
return
}
s.statusChildren(child, foundStatuses, false, minID)
}
}
func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
}
func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
}
func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
}
func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
faves := []*gtsmodel.StatusFave{}
q := s.newFaveQ(&faves).
Where("status_id = ?", status.ID)
err := processErrorResponse(q.Select())
return faves, err
}
func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
reblogs := []*gtsmodel.Status{}
q := s.newStatusQ(&reblogs).
Where("boost_of_id = ?", status.ID)
err := processErrorResponse(q.Select())
return reblogs, err
}

View file

@ -1,25 +0,0 @@
package pg
import (
"strings"
"github.com/go-pg/pg/v10"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
// processErrorResponse parses the given error and returns an appropriate DBError.
func processErrorResponse(err error) db.Error {
switch err {
case nil:
return nil
case pg.ErrNoRows:
return db.ErrNoEntries
case pg.ErrMultiRows:
return db.ErrMultipleEntries
default:
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
}

View file

@ -18,54 +18,58 @@
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Relationship contains functions for getting or modifying the relationship between two accounts.
type Relationship interface {
// IsBlocked checks whether account 1 has a block in place against block2.
// If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1.
IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, Error)
IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, Error)
// GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't.
//
// Because this is slower than Blocked, only use it if you need the actual Block struct for some reason,
// not if you're just checking for the existence of a block.
GetBlock(account1 string, account2 string) (*gtsmodel.Block, Error)
GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, Error)
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
//
// It will return the newly created follow for further processing.
AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
// GetAccountFollowRequests returns all follow requests targeting the given account.
GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, Error)
GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, Error)
// GetAccountFollows returns a slice of follows owned by the given accountID.
GetAccountFollows(accountID string) ([]*gtsmodel.Follow, Error)
GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, Error)
// CountAccountFollows returns the amount of accounts that the given accountID is following.
//
// If localOnly is set to true, then only follows from *this instance* will be returned.
CountAccountFollows(accountID string, localOnly bool) (int, Error)
CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, Error)
// GetAccountFollowedBy fetches follows that target given accountID.
//
// If localOnly is set to true, then only follows from *this instance* will be returned.
GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, Error)
GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, Error)
// CountAccountFollowedBy returns the amounts that the given ID is followed by.
CountAccountFollowedBy(accountID string, localOnly bool) (int, Error)
CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, Error)
}

31
internal/db/session.go Normal file
View file

@ -0,0 +1,31 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Session handles getting/creation of router sessions.
type Session interface {
GetSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
CreateSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
}

View file

@ -18,58 +18,62 @@
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.
type Status interface {
// GetStatusByID returns one status from the database, with all rel fields populated (if possible).
GetStatusByID(id string) (*gtsmodel.Status, Error)
GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error)
// GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
GetStatusByURI(uri string) (*gtsmodel.Status, Error)
GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// GetStatusByURL returns one status from the database, with all rel fields populated (if possible).
GetStatusByURL(uri string) (*gtsmodel.Status, Error)
GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// PutStatus stores one status in the database.
PutStatus(status *gtsmodel.Status) Error
PutStatus(ctx context.Context, status *gtsmodel.Status) Error
// CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong
CountStatusReplies(status *gtsmodel.Status) (int, Error)
CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, Error)
// CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
CountStatusReblogs(status *gtsmodel.Status) (int, Error)
CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, Error)
// CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong
CountStatusFaves(status *gtsmodel.Status) (int, Error)
CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, Error)
// GetStatusParents gets the parent statuses of a given status.
//
// If onlyDirect is true, only the immediate parent will be returned.
GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
// GetStatusChildren gets the child statuses of a given status.
//
// If onlyDirect is true, only the immediate children will be returned.
GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
// IsStatusFavedBy checks if a given status has been faved by a given account ID
IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, Error)
IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, Error)
IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusMutedBy checks if a given status has been muted by a given account ID
IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, Error)
IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, Error)
IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, Error)
// GetStatusFaves returns a slice of faves/likes of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error)
GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error)
// GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
}

View file

@ -18,20 +18,24 @@
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Timeline contains functionality for retrieving home/public/faved etc timelines for an account.
type Timeline interface {
// GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
// It will use the given filters and try to return as many statuses as possible up to the limit.
@ -40,5 +44,5 @@ type Timeline interface {
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
}