Add SQLite support, fix un-thread-safe DB caches, small performance f… (#172)

* Add SQLite support, fix un-thread-safe DB caches, small performance fixes

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* add SQLite licenses to README

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* appease the linter, and fix my dumbass-ery

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* make requested changes

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* add back comment

Signed-off-by: kim (grufwub) <grufwub@gmail.com>
This commit is contained in:
kim 2021-08-29 15:41:41 +01:00 committed by GitHub
commit ed46224573
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
730 changed files with 2239881 additions and 3669 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

106
internal/cache/status.go vendored Normal file
View file

@ -0,0 +1,106 @@
package cache
import (
"sync"
"github.com/ReneKroon/ttlcache"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// statusCache is a wrapper around ttlcache.Cache to provide URL and URI lookups for gtsmodel.Status
type StatusCache struct {
cache *ttlcache.Cache // map of IDs -> cached statuses
urls map[string]string // map of status URLs -> IDs
uris map[string]string // map of status URIs -> IDs
mutex sync.Mutex
}
// newStatusCache returns a new instantiated statusCache object
func NewStatusCache() *StatusCache {
c := StatusCache{
cache: ttlcache.NewCache(),
urls: make(map[string]string, 100),
uris: make(map[string]string, 100),
mutex: sync.Mutex{},
}
// Set callback to purge lookup maps on expiration
c.cache.SetExpirationCallback(func(key string, value interface{}) {
status := value.(*gtsmodel.Status)
c.mutex.Lock()
delete(c.urls, status.URL)
delete(c.uris, status.URI)
c.mutex.Unlock()
})
return &c
}
// GetByID attempts to fetch a status from the cache by its ID
func (c *StatusCache) GetByID(id string) (*gtsmodel.Status, bool) {
c.mutex.Lock()
status, ok := c.getByID(id)
c.mutex.Unlock()
return status, ok
}
// GetByURL attempts to fetch a status from the cache by its URL
func (c *StatusCache) GetByURL(url string) (*gtsmodel.Status, bool) {
// Perform safe ID lookup
c.mutex.Lock()
id, ok := c.urls[url]
// Not found, unlock early
if !ok {
c.mutex.Unlock()
return nil, false
}
// Attempt status lookup
status, ok := c.getByID(id)
c.mutex.Unlock()
return status, ok
}
// GetByURI attempts to fetch a status from the cache by its URI
func (c *StatusCache) GetByURI(uri string) (*gtsmodel.Status, bool) {
// Perform safe ID lookup
c.mutex.Lock()
id, ok := c.uris[uri]
// Not found, unlock early
if !ok {
c.mutex.Unlock()
return nil, false
}
// Attempt status lookup
status, ok := c.getByID(id)
c.mutex.Unlock()
return status, ok
}
// getByID performs an unsafe (no mutex locks) lookup of status by ID
func (c *StatusCache) getByID(id string) (*gtsmodel.Status, bool) {
v, ok := c.cache.Get(id)
if !ok {
return nil, false
}
return v.(*gtsmodel.Status), true
}
// Put places a status in the cache
func (c *StatusCache) Put(status *gtsmodel.Status) {
if status == nil || status.ID == "" ||
status.URL == "" ||
status.URI == "" {
panic("invalid status")
}
c.mutex.Lock()
c.cache.Set(status.ID, status)
c.urls[status.URL] = status.ID
c.uris[status.URI] = status.ID
c.mutex.Unlock()
}

41
internal/cache/status_test.go vendored Normal file
View file

@ -0,0 +1,41 @@
package cache_test
import (
"testing"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func TestStatusCache(t *testing.T) {
cache := cache.NewStatusCache()
// Attempt to place a status
status := gtsmodel.Status{
ID: "id",
URI: "uri",
URL: "url",
}
cache.Put(&status)
var ok bool
var check *gtsmodel.Status
// Check we can retrieve
check, ok = cache.GetByID(status.ID)
if !ok || !statusIs(&status, check) {
t.Fatal("Could not find expected status")
}
check, ok = cache.GetByURI(status.URI)
if !ok || !statusIs(&status, check) {
t.Fatal("Could not find expected status")
}
check, ok = cache.GetByURL(status.URL)
if !ok || !statusIs(&status, check) {
t.Fatal("Could not find expected status")
}
}
func statusIs(status1, status2 *gtsmodel.Status) bool {
return status1.ID == status2.ID && status1.URI == status2.URI && status1.URL == status2.URL
}

View file

@ -25,7 +25,6 @@ import (
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -34,8 +33,7 @@ import (
type accountDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
@ -52,9 +50,11 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac
q := a.newAccountQ(account).
Where("account.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
return account, err
err := q.Scan(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
}
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
@ -63,9 +63,11 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
q := a.newAccountQ(account).
Where("account.uri = ?", uri)
err := processErrorResponse(q.Scan(ctx))
return account, err
err := q.Scan(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
}
func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
@ -74,9 +76,11 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.
q := a.newAccountQ(account).
Where("account.url = ?", uri)
err := processErrorResponse(q.Scan(ctx))
return account, err
err := q.Scan(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
}
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
@ -92,10 +96,10 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
WherePK()
_, err := q.Exec(ctx)
err = processErrorResponse(err)
return account, err
if err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
}
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
@ -113,9 +117,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
WhereGroup(" AND ", whereEmptyOrNull("domain"))
}
err := processErrorResponse(q.Scan(ctx))
return account, err
err := q.Scan(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
}
func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) {
@ -129,9 +135,11 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string)
Where("account_id = ?", accountID).
Column("created_at")
err := processErrorResponse(q.Scan(ctx))
return status.CreatedAt, err
err := q.Scan(ctx)
if err != nil {
return time.Time{}, a.conn.ProcessError(err)
}
return status.CreatedAt, nil
}
func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
@ -153,17 +161,17 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
NewInsert().
Model(mediaAttachment).
Exec(ctx); err != nil {
return err
return a.conn.ProcessError(err)
}
if _, err := a.conn.
NewUpdate().
Model(&gtsmodel.Account{}).
Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
Where("id = ?", accountID).
Exec(ctx); err != nil {
return err
return a.conn.ProcessError(err)
}
return nil
}
@ -174,9 +182,11 @@ func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username stri
Where("username = ?", username).
WhereGroup(" AND ", whereEmptyOrNull("domain"))
err := processErrorResponse(q.Scan(ctx))
return account, err
err := q.Scan(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
}
func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) {
@ -187,8 +197,9 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
Model(faves).
Where("account_id = ?", accountID).
Scan(ctx); err != nil {
return nil, err
return nil, a.conn.ProcessError(err)
}
return *faves, nil
}
@ -201,7 +212,6 @@ func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string)
}
func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
a.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
q := a.conn.
@ -238,14 +248,13 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
}
if err := q.Scan(ctx); err != nil {
return nil, err
return nil, a.conn.ProcessError(err)
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries
}
a.log.Debugf("returning statuses for account %s", accountID)
return statuses, nil
}
@ -273,7 +282,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
err := fq.Scan(ctx)
if err != nil {
return nil, "", "", err
return nil, "", "", a.conn.ProcessError(err)
}
if len(blocks) == 0 {

View file

@ -29,20 +29,17 @@ import (
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
"golang.org/x/crypto/bcrypt"
)
type adminDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
@ -52,7 +49,7 @@ func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (boo
Where("username = ?", username).
Where("domain = ?", nil)
return notExists(ctx, q)
return a.conn.NotExists(ctx, q)
}
func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) {
@ -72,7 +69,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
// fail because we found something
return false, fmt.Errorf("email domain %s is blocked", domain)
} else if err != sql.ErrNoRows {
return false, processErrorResponse(err)
return false, a.conn.ProcessError(err)
}
// check if this email is associated with a user already
@ -82,13 +79,13 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
Where("email = ?", email).
WhereOr("unconfirmed_email = ?", email)
return notExists(ctx, q)
return a.conn.NotExists(ctx, q)
}
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
a.conn.log.Errorf("error creating new rsa key: %s", err)
return nil, err
}
@ -128,7 +125,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
NewInsert().
Model(acct).
Exec(ctx); err != nil {
return nil, err
return nil, a.conn.ProcessError(err)
}
}
@ -167,7 +164,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
NewInsert().
Model(u).
Exec(ctx); err != nil {
return nil, err
return nil, a.conn.ProcessError(err)
}
return u, nil
@ -184,15 +181,15 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
WhereGroup(" AND ", whereEmptyOrNull("domain"))
count, err := existsQ.Count(ctx)
if err != nil && count == 1 {
a.log.Infof("instance account %s already exists", username)
a.conn.log.Infof("instance account %s already exists", username)
return nil
} else if err != sql.ErrNoRows {
return processErrorResponse(err)
return a.conn.ProcessError(err)
}
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
a.conn.log.Errorf("error creating new rsa key: %s", err)
return err
}
@ -224,10 +221,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
Model(acct)
if _, err := insertQ.Exec(ctx); err != nil {
return err
return a.conn.ProcessError(err)
}
a.log.Infof("instance account %s CREATED with id %s", username, acct.ID)
a.conn.log.Infof("instance account %s CREATED with id %s", username, acct.ID)
return nil
}
@ -240,12 +237,12 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
Model(&gtsmodel.Instance{}).
Where("domain = ?", domain)
exists, err := exists(ctx, q)
exists, err := a.conn.Exists(ctx, q)
if err != nil {
return err
}
if exists {
a.log.Infof("instance entry already exists")
a.conn.log.Infof("instance entry already exists")
return nil
}
@ -266,10 +263,10 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
Model(i)
_, err = insertQ.Exec(ctx)
err = processErrorResponse(err)
if err == nil {
a.log.Infof("created instance instance %s with id %s", domain, i.ID)
if err != nil {
return a.conn.ProcessError(err)
}
return err
a.conn.log.Infof("created instance instance %s with id %s", domain, i.ID)
return nil
}

View file

@ -21,9 +21,7 @@ package bundb
import (
"context"
"errors"
"strings"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
@ -31,16 +29,12 @@ import (
type basicDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewInsert().Model(i).Exec(ctx)
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
return b.conn.ProcessError(err)
}
func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error {
@ -49,7 +43,8 @@ func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Erro
Model(i).
Where("id = ?", id)
return processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
return b.conn.ProcessError(err)
}
func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
@ -59,7 +54,6 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{})
q := b.conn.NewSelect().Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", bun.Ident(w.Key))
} else {
@ -71,7 +65,8 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{})
}
}
return processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
return b.conn.ProcessError(err)
}
func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
@ -79,7 +74,8 @@ func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
NewSelect().
Model(i)
return processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
return b.conn.ProcessError(err)
}
func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error {
@ -89,8 +85,7 @@ func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.E
Where("id = ?", id)
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
@ -107,8 +102,7 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface
}
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error {
@ -118,8 +112,7 @@ func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.E
WherePK()
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error {
@ -129,8 +122,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu
WherePK()
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
@ -151,8 +143,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string,
q = q.Set("? = ?", bun.Safe(key), value)
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
@ -162,7 +153,7 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
@ -170,10 +161,6 @@ func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
}
func (b *basicDB) Stop(ctx context.Context) db.Error {
b.log.Info("closing db connection")
if err := b.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
return err
}
return nil
b.conn.log.Info("closing db connection")
return b.conn.Close()
}

View file

@ -30,15 +30,19 @@ import (
"strings"
"time"
"github.com/ReneKroon/ttlcache"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/dialect/sqlitedialect"
_ "modernc.org/sqlite"
)
const (
@ -66,15 +70,14 @@ type bunDBService struct {
db.Status
db.Timeline
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
var sqldb *sql.DB
var conn *bun.DB
var conn *DBConn
// depending on the database type we're trying to create, we need to use a different driver...
switch strings.ToLower(c.DBConfig.Type) {
@ -85,10 +88,24 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
}
sqldb = stdlib.OpenDB(*opts)
conn = bun.NewDB(sqldb, pgdialect.New())
conn = WrapDBConn(bun.NewDB(sqldb, pgdialect.New()), log)
case dbTypeSqlite:
// SQLITE
// TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite
var err error
sqldb, err = sql.Open("sqlite", c.DBConfig.Address)
if err != nil {
return nil, fmt.Errorf("could not open sqlite db: %s", err)
}
conn = WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New()), log)
if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") {
log.Warn("sqlite in-memory database should only be used for debugging")
// don't close connections on disconnect -- otherwise
// the SQLite database will be deleted when there
// are no active connections
sqldb.SetConnMaxLifetime(0)
}
default:
return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
}
@ -108,66 +125,56 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
Account: &accountDB{
config: c,
conn: conn,
log: log,
},
Admin: &adminDB{
config: c,
conn: conn,
log: log,
},
Basic: &basicDB{
config: c,
conn: conn,
log: log,
},
Domain: &domainDB{
config: c,
conn: conn,
log: log,
},
Instance: &instanceDB{
config: c,
conn: conn,
log: log,
},
Media: &mediaDB{
config: c,
conn: conn,
log: log,
},
Mention: &mentionDB{
config: c,
conn: conn,
log: log,
cache: ttlcache.NewCache(),
},
Notification: &notificationDB{
config: c,
conn: conn,
log: log,
cache: ttlcache.NewCache(),
},
Relationship: &relationshipDB{
config: c,
conn: conn,
log: log,
},
Session: &sessionDB{
config: c,
conn: conn,
log: log,
},
Status: &statusDB{
config: c,
conn: conn,
log: log,
cache: cache.NewStatusCache(),
},
Timeline: &timelineDB{
config: c,
conn: conn,
log: log,
},
config: c,
conn: conn,
log: log,
}
// we can confidently return this useable service now
@ -332,7 +339,7 @@ func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAcco
if err != nil {
if err == sql.ErrNoRows {
// no result found for this username/domain so just don't include it as a mencho and carry on about our business
ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
ps.conn.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
continue
}
// a serious error has happened so bail
@ -398,7 +405,7 @@ func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []strin
if err != nil {
if err == sql.ErrNoRows {
// no result found for this username/domain so just don't include it as an emoji and carry on about our business
ps.log.Debugf("no emoji found with shortcode %s, skipping it", e)
ps.conn.log.Debugf("no emoji found with shortcode %s, skipping it", e)
continue
}
// a serious error has happened so bail

72
internal/db/bundb/conn.go Normal file
View file

@ -0,0 +1,72 @@
package bundb
import (
"context"
"database/sql"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
)
// dbConn wrapps a bun.DB conn to provide SQL-type specific additional functionality
type DBConn struct {
errProc func(error) db.Error // errProc is the SQL-type specific error processor
log *logrus.Logger // log is the logger passed with this DBConn
*bun.DB // DB is the underlying bun.DB connection
}
// WrapDBConn @TODO
func WrapDBConn(dbConn *bun.DB, log *logrus.Logger) *DBConn {
var errProc func(error) db.Error
switch dbConn.Dialect().Name() {
case dialect.PG:
errProc = processPostgresError
case dialect.SQLite:
errProc = processSQLiteError
default:
panic("unknown dialect name: " + dbConn.Dialect().Name().String())
}
return &DBConn{
errProc: errProc,
log: log,
DB: dbConn,
}
}
// ProcessError processes an error to replace any known values with our own db.Error types,
// making it easier to catch specific situations (e.g. no rows, already exists, etc)
func (conn *DBConn) ProcessError(err error) db.Error {
switch {
case err == nil:
return nil
case err == sql.ErrNoRows:
return db.ErrNoEntries
default:
return conn.errProc(err)
}
}
// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors
func (conn *DBConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
// Get the select query result
count, err := query.Count(ctx)
// Process error as our own and check if it exists
switch err := conn.ProcessError(err); err {
case nil:
return (count != 0), nil
case db.ErrNoEntries:
return false, nil
default:
return false, err
}
}
// NotExists is the functional opposite of conn.Exists()
func (conn *DBConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
// Simply inverse of conn.exists()
exists, err := conn.Exists(ctx, query)
return !exists, err
}

View file

@ -22,18 +22,15 @@ import (
"context"
"net/url"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
type domainDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
@ -47,7 +44,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db
Where("LOWER(domain) = LOWER(?)", domain).
Limit(1)
return exists(ctx, q)
return d.conn.Exists(ctx, q)
}
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {

View file

@ -0,0 +1,43 @@
package bundb
import (
"github.com/jackc/pgconn"
"github.com/superseriousbusiness/gotosocial/internal/db"
"modernc.org/sqlite"
sqlite3 "modernc.org/sqlite/lib"
)
// processPostgresError processes an error, replacing any postgres specific errors with our own error type
func processPostgresError(err error) db.Error {
// Attempt to cast as postgres
pgErr, ok := err.(*pgconn.PgError)
if !ok {
return err
}
// Handle supplied error code:
// (https://www.postgresql.org/docs/10/errcodes-appendix.html)
switch pgErr.Code {
case "23505" /* unique_violation */ :
return db.ErrAlreadyExists
default:
return err
}
}
// processSQLiteError processes an error, replacing any sqlite specific errors with our own error type
func processSQLiteError(err error) db.Error {
// Attempt to cast as sqlite
sqliteErr, ok := err.(*sqlite.Error)
if !ok {
return err
}
// Handle supplied error code:
switch sqliteErr.Code() {
case sqlite3.SQLITE_CONSTRAINT_UNIQUE:
return db.ErrAlreadyExists
default:
return err
}
}

View file

@ -21,7 +21,6 @@ package bundb
import (
"context"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -30,8 +29,7 @@ import (
type instanceDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
@ -49,8 +47,10 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int
}
count, err := q.Count(ctx)
return count, processErrorResponse(err)
if err != nil {
return 0, i.conn.ProcessError(err)
}
return count, nil
}
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
@ -68,8 +68,10 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (
}
count, err := q.Count(ctx)
return count, processErrorResponse(err)
if err != nil {
return 0, i.conn.ProcessError(err)
}
return count, nil
}
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
@ -89,12 +91,14 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i
}
count, err := q.Count(ctx)
return count, processErrorResponse(err)
if err != nil {
return 0, i.conn.ProcessError(err)
}
return count, nil
}
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
i.log.Debug("GetAccountsForInstance")
i.conn.log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{}
@ -111,7 +115,9 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max
q = q.Limit(limit)
}
err := processErrorResponse(q.Scan(ctx))
return accounts, err
err := q.Scan(ctx)
if err != nil {
return nil, i.conn.ProcessError(err)
}
return accounts, nil
}

View file

@ -21,7 +21,6 @@ package bundb
import (
"context"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -30,8 +29,7 @@ import (
type mediaDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery {
@ -47,7 +45,9 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
q := m.newMediaQ(attachment).
Where("media_attachment.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
return attachment, err
err := q.Scan(ctx)
if err != nil {
return nil, m.conn.ProcessError(err)
}
return attachment, nil
}

View file

@ -21,8 +21,7 @@ package bundb
import (
"context"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/ReneKroon/ttlcache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -31,38 +30,8 @@ import (
type mentionDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cache cache.Cache
}
func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) {
if m.cache == nil {
m.cache = cache.New()
}
if err := m.cache.Store(id, mention); err != nil {
m.log.Panicf("mentionDB: error storing in cache: %s", err)
}
}
func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) {
if m.cache == nil {
m.cache = cache.New()
return nil, false
}
mI, err := m.cache.Fetch(id)
if err != nil || mI == nil {
return nil, false
}
mention, ok := mI.(*gtsmodel.Mention)
if !ok {
m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id)
}
return mention, true
conn *DBConn
cache *ttlcache.Cache
}
func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
@ -74,33 +43,57 @@ func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
Relation("TargetAccount")
}
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
if mention, cached := m.mentionCached(id); cached {
return mention, nil
func (m *mentionDB) getMentionCached(id string) (*gtsmodel.Mention, bool) {
v, ok := m.cache.Get(id)
if !ok {
return nil, false
}
return v.(*gtsmodel.Mention), true
}
func (m *mentionDB) putMentionCache(mention *gtsmodel.Mention) {
m.cache.Set(mention.ID, mention)
}
func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
mention := &gtsmodel.Mention{}
q := m.newMentionQ(mention).
Where("mention.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
if err == nil && mention != nil {
m.cacheMention(id, mention)
err := q.Scan(ctx)
if err != nil {
return nil, m.conn.ProcessError(err)
}
return mention, err
m.putMentionCache(mention)
return mention, nil
}
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
if mention, cached := m.getMentionCached(id); cached {
return mention, nil
}
return m.getMentionDB(ctx, id)
}
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
mentions := []*gtsmodel.Mention{}
mentions := make([]*gtsmodel.Mention, 0, len(ids))
for _, i := range ids {
mention, err := m.GetMention(ctx, i)
if err != nil {
return nil, processErrorResponse(err)
for _, id := range ids {
// Attempt fetch from cache
mention, cached := m.getMentionCached(id)
if cached {
mentions = append(mentions, mention)
}
// Attempt fetch from DB
mention, err := m.getMentionDB(ctx, id)
if err != nil {
return nil, err
}
// Append mention
mentions = append(mentions, mention)
}

View file

@ -21,8 +21,7 @@ package bundb
import (
"context"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/ReneKroon/ttlcache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -31,38 +30,8 @@ import (
type notificationDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cache cache.Cache
}
func (n *notificationDB) cacheNotification(id string, notification *gtsmodel.Notification) {
if n.cache == nil {
n.cache = cache.New()
}
if err := n.cache.Store(id, notification); err != nil {
n.log.Panicf("notificationDB: error storing in cache: %s", err)
}
}
func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, bool) {
if n.cache == nil {
n.cache = cache.New()
return nil, false
}
nI, err := n.cache.Fetch(id)
if err != nil || nI == nil {
return nil, false
}
notification, ok := nI.(*gtsmodel.Notification)
if !ok {
n.log.Panicf("notificationDB: cached interface with key %s was not a notification", id)
}
return notification, true
conn *DBConn
cache *ttlcache.Cache
}
func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery {
@ -75,30 +44,30 @@ func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery {
}
func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
if notification, cached := n.notificationCached(id); cached {
if notification, cached := n.getNotificationCache(id); cached {
return notification, nil
}
notification := &gtsmodel.Notification{}
q := n.newNotificationQ(notification).
Where("notification.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
if err == nil && notification != nil {
n.cacheNotification(id, notification)
notif := &gtsmodel.Notification{}
err := n.getNotificationDB(ctx, id, notif)
if err != nil {
return nil, err
}
return notification, err
return notif, nil
}
func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
// begin by selecting just the IDs
notifIDs := []*gtsmodel.Notification{}
// Ensure reasonable
if limit < 0 {
limit = 0
}
// Make a guess for slice size
notifications := make([]*gtsmodel.Notification, 0, limit)
q := n.conn.
NewSelect().
Model(&notifIDs).
Model(&notifications).
Column("id").
Where("target_account_id = ?", accountID).
Order("id DESC")
@ -115,22 +84,52 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
q = q.Limit(limit)
}
err := processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
if err != nil {
return nil, err
return nil, n.conn.ProcessError(err)
}
// now we have the IDs, select the notifs one by one
// reason for this is that for each notif, we can instead get it from our cache if it's cached
notifications := []*gtsmodel.Notification{}
for _, notifID := range notifIDs {
notif, err := n.GetNotification(ctx, notifID.ID)
errP := processErrorResponse(err)
if errP != nil {
return nil, errP
for i, notif := range notifications {
// Check cache for notification
nn, cached := n.getNotificationCache(notif.ID)
if cached {
notifications[i] = nn
continue
}
// Check DB for notification
err := n.getNotificationDB(ctx, notif.ID, notif)
if err != nil {
return nil, err
}
notifications = append(notifications, notif)
}
return notifications, nil
}
func (n *notificationDB) getNotificationCache(id string) (*gtsmodel.Notification, bool) {
v, ok := n.cache.Get(id)
if !ok {
return nil, false
}
return v.(*gtsmodel.Notification), true
}
func (n *notificationDB) putNotificationCache(notif *gtsmodel.Notification) {
n.cache.Set(notif.ID, notif)
}
func (n *notificationDB) getNotificationDB(ctx context.Context, id string, dst *gtsmodel.Notification) error {
q := n.newNotificationQ(dst).
Where("notification.id = ?", id)
err := q.Scan(ctx)
if err != nil {
return n.conn.ProcessError(err)
}
n.putNotificationCache(dst)
return nil
}

View file

@ -23,7 +23,6 @@ import (
"database/sql"
"fmt"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -32,8 +31,7 @@ import (
type relationshipDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery {
@ -66,7 +64,7 @@ func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account
Where("account_id = ?", account2)
}
return exists(ctx, q)
return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
@ -76,9 +74,11 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2
Where("block.account_id = ?", account1).
Where("block.target_account_id = ?", account2)
err := processErrorResponse(q.Scan(ctx))
return block, err
err := q.Scan(ctx)
if err != nil {
return nil, r.conn.ProcessError(err)
}
return block, nil
}
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
@ -176,7 +176,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
Where("target_account_id = ?", targetAccount.ID).
Limit(1)
return exists(ctx, q)
return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
@ -190,7 +190,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID)
return exists(ctx, q)
return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
@ -201,13 +201,13 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod
// make sure account 1 follows account 2
f1, err := r.IsFollowing(ctx, account1, account2)
if err != nil {
return false, processErrorResponse(err)
return false, err
}
// make sure account 2 follows account 1
f2, err := r.IsFollowing(ctx, account2, account1)
if err != nil {
return false, processErrorResponse(err)
return false, err
}
return f1 && f2, nil
@ -222,7 +222,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
Where("account_id = ?", originAccountID).
Where("target_account_id = ?", targetAccountID).
Scan(ctx); err != nil {
return nil, processErrorResponse(err)
return nil, r.conn.ProcessError(err)
}
// create a new follow to 'replace' the request with
@ -239,7 +239,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
Model(follow).
On("CONFLICT ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
Exec(ctx); err != nil {
return nil, processErrorResponse(err)
return nil, r.conn.ProcessError(err)
}
// now remove the follow request
@ -249,7 +249,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
Where("account_id = ?", originAccountID).
Where("target_account_id = ?", targetAccountID).
Exec(ctx); err != nil {
return nil, processErrorResponse(err)
return nil, r.conn.ProcessError(err)
}
return follow, nil
@ -261,9 +261,11 @@ func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID
q := r.newFollowQ(&followRequests).
Where("target_account_id = ?", accountID)
err := processErrorResponse(q.Scan(ctx))
return followRequests, err
err := q.Scan(ctx)
if err != nil {
return nil, r.conn.ProcessError(err)
}
return followRequests, nil
}
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) {
@ -272,9 +274,11 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
q := r.newFollowQ(&follows).
Where("account_id = ?", accountID)
err := processErrorResponse(q.Scan(ctx))
return follows, err
err := q.Scan(ctx)
if err != nil {
return nil, r.conn.ProcessError(err)
}
return follows, nil
}
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
@ -286,7 +290,6 @@ func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID stri
}
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
q := r.conn.
@ -302,11 +305,9 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
q = q.Where("target_account_id = ?", accountID)
}
if err := q.Scan(ctx); err != nil {
if err == sql.ErrNoRows {
return follows, nil
}
return nil, processErrorResponse(err)
err := q.Scan(ctx)
if err != nil && err != sql.ErrNoRows {
return nil, r.conn.ProcessError(err)
}
return follows, nil
}

View file

@ -23,22 +23,19 @@ import (
"crypto/rand"
"errors"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/uptrace/bun"
)
type sessionDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
rss := []*gtsmodel.RouterSession{}
rss := make([]*gtsmodel.RouterSession, 0, 1)
_, err := s.conn.
NewSelect().
@ -47,7 +44,7 @@ func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db
Order("id DESC").
Exec(ctx)
if err != nil {
return nil, processErrorResponse(err)
return nil, s.conn.ProcessError(err)
}
if len(rss) <= 0 {
@ -92,8 +89,8 @@ func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession,
Model(rs)
_, err = q.Exec(ctx)
err = processErrorResponse(err)
return rs, err
if err != nil {
return nil, s.conn.ProcessError(err)
}
return rs, nil
}

Binary file not shown.

View file

@ -24,7 +24,6 @@ import (
"errors"
"time"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
@ -34,38 +33,8 @@ import (
type statusDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cache cache.Cache
}
func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
if s.cache == nil {
s.cache = cache.New()
}
if err := s.cache.Store(id, status); err != nil {
s.log.Panicf("statusDB: error storing in cache: %s", err)
}
}
func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
if s.cache == nil {
s.cache = cache.New()
return nil, false
}
sI, err := s.cache.Fetch(id)
if err != nil || sI == nil {
return nil, false
}
status, ok := sI.(*gtsmodel.Status)
if !ok {
s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
}
return status, true
conn *DBConn
cache *cache.StatusCache
}
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
@ -84,7 +53,9 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status {
if status.InReplyToID != "" && status.InReplyTo == nil {
if inReplyTo, cached := s.statusCached(status.InReplyToID); cached {
// TODO: do we want to keep this possibly recursive strategy?
if inReplyTo, cached := s.cache.GetByID(status.InReplyToID); cached {
status.InReplyTo = inReplyTo
} else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil {
status.InReplyTo = inReplyTo
@ -92,7 +63,9 @@ func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Sta
}
if status.BoostOfID != "" && status.BoostOf == nil {
if boostOf, cached := s.statusCached(status.BoostOfID); cached {
// TODO: do we want to keep this possibly recursive strategy?
if boostOf, cached := s.cache.GetByID(status.BoostOfID); cached {
status.BoostOf = boostOf
} else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil {
status.BoostOf = boostOf
@ -112,29 +85,26 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
}
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(id); cached {
if status, cached := s.cache.GetByID(id); cached {
return status, nil
}
status := new(gtsmodel.Status)
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("status.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
if err != nil {
return nil, err
return nil, s.conn.ProcessError(err)
}
if status != nil {
s.cacheStatus(id, status)
}
return s.getAttachedStatuses(ctx, status), err
s.cache.Put(status)
return s.getAttachedStatuses(ctx, status), nil
}
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
if status, cached := s.cache.GetByURI(uri); cached {
return status, nil
}
@ -143,38 +113,32 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St
q := s.newStatusQ(status).
Where("LOWER(status.uri) = LOWER(?)", uri)
err := processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
if err != nil {
return nil, err
return nil, s.conn.ProcessError(err)
}
if status != nil {
s.cacheStatus(uri, status)
}
return s.getAttachedStatuses(ctx, status), err
s.cache.Put(status)
return s.getAttachedStatuses(ctx, status), nil
}
func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) {
if status, cached := s.cache.GetByURL(url); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.url) = LOWER(?)", uri)
Where("LOWER(status.url) = LOWER(?)", url)
err := processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
if err != nil {
return nil, err
return nil, s.conn.ProcessError(err)
}
if status != nil {
s.cacheStatus(uri, status)
}
return s.getAttachedStatuses(ctx, status), err
s.cache.Put(status)
return s.getAttachedStatuses(ctx, status), nil
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
@ -213,14 +177,12 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
_, err := tx.NewInsert().Model(status).Exec(ctx)
return err
}
return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction))
}
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
parents := []*gtsmodel.Status{}
s.statusParent(ctx, status, &parents, onlyDirect)
return parents, nil
}
@ -318,7 +280,7 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status,
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
return s.conn.Exists(ctx, q)
}
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
@ -328,7 +290,7 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta
Where("boost_of_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
return s.conn.Exists(ctx, q)
}
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
@ -338,7 +300,7 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status,
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
return s.conn.Exists(ctx, q)
}
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
@ -348,7 +310,7 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
return s.conn.Exists(ctx, q)
}
func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
@ -357,8 +319,11 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status)
q := s.newFaveQ(&faves).
Where("status_id = ?", status.ID)
err := processErrorResponse(q.Scan(ctx))
return faves, err
err := q.Scan(ctx)
if err != nil {
return nil, s.conn.ProcessError(err)
}
return faves, nil
}
func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
@ -367,6 +332,9 @@ func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status
q := s.newStatusQ(&reblogs).
Where("boost_of_id = ?", status.ID)
err := processErrorResponse(q.Scan(ctx))
return reblogs, err
err := q.Scan(ctx)
if err != nil {
return nil, s.conn.ProcessError(err)
}
return reblogs, nil
}

View file

@ -59,7 +59,6 @@ func (suite *StatusTestSuite) TearDownTest() {
func (suite *StatusTestSuite) TestGetStatusByID() {
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID)
if err != nil {
fmt.Println(err.Error())
suite.FailNow(err.Error())
}
suite.NotNil(status)

View file

@ -23,7 +23,6 @@ import (
"database/sql"
"sort"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -32,12 +31,18 @@ import (
type timelineDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
conn *DBConn
}
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
// Ensure reasonable
if limit < 0 {
limit = 0
}
// Make educated guess for slice size
statuses := make([]*gtsmodel.Status, 0, limit)
q := t.conn.
NewSelect().
Model(&statuses)
@ -86,11 +91,21 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
q = q.WhereGroup(" AND ", whereGroup)
return statuses, processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
if err != nil {
return nil, t.conn.ProcessError(err)
}
return statuses, nil
}
func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
// Ensure reasonable
if limit < 0 {
limit = 0
}
// Make educated guess for slice size
statuses := make([]*gtsmodel.Status, 0, limit)
q := t.conn.
NewSelect().
@ -121,14 +136,23 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, ma
q = q.Limit(limit)
}
return statuses, processErrorResponse(q.Scan(ctx))
err := q.Scan(ctx)
if err != nil {
return nil, t.conn.ProcessError(err)
}
return statuses, nil
}
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
// Ensure reasonable
if limit < 0 {
limit = 0
}
faves := []*gtsmodel.StatusFave{}
// Make educated guess for slice size
faves := make([]*gtsmodel.StatusFave, 0, limit)
fq := t.conn.
NewSelect().
@ -160,26 +184,23 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
return nil, "", "", db.ErrNoEntries
}
// map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID
statusesFavesMap := map[string]string{}
in := []string{}
// map[statusID]faveID -- we need this to sort statuses by fave ID rather than status ID
statusesFavesMap := make(map[string]string, len(faves))
statusIDs := make([]string, 0, len(faves))
for _, f := range faves {
statusesFavesMap[f.StatusID] = f.ID
in = append(in, f.StatusID)
statusIDs = append(statusIDs, f.StatusID)
}
statuses := []*gtsmodel.Status{}
statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
err = t.conn.
NewSelect().
Model(&statuses).
Where("id IN (?)", bun.In(in)).
Where("id IN (?)", bun.In(statusIDs)).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
return nil, "", "", t.conn.ProcessError(err)
}
if len(statuses) == 0 {

View file

@ -19,64 +19,9 @@
package bundb
import (
"context"
"strings"
"database/sql"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
)
// processErrorResponse parses the given error and returns an appropriate DBError.
func processErrorResponse(err error) db.Error {
switch err {
case nil:
return nil
case sql.ErrNoRows:
return db.ErrNoEntries
default:
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
}
func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
count, err := q.Count(ctx)
exists := count != 0
err = processErrorResponse(err)
if err != nil {
if err == db.ErrNoEntries {
return false, nil
}
return false, err
}
return exists, nil
}
func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
count, err := q.Count(ctx)
notExists := count == 0
err = processErrorResponse(err)
if err != nil {
if err == db.ErrNoEntries {
return true, nil
}
return false, err
}
return notExists, nil
}
// whereEmptyOrNull is a convenience function to return a bun WhereGroup that specifies
// that the given column should be EITHER an empty string OR null.
//

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.