Add SQLite support, fix un-thread-safe DB caches, small performance f… (#172)

* Add SQLite support, fix un-thread-safe DB caches, small performance fixes

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* add SQLite licenses to README

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* appease the linter, and fix my dumbass-ery

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* make requested changes

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* add back comment

Signed-off-by: kim (grufwub) <grufwub@gmail.com>
This commit is contained in:
kim 2021-08-29 15:41:41 +01:00 committed by GitHub
commit ed46224573
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
730 changed files with 2239881 additions and 3669 deletions

View file

@ -25,7 +25,6 @@ import (
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -34,8 +33,7 @@ import (
type accountDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
@ -52,9 +50,11 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac
q := a.newAccountQ(account).
Where("account.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
return account, err
err := q.Scan(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
}
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
@ -63,9 +63,11 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
q := a.newAccountQ(account).
Where("account.uri = ?", uri)
err := processErrorResponse(q.Scan(ctx))
return account, err
err := q.Scan(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
}
func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
@ -74,9 +76,11 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.
q := a.newAccountQ(account).
Where("account.url = ?", uri)
err := processErrorResponse(q.Scan(ctx))
return account, err
err := q.Scan(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
}
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
@ -92,10 +96,10 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
WherePK()
_, err := q.Exec(ctx)
err = processErrorResponse(err)
return account, err
if err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
}
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
@ -113,9 +117,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
WhereGroup(" AND ", whereEmptyOrNull("domain"))
}
err := processErrorResponse(q.Scan(ctx))
return account, err
err := q.Scan(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
}
func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) {
@ -129,9 +135,11 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string)
Where("account_id = ?", accountID).
Column("created_at")
err := processErrorResponse(q.Scan(ctx))
return status.CreatedAt, err
err := q.Scan(ctx)
if err != nil {
return time.Time{}, a.conn.ProcessError(err)
}
return status.CreatedAt, nil
}
func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
@ -153,17 +161,17 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
NewInsert().
Model(mediaAttachment).
Exec(ctx); err != nil {
return err
return a.conn.ProcessError(err)
}
if _, err := a.conn.
NewUpdate().
Model(&gtsmodel.Account{}).
Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
Where("id = ?", accountID).
Exec(ctx); err != nil {
return err
return a.conn.ProcessError(err)
}
return nil
}
@ -174,9 +182,11 @@ func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username stri
Where("username = ?", username).
WhereGroup(" AND ", whereEmptyOrNull("domain"))
err := processErrorResponse(q.Scan(ctx))
return account, err
err := q.Scan(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
}
func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) {
@ -187,8 +197,9 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
Model(faves).
Where("account_id = ?", accountID).
Scan(ctx); err != nil {
return nil, err
return nil, a.conn.ProcessError(err)
}
return *faves, nil
}
@ -201,7 +212,6 @@ func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string)
}
func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
a.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
q := a.conn.
@ -238,14 +248,13 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
}
if err := q.Scan(ctx); err != nil {
return nil, err
return nil, a.conn.ProcessError(err)
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries
}
a.log.Debugf("returning statuses for account %s", accountID)
return statuses, nil
}
@ -273,7 +282,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
err := fq.Scan(ctx)
if err != nil {
return nil, "", "", err
return nil, "", "", a.conn.ProcessError(err)
}
if len(blocks) == 0 {

View file

@ -29,20 +29,17 @@ import (
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
"golang.org/x/crypto/bcrypt"
)
type adminDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
@ -52,7 +49,7 @@ func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (boo
Where("username = ?", username).
Where("domain = ?", nil)
return notExists(ctx, q)
return a.conn.NotExists(ctx, q)
}
func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) {
@ -72,7 +69,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
// fail because we found something
return false, fmt.Errorf("email domain %s is blocked", domain)
} else if err != sql.ErrNoRows {
return false, processErrorResponse(err)
return false, a.conn.ProcessError(err)
}
// check if this email is associated with a user already
@ -82,13 +79,13 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
Where("email = ?", email).
WhereOr("unconfirmed_email = ?", email)
return notExists(ctx, q)
return a.conn.NotExists(ctx, q)
}
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
a.conn.log.Errorf("error creating new rsa key: %s", err)
return nil, err
}
@ -128,7 +125,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
NewInsert().
Model(acct).
Exec(ctx); err != nil {
return nil, err
return nil, a.conn.ProcessError(err)
}
}
@ -167,7 +164,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
NewInsert().
Model(u).
Exec(ctx); err != nil {
return nil, err
return nil, a.conn.ProcessError(err)
}
return u, nil
@ -184,15 +181,15 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
WhereGroup(" AND ", whereEmptyOrNull("domain"))
count, err := existsQ.Count(ctx)
if err != nil && count == 1 {
a.log.Infof("instance account %s already exists", username)
a.conn.log.Infof("instance account %s already exists", username)
return nil
} else if err != sql.ErrNoRows {
return processErrorResponse(err)
return a.conn.ProcessError(err)
}
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
a.conn.log.Errorf("error creating new rsa key: %s", err)
return err
}
@ -224,10 +221,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
Model(acct)
if _, err := insertQ.Exec(ctx); err != nil {
return err
return a.conn.ProcessError(err)
}
a.log.Infof("instance account %s CREATED with id %s", username, acct.ID)
a.conn.log.Infof("instance account %s CREATED with id %s", username, acct.ID)
return nil
}
@ -240,12 +237,12 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
Model(&gtsmodel.Instance{}).
Where("domain = ?", domain)
exists, err := exists(ctx, q)
exists, err := a.conn.Exists(ctx, q)
if err != nil {
return err
}
if exists {
a.log.Infof("instance entry already exists")
a.conn.log.Infof("instance entry already exists")
return nil
}
@ -266,10 +263,10 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
Model(i)
_, err = insertQ.Exec(ctx)
err = processErrorResponse(err)
if err == nil {
a.log.Infof("created instance instance %s with id %s", domain, i.ID)
if err != nil {
return a.conn.ProcessError(err)
}
return err
a.conn.log.Infof("created instance instance %s with id %s", domain, i.ID)
return nil
}

View file

@ -21,9 +21,7 @@ package bundb
import (
"context"
"errors"
"strings"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
@ -31,16 +29,12 @@ import (
type basicDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewInsert().Model(i).Exec(ctx)
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
return b.conn.ProcessError(err)
}
func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error {
@ -49,7 +43,8 @@ func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Erro
Model(i).
Where("id = ?", id)
return processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
return b.conn.ProcessError(err)
}
func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
@ -59,7 +54,6 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{})
q := b.conn.NewSelect().Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", bun.Ident(w.Key))
} else {
@ -71,7 +65,8 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{})
}
}
return processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
return b.conn.ProcessError(err)
}
func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
@ -79,7 +74,8 @@ func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
NewSelect().
Model(i)
return processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
return b.conn.ProcessError(err)
}
func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error {
@ -89,8 +85,7 @@ func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.E
Where("id = ?", id)
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
@ -107,8 +102,7 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface
}
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error {
@ -118,8 +112,7 @@ func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.E
WherePK()
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error {
@ -129,8 +122,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu
WherePK()
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
@ -151,8 +143,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string,
q = q.Set("? = ?", bun.Safe(key), value)
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
@ -162,7 +153,7 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
@ -170,10 +161,6 @@ func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
}
func (b *basicDB) Stop(ctx context.Context) db.Error {
b.log.Info("closing db connection")
if err := b.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
return err
}
return nil
b.conn.log.Info("closing db connection")
return b.conn.Close()
}

View file

@ -30,15 +30,19 @@ import (
"strings"
"time"
"github.com/ReneKroon/ttlcache"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/dialect/sqlitedialect"
_ "modernc.org/sqlite"
)
const (
@ -66,15 +70,14 @@ type bunDBService struct {
db.Status
db.Timeline
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
var sqldb *sql.DB
var conn *bun.DB
var conn *DBConn
// depending on the database type we're trying to create, we need to use a different driver...
switch strings.ToLower(c.DBConfig.Type) {
@ -85,10 +88,24 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
}
sqldb = stdlib.OpenDB(*opts)
conn = bun.NewDB(sqldb, pgdialect.New())
conn = WrapDBConn(bun.NewDB(sqldb, pgdialect.New()), log)
case dbTypeSqlite:
// SQLITE
// TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite
var err error
sqldb, err = sql.Open("sqlite", c.DBConfig.Address)
if err != nil {
return nil, fmt.Errorf("could not open sqlite db: %s", err)
}
conn = WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New()), log)
if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") {
log.Warn("sqlite in-memory database should only be used for debugging")
// don't close connections on disconnect -- otherwise
// the SQLite database will be deleted when there
// are no active connections
sqldb.SetConnMaxLifetime(0)
}
default:
return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
}
@ -108,66 +125,56 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
Account: &accountDB{
config: c,
conn: conn,
log: log,
},
Admin: &adminDB{
config: c,
conn: conn,
log: log,
},
Basic: &basicDB{
config: c,
conn: conn,
log: log,
},
Domain: &domainDB{
config: c,
conn: conn,
log: log,
},
Instance: &instanceDB{
config: c,
conn: conn,
log: log,
},
Media: &mediaDB{
config: c,
conn: conn,
log: log,
},
Mention: &mentionDB{
config: c,
conn: conn,
log: log,
cache: ttlcache.NewCache(),
},
Notification: &notificationDB{
config: c,
conn: conn,
log: log,
cache: ttlcache.NewCache(),
},
Relationship: &relationshipDB{
config: c,
conn: conn,
log: log,
},
Session: &sessionDB{
config: c,
conn: conn,
log: log,
},
Status: &statusDB{
config: c,
conn: conn,
log: log,
cache: cache.NewStatusCache(),
},
Timeline: &timelineDB{
config: c,
conn: conn,
log: log,
},
config: c,
conn: conn,
log: log,
}
// we can confidently return this useable service now
@ -332,7 +339,7 @@ func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAcco
if err != nil {
if err == sql.ErrNoRows {
// no result found for this username/domain so just don't include it as a mencho and carry on about our business
ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
ps.conn.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
continue
}
// a serious error has happened so bail
@ -398,7 +405,7 @@ func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []strin
if err != nil {
if err == sql.ErrNoRows {
// no result found for this username/domain so just don't include it as an emoji and carry on about our business
ps.log.Debugf("no emoji found with shortcode %s, skipping it", e)
ps.conn.log.Debugf("no emoji found with shortcode %s, skipping it", e)
continue
}
// a serious error has happened so bail

72
internal/db/bundb/conn.go Normal file
View file

@ -0,0 +1,72 @@
package bundb
import (
"context"
"database/sql"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
)
// dbConn wrapps a bun.DB conn to provide SQL-type specific additional functionality
type DBConn struct {
errProc func(error) db.Error // errProc is the SQL-type specific error processor
log *logrus.Logger // log is the logger passed with this DBConn
*bun.DB // DB is the underlying bun.DB connection
}
// WrapDBConn @TODO
func WrapDBConn(dbConn *bun.DB, log *logrus.Logger) *DBConn {
var errProc func(error) db.Error
switch dbConn.Dialect().Name() {
case dialect.PG:
errProc = processPostgresError
case dialect.SQLite:
errProc = processSQLiteError
default:
panic("unknown dialect name: " + dbConn.Dialect().Name().String())
}
return &DBConn{
errProc: errProc,
log: log,
DB: dbConn,
}
}
// ProcessError processes an error to replace any known values with our own db.Error types,
// making it easier to catch specific situations (e.g. no rows, already exists, etc)
func (conn *DBConn) ProcessError(err error) db.Error {
switch {
case err == nil:
return nil
case err == sql.ErrNoRows:
return db.ErrNoEntries
default:
return conn.errProc(err)
}
}
// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors
func (conn *DBConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
// Get the select query result
count, err := query.Count(ctx)
// Process error as our own and check if it exists
switch err := conn.ProcessError(err); err {
case nil:
return (count != 0), nil
case db.ErrNoEntries:
return false, nil
default:
return false, err
}
}
// NotExists is the functional opposite of conn.Exists()
func (conn *DBConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
// Simply inverse of conn.exists()
exists, err := conn.Exists(ctx, query)
return !exists, err
}

View file

@ -22,18 +22,15 @@ import (
"context"
"net/url"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
type domainDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
@ -47,7 +44,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db
Where("LOWER(domain) = LOWER(?)", domain).
Limit(1)
return exists(ctx, q)
return d.conn.Exists(ctx, q)
}
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {

View file

@ -0,0 +1,43 @@
package bundb
import (
"github.com/jackc/pgconn"
"github.com/superseriousbusiness/gotosocial/internal/db"
"modernc.org/sqlite"
sqlite3 "modernc.org/sqlite/lib"
)
// processPostgresError processes an error, replacing any postgres specific errors with our own error type
func processPostgresError(err error) db.Error {
// Attempt to cast as postgres
pgErr, ok := err.(*pgconn.PgError)
if !ok {
return err
}
// Handle supplied error code:
// (https://www.postgresql.org/docs/10/errcodes-appendix.html)
switch pgErr.Code {
case "23505" /* unique_violation */ :
return db.ErrAlreadyExists
default:
return err
}
}
// processSQLiteError processes an error, replacing any sqlite specific errors with our own error type
func processSQLiteError(err error) db.Error {
// Attempt to cast as sqlite
sqliteErr, ok := err.(*sqlite.Error)
if !ok {
return err
}
// Handle supplied error code:
switch sqliteErr.Code() {
case sqlite3.SQLITE_CONSTRAINT_UNIQUE:
return db.ErrAlreadyExists
default:
return err
}
}

View file

@ -21,7 +21,6 @@ package bundb
import (
"context"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -30,8 +29,7 @@ import (
type instanceDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
@ -49,8 +47,10 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int
}
count, err := q.Count(ctx)
return count, processErrorResponse(err)
if err != nil {
return 0, i.conn.ProcessError(err)
}
return count, nil
}
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
@ -68,8 +68,10 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (
}
count, err := q.Count(ctx)
return count, processErrorResponse(err)
if err != nil {
return 0, i.conn.ProcessError(err)
}
return count, nil
}
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
@ -89,12 +91,14 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i
}
count, err := q.Count(ctx)
return count, processErrorResponse(err)
if err != nil {
return 0, i.conn.ProcessError(err)
}
return count, nil
}
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
i.log.Debug("GetAccountsForInstance")
i.conn.log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{}
@ -111,7 +115,9 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max
q = q.Limit(limit)
}
err := processErrorResponse(q.Scan(ctx))
return accounts, err
err := q.Scan(ctx)
if err != nil {
return nil, i.conn.ProcessError(err)
}
return accounts, nil
}

View file

@ -21,7 +21,6 @@ package bundb
import (
"context"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -30,8 +29,7 @@ import (
type mediaDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery {
@ -47,7 +45,9 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
q := m.newMediaQ(attachment).
Where("media_attachment.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
return attachment, err
err := q.Scan(ctx)
if err != nil {
return nil, m.conn.ProcessError(err)
}
return attachment, nil
}

View file

@ -21,8 +21,7 @@ package bundb
import (
"context"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/ReneKroon/ttlcache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -31,38 +30,8 @@ import (
type mentionDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cache cache.Cache
}
func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) {
if m.cache == nil {
m.cache = cache.New()
}
if err := m.cache.Store(id, mention); err != nil {
m.log.Panicf("mentionDB: error storing in cache: %s", err)
}
}
func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) {
if m.cache == nil {
m.cache = cache.New()
return nil, false
}
mI, err := m.cache.Fetch(id)
if err != nil || mI == nil {
return nil, false
}
mention, ok := mI.(*gtsmodel.Mention)
if !ok {
m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id)
}
return mention, true
conn *DBConn
cache *ttlcache.Cache
}
func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
@ -74,33 +43,57 @@ func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
Relation("TargetAccount")
}
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
if mention, cached := m.mentionCached(id); cached {
return mention, nil
func (m *mentionDB) getMentionCached(id string) (*gtsmodel.Mention, bool) {
v, ok := m.cache.Get(id)
if !ok {
return nil, false
}
return v.(*gtsmodel.Mention), true
}
func (m *mentionDB) putMentionCache(mention *gtsmodel.Mention) {
m.cache.Set(mention.ID, mention)
}
func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
mention := &gtsmodel.Mention{}
q := m.newMentionQ(mention).
Where("mention.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
if err == nil && mention != nil {
m.cacheMention(id, mention)
err := q.Scan(ctx)
if err != nil {
return nil, m.conn.ProcessError(err)
}
return mention, err
m.putMentionCache(mention)
return mention, nil
}
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
if mention, cached := m.getMentionCached(id); cached {
return mention, nil
}
return m.getMentionDB(ctx, id)
}
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
mentions := []*gtsmodel.Mention{}
mentions := make([]*gtsmodel.Mention, 0, len(ids))
for _, i := range ids {
mention, err := m.GetMention(ctx, i)
if err != nil {
return nil, processErrorResponse(err)
for _, id := range ids {
// Attempt fetch from cache
mention, cached := m.getMentionCached(id)
if cached {
mentions = append(mentions, mention)
}
// Attempt fetch from DB
mention, err := m.getMentionDB(ctx, id)
if err != nil {
return nil, err
}
// Append mention
mentions = append(mentions, mention)
}

View file

@ -21,8 +21,7 @@ package bundb
import (
"context"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/ReneKroon/ttlcache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -31,38 +30,8 @@ import (
type notificationDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cache cache.Cache
}
func (n *notificationDB) cacheNotification(id string, notification *gtsmodel.Notification) {
if n.cache == nil {
n.cache = cache.New()
}
if err := n.cache.Store(id, notification); err != nil {
n.log.Panicf("notificationDB: error storing in cache: %s", err)
}
}
func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, bool) {
if n.cache == nil {
n.cache = cache.New()
return nil, false
}
nI, err := n.cache.Fetch(id)
if err != nil || nI == nil {
return nil, false
}
notification, ok := nI.(*gtsmodel.Notification)
if !ok {
n.log.Panicf("notificationDB: cached interface with key %s was not a notification", id)
}
return notification, true
conn *DBConn
cache *ttlcache.Cache
}
func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery {
@ -75,30 +44,30 @@ func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery {
}
func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
if notification, cached := n.notificationCached(id); cached {
if notification, cached := n.getNotificationCache(id); cached {
return notification, nil
}
notification := &gtsmodel.Notification{}
q := n.newNotificationQ(notification).
Where("notification.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
if err == nil && notification != nil {
n.cacheNotification(id, notification)
notif := &gtsmodel.Notification{}
err := n.getNotificationDB(ctx, id, notif)
if err != nil {
return nil, err
}
return notification, err
return notif, nil
}
func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
// begin by selecting just the IDs
notifIDs := []*gtsmodel.Notification{}
// Ensure reasonable
if limit < 0 {
limit = 0
}
// Make a guess for slice size
notifications := make([]*gtsmodel.Notification, 0, limit)
q := n.conn.
NewSelect().
Model(&notifIDs).
Model(&notifications).
Column("id").
Where("target_account_id = ?", accountID).
Order("id DESC")
@ -115,22 +84,52 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
q = q.Limit(limit)
}
err := processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
if err != nil {
return nil, err
return nil, n.conn.ProcessError(err)
}
// now we have the IDs, select the notifs one by one
// reason for this is that for each notif, we can instead get it from our cache if it's cached
notifications := []*gtsmodel.Notification{}
for _, notifID := range notifIDs {
notif, err := n.GetNotification(ctx, notifID.ID)
errP := processErrorResponse(err)
if errP != nil {
return nil, errP
for i, notif := range notifications {
// Check cache for notification
nn, cached := n.getNotificationCache(notif.ID)
if cached {
notifications[i] = nn
continue
}
// Check DB for notification
err := n.getNotificationDB(ctx, notif.ID, notif)
if err != nil {
return nil, err
}
notifications = append(notifications, notif)
}
return notifications, nil
}
func (n *notificationDB) getNotificationCache(id string) (*gtsmodel.Notification, bool) {
v, ok := n.cache.Get(id)
if !ok {
return nil, false
}
return v.(*gtsmodel.Notification), true
}
func (n *notificationDB) putNotificationCache(notif *gtsmodel.Notification) {
n.cache.Set(notif.ID, notif)
}
func (n *notificationDB) getNotificationDB(ctx context.Context, id string, dst *gtsmodel.Notification) error {
q := n.newNotificationQ(dst).
Where("notification.id = ?", id)
err := q.Scan(ctx)
if err != nil {
return n.conn.ProcessError(err)
}
n.putNotificationCache(dst)
return nil
}

View file

@ -23,7 +23,6 @@ import (
"database/sql"
"fmt"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -32,8 +31,7 @@ import (
type relationshipDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery {
@ -66,7 +64,7 @@ func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account
Where("account_id = ?", account2)
}
return exists(ctx, q)
return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
@ -76,9 +74,11 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2
Where("block.account_id = ?", account1).
Where("block.target_account_id = ?", account2)
err := processErrorResponse(q.Scan(ctx))
return block, err
err := q.Scan(ctx)
if err != nil {
return nil, r.conn.ProcessError(err)
}
return block, nil
}
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
@ -176,7 +176,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
Where("target_account_id = ?", targetAccount.ID).
Limit(1)
return exists(ctx, q)
return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
@ -190,7 +190,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID)
return exists(ctx, q)
return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
@ -201,13 +201,13 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod
// make sure account 1 follows account 2
f1, err := r.IsFollowing(ctx, account1, account2)
if err != nil {
return false, processErrorResponse(err)
return false, err
}
// make sure account 2 follows account 1
f2, err := r.IsFollowing(ctx, account2, account1)
if err != nil {
return false, processErrorResponse(err)
return false, err
}
return f1 && f2, nil
@ -222,7 +222,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
Where("account_id = ?", originAccountID).
Where("target_account_id = ?", targetAccountID).
Scan(ctx); err != nil {
return nil, processErrorResponse(err)
return nil, r.conn.ProcessError(err)
}
// create a new follow to 'replace' the request with
@ -239,7 +239,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
Model(follow).
On("CONFLICT ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
Exec(ctx); err != nil {
return nil, processErrorResponse(err)
return nil, r.conn.ProcessError(err)
}
// now remove the follow request
@ -249,7 +249,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
Where("account_id = ?", originAccountID).
Where("target_account_id = ?", targetAccountID).
Exec(ctx); err != nil {
return nil, processErrorResponse(err)
return nil, r.conn.ProcessError(err)
}
return follow, nil
@ -261,9 +261,11 @@ func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID
q := r.newFollowQ(&followRequests).
Where("target_account_id = ?", accountID)
err := processErrorResponse(q.Scan(ctx))
return followRequests, err
err := q.Scan(ctx)
if err != nil {
return nil, r.conn.ProcessError(err)
}
return followRequests, nil
}
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) {
@ -272,9 +274,11 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
q := r.newFollowQ(&follows).
Where("account_id = ?", accountID)
err := processErrorResponse(q.Scan(ctx))
return follows, err
err := q.Scan(ctx)
if err != nil {
return nil, r.conn.ProcessError(err)
}
return follows, nil
}
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
@ -286,7 +290,6 @@ func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID stri
}
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
q := r.conn.
@ -302,11 +305,9 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
q = q.Where("target_account_id = ?", accountID)
}
if err := q.Scan(ctx); err != nil {
if err == sql.ErrNoRows {
return follows, nil
}
return nil, processErrorResponse(err)
err := q.Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, r.conn.ProcessError(err)
}
return follows, nil
}

View file

@ -23,22 +23,19 @@ import (
"crypto/rand"
"errors"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/uptrace/bun"
)
type sessionDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
rss := []*gtsmodel.RouterSession{}
rss := make([]*gtsmodel.RouterSession, 0, 1)
_, err := s.conn.
NewSelect().
@ -47,7 +44,7 @@ func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db
Order("id DESC").
Exec(ctx)
if err != nil {
return nil, processErrorResponse(err)
return nil, s.conn.ProcessError(err)
}
if len(rss) <= 0 {
@ -92,8 +89,8 @@ func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession,
Model(rs)
_, err = q.Exec(ctx)
err = processErrorResponse(err)
return rs, err
if err != nil {
return nil, s.conn.ProcessError(err)
}
return rs, nil
}

Binary file not shown.

View file

@ -24,7 +24,6 @@ import (
"errors"
"time"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
@ -34,38 +33,8 @@ import (
type statusDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cache cache.Cache
}
func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
if s.cache == nil {
s.cache = cache.New()
}
if err := s.cache.Store(id, status); err != nil {
s.log.Panicf("statusDB: error storing in cache: %s", err)
}
}
func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
if s.cache == nil {
s.cache = cache.New()
return nil, false
}
sI, err := s.cache.Fetch(id)
if err != nil || sI == nil {
return nil, false
}
status, ok := sI.(*gtsmodel.Status)
if !ok {
s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
}
return status, true
conn *DBConn
cache *cache.StatusCache
}
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
@ -84,7 +53,9 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status {
if status.InReplyToID != "" && status.InReplyTo == nil {
if inReplyTo, cached := s.statusCached(status.InReplyToID); cached {
// TODO: do we want to keep this possibly recursive strategy?
if inReplyTo, cached := s.cache.GetByID(status.InReplyToID); cached {
status.InReplyTo = inReplyTo
} else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil {
status.InReplyTo = inReplyTo
@ -92,7 +63,9 @@ func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Sta
}
if status.BoostOfID != "" && status.BoostOf == nil {
if boostOf, cached := s.statusCached(status.BoostOfID); cached {
// TODO: do we want to keep this possibly recursive strategy?
if boostOf, cached := s.cache.GetByID(status.BoostOfID); cached {
status.BoostOf = boostOf
} else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil {
status.BoostOf = boostOf
@ -112,29 +85,26 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
}
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(id); cached {
if status, cached := s.cache.GetByID(id); cached {
return status, nil
}
status := new(gtsmodel.Status)
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("status.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
if err != nil {
return nil, err
return nil, s.conn.ProcessError(err)
}
if status != nil {
s.cacheStatus(id, status)
}
return s.getAttachedStatuses(ctx, status), err
s.cache.Put(status)
return s.getAttachedStatuses(ctx, status), nil
}
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
if status, cached := s.cache.GetByURI(uri); cached {
return status, nil
}
@ -143,38 +113,32 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St
q := s.newStatusQ(status).
Where("LOWER(status.uri) = LOWER(?)", uri)
err := processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
if err != nil {
return nil, err
return nil, s.conn.ProcessError(err)
}
if status != nil {
s.cacheStatus(uri, status)
}
return s.getAttachedStatuses(ctx, status), err
s.cache.Put(status)
return s.getAttachedStatuses(ctx, status), nil
}
func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) {
if status, cached := s.cache.GetByURL(url); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.url) = LOWER(?)", uri)
Where("LOWER(status.url) = LOWER(?)", url)
err := processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
if err != nil {
return nil, err
return nil, s.conn.ProcessError(err)
}
if status != nil {
s.cacheStatus(uri, status)
}
return s.getAttachedStatuses(ctx, status), err
s.cache.Put(status)
return s.getAttachedStatuses(ctx, status), nil
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
@ -213,14 +177,12 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
_, err := tx.NewInsert().Model(status).Exec(ctx)
return err
}
return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction))
}
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
parents := []*gtsmodel.Status{}
s.statusParent(ctx, status, &parents, onlyDirect)
return parents, nil
}
@ -318,7 +280,7 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status,
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
return s.conn.Exists(ctx, q)
}
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
@ -328,7 +290,7 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta
Where("boost_of_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
return s.conn.Exists(ctx, q)
}
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
@ -338,7 +300,7 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status,
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
return s.conn.Exists(ctx, q)
}
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
@ -348,7 +310,7 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
return s.conn.Exists(ctx, q)
}
func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
@ -357,8 +319,11 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status)
q := s.newFaveQ(&faves).
Where("status_id = ?", status.ID)
err := processErrorResponse(q.Scan(ctx))
return faves, err
err := q.Scan(ctx)
if err != nil {
return nil, s.conn.ProcessError(err)
}
return faves, nil
}
func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
@ -367,6 +332,9 @@ func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status
q := s.newStatusQ(&reblogs).
Where("boost_of_id = ?", status.ID)
err := processErrorResponse(q.Scan(ctx))
return reblogs, err
err := q.Scan(ctx)
if err != nil {
return nil, s.conn.ProcessError(err)
}
return reblogs, nil
}

View file

@ -59,7 +59,6 @@ func (suite *StatusTestSuite) TearDownTest() {
func (suite *StatusTestSuite) TestGetStatusByID() {
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID)
if err != nil {
fmt.Println(err.Error())
suite.FailNow(err.Error())
}
suite.NotNil(status)

View file

@ -23,7 +23,6 @@ import (
"database/sql"
"sort"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -32,12 +31,18 @@ import (
type timelineDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
// Ensure reasonable
if limit < 0 {
limit = 0
}
// Make educated guess for slice size
statuses := make([]*gtsmodel.Status, 0, limit)
q := t.conn.
NewSelect().
Model(&statuses)
@ -86,11 +91,21 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
q = q.WhereGroup(" AND ", whereGroup)
return statuses, processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
if err != nil {
return nil, t.conn.ProcessError(err)
}
return statuses, nil
}
func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
// Ensure reasonable
if limit < 0 {
limit = 0
}
// Make educated guess for slice size
statuses := make([]*gtsmodel.Status, 0, limit)
q := t.conn.
NewSelect().
@ -121,14 +136,23 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, ma
q = q.Limit(limit)
}
return statuses, processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
if err != nil {
return nil, t.conn.ProcessError(err)
}
return statuses, nil
}
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
// Ensure reasonable
if limit < 0 {
limit = 0
}
faves := []*gtsmodel.StatusFave{}
// Make educated guess for slice size
faves := make([]*gtsmodel.StatusFave, 0, limit)
fq := t.conn.
NewSelect().
@ -160,26 +184,23 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
return nil, "", "", db.ErrNoEntries
}
// map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID
statusesFavesMap := map[string]string{}
in := []string{}
// map[statusID]faveID -- we need this to sort statuses by fave ID rather than status ID
statusesFavesMap := make(map[string]string, len(faves))
statusIDs := make([]string, 0, len(faves))
for _, f := range faves {
statusesFavesMap[f.StatusID] = f.ID
in = append(in, f.StatusID)
statusIDs = append(statusIDs, f.StatusID)
}
statuses := []*gtsmodel.Status{}
statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
err = t.conn.
NewSelect().
Model(&statuses).
Where("id IN (?)", bun.In(in)).
Where("id IN (?)", bun.In(statusIDs)).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
return nil, "", "", t.conn.ProcessError(err)
}
if len(statuses) == 0 {

View file

@ -19,64 +19,9 @@
package bundb
import (
"context"
"strings"
"database/sql"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
)
// processErrorResponse parses the given error and returns an appropriate DBError.
func processErrorResponse(err error) db.Error {
switch err {
case nil:
return nil
case sql.ErrNoRows:
return db.ErrNoEntries
default:
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
}
func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
count, err := q.Count(ctx)
exists := count != 0
err = processErrorResponse(err)
if err != nil {
if err == db.ErrNoEntries {
return false, nil
}
return false, err
}
return exists, nil
}
func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
count, err := q.Count(ctx)
notExists := count == 0
err = processErrorResponse(err)
if err != nil {
if err == db.ErrNoEntries {
return true, nil
}
return false, err
}
return notExists, nil
}
// whereEmptyOrNull is a convenience function to return a bun WhereGroup that specifies
// that the given column should be EITHER an empty string OR null.
//