more big changes

This commit is contained in:
tsmethurst 2021-08-25 13:36:54 +02:00
commit 4e054233da
71 changed files with 640 additions and 405 deletions

View file

@ -36,6 +36,9 @@ type Account interface {
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// UpdateAccount updates one account by ID.
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
// GetLocalAccountByUsername returns an account on this instance by its username.
GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, Error)

View file

@ -16,12 +16,13 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/sirupsen/logrus"
@ -35,7 +36,6 @@ type accountDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
@ -79,6 +79,25 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.
return account, err
}
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
if strings.TrimSpace(account.ID) == "" {
return nil, errors.New("account had no ID")
}
account.UpdatedAt = time.Now()
q := a.conn.
NewUpdate().
Model(account).
WherePK()
_, err := q.Exec(ctx)
err = processErrorResponse(err)
return account, err
}
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)

View file

@ -16,18 +16,19 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
package bundb_test
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type AccountTestSuite struct {
PGStandardTestSuite
BunDBStandardTestSuite
}
func (suite *AccountTestSuite) SetupSuite() {
@ -66,6 +67,20 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
suite.NotEmpty(account.HeaderMediaAttachment.URL)
}
func (suite *AccountTestSuite) TestUpdateAccount() {
testAccount := suite.testAccounts["local_account_1"]
testAccount.DisplayName = "new display name!"
_, err := suite.db.UpdateAccount(context.Background(), testAccount)
suite.NoError(err)
updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID)
suite.NoError(err)
suite.Equal("new display name!", updated.DisplayName)
suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second)
}
func TestAccountTestSuite(t *testing.T) {
suite.Run(t, new(AccountTestSuite))
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -43,7 +43,6 @@ type adminDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -33,7 +33,6 @@ type basicDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error {
@ -116,7 +115,7 @@ func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.E
q := b.conn.
NewUpdate().
Model(i).
Where("id = ?", id)
WherePK()
_, err := q.Exec(ctx)
@ -127,7 +126,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu
q := b.conn.NewUpdate().
Model(i).
Set("? = ?", bun.Safe(key), value).
Where("id = ?", id)
WherePK()
_, err := q.Exec(ctx)
@ -174,7 +173,6 @@ func (b *basicDB) Stop(ctx context.Context) db.Error {
b.log.Info("closing db connection")
if err := b.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
b.cancel()
return err
}
return nil

View file

@ -0,0 +1,68 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb_test
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type BasicTestSuite struct {
BunDBStandardTestSuite
}
func (suite *BasicTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
suite.testAttachments = testrig.NewTestAttachments()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
}
func (suite *BasicTestSuite) SetupTest() {
suite.config = testrig.NewTestConfig()
suite.db = testrig.NewTestDB()
suite.log = testrig.NewTestLog()
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
func (suite *BasicTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *BasicTestSuite) TestGetAccountByID() {
testAccount := suite.testAccounts["local_account_1"]
a := &gtsmodel.Account{}
err := suite.db.GetByID(context.Background(), testAccount.ID, a)
suite.NoError(err)
}
func TestBasicTestSuite(t *testing.T) {
suite.Run(t, new(BasicTestSuite))
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -41,13 +41,18 @@ import (
"github.com/uptrace/bun/dialect/pgdialect"
)
const (
dbTypePostgres = "postgres"
dbTypeSqlite = "sqlite"
)
var registerTables []interface{} = []interface{}{
&gtsmodel.StatusToEmoji{},
&gtsmodel.StatusToTag{},
}
// postgresService satisfies the DB interface
type postgresService struct {
// bunDBService satisfies the DB interface
type bunDBService struct {
db.Account
db.Admin
db.Basic
@ -57,37 +62,49 @@ type postgresService struct {
db.Mention
db.Notification
db.Relationship
db.Session
db.Status
db.Timeline
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
var sqldb *sql.DB
var conn *bun.DB
opts, err := derivePGOptions(c)
if err != nil {
return nil, fmt.Errorf("could not create postgres options: %s", err)
// depending on the database type we're trying to create, we need to use a different driver...
switch strings.ToLower(c.DBConfig.Type) {
case dbTypePostgres:
// POSTGRES
opts, err := deriveBunDBPGOptions(c)
if err != nil {
return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
}
sqldb = stdlib.OpenDB(*opts)
conn = bun.NewDB(sqldb, pgdialect.New())
case dbTypeSqlite:
// SQLITE
// TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite
default:
return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
}
sqldb := stdlib.OpenDB(*opts)
conn := bun.NewDB(sqldb, pgdialect.New())
// actually *begin* the connection so that we can tell if the db is there and listening
if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("db connection error: %s", err)
}
log.Info("connected to postgres")
log.Info("connected to database")
for _, t := range registerTables {
// https://bun.uptrace.dev/orm/many-to-many-relation/
conn.RegisterModel(t)
}
ps := &postgresService{
ps := &bunDBService{
Account: &accountDB{
config: c,
conn: conn,
@ -133,6 +150,11 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
conn: conn,
log: log,
},
Session: &sessionDB{
config: c,
conn: conn,
log: log,
},
Status: &statusDB{
config: c,
conn: conn,
@ -148,7 +170,7 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
log: log,
}
// we can confidently return this useable postgres service now
// we can confidently return this useable service now
return ps, nil
}
@ -156,9 +178,9 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
HANDY STUFF
*/
// derivePGOptions takes an application config and returns either a ready-to-use set of options
// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options
// with sensible defaults, or an error if it's not satisfied by the provided config.
func derivePGOptions(c *config.Config) (*pgx.ConnConfig, error) {
func deriveBunDBPGOptions(c *config.Config) (*pgx.ConnConfig, error) {
if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres {
return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type)
}
@ -236,15 +258,16 @@ func derivePGOptions(c *config.Config) (*pgx.ConnConfig, error) {
tlsConfig.RootCAs = certPool
}
opts, _ := pgx.ParseConfig("")
opts.Host = c.DBConfig.Address
opts.Port = uint16(c.DBConfig.Port)
opts.User = c.DBConfig.User
opts.Password = c.DBConfig.Password
opts.TLSConfig = tlsConfig
opts.PreferSimpleProtocol = true
cfg, _ := pgx.ParseConfig("")
cfg.Host = c.DBConfig.Address
cfg.Port = uint16(c.DBConfig.Port)
cfg.User = c.DBConfig.User
cfg.Password = c.DBConfig.Password
cfg.TLSConfig = tlsConfig
cfg.Database = c.DBConfig.Database
cfg.PreferSimpleProtocol = true
return opts, nil
return cfg, nil
}
/*
@ -253,7 +276,7 @@ func derivePGOptions(c *config.Config) (*pgx.ConnConfig, error) {
// TODO: move these to the type converter, it's bananas that they're here and not there
func (ps *postgresService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
ogAccount := &gtsmodel.Account{}
if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil {
return nil, err
@ -331,7 +354,7 @@ func (ps *postgresService) MentionStringsToMentions(ctx context.Context, targetA
return menchies, nil
}
func (ps *postgresService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
newTags := []*gtsmodel.Tag{}
for _, t := range tags {
tag := &gtsmodel.Tag{}
@ -367,7 +390,7 @@ func (ps *postgresService) TagStringsToTags(ctx context.Context, tags []string,
return newTags, nil
}
func (ps *postgresService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
newEmojis := []*gtsmodel.Emoji{}
for _, e := range emojis {
emoji := &gtsmodel.Emoji{}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
package bundb_test
import (
"github.com/sirupsen/logrus"
@ -27,7 +27,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
type PGStandardTestSuite struct {
type BunDBStandardTestSuite struct {
// standard suite interfaces
suite.Suite
config *config.Config

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -34,7 +34,6 @@ type domainDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -32,7 +32,6 @@ type instanceDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -32,7 +32,6 @@ type mediaDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery {

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -33,7 +33,6 @@ type mentionDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -33,7 +33,6 @@ type notificationDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -34,7 +34,6 @@ type relationshipDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery {

View file

@ -0,0 +1,85 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"crypto/rand"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/uptrace/bun"
)
type sessionDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
}
func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
rs := new(gtsmodel.RouterSession)
q := s.conn.
NewSelect().
Model(rs).
Limit(1)
_, err := q.Exec(ctx)
err = processErrorResponse(err)
return rs, err
}
func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
auth := make([]byte, 32)
crypt := make([]byte, 32)
if _, err := rand.Read(auth); err != nil {
return nil, err
}
if _, err := rand.Read(crypt); err != nil {
return nil, err
}
rid, err := id.NewULID()
if err != nil {
return nil, err
}
rs := &gtsmodel.RouterSession{
ID: rid,
Auth: auth,
Crypt: crypt,
}
q := s.conn.
NewInsert().
Model(rs)
_, err = q.Exec(ctx)
err = processErrorResponse(err)
return rs, err
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"container/list"
@ -36,7 +36,6 @@ type statusDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
package bundb_test
import (
"context"
@ -29,7 +29,7 @@ import (
)
type StatusTestSuite struct {
PGStandardTestSuite
BunDBStandardTestSuite
}
func (suite *StatusTestSuite) SetupSuite() {

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -34,7 +34,6 @@ type timelineDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
@ -78,7 +77,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
// OR statuses posted by accountID itself (since a user should be able to see their own statuses).
//
// This is equivalent to something like WHERE ... AND (... OR ...)
// See: https://pg.uptrace.dev/queries/#select
// See: https://bun.uptrace.dev/guide/queries.html#select
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
return q.
WhereOr("f.account_id = ?", accountID).

View file

@ -16,10 +16,11 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
"strings"
"database/sql"
@ -35,6 +36,9 @@ func processErrorResponse(err error) db.Error {
case sql.ErrNoRows:
return db.ErrNoEntries
default:
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
}

View file

@ -40,6 +40,7 @@ type DB interface {
Mention
Notification
Relationship
Session
Status
Timeline

31
internal/db/session.go Normal file
View file

@ -0,0 +1,31 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Session handles getting/creation of router sessions.
type Session interface {
GetSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
CreateSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
}