Improve GetRemoteStatus and db.GetStatus() logic (#174)

* only fetch status parents / children if explicity requested when dereferencing

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* Remove recursive DB GetStatus logic, don't fetch parent unless requested

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* StatusCache copies status so there are no thread-safety issues with modified status objects

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* remove sqlite test files

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* fix bugs introduced by previous commit

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* fix not continue on error in loop

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* use our own RunInTx implementation (possible fix for nested tx error)

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* fix cast statement to work with SQLite

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* be less strict about valid status in cache

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* add cache=shared ALWAYS for SQLite db instances

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* Fix EnrichRemoteAccount when updating account fails

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* add nolint tag

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* ensure file: prefixes the filename in sqlite addr

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* add an account cache, add status author account from db

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* Fix incompatible SQLite query

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* *actually* use the new getAccount() function in accountsDB

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* update cache tests to use test suite

Signed-off-by: kim (grufwub) <grufwub@gmail.com>

* add RelationshipTestSuite, add tests for methods with changed SQL

Signed-off-by: kim (grufwub) <grufwub@gmail.com>
This commit is contained in:
kim 2021-09-01 10:08:21 +01:00 committed by GitHub
commit 7d193de25f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 660 additions and 234 deletions

View file

@ -25,6 +25,7 @@ import (
"strings"
"time"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -34,6 +35,7 @@ import (
type accountDB struct {
config *config.Config
conn *DBConn
cache *cache.AccountCache
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
@ -45,60 +47,80 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
}
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("account.id = ?", id)
err := q.Scan(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
return a.getAccount(
ctx,
func() (*gtsmodel.Account, bool) {
return a.cache.GetByID(id)
},
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx)
},
)
}
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("account.uri = ?", uri)
err := q.Scan(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
}
return account, nil
return a.getAccount(
ctx,
func() (*gtsmodel.Account, bool) {
return a.cache.GetByURI(uri)
},
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx)
},
)
}
func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
func() (*gtsmodel.Account, bool) {
return a.cache.GetByURL(url)
},
func(account *gtsmodel.Account) error {
return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx)
},
)
}
q := a.newAccountQ(account).
Where("account.url = ?", uri)
func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) {
// Attempt to fetch cached account
account, cached := cacheGet()
err := q.Scan(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
if !cached {
account = &gtsmodel.Account{}
// Not cached! Perform database query
err := dbQuery(account)
if err != nil {
return nil, a.conn.ProcessError(err)
}
// Place in the cache
a.cache.Put(account)
}
return account, nil
}
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
if strings.TrimSpace(account.ID) == "" {
// TODO: we should not need this check here
return nil, errors.New("account had no ID")
}
// Update the account's last-used
account.UpdatedAt = time.Now()
q := a.conn.
NewUpdate().
Model(account).
WherePK()
_, err := q.Exec(ctx)
// Update the account model in the DB
_, err := a.conn.NewUpdate().Model(account).WherePK().Exec(ctx)
if err != nil {
return nil, a.conn.ProcessError(err)
}
// Place updated account in cache
// (this will replace existing, i.e. invalidating)
a.cache.Put(account)
return account, nil
}

View file

@ -91,6 +91,15 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
conn = WrapDBConn(bun.NewDB(sqldb, pgdialect.New()), log)
case dbTypeSqlite:
// SQLITE
// Drop anything fancy from DB address
c.DBConfig.Address = strings.Split(c.DBConfig.Address, "?")[0]
c.DBConfig.Address = strings.TrimPrefix(c.DBConfig.Address, "file:")
// Append our own SQLite preferences
c.DBConfig.Address = "file:" + c.DBConfig.Address + "?cache=shared"
// Open new DB instance
var err error
sqldb, err = sql.Open("sqlite", c.DBConfig.Address)
if err != nil {
@ -98,7 +107,7 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
}
conn = WrapDBConn(bun.NewDB(sqldb, sqlitedialect.New()), log)
if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") {
if c.DBConfig.Address == "file::memory:?cache=shared" {
log.Warn("sqlite in-memory database should only be used for debugging")
// don't close connections on disconnect -- otherwise
@ -121,11 +130,10 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
conn.RegisterModel(t)
}
accounts := &accountDB{config: c, conn: conn, cache: cache.NewAccountCache()}
ps := &bunDBService{
Account: &accountDB{
config: c,
conn: conn,
},
Account: accounts,
Admin: &adminDB{
config: c,
conn: conn,
@ -165,9 +173,10 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
conn: conn,
},
Status: &statusDB{
config: c,
conn: conn,
cache: cache.NewStatusCache(),
config: c,
conn: conn,
cache: cache.NewStatusCache(),
accounts: accounts,
},
Timeline: &timelineDB{
config: c,

View file

@ -12,6 +12,8 @@ import (
// dbConn wrapps a bun.DB conn to provide SQL-type specific additional functionality
type DBConn struct {
// TODO: move *Config here, no need to be in each struct type
errProc func(error) db.Error // errProc is the SQL-type specific error processor
log *logrus.Logger // log is the logger passed with this DBConn
*bun.DB // DB is the underlying bun.DB connection
@ -35,6 +37,24 @@ func WrapDBConn(dbConn *bun.DB, log *logrus.Logger) *DBConn {
}
}
func (conn *DBConn) RunInTx(ctx context.Context, fn func(bun.Tx) error) db.Error {
// Acquire a new transaction
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
return conn.ProcessError(err)
}
// Perform supplied transaction
if err = fn(tx); err != nil {
tx.Rollback() //nolint
return conn.ProcessError(err)
}
// Finally, commit transaction
err = tx.Commit()
return conn.ProcessError(err)
}
// ProcessError processes an error to replace any known values with our own db.Error types,
// making it easier to catch specific situations (e.g. no rows, already exists, etc)
func (conn *DBConn) ProcessError(err error) db.Error {

View file

@ -237,7 +237,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
if _, err := r.conn.
NewInsert().
Model(follow).
On("CONFLICT ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI).
Exec(ctx); err != nil {
return nil, r.conn.ProcessError(err)
}
@ -298,7 +298,7 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
if localOnly {
q = q.ColumnExpr("follow.*").
Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)").
Where("follow.target_account_id = ?", accountID).
WhereGroup(" AND ", whereEmptyOrNull("a.domain"))
} else {

View file

@ -0,0 +1,124 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb_test
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type RelationshipTestSuite struct {
BunDBStandardTestSuite
}
func (suite *RelationshipTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
suite.testAttachments = testrig.NewTestAttachments()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
}
func (suite *RelationshipTestSuite) SetupTest() {
suite.config = testrig.NewTestConfig()
suite.db = testrig.NewTestDB()
suite.log = testrig.NewTestLog()
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
func (suite *RelationshipTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *RelationshipTestSuite) TestIsBlocked() {
suite.Suite.T().Skip("TODO: implement")
}
func (suite *RelationshipTestSuite) TestGetBlock() {
suite.Suite.T().Skip("TODO: implement")
}
func (suite *RelationshipTestSuite) TestGetRelationship() {
suite.Suite.T().Skip("TODO: implement")
}
func (suite *RelationshipTestSuite) TestIsFollowing() {
suite.Suite.T().Skip("TODO: implement")
}
func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
suite.Suite.T().Skip("TODO: implement")
}
func (suite *RelationshipTestSuite) AcceptFollowRequest() {
for _, account := range suite.testAccounts {
_, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID")
if err != nil && !errors.Is(err, db.ErrNoEntries) {
suite.Suite.Fail("error accepting follow request: %v", err)
}
}
}
func (suite *RelationshipTestSuite) GetAccountFollowRequests() {
suite.Suite.T().Skip("TODO: implement")
}
func (suite *RelationshipTestSuite) GetAccountFollows() {
suite.Suite.T().Skip("TODO: implement")
}
func (suite *RelationshipTestSuite) CountAccountFollows() {
suite.Suite.T().Skip("TODO: implement")
}
func (suite *RelationshipTestSuite) GetAccountFollowedBy() {
// TODO: more comprehensive tests here
for _, account := range suite.testAccounts {
var err error
_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false)
if err != nil {
suite.Suite.Fail("error checking accounts followed by: %v", err)
}
_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true)
if err != nil {
suite.Suite.Fail("error checking localOnly accounts followed by: %v", err)
}
}
}
func (suite *RelationshipTestSuite) CountAccountFollowedBy() {
suite.Suite.T().Skip("TODO: implement")
}
func TestRelationshipTestSuite(t *testing.T) {
suite.Run(t, new(RelationshipTestSuite))
}

Binary file not shown.

View file

@ -21,7 +21,6 @@ package bundb
import (
"container/list"
"context"
"errors"
"time"
"github.com/superseriousbusiness/gotosocial/internal/cache"
@ -35,6 +34,11 @@ type statusDB struct {
config *config.Config
conn *DBConn
cache *cache.StatusCache
// TODO: keep method definitions in same place but instead have receiver
// all point to one single "db" type, so they can all share methods
// and caches where necessary
accounts *accountDB
}
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
@ -51,30 +55,6 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
Relation("CreatedWithApplication")
}
func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status {
if status.InReplyToID != "" && status.InReplyTo == nil {
// TODO: do we want to keep this possibly recursive strategy?
if inReplyTo, cached := s.cache.GetByID(status.InReplyToID); cached {
status.InReplyTo = inReplyTo
} else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil {
status.InReplyTo = inReplyTo
}
}
if status.BoostOfID != "" && status.BoostOf == nil {
// TODO: do we want to keep this possibly recursive strategy?
if boostOf, cached := s.cache.GetByID(status.BoostOfID); cached {
status.BoostOf = boostOf
} else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil {
status.BoostOf = boostOf
}
}
return status
}
func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
return s.conn.
NewSelect().
@ -85,64 +65,79 @@ func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
}
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
if status, cached := s.cache.GetByID(id); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("status.id = ?", id)
err := q.Scan(ctx)
if err != nil {
return nil, s.conn.ProcessError(err)
}
s.cache.Put(status)
return s.getAttachedStatuses(ctx, status), nil
return s.getStatus(
ctx,
func() (*gtsmodel.Status, bool) {
return s.cache.GetByID(id)
},
func(status *gtsmodel.Status) error {
return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx)
},
)
}
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.cache.GetByURI(uri); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.uri) = LOWER(?)", uri)
err := q.Scan(ctx)
if err != nil {
return nil, s.conn.ProcessError(err)
}
s.cache.Put(status)
return s.getAttachedStatuses(ctx, status), nil
return s.getStatus(
ctx,
func() (*gtsmodel.Status, bool) {
return s.cache.GetByURI(uri)
},
func(status *gtsmodel.Status) error {
return s.newStatusQ(status).Where("LOWER(status.uri) = LOWER(?)", uri).Scan(ctx)
},
)
}
func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.Status, db.Error) {
if status, cached := s.cache.GetByURL(url); cached {
return status, nil
return s.getStatus(
ctx,
func() (*gtsmodel.Status, bool) {
return s.cache.GetByURL(url)
},
func(status *gtsmodel.Status) error {
return s.newStatusQ(status).Where("LOWER(status.url) = LOWER(?)", url).Scan(ctx)
},
)
}
func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Status, bool), dbQuery func(*gtsmodel.Status) error) (*gtsmodel.Status, db.Error) {
// Attempt to fetch cached status
status, cached := cacheGet()
if !cached {
status = &gtsmodel.Status{}
// Not cached! Perform database query
err := dbQuery(status)
if err != nil {
return nil, s.conn.ProcessError(err)
}
// If there is boosted, fetch from DB also
if status.BoostOfID != "" {
boostOf, err := s.GetStatusByID(ctx, status.BoostOfID)
if err == nil {
status.BoostOf = boostOf
}
}
// Place in the cache
s.cache.Put(status)
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.url) = LOWER(?)", url)
err := q.Scan(ctx)
// Set the status author account
author, err := s.accounts.GetAccountByID(ctx, status.AccountID)
if err != nil {
return nil, s.conn.ProcessError(err)
return nil, err
}
s.cache.Put(status)
return s.getAttachedStatuses(ctx, status), nil
// Return the prepared status
status.Account = author
return status, nil
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
transaction := func(ctx context.Context, tx bun.Tx) error {
return s.conn.RunInTx(ctx, func(tx bun.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
if _, err := tx.NewInsert().Model(&gtsmodel.StatusToEmoji{
@ -174,10 +169,10 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
}
}
// Finally, insert the status
_, err := tx.NewInsert().Model(status).Exec(ctx)
return err
}
return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction))
})
}
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
@ -210,12 +205,8 @@ func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Statu
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
// only append children, not the overall parent status
entry := e.Value.(*gtsmodel.Status)
if entry.ID != status.ID {
children = append(children, entry)
}
@ -242,11 +233,7 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status,
for _, child := range immediateChildren {
insertLoop:
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
entry := e.Value.(*gtsmodel.Status)
if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
foundStatuses.InsertAfter(child, e)
break insertLoop

View file

@ -105,10 +105,9 @@ func (suite *StatusTestSuite) TestGetStatusWithMention() {
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.NotEmpty(status.Mentions)
suite.NotEmpty(status.MentionIDs)
suite.NotNil(status.InReplyTo)
suite.NotNil(status.InReplyToAccount)
suite.NotEmpty(status.InReplyToID)
suite.NotEmpty(status.InReplyToAccountID)
}
func (suite *StatusTestSuite) TestGetStatusTwice() {

View file

@ -26,13 +26,13 @@ import (
// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.
type Status interface {
// GetStatusByID returns one status from the database, with all rel fields populated (if possible).
// GetStatusByID returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error)
// GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
// GetStatusByURI returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// GetStatusByURL returns one status from the database, with all rel fields populated (if possible).
// GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)
// PutStatus stores one status in the database.