more big changes

This commit is contained in:
tsmethurst 2021-08-25 13:36:54 +02:00
commit 4e054233da
71 changed files with 640 additions and 405 deletions

8
go.sum
View file

@ -446,10 +446,10 @@ github.com/ugorji/go v1.2.6/go.mod h1:anCg0y61KIhDlPZmnH+so+RQbysYVyDko0IMgJv0Nn
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
github.com/ugorji/go/codec v1.2.6 h1:7kbGefxLoDBuYXOms4yD7223OpNMMPNPZxXk5TvFcyQ=
github.com/ugorji/go/codec v1.2.6/go.mod h1:V6TCNZ4PHqoHGFZuSG1W8nrCzzdgA2DozYxWFFpvxTw=
github.com/uptrace/bun v1.0.0-rc.1 h1:yMjz/JdlgYRema5mk59URzjAV7HpOV58Fj3KID34K9U=
github.com/uptrace/bun v1.0.0-rc.1/go.mod h1:aL6D9vPw8DXaTQTwGrEPtUderBYXx7ShUmPfnxnqscw=
github.com/uptrace/bun/dialect/pgdialect v1.0.0-rc.1 h1:4rxcO4+x8r2xyCnduH9fhhkjhbtzLHFl3vmx2goVgio=
github.com/uptrace/bun/dialect/pgdialect v1.0.0-rc.1/go.mod h1:x48KjNeAGTSq4WNIOvtfzAOFjPG/V/Gx+SdwO5ksimU=
github.com/uptrace/bun v0.4.3 h1:x6bjDqwjxwM/9Q1eauhkznuvTrz/rLiCK2p4tT63sAE=
github.com/uptrace/bun v0.4.3/go.mod h1:aL6D9vPw8DXaTQTwGrEPtUderBYXx7ShUmPfnxnqscw=
github.com/uptrace/bun/dialect/pgdialect v0.4.3 h1:lM2IUKpU99110chKkupw3oTfXiOKpB0hTJIe6frqQDo=
github.com/uptrace/bun/dialect/pgdialect v0.4.3/go.mod h1:BaNvWejl32oKUhwpFkw/eNcWldzIlVY4nfw/sNul0s8=
github.com/urfave/cli/v2 v2.3.0 h1:qph92Y649prgesehzOrQjdWyxFOp/QVM+6imKHad91M=
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=

View file

@ -79,7 +79,7 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandler()
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"])
ctx.Set(oauth.SessionAuthorizedToken, oauth.TokenToOauthToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", account.UpdateCredentialsPath), bytes.NewReader(requestBody.Bytes())) // the endpoint we're hitting

View file

@ -28,7 +28,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/pg"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"golang.org/x/crypto/bcrypt"
@ -104,7 +104,7 @@ func (suite *AuthTestSuite) SetupTest() {
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
db, err := pg.NewPostgresService(context.Background(), suite.config, log)
db, err := bundb.NewBunDBService(context.Background(), suite.config, log)
if err != nil {
logrus.Panicf("error creating database connection: %s", err)
}

View file

@ -80,20 +80,15 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
}
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3
user := &gtsmodel.User{
ID: userID,
}
user := &gtsmodel.User{}
if err := m.db.GetByID(c.Request.Context(), user.ID, user); err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
acct := &gtsmodel.Account{
ID: user.AccountID,
}
if err := m.db.GetByID(c.Request.Context(), acct.ID, acct); err != nil {
acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
if err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return

View file

@ -56,8 +56,8 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) {
c.Set(oauth.SessionAuthorizedUser, user)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user)
acct := &gtsmodel.Account{}
if err := m.db.GetByID(c.Request.Context(), user.AccountID, acct); err != nil || acct == nil {
acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
if err != nil || acct == nil {
l.Warnf("no account found for validated user %s", uid)
return
}

View file

@ -121,7 +121,7 @@ func (suite *MediaCreateTestSuite) TestStatusCreatePOSTImageHandlerSuccessful()
// set up the context for the request
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])

View file

@ -67,7 +67,7 @@ func (suite *StatusBoostTestSuite) TearDownTest() {
func (suite *StatusBoostTestSuite) TestPostBoost() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["admin_account_status_1"]
@ -133,7 +133,7 @@ func (suite *StatusBoostTestSuite) TestPostBoost() {
func (suite *StatusBoostTestSuite) TestPostUnboostable() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["local_account_2_status_4"]
@ -171,7 +171,7 @@ func (suite *StatusBoostTestSuite) TestPostUnboostable() {
func (suite *StatusBoostTestSuite) TestPostNotVisible() {
t := suite.testTokens["local_account_2"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["local_account_1_status_3"] // this is a mutual only status and these accounts aren't mutuals

View file

@ -83,7 +83,7 @@ https://docs.gotosocial.org/en/latest/user_guide/posts/#links
func (suite *StatusCreateTestSuite) TestPostNewStatus() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@ -137,7 +137,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatus() {
func (suite *StatusCreateTestSuite) TestPostAnotherNewStatus() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@ -172,7 +172,7 @@ func (suite *StatusCreateTestSuite) TestPostAnotherNewStatus() {
func (suite *StatusCreateTestSuite) TestPostNewStatusWithEmoji() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@ -213,7 +213,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatusWithEmoji() {
// Try to reply to a status that doesn't exist
func (suite *StatusCreateTestSuite) TestReplyToNonexistentStatus() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@ -244,7 +244,7 @@ func (suite *StatusCreateTestSuite) TestReplyToNonexistentStatus() {
// Post a reply to the status of a local user that allows replies.
func (suite *StatusCreateTestSuite) TestReplyToLocalStatus() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@ -284,7 +284,7 @@ func (suite *StatusCreateTestSuite) TestReplyToLocalStatus() {
// Take a media file which is currently not associated with a status, and attach it to a new status.
func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
attachment := suite.testAttachments["local_account_1_unattached_1"]
@ -323,8 +323,7 @@ func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() {
assert.Len(suite.T(), statusResponse.MediaAttachments, 1)
// get the updated media attachment from the database
gtsAttachment := &gtsmodel.MediaAttachment{}
err = suite.db.GetByID(context.Background(), statusResponse.MediaAttachments[0].ID, gtsAttachment)
gtsAttachment, err := suite.db.GetAttachmentByID(context.Background(), statusResponse.MediaAttachments[0].ID)
assert.NoError(suite.T(), err)
// convert it to a masto attachment

View file

@ -71,7 +71,7 @@ func (suite *StatusFaveTestSuite) TearDownTest() {
func (suite *StatusFaveTestSuite) TestPostFave() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["admin_account_status_2"]
@ -119,7 +119,7 @@ func (suite *StatusFaveTestSuite) TestPostFave() {
func (suite *StatusFaveTestSuite) TestPostUnfaveable() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["local_account_2_status_3"] // this one is unlikeable and unreplyable

View file

@ -69,7 +69,7 @@ func (suite *StatusFavedByTestSuite) TearDownTest() {
func (suite *StatusFavedByTestSuite) TestGetFavedBy() {
t := suite.testTokens["local_account_2"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["admin_account_status_1"] // this status is faved by local_account_1

View file

@ -71,7 +71,7 @@ func (suite *StatusUnfaveTestSuite) TearDownTest() {
func (suite *StatusUnfaveTestSuite) TestPostUnfave() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
// this is the status we wanna unfave: in the testrig it's already faved by this account
targetStatus := suite.testStatuses["admin_account_status_1"]
@ -120,7 +120,7 @@ func (suite *StatusUnfaveTestSuite) TestPostUnfave() {
func (suite *StatusUnfaveTestSuite) TestPostAlreadyNotFaved() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
// this is the status we wanna unfave: in the testrig it's not faved by this account
targetStatus := suite.testStatuses["admin_account_status_2"]

View file

@ -37,7 +37,7 @@ type cache struct {
// New returns a new in-memory cache.
func New() Cache {
c := ttlcache.NewCache()
c.SetTTL(30 * time.Second)
c.SetTTL(5 * time.Minute)
cache := &cache{
c: c,
}

View file

@ -28,7 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/cliactions"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/pg"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
"golang.org/x/crypto/bcrypt"
@ -36,7 +36,7 @@ import (
// Create creates a new account in the database using the provided flags.
var Create cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
dbConn, err := pg.NewPostgresService(ctx, c, log)
dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@ -75,7 +75,7 @@ var Create cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo
// Confirm sets a user to Approved, sets Email to the current UnconfirmedEmail value, and sets ConfirmedAt to now.
var Confirm cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
dbConn, err := pg.NewPostgresService(ctx, c, log)
dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@ -110,7 +110,7 @@ var Confirm cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
// Promote sets a user to admin.
var Promote cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
dbConn, err := pg.NewPostgresService(ctx, c, log)
dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@ -142,7 +142,7 @@ var Promote cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
// Demote sets admin on a user to false.
var Demote cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
dbConn, err := pg.NewPostgresService(ctx, c, log)
dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@ -174,7 +174,7 @@ var Demote cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo
// Disable sets Disabled to true on a user.
var Disable cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
dbConn, err := pg.NewPostgresService(ctx, c, log)
dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@ -212,7 +212,7 @@ var Suspend cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
// Password sets the password of target account.
var Password cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
dbConn, err := pg.NewPostgresService(ctx, c, log)
dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}

View file

@ -35,7 +35,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/blob"
"github.com/superseriousbusiness/gotosocial/internal/cliactions"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db/pg"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/gotosocial"
@ -79,7 +79,7 @@ var models []interface{} = []interface{}{
// Start creates and starts a gotosocial server
var Start cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
dbService, err := pg.NewPostgresService(ctx, c, log)
dbService, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}

View file

@ -36,6 +36,9 @@ type Account interface {
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)
// UpdateAccount updates one account by ID.
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
// GetLocalAccountByUsername returns an account on this instance by its username.
GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, Error)

View file

@ -16,12 +16,13 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/sirupsen/logrus"
@ -35,7 +36,6 @@ type accountDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
@ -79,6 +79,25 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.
return account, err
}
func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
if strings.TrimSpace(account.ID) == "" {
return nil, errors.New("account had no ID")
}
account.UpdatedAt = time.Now()
q := a.conn.
NewUpdate().
Model(account).
WherePK()
_, err := q.Exec(ctx)
err = processErrorResponse(err)
return account, err
}
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)

View file

@ -16,18 +16,19 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
package bundb_test
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type AccountTestSuite struct {
PGStandardTestSuite
BunDBStandardTestSuite
}
func (suite *AccountTestSuite) SetupSuite() {
@ -66,6 +67,20 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
suite.NotEmpty(account.HeaderMediaAttachment.URL)
}
func (suite *AccountTestSuite) TestUpdateAccount() {
testAccount := suite.testAccounts["local_account_1"]
testAccount.DisplayName = "new display name!"
_, err := suite.db.UpdateAccount(context.Background(), testAccount)
suite.NoError(err)
updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID)
suite.NoError(err)
suite.Equal("new display name!", updated.DisplayName)
suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second)
}
func TestAccountTestSuite(t *testing.T) {
suite.Run(t, new(AccountTestSuite))
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -43,7 +43,6 @@ type adminDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -33,7 +33,6 @@ type basicDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error {
@ -116,7 +115,7 @@ func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.E
q := b.conn.
NewUpdate().
Model(i).
Where("id = ?", id)
WherePK()
_, err := q.Exec(ctx)
@ -127,7 +126,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu
q := b.conn.NewUpdate().
Model(i).
Set("? = ?", bun.Safe(key), value).
Where("id = ?", id)
WherePK()
_, err := q.Exec(ctx)
@ -174,7 +173,6 @@ func (b *basicDB) Stop(ctx context.Context) db.Error {
b.log.Info("closing db connection")
if err := b.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
b.cancel()
return err
}
return nil

View file

@ -0,0 +1,68 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb_test
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type BasicTestSuite struct {
BunDBStandardTestSuite
}
func (suite *BasicTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
suite.testAttachments = testrig.NewTestAttachments()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
}
func (suite *BasicTestSuite) SetupTest() {
suite.config = testrig.NewTestConfig()
suite.db = testrig.NewTestDB()
suite.log = testrig.NewTestLog()
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
func (suite *BasicTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *BasicTestSuite) TestGetAccountByID() {
testAccount := suite.testAccounts["local_account_1"]
a := &gtsmodel.Account{}
err := suite.db.GetByID(context.Background(), testAccount.ID, a)
suite.NoError(err)
}
func TestBasicTestSuite(t *testing.T) {
suite.Run(t, new(BasicTestSuite))
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -41,13 +41,18 @@ import (
"github.com/uptrace/bun/dialect/pgdialect"
)
const (
dbTypePostgres = "postgres"
dbTypeSqlite = "sqlite"
)
var registerTables []interface{} = []interface{}{
&gtsmodel.StatusToEmoji{},
&gtsmodel.StatusToTag{},
}
// postgresService satisfies the DB interface
type postgresService struct {
// bunDBService satisfies the DB interface
type bunDBService struct {
db.Account
db.Admin
db.Basic
@ -57,37 +62,49 @@ type postgresService struct {
db.Mention
db.Notification
db.Relationship
db.Session
db.Status
db.Timeline
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
var sqldb *sql.DB
var conn *bun.DB
opts, err := derivePGOptions(c)
if err != nil {
return nil, fmt.Errorf("could not create postgres options: %s", err)
// depending on the database type we're trying to create, we need to use a different driver...
switch strings.ToLower(c.DBConfig.Type) {
case dbTypePostgres:
// POSTGRES
opts, err := deriveBunDBPGOptions(c)
if err != nil {
return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
}
sqldb = stdlib.OpenDB(*opts)
conn = bun.NewDB(sqldb, pgdialect.New())
case dbTypeSqlite:
// SQLITE
// TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite
default:
return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
}
sqldb := stdlib.OpenDB(*opts)
conn := bun.NewDB(sqldb, pgdialect.New())
// actually *begin* the connection so that we can tell if the db is there and listening
if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("db connection error: %s", err)
}
log.Info("connected to postgres")
log.Info("connected to database")
for _, t := range registerTables {
// https://bun.uptrace.dev/orm/many-to-many-relation/
conn.RegisterModel(t)
}
ps := &postgresService{
ps := &bunDBService{
Account: &accountDB{
config: c,
conn: conn,
@ -133,6 +150,11 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
conn: conn,
log: log,
},
Session: &sessionDB{
config: c,
conn: conn,
log: log,
},
Status: &statusDB{
config: c,
conn: conn,
@ -148,7 +170,7 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
log: log,
}
// we can confidently return this useable postgres service now
// we can confidently return this useable service now
return ps, nil
}
@ -156,9 +178,9 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
HANDY STUFF
*/
// derivePGOptions takes an application config and returns either a ready-to-use set of options
// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options
// with sensible defaults, or an error if it's not satisfied by the provided config.
func derivePGOptions(c *config.Config) (*pgx.ConnConfig, error) {
func deriveBunDBPGOptions(c *config.Config) (*pgx.ConnConfig, error) {
if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres {
return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type)
}
@ -236,15 +258,16 @@ func derivePGOptions(c *config.Config) (*pgx.ConnConfig, error) {
tlsConfig.RootCAs = certPool
}
opts, _ := pgx.ParseConfig("")
opts.Host = c.DBConfig.Address
opts.Port = uint16(c.DBConfig.Port)
opts.User = c.DBConfig.User
opts.Password = c.DBConfig.Password
opts.TLSConfig = tlsConfig
opts.PreferSimpleProtocol = true
cfg, _ := pgx.ParseConfig("")
cfg.Host = c.DBConfig.Address
cfg.Port = uint16(c.DBConfig.Port)
cfg.User = c.DBConfig.User
cfg.Password = c.DBConfig.Password
cfg.TLSConfig = tlsConfig
cfg.Database = c.DBConfig.Database
cfg.PreferSimpleProtocol = true
return opts, nil
return cfg, nil
}
/*
@ -253,7 +276,7 @@ func derivePGOptions(c *config.Config) (*pgx.ConnConfig, error) {
// TODO: move these to the type converter, it's bananas that they're here and not there
func (ps *postgresService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
ogAccount := &gtsmodel.Account{}
if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil {
return nil, err
@ -331,7 +354,7 @@ func (ps *postgresService) MentionStringsToMentions(ctx context.Context, targetA
return menchies, nil
}
func (ps *postgresService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
newTags := []*gtsmodel.Tag{}
for _, t := range tags {
tag := &gtsmodel.Tag{}
@ -367,7 +390,7 @@ func (ps *postgresService) TagStringsToTags(ctx context.Context, tags []string,
return newTags, nil
}
func (ps *postgresService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
newEmojis := []*gtsmodel.Emoji{}
for _, e := range emojis {
emoji := &gtsmodel.Emoji{}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
package bundb_test
import (
"github.com/sirupsen/logrus"
@ -27,7 +27,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
type PGStandardTestSuite struct {
type BunDBStandardTestSuite struct {
// standard suite interfaces
suite.Suite
config *config.Config

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -34,7 +34,6 @@ type domainDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -32,7 +32,6 @@ type instanceDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -32,7 +32,6 @@ type mediaDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery {

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -33,7 +33,6 @@ type mentionDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -33,7 +33,6 @@ type notificationDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -34,7 +34,6 @@ type relationshipDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery {

View file

@ -0,0 +1,85 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"crypto/rand"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/uptrace/bun"
)
type sessionDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
}
func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
rs := new(gtsmodel.RouterSession)
q := s.conn.
NewSelect().
Model(rs).
Limit(1)
_, err := q.Exec(ctx)
err = processErrorResponse(err)
return rs, err
}
func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
auth := make([]byte, 32)
crypt := make([]byte, 32)
if _, err := rand.Read(auth); err != nil {
return nil, err
}
if _, err := rand.Read(crypt); err != nil {
return nil, err
}
rid, err := id.NewULID()
if err != nil {
return nil, err
}
rs := &gtsmodel.RouterSession{
ID: rid,
Auth: auth,
Crypt: crypt,
}
q := s.conn.
NewInsert().
Model(rs)
_, err = q.Exec(ctx)
err = processErrorResponse(err)
return rs, err
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"container/list"
@ -36,7 +36,6 @@ type statusDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
package bundb_test
import (
"context"
@ -29,7 +29,7 @@ import (
)
type StatusTestSuite struct {
PGStandardTestSuite
BunDBStandardTestSuite
}
func (suite *StatusTestSuite) SetupSuite() {

View file

@ -16,7 +16,7 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
@ -34,7 +34,6 @@ type timelineDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
@ -78,7 +77,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
// OR statuses posted by accountID itself (since a user should be able to see their own statuses).
//
// This is equivalent to something like WHERE ... AND (... OR ...)
// See: https://pg.uptrace.dev/queries/#select
// See: https://bun.uptrace.dev/guide/queries.html#select
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
return q.
WhereOr("f.account_id = ?", accountID).

View file

@ -16,10 +16,11 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package bundb
import (
"context"
"strings"
"database/sql"
@ -35,6 +36,9 @@ func processErrorResponse(err error) db.Error {
case sql.ErrNoRows:
return db.ErrNoEntries
default:
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
}

View file

@ -40,6 +40,7 @@ type DB interface {
Mention
Notification
Relationship
Session
Status
Timeline

31
internal/db/session.go Normal file
View file

@ -0,0 +1,31 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Session handles getting/creation of router sessions.
type Session interface {
GetSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
CreateSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
}

View file

@ -24,6 +24,7 @@ import (
"errors"
"fmt"
"net/url"
"strings"
"github.com/go-fed/activity/streams"
"github.com/go-fed/activity/streams/vocab"
@ -34,18 +35,33 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/transport"
)
func instanceAccount(account *gtsmodel.Account) bool {
return strings.EqualFold(account.Username, account.Domain) ||
account.FollowersURI == "" ||
account.FollowingURI == "" ||
(account.Username == "internal.fetch" && strings.Contains(account.Note, "internal service actor"))
}
// EnrichRemoteAccount takes an account that's already been inserted into the database in a minimal form,
// and populates it with additional fields, media, etc.
//
// EnrichRemoteAccount is mostly useful for calling after an account has been initially created by
// the federatingDB's Create function, or during the federated authorization flow.
func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) {
// if we're dealing with an instance account, we don't need to update anything
if instanceAccount(account) {
return account, nil
}
if err := d.PopulateAccountFields(ctx, account, username, false); err != nil {
return nil, err
}
if err := d.db.UpdateByID(ctx, account.ID, account); err != nil {
return nil, fmt.Errorf("EnrichRemoteAccount: error updating account: %s", err)
var err error
account, err = d.db.UpdateAccount(ctx, account)
if err != nil {
d.log.Errorf("EnrichRemoteAccount: error updating account: %s", err)
}
return account, nil
@ -108,8 +124,9 @@ func (d *deref) GetRemoteAccount(ctx context.Context, username string, remoteAcc
return nil, new, fmt.Errorf("FullyDereferenceAccount: error populating further account fields: %s", err)
}
if err := d.db.UpdateByID(ctx, gtsAccount.ID, gtsAccount); err != nil {
return nil, new, fmt.Errorf("FullyDereferenceAccount: error updating existing account: %s", err)
gtsAccount, err = d.db.UpdateAccount(ctx, gtsAccount)
if err != nil {
return nil, false, fmt.Errorf("EnrichRemoteAccount: error updating account: %s", err)
}
}

View file

@ -25,7 +25,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
@ -87,8 +86,8 @@ func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI
return err
}
status := &gtsmodel.Status{}
if err := d.db.GetByID(ctx, id, status); err != nil {
status, err := d.db.GetStatusByID(ctx, id)
if err != nil {
return err
}

View file

@ -87,7 +87,7 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
a, err := f.db.GetAccountByURI(ctx, id.String())
if err == nil {
// it's an account
l.Debugf("uri is for an account with id: %s", s.ID)
l.Debugf("uri is for an account with id: %s", a.ID)
if err := f.db.DeleteByID(ctx, a.ID, &gtsmodel.Account{}); err != nil {
return fmt.Errorf("DELETE: err deleting account: %s", err)
}

View file

@ -51,13 +51,17 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow
followers = streams.NewActivityStreamsCollection()
items := streams.NewActivityStreamsItemsProperty()
for _, follow := range acctFollowers {
gtsFollower := &gtsmodel.Account{}
if err := f.db.GetByID(ctx, follow.AccountID, gtsFollower); err != nil {
return nil, fmt.Errorf("FOLLOWERS: db error getting account id %s: %s", follow.AccountID, err)
if follow.Account == nil {
followAccount, err := f.db.GetAccountByID(ctx, follow.AccountID)
if err != nil {
return nil, fmt.Errorf("FOLLOWERS: db error getting account id %s: %s", follow.AccountID, err)
}
follow.Account = followAccount
}
uri, err := url.Parse(gtsFollower.URI)
uri, err := url.Parse(follow.Account.URI)
if err != nil {
return nil, fmt.Errorf("FOLLOWERS: error parsing %s as url: %s", gtsFollower.URI, err)
return nil, fmt.Errorf("FOLLOWERS: error parsing %s as url: %s", follow.Account.URI, err)
}
items.AppendIRI(uri)
}

View file

@ -64,13 +64,17 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow
following = streams.NewActivityStreamsCollection()
items := streams.NewActivityStreamsItemsProperty()
for _, follow := range acctFollowing {
gtsFollowing := &gtsmodel.Account{}
if err := f.db.GetByID(ctx, follow.AccountID, gtsFollowing); err != nil {
return nil, fmt.Errorf("FOLLOWING: db error getting account id %s: %s", follow.AccountID, err)
if follow.Account == nil {
followAccount, err := f.db.GetAccountByID(ctx, follow.AccountID)
if err != nil {
return nil, fmt.Errorf("FOLLOWING: db error getting account id %s: %s", follow.AccountID, err)
}
follow.Account = followAccount
}
uri, err := url.Parse(gtsFollowing.URI)
uri, err := url.Parse(follow.Account.URI)
if err != nil {
return nil, fmt.Errorf("FOLLOWING: error parsing %s as url: %s", gtsFollowing.URI, err)
return nil, fmt.Errorf("FOLLOWING: error parsing %s as url: %s", follow.Account.URI, err)
}
items.AppendIRI(uri)
}

View file

@ -152,7 +152,8 @@ func (f *federatingDB) Update(ctx context.Context, asType vocab.Type) error {
}
updatedAcct.ID = requestingAcct.ID // set this here so the db will update properly instead of trying to PUT this and getting constraint issues
if err := f.db.UpdateByID(ctx, requestingAcct.ID, updatedAcct); err != nil {
updatedAcct, err = f.db.UpdateAccount(ctx, updatedAcct)
if err != nil {
return fmt.Errorf("UPDATE: database error inserting updated account: %s", err)
}

View file

@ -18,8 +18,8 @@
// Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database.
// These types should never be serialized and/or sent out via public APIs, as they contain sensitive information.
// The annotation used on these structs is for handling them via the go-pg ORM (hence why they're in this db subdir).
// See here for more info on go-pg model annotations: https://pg.uptrace.dev/models/
// The annotation used on these structs is for handling them via the bun-db ORM.
// See here for more info on bun model annotations: https://bun.uptrace.dev/guide/models.html
package gtsmodel
import (
@ -34,9 +34,9 @@ type Account struct {
*/
// id of this account in the local database
ID string `bun:"type:CHAR(26),pk,notnull,unique"`
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"`
// Username of the account, should just be a string of [a-z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org``
Username string `bun:",notnull,unique:userdomain"` // username and domain should be unique *with* each other
Username string `bun:",notnull,unique:userdomain,nullzero"` // username and domain should be unique *with* each other
// Domain of the account, will be null if this is a local account, otherwise something like ``example.org`` or ``mastodon.social``. Should be unique with username.
Domain string `bun:",unique:userdomain,nullzero"` // username and domain should be unique *with* each other
@ -93,21 +93,21 @@ type Account struct {
*/
// What is the activitypub URI for this account discovered by webfinger?
URI string `bun:",unique"`
URI string `bun:",unique,nullzero"`
// At which URL can we see the user account in a web browser?
URL string `bun:",unique"`
URL string `bun:",unique,nullzero"`
// Last time this account was located using the webfinger API.
LastWebfingeredAt time.Time `bun:",nullzero"`
// Address of this account's activitypub inbox, for sending activity to
InboxURI string `bun:",unique"`
InboxURI string `bun:",unique,nullzero"`
// Address of this account's activitypub outbox
OutboxURI string `bun:",unique"`
OutboxURI string `bun:",unique,nullzero"`
// URI for getting the following list of this account
FollowingURI string `bun:",unique"`
FollowingURI string `bun:",unique,nullzero"`
// URI for getting the followers list of this account
FollowersURI string `bun:",unique"`
FollowersURI string `bun:",unique,nullzero"`
// URL for getting the featured collection list of this account
FeaturedCollectionURI string `bun:",unique"`
FeaturedCollectionURI string `bun:",unique,nullzero"`
// What type of activitypub actor is this account?
ActorType string
// This account is associated with x account id
@ -137,7 +137,7 @@ type Account struct {
// Should we hide this account's collections?
HideCollections bool
// id of the database entry that caused this account to become suspended -- can be an account ID or a domain block ID
SuspensionOrigin string `bun:"type:CHAR(26)"`
SuspensionOrigin string `bun:"type:CHAR(26),nullzero"`
}
// Field represents a key value field on an account, for things like pronouns, website, etc.

View file

@ -73,5 +73,5 @@ type Emoji struct {
// Is this emoji visible in the admin emoji picker?
VisibleInPicker bool `bun:",notnull,default:true"`
// In which emoji category is this emoji visible?
CategoryID string `bun:"type:CHAR(26),nullzero"`
CategoryID string `bun:"type:CHAR(26),nullzero"`
}

View file

@ -29,15 +29,15 @@ type FollowRequest struct {
// When was this follow request last updated?
UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// Who does this follow request originate from?
AccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"`
Account Account `bun:"rel:belongs-to"`
AccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"`
Account *Account `bun:"rel:belongs-to"`
// Who is the target of this follow request?
TargetAccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"`
TargetAccount Account `bun:"rel:belongs-to"`
TargetAccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"`
TargetAccount *Account `bun:"rel:belongs-to"`
// Does this follow also want to see reblogs and not just posts?
ShowReblogs bool `bun:"default:true"`
// What is the activitypub URI of this follow request?
URI string `bun:",unique"`
URI string `bun:",unique,nullzero"`
// does the following account want to be notified when the followed account posts?
Notify bool
}

View file

@ -27,9 +27,9 @@ type Status struct {
// id of the status in the database
ID string `bun:"type:CHAR(26),pk,notnull"`
// uri at which this status is reachable
URI string `bun:",unique"`
URI string `bun:",unique,nullzero"`
// web url for viewing this status
URL string `bun:",unique"`
URL string `bun:",unique,nullzero"`
// the html-formatted content of this status
Content string
// Database IDs of any media attachments associated with this status
@ -45,9 +45,9 @@ type Status struct {
EmojiIDs []string `bun:"emojis,array"`
Emojis []*Emoji `bun:"attached_emojis,m2m:status_to_emojis"` // https://bun.uptrace.dev/guide/relations.html#many-to-many-relation
// when was this status created?
CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
CreatedAt time.Time `bun:",notnull,default:current_timestamp"`
// when was this status updated?
UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
UpdatedAt time.Time `bun:",notnull,default:current_timestamp"`
// is this status from a local account?
Local bool
// which account posted this status?
@ -93,17 +93,17 @@ type Status struct {
// StatusToTag is an intermediate struct to facilitate the many2many relationship between a status and one or more tags.
type StatusToTag struct {
StatusID string `bun:"type:CHAR(26),unique:statustag"`
StatusID string `bun:"type:CHAR(26),unique:statustag,nullzero"`
Status *Status `bun:"rel:belongs-to"`
TagID string `bun:"type:CHAR(26),unique:statustag"`
TagID string `bun:"type:CHAR(26),unique:statustag,nullzero"`
Tag *Tag `bun:"rel:belongs-to"`
}
// StatusToEmoji is an intermediate struct to facilitate the many2many relationship between a status and one or more emojis.
type StatusToEmoji struct {
StatusID string `bun:"type:CHAR(26),unique:statusemoji"`
StatusID string `bun:"type:CHAR(26),unique:statusemoji,nullzero"`
Status *Status `bun:"rel:belongs-to"`
EmojiID string `bun:"type:CHAR(26),unique:statusemoji"`
EmojiID string `bun:"type:CHAR(26),unique:statusemoji,nullzero"`
Emoji *Emoji `bun:"rel:belongs-to"`
}

View file

@ -33,9 +33,9 @@ type User struct {
// id of this user in the local database; the end-user will never need to know this, it's strictly internal
ID string `bun:"type:CHAR(26),pk,notnull,unique"`
// confirmed email address for this user, this should be unique -- only one email address registered per instance, multiple users per email are not supported
Email string `bun:"default:null,unique"`
Email string `bun:"default:null,unique,nullzero"`
// The id of the local gtsmodel.Account entry for this user, if it exists (unconfirmed users don't have an account yet)
AccountID string `bun:"type:CHAR(26),unique"`
AccountID string `bun:"type:CHAR(26),unique,nullzero"`
Account *Account `bun:"rel:belongs-to"`
// The encrypted password of this user, generated using https://pkg.go.dev/golang.org/x/crypto/bcrypt#GenerateFromPassword. A salt is included so we're safe against 🌈 tables
EncryptedPassword string `bun:",notnull"`

View file

@ -39,9 +39,7 @@ func NewClientStore(db db.Basic) oauth2.ClientStore {
}
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
poc := &Client{
ID: clientID,
}
poc := &Client{}
if err := cs.db.GetByID(ctx, clientID, poc); err != nil {
return nil, err
}
@ -67,7 +65,7 @@ func (cs *clientStore) Delete(ctx context.Context, id string) error {
// Client is a handy little wrapper for typical oauth client details
type Client struct {
ID string `pg:"type:CHAR(26),pk,notnull"`
ID string `bun:"type:CHAR(26),pk,notnull"`
Secret string
Domain string
UserID string

View file

@ -43,13 +43,13 @@ type tokenStore struct {
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
// the tokens in the DB once per minute and deletes any that have expired.
func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.TokenStore {
pts := &tokenStore{
ts := &tokenStore{
db: db,
log: log,
}
// set the token store to clean out expired tokens once per minute, or return if we're done
go func(ctx context.Context, pts *tokenStore, log *logrus.Logger) {
go func(ctx context.Context, ts *tokenStore, log *logrus.Logger) {
cleanloop:
for {
select {
@ -58,32 +58,32 @@ func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.
break cleanloop
case <-time.After(1 * time.Minute):
log.Trace("sweeping out old oauth entries broom broom")
if err := pts.sweep(ctx); err != nil {
if err := ts.sweep(ctx); err != nil {
log.Errorf("error while sweeping oauth entries: %s", err)
}
}
}
}(ctx, pts, log)
return pts
}(ctx, ts, log)
return ts
}
// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so.
func (pts *tokenStore) sweep(ctx context.Context) error {
func (ts *tokenStore) sweep(ctx context.Context) error {
// select *all* tokens from the db
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
tokens := new([]*Token)
if err := pts.db.GetAll(ctx, tokens); err != nil {
if err := ts.db.GetAll(ctx, tokens); err != nil {
return err
}
// iterate through and remove expired tokens
now := time.Now()
for _, pgt := range *tokens {
for _, dbt := range *tokens {
// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
// we only want to check if a token expired before now if the expiry time is *not zero*;
// ie., if it's been explicity set.
if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) {
if err := pts.db.DeleteByID(ctx, pgt.ID, pgt); err != nil {
if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) {
if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil {
return err
}
}
@ -94,92 +94,92 @@ func (pts *tokenStore) sweep(ctx context.Context) error {
// Create creates and store the new token information.
// For the original implementation, see https://github.com/superseriousbusiness/oauth2/blob/master/store/token.go#L34
func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
t, ok := info.(*models.Token)
if !ok {
return errors.New("info param was not a models.Token")
}
pgt := TokenToPGToken(t)
if pgt.ID == "" {
pgtID, err := id.NewRandomULID()
dbt := TokenToDBToken(t)
if dbt.ID == "" {
dbtID, err := id.NewRandomULID()
if err != nil {
return err
}
pgt.ID = pgtID
dbt.ID = dbtID
}
if err := pts.db.Put(ctx, pgt); err != nil {
if err := ts.db.Put(ctx, dbt); err != nil {
return fmt.Errorf("error in tokenstore create: %s", err)
}
return nil
}
// RemoveByCode deletes a token from the DB based on the Code field
func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
return pts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, &Token{})
func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
return ts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, &Token{})
}
// RemoveByAccess deletes a token from the DB based on the Access field
func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
return pts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, &Token{})
func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
return ts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, &Token{})
}
// RemoveByRefresh deletes a token from the DB based on the Refresh field
func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
return pts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, &Token{})
func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
return ts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, &Token{})
}
// GetByCode selects a token from the DB based on the Code field
func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
if code == "" {
return nil, nil
}
pgt := &Token{
dbt := &Token{
Code: code,
}
if err := pts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, pgt); err != nil {
if err := ts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, dbt); err != nil {
return nil, err
}
return TokenToOauthToken(pgt), nil
return DBTokenToToken(dbt), nil
}
// GetByAccess selects a token from the DB based on the Access field
func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
if access == "" {
return nil, nil
}
pgt := &Token{
dbt := &Token{
Access: access,
}
if err := pts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, pgt); err != nil {
if err := ts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, dbt); err != nil {
return nil, err
}
return TokenToOauthToken(pgt), nil
return DBTokenToToken(dbt), nil
}
// GetByRefresh selects a token from the DB based on the Refresh field
func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
if refresh == "" {
return nil, nil
}
pgt := &Token{
dbt := &Token{
Refresh: refresh,
}
if err := pts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, pgt); err != nil {
if err := ts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, dbt); err != nil {
return nil, err
}
return TokenToOauthToken(pgt), nil
return DBTokenToToken(dbt), nil
}
/*
The following models are basically helpers for the postgres token store implementation, they should only be used internally.
The following models are basically helpers for the token store implementation, they should only be used internally.
*/
// Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
//
// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined,
// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and
// and tokens with expired TTLs are automatically removed. Since some databases don't have that feature, it's easier to set an expiry time and
// then periodically sweep out tokens when that time has passed.
//
// Note that this struct does *not* satisfy the token interface shown here: https://github.com/superseriousbusiness/oauth2/blob/master/model.go#L22
@ -187,26 +187,26 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
// As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
// and pgTokenToOauthToken can be used for that.
type Token struct {
ID string `pg:"type:CHAR(26),pk,notnull"`
ID string `bun:"type:CHAR(26),pk,notnull"`
ClientID string
UserID string
RedirectURI string
Scope string
Code string `pg:"default:'',pk"`
Code string `bun:"default:'',pk"`
CodeChallenge string
CodeChallengeMethod string
CodeCreateAt time.Time `pg:"type:timestamp"`
CodeExpiresAt time.Time `pg:"type:timestamp"`
Access string `pg:"default:'',pk"`
AccessCreateAt time.Time `pg:"type:timestamp"`
AccessExpiresAt time.Time `pg:"type:timestamp"`
Refresh string `pg:"default:'',pk"`
RefreshCreateAt time.Time `pg:"type:timestamp"`
RefreshExpiresAt time.Time `pg:"type:timestamp"`
CodeCreateAt time.Time `bun:",nullzero"`
CodeExpiresAt time.Time `bun:",nullzero"`
Access string `bun:"default:'',pk"`
AccessCreateAt time.Time `bun:",nullzero"`
AccessExpiresAt time.Time `bun:",nullzero"`
Refresh string `bun:"default:'',pk"`
RefreshCreateAt time.Time `bun:",nullzero"`
RefreshExpiresAt time.Time `bun:",nullzero"`
}
// TokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres
func TokenToPGToken(tkn *models.Token) *Token {
// TokenToDBToken is a lil util function that takes a gotosocial token and gives back a token for inserting into a database.
func TokenToDBToken(tkn *models.Token) *Token {
now := time.Now()
// For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
@ -247,40 +247,40 @@ func TokenToPGToken(tkn *models.Token) *Token {
}
}
// TokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token
func TokenToOauthToken(pgt *Token) *models.Token {
// DBTokenToToken is a lil util function that takes a database token and gives back a gotosocial token
func DBTokenToToken(dbt *Token) *models.Token {
now := time.Now()
var codeExpiresIn time.Duration
if !pgt.CodeExpiresAt.IsZero() {
codeExpiresIn = pgt.CodeExpiresAt.Sub(now)
if !dbt.CodeExpiresAt.IsZero() {
codeExpiresIn = dbt.CodeExpiresAt.Sub(now)
}
var accessExpiresIn time.Duration
if !pgt.AccessExpiresAt.IsZero() {
accessExpiresIn = pgt.AccessExpiresAt.Sub(now)
if !dbt.AccessExpiresAt.IsZero() {
accessExpiresIn = dbt.AccessExpiresAt.Sub(now)
}
var refreshExpiresIn time.Duration
if !pgt.RefreshExpiresAt.IsZero() {
refreshExpiresIn = pgt.RefreshExpiresAt.Sub(now)
if !dbt.RefreshExpiresAt.IsZero() {
refreshExpiresIn = dbt.RefreshExpiresAt.Sub(now)
}
return &models.Token{
ClientID: pgt.ClientID,
UserID: pgt.UserID,
RedirectURI: pgt.RedirectURI,
Scope: pgt.Scope,
Code: pgt.Code,
CodeChallenge: pgt.CodeChallenge,
CodeChallengeMethod: pgt.CodeChallengeMethod,
CodeCreateAt: pgt.CodeCreateAt,
ClientID: dbt.ClientID,
UserID: dbt.UserID,
RedirectURI: dbt.RedirectURI,
Scope: dbt.Scope,
Code: dbt.Code,
CodeChallenge: dbt.CodeChallenge,
CodeChallengeMethod: dbt.CodeChallengeMethod,
CodeCreateAt: dbt.CodeCreateAt,
CodeExpiresIn: codeExpiresIn,
Access: pgt.Access,
AccessCreateAt: pgt.AccessCreateAt,
Access: dbt.Access,
AccessCreateAt: dbt.AccessCreateAt,
AccessExpiresIn: accessExpiresIn,
Refresh: pgt.Refresh,
RefreshCreateAt: pgt.RefreshCreateAt,
Refresh: dbt.Refresh,
RefreshCreateAt: dbt.RefreshCreateAt,
RefreshExpiresIn: refreshExpiresIn,
}
}

View file

@ -177,17 +177,21 @@ selectStatusesLoop:
}
for _, b := range boosts {
oa := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, b.AccountID, oa); err == nil {
l.Debug("putting boost undo in the client api channel")
p.fromClientAPI <- gtsmodel.FromClientAPI{
APObjectType: gtsmodel.ActivityStreamsAnnounce,
APActivityType: gtsmodel.ActivityStreamsUndo,
GTSModel: s,
OriginAccount: oa,
TargetAccount: account,
if b.Account == nil {
bAccount, err := p.db.GetAccountByID(ctx, b.AccountID)
if err != nil {
continue
}
b.Account = bAccount
}
l.Debug("putting boost undo in the client api channel")
p.fromClientAPI <- gtsmodel.FromClientAPI{
APObjectType: gtsmodel.ActivityStreamsAnnounce,
APActivityType: gtsmodel.ActivityStreamsUndo,
GTSModel: s,
OriginAccount: b.Account,
TargetAccount: account,
}
if err := p.db.DeleteByID(ctx, b.ID, b); err != nil {
@ -267,7 +271,8 @@ selectStatusesLoop:
account.SuspendedAt = time.Now()
account.SuspensionOrigin = origin
if err := p.db.UpdateByID(ctx, account.ID, account); err != nil {
account, err := p.db.UpdateAccount(ctx, account)
if err != nil {
return err
}

View file

@ -29,8 +29,8 @@ import (
)
func (p *processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) {
targetAccount := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, targetAccountID, targetAccount); err != nil {
targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID)
if err != nil {
if err == db.ErrNoEntries {
return nil, errors.New("account not found")
}
@ -38,7 +38,6 @@ func (p *processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account
}
var blocked bool
var err error
if requestingAccount != nil {
blocked, err = p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true)
if err != nil {

View file

@ -39,8 +39,8 @@ func (p *processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
}
// make sure the target account actually exists in our db
targetAcct := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, targetAccountID, targetAcct); err != nil {
targetAcct, err := p.db.GetAccountByID(ctx, targetAccountID)
if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err))
}

View file

@ -117,8 +117,8 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form
}
// fetch the account with all updated values set
updatedAccount := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, account.ID, updatedAccount); err != nil {
updatedAccount, err := p.db.GetAccountByID(ctx, account.ID)
if err != nil {
return nil, fmt.Errorf("could not fetch updated account %s: %s", account.ID, err)
}

View file

@ -38,11 +38,15 @@ func (p *processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]
accts := []apimodel.Account{}
for _, fr := range frs {
acct := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, fr.AccountID, acct); err != nil {
return nil, gtserror.NewErrorInternalError(err)
if fr.Account == nil {
frAcct, err := p.db.GetAccountByID(ctx, fr.AccountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
fr.Account = frAcct
}
mastoAcct, err := p.tc.AccountToMastoPublic(ctx, acct)
mastoAcct, err := p.tc.AccountToMastoPublic(ctx, fr.Account)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@ -57,22 +61,28 @@ func (p *processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a
return nil, gtserror.NewErrorNotFound(err)
}
originAccount := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, follow.AccountID, originAccount); err != nil {
return nil, gtserror.NewErrorInternalError(err)
if follow.Account == nil {
followAccount, err := p.db.GetAccountByID(ctx, follow.AccountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
follow.Account = followAccount
}
targetAccount := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, follow.TargetAccountID, targetAccount); err != nil {
return nil, gtserror.NewErrorInternalError(err)
if follow.TargetAccount == nil {
followTargetAccount, err := p.db.GetAccountByID(ctx, follow.TargetAccountID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
follow.TargetAccount = followTargetAccount
}
p.fromClientAPI <- gtsmodel.FromClientAPI{
APObjectType: gtsmodel.ActivityStreamsFollow,
APActivityType: gtsmodel.ActivityStreamsAccept,
GTSModel: follow,
OriginAccount: originAccount,
TargetAccount: targetAccount,
OriginAccount: follow.Account,
TargetAccount: follow.TargetAccount,
}
gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID)

View file

@ -238,11 +238,11 @@ func (p *processor) processFromClientAPI(ctx context.Context, clientMsg gtsmodel
func (p *processor) federateStatus(ctx context.Context, status *gtsmodel.Status) error {
if status.Account == nil {
a := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, status.AccountID, a); err != nil {
statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID)
if err != nil {
return fmt.Errorf("federateStatus: error fetching status author account: %s", err)
}
status.Account = a
status.Account = statusAccount
}
// do nothing if this isn't our status
@ -266,11 +266,11 @@ func (p *processor) federateStatus(ctx context.Context, status *gtsmodel.Status)
func (p *processor) federateStatusDelete(ctx context.Context, status *gtsmodel.Status) error {
if status.Account == nil {
a := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, status.AccountID, a); err != nil {
return fmt.Errorf("federateStatus: error fetching status author account: %s", err)
statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID)
if err != nil {
return fmt.Errorf("federateStatusDelete: error fetching status author account: %s", err)
}
status.Account = a
status.Account = statusAccount
}
// do nothing if this isn't our status
@ -558,19 +558,19 @@ func (p *processor) federateAccountUpdate(ctx context.Context, updatedAccount *g
func (p *processor) federateBlock(ctx context.Context, block *gtsmodel.Block) error {
if block.Account == nil {
a := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, block.AccountID, a); err != nil {
blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID)
if err != nil {
return fmt.Errorf("federateBlock: error getting block account from database: %s", err)
}
block.Account = a
block.Account = blockAccount
}
if block.TargetAccount == nil {
a := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, block.TargetAccountID, a); err != nil {
blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID)
if err != nil {
return fmt.Errorf("federateBlock: error getting block target account from database: %s", err)
}
block.TargetAccount = a
block.TargetAccount = blockTargetAccount
}
// if both accounts are local there's nothing to do here
@ -594,19 +594,19 @@ func (p *processor) federateBlock(ctx context.Context, block *gtsmodel.Block) er
func (p *processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) error {
if block.Account == nil {
a := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, block.AccountID, a); err != nil {
blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID)
if err != nil {
return fmt.Errorf("federateUnblock: error getting block account from database: %s", err)
}
block.Account = a
block.Account = blockAccount
}
if block.TargetAccount == nil {
a := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, block.TargetAccountID, a); err != nil {
blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID)
if err != nil {
return fmt.Errorf("federateUnblock: error getting block target account from database: %s", err)
}
block.TargetAccount = a
block.TargetAccount = blockTargetAccount
}
// if both accounts are local there's nothing to do here

View file

@ -7,12 +7,11 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (p *processor) Delete(ctx context.Context, mediaAttachmentID string) gtserror.WithCode {
a := &gtsmodel.MediaAttachment{}
if err := p.db.GetByID(ctx, mediaAttachmentID, a); err != nil {
attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
if err != nil {
if err == db.ErrNoEntries {
// attachment already gone
return nil
@ -24,21 +23,21 @@ func (p *processor) Delete(ctx context.Context, mediaAttachmentID string) gtserr
errs := []string{}
// delete the thumbnail from storage
if a.Thumbnail.Path != "" {
if err := p.storage.RemoveFileAt(a.Thumbnail.Path); err != nil {
errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", a.Thumbnail.Path, err))
if attachment.Thumbnail.Path != "" {
if err := p.storage.RemoveFileAt(attachment.Thumbnail.Path); err != nil {
errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", attachment.Thumbnail.Path, err))
}
}
// delete the file from storage
if a.File.Path != "" {
if err := p.storage.RemoveFileAt(a.File.Path); err != nil {
errs = append(errs, fmt.Sprintf("remove file at path %s: %s", a.File.Path, err))
if attachment.File.Path != "" {
if err := p.storage.RemoveFileAt(attachment.File.Path); err != nil {
errs = append(errs, fmt.Sprintf("remove file at path %s: %s", attachment.File.Path, err))
}
}
// delete the attachment
if err := p.db.DeleteByID(ctx, mediaAttachmentID, a); err != nil {
if err := p.db.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil {
if err != db.ErrNoEntries {
errs = append(errs, fmt.Sprintf("remove attachment: %s", err))
}

View file

@ -48,8 +48,8 @@ func (p *processor) GetFile(ctx context.Context, account *gtsmodel.Account, form
wantedMediaID := spl[0]
// get the account that owns the media and make sure it's not suspended
acct := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, form.AccountID, acct); err != nil {
acct, err := p.db.GetAccountByID(ctx, form.AccountID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("account with id %s could not be selected from the db: %s", form.AccountID, err))
}
if !acct.SuspendedAt.IsZero() {
@ -91,8 +91,8 @@ func (p *processor) GetFile(ctx context.Context, account *gtsmodel.Account, form
return nil, gtserror.NewErrorNotFound(fmt.Errorf("media size %s not recognized for emoji", mediaSize))
}
case media.Attachment, media.Header, media.Avatar:
a := &gtsmodel.MediaAttachment{}
if err := p.db.GetByID(ctx, wantedMediaID, a); err != nil {
a, err := p.db.GetAttachmentByID(ctx, wantedMediaID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("attachment %s could not be taken from the db: %s", wantedMediaID, err))
}
if a.AccountID != form.AccountID {

View file

@ -30,8 +30,8 @@ import (
)
func (p *processor) GetMedia(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
attachment := &gtsmodel.MediaAttachment{}
if err := p.db.GetByID(ctx, mediaAttachmentID, attachment); err != nil {
attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
if err != nil {
if err == db.ErrNoEntries {
// attachment doesn't exist
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))

View file

@ -31,8 +31,8 @@ import (
)
func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) {
attachment := &gtsmodel.MediaAttachment{}
if err := p.db.GetByID(ctx, mediaAttachmentID, attachment); err != nil {
attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
if err != nil {
if err == db.ErrNoEntries {
// attachment doesn't exist
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))

View file

@ -24,8 +24,8 @@ func (p *processor) AuthorizeStreamingRequest(ctx context.Context, accessToken s
return nil, fmt.Errorf("AuthorizeStreamingRequest: no user found for validated uid %s", uid)
}
acct := &gtsmodel.Account{}
if err := p.db.GetByID(ctx, user.AccountID, acct); err != nil || acct == nil {
acct, err := p.db.GetAccountByID(ctx, user.AccountID)
if err != nil || acct == nil {
return nil, fmt.Errorf("AuthorizeStreamingRequest: no account retrieved for user with id %s", uid)
}

View file

@ -20,7 +20,6 @@ package router
import (
"context"
"crypto/rand"
"errors"
"fmt"
"net/http"
@ -30,8 +29,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
)
// SessionOptions returns the standard set of options to use for each session.
@ -46,34 +43,23 @@ func SessionOptions(cfg *config.Config) sessions.Options {
}
}
func useSession(ctx context.Context, cfg *config.Config, dbService db.DB, engine *gin.Engine) error {
func useSession(ctx context.Context, cfg *config.Config, sessionDB db.Session, engine *gin.Engine) error {
// check if we have a saved router session already
routerSessions := []*gtsmodel.RouterSession{}
if err := dbService.GetAll(ctx, &routerSessions); err != nil {
rs, err := sessionDB.GetSession(ctx)
if err != nil {
if err != db.ErrNoEntries {
// proper error occurred
return err
}
}
var rs *gtsmodel.RouterSession
if len(routerSessions) == 1 {
// we have a router session stored
rs = routerSessions[0]
} else if len(routerSessions) == 0 {
// we have no router sessions so we need to create a new one
var err error
rs, err = routerSession(ctx, dbService)
// no session saved so create a new one
rs, err = sessionDB.CreateSession(ctx)
if err != nil {
return fmt.Errorf("error creating new router session: %s", err)
return err
}
} else {
// we should only have one router session stored ever
return errors.New("we had more than one router session in the db")
}
if rs == nil {
return errors.New("error getting or creating router session: router session was nil")
return errors.New("router session was nil")
}
store := memstore.NewStore(rs.Auth, rs.Crypt)
@ -82,34 +68,3 @@ func useSession(ctx context.Context, cfg *config.Config, dbService db.DB, engine
engine.Use(sessions.Sessions(sessionName, store))
return nil
}
// routerSession generates a new router session with random auth and crypt bytes,
// puts it in the database for persistence, and returns it for use.
func routerSession(ctx context.Context, dbService db.DB) (*gtsmodel.RouterSession, error) {
auth := make([]byte, 32)
crypt := make([]byte, 32)
if _, err := rand.Read(auth); err != nil {
return nil, err
}
if _, err := rand.Read(crypt); err != nil {
return nil, err
}
rid, err := id.NewULID()
if err != nil {
return nil, err
}
rs := &gtsmodel.RouterSession{
ID: rid,
Auth: auth,
Crypt: crypt,
}
if err := dbService.Put(ctx, rs); err != nil {
return nil, err
}
return rs, nil
}

View file

@ -98,8 +98,8 @@ func (f *formatter) ReplaceMentions(ctx context.Context, in string, mentions []*
// got it from the mention
targetAccount = menchie.OriginAccount
} else {
a := &gtsmodel.Account{}
if err := f.db.GetByID(ctx, menchie.TargetAccountID, a); err == nil {
a, err := f.db.GetAccountByID(ctx, menchie.TargetAccountID)
if err == nil {
// got it from the db
targetAccount = a
} else {

View file

@ -54,7 +54,8 @@ func (t *timeline) Get(ctx context.Context, amount int, maxID string, sinceID st
if prepareNext {
// already cache the next query to speed up scrolling
go func() {
if err := t.prepareNextQuery(ctx, amount, nextMaxID, "", ""); err != nil {
// use context.Background() because we don't want the query to abort when the request finishes
if err := t.prepareNextQuery(context.Background(), amount, nextMaxID, "", ""); err != nil {
l.Errorf("error preparing next query: %s", err)
}
}()
@ -73,7 +74,8 @@ func (t *timeline) Get(ctx context.Context, amount int, maxID string, sinceID st
if prepareNext {
// already cache the next query to speed up scrolling
go func() {
if err := t.prepareNextQuery(ctx, amount, nextMaxID, "", ""); err != nil {
// use context.Background() because we don't want the query to abort when the request finishes
if err := t.prepareNextQuery(context.Background(), amount, nextMaxID, "", ""); err != nil {
l.Errorf("error preparing next query: %s", err)
}
}()

View file

@ -214,9 +214,12 @@ func (c *converter) AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab
// icon
// Used as profile avatar.
if a.AvatarMediaAttachmentID != "" {
avatar := &gtsmodel.MediaAttachment{}
if err := c.db.GetByID(ctx, a.AvatarMediaAttachmentID, avatar); err != nil {
return nil, err
if a.AvatarMediaAttachment == nil {
avatar := &gtsmodel.MediaAttachment{}
if err := c.db.GetByID(ctx, a.AvatarMediaAttachmentID, avatar); err != nil {
return nil, err
}
a.AvatarMediaAttachment = avatar
}
iconProperty := streams.NewActivityStreamsIconProperty()
@ -224,11 +227,11 @@ func (c *converter) AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab
iconImage := streams.NewActivityStreamsImage()
mediaType := streams.NewActivityStreamsMediaTypeProperty()
mediaType.Set(avatar.File.ContentType)
mediaType.Set(a.AvatarMediaAttachment.File.ContentType)
iconImage.SetActivityStreamsMediaType(mediaType)
avatarURLProperty := streams.NewActivityStreamsUrlProperty()
avatarURL, err := url.Parse(avatar.URL)
avatarURL, err := url.Parse(a.AvatarMediaAttachment.URL)
if err != nil {
return nil, err
}
@ -242,9 +245,12 @@ func (c *converter) AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab
// image
// Used as profile header.
if a.HeaderMediaAttachmentID != "" {
header := &gtsmodel.MediaAttachment{}
if err := c.db.GetByID(ctx, a.HeaderMediaAttachmentID, header); err != nil {
return nil, err
if a.HeaderMediaAttachment == nil {
header := &gtsmodel.MediaAttachment{}
if err := c.db.GetByID(ctx, a.HeaderMediaAttachmentID, header); err != nil {
return nil, err
}
a.HeaderMediaAttachment = header
}
headerProperty := streams.NewActivityStreamsImageProperty()
@ -252,11 +258,11 @@ func (c *converter) AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab
headerImage := streams.NewActivityStreamsImage()
mediaType := streams.NewActivityStreamsMediaTypeProperty()
mediaType.Set(header.File.ContentType)
mediaType.Set(a.HeaderMediaAttachment.File.ContentType)
headerImage.SetActivityStreamsMediaType(mediaType)
headerURLProperty := streams.NewActivityStreamsUrlProperty()
headerURL, err := url.Parse(header.URL)
headerURL, err := url.Parse(a.HeaderMediaAttachment.URL)
if err != nil {
return nil, err
}

View file

@ -63,6 +63,10 @@ func (c *converter) AccountToMastoSensitive(ctx context.Context, a *gtsmodel.Acc
}
func (c *converter) AccountToMastoPublic(ctx context.Context, a *gtsmodel.Account) (*model.Account, error) {
if a == nil {
return nil, fmt.Errorf("given account was nil")
}
// first check if we have this account in our frontEnd cache
if accountI, err := c.frontendCache.Fetch(a.ID); err == nil {
if account, ok := accountI.(*model.Account); ok {
@ -266,27 +270,30 @@ func (c *converter) AttachmentToMasto(ctx context.Context, a *gtsmodel.MediaAtta
}
func (c *converter) MentionToMasto(ctx context.Context, m *gtsmodel.Mention) (model.Mention, error) {
target := &gtsmodel.Account{}
if err := c.db.GetByID(ctx, m.TargetAccountID, target); err != nil {
return model.Mention{}, err
if m.TargetAccount == nil {
targetAccount, err := c.db.GetAccountByID(ctx, m.TargetAccountID)
if err != nil {
return model.Mention{}, err
}
m.TargetAccount = targetAccount
}
var local bool
if target.Domain == "" {
if m.TargetAccount.Domain == "" {
local = true
}
var acct string
if local {
acct = target.Username
acct = m.TargetAccount.Username
} else {
acct = fmt.Sprintf("%s@%s", target.Username, target.Domain)
acct = fmt.Sprintf("%s@%s", m.TargetAccount.Username, m.TargetAccount.Domain)
}
return model.Mention{
ID: target.ID,
Username: target.Username,
URL: target.URL,
ID: m.TargetAccount.ID,
Username: m.TargetAccount.Username,
URL: m.TargetAccount.URL,
Acct: acct,
}, nil
}
@ -302,11 +309,9 @@ func (c *converter) EmojiToMasto(ctx context.Context, e *gtsmodel.Emoji) (model.
}
func (c *converter) TagToMasto(ctx context.Context, t *gtsmodel.Tag) (model.Tag, error) {
tagURL := fmt.Sprintf("%s://%s/tags/%s", c.config.Protocol, c.config.Host, t.Name)
return model.Tag{
Name: t.Name,
URL: tagURL, // we don't serve URLs with collections of tagged statuses (FOR NOW) so this is purely for mastodon compatibility ¯\_(ツ)_/¯
URL: t.URL,
}, nil
}
@ -331,8 +336,8 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque
// the boosted status might have been set on this struct already so check first before doing db calls
if s.BoostOf == nil {
// it's not set so fetch it from the db
bs := &gtsmodel.Status{}
if err := c.db.GetByID(ctx, s.BoostOfID, bs); err != nil {
bs, err := c.db.GetStatusByID(ctx, s.BoostOfID)
if err != nil {
return nil, fmt.Errorf("error getting boosted status with id %s: %s", s.BoostOfID, err)
}
s.BoostOf = bs
@ -341,8 +346,8 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque
// the boosted account might have been set on this struct already or passed as a param so check first before doing db calls
if s.BoostOfAccount == nil {
// it's not set so fetch it from the db
ba := &gtsmodel.Account{}
if err := c.db.GetByID(ctx, s.BoostOf.AccountID, ba); err != nil {
ba, err := c.db.GetAccountByID(ctx, s.BoostOf.AccountID)
if err != nil {
return nil, fmt.Errorf("error getting boosted account %s from status with id %s: %s", s.BoostOf.AccountID, s.BoostOfID, err)
}
s.BoostOfAccount = ba
@ -368,8 +373,8 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque
}
if s.Account == nil {
a := &gtsmodel.Account{}
if err := c.db.GetByID(ctx, s.AccountID, a); err != nil {
a, err := c.db.GetAccountByID(ctx, s.AccountID)
if err != nil {
return nil, fmt.Errorf("error getting status author: %s", err)
}
s.Account = a
@ -394,14 +399,14 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque
// the status doesn't have gts attachments on it, but it does have attachment IDs
// in this case, we need to pull the gts attachments from the db to convert them into masto ones
} else {
for _, a := range s.AttachmentIDs {
gtsAttachment := &gtsmodel.MediaAttachment{}
if err := c.db.GetByID(ctx, a, gtsAttachment); err != nil {
return nil, fmt.Errorf("error getting attachment with id %s: %s", a, err)
for _, aID := range s.AttachmentIDs {
gtsAttachment, err := c.db.GetAttachmentByID(ctx, aID)
if err != nil {
return nil, fmt.Errorf("error getting attachment with id %s: %s", aID, err)
}
mastoAttachment, err := c.AttachmentToMasto(ctx, gtsAttachment)
if err != nil {
return nil, fmt.Errorf("error converting attachment with id %s: %s", a, err)
return nil, fmt.Errorf("error converting attachment with id %s: %s", aID, err)
}
mastoAttachments = append(mastoAttachments, mastoAttachment)
}
@ -421,10 +426,10 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque
// the status doesn't have gts mentions on it, but it does have mention IDs
// in this case, we need to pull the gts mentions from the db to convert them into masto ones
} else {
for _, m := range s.MentionIDs {
gtsMention := &gtsmodel.Mention{}
if err := c.db.GetByID(ctx, m, gtsMention); err != nil {
return nil, fmt.Errorf("error getting mention with id %s: %s", m, err)
for _, mID := range s.MentionIDs {
gtsMention, err := c.db.GetMention(ctx, mID)
if err != nil {
return nil, fmt.Errorf("error getting mention with id %s: %s", mID, err)
}
mastoMention, err := c.MentionToMasto(ctx, gtsMention)
if err != nil {
@ -603,13 +608,16 @@ func (c *converter) InstanceToMasto(ctx context.Context, i *gtsmodel.Instance) (
// contact account is optional but let's try to get it
if i.ContactAccountID != "" {
ia := &gtsmodel.Account{}
if err := c.db.GetByID(ctx, i.ContactAccountID, ia); err == nil {
ma, err := c.AccountToMastoPublic(ctx, ia)
if i.ContactAccount == nil {
contactAccount, err := c.db.GetAccountByID(ctx, i.ContactAccountID)
if err == nil {
mi.ContactAccount = ma
i.ContactAccount = contactAccount
}
}
ma, err := c.AccountToMastoPublic(ctx, i.ContactAccount)
if err == nil {
mi.ContactAccount = ma
}
}
return mi, nil

View file

@ -24,7 +24,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/pg"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
@ -68,7 +68,7 @@ func NewTestDB() db.DB {
l := logrus.New()
l.SetLevel(logrus.TraceLevel)
testDB, err := pg.NewPostgresService(context.Background(), config, l)
testDB, err := bundb.NewBunDBService(context.Background(), config, l)
if err != nil {
logrus.Panic(err)
}

View file

@ -11,10 +11,6 @@
[![Documentation](https://img.shields.io/badge/bun-documentation-informational)](https://bun.uptrace.dev/)
[![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj)
Status: API freeze (release candidate). Note that all sub-packages (mainly extra/\* packages) are
not part of the API freeze and are developed independently. You can think of them as 3-rd party
packages.
Main features are:
- Works with [PostgreSQL](https://bun.uptrace.dev/guide/drivers.html#postgresql),

View file

@ -4,4 +4,4 @@ go 1.16
replace github.com/uptrace/bun => ../..
require github.com/uptrace/bun v1.0.0-rc.1
require github.com/uptrace/bun v0.4.3

View file

@ -20,4 +20,4 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -89,10 +89,10 @@ func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) []
func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte {
var namedArgs NamedArgAppender
if len(args) == 1 {
if v, ok := args[0].(NamedArgAppender); ok {
namedArgs = v
} else if v, ok := newStructArgs(f, args[0]); ok {
namedArgs = v
var ok bool
namedArgs, ok = args[0].(NamedArgAppender)
if !ok {
namedArgs, _ = newStructArgs(f, args[0])
}
}

View file

@ -2,5 +2,5 @@ package bun
// Version is the current release version.
func Version() string {
return "1.0.0-rc.1"
return "0.4.3"
}

4
vendor/modules.txt vendored
View file

@ -389,7 +389,7 @@ github.com/tdewolff/parse/v2/strconv
github.com/tmthrgd/go-hex
# github.com/ugorji/go/codec v1.2.6
github.com/ugorji/go/codec
# github.com/uptrace/bun v1.0.0-rc.1
# github.com/uptrace/bun v0.4.3
## explicit
github.com/uptrace/bun
github.com/uptrace/bun/dialect
@ -400,7 +400,7 @@ github.com/uptrace/bun/internal
github.com/uptrace/bun/internal/parser
github.com/uptrace/bun/internal/tagparser
github.com/uptrace/bun/schema
# github.com/uptrace/bun/dialect/pgdialect v1.0.0-rc.1
# github.com/uptrace/bun/dialect/pgdialect v0.4.3
## explicit
github.com/uptrace/bun/dialect/pgdialect
# github.com/urfave/cli/v2 v2.3.0