Pg to bun (#148)

* start moving to bun

* changing more stuff

* more

* and yet more

* tests passing

* seems stable now

* more big changes

* small fix

* little fixes
This commit is contained in:
tobi 2021-08-25 15:34:33 +02:00 committed by GitHub
commit 2dc9fc1626
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
713 changed files with 98694 additions and 22704 deletions

View file

@ -0,0 +1,291 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type accountDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
return a.conn.
NewSelect().
Model(account).
Relation("AvatarMediaAttachment").
Relation("HeaderMediaAttachment")
}
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("account.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
return account, err
}
func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("account.uri = ?", uri)
err := processErrorResponse(q.Scan(ctx))
return account, err
}
func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("account.url = ?", uri)
err := processErrorResponse(q.Scan(ctx))
return account, err
}
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
if strings.TrimSpace(account.ID) == "" {
return nil, errors.New("account had no ID")
}
account.UpdatedAt = time.Now()
q := a.conn.
NewUpdate().
Model(account).
WherePK()
_, err := q.Exec(ctx)
err = processErrorResponse(err)
return account, err
}
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
q := a.newAccountQ(account)
if domain == "" {
q = q.
Where("account.username = ?", domain).
Where("account.domain = ?", domain)
} else {
q = q.
Where("account.username = ?", domain).
Where("? IS NULL", bun.Ident("domain"))
}
err := processErrorResponse(q.Scan(ctx))
return account, err
}
func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) (time.Time, db.Error) {
status := new(gtsmodel.Status)
q := a.conn.
NewSelect().
Model(status).
Order("id DESC").
Limit(1).
Where("account_id = ?", accountID).
Column("created_at")
err := processErrorResponse(q.Scan(ctx))
return status.CreatedAt, err
}
func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
if mediaAttachment.Avatar && mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
var headerOrAVI string
if mediaAttachment.Avatar {
headerOrAVI = "avatar"
} else if mediaAttachment.Header {
headerOrAVI = "header"
} else {
return errors.New("given media attachment was neither a header nor an avatar")
}
// TODO: there are probably more side effects here that need to be handled
if _, err := a.conn.
NewInsert().
Model(mediaAttachment).
Exec(ctx); err != nil {
return err
}
if _, err := a.conn.
NewUpdate().
Model(&gtsmodel.Account{}).
Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).
Where("id = ?", accountID).
Exec(ctx); err != nil {
return err
}
return nil
}
func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
q := a.newAccountQ(account).
Where("username = ?", username).
Where("? IS NULL", bun.Ident("domain"))
err := processErrorResponse(q.Scan(ctx))
return account, err
}
func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*gtsmodel.StatusFave, db.Error) {
faves := new([]*gtsmodel.StatusFave)
if err := a.conn.
NewSelect().
Model(faves).
Where("account_id = ?", accountID).
Scan(ctx); err != nil {
return nil, err
}
return *faves, nil
}
func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) {
return a.conn.
NewSelect().
Model(&gtsmodel.Status{}).
Where("account_id = ?", accountID).
Count(ctx)
}
func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
a.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
q := a.conn.
NewSelect().
Model(&statuses).
Order("id DESC")
if accountID != "" {
q = q.Where("account_id = ?", accountID)
}
if limit != 0 {
q = q.Limit(limit)
}
if excludeReplies {
q = q.Where("? IS NULL", bun.Ident("in_reply_to_id"))
}
if pinnedOnly {
q = q.Where("pinned = ?", true)
}
if maxID != "" {
q = q.Where("id < ?", maxID)
}
if mediaOnly {
q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
return q.
WhereOr("? IS NOT NULL", bun.Ident("attachments")).
WhereOr("attachments != '{}'")
})
}
if err := q.Scan(ctx); err != nil {
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries
}
a.log.Debugf("returning statuses for account %s", accountID)
return statuses, nil
}
func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
blocks := []*gtsmodel.Block{}
fq := a.conn.
NewSelect().
Model(&blocks).
Where("block.account_id = ?", accountID).
Relation("TargetAccount").
Order("block.id DESC")
if maxID != "" {
fq = fq.Where("block.id < ?", maxID)
}
if sinceID != "" {
fq = fq.Where("block.id > ?", sinceID)
}
if limit > 0 {
fq = fq.Limit(limit)
}
err := fq.Scan(ctx)
if err != nil {
return nil, "", "", err
}
if len(blocks) == 0 {
return nil, "", "", db.ErrNoEntries
}
accounts := []*gtsmodel.Account{}
for _, b := range blocks {
accounts = append(accounts, b.TargetAccount)
}
nextMaxID := blocks[len(blocks)-1].ID
prevMinID := blocks[0].ID
return accounts, nextMaxID, prevMinID, nil
}

View file

@ -0,0 +1,86 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb_test
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type AccountTestSuite struct {
BunDBStandardTestSuite
}
func (suite *AccountTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
suite.testAttachments = testrig.NewTestAttachments()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
}
func (suite *AccountTestSuite) SetupTest() {
suite.config = testrig.NewTestConfig()
suite.db = testrig.NewTestDB()
suite.log = testrig.NewTestLog()
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
func (suite *AccountTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(account)
suite.NotNil(account.AvatarMediaAttachment)
suite.NotEmpty(account.AvatarMediaAttachment.URL)
suite.NotNil(account.HeaderMediaAttachment)
suite.NotEmpty(account.HeaderMediaAttachment.URL)
}
func (suite *AccountTestSuite) TestUpdateAccount() {
testAccount := suite.testAccounts["local_account_1"]
testAccount.DisplayName = "new display name!"
_, err := suite.db.UpdateAccount(context.Background(), testAccount)
suite.NoError(err)
updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID)
suite.NoError(err)
suite.Equal("new display name!", updated.DisplayName)
suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second)
}
func TestAccountTestSuite(t *testing.T) {
suite.Run(t, new(AccountTestSuite))
}

272
internal/db/bundb/admin.go Normal file
View file

@ -0,0 +1,272 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"crypto/rand"
"crypto/rsa"
"database/sql"
"fmt"
"net"
"net/mail"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
"golang.org/x/crypto/bcrypt"
)
type adminDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
}
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
q := a.conn.
NewSelect().
Model(&gtsmodel.Account{}).
Where("username = ?", username).
Where("domain = ?", nil)
return notExists(ctx, q)
}
func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
return false, fmt.Errorf("error parsing email address %s: %s", email, err)
}
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
if err := a.conn.
NewSelect().
Model(&gtsmodel.EmailDomainBlock{}).
Where("domain = ?", domain).
Scan(ctx); err == nil {
// fail because we found something
return false, fmt.Errorf("email domain %s is blocked", domain)
} else if err != sql.ErrNoRows {
return false, processErrorResponse(err)
}
// check if this email is associated with a user already
q := a.conn.
NewSelect().
Model(&gtsmodel.User{}).
Where("email = ?", email).
WhereOr("unconfirmed_email = ?", email)
return notExists(ctx, q)
}
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
return nil, err
}
// if something went wrong while creating a user, we might already have an account, so check here first...
acct := &gtsmodel.Account{}
err = a.conn.NewSelect().
Model(acct).
Where("username = ?", username).
Where("? IS NULL", bun.Ident("domain")).
Scan(ctx)
if err != nil {
// we just don't have an account yet create one
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
newAccountID, err := id.NewRandomULID()
if err != nil {
return nil, err
}
acct = &gtsmodel.Account{
ID: newAccountID,
Username: username,
DisplayName: username,
Reason: reason,
URL: newAccountURIs.UserURL,
PrivateKey: key,
PublicKey: &key.PublicKey,
PublicKeyURI: newAccountURIs.PublicKeyURI,
ActorType: gtsmodel.ActivityStreamsPerson,
URI: newAccountURIs.UserURI,
InboxURI: newAccountURIs.InboxURI,
OutboxURI: newAccountURIs.OutboxURI,
FollowersURI: newAccountURIs.FollowersURI,
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
if _, err = a.conn.
NewInsert().
Model(acct).
Exec(ctx); err != nil {
return nil, err
}
}
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("error hashing password: %s", err)
}
newUserID, err := id.NewRandomULID()
if err != nil {
return nil, err
}
u := &gtsmodel.User{
ID: newUserID,
AccountID: acct.ID,
EncryptedPassword: string(pw),
SignUpIP: signUpIP.To4(),
Locale: locale,
UnconfirmedEmail: email,
CreatedByApplicationID: appID,
Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
}
if emailVerified {
u.ConfirmedAt = time.Now()
u.Email = email
}
if admin {
u.Admin = true
u.Moderator = true
}
if _, err = a.conn.
NewInsert().
Model(u).
Exec(ctx); err != nil {
return nil, err
}
return u, nil
}
func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
username := a.config.Host
// check if instance account already exists
existsQ := a.conn.
NewSelect().
Model(&gtsmodel.Account{}).
Where("username = ?", username).
Where("? IS NULL", bun.Ident("domain"))
count, err := existsQ.Count(ctx)
if err != nil && count == 1 {
a.log.Infof("instance account %s already exists", username)
return nil
} else if err != sql.ErrNoRows {
return processErrorResponse(err)
}
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
return err
}
aID, err := id.NewRandomULID()
if err != nil {
return err
}
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
acct := &gtsmodel.Account{
ID: aID,
Username: a.config.Host,
DisplayName: username,
URL: newAccountURIs.UserURL,
PrivateKey: key,
PublicKey: &key.PublicKey,
PublicKeyURI: newAccountURIs.PublicKeyURI,
ActorType: gtsmodel.ActivityStreamsPerson,
URI: newAccountURIs.UserURI,
InboxURI: newAccountURIs.InboxURI,
OutboxURI: newAccountURIs.OutboxURI,
FollowersURI: newAccountURIs.FollowersURI,
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
insertQ := a.conn.
NewInsert().
Model(acct)
if _, err := insertQ.Exec(ctx); err != nil {
return err
}
a.log.Infof("instance account %s CREATED with id %s", username, acct.ID)
return nil
}
func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
domain := a.config.Host
// check if instance entry already exists
existsQ := a.conn.
NewSelect().
Model(&gtsmodel.Instance{}).
Where("domain = ?", domain)
count, err := existsQ.Count(ctx)
if err != nil && count == 1 {
a.log.Infof("instance instance %s already exists", domain)
return nil
} else if err != sql.ErrNoRows {
return processErrorResponse(err)
}
iID, err := id.NewRandomULID()
if err != nil {
return err
}
i := &gtsmodel.Instance{
ID: iID,
Domain: domain,
Title: domain,
URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host),
}
insertQ := a.conn.
NewInsert().
Model(i)
if _, err := insertQ.Exec(ctx); err != nil {
return err
}
a.log.Infof("created instance instance %s with id %s", domain, i.ID)
return nil
}

179
internal/db/bundb/basic.go Normal file
View file

@ -0,0 +1,179 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"errors"
"strings"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
)
type basicDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
}
func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewInsert().Model(i).Exec(ctx)
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Error {
q := b.conn.
NewSelect().
Model(i).
Where("id = ?", id)
return processErrorResponse(q.Scan(ctx))
}
func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := b.conn.NewSelect().Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", bun.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
}
}
}
return processErrorResponse(q.Scan(ctx))
}
func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
q := b.conn.
NewSelect().
Model(i)
return processErrorResponse(q.Scan(ctx))
}
func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error {
q := b.conn.
NewDelete().
Model(i).
Where("id = ?", id)
_, err := q.Exec(ctx)
return processErrorResponse(err)
}
func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := b.conn.
NewDelete().
Model(i)
for _, w := range where {
q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
}
_, err := q.Exec(ctx)
return processErrorResponse(err)
}
func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error {
q := b.conn.
NewUpdate().
Model(i).
WherePK()
_, err := q.Exec(ctx)
return processErrorResponse(err)
}
func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error {
q := b.conn.NewUpdate().
Model(i).
Set("? = ?", bun.Safe(key), value).
WherePK()
_, err := q.Exec(ctx)
return processErrorResponse(err)
}
func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
q := b.conn.NewUpdate().Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", bun.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", bun.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", bun.Safe(w.Key), w.Value)
}
}
}
q = q.Set("? = ?", bun.Safe(key), value)
_, err := q.Exec(ctx)
return processErrorResponse(err)
}
func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewCreateTable().Model(i).IfNotExists().Exec(ctx)
return err
}
func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
return processErrorResponse(err)
}
func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
return b.conn.Ping()
}
func (b *basicDB) Stop(ctx context.Context) db.Error {
b.log.Info("closing db connection")
if err := b.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
return err
}
return nil
}

View file

@ -0,0 +1,68 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb_test
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type BasicTestSuite struct {
BunDBStandardTestSuite
}
func (suite *BasicTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
suite.testAttachments = testrig.NewTestAttachments()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
}
func (suite *BasicTestSuite) SetupTest() {
suite.config = testrig.NewTestConfig()
suite.db = testrig.NewTestDB()
suite.log = testrig.NewTestLog()
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
func (suite *BasicTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *BasicTestSuite) TestGetAccountByID() {
testAccount := suite.testAccounts["local_account_1"]
a := &gtsmodel.Account{}
err := suite.db.GetByID(context.Background(), testAccount.ID, a)
suite.NoError(err)
}
func TestBasicTestSuite(t *testing.T) {
suite.Run(t, new(BasicTestSuite))
}

410
internal/db/bundb/bundb.go Normal file
View file

@ -0,0 +1,410 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"encoding/pem"
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
)
const (
dbTypePostgres = "postgres"
dbTypeSqlite = "sqlite"
)
var registerTables []interface{} = []interface{}{
&gtsmodel.StatusToEmoji{},
&gtsmodel.StatusToTag{},
}
// bunDBService satisfies the DB interface
type bunDBService struct {
db.Account
db.Admin
db.Basic
db.Domain
db.Instance
db.Media
db.Mention
db.Notification
db.Relationship
db.Session
db.Status
db.Timeline
config *config.Config
conn *bun.DB
log *logrus.Logger
}
// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
var sqldb *sql.DB
var conn *bun.DB
// depending on the database type we're trying to create, we need to use a different driver...
switch strings.ToLower(c.DBConfig.Type) {
case dbTypePostgres:
// POSTGRES
opts, err := deriveBunDBPGOptions(c)
if err != nil {
return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
}
sqldb = stdlib.OpenDB(*opts)
conn = bun.NewDB(sqldb, pgdialect.New())
case dbTypeSqlite:
// SQLITE
// TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite
default:
return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
}
// actually *begin* the connection so that we can tell if the db is there and listening
if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("db connection error: %s", err)
}
log.Info("connected to database")
for _, t := range registerTables {
// https://bun.uptrace.dev/orm/many-to-many-relation/
conn.RegisterModel(t)
}
ps := &bunDBService{
Account: &accountDB{
config: c,
conn: conn,
log: log,
},
Admin: &adminDB{
config: c,
conn: conn,
log: log,
},
Basic: &basicDB{
config: c,
conn: conn,
log: log,
},
Domain: &domainDB{
config: c,
conn: conn,
log: log,
},
Instance: &instanceDB{
config: c,
conn: conn,
log: log,
},
Media: &mediaDB{
config: c,
conn: conn,
log: log,
},
Mention: &mentionDB{
config: c,
conn: conn,
log: log,
},
Notification: &notificationDB{
config: c,
conn: conn,
log: log,
},
Relationship: &relationshipDB{
config: c,
conn: conn,
log: log,
},
Session: &sessionDB{
config: c,
conn: conn,
log: log,
},
Status: &statusDB{
config: c,
conn: conn,
log: log,
},
Timeline: &timelineDB{
config: c,
conn: conn,
log: log,
},
config: c,
conn: conn,
log: log,
}
// we can confidently return this useable service now
return ps, nil
}
/*
HANDY STUFF
*/
// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options
// with sensible defaults, or an error if it's not satisfied by the provided config.
func deriveBunDBPGOptions(c *config.Config) (*pgx.ConnConfig, error) {
if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres {
return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type)
}
// validate port
if c.DBConfig.Port == 0 {
return nil, errors.New("no port set")
}
// validate address
if c.DBConfig.Address == "" {
return nil, errors.New("no address set")
}
// validate username
if c.DBConfig.User == "" {
return nil, errors.New("no user set")
}
// validate that there's a password
if c.DBConfig.Password == "" {
return nil, errors.New("no password set")
}
// validate database
if c.DBConfig.Database == "" {
return nil, errors.New("no database set")
}
var tlsConfig *tls.Config
switch c.DBConfig.TLSMode {
case config.DBTLSModeDisable, config.DBTLSModeUnset:
break // nothing to do
case config.DBTLSModeEnable:
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
case config.DBTLSModeRequire:
tlsConfig = &tls.Config{
InsecureSkipVerify: false,
ServerName: c.DBConfig.Address,
}
}
if tlsConfig != nil && c.DBConfig.TLSCACert != "" {
// load the system cert pool first -- we'll append the given CA cert to this
certPool, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("error fetching system CA cert pool: %s", err)
}
// open the file itself and make sure there's something in it
caCertBytes, err := os.ReadFile(c.DBConfig.TLSCACert)
if err != nil {
return nil, fmt.Errorf("error opening CA certificate at %s: %s", c.DBConfig.TLSCACert, err)
}
if len(caCertBytes) == 0 {
return nil, fmt.Errorf("ca cert at %s was empty", c.DBConfig.TLSCACert)
}
// make sure we have a PEM block
caPem, _ := pem.Decode(caCertBytes)
if caPem == nil {
return nil, fmt.Errorf("could not parse cert at %s into PEM", c.DBConfig.TLSCACert)
}
// parse the PEM block into the certificate
caCert, err := x509.ParseCertificate(caPem.Bytes)
if err != nil {
return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", c.DBConfig.TLSCACert, err)
}
// we're happy, add it to the existing pool and then use this pool in our tls config
certPool.AddCert(caCert)
tlsConfig.RootCAs = certPool
}
cfg, _ := pgx.ParseConfig("")
cfg.Host = c.DBConfig.Address
cfg.Port = uint16(c.DBConfig.Port)
cfg.User = c.DBConfig.User
cfg.Password = c.DBConfig.Password
cfg.TLSConfig = tlsConfig
cfg.Database = c.DBConfig.Database
cfg.PreferSimpleProtocol = true
return cfg, nil
}
/*
CONVERSION FUNCTIONS
*/
// TODO: move these to the type converter, it's bananas that they're here and not there
func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
ogAccount := &gtsmodel.Account{}
if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil {
return nil, err
}
menchies := []*gtsmodel.Mention{}
for _, a := range targetAccounts {
// A mentioned account looks like "@test@example.org" or just "@test" for a local account
// -- we can guarantee this from the regex that targetAccounts should have been derived from.
// But we still need to do a bit of fiddling to get what we need here -- the username and domain (if given).
// 1. trim off the first @
t := strings.TrimPrefix(a, "@")
// 2. split the username and domain
s := strings.Split(t, "@")
// 3. if it's length 1 it's a local account, length 2 means remote, anything else means something is wrong
var local bool
switch len(s) {
case 1:
local = true
case 2:
local = false
default:
return nil, fmt.Errorf("mentioned account format '%s' was not valid", a)
}
var username, domain string
username = s[0]
if !local {
domain = s[1]
}
// 4. check we now have a proper username and domain
if username == "" || (!local && domain == "") {
return nil, fmt.Errorf("username or domain for '%s' was nil", a)
}
// okay we're good now, we can start pulling accounts out of the database
mentionedAccount := &gtsmodel.Account{}
var err error
// match username + account, case insensitive
if local {
// local user -- should have a null domain
err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("? IS NULL", bun.Ident("domain")).Scan(ctx)
} else {
// remote user -- should have domain defined
err = ps.conn.NewSelect().Model(mentionedAccount).Where("LOWER(?) = LOWER(?)", bun.Ident("username"), username).Where("LOWER(?) = LOWER(?)", bun.Ident("domain"), domain).Scan(ctx)
}
if err != nil {
if err == sql.ErrNoRows {
// no result found for this username/domain so just don't include it as a mencho and carry on about our business
ps.log.Debugf("no account found with username '%s' and domain '%s', skipping it", username, domain)
continue
}
// a serious error has happened so bail
return nil, fmt.Errorf("error getting account with username '%s' and domain '%s': %s", username, domain, err)
}
// id, createdAt and updatedAt will be populated by the db, so we have everything we need!
menchies = append(menchies, &gtsmodel.Mention{
StatusID: statusID,
OriginAccountID: ogAccount.ID,
OriginAccountURI: ogAccount.URI,
TargetAccountID: mentionedAccount.ID,
NameString: a,
TargetAccountURI: mentionedAccount.URI,
TargetAccountURL: mentionedAccount.URL,
OriginAccount: mentionedAccount,
})
}
return menchies, nil
}
func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
newTags := []*gtsmodel.Tag{}
for _, t := range tags {
tag := &gtsmodel.Tag{}
// we can use selectorinsert here to create the new tag if it doesn't exist already
// inserted will be true if this is a new tag we just created
if err := ps.conn.NewSelect().Model(tag).Where("LOWER(?) = LOWER(?)", bun.Ident("name"), t).Scan(ctx); err != nil {
if err == sql.ErrNoRows {
// tag doesn't exist yet so populate it
newID, err := id.NewRandomULID()
if err != nil {
return nil, err
}
tag.ID = newID
tag.URL = fmt.Sprintf("%s://%s/tags/%s", ps.config.Protocol, ps.config.Host, t)
tag.Name = t
tag.FirstSeenFromAccountID = originAccountID
tag.CreatedAt = time.Now()
tag.UpdatedAt = time.Now()
tag.Useable = true
tag.Listable = true
} else {
return nil, fmt.Errorf("error getting tag with name %s: %s", t, err)
}
}
// bail already if the tag isn't useable
if !tag.Useable {
continue
}
tag.LastStatusAt = time.Now()
newTags = append(newTags, tag)
}
return newTags, nil
}
func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
newEmojis := []*gtsmodel.Emoji{}
for _, e := range emojis {
emoji := &gtsmodel.Emoji{}
err := ps.conn.NewSelect().Model(emoji).Where("shortcode = ?", e).Where("visible_in_picker = true").Where("disabled = false").Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
// no result found for this username/domain so just don't include it as an emoji and carry on about our business
ps.log.Debugf("no emoji found with shortcode %s, skipping it", e)
continue
}
// a serious error has happened so bail
return nil, fmt.Errorf("error getting emoji with shortcode %s: %s", e, err)
}
newEmojis = append(newEmojis, emoji)
}
return newEmojis, nil
}

View file

@ -0,0 +1,47 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb_test
import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
type BunDBStandardTestSuite struct {
// standard suite interfaces
suite.Suite
config *config.Config
db db.DB
log *logrus.Logger
// standard suite models
testTokens map[string]*oauth.Token
testClients map[string]*oauth.Client
testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account
testAttachments map[string]*gtsmodel.MediaAttachment
testStatuses map[string]*gtsmodel.Status
testTags map[string]*gtsmodel.Tag
testMentions map[string]*gtsmodel.Mention
}

View file

@ -0,0 +1,81 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"net/url"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
type domainDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
}
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
if domain == "" {
return false, nil
}
q := d.conn.
NewSelect().
Model(&gtsmodel.DomainBlock{}).
Where("LOWER(domain) = LOWER(?)", domain).
Limit(1)
return exists(ctx, q)
}
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
// filter out any doubles
uniqueDomains := util.UniqueStrings(domains)
for _, domain := range uniqueDomains {
if blocked, err := d.IsDomainBlocked(ctx, domain); err != nil {
return false, err
} else if blocked {
return blocked, nil
}
}
// no blocks found
return false, nil
}
func (d *domainDB) IsURIBlocked(ctx context.Context, uri *url.URL) (bool, db.Error) {
domain := uri.Hostname()
return d.IsDomainBlocked(ctx, domain)
}
func (d *domainDB) AreURIsBlocked(ctx context.Context, uris []*url.URL) (bool, db.Error) {
domains := []string{}
for _, uri := range uris {
domains = append(domains, uri.Hostname())
}
return d.AreDomainsBlocked(ctx, domains)
}

View file

@ -0,0 +1,118 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type instanceDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
}
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
q := i.conn.
NewSelect().
Model(&[]*gtsmodel.Account{})
if domain == i.config.Host {
// if the domain is *this* domain, just count where the domain field is null
q = q.Where("? IS NULL", bun.Ident("domain"))
} else {
q = q.Where("domain = ?", domain)
}
// don't count the instance account or suspended users
q = q.
Where("username != ?", domain).
Where("? IS NULL", bun.Ident("suspended_at"))
count, err := q.Count(ctx)
return count, processErrorResponse(err)
}
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
q := i.conn.
NewSelect().
Model(&[]*gtsmodel.Status{})
if domain == i.config.Host {
// if the domain is *this* domain, just count where local is true
q = q.Where("local = ?", true)
} else {
// join on the domain of the account
q = q.Join("JOIN accounts AS account ON account.id = status.account_id").
Where("account.domain = ?", domain)
}
count, err := q.Count(ctx)
return count, processErrorResponse(err)
}
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
q := i.conn.
NewSelect().
Model(&[]*gtsmodel.Instance{})
if domain == i.config.Host {
// if the domain is *this* domain, just count other instances it knows about
// exclude domains that are blocked
q = q.Where("domain != ?", domain).Where("? IS NULL", bun.Ident("suspended_at"))
} else {
// TODO: implement federated domain counting properly for remote domains
return 0, nil
}
count, err := q.Count(ctx)
return count, processErrorResponse(err)
}
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
i.log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{}
q := i.conn.NewSelect().
Model(&accounts).
Where("domain = ?", domain).
Order("id DESC")
if maxID != "" {
q = q.Where("id < ?", maxID)
}
if limit > 0 {
q = q.Limit(limit)
}
err := processErrorResponse(q.Scan(ctx))
return accounts, err
}

View file

@ -0,0 +1,53 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type mediaDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
}
func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery {
return m.conn.
NewSelect().
Model(i).
Relation("Account")
}
func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, db.Error) {
attachment := &gtsmodel.MediaAttachment{}
q := m.newMediaQ(attachment).
Where("media_attachment.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
return attachment, err
}

View file

@ -0,0 +1,108 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type mentionDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cache cache.Cache
}
func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) {
if m.cache == nil {
m.cache = cache.New()
}
if err := m.cache.Store(id, mention); err != nil {
m.log.Panicf("mentionDB: error storing in cache: %s", err)
}
}
func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) {
if m.cache == nil {
m.cache = cache.New()
return nil, false
}
mI, err := m.cache.Fetch(id)
if err != nil || mI == nil {
return nil, false
}
mention, ok := mI.(*gtsmodel.Mention)
if !ok {
m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id)
}
return mention, true
}
func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
return m.conn.
NewSelect().
Model(i).
Relation("Status").
Relation("OriginAccount").
Relation("TargetAccount")
}
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
if mention, cached := m.mentionCached(id); cached {
return mention, nil
}
mention := &gtsmodel.Mention{}
q := m.newMentionQ(mention).
Where("mention.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
if err == nil && mention != nil {
m.cacheMention(id, mention)
}
return mention, err
}
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
mentions := []*gtsmodel.Mention{}
for _, i := range ids {
mention, err := m.GetMention(ctx, i)
if err != nil {
return nil, processErrorResponse(err)
}
mentions = append(mentions, mention)
}
return mentions, nil
}

View file

@ -0,0 +1,136 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type notificationDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cache cache.Cache
}
func (n *notificationDB) cacheNotification(id string, notification *gtsmodel.Notification) {
if n.cache == nil {
n.cache = cache.New()
}
if err := n.cache.Store(id, notification); err != nil {
n.log.Panicf("notificationDB: error storing in cache: %s", err)
}
}
func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, bool) {
if n.cache == nil {
n.cache = cache.New()
return nil, false
}
nI, err := n.cache.Fetch(id)
if err != nil || nI == nil {
return nil, false
}
notification, ok := nI.(*gtsmodel.Notification)
if !ok {
n.log.Panicf("notificationDB: cached interface with key %s was not a notification", id)
}
return notification, true
}
func (n *notificationDB) newNotificationQ(i interface{}) *bun.SelectQuery {
return n.conn.
NewSelect().
Model(i).
Relation("OriginAccount").
Relation("TargetAccount").
Relation("Status")
}
func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
if notification, cached := n.notificationCached(id); cached {
return notification, nil
}
notification := &gtsmodel.Notification{}
q := n.newNotificationQ(notification).
Where("notification.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
if err == nil && notification != nil {
n.cacheNotification(id, notification)
}
return notification, err
}
func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
// begin by selecting just the IDs
notifIDs := []*gtsmodel.Notification{}
q := n.conn.
NewSelect().
Model(&notifIDs).
Column("id").
Where("target_account_id = ?", accountID).
Order("id DESC")
if maxID != "" {
q = q.Where("id < ?", maxID)
}
if sinceID != "" {
q = q.Where("id > ?", sinceID)
}
if limit != 0 {
q = q.Limit(limit)
}
err := processErrorResponse(q.Scan(ctx))
if err != nil {
return nil, err
}
// now we have the IDs, select the notifs one by one
// reason for this is that for each notif, we can instead get it from our cache if it's cached
notifications := []*gtsmodel.Notification{}
for _, notifID := range notifIDs {
notif, err := n.GetNotification(ctx, notifID.ID)
errP := processErrorResponse(err)
if errP != nil {
return nil, errP
}
notifications = append(notifications, notif)
}
return notifications, nil
}

View file

@ -0,0 +1,328 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"database/sql"
"fmt"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type relationshipDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
}
func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery {
return r.conn.
NewSelect().
Model(block).
Relation("Account").
Relation("TargetAccount")
}
func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery {
return r.conn.
NewSelect().
Model(follow).
Relation("Account").
Relation("TargetAccount")
}
func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
q := r.conn.
NewSelect().
Model(&gtsmodel.Block{}).
Where("account_id = ?", account1).
Where("target_account_id = ?", account2).
Limit(1)
if eitherDirection {
q = q.
WhereOr("target_account_id = ?", account1).
Where("account_id = ?", account2)
}
return exists(ctx, q)
}
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
block := &gtsmodel.Block{}
q := r.newBlockQ(block).
Where("block.account_id = ?", account1).
Where("block.target_account_id = ?", account2)
err := processErrorResponse(q.Scan(ctx))
return block, err
}
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
rel := &gtsmodel.Relationship{
ID: targetAccount,
}
// check if the requesting account follows the target account
follow := &gtsmodel.Follow{}
if err := r.conn.
NewSelect().
Model(follow).
Where("account_id = ?", requestingAccount).
Where("target_account_id = ?", targetAccount).
Limit(1).
Scan(ctx); err != nil {
if err != sql.ErrNoRows {
// a proper error
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
}
// no follow exists so these are all false
rel.Following = false
rel.ShowingReblogs = false
rel.Notifying = false
} else {
// follow exists so we can fill these fields out...
rel.Following = true
rel.ShowingReblogs = follow.ShowReblogs
rel.Notifying = follow.Notify
}
// check if the target account follows the requesting account
count, err := r.conn.
NewSelect().
Model(&gtsmodel.Follow{}).
Where("account_id = ?", targetAccount).
Where("target_account_id = ?", requestingAccount).
Limit(1).
Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
}
rel.FollowedBy = count > 0
// check if the requesting account blocks the target account
count, err = r.conn.NewSelect().
Model(&gtsmodel.Block{}).
Where("account_id = ?", requestingAccount).
Where("target_account_id = ?", targetAccount).
Limit(1).
Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
}
rel.Blocking = count > 0
// check if the target account blocks the requesting account
count, err = r.conn.
NewSelect().
Model(&gtsmodel.Block{}).
Where("account_id = ?", targetAccount).
Where("target_account_id = ?", requestingAccount).
Limit(1).
Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
rel.BlockedBy = count > 0
// check if there's a pending following request from requesting account to target account
count, err = r.conn.
NewSelect().
Model(&gtsmodel.FollowRequest{}).
Where("account_id = ?", requestingAccount).
Where("target_account_id = ?", targetAccount).
Limit(1).
Count(ctx)
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
rel.Requested = count > 0
return rel, nil
}
func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
NewSelect().
Model(&gtsmodel.Follow{}).
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID).
Limit(1)
return exists(ctx, q)
}
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
NewSelect().
Model(&gtsmodel.FollowRequest{}).
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID)
return exists(ctx, q)
}
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
if account1 == nil || account2 == nil {
return false, nil
}
// make sure account 1 follows account 2
f1, err := r.IsFollowing(ctx, account1, account2)
if err != nil {
return false, processErrorResponse(err)
}
// make sure account 2 follows account 1
f2, err := r.IsFollowing(ctx, account2, account1)
if err != nil {
return false, processErrorResponse(err)
}
return f1 && f2, nil
}
func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
// make sure the original follow request exists
fr := &gtsmodel.FollowRequest{}
if err := r.conn.
NewSelect().
Model(fr).
Where("account_id = ?", originAccountID).
Where("target_account_id = ?", targetAccountID).
Scan(ctx); err != nil {
return nil, processErrorResponse(err)
}
// create a new follow to 'replace' the request with
follow := &gtsmodel.Follow{
ID: fr.ID,
AccountID: originAccountID,
TargetAccountID: targetAccountID,
URI: fr.URI,
}
// if the follow already exists, just update the URI -- we don't need to do anything else
if _, err := r.conn.
NewInsert().
Model(follow).
On("CONFLICT CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
Exec(ctx); err != nil {
return nil, processErrorResponse(err)
}
// now remove the follow request
if _, err := r.conn.
NewDelete().
Model(&gtsmodel.FollowRequest{}).
Where("account_id = ?", originAccountID).
Where("target_account_id = ?", targetAccountID).
Exec(ctx); err != nil {
return nil, processErrorResponse(err)
}
return follow, nil
}
func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
followRequests := []*gtsmodel.FollowRequest{}
q := r.newFollowQ(&followRequests).
Where("target_account_id = ?", accountID)
err := processErrorResponse(q.Scan(ctx))
return followRequests, err
}
func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
q := r.newFollowQ(&follows).
Where("account_id = ?", accountID)
err := processErrorResponse(q.Scan(ctx))
return follows, err
}
func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
return r.conn.
NewSelect().
Model(&[]*gtsmodel.Follow{}).
Where("account_id = ?", accountID).
Count(ctx)
}
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
q := r.conn.
NewSelect().
Model(&follows)
if localOnly {
// for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
whereGroup := func(q *bun.SelectQuery) *bun.SelectQuery {
q = q.
WhereOr("? IS NULL", bun.Ident("a.domain")).
WhereOr("a.domain = ?", "")
return q
}
q = q.ColumnExpr("follow.*").
Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
Where("follow.target_account_id = ?", accountID).
WhereGroup(" AND ", whereGroup)
} else {
q = q.Where("target_account_id = ?", accountID)
}
if err := q.Scan(ctx); err != nil {
if err == sql.ErrNoRows {
return follows, nil
}
return nil, processErrorResponse(err)
}
return follows, nil
}
func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) {
return r.conn.
NewSelect().
Model(&[]*gtsmodel.Follow{}).
Where("target_account_id = ?", accountID).
Count(ctx)
}

View file

@ -0,0 +1,85 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"crypto/rand"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/uptrace/bun"
)
type sessionDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
}
func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
rs := new(gtsmodel.RouterSession)
q := s.conn.
NewSelect().
Model(rs).
Limit(1)
_, err := q.Exec(ctx)
err = processErrorResponse(err)
return rs, err
}
func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
auth := make([]byte, 32)
crypt := make([]byte, 32)
if _, err := rand.Read(auth); err != nil {
return nil, err
}
if _, err := rand.Read(crypt); err != nil {
return nil, err
}
rid, err := id.NewULID()
if err != nil {
return nil, err
}
rs := &gtsmodel.RouterSession{
ID: rid,
Auth: auth,
Crypt: crypt,
}
q := s.conn.
NewInsert().
Model(rs)
_, err = q.Exec(ctx)
err = processErrorResponse(err)
return rs, err
}

375
internal/db/bundb/status.go Normal file
View file

@ -0,0 +1,375 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"container/list"
"context"
"errors"
"time"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type statusDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cache cache.Cache
}
func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
if s.cache == nil {
s.cache = cache.New()
}
if err := s.cache.Store(id, status); err != nil {
s.log.Panicf("statusDB: error storing in cache: %s", err)
}
}
func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
if s.cache == nil {
s.cache = cache.New()
return nil, false
}
sI, err := s.cache.Fetch(id)
if err != nil || sI == nil {
return nil, false
}
status, ok := sI.(*gtsmodel.Status)
if !ok {
s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
}
return status, true
}
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
return s.conn.
NewSelect().
Model(status).
Relation("Attachments").
Relation("Tags").
Relation("Mentions").
Relation("Emojis").
Relation("Account").
Relation("InReplyToAccount").
Relation("BoostOfAccount").
Relation("CreatedWithApplication")
}
func (s *statusDB) getAttachedStatuses(ctx context.Context, status *gtsmodel.Status) *gtsmodel.Status {
if status.InReplyToID != "" && status.InReplyTo == nil {
if inReplyTo, cached := s.statusCached(status.InReplyToID); cached {
status.InReplyTo = inReplyTo
} else if inReplyTo, err := s.GetStatusByID(ctx, status.InReplyToID); err == nil {
status.InReplyTo = inReplyTo
}
}
if status.BoostOfID != "" && status.BoostOf == nil {
if boostOf, cached := s.statusCached(status.BoostOfID); cached {
status.BoostOf = boostOf
} else if boostOf, err := s.GetStatusByID(ctx, status.BoostOfID); err == nil {
status.BoostOf = boostOf
}
}
return status
}
func (s *statusDB) newFaveQ(faves interface{}) *bun.SelectQuery {
return s.conn.
NewSelect().
Model(faves).
Relation("Account").
Relation("TargetAccount").
Relation("Status")
}
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(id); cached {
return status, nil
}
status := new(gtsmodel.Status)
q := s.newStatusQ(status).
Where("status.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
if err != nil {
return nil, err
}
if status != nil {
s.cacheStatus(id, status)
}
return s.getAttachedStatuses(ctx, status), err
}
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.uri) = LOWER(?)", uri)
err := processErrorResponse(q.Scan(ctx))
if err != nil {
return nil, err
}
if status != nil {
s.cacheStatus(uri, status)
}
return s.getAttachedStatuses(ctx, status), err
}
func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.url) = LOWER(?)", uri)
err := processErrorResponse(q.Scan(ctx))
if err != nil {
return nil, err
}
if status != nil {
s.cacheStatus(uri, status)
}
return s.getAttachedStatuses(ctx, status), err
}
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
transaction := func(ctx context.Context, tx bun.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
if _, err := tx.NewInsert().Model(&gtsmodel.StatusToEmoji{
StatusID: status.ID,
EmojiID: i,
}).Exec(ctx); err != nil {
return err
}
}
// create links between this status and any tags it uses
for _, i := range status.TagIDs {
if _, err := tx.NewInsert().Model(&gtsmodel.StatusToTag{
StatusID: status.ID,
TagID: i,
}).Exec(ctx); err != nil {
return err
}
}
// change the status ID of the media attachments to the new status
for _, a := range status.Attachments {
a.StatusID = status.ID
a.UpdatedAt = time.Now()
if _, err := s.conn.NewUpdate().Model(a).
Where("id = ?", a.ID).
Exec(ctx); err != nil {
return err
}
}
_, err := tx.NewInsert().Model(status).Exec(ctx)
return err
}
return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
}
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
parents := []*gtsmodel.Status{}
s.statusParent(ctx, status, &parents, onlyDirect)
return parents, nil
}
func (s *statusDB) statusParent(ctx context.Context, status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
if status.InReplyToID == "" {
return
}
parentStatus, err := s.GetStatusByID(ctx, status.InReplyToID)
if err == nil {
*foundStatuses = append(*foundStatuses, parentStatus)
}
if onlyDirect {
return
}
s.statusParent(ctx, parentStatus, foundStatuses, false)
}
func (s *statusDB) GetStatusChildren(ctx context.Context, status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
foundStatuses := &list.List{}
foundStatuses.PushFront(status)
s.statusChildren(ctx, status, foundStatuses, onlyDirect, minID)
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
// only append children, not the overall parent status
if entry.ID != status.ID {
children = append(children, entry)
}
}
return children, nil
}
func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
immediateChildren := []*gtsmodel.Status{}
q := s.conn.
NewSelect().
Model(&immediateChildren).
Where("in_reply_to_id = ?", status.ID)
if minID != "" {
q = q.Where("status.id > ?", minID)
}
if err := q.Scan(ctx); err != nil {
return
}
for _, child := range immediateChildren {
insertLoop:
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
foundStatuses.InsertAfter(child, e)
break insertLoop
}
}
// only do one loop if we only want direct children
if onlyDirect {
return
}
s.statusChildren(ctx, child, foundStatuses, false, minID)
}
}
func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
return s.conn.NewSelect().Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx)
}
func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
return s.conn.NewSelect().Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx)
}
func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) {
return s.conn.NewSelect().Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx)
}
func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
Model(&gtsmodel.StatusFave{}).
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
}
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
Model(&gtsmodel.Status{}).
Where("boost_of_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
}
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
Model(&gtsmodel.StatusMute{}).
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
}
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
q := s.conn.
NewSelect().
Model(&gtsmodel.StatusBookmark{}).
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
}
func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
faves := []*gtsmodel.StatusFave{}
q := s.newFaveQ(&faves).
Where("status_id = ?", status.ID)
err := processErrorResponse(q.Scan(ctx))
return faves, err
}
func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
reblogs := []*gtsmodel.Status{}
q := s.newStatusQ(&reblogs).
Where("boost_of_id = ?", status.ID)
err := processErrorResponse(q.Scan(ctx))
return reblogs, err
}

View file

@ -0,0 +1,136 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb_test
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type StatusTestSuite struct {
BunDBStandardTestSuite
}
func (suite *StatusTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
suite.testAttachments = testrig.NewTestAttachments()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
}
func (suite *StatusTestSuite) SetupTest() {
suite.config = testrig.NewTestConfig()
suite.db = testrig.NewTestDB()
suite.log = testrig.NewTestLog()
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
func (suite *StatusTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *StatusTestSuite) TestGetStatusByID() {
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_1_status_1"].ID)
if err != nil {
fmt.Println(err.Error())
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.Nil(status.BoostOf)
suite.Nil(status.BoostOfAccount)
suite.Nil(status.InReplyTo)
suite.Nil(status.InReplyToAccount)
}
func (suite *StatusTestSuite) TestGetStatusByURI() {
status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.Nil(status.BoostOf)
suite.Nil(status.BoostOfAccount)
suite.Nil(status.InReplyTo)
suite.Nil(status.InReplyToAccount)
}
func (suite *StatusTestSuite) TestGetStatusWithExtras() {
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["admin_account_status_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.NotEmpty(status.Tags)
suite.NotEmpty(status.Attachments)
suite.NotEmpty(status.Emojis)
}
func (suite *StatusTestSuite) TestGetStatusWithMention() {
status, err := suite.db.GetStatusByID(context.Background(), suite.testStatuses["local_account_2_status_5"].ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.NotEmpty(status.Mentions)
suite.NotEmpty(status.MentionIDs)
suite.NotNil(status.InReplyTo)
suite.NotNil(status.InReplyToAccount)
}
func (suite *StatusTestSuite) TestGetStatusTwice() {
before1 := time.Now()
_, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after1 := time.Now()
duration1 := after1.Sub(before1)
fmt.Println(duration1.Milliseconds())
before2 := time.Now()
_, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after2 := time.Now()
duration2 := after2.Sub(before2)
fmt.Println(duration2.Milliseconds())
// second retrieval should be several orders faster since it will be cached now
suite.Less(duration2, duration1)
}
func TestStatusTestSuite(t *testing.T) {
suite.Run(t, new(StatusTestSuite))
}

View file

@ -0,0 +1,199 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"database/sql"
"sort"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type timelineDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
}
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
q := t.conn.
NewSelect().
Model(&statuses)
q = q.ColumnExpr("status.*").
// Find out who accountID follows.
Join("LEFT JOIN follows AS f ON f.target_account_id = status.account_id").
// Sort by highest ID (newest) to lowest ID (oldest)
Order("status.id DESC")
if maxID != "" {
// return only statuses LOWER (ie., older) than maxID
q = q.Where("status.id < ?", maxID)
}
if sinceID != "" {
// return only statuses HIGHER (ie., newer) than sinceID
q = q.Where("status.id > ?", sinceID)
}
if minID != "" {
// return only statuses HIGHER (ie., newer) than minID
q = q.Where("status.id > ?", minID)
}
if local {
// return only statuses posted by local account havers
q = q.Where("status.local = ?", local)
}
if limit > 0 {
// limit amount of statuses returned
q = q.Limit(limit)
}
// Use a WhereGroup here to specify that we want EITHER statuses posted by accounts that accountID follows,
// OR statuses posted by accountID itself (since a user should be able to see their own statuses).
//
// This is equivalent to something like WHERE ... AND (... OR ...)
// See: https://bun.uptrace.dev/guide/queries.html#select
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
return q.
WhereOr("f.account_id = ?", accountID).
WhereOr("status.account_id = ?", accountID)
}
q = q.WhereGroup(" AND ", whereGroup)
return statuses, processErrorResponse(q.Scan(ctx))
}
func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
q := t.conn.
NewSelect().
Model(&statuses).
Where("visibility = ?", gtsmodel.VisibilityPublic).
Where("? IS NULL", bun.Ident("in_reply_to_id")).
Where("? IS NULL", bun.Ident("in_reply_to_uri")).
Where("? IS NULL", bun.Ident("boost_of_id")).
Order("status.id DESC")
if maxID != "" {
q = q.Where("status.id < ?", maxID)
}
if sinceID != "" {
q = q.Where("status.id > ?", sinceID)
}
if minID != "" {
q = q.Where("status.id > ?", minID)
}
if local {
q = q.Where("status.local = ?", local)
}
if limit > 0 {
q = q.Limit(limit)
}
return statuses, processErrorResponse(q.Scan(ctx))
}
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
faves := []*gtsmodel.StatusFave{}
fq := t.conn.
NewSelect().
Model(&faves).
Where("account_id = ?", accountID).
Order("id DESC")
if maxID != "" {
fq = fq.Where("id < ?", maxID)
}
if minID != "" {
fq = fq.Where("id > ?", minID)
}
if limit > 0 {
fq = fq.Limit(limit)
}
err := fq.Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(faves) == 0 {
return nil, "", "", db.ErrNoEntries
}
// map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID
statusesFavesMap := map[string]string{}
in := []string{}
for _, f := range faves {
statusesFavesMap[f.StatusID] = f.ID
in = append(in, f.StatusID)
}
statuses := []*gtsmodel.Status{}
err = t.conn.
NewSelect().
Model(&statuses).
Where("id IN (?)", bun.In(in)).
Scan(ctx)
if err != nil {
if err == sql.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(statuses) == 0 {
return nil, "", "", db.ErrNoEntries
}
// arrange statuses by fave ID
sort.Slice(statuses, func(i int, j int) bool {
statusI := statuses[i]
statusJ := statuses[j]
return statusesFavesMap[statusI.ID] < statusesFavesMap[statusJ.ID]
})
nextMaxID := faves[len(faves)-1].ID
prevMinID := faves[0].ID
return statuses, nextMaxID, prevMinID, nil
}

78
internal/db/bundb/util.go Normal file
View file

@ -0,0 +1,78 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"strings"
"database/sql"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
)
// processErrorResponse parses the given error and returns an appropriate DBError.
func processErrorResponse(err error) db.Error {
switch err {
case nil:
return nil
case sql.ErrNoRows:
return db.ErrNoEntries
default:
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
}
func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
count, err := q.Count(ctx)
exists := count != 0
err = processErrorResponse(err)
if err != nil {
if err == db.ErrNoEntries {
return false, nil
}
return false, err
}
return exists, nil
}
func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
count, err := q.Count(ctx)
notExists := count == 0
err = processErrorResponse(err)
if err != nil {
if err == db.ErrNoEntries {
return true, nil
}
return false, err
}
return notExists, nil
}