mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-10-29 10:32:25 -05:00
Add SQLite support, fix un-thread-safe DB caches, small performance f… (#172)
* Add SQLite support, fix un-thread-safe DB caches, small performance fixes Signed-off-by: kim (grufwub) <grufwub@gmail.com> * add SQLite licenses to README Signed-off-by: kim (grufwub) <grufwub@gmail.com> * appease the linter, and fix my dumbass-ery Signed-off-by: kim (grufwub) <grufwub@gmail.com> * make requested changes Signed-off-by: kim (grufwub) <grufwub@gmail.com> * add back comment Signed-off-by: kim (grufwub) <grufwub@gmail.com>
This commit is contained in:
parent
53507ac2a3
commit
ed46224573
730 changed files with 2239881 additions and 3669 deletions
BIN
internal/api/client/account/sqlite-test.db
Normal file
BIN
internal/api/client/account/sqlite-test.db
Normal file
Binary file not shown.
BIN
internal/api/client/fileserver/sqlite-test.db
Normal file
BIN
internal/api/client/fileserver/sqlite-test.db
Normal file
Binary file not shown.
BIN
internal/api/client/media/sqlite-test.db
Normal file
BIN
internal/api/client/media/sqlite-test.db
Normal file
Binary file not shown.
BIN
internal/api/client/status/sqlite-test.db
Normal file
BIN
internal/api/client/status/sqlite-test.db
Normal file
Binary file not shown.
BIN
internal/api/s2s/user/sqlite-test.db
Normal file
BIN
internal/api/s2s/user/sqlite-test.db
Normal file
Binary file not shown.
106
internal/cache/status.go
vendored
Normal file
106
internal/cache/status.go
vendored
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/ReneKroon/ttlcache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
||||
// statusCache is a wrapper around ttlcache.Cache to provide URL and URI lookups for gtsmodel.Status
|
||||
type StatusCache struct {
|
||||
cache *ttlcache.Cache // map of IDs -> cached statuses
|
||||
urls map[string]string // map of status URLs -> IDs
|
||||
uris map[string]string // map of status URIs -> IDs
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// newStatusCache returns a new instantiated statusCache object
|
||||
func NewStatusCache() *StatusCache {
|
||||
c := StatusCache{
|
||||
cache: ttlcache.NewCache(),
|
||||
urls: make(map[string]string, 100),
|
||||
uris: make(map[string]string, 100),
|
||||
mutex: sync.Mutex{},
|
||||
}
|
||||
|
||||
// Set callback to purge lookup maps on expiration
|
||||
c.cache.SetExpirationCallback(func(key string, value interface{}) {
|
||||
status := value.(*gtsmodel.Status)
|
||||
|
||||
c.mutex.Lock()
|
||||
delete(c.urls, status.URL)
|
||||
delete(c.uris, status.URI)
|
||||
c.mutex.Unlock()
|
||||
})
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
// GetByID attempts to fetch a status from the cache by its ID
|
||||
func (c *StatusCache) GetByID(id string) (*gtsmodel.Status, bool) {
|
||||
c.mutex.Lock()
|
||||
status, ok := c.getByID(id)
|
||||
c.mutex.Unlock()
|
||||
return status, ok
|
||||
}
|
||||
|
||||
// GetByURL attempts to fetch a status from the cache by its URL
|
||||
func (c *StatusCache) GetByURL(url string) (*gtsmodel.Status, bool) {
|
||||
// Perform safe ID lookup
|
||||
c.mutex.Lock()
|
||||
id, ok := c.urls[url]
|
||||
|
||||
// Not found, unlock early
|
||||
if !ok {
|
||||
c.mutex.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Attempt status lookup
|
||||
status, ok := c.getByID(id)
|
||||
c.mutex.Unlock()
|
||||
return status, ok
|
||||
}
|
||||
|
||||
// GetByURI attempts to fetch a status from the cache by its URI
|
||||
func (c *StatusCache) GetByURI(uri string) (*gtsmodel.Status, bool) {
|
||||
// Perform safe ID lookup
|
||||
c.mutex.Lock()
|
||||
id, ok := c.uris[uri]
|
||||
|
||||
// Not found, unlock early
|
||||
if !ok {
|
||||
c.mutex.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Attempt status lookup
|
||||
status, ok := c.getByID(id)
|
||||
c.mutex.Unlock()
|
||||
return status, ok
|
||||
}
|
||||
|
||||
// getByID performs an unsafe (no mutex locks) lookup of status by ID
|
||||
func (c *StatusCache) getByID(id string) (*gtsmodel.Status, bool) {
|
||||
v, ok := c.cache.Get(id)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return v.(*gtsmodel.Status), true
|
||||
}
|
||||
|
||||
// Put places a status in the cache
|
||||
func (c *StatusCache) Put(status *gtsmodel.Status) {
|
||||
if status == nil || status.ID == "" ||
|
||||
status.URL == "" ||
|
||||
status.URI == "" {
|
||||
panic("invalid status")
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.cache.Set(status.ID, status)
|
||||
c.urls[status.URL] = status.ID
|
||||
c.uris[status.URI] = status.ID
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
41
internal/cache/status_test.go
vendored
Normal file
41
internal/cache/status_test.go
vendored
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
package cache_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
||||
func TestStatusCache(t *testing.T) {
|
||||
cache := cache.NewStatusCache()
|
||||
|
||||
// Attempt to place a status
|
||||
status := gtsmodel.Status{
|
||||
ID: "id",
|
||||
URI: "uri",
|
||||
URL: "url",
|
||||
}
|
||||
cache.Put(&status)
|
||||
|
||||
var ok bool
|
||||
var check *gtsmodel.Status
|
||||
|
||||
// Check we can retrieve
|
||||
check, ok = cache.GetByID(status.ID)
|
||||
if !ok || !statusIs(&status, check) {
|
||||
t.Fatal("Could not find expected status")
|
||||
}
|
||||
check, ok = cache.GetByURI(status.URI)
|
||||
if !ok || !statusIs(&status, check) {
|
||||
t.Fatal("Could not find expected status")
|
||||
}
|
||||
check, ok = cache.GetByURL(status.URL)
|
||||
if !ok || !statusIs(&status, check) {
|
||||
t.Fatal("Could not find expected status")
|
||||
}
|
||||
}
|
||||
|
||||
func statusIs(status1, status2 *gtsmodel.Status) bool {
|
||||
return status1.ID == status2.ID && status1.URI == status2.URI && status1.URL == status2.URL
|
||||
}
|
||||
|
|
@ -25,7 +25,6 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
|
@ -34,8 +33,7 @@ import (
|
|||
|
||||
type accountDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
log *logrus.Logger
|
||||
conn *DBConn
|
||||
}
|
||||
|
||||
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
|
||||
|
|
@ -52,9 +50,11 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac
|
|||
q := a.newAccountQ(account).
|
||||
Where("account.id = ?", id)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
return account, err
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
|
||||
|
|
@ -63,9 +63,11 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
|
|||
q := a.newAccountQ(account).
|
||||
Where("account.uri = ?", uri)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
return account, err
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
|
||||
|
|
@ -74,9 +76,11 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.
|
|||
q := a.newAccountQ(account).
|
||||
Where("account.url = ?", uri)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
return account, err
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
|
||||
|
|
@ -92,10 +96,10 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
|
|||
WherePK()
|
||||
|
||||
_, err := q.Exec(ctx)
|
||||
|
||||
err = processErrorResponse(err)
|
||||
|
||||
return account, err
|
||||
if err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
|
||||
|
|
@ -113,9 +117,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
|
|||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||
}
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
return account, err
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) {
|
||||
|
|
@ -129,9 +135,11 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string)
|
|||
Where("account_id = ?", accountID).
|
||||
Column("created_at")
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
return status.CreatedAt, err
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return time.Time{}, a.conn.ProcessError(err)
|
||||
}
|
||||
return status.CreatedAt, nil
|
||||
}
|
||||
|
||||
func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
|
||||
|
|
@ -153,17 +161,17 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen
|
|||
NewInsert().
|
||||
Model(mediaAttachment).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
if _, err := a.conn.
|
||||
NewUpdate().
|
||||
Model(>smodel.Account{}).
|
||||
Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
|
||||
Where("id = ?", accountID).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -174,9 +182,11 @@ func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username stri
|
|||
Where("username = ?", username).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
return account, err
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) {
|
||||
|
|
@ -187,8 +197,9 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g
|
|||
Model(faves).
|
||||
Where("account_id = ?", accountID).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, err
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return *faves, nil
|
||||
}
|
||||
|
||||
|
|
@ -201,7 +212,6 @@ func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string)
|
|||
}
|
||||
|
||||
func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
|
||||
a.log.Debugf("getting statuses for account %s", accountID)
|
||||
statuses := []*gtsmodel.Status{}
|
||||
|
||||
q := a.conn.
|
||||
|
|
@ -238,14 +248,13 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li
|
|||
}
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, err
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
if len(statuses) == 0 {
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
|
||||
a.log.Debugf("returning statuses for account %s", accountID)
|
||||
return statuses, nil
|
||||
}
|
||||
|
||||
|
|
@ -273,7 +282,7 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI
|
|||
|
||||
err := fq.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
return nil, "", "", a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
|
|
|
|||
|
|
@ -29,20 +29,17 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||
"github.com/uptrace/bun"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type adminDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
log *logrus.Logger
|
||||
conn *DBConn
|
||||
}
|
||||
|
||||
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
|
||||
|
|
@ -52,7 +49,7 @@ func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (boo
|
|||
Where("username = ?", username).
|
||||
Where("domain = ?", nil)
|
||||
|
||||
return notExists(ctx, q)
|
||||
return a.conn.NotExists(ctx, q)
|
||||
}
|
||||
|
||||
func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) {
|
||||
|
|
@ -72,7 +69,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
|
|||
// fail because we found something
|
||||
return false, fmt.Errorf("email domain %s is blocked", domain)
|
||||
} else if err != sql.ErrNoRows {
|
||||
return false, processErrorResponse(err)
|
||||
return false, a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// check if this email is associated with a user already
|
||||
|
|
@ -82,13 +79,13 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
|
|||
Where("email = ?", email).
|
||||
WhereOr("unconfirmed_email = ?", email)
|
||||
|
||||
return notExists(ctx, q)
|
||||
return a.conn.NotExists(ctx, q)
|
||||
}
|
||||
|
||||
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
a.log.Errorf("error creating new rsa key: %s", err)
|
||||
a.conn.log.Errorf("error creating new rsa key: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -128,7 +125,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
|
|||
NewInsert().
|
||||
Model(acct).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, err
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -167,7 +164,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
|
|||
NewInsert().
|
||||
Model(u).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, err
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
|
|
@ -184,15 +181,15 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
|
|||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||
count, err := existsQ.Count(ctx)
|
||||
if err != nil && count == 1 {
|
||||
a.log.Infof("instance account %s already exists", username)
|
||||
a.conn.log.Infof("instance account %s already exists", username)
|
||||
return nil
|
||||
} else if err != sql.ErrNoRows {
|
||||
return processErrorResponse(err)
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
a.log.Errorf("error creating new rsa key: %s", err)
|
||||
a.conn.log.Errorf("error creating new rsa key: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -224,10 +221,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
|
|||
Model(acct)
|
||||
|
||||
if _, err := insertQ.Exec(ctx); err != nil {
|
||||
return err
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
a.log.Infof("instance account %s CREATED with id %s", username, acct.ID)
|
||||
a.conn.log.Infof("instance account %s CREATED with id %s", username, acct.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -240,12 +237,12 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
|
|||
Model(>smodel.Instance{}).
|
||||
Where("domain = ?", domain)
|
||||
|
||||
exists, err := exists(ctx, q)
|
||||
exists, err := a.conn.Exists(ctx, q)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
a.log.Infof("instance entry already exists")
|
||||
a.conn.log.Infof("instance entry already exists")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -266,10 +263,10 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
|
|||
Model(i)
|
||||
|
||||
_, err = insertQ.Exec(ctx)
|
||||
err = processErrorResponse(err)
|
||||
|
||||
if err == nil {
|
||||
a.log.Infof("created instance instance %s with id %s", domain, i.ID)
|
||||
if err != nil {
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
return err
|
||||
|
||||
a.conn.log.Infof("created instance instance %s with id %s", domain, i.ID)
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,9 +21,7 @@ package bundb
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/uptrace/bun"
|
||||
|
|
@ -31,16 +29,12 @@ import (
|
|||
|
||||
type basicDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
log *logrus.Logger
|
||||
conn *DBConn
|
||||
}
|
||||
|
||||
func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error {
|
||||
_, err := b.conn.NewInsert().Model(i).Exec(ctx)
|
||||
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
|
||||
return db.ErrAlreadyExists
|
||||
}
|
||||
return err
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error {
|
||||
|
|
@ -49,7 +43,8 @@ func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Erro
|
|||
Model(i).
|
||||
Where("id = ?", id)
|
||||
|
||||
return processErrorResponse(q.Scan(ctx))
|
||||
err := q.Scan(ctx)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
|
||||
|
|
@ -59,7 +54,6 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{})
|
|||
|
||||
q := b.conn.NewSelect().Model(i)
|
||||
for _, w := range where {
|
||||
|
||||
if w.Value == nil {
|
||||
q = q.Where("? IS NULL", bun.Ident(w.Key))
|
||||
} else {
|
||||
|
|
@ -71,7 +65,8 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{})
|
|||
}
|
||||
}
|
||||
|
||||
return processErrorResponse(q.Scan(ctx))
|
||||
err := q.Scan(ctx)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
|
||||
|
|
@ -79,7 +74,8 @@ func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
|
|||
NewSelect().
|
||||
Model(i)
|
||||
|
||||
return processErrorResponse(q.Scan(ctx))
|
||||
err := q.Scan(ctx)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error {
|
||||
|
|
@ -89,8 +85,7 @@ func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.E
|
|||
Where("id = ?", id)
|
||||
|
||||
_, err := q.Exec(ctx)
|
||||
|
||||
return processErrorResponse(err)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
|
||||
|
|
@ -107,8 +102,7 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface
|
|||
}
|
||||
|
||||
_, err := q.Exec(ctx)
|
||||
|
||||
return processErrorResponse(err)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error {
|
||||
|
|
@ -118,8 +112,7 @@ func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.E
|
|||
WherePK()
|
||||
|
||||
_, err := q.Exec(ctx)
|
||||
|
||||
return processErrorResponse(err)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error {
|
||||
|
|
@ -129,8 +122,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu
|
|||
WherePK()
|
||||
|
||||
_, err := q.Exec(ctx)
|
||||
|
||||
return processErrorResponse(err)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
|
||||
|
|
@ -151,8 +143,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string,
|
|||
q = q.Set("? = ?", bun.Safe(key), value)
|
||||
|
||||
_, err := q.Exec(ctx)
|
||||
|
||||
return processErrorResponse(err)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
|
||||
|
|
@ -162,7 +153,7 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
|
|||
|
||||
func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
|
||||
_, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
|
||||
return processErrorResponse(err)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
|
||||
|
|
@ -170,10 +161,6 @@ func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
|
|||
}
|
||||
|
||||
func (b *basicDB) Stop(ctx context.Context) db.Error {
|
||||
b.log.Info("closing db connection")
|
||||
if err := b.conn.Close(); err != nil {
|
||||
// only cancel if there's a problem closing the db
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
b.conn.log.Info("closing db connection")
|
||||
return b.conn.Close()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,15 +30,19 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ReneKroon/ttlcache"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/stdlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/dialect/pgdialect"
|
||||
"github.com/uptrace/bun/dialect/sqlitedialect"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -66,15 +70,14 @@ type bunDBService struct {
|
|||
db.Status
|
||||
db.Timeline
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
log *logrus.Logger
|
||||
conn *DBConn
|
||||
}
|
||||
|
||||
// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
|
||||
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
|
||||
func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
|
||||
var sqldb *sql.DB
|
||||
var conn *bun.DB
|
||||
var conn *DBConn
|
||||
|
||||
// depending on the database type we're trying to create, we need to use a different driver...
|
||||
switch strings.ToLower(c.DBConfig.Type) {
|
||||
|
|
@ -85,10 +88,24 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
|
|||
return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
|
||||
}
|
||||
sqldb = stdlib.OpenDB(*opts)
|
||||
conn = bun.NewDB(sqldb, pgdialect.New())
|
||||
conn = WrapDBConn(bun.NewDB(sqldb, pgdialect.New()), log)
|
||||
case dbTypeSqlite:
|
||||
// SQLITE
|
||||
// TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite
|
||||
var err error
|
||||
sqldb, err = sql.Open("sqlite", c.DBConfig.Address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not open sqlite db: %s", err)
|
||||
}
|
||||
conn = WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New()), log)
|
||||
|
||||
if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") {
|
||||
log.Warn("sqlite in-memory database should only be used for debugging")
|
||||
|
||||
// don't close connections on disconnect -- otherwise
|
||||
// the SQLite database will be deleted when there
|
||||
// are no active connections
|
||||
sqldb.SetConnMaxLifetime(0)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
|
||||
}
|
||||
|
|
@ -108,66 +125,56 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
|
|||
Account: &accountDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
},
|
||||
Admin: &adminDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
},
|
||||
Basic: &basicDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
},
|
||||
Domain: &domainDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
},
|
||||
Instance: &instanceDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
},
|
||||
Media: &mediaDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
},
|
||||
Mention: &mentionDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
cache: ttlcache.NewCache(),
|
||||
},
|
||||
Notification: ¬ificationDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
cache: ttlcache.NewCache(),
|
||||
},
|
||||
Relationship: &relationshipDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
},
|
||||
Session: &sessionDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
},
|
||||
Status: &statusDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
cache: cache.NewStatusCache(),
|
||||
},
|
||||
Timeline: &timelineDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
},
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
}
|
||||
|
||||
// we can confidently return this useable service now
|
||||
|
|
@ -332,7 +339,7 @@ func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAcco
|
|||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
// no result found for this username/domain so just don't include it as a mencho and carry on about our business
|
||||
ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
|
||||
ps.conn.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
|
||||
continue
|
||||
}
|
||||
// a serious error has happened so bail
|
||||
|
|
@ -398,7 +405,7 @@ func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []strin
|
|||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
// no result found for this username/domain so just don't include it as an emoji and carry on about our business
|
||||
ps.log.Debugf("no emoji found with shortcode %s, skipping it", e)
|
||||
ps.conn.log.Debugf("no emoji found with shortcode %s, skipping it", e)
|
||||
continue
|
||||
}
|
||||
// a serious error has happened so bail
|
||||
|
|
|
|||
72
internal/db/bundb/conn.go
Normal file
72
internal/db/bundb/conn.go
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
package bundb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/dialect"
|
||||
)
|
||||
|
||||
// dbConn wrapps a bun.DB conn to provide SQL-type specific additional functionality
|
||||
type DBConn struct {
|
||||
errProc func(error) db.Error // errProc is the SQL-type specific error processor
|
||||
log *logrus.Logger // log is the logger passed with this DBConn
|
||||
*bun.DB // DB is the underlying bun.DB connection
|
||||
}
|
||||
|
||||
// WrapDBConn @TODO
|
||||
func WrapDBConn(dbConn *bun.DB, log *logrus.Logger) *DBConn {
|
||||
var errProc func(error) db.Error
|
||||
switch dbConn.Dialect().Name() {
|
||||
case dialect.PG:
|
||||
errProc = processPostgresError
|
||||
case dialect.SQLite:
|
||||
errProc = processSQLiteError
|
||||
default:
|
||||
panic("unknown dialect name: " + dbConn.Dialect().Name().String())
|
||||
}
|
||||
return &DBConn{
|
||||
errProc: errProc,
|
||||
log: log,
|
||||
DB: dbConn,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessError processes an error to replace any known values with our own db.Error types,
|
||||
// making it easier to catch specific situations (e.g. no rows, already exists, etc)
|
||||
func (conn *DBConn) ProcessError(err error) db.Error {
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case err == sql.ErrNoRows:
|
||||
return db.ErrNoEntries
|
||||
default:
|
||||
return conn.errProc(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors
|
||||
func (conn *DBConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
|
||||
// Get the select query result
|
||||
count, err := query.Count(ctx)
|
||||
|
||||
// Process error as our own and check if it exists
|
||||
switch err := conn.ProcessError(err); err {
|
||||
case nil:
|
||||
return (count != 0), nil
|
||||
case db.ErrNoEntries:
|
||||
return false, nil
|
||||
default:
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// NotExists is the functional opposite of conn.Exists()
|
||||
func (conn *DBConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
|
||||
// Simply inverse of conn.exists()
|
||||
exists, err := conn.Exists(ctx, query)
|
||||
return !exists, err
|
||||
}
|
||||
|
|
@ -22,18 +22,15 @@ import (
|
|||
"context"
|
||||
"net/url"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type domainDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
log *logrus.Logger
|
||||
conn *DBConn
|
||||
}
|
||||
|
||||
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
|
||||
|
|
@ -47,7 +44,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db
|
|||
Where("LOWER(domain) = LOWER(?)", domain).
|
||||
Limit(1)
|
||||
|
||||
return exists(ctx, q)
|
||||
return d.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
|
||||
|
|
|
|||
43
internal/db/bundb/errors.go
Normal file
43
internal/db/bundb/errors.go
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
package bundb
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"modernc.org/sqlite"
|
||||
sqlite3 "modernc.org/sqlite/lib"
|
||||
)
|
||||
|
||||
// processPostgresError processes an error, replacing any postgres specific errors with our own error type
|
||||
func processPostgresError(err error) db.Error {
|
||||
// Attempt to cast as postgres
|
||||
pgErr, ok := err.(*pgconn.PgError)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle supplied error code:
|
||||
// (https://www.postgresql.org/docs/10/errcodes-appendix.html)
|
||||
switch pgErr.Code {
|
||||
case "23505" /* unique_violation */ :
|
||||
return db.ErrAlreadyExists
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// processSQLiteError processes an error, replacing any sqlite specific errors with our own error type
|
||||
func processSQLiteError(err error) db.Error {
|
||||
// Attempt to cast as sqlite
|
||||
sqliteErr, ok := err.(*sqlite.Error)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle supplied error code:
|
||||
switch sqliteErr.Code() {
|
||||
case sqlite3.SQLITE_CONSTRAINT_UNIQUE:
|
||||
return db.ErrAlreadyExists
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -21,7 +21,6 @@ package bundb
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
|
@ -30,8 +29,7 @@ import (
|
|||
|
||||
type instanceDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
log *logrus.Logger
|
||||
conn *DBConn
|
||||
}
|
||||
|
||||
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
|
||||
|
|
@ -49,8 +47,10 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int
|
|||
}
|
||||
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
return count, processErrorResponse(err)
|
||||
if err != nil {
|
||||
return 0, i.conn.ProcessError(err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
|
||||
|
|
@ -68,8 +68,10 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (
|
|||
}
|
||||
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
return count, processErrorResponse(err)
|
||||
if err != nil {
|
||||
return 0, i.conn.ProcessError(err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
|
||||
|
|
@ -89,12 +91,14 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i
|
|||
}
|
||||
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
return count, processErrorResponse(err)
|
||||
if err != nil {
|
||||
return 0, i.conn.ProcessError(err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
|
||||
i.log.Debug("GetAccountsForInstance")
|
||||
i.conn.log.Debug("GetAccountsForInstance")
|
||||
|
||||
accounts := []*gtsmodel.Account{}
|
||||
|
||||
|
|
@ -111,7 +115,9 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max
|
|||
q = q.Limit(limit)
|
||||
}
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
return accounts, err
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, i.conn.ProcessError(err)
|
||||
}
|
||||
return accounts, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ package bundb
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
|
@ -30,8 +29,7 @@ import (
|
|||
|
||||
type mediaDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
log *logrus.Logger
|
||||
conn *DBConn
|
||||
}
|
||||
|
||||
func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery {
|
||||
|
|
@ -47,7 +45,9 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
|
|||
q := m.newMediaQ(attachment).
|
||||
Where("media_attachment.id = ?", id)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
return attachment, err
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, m.conn.ProcessError(err)
|
||||
}
|
||||
return attachment, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,8 +21,7 @@ package bundb
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache"
|
||||
"github.com/ReneKroon/ttlcache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
|
@ -31,38 +30,8 @@ import (
|
|||
|
||||
type mentionDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
log *logrus.Logger
|
||||
cache cache.Cache
|
||||
}
|
||||
|
||||
func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) {
|
||||
if m.cache == nil {
|
||||
m.cache = cache.New()
|
||||
}
|
||||
|
||||
if err := m.cache.Store(id, mention); err != nil {
|
||||
m.log.Panicf("mentionDB: error storing in cache: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) {
|
||||
if m.cache == nil {
|
||||
m.cache = cache.New()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
mI, err := m.cache.Fetch(id)
|
||||
if err != nil || mI == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
mention, ok := mI.(*gtsmodel.Mention)
|
||||
if !ok {
|
||||
m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id)
|
||||
}
|
||||
|
||||
return mention, true
|
||||
conn *DBConn
|
||||
cache *ttlcache.Cache
|
||||
}
|
||||
|
||||
func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
|
||||
|
|
@ -74,33 +43,57 @@ func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
|
|||
Relation("TargetAccount")
|
||||
}
|
||||
|
||||
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
|
||||
if mention, cached := m.mentionCached(id); cached {
|
||||
return mention, nil
|
||||
func (m *mentionDB) getMentionCached(id string) (*gtsmodel.Mention, bool) {
|
||||
v, ok := m.cache.Get(id)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return v.(*gtsmodel.Mention), true
|
||||
}
|
||||
|
||||
func (m *mentionDB) putMentionCache(mention *gtsmodel.Mention) {
|
||||
m.cache.Set(mention.ID, mention)
|
||||
}
|
||||
|
||||
func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
|
||||
mention := >smodel.Mention{}
|
||||
|
||||
q := m.newMentionQ(mention).
|
||||
Where("mention.id = ?", id)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
if err == nil && mention != nil {
|
||||
m.cacheMention(id, mention)
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, m.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return mention, err
|
||||
m.putMentionCache(mention)
|
||||
return mention, nil
|
||||
}
|
||||
|
||||
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
|
||||
if mention, cached := m.getMentionCached(id); cached {
|
||||
return mention, nil
|
||||
}
|
||||
return m.getMentionDB(ctx, id)
|
||||
}
|
||||
|
||||
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
|
||||
mentions := []*gtsmodel.Mention{}
|
||||
mentions := make([]*gtsmodel.Mention, 0, len(ids))
|
||||
|
||||
for _, i := range ids {
|
||||
mention, err := m.GetMention(ctx, i)
|
||||
if err != nil {
|
||||
return nil, processErrorResponse(err)
|
||||
for _, id := range ids {
|
||||
// Attempt fetch from cache
|
||||
mention, cached := m.getMentionCached(id)
|
||||
if cached {
|
||||
mentions = append(mentions, mention)
|
||||
}
|
||||
|
||||
// Attempt fetch from DB
|
||||
mention, err := m.getMentionDB(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Append mention
|
||||
mentions = append(mentions, mention)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,8 +21,7 @@ package bundb
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache"
|
||||
"github.com/ReneKroon/ttlcache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
|
@ -31,38 +30,8 @@ import (
|
|||
|
||||
type notificationDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
log *logrus.Logger
|
||||
cache cache.Cache
|
||||
}
|
||||
|
||||
func (n *notificationDB) cacheNotification(id string, notification *gtsmodel.Notification) {
|
||||
if n.cache == nil {
|
||||
n.cache = cache.New()
|
||||
}
|
||||
|
||||
if err := n.cache.Store(id, notification); err != nil {
|
||||
n.log.Panicf("notificationDB: error storing in cache: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, bool) {
|
||||
if n.cache == nil {
|
||||
n.cache = cache.New()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
nI, err := n.cache.Fetch(id)
|
||||
if err != nil || nI == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
notification, ok := nI.(*gtsmodel.Notification)
|
||||
if !ok {
|
||||
n.log.Panicf("notificationDB: cached interface with key %s was not a notification", id)
|
||||
}
|
||||
|
||||
return notification, true
|
||||
conn *DBConn
|
||||
cache *ttlcache.Cache
|
||||
}
|
||||
|
||||
func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery {
|
||||
|
|
@ -75,30 +44,30 @@ func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery {
|
|||
}
|
||||
|
||||
func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
|
||||
if notification, cached := n.notificationCached(id); cached {
|
||||
if notification, cached := n.getNotificationCache(id); cached {
|
||||
return notification, nil
|
||||
}
|
||||
|
||||
notification := >smodel.Notification{}
|
||||
|
||||
q := n.newNotificationQ(notification).
|
||||
Where("notification.id = ?", id)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
if err == nil && notification != nil {
|
||||
n.cacheNotification(id, notification)
|
||||
notif := >smodel.Notification{}
|
||||
err := n.getNotificationDB(ctx, id, notif)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return notification, err
|
||||
return notif, nil
|
||||
}
|
||||
|
||||
func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
|
||||
// begin by selecting just the IDs
|
||||
notifIDs := []*gtsmodel.Notification{}
|
||||
// Ensure reasonable
|
||||
if limit < 0 {
|
||||
limit = 0
|
||||
}
|
||||
|
||||
// Make a guess for slice size
|
||||
notifications := make([]*gtsmodel.Notification, 0, limit)
|
||||
|
||||
q := n.conn.
|
||||
NewSelect().
|
||||
Model(¬ifIDs).
|
||||
Model(¬ifications).
|
||||
Column("id").
|
||||
Where("target_account_id = ?", accountID).
|
||||
Order("id DESC")
|
||||
|
|
@ -115,22 +84,52 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
|
|||
q = q.Limit(limit)
|
||||
}
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, n.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// now we have the IDs, select the notifs one by one
|
||||
// reason for this is that for each notif, we can instead get it from our cache if it's cached
|
||||
notifications := []*gtsmodel.Notification{}
|
||||
for _, notifID := range notifIDs {
|
||||
notif, err := n.GetNotification(ctx, notifID.ID)
|
||||
errP := processErrorResponse(err)
|
||||
if errP != nil {
|
||||
return nil, errP
|
||||
for i, notif := range notifications {
|
||||
// Check cache for notification
|
||||
nn, cached := n.getNotificationCache(notif.ID)
|
||||
if cached {
|
||||
notifications[i] = nn
|
||||
continue
|
||||
}
|
||||
|
||||
// Check DB for notification
|
||||
err := n.getNotificationDB(ctx, notif.ID, notif)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
notifications = append(notifications, notif)
|
||||
}
|
||||
|
||||
return notifications, nil
|
||||
}
|
||||
|
||||
func (n *notificationDB) getNotificationCache(id string) (*gtsmodel.Notification, bool) {
|
||||
v, ok := n.cache.Get(id)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return v.(*gtsmodel.Notification), true
|
||||
}
|
||||
|
||||
func (n *notificationDB) putNotificationCache(notif *gtsmodel.Notification) {
|
||||
n.cache.Set(notif.ID, notif)
|
||||
}
|
||||
|
||||
func (n *notificationDB) getNotificationDB(ctx context.Context, id string, dst *gtsmodel.Notification) error {
|
||||
q := n.newNotificationQ(dst).
|
||||
Where("notification.id = ?", id)
|
||||
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return n.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
n.putNotificationCache(dst)
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ import (
|
|||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
|
@ -32,8 +31,7 @@ import (
|
|||
|
||||
type relationshipDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
log *logrus.Logger
|
||||
conn *DBConn
|
||||
}
|
||||
|
||||
func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery {
|
||||
|
|
@ -66,7 +64,7 @@ func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account
|
|||
Where("account_id = ?", account2)
|
||||
}
|
||||
|
||||
return exists(ctx, q)
|
||||
return r.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
|
||||
|
|
@ -76,9 +74,11 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2
|
|||
Where("block.account_id = ?", account1).
|
||||
Where("block.target_account_id = ?", account2)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
return block, err
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
return block, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
|
||||
|
|
@ -176,7 +176,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
|
|||
Where("target_account_id = ?", targetAccount.ID).
|
||||
Limit(1)
|
||||
|
||||
return exists(ctx, q)
|
||||
return r.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
|
||||
|
|
@ -190,7 +190,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
|
|||
Where("account_id = ?", sourceAccount.ID).
|
||||
Where("target_account_id = ?", targetAccount.ID)
|
||||
|
||||
return exists(ctx, q)
|
||||
return r.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
|
||||
|
|
@ -201,13 +201,13 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod
|
|||
// make sure account 1 follows account 2
|
||||
f1, err := r.IsFollowing(ctx, account1, account2)
|
||||
if err != nil {
|
||||
return false, processErrorResponse(err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
// make sure account 2 follows account 1
|
||||
f2, err := r.IsFollowing(ctx, account2, account1)
|
||||
if err != nil {
|
||||
return false, processErrorResponse(err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
return f1 && f2, nil
|
||||
|
|
@ -222,7 +222,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
|
|||
Where("account_id = ?", originAccountID).
|
||||
Where("target_account_id = ?", targetAccountID).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, processErrorResponse(err)
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// create a new follow to 'replace' the request with
|
||||
|
|
@ -239,7 +239,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
|
|||
Model(follow).
|
||||
On("CONFLICT ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, processErrorResponse(err)
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// now remove the follow request
|
||||
|
|
@ -249,7 +249,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
|
|||
Where("account_id = ?", originAccountID).
|
||||
Where("target_account_id = ?", targetAccountID).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, processErrorResponse(err)
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return follow, nil
|
||||
|
|
@ -261,9 +261,11 @@ func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID
|
|||
q := r.newFollowQ(&followRequests).
|
||||
Where("target_account_id = ?", accountID)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
return followRequests, err
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
return followRequests, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) {
|
||||
|
|
@ -272,9 +274,11 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
|
|||
q := r.newFollowQ(&follows).
|
||||
Where("account_id = ?", accountID)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
return follows, err
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
return follows, nil
|
||||
}
|
||||
|
||||
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
|
||||
|
|
@ -286,7 +290,6 @@ func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID stri
|
|||
}
|
||||
|
||||
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
|
||||
|
||||
follows := []*gtsmodel.Follow{}
|
||||
|
||||
q := r.conn.
|
||||
|
|
@ -302,11 +305,9 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
|
|||
q = q.Where("target_account_id = ?", accountID)
|
||||
}
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return follows, nil
|
||||
}
|
||||
return nil, processErrorResponse(err)
|
||||
err := q.Scan(ctx)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
return follows, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,22 +23,19 @@ import (
|
|||
"crypto/rand"
|
||||
"errors"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type sessionDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
log *logrus.Logger
|
||||
conn *DBConn
|
||||
}
|
||||
|
||||
func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
|
||||
rss := []*gtsmodel.RouterSession{}
|
||||
rss := make([]*gtsmodel.RouterSession, 0, 1)
|
||||
|
||||
_, err := s.conn.
|
||||
NewSelect().
|
||||
|
|
@ -47,7 +44,7 @@ func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db
|
|||
Order("id DESC").
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, processErrorResponse(err)
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
if len(rss) <= 0 {
|
||||
|
|
@ -92,8 +89,8 @@ func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession,
|
|||
Model(rs)
|
||||
|
||||
_, err = q.Exec(ctx)
|
||||
|
||||
err = processErrorResponse(err)
|
||||
|
||||
return rs, err
|
||||
if err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
return rs, nil
|
||||
}
|
||||
|
|
|
|||
BIN
internal/db/bundb/sqlite-test.db
Normal file
BIN
internal/db/bundb/sqlite-test.db
Normal file
Binary file not shown.
|
|
@ -24,7 +24,6 @@ import (
|
|||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
|
|
@ -34,38 +33,8 @@ import (
|
|||
|
||||
type statusDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
log *logrus.Logger
|
||||
cache cache.Cache
|
||||
}
|
||||
|
||||
func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
|
||||
if s.cache == nil {
|
||||
s.cache = cache.New()
|
||||
}
|
||||
|
||||
if err := s.cache.Store(id, status); err != nil {
|
||||
s.log.Panicf("statusDB: error storing in cache: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
|
||||
if s.cache == nil {
|
||||
s.cache = cache.New()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
sI, err := s.cache.Fetch(id)
|
||||
if err != nil || sI == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
status, ok := sI.(*gtsmodel.Status)
|
||||
if !ok {
|
||||
s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
|
||||
}
|
||||
|
||||
return status, true
|
||||
conn *DBConn
|
||||
cache *cache.StatusCache
|
||||
}
|
||||
|
||||
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
|
||||
|
|
@ -84,7 +53,9 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
|
|||
|
||||
func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status {
|
||||
if status.InReplyToID != "" && status.InReplyTo == nil {
|
||||
if inReplyTo, cached := s.statusCached(status.InReplyToID); cached {
|
||||
// TODO: do we want to keep this possibly recursive strategy?
|
||||
|
||||
if inReplyTo, cached := s.cache.GetByID(status.InReplyToID); cached {
|
||||
status.InReplyTo = inReplyTo
|
||||
} else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil {
|
||||
status.InReplyTo = inReplyTo
|
||||
|
|
@ -92,7 +63,9 @@ func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Sta
|
|||
}
|
||||
|
||||
if status.BoostOfID != "" && status.BoostOf == nil {
|
||||
if boostOf, cached := s.statusCached(status.BoostOfID); cached {
|
||||
// TODO: do we want to keep this possibly recursive strategy?
|
||||
|
||||
if boostOf, cached := s.cache.GetByID(status.BoostOfID); cached {
|
||||
status.BoostOf = boostOf
|
||||
} else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil {
|
||||
status.BoostOf = boostOf
|
||||
|
|
@ -112,29 +85,26 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
|
|||
}
|
||||
|
||||
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
|
||||
if status, cached := s.statusCached(id); cached {
|
||||
if status, cached := s.cache.GetByID(id); cached {
|
||||
return status, nil
|
||||
}
|
||||
|
||||
status := new(gtsmodel.Status)
|
||||
status := >smodel.Status{}
|
||||
|
||||
q := s.newStatusQ(status).
|
||||
Where("status.id = ?", id)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
if status != nil {
|
||||
s.cacheStatus(id, status)
|
||||
}
|
||||
|
||||
return s.getAttachedStatuses(ctx, status), err
|
||||
s.cache.Put(status)
|
||||
return s.getAttachedStatuses(ctx, status), nil
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
|
||||
if status, cached := s.statusCached(uri); cached {
|
||||
if status, cached := s.cache.GetByURI(uri); cached {
|
||||
return status, nil
|
||||
}
|
||||
|
||||
|
|
@ -143,38 +113,32 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St
|
|||
q := s.newStatusQ(status).
|
||||
Where("LOWER(status.uri) = LOWER(?)", uri)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
if status != nil {
|
||||
s.cacheStatus(uri, status)
|
||||
}
|
||||
|
||||
return s.getAttachedStatuses(ctx, status), err
|
||||
s.cache.Put(status)
|
||||
return s.getAttachedStatuses(ctx, status), nil
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
|
||||
if status, cached := s.statusCached(uri); cached {
|
||||
func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) {
|
||||
if status, cached := s.cache.GetByURL(url); cached {
|
||||
return status, nil
|
||||
}
|
||||
|
||||
status := >smodel.Status{}
|
||||
|
||||
q := s.newStatusQ(status).
|
||||
Where("LOWER(status.url) = LOWER(?)", uri)
|
||||
Where("LOWER(status.url) = LOWER(?)", url)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
if status != nil {
|
||||
s.cacheStatus(uri, status)
|
||||
}
|
||||
|
||||
return s.getAttachedStatuses(ctx, status), err
|
||||
s.cache.Put(status)
|
||||
return s.getAttachedStatuses(ctx, status), nil
|
||||
}
|
||||
|
||||
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
|
||||
|
|
@ -213,14 +177,12 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
|
|||
_, err := tx.NewInsert().Model(status).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
|
||||
return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction))
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
|
||||
parents := []*gtsmodel.Status{}
|
||||
s.statusParent(ctx, status, &parents, onlyDirect)
|
||||
|
||||
return parents, nil
|
||||
}
|
||||
|
||||
|
|
@ -318,7 +280,7 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status,
|
|||
Where("status_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
|
||||
return exists(ctx, q)
|
||||
return s.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
|
|
@ -328,7 +290,7 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta
|
|||
Where("boost_of_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
|
||||
return exists(ctx, q)
|
||||
return s.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
|
|
@ -338,7 +300,7 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status,
|
|||
Where("status_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
|
||||
return exists(ctx, q)
|
||||
return s.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
|
|
@ -348,7 +310,7 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St
|
|||
Where("status_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
|
||||
return exists(ctx, q)
|
||||
return s.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
|
||||
|
|
@ -357,8 +319,11 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status)
|
|||
q := s.newFaveQ(&faves).
|
||||
Where("status_id = ?", status.ID)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
return faves, err
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
return faves, nil
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
|
||||
|
|
@ -367,6 +332,9 @@ func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status
|
|||
q := s.newStatusQ(&reblogs).
|
||||
Where("boost_of_id = ?", status.ID)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
return reblogs, err
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, s.conn.ProcessError(err)
|
||||
}
|
||||
return reblogs, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,7 +59,6 @@ func (suite *StatusTestSuite) TearDownTest() {
|
|||
func (suite *StatusTestSuite) TestGetStatusByID() {
|
||||
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
suite.NotNil(status)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ import (
|
|||
"database/sql"
|
||||
"sort"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
|
@ -32,12 +31,18 @@ import (
|
|||
|
||||
type timelineDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
log *logrus.Logger
|
||||
conn *DBConn
|
||||
}
|
||||
|
||||
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
|
||||
statuses := []*gtsmodel.Status{}
|
||||
// Ensure reasonable
|
||||
if limit < 0 {
|
||||
limit = 0
|
||||
}
|
||||
|
||||
// Make educated guess for slice size
|
||||
statuses := make([]*gtsmodel.Status, 0, limit)
|
||||
|
||||
q := t.conn.
|
||||
NewSelect().
|
||||
Model(&statuses)
|
||||
|
|
@ -86,11 +91,21 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
|
|||
|
||||
q = q.WhereGroup(" AND ", whereGroup)
|
||||
|
||||
return statuses, processErrorResponse(q.Scan(ctx))
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, t.conn.ProcessError(err)
|
||||
}
|
||||
return statuses, nil
|
||||
}
|
||||
|
||||
func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
|
||||
statuses := []*gtsmodel.Status{}
|
||||
// Ensure reasonable
|
||||
if limit < 0 {
|
||||
limit = 0
|
||||
}
|
||||
|
||||
// Make educated guess for slice size
|
||||
statuses := make([]*gtsmodel.Status, 0, limit)
|
||||
|
||||
q := t.conn.
|
||||
NewSelect().
|
||||
|
|
@ -121,14 +136,23 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, ma
|
|||
q = q.Limit(limit)
|
||||
}
|
||||
|
||||
return statuses, processErrorResponse(q.Scan(ctx))
|
||||
err := q.Scan(ctx)
|
||||
if err != nil {
|
||||
return nil, t.conn.ProcessError(err)
|
||||
}
|
||||
return statuses, nil
|
||||
}
|
||||
|
||||
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
|
||||
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
|
||||
func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
|
||||
// Ensure reasonable
|
||||
if limit < 0 {
|
||||
limit = 0
|
||||
}
|
||||
|
||||
faves := []*gtsmodel.StatusFave{}
|
||||
// Make educated guess for slice size
|
||||
faves := make([]*gtsmodel.StatusFave, 0, limit)
|
||||
|
||||
fq := t.conn.
|
||||
NewSelect().
|
||||
|
|
@ -160,26 +184,23 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max
|
|||
return nil, "", "", db.ErrNoEntries
|
||||
}
|
||||
|
||||
// map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID
|
||||
statusesFavesMap := map[string]string{}
|
||||
|
||||
in := []string{}
|
||||
// map[statusID]faveID -- we need this to sort statuses by fave ID rather than status ID
|
||||
statusesFavesMap := make(map[string]string, len(faves))
|
||||
statusIDs := make([]string, 0, len(faves))
|
||||
for _, f := range faves {
|
||||
statusesFavesMap[f.StatusID] = f.ID
|
||||
in = append(in, f.StatusID)
|
||||
statusIDs = append(statusIDs, f.StatusID)
|
||||
}
|
||||
|
||||
statuses := []*gtsmodel.Status{}
|
||||
statuses := make([]*gtsmodel.Status, 0, len(statusIDs))
|
||||
|
||||
err = t.conn.
|
||||
NewSelect().
|
||||
Model(&statuses).
|
||||
Where("id IN (?)", bun.In(in)).
|
||||
Where("id IN (?)", bun.In(statusIDs)).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, "", "", db.ErrNoEntries
|
||||
}
|
||||
return nil, "", "", err
|
||||
return nil, "", "", t.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
if len(statuses) == 0 {
|
||||
|
|
|
|||
|
|
@ -19,64 +19,9 @@
|
|||
package bundb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"database/sql"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// processErrorResponse parses the given error and returns an appropriate DBError.
|
||||
func processErrorResponse(err error) db.Error {
|
||||
switch err {
|
||||
case nil:
|
||||
return nil
|
||||
case sql.ErrNoRows:
|
||||
return db.ErrNoEntries
|
||||
default:
|
||||
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
|
||||
return db.ErrAlreadyExists
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
exists := count != 0
|
||||
|
||||
err = processErrorResponse(err)
|
||||
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
notExists := count == 0
|
||||
|
||||
err = processErrorResponse(err)
|
||||
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
return notExists, nil
|
||||
}
|
||||
|
||||
// whereEmptyOrNull is a convenience function to return a bun WhereGroup that specifies
|
||||
// that the given column should be EITHER an empty string OR null.
|
||||
//
|
||||
|
|
|
|||
BIN
internal/federation/dereferencing/sqlite-test.db
Normal file
BIN
internal/federation/dereferencing/sqlite-test.db
Normal file
Binary file not shown.
BIN
internal/federation/sqlite-test.db
Normal file
BIN
internal/federation/sqlite-test.db
Normal file
Binary file not shown.
BIN
internal/oauth/sqlite-test.db
Normal file
BIN
internal/oauth/sqlite-test.db
Normal file
Binary file not shown.
BIN
internal/processing/status/sqlite-test.db
Normal file
BIN
internal/processing/status/sqlite-test.db
Normal file
Binary file not shown.
BIN
internal/text/sqlite-test.db
Normal file
BIN
internal/text/sqlite-test.db
Normal file
Binary file not shown.
BIN
internal/timeline/sqlite-test.db
Normal file
BIN
internal/timeline/sqlite-test.db
Normal file
Binary file not shown.
BIN
internal/typeutils/sqlite-test.db
Normal file
BIN
internal/typeutils/sqlite-test.db
Normal file
Binary file not shown.
Loading…
Add table
Add a link
Reference in a new issue