diff --git a/go.sum b/go.sum
index 56a215cd3..42d24ad8c 100644
--- a/go.sum
+++ b/go.sum
@@ -446,10 +446,10 @@ github.com/ugorji/go v1.2.6/go.mod h1:anCg0y61KIhDlPZmnH+so+RQbysYVyDko0IMgJv0Nn
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
github.com/ugorji/go/codec v1.2.6 h1:7kbGefxLoDBuYXOms4yD7223OpNMMPNPZxXk5TvFcyQ=
github.com/ugorji/go/codec v1.2.6/go.mod h1:V6TCNZ4PHqoHGFZuSG1W8nrCzzdgA2DozYxWFFpvxTw=
-github.com/uptrace/bun v1.0.0-rc.1 h1:yMjz/JdlgYRema5mk59URzjAV7HpOV58Fj3KID34K9U=
-github.com/uptrace/bun v1.0.0-rc.1/go.mod h1:aL6D9vPw8DXaTQTwGrEPtUderBYXx7ShUmPfnxnqscw=
-github.com/uptrace/bun/dialect/pgdialect v1.0.0-rc.1 h1:4rxcO4+x8r2xyCnduH9fhhkjhbtzLHFl3vmx2goVgio=
-github.com/uptrace/bun/dialect/pgdialect v1.0.0-rc.1/go.mod h1:x48KjNeAGTSq4WNIOvtfzAOFjPG/V/Gx+SdwO5ksimU=
+github.com/uptrace/bun v0.4.3 h1:x6bjDqwjxwM/9Q1eauhkznuvTrz/rLiCK2p4tT63sAE=
+github.com/uptrace/bun v0.4.3/go.mod h1:aL6D9vPw8DXaTQTwGrEPtUderBYXx7ShUmPfnxnqscw=
+github.com/uptrace/bun/dialect/pgdialect v0.4.3 h1:lM2IUKpU99110chKkupw3oTfXiOKpB0hTJIe6frqQDo=
+github.com/uptrace/bun/dialect/pgdialect v0.4.3/go.mod h1:BaNvWejl32oKUhwpFkw/eNcWldzIlVY4nfw/sNul0s8=
github.com/urfave/cli/v2 v2.3.0 h1:qph92Y649prgesehzOrQjdWyxFOp/QVM+6imKHad91M=
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
diff --git a/internal/api/client/account/accountupdate_test.go b/internal/api/client/account/accountupdate_test.go
index 349429625..8fc31171b 100644
--- a/internal/api/client/account/accountupdate_test.go
+++ b/internal/api/client/account/accountupdate_test.go
@@ -79,7 +79,7 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandler()
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"])
- ctx.Set(oauth.SessionAuthorizedToken, oauth.TokenToOauthToken(suite.testTokens["local_account_1"]))
+ ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", account.UpdateCredentialsPath), bytes.NewReader(requestBody.Bytes())) // the endpoint we're hitting
diff --git a/internal/api/client/auth/auth_test.go b/internal/api/client/auth/auth_test.go
index d1aedb6d8..3d5170f31 100644
--- a/internal/api/client/auth/auth_test.go
+++ b/internal/api/client/auth/auth_test.go
@@ -28,7 +28,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/db/pg"
+ "github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"golang.org/x/crypto/bcrypt"
@@ -104,7 +104,7 @@ func (suite *AuthTestSuite) SetupTest() {
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
- db, err := pg.NewPostgresService(context.Background(), suite.config, log)
+ db, err := bundb.NewBunDBService(context.Background(), suite.config, log)
if err != nil {
logrus.Panicf("error creating database connection: %s", err)
}
diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go
index 34e53f5ae..d67b39ed5 100644
--- a/internal/api/client/auth/authorize.go
+++ b/internal/api/client/auth/authorize.go
@@ -80,20 +80,15 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
}
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3
- user := >smodel.User{
- ID: userID,
- }
+ user := >smodel.User{}
if err := m.db.GetByID(c.Request.Context(), user.ID, user); err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
- acct := >smodel.Account{
- ID: user.AccountID,
- }
-
- if err := m.db.GetByID(c.Request.Context(), acct.ID, acct); err != nil {
+ acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
+ if err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
diff --git a/internal/api/client/auth/middleware.go b/internal/api/client/auth/middleware.go
index c1995ca92..3599c7048 100644
--- a/internal/api/client/auth/middleware.go
+++ b/internal/api/client/auth/middleware.go
@@ -56,8 +56,8 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) {
c.Set(oauth.SessionAuthorizedUser, user)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user)
- acct := >smodel.Account{}
- if err := m.db.GetByID(c.Request.Context(), user.AccountID, acct); err != nil || acct == nil {
+ acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
+ if err != nil || acct == nil {
l.Warnf("no account found for validated user %s", uid)
return
}
diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go
index 5c48a4381..8433786e4 100644
--- a/internal/api/client/media/mediacreate_test.go
+++ b/internal/api/client/media/mediacreate_test.go
@@ -121,7 +121,7 @@ func (suite *MediaCreateTestSuite) TestStatusCreatePOSTImageHandlerSuccessful()
// set up the context for the request
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
diff --git a/internal/api/client/status/statusboost_test.go b/internal/api/client/status/statusboost_test.go
index fbe267fac..4157bde38 100644
--- a/internal/api/client/status/statusboost_test.go
+++ b/internal/api/client/status/statusboost_test.go
@@ -67,7 +67,7 @@ func (suite *StatusBoostTestSuite) TearDownTest() {
func (suite *StatusBoostTestSuite) TestPostBoost() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["admin_account_status_1"]
@@ -133,7 +133,7 @@ func (suite *StatusBoostTestSuite) TestPostBoost() {
func (suite *StatusBoostTestSuite) TestPostUnboostable() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["local_account_2_status_4"]
@@ -171,7 +171,7 @@ func (suite *StatusBoostTestSuite) TestPostUnboostable() {
func (suite *StatusBoostTestSuite) TestPostNotVisible() {
t := suite.testTokens["local_account_2"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["local_account_1_status_3"] // this is a mutual only status and these accounts aren't mutuals
diff --git a/internal/api/client/status/statuscreate_test.go b/internal/api/client/status/statuscreate_test.go
index 097f268f2..060d96dad 100644
--- a/internal/api/client/status/statuscreate_test.go
+++ b/internal/api/client/status/statuscreate_test.go
@@ -83,7 +83,7 @@ https://docs.gotosocial.org/en/latest/user_guide/posts/#links
func (suite *StatusCreateTestSuite) TestPostNewStatus() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@@ -137,7 +137,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatus() {
func (suite *StatusCreateTestSuite) TestPostAnotherNewStatus() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@@ -172,7 +172,7 @@ func (suite *StatusCreateTestSuite) TestPostAnotherNewStatus() {
func (suite *StatusCreateTestSuite) TestPostNewStatusWithEmoji() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@@ -213,7 +213,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatusWithEmoji() {
// Try to reply to a status that doesn't exist
func (suite *StatusCreateTestSuite) TestReplyToNonexistentStatus() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@@ -244,7 +244,7 @@ func (suite *StatusCreateTestSuite) TestReplyToNonexistentStatus() {
// Post a reply to the status of a local user that allows replies.
func (suite *StatusCreateTestSuite) TestReplyToLocalStatus() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@@ -284,7 +284,7 @@ func (suite *StatusCreateTestSuite) TestReplyToLocalStatus() {
// Take a media file which is currently not associated with a status, and attach it to a new status.
func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
attachment := suite.testAttachments["local_account_1_unattached_1"]
@@ -323,8 +323,7 @@ func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() {
assert.Len(suite.T(), statusResponse.MediaAttachments, 1)
// get the updated media attachment from the database
- gtsAttachment := >smodel.MediaAttachment{}
- err = suite.db.GetByID(context.Background(), statusResponse.MediaAttachments[0].ID, gtsAttachment)
+ gtsAttachment, err := suite.db.GetAttachmentByID(context.Background(), statusResponse.MediaAttachments[0].ID)
assert.NoError(suite.T(), err)
// convert it to a masto attachment
diff --git a/internal/api/client/status/statusfave_test.go b/internal/api/client/status/statusfave_test.go
index 0f44b5e90..2f7a2c596 100644
--- a/internal/api/client/status/statusfave_test.go
+++ b/internal/api/client/status/statusfave_test.go
@@ -71,7 +71,7 @@ func (suite *StatusFaveTestSuite) TearDownTest() {
func (suite *StatusFaveTestSuite) TestPostFave() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["admin_account_status_2"]
@@ -119,7 +119,7 @@ func (suite *StatusFaveTestSuite) TestPostFave() {
func (suite *StatusFaveTestSuite) TestPostUnfaveable() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["local_account_2_status_3"] // this one is unlikeable and unreplyable
diff --git a/internal/api/client/status/statusfavedby_test.go b/internal/api/client/status/statusfavedby_test.go
index 22a549b30..7475f1e69 100644
--- a/internal/api/client/status/statusfavedby_test.go
+++ b/internal/api/client/status/statusfavedby_test.go
@@ -69,7 +69,7 @@ func (suite *StatusFavedByTestSuite) TearDownTest() {
func (suite *StatusFavedByTestSuite) TestGetFavedBy() {
t := suite.testTokens["local_account_2"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["admin_account_status_1"] // this status is faved by local_account_1
diff --git a/internal/api/client/status/statusunfave_test.go b/internal/api/client/status/statusunfave_test.go
index a5f267f4c..9e7ea8f82 100644
--- a/internal/api/client/status/statusunfave_test.go
+++ b/internal/api/client/status/statusunfave_test.go
@@ -71,7 +71,7 @@ func (suite *StatusUnfaveTestSuite) TearDownTest() {
func (suite *StatusUnfaveTestSuite) TestPostUnfave() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
// this is the status we wanna unfave: in the testrig it's already faved by this account
targetStatus := suite.testStatuses["admin_account_status_1"]
@@ -120,7 +120,7 @@ func (suite *StatusUnfaveTestSuite) TestPostUnfave() {
func (suite *StatusUnfaveTestSuite) TestPostAlreadyNotFaved() {
t := suite.testTokens["local_account_1"]
- oauthToken := oauth.TokenToOauthToken(t)
+ oauthToken := oauth.DBTokenToToken(t)
// this is the status we wanna unfave: in the testrig it's not faved by this account
targetStatus := suite.testStatuses["admin_account_status_2"]
diff --git a/internal/cache/cache.go b/internal/cache/cache.go
index eb3744cfe..ce4aad04d 100644
--- a/internal/cache/cache.go
+++ b/internal/cache/cache.go
@@ -37,7 +37,7 @@ type cache struct {
// New returns a new in-memory cache.
func New() Cache {
c := ttlcache.NewCache()
- c.SetTTL(30 * time.Second)
+ c.SetTTL(5 * time.Minute)
cache := &cache{
c: c,
}
diff --git a/internal/cliactions/admin/account/account.go b/internal/cliactions/admin/account/account.go
index 82058fe25..46998ec6a 100644
--- a/internal/cliactions/admin/account/account.go
+++ b/internal/cliactions/admin/account/account.go
@@ -28,7 +28,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/cliactions"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/db/pg"
+ "github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
"golang.org/x/crypto/bcrypt"
@@ -36,7 +36,7 @@ import (
// Create creates a new account in the database using the provided flags.
var Create cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- dbConn, err := pg.NewPostgresService(ctx, c, log)
+ dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@@ -75,7 +75,7 @@ var Create cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo
// Confirm sets a user to Approved, sets Email to the current UnconfirmedEmail value, and sets ConfirmedAt to now.
var Confirm cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- dbConn, err := pg.NewPostgresService(ctx, c, log)
+ dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@@ -110,7 +110,7 @@ var Confirm cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
// Promote sets a user to admin.
var Promote cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- dbConn, err := pg.NewPostgresService(ctx, c, log)
+ dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@@ -142,7 +142,7 @@ var Promote cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
// Demote sets admin on a user to false.
var Demote cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- dbConn, err := pg.NewPostgresService(ctx, c, log)
+ dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@@ -174,7 +174,7 @@ var Demote cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo
// Disable sets Disabled to true on a user.
var Disable cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- dbConn, err := pg.NewPostgresService(ctx, c, log)
+ dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
@@ -212,7 +212,7 @@ var Suspend cliactions.GTSAction = func(ctx context.Context, c *config.Config, l
// Password sets the password of target account.
var Password cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- dbConn, err := pg.NewPostgresService(ctx, c, log)
+ dbConn, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
diff --git a/internal/cliactions/server/server.go b/internal/cliactions/server/server.go
index be44c00dc..877a9d397 100644
--- a/internal/cliactions/server/server.go
+++ b/internal/cliactions/server/server.go
@@ -35,7 +35,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/blob"
"github.com/superseriousbusiness/gotosocial/internal/cliactions"
"github.com/superseriousbusiness/gotosocial/internal/config"
- "github.com/superseriousbusiness/gotosocial/internal/db/pg"
+ "github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/gotosocial"
@@ -79,7 +79,7 @@ var models []interface{} = []interface{}{
// Start creates and starts a gotosocial server
var Start cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
- dbService, err := pg.NewPostgresService(ctx, c, log)
+ dbService, err := bundb.NewBunDBService(ctx, c, log)
if err != nil {
return fmt.Errorf("error creating dbservice: %s", err)
}
diff --git a/internal/db/account.go b/internal/db/account.go
index 61d97bf8c..058a89859 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -36,6 +36,9 @@ type Account interface {
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)
+ // UpdateAccount updates one account by ID.
+ UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)
+
// GetLocalAccountByUsername returns an account on this instance by its username.
GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, Error)
diff --git a/internal/db/pg/account.go b/internal/db/bundb/account.go
similarity index 94%
rename from internal/db/pg/account.go
rename to internal/db/bundb/account.go
index 73b71cf11..7ebb79a15 100644
--- a/internal/db/pg/account.go
+++ b/internal/db/bundb/account.go
@@ -16,12 +16,13 @@
along with this program. If not, see .
*/
-package pg
+package bundb
import (
"context"
"errors"
"fmt"
+ "strings"
"time"
"github.com/sirupsen/logrus"
@@ -35,7 +36,6 @@ type accountDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
@@ -79,6 +79,25 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.
return account, err
}
+func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) {
+ if strings.TrimSpace(account.ID) == "" {
+ return nil, errors.New("account had no ID")
+ }
+
+ account.UpdatedAt = time.Now()
+
+ q := a.conn.
+ NewUpdate().
+ Model(account).
+ WherePK()
+
+ _, err := q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return account, err
+}
+
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
account := new(gtsmodel.Account)
diff --git a/internal/db/pg/account_test.go b/internal/db/bundb/account_test.go
similarity index 80%
rename from internal/db/pg/account_test.go
rename to internal/db/bundb/account_test.go
index df4d244bf..7174b781d 100644
--- a/internal/db/pg/account_test.go
+++ b/internal/db/bundb/account_test.go
@@ -16,18 +16,19 @@
along with this program. If not, see .
*/
-package pg_test
+package bundb_test
import (
"context"
"testing"
+ "time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type AccountTestSuite struct {
- PGStandardTestSuite
+ BunDBStandardTestSuite
}
func (suite *AccountTestSuite) SetupSuite() {
@@ -66,6 +67,20 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
suite.NotEmpty(account.HeaderMediaAttachment.URL)
}
+func (suite *AccountTestSuite) TestUpdateAccount() {
+ testAccount := suite.testAccounts["local_account_1"]
+
+ testAccount.DisplayName = "new display name!"
+
+ _, err := suite.db.UpdateAccount(context.Background(), testAccount)
+ suite.NoError(err)
+
+ updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID)
+ suite.NoError(err)
+ suite.Equal("new display name!", updated.DisplayName)
+ suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second)
+}
+
func TestAccountTestSuite(t *testing.T) {
suite.Run(t, new(AccountTestSuite))
}
diff --git a/internal/db/pg/admin.go b/internal/db/bundb/admin.go
similarity index 99%
rename from internal/db/pg/admin.go
rename to internal/db/bundb/admin.go
index e9fd01b11..67a1e8a0d 100644
--- a/internal/db/pg/admin.go
+++ b/internal/db/bundb/admin.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package pg
+package bundb
import (
"context"
@@ -43,7 +43,6 @@ type adminDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
diff --git a/internal/db/pg/basic.go b/internal/db/bundb/basic.go
similarity index 97%
rename from internal/db/pg/basic.go
rename to internal/db/bundb/basic.go
index dd91acb34..983b6b810 100644
--- a/internal/db/pg/basic.go
+++ b/internal/db/bundb/basic.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package pg
+package bundb
import (
"context"
@@ -33,7 +33,6 @@ type basicDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error {
@@ -116,7 +115,7 @@ func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.E
q := b.conn.
NewUpdate().
Model(i).
- Where("id = ?", id)
+ WherePK()
_, err := q.Exec(ctx)
@@ -127,7 +126,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu
q := b.conn.NewUpdate().
Model(i).
Set("? = ?", bun.Safe(key), value).
- Where("id = ?", id)
+ WherePK()
_, err := q.Exec(ctx)
@@ -174,7 +173,6 @@ func (b *basicDB) Stop(ctx context.Context) db.Error {
b.log.Info("closing db connection")
if err := b.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
- b.cancel()
return err
}
return nil
diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go
new file mode 100644
index 000000000..9189618c9
--- /dev/null
+++ b/internal/db/bundb/basic_test.go
@@ -0,0 +1,68 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type BasicTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *BasicTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+ suite.testStatuses = testrig.NewTestStatuses()
+ suite.testTags = testrig.NewTestTags()
+ suite.testMentions = testrig.NewTestMentions()
+}
+
+func (suite *BasicTestSuite) SetupTest() {
+ suite.config = testrig.NewTestConfig()
+ suite.db = testrig.NewTestDB()
+ suite.log = testrig.NewTestLog()
+
+ testrig.StandardDBSetup(suite.db, suite.testAccounts)
+}
+
+func (suite *BasicTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+}
+
+func (suite *BasicTestSuite) TestGetAccountByID() {
+ testAccount := suite.testAccounts["local_account_1"]
+
+ a := >smodel.Account{}
+ err := suite.db.GetByID(context.Background(), testAccount.ID, a)
+ suite.NoError(err)
+}
+
+func TestBasicTestSuite(t *testing.T) {
+ suite.Run(t, new(BasicTestSuite))
+}
diff --git a/internal/db/pg/pg.go b/internal/db/bundb/bundb.go
similarity index 81%
rename from internal/db/pg/pg.go
rename to internal/db/bundb/bundb.go
index 5d1a1d01a..49ed09cbd 100644
--- a/internal/db/pg/pg.go
+++ b/internal/db/bundb/bundb.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package pg
+package bundb
import (
"context"
@@ -41,13 +41,18 @@ import (
"github.com/uptrace/bun/dialect/pgdialect"
)
+const (
+ dbTypePostgres = "postgres"
+ dbTypeSqlite = "sqlite"
+)
+
var registerTables []interface{} = []interface{}{
>smodel.StatusToEmoji{},
>smodel.StatusToTag{},
}
-// postgresService satisfies the DB interface
-type postgresService struct {
+// bunDBService satisfies the DB interface
+type bunDBService struct {
db.Account
db.Admin
db.Basic
@@ -57,37 +62,49 @@ type postgresService struct {
db.Mention
db.Notification
db.Relationship
+ db.Session
db.Status
db.Timeline
config *config.Config
conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
-// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
-// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
-func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
+// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
+// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
+func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
+ var sqldb *sql.DB
+ var conn *bun.DB
- opts, err := derivePGOptions(c)
- if err != nil {
- return nil, fmt.Errorf("could not create postgres options: %s", err)
+ // depending on the database type we're trying to create, we need to use a different driver...
+ switch strings.ToLower(c.DBConfig.Type) {
+ case dbTypePostgres:
+ // POSTGRES
+ opts, err := deriveBunDBPGOptions(c)
+ if err != nil {
+ return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
+ }
+ sqldb = stdlib.OpenDB(*opts)
+ conn = bun.NewDB(sqldb, pgdialect.New())
+ case dbTypeSqlite:
+ // SQLITE
+ // TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite
+ default:
+ return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
}
- sqldb := stdlib.OpenDB(*opts)
- conn := bun.NewDB(sqldb, pgdialect.New())
// actually *begin* the connection so that we can tell if the db is there and listening
if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("db connection error: %s", err)
}
- log.Info("connected to postgres")
+ log.Info("connected to database")
for _, t := range registerTables {
// https://bun.uptrace.dev/orm/many-to-many-relation/
conn.RegisterModel(t)
}
- ps := &postgresService{
+ ps := &bunDBService{
Account: &accountDB{
config: c,
conn: conn,
@@ -133,6 +150,11 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
conn: conn,
log: log,
},
+ Session: &sessionDB{
+ config: c,
+ conn: conn,
+ log: log,
+ },
Status: &statusDB{
config: c,
conn: conn,
@@ -148,7 +170,7 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
log: log,
}
- // we can confidently return this useable postgres service now
+ // we can confidently return this useable service now
return ps, nil
}
@@ -156,9 +178,9 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
HANDY STUFF
*/
-// derivePGOptions takes an application config and returns either a ready-to-use set of options
+// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options
// with sensible defaults, or an error if it's not satisfied by the provided config.
-func derivePGOptions(c *config.Config) (*pgx.ConnConfig, error) {
+func deriveBunDBPGOptions(c *config.Config) (*pgx.ConnConfig, error) {
if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres {
return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type)
}
@@ -236,15 +258,16 @@ func derivePGOptions(c *config.Config) (*pgx.ConnConfig, error) {
tlsConfig.RootCAs = certPool
}
- opts, _ := pgx.ParseConfig("")
- opts.Host = c.DBConfig.Address
- opts.Port = uint16(c.DBConfig.Port)
- opts.User = c.DBConfig.User
- opts.Password = c.DBConfig.Password
- opts.TLSConfig = tlsConfig
- opts.PreferSimpleProtocol = true
+ cfg, _ := pgx.ParseConfig("")
+ cfg.Host = c.DBConfig.Address
+ cfg.Port = uint16(c.DBConfig.Port)
+ cfg.User = c.DBConfig.User
+ cfg.Password = c.DBConfig.Password
+ cfg.TLSConfig = tlsConfig
+ cfg.Database = c.DBConfig.Database
+ cfg.PreferSimpleProtocol = true
- return opts, nil
+ return cfg, nil
}
/*
@@ -253,7 +276,7 @@ func derivePGOptions(c *config.Config) (*pgx.ConnConfig, error) {
// TODO: move these to the type converter, it's bananas that they're here and not there
-func (ps *postgresService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
+func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) {
ogAccount := >smodel.Account{}
if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil {
return nil, err
@@ -331,7 +354,7 @@ func (ps *postgresService) MentionStringsToMentions(ctx context.Context, targetA
return menchies, nil
}
-func (ps *postgresService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
+func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) {
newTags := []*gtsmodel.Tag{}
for _, t := range tags {
tag := >smodel.Tag{}
@@ -367,7 +390,7 @@ func (ps *postgresService) TagStringsToTags(ctx context.Context, tags []string,
return newTags, nil
}
-func (ps *postgresService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
+func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) {
newEmojis := []*gtsmodel.Emoji{}
for _, e := range emojis {
emoji := >smodel.Emoji{}
diff --git a/internal/db/pg/pg_test.go b/internal/db/bundb/bundb_test.go
similarity index 96%
rename from internal/db/pg/pg_test.go
rename to internal/db/bundb/bundb_test.go
index c1e10abdf..b789375af 100644
--- a/internal/db/pg/pg_test.go
+++ b/internal/db/bundb/bundb_test.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package pg_test
+package bundb_test
import (
"github.com/sirupsen/logrus"
@@ -27,7 +27,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
-type PGStandardTestSuite struct {
+type BunDBStandardTestSuite struct {
// standard suite interfaces
suite.Suite
config *config.Config
diff --git a/internal/db/pg/domain.go b/internal/db/bundb/domain.go
similarity index 98%
rename from internal/db/pg/domain.go
rename to internal/db/bundb/domain.go
index 4a9db09d6..6aa2b8ffe 100644
--- a/internal/db/pg/domain.go
+++ b/internal/db/bundb/domain.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package pg
+package bundb
import (
"context"
@@ -34,7 +34,6 @@ type domainDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
diff --git a/internal/db/pg/instance.go b/internal/db/bundb/instance.go
similarity index 98%
rename from internal/db/pg/instance.go
rename to internal/db/bundb/instance.go
index 946c4c441..f9364346e 100644
--- a/internal/db/pg/instance.go
+++ b/internal/db/bundb/instance.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package pg
+package bundb
import (
"context"
@@ -32,7 +32,6 @@ type instanceDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
diff --git a/internal/db/pg/media.go b/internal/db/bundb/media.go
similarity index 97%
rename from internal/db/pg/media.go
rename to internal/db/bundb/media.go
index dea26b8de..04e55ca62 100644
--- a/internal/db/pg/media.go
+++ b/internal/db/bundb/media.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package pg
+package bundb
import (
"context"
@@ -32,7 +32,6 @@ type mediaDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery {
diff --git a/internal/db/pg/mention.go b/internal/db/bundb/mention.go
similarity index 98%
rename from internal/db/pg/mention.go
rename to internal/db/bundb/mention.go
index 5f61b93ec..a444f9b5f 100644
--- a/internal/db/pg/mention.go
+++ b/internal/db/bundb/mention.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package pg
+package bundb
import (
"context"
@@ -33,7 +33,6 @@ type mentionDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
cache cache.Cache
}
diff --git a/internal/db/pg/notification.go b/internal/db/bundb/notification.go
similarity index 98%
rename from internal/db/pg/notification.go
rename to internal/db/bundb/notification.go
index 497bfb056..1c30837ec 100644
--- a/internal/db/pg/notification.go
+++ b/internal/db/bundb/notification.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package pg
+package bundb
import (
"context"
@@ -33,7 +33,6 @@ type notificationDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
cache cache.Cache
}
diff --git a/internal/db/pg/relationship.go b/internal/db/bundb/relationship.go
similarity index 99%
rename from internal/db/pg/relationship.go
rename to internal/db/bundb/relationship.go
index f78179476..ccc604baf 100644
--- a/internal/db/pg/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package pg
+package bundb
import (
"context"
@@ -34,7 +34,6 @@ type relationshipDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery {
diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go
new file mode 100644
index 000000000..87e20673d
--- /dev/null
+++ b/internal/db/bundb/session.go
@@ -0,0 +1,85 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package bundb
+
+import (
+ "context"
+ "crypto/rand"
+
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/uptrace/bun"
+)
+
+type sessionDB struct {
+ config *config.Config
+ conn *bun.DB
+ log *logrus.Logger
+}
+
+func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
+ rs := new(gtsmodel.RouterSession)
+
+ q := s.conn.
+ NewSelect().
+ Model(rs).
+ Limit(1)
+
+ _, err := q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return rs, err
+}
+
+func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) {
+ auth := make([]byte, 32)
+ crypt := make([]byte, 32)
+
+ if _, err := rand.Read(auth); err != nil {
+ return nil, err
+ }
+ if _, err := rand.Read(crypt); err != nil {
+ return nil, err
+ }
+
+ rid, err := id.NewULID()
+ if err != nil {
+ return nil, err
+ }
+
+ rs := >smodel.RouterSession{
+ ID: rid,
+ Auth: auth,
+ Crypt: crypt,
+ }
+
+ q := s.conn.
+ NewInsert().
+ Model(rs)
+
+ _, err = q.Exec(ctx)
+
+ err = processErrorResponse(err)
+
+ return rs, err
+}
diff --git a/internal/db/pg/status.go b/internal/db/bundb/status.go
similarity index 99%
rename from internal/db/pg/status.go
rename to internal/db/bundb/status.go
index e4609a116..da8d8ca41 100644
--- a/internal/db/pg/status.go
+++ b/internal/db/bundb/status.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package pg
+package bundb
import (
"container/list"
@@ -36,7 +36,6 @@ type statusDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
cache cache.Cache
}
diff --git a/internal/db/pg/status_test.go b/internal/db/bundb/status_test.go
similarity index 98%
rename from internal/db/pg/status_test.go
rename to internal/db/bundb/status_test.go
index e3b9f1867..513000577 100644
--- a/internal/db/pg/status_test.go
+++ b/internal/db/bundb/status_test.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package pg_test
+package bundb_test
import (
"context"
@@ -29,7 +29,7 @@ import (
)
type StatusTestSuite struct {
- PGStandardTestSuite
+ BunDBStandardTestSuite
}
func (suite *StatusTestSuite) SetupSuite() {
diff --git a/internal/db/pg/timeline.go b/internal/db/bundb/timeline.go
similarity index 98%
rename from internal/db/pg/timeline.go
rename to internal/db/bundb/timeline.go
index 0059f8319..b62ad4c50 100644
--- a/internal/db/pg/timeline.go
+++ b/internal/db/bundb/timeline.go
@@ -16,7 +16,7 @@
along with this program. If not, see .
*/
-package pg
+package bundb
import (
"context"
@@ -34,7 +34,6 @@ type timelineDB struct {
config *config.Config
conn *bun.DB
log *logrus.Logger
- cancel context.CancelFunc
}
func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
@@ -78,7 +77,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
// OR statuses posted by accountID itself (since a user should be able to see their own statuses).
//
// This is equivalent to something like WHERE ... AND (... OR ...)
- // See: https://pg.uptrace.dev/queries/#select
+ // See: https://bun.uptrace.dev/guide/queries.html#select
whereGroup := func(*bun.SelectQuery) *bun.SelectQuery {
return q.
WhereOr("f.account_id = ?", accountID).
diff --git a/internal/db/pg/util.go b/internal/db/bundb/util.go
similarity index 91%
rename from internal/db/pg/util.go
rename to internal/db/bundb/util.go
index 90e784c3b..115d18de2 100644
--- a/internal/db/pg/util.go
+++ b/internal/db/bundb/util.go
@@ -16,10 +16,11 @@
along with this program. If not, see .
*/
-package pg
+package bundb
import (
"context"
+ "strings"
"database/sql"
@@ -35,6 +36,9 @@ func processErrorResponse(err error) db.Error {
case sql.ErrNoRows:
return db.ErrNoEntries
default:
+ if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
+ return db.ErrAlreadyExists
+ }
return err
}
}
diff --git a/internal/db/db.go b/internal/db/db.go
index 71ac887bb..ec94fcfe7 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -40,6 +40,7 @@ type DB interface {
Mention
Notification
Relationship
+ Session
Status
Timeline
diff --git a/internal/db/session.go b/internal/db/session.go
new file mode 100644
index 000000000..ae13dccce
--- /dev/null
+++ b/internal/db/session.go
@@ -0,0 +1,31 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package db
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+// Session handles getting/creation of router sessions.
+type Session interface {
+ GetSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
+ CreateSession(ctx context.Context) (*gtsmodel.RouterSession, Error)
+}
diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go
index dbecd49f6..2eee0645d 100644
--- a/internal/federation/dereferencing/account.go
+++ b/internal/federation/dereferencing/account.go
@@ -24,6 +24,7 @@ import (
"errors"
"fmt"
"net/url"
+ "strings"
"github.com/go-fed/activity/streams"
"github.com/go-fed/activity/streams/vocab"
@@ -34,18 +35,33 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/transport"
)
+func instanceAccount(account *gtsmodel.Account) bool {
+ return strings.EqualFold(account.Username, account.Domain) ||
+ account.FollowersURI == "" ||
+ account.FollowingURI == "" ||
+ (account.Username == "internal.fetch" && strings.Contains(account.Note, "internal service actor"))
+}
+
// EnrichRemoteAccount takes an account that's already been inserted into the database in a minimal form,
// and populates it with additional fields, media, etc.
//
// EnrichRemoteAccount is mostly useful for calling after an account has been initially created by
// the federatingDB's Create function, or during the federated authorization flow.
func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) {
+
+ // if we're dealing with an instance account, we don't need to update anything
+ if instanceAccount(account) {
+ return account, nil
+ }
+
if err := d.PopulateAccountFields(ctx, account, username, false); err != nil {
return nil, err
}
- if err := d.db.UpdateByID(ctx, account.ID, account); err != nil {
- return nil, fmt.Errorf("EnrichRemoteAccount: error updating account: %s", err)
+ var err error
+ account, err = d.db.UpdateAccount(ctx, account)
+ if err != nil {
+ d.log.Errorf("EnrichRemoteAccount: error updating account: %s", err)
}
return account, nil
@@ -108,8 +124,9 @@ func (d *deref) GetRemoteAccount(ctx context.Context, username string, remoteAcc
return nil, new, fmt.Errorf("FullyDereferenceAccount: error populating further account fields: %s", err)
}
- if err := d.db.UpdateByID(ctx, gtsAccount.ID, gtsAccount); err != nil {
- return nil, new, fmt.Errorf("FullyDereferenceAccount: error updating existing account: %s", err)
+ gtsAccount, err = d.db.UpdateAccount(ctx, gtsAccount)
+ if err != nil {
+ return nil, false, fmt.Errorf("EnrichRemoteAccount: error updating account: %s", err)
}
}
diff --git a/internal/federation/dereferencing/thread.go b/internal/federation/dereferencing/thread.go
index 328a1c4ee..f9dd9aa09 100644
--- a/internal/federation/dereferencing/thread.go
+++ b/internal/federation/dereferencing/thread.go
@@ -25,7 +25,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/ap"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
@@ -87,8 +86,8 @@ func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI
return err
}
- status := >smodel.Status{}
- if err := d.db.GetByID(ctx, id, status); err != nil {
+ status, err := d.db.GetStatusByID(ctx, id)
+ if err != nil {
return err
}
diff --git a/internal/federation/federatingdb/delete.go b/internal/federation/federatingdb/delete.go
index 94dfdf71f..11b818168 100644
--- a/internal/federation/federatingdb/delete.go
+++ b/internal/federation/federatingdb/delete.go
@@ -87,7 +87,7 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
a, err := f.db.GetAccountByURI(ctx, id.String())
if err == nil {
// it's an account
- l.Debugf("uri is for an account with id: %s", s.ID)
+ l.Debugf("uri is for an account with id: %s", a.ID)
if err := f.db.DeleteByID(ctx, a.ID, >smodel.Account{}); err != nil {
return fmt.Errorf("DELETE: err deleting account: %s", err)
}
diff --git a/internal/federation/federatingdb/followers.go b/internal/federation/federatingdb/followers.go
index e0923453f..c7f636a12 100644
--- a/internal/federation/federatingdb/followers.go
+++ b/internal/federation/federatingdb/followers.go
@@ -51,13 +51,17 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow
followers = streams.NewActivityStreamsCollection()
items := streams.NewActivityStreamsItemsProperty()
for _, follow := range acctFollowers {
- gtsFollower := >smodel.Account{}
- if err := f.db.GetByID(ctx, follow.AccountID, gtsFollower); err != nil {
- return nil, fmt.Errorf("FOLLOWERS: db error getting account id %s: %s", follow.AccountID, err)
+ if follow.Account == nil {
+ followAccount, err := f.db.GetAccountByID(ctx, follow.AccountID)
+ if err != nil {
+ return nil, fmt.Errorf("FOLLOWERS: db error getting account id %s: %s", follow.AccountID, err)
+ }
+ follow.Account = followAccount
}
- uri, err := url.Parse(gtsFollower.URI)
+
+ uri, err := url.Parse(follow.Account.URI)
if err != nil {
- return nil, fmt.Errorf("FOLLOWERS: error parsing %s as url: %s", gtsFollower.URI, err)
+ return nil, fmt.Errorf("FOLLOWERS: error parsing %s as url: %s", follow.Account.URI, err)
}
items.AppendIRI(uri)
}
diff --git a/internal/federation/federatingdb/following.go b/internal/federation/federatingdb/following.go
index 963ba63e4..9d5c0693c 100644
--- a/internal/federation/federatingdb/following.go
+++ b/internal/federation/federatingdb/following.go
@@ -64,13 +64,17 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow
following = streams.NewActivityStreamsCollection()
items := streams.NewActivityStreamsItemsProperty()
for _, follow := range acctFollowing {
- gtsFollowing := >smodel.Account{}
- if err := f.db.GetByID(ctx, follow.AccountID, gtsFollowing); err != nil {
- return nil, fmt.Errorf("FOLLOWING: db error getting account id %s: %s", follow.AccountID, err)
+ if follow.Account == nil {
+ followAccount, err := f.db.GetAccountByID(ctx, follow.AccountID)
+ if err != nil {
+ return nil, fmt.Errorf("FOLLOWING: db error getting account id %s: %s", follow.AccountID, err)
+ }
+ follow.Account = followAccount
}
- uri, err := url.Parse(gtsFollowing.URI)
+
+ uri, err := url.Parse(follow.Account.URI)
if err != nil {
- return nil, fmt.Errorf("FOLLOWING: error parsing %s as url: %s", gtsFollowing.URI, err)
+ return nil, fmt.Errorf("FOLLOWING: error parsing %s as url: %s", follow.Account.URI, err)
}
items.AppendIRI(uri)
}
diff --git a/internal/federation/federatingdb/update.go b/internal/federation/federatingdb/update.go
index 324509ebd..e9dfe5315 100644
--- a/internal/federation/federatingdb/update.go
+++ b/internal/federation/federatingdb/update.go
@@ -152,7 +152,8 @@ func (f *federatingDB) Update(ctx context.Context, asType vocab.Type) error {
}
updatedAcct.ID = requestingAcct.ID // set this here so the db will update properly instead of trying to PUT this and getting constraint issues
- if err := f.db.UpdateByID(ctx, requestingAcct.ID, updatedAcct); err != nil {
+ updatedAcct, err = f.db.UpdateAccount(ctx, updatedAcct)
+ if err != nil {
return fmt.Errorf("UPDATE: database error inserting updated account: %s", err)
}
diff --git a/internal/gtsmodel/account.go b/internal/gtsmodel/account.go
index 8eeab1812..98d2dcfc9 100644
--- a/internal/gtsmodel/account.go
+++ b/internal/gtsmodel/account.go
@@ -18,8 +18,8 @@
// Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database.
// These types should never be serialized and/or sent out via public APIs, as they contain sensitive information.
-// The annotation used on these structs is for handling them via the go-pg ORM (hence why they're in this db subdir).
-// See here for more info on go-pg model annotations: https://pg.uptrace.dev/models/
+// The annotation used on these structs is for handling them via the bun-db ORM.
+// See here for more info on bun model annotations: https://bun.uptrace.dev/guide/models.html
package gtsmodel
import (
@@ -34,9 +34,9 @@ type Account struct {
*/
// id of this account in the local database
- ID string `bun:"type:CHAR(26),pk,notnull,unique"`
+ ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"`
// Username of the account, should just be a string of [a-z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org``
- Username string `bun:",notnull,unique:userdomain"` // username and domain should be unique *with* each other
+ Username string `bun:",notnull,unique:userdomain,nullzero"` // username and domain should be unique *with* each other
// Domain of the account, will be null if this is a local account, otherwise something like ``example.org`` or ``mastodon.social``. Should be unique with username.
Domain string `bun:",unique:userdomain,nullzero"` // username and domain should be unique *with* each other
@@ -93,21 +93,21 @@ type Account struct {
*/
// What is the activitypub URI for this account discovered by webfinger?
- URI string `bun:",unique"`
+ URI string `bun:",unique,nullzero"`
// At which URL can we see the user account in a web browser?
- URL string `bun:",unique"`
+ URL string `bun:",unique,nullzero"`
// Last time this account was located using the webfinger API.
LastWebfingeredAt time.Time `bun:",nullzero"`
// Address of this account's activitypub inbox, for sending activity to
- InboxURI string `bun:",unique"`
+ InboxURI string `bun:",unique,nullzero"`
// Address of this account's activitypub outbox
- OutboxURI string `bun:",unique"`
+ OutboxURI string `bun:",unique,nullzero"`
// URI for getting the following list of this account
- FollowingURI string `bun:",unique"`
+ FollowingURI string `bun:",unique,nullzero"`
// URI for getting the followers list of this account
- FollowersURI string `bun:",unique"`
+ FollowersURI string `bun:",unique,nullzero"`
// URL for getting the featured collection list of this account
- FeaturedCollectionURI string `bun:",unique"`
+ FeaturedCollectionURI string `bun:",unique,nullzero"`
// What type of activitypub actor is this account?
ActorType string
// This account is associated with x account id
@@ -137,7 +137,7 @@ type Account struct {
// Should we hide this account's collections?
HideCollections bool
// id of the database entry that caused this account to become suspended -- can be an account ID or a domain block ID
- SuspensionOrigin string `bun:"type:CHAR(26)"`
+ SuspensionOrigin string `bun:"type:CHAR(26),nullzero"`
}
// Field represents a key value field on an account, for things like pronouns, website, etc.
diff --git a/internal/gtsmodel/emoji.go b/internal/gtsmodel/emoji.go
index 549951ddd..3b02c14e7 100644
--- a/internal/gtsmodel/emoji.go
+++ b/internal/gtsmodel/emoji.go
@@ -73,5 +73,5 @@ type Emoji struct {
// Is this emoji visible in the admin emoji picker?
VisibleInPicker bool `bun:",notnull,default:true"`
// In which emoji category is this emoji visible?
- CategoryID string `bun:"type:CHAR(26),nullzero"`
+ CategoryID string `bun:"type:CHAR(26),nullzero"`
}
diff --git a/internal/gtsmodel/followrequest.go b/internal/gtsmodel/followrequest.go
index 4a98525be..5a6cb5e02 100644
--- a/internal/gtsmodel/followrequest.go
+++ b/internal/gtsmodel/followrequest.go
@@ -29,15 +29,15 @@ type FollowRequest struct {
// When was this follow request last updated?
UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
// Who does this follow request originate from?
- AccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"`
- Account Account `bun:"rel:belongs-to"`
+ AccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"`
+ Account *Account `bun:"rel:belongs-to"`
// Who is the target of this follow request?
- TargetAccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"`
- TargetAccount Account `bun:"rel:belongs-to"`
+ TargetAccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"`
+ TargetAccount *Account `bun:"rel:belongs-to"`
// Does this follow also want to see reblogs and not just posts?
ShowReblogs bool `bun:"default:true"`
// What is the activitypub URI of this follow request?
- URI string `bun:",unique"`
+ URI string `bun:",unique,nullzero"`
// does the following account want to be notified when the followed account posts?
Notify bool
}
diff --git a/internal/gtsmodel/status.go b/internal/gtsmodel/status.go
index b4772e96b..c9766a7f2 100644
--- a/internal/gtsmodel/status.go
+++ b/internal/gtsmodel/status.go
@@ -27,9 +27,9 @@ type Status struct {
// id of the status in the database
ID string `bun:"type:CHAR(26),pk,notnull"`
// uri at which this status is reachable
- URI string `bun:",unique"`
+ URI string `bun:",unique,nullzero"`
// web url for viewing this status
- URL string `bun:",unique"`
+ URL string `bun:",unique,nullzero"`
// the html-formatted content of this status
Content string
// Database IDs of any media attachments associated with this status
@@ -45,9 +45,9 @@ type Status struct {
EmojiIDs []string `bun:"emojis,array"`
Emojis []*Emoji `bun:"attached_emojis,m2m:status_to_emojis"` // https://bun.uptrace.dev/guide/relations.html#many-to-many-relation
// when was this status created?
- CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
+ CreatedAt time.Time `bun:",notnull,default:current_timestamp"`
// when was this status updated?
- UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
+ UpdatedAt time.Time `bun:",notnull,default:current_timestamp"`
// is this status from a local account?
Local bool
// which account posted this status?
@@ -93,17 +93,17 @@ type Status struct {
// StatusToTag is an intermediate struct to facilitate the many2many relationship between a status and one or more tags.
type StatusToTag struct {
- StatusID string `bun:"type:CHAR(26),unique:statustag"`
+ StatusID string `bun:"type:CHAR(26),unique:statustag,nullzero"`
Status *Status `bun:"rel:belongs-to"`
- TagID string `bun:"type:CHAR(26),unique:statustag"`
+ TagID string `bun:"type:CHAR(26),unique:statustag,nullzero"`
Tag *Tag `bun:"rel:belongs-to"`
}
// StatusToEmoji is an intermediate struct to facilitate the many2many relationship between a status and one or more emojis.
type StatusToEmoji struct {
- StatusID string `bun:"type:CHAR(26),unique:statusemoji"`
+ StatusID string `bun:"type:CHAR(26),unique:statusemoji,nullzero"`
Status *Status `bun:"rel:belongs-to"`
- EmojiID string `bun:"type:CHAR(26),unique:statusemoji"`
+ EmojiID string `bun:"type:CHAR(26),unique:statusemoji,nullzero"`
Emoji *Emoji `bun:"rel:belongs-to"`
}
diff --git a/internal/gtsmodel/user.go b/internal/gtsmodel/user.go
index 026fac9fc..f439be439 100644
--- a/internal/gtsmodel/user.go
+++ b/internal/gtsmodel/user.go
@@ -33,9 +33,9 @@ type User struct {
// id of this user in the local database; the end-user will never need to know this, it's strictly internal
ID string `bun:"type:CHAR(26),pk,notnull,unique"`
// confirmed email address for this user, this should be unique -- only one email address registered per instance, multiple users per email are not supported
- Email string `bun:"default:null,unique"`
+ Email string `bun:"default:null,unique,nullzero"`
// The id of the local gtsmodel.Account entry for this user, if it exists (unconfirmed users don't have an account yet)
- AccountID string `bun:"type:CHAR(26),unique"`
+ AccountID string `bun:"type:CHAR(26),unique,nullzero"`
Account *Account `bun:"rel:belongs-to"`
// The encrypted password of this user, generated using https://pkg.go.dev/golang.org/x/crypto/bcrypt#GenerateFromPassword. A salt is included so we're safe against 🌈 tables
EncryptedPassword string `bun:",notnull"`
diff --git a/internal/oauth/clientstore.go b/internal/oauth/clientstore.go
index ad8dbd91e..a642f6cfa 100644
--- a/internal/oauth/clientstore.go
+++ b/internal/oauth/clientstore.go
@@ -39,9 +39,7 @@ func NewClientStore(db db.Basic) oauth2.ClientStore {
}
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
- poc := &Client{
- ID: clientID,
- }
+ poc := &Client{}
if err := cs.db.GetByID(ctx, clientID, poc); err != nil {
return nil, err
}
@@ -67,7 +65,7 @@ func (cs *clientStore) Delete(ctx context.Context, id string) error {
// Client is a handy little wrapper for typical oauth client details
type Client struct {
- ID string `pg:"type:CHAR(26),pk,notnull"`
+ ID string `bun:"type:CHAR(26),pk,notnull"`
Secret string
Domain string
UserID string
diff --git a/internal/oauth/tokenstore.go b/internal/oauth/tokenstore.go
index 33b4f7d00..264678ff5 100644
--- a/internal/oauth/tokenstore.go
+++ b/internal/oauth/tokenstore.go
@@ -43,13 +43,13 @@ type tokenStore struct {
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
// the tokens in the DB once per minute and deletes any that have expired.
func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.TokenStore {
- pts := &tokenStore{
+ ts := &tokenStore{
db: db,
log: log,
}
// set the token store to clean out expired tokens once per minute, or return if we're done
- go func(ctx context.Context, pts *tokenStore, log *logrus.Logger) {
+ go func(ctx context.Context, ts *tokenStore, log *logrus.Logger) {
cleanloop:
for {
select {
@@ -58,32 +58,32 @@ func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.
break cleanloop
case <-time.After(1 * time.Minute):
log.Trace("sweeping out old oauth entries broom broom")
- if err := pts.sweep(ctx); err != nil {
+ if err := ts.sweep(ctx); err != nil {
log.Errorf("error while sweeping oauth entries: %s", err)
}
}
}
- }(ctx, pts, log)
- return pts
+ }(ctx, ts, log)
+ return ts
}
// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so.
-func (pts *tokenStore) sweep(ctx context.Context) error {
+func (ts *tokenStore) sweep(ctx context.Context) error {
// select *all* tokens from the db
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
tokens := new([]*Token)
- if err := pts.db.GetAll(ctx, tokens); err != nil {
+ if err := ts.db.GetAll(ctx, tokens); err != nil {
return err
}
// iterate through and remove expired tokens
now := time.Now()
- for _, pgt := range *tokens {
+ for _, dbt := range *tokens {
// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
// we only want to check if a token expired before now if the expiry time is *not zero*;
// ie., if it's been explicity set.
- if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) {
- if err := pts.db.DeleteByID(ctx, pgt.ID, pgt); err != nil {
+ if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) {
+ if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil {
return err
}
}
@@ -94,92 +94,92 @@ func (pts *tokenStore) sweep(ctx context.Context) error {
// Create creates and store the new token information.
// For the original implementation, see https://github.com/superseriousbusiness/oauth2/blob/master/store/token.go#L34
-func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
+func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
t, ok := info.(*models.Token)
if !ok {
return errors.New("info param was not a models.Token")
}
- pgt := TokenToPGToken(t)
- if pgt.ID == "" {
- pgtID, err := id.NewRandomULID()
+ dbt := TokenToDBToken(t)
+ if dbt.ID == "" {
+ dbtID, err := id.NewRandomULID()
if err != nil {
return err
}
- pgt.ID = pgtID
+ dbt.ID = dbtID
}
- if err := pts.db.Put(ctx, pgt); err != nil {
+ if err := ts.db.Put(ctx, dbt); err != nil {
return fmt.Errorf("error in tokenstore create: %s", err)
}
return nil
}
// RemoveByCode deletes a token from the DB based on the Code field
-func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
- return pts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, &Token{})
+func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
+ return ts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, &Token{})
}
// RemoveByAccess deletes a token from the DB based on the Access field
-func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
- return pts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, &Token{})
+func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
+ return ts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, &Token{})
}
// RemoveByRefresh deletes a token from the DB based on the Refresh field
-func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
- return pts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, &Token{})
+func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
+ return ts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, &Token{})
}
// GetByCode selects a token from the DB based on the Code field
-func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
+func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
if code == "" {
return nil, nil
}
- pgt := &Token{
+ dbt := &Token{
Code: code,
}
- if err := pts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, pgt); err != nil {
+ if err := ts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, dbt); err != nil {
return nil, err
}
- return TokenToOauthToken(pgt), nil
+ return DBTokenToToken(dbt), nil
}
// GetByAccess selects a token from the DB based on the Access field
-func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
+func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
if access == "" {
return nil, nil
}
- pgt := &Token{
+ dbt := &Token{
Access: access,
}
- if err := pts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, pgt); err != nil {
+ if err := ts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, dbt); err != nil {
return nil, err
}
- return TokenToOauthToken(pgt), nil
+ return DBTokenToToken(dbt), nil
}
// GetByRefresh selects a token from the DB based on the Refresh field
-func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
+func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
if refresh == "" {
return nil, nil
}
- pgt := &Token{
+ dbt := &Token{
Refresh: refresh,
}
- if err := pts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, pgt); err != nil {
+ if err := ts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, dbt); err != nil {
return nil, err
}
- return TokenToOauthToken(pgt), nil
+ return DBTokenToToken(dbt), nil
}
/*
- The following models are basically helpers for the postgres token store implementation, they should only be used internally.
+ The following models are basically helpers for the token store implementation, they should only be used internally.
*/
// Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
//
// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined,
-// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and
+// and tokens with expired TTLs are automatically removed. Since some databases don't have that feature, it's easier to set an expiry time and
// then periodically sweep out tokens when that time has passed.
//
// Note that this struct does *not* satisfy the token interface shown here: https://github.com/superseriousbusiness/oauth2/blob/master/model.go#L22
@@ -187,26 +187,26 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
// As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
// and pgTokenToOauthToken can be used for that.
type Token struct {
- ID string `pg:"type:CHAR(26),pk,notnull"`
+ ID string `bun:"type:CHAR(26),pk,notnull"`
ClientID string
UserID string
RedirectURI string
Scope string
- Code string `pg:"default:'',pk"`
+ Code string `bun:"default:'',pk"`
CodeChallenge string
CodeChallengeMethod string
- CodeCreateAt time.Time `pg:"type:timestamp"`
- CodeExpiresAt time.Time `pg:"type:timestamp"`
- Access string `pg:"default:'',pk"`
- AccessCreateAt time.Time `pg:"type:timestamp"`
- AccessExpiresAt time.Time `pg:"type:timestamp"`
- Refresh string `pg:"default:'',pk"`
- RefreshCreateAt time.Time `pg:"type:timestamp"`
- RefreshExpiresAt time.Time `pg:"type:timestamp"`
+ CodeCreateAt time.Time `bun:",nullzero"`
+ CodeExpiresAt time.Time `bun:",nullzero"`
+ Access string `bun:"default:'',pk"`
+ AccessCreateAt time.Time `bun:",nullzero"`
+ AccessExpiresAt time.Time `bun:",nullzero"`
+ Refresh string `bun:"default:'',pk"`
+ RefreshCreateAt time.Time `bun:",nullzero"`
+ RefreshExpiresAt time.Time `bun:",nullzero"`
}
-// TokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres
-func TokenToPGToken(tkn *models.Token) *Token {
+// TokenToDBToken is a lil util function that takes a gotosocial token and gives back a token for inserting into a database.
+func TokenToDBToken(tkn *models.Token) *Token {
now := time.Now()
// For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
@@ -247,40 +247,40 @@ func TokenToPGToken(tkn *models.Token) *Token {
}
}
-// TokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token
-func TokenToOauthToken(pgt *Token) *models.Token {
+// DBTokenToToken is a lil util function that takes a database token and gives back a gotosocial token
+func DBTokenToToken(dbt *Token) *models.Token {
now := time.Now()
var codeExpiresIn time.Duration
- if !pgt.CodeExpiresAt.IsZero() {
- codeExpiresIn = pgt.CodeExpiresAt.Sub(now)
+ if !dbt.CodeExpiresAt.IsZero() {
+ codeExpiresIn = dbt.CodeExpiresAt.Sub(now)
}
var accessExpiresIn time.Duration
- if !pgt.AccessExpiresAt.IsZero() {
- accessExpiresIn = pgt.AccessExpiresAt.Sub(now)
+ if !dbt.AccessExpiresAt.IsZero() {
+ accessExpiresIn = dbt.AccessExpiresAt.Sub(now)
}
var refreshExpiresIn time.Duration
- if !pgt.RefreshExpiresAt.IsZero() {
- refreshExpiresIn = pgt.RefreshExpiresAt.Sub(now)
+ if !dbt.RefreshExpiresAt.IsZero() {
+ refreshExpiresIn = dbt.RefreshExpiresAt.Sub(now)
}
return &models.Token{
- ClientID: pgt.ClientID,
- UserID: pgt.UserID,
- RedirectURI: pgt.RedirectURI,
- Scope: pgt.Scope,
- Code: pgt.Code,
- CodeChallenge: pgt.CodeChallenge,
- CodeChallengeMethod: pgt.CodeChallengeMethod,
- CodeCreateAt: pgt.CodeCreateAt,
+ ClientID: dbt.ClientID,
+ UserID: dbt.UserID,
+ RedirectURI: dbt.RedirectURI,
+ Scope: dbt.Scope,
+ Code: dbt.Code,
+ CodeChallenge: dbt.CodeChallenge,
+ CodeChallengeMethod: dbt.CodeChallengeMethod,
+ CodeCreateAt: dbt.CodeCreateAt,
CodeExpiresIn: codeExpiresIn,
- Access: pgt.Access,
- AccessCreateAt: pgt.AccessCreateAt,
+ Access: dbt.Access,
+ AccessCreateAt: dbt.AccessCreateAt,
AccessExpiresIn: accessExpiresIn,
- Refresh: pgt.Refresh,
- RefreshCreateAt: pgt.RefreshCreateAt,
+ Refresh: dbt.Refresh,
+ RefreshCreateAt: dbt.RefreshCreateAt,
RefreshExpiresIn: refreshExpiresIn,
}
}
diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go
index a0758c846..d97af4d2e 100644
--- a/internal/processing/account/delete.go
+++ b/internal/processing/account/delete.go
@@ -177,17 +177,21 @@ selectStatusesLoop:
}
for _, b := range boosts {
- oa := >smodel.Account{}
- if err := p.db.GetByID(ctx, b.AccountID, oa); err == nil {
-
- l.Debug("putting boost undo in the client api channel")
- p.fromClientAPI <- gtsmodel.FromClientAPI{
- APObjectType: gtsmodel.ActivityStreamsAnnounce,
- APActivityType: gtsmodel.ActivityStreamsUndo,
- GTSModel: s,
- OriginAccount: oa,
- TargetAccount: account,
+ if b.Account == nil {
+ bAccount, err := p.db.GetAccountByID(ctx, b.AccountID)
+ if err != nil {
+ continue
}
+ b.Account = bAccount
+ }
+
+ l.Debug("putting boost undo in the client api channel")
+ p.fromClientAPI <- gtsmodel.FromClientAPI{
+ APObjectType: gtsmodel.ActivityStreamsAnnounce,
+ APActivityType: gtsmodel.ActivityStreamsUndo,
+ GTSModel: s,
+ OriginAccount: b.Account,
+ TargetAccount: account,
}
if err := p.db.DeleteByID(ctx, b.ID, b); err != nil {
@@ -267,7 +271,8 @@ selectStatusesLoop:
account.SuspendedAt = time.Now()
account.SuspensionOrigin = origin
- if err := p.db.UpdateByID(ctx, account.ID, account); err != nil {
+ account, err := p.db.UpdateAccount(ctx, account)
+ if err != nil {
return err
}
diff --git a/internal/processing/account/get.go b/internal/processing/account/get.go
index 39ae1c376..5f039127c 100644
--- a/internal/processing/account/get.go
+++ b/internal/processing/account/get.go
@@ -29,8 +29,8 @@ import (
)
func (p *processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) {
- targetAccount := >smodel.Account{}
- if err := p.db.GetByID(ctx, targetAccountID, targetAccount); err != nil {
+ targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID)
+ if err != nil {
if err == db.ErrNoEntries {
return nil, errors.New("account not found")
}
@@ -38,7 +38,6 @@ func (p *processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account
}
var blocked bool
- var err error
if requestingAccount != nil {
blocked, err = p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true)
if err != nil {
diff --git a/internal/processing/account/removefollow.go b/internal/processing/account/removefollow.go
index c271de79f..6186c550f 100644
--- a/internal/processing/account/removefollow.go
+++ b/internal/processing/account/removefollow.go
@@ -39,8 +39,8 @@ func (p *processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
}
// make sure the target account actually exists in our db
- targetAcct := >smodel.Account{}
- if err := p.db.GetByID(ctx, targetAccountID, targetAcct); err != nil {
+ targetAcct, err := p.db.GetAccountByID(ctx, targetAccountID)
+ if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err))
}
diff --git a/internal/processing/account/update.go b/internal/processing/account/update.go
index 339542760..99ccbf5a0 100644
--- a/internal/processing/account/update.go
+++ b/internal/processing/account/update.go
@@ -117,8 +117,8 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form
}
// fetch the account with all updated values set
- updatedAccount := >smodel.Account{}
- if err := p.db.GetByID(ctx, account.ID, updatedAccount); err != nil {
+ updatedAccount, err := p.db.GetAccountByID(ctx, account.ID)
+ if err != nil {
return nil, fmt.Errorf("could not fetch updated account %s: %s", account.ID, err)
}
diff --git a/internal/processing/followrequest.go b/internal/processing/followrequest.go
index da31d14fc..3dd6432e2 100644
--- a/internal/processing/followrequest.go
+++ b/internal/processing/followrequest.go
@@ -38,11 +38,15 @@ func (p *processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]
accts := []apimodel.Account{}
for _, fr := range frs {
- acct := >smodel.Account{}
- if err := p.db.GetByID(ctx, fr.AccountID, acct); err != nil {
- return nil, gtserror.NewErrorInternalError(err)
+ if fr.Account == nil {
+ frAcct, err := p.db.GetAccountByID(ctx, fr.AccountID)
+ if err != nil {
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+ fr.Account = frAcct
}
- mastoAcct, err := p.tc.AccountToMastoPublic(ctx, acct)
+
+ mastoAcct, err := p.tc.AccountToMastoPublic(ctx, fr.Account)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -57,22 +61,28 @@ func (p *processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a
return nil, gtserror.NewErrorNotFound(err)
}
- originAccount := >smodel.Account{}
- if err := p.db.GetByID(ctx, follow.AccountID, originAccount); err != nil {
- return nil, gtserror.NewErrorInternalError(err)
+ if follow.Account == nil {
+ followAccount, err := p.db.GetAccountByID(ctx, follow.AccountID)
+ if err != nil {
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+ follow.Account = followAccount
}
- targetAccount := >smodel.Account{}
- if err := p.db.GetByID(ctx, follow.TargetAccountID, targetAccount); err != nil {
- return nil, gtserror.NewErrorInternalError(err)
+ if follow.TargetAccount == nil {
+ followTargetAccount, err := p.db.GetAccountByID(ctx, follow.TargetAccountID)
+ if err != nil {
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+ follow.TargetAccount = followTargetAccount
}
p.fromClientAPI <- gtsmodel.FromClientAPI{
APObjectType: gtsmodel.ActivityStreamsFollow,
APActivityType: gtsmodel.ActivityStreamsAccept,
GTSModel: follow,
- OriginAccount: originAccount,
- TargetAccount: targetAccount,
+ OriginAccount: follow.Account,
+ TargetAccount: follow.TargetAccount,
}
gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID)
diff --git a/internal/processing/fromclientapi.go b/internal/processing/fromclientapi.go
index 2e0f2c0c4..a6ea0068b 100644
--- a/internal/processing/fromclientapi.go
+++ b/internal/processing/fromclientapi.go
@@ -238,11 +238,11 @@ func (p *processor) processFromClientAPI(ctx context.Context, clientMsg gtsmodel
func (p *processor) federateStatus(ctx context.Context, status *gtsmodel.Status) error {
if status.Account == nil {
- a := >smodel.Account{}
- if err := p.db.GetByID(ctx, status.AccountID, a); err != nil {
+ statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID)
+ if err != nil {
return fmt.Errorf("federateStatus: error fetching status author account: %s", err)
}
- status.Account = a
+ status.Account = statusAccount
}
// do nothing if this isn't our status
@@ -266,11 +266,11 @@ func (p *processor) federateStatus(ctx context.Context, status *gtsmodel.Status)
func (p *processor) federateStatusDelete(ctx context.Context, status *gtsmodel.Status) error {
if status.Account == nil {
- a := >smodel.Account{}
- if err := p.db.GetByID(ctx, status.AccountID, a); err != nil {
- return fmt.Errorf("federateStatus: error fetching status author account: %s", err)
+ statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID)
+ if err != nil {
+ return fmt.Errorf("federateStatusDelete: error fetching status author account: %s", err)
}
- status.Account = a
+ status.Account = statusAccount
}
// do nothing if this isn't our status
@@ -558,19 +558,19 @@ func (p *processor) federateAccountUpdate(ctx context.Context, updatedAccount *g
func (p *processor) federateBlock(ctx context.Context, block *gtsmodel.Block) error {
if block.Account == nil {
- a := >smodel.Account{}
- if err := p.db.GetByID(ctx, block.AccountID, a); err != nil {
+ blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID)
+ if err != nil {
return fmt.Errorf("federateBlock: error getting block account from database: %s", err)
}
- block.Account = a
+ block.Account = blockAccount
}
if block.TargetAccount == nil {
- a := >smodel.Account{}
- if err := p.db.GetByID(ctx, block.TargetAccountID, a); err != nil {
+ blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID)
+ if err != nil {
return fmt.Errorf("federateBlock: error getting block target account from database: %s", err)
}
- block.TargetAccount = a
+ block.TargetAccount = blockTargetAccount
}
// if both accounts are local there's nothing to do here
@@ -594,19 +594,19 @@ func (p *processor) federateBlock(ctx context.Context, block *gtsmodel.Block) er
func (p *processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) error {
if block.Account == nil {
- a := >smodel.Account{}
- if err := p.db.GetByID(ctx, block.AccountID, a); err != nil {
+ blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID)
+ if err != nil {
return fmt.Errorf("federateUnblock: error getting block account from database: %s", err)
}
- block.Account = a
+ block.Account = blockAccount
}
if block.TargetAccount == nil {
- a := >smodel.Account{}
- if err := p.db.GetByID(ctx, block.TargetAccountID, a); err != nil {
+ blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID)
+ if err != nil {
return fmt.Errorf("federateUnblock: error getting block target account from database: %s", err)
}
- block.TargetAccount = a
+ block.TargetAccount = blockTargetAccount
}
// if both accounts are local there's nothing to do here
diff --git a/internal/processing/media/delete.go b/internal/processing/media/delete.go
index 878425800..281ddba03 100644
--- a/internal/processing/media/delete.go
+++ b/internal/processing/media/delete.go
@@ -7,12 +7,11 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (p *processor) Delete(ctx context.Context, mediaAttachmentID string) gtserror.WithCode {
- a := >smodel.MediaAttachment{}
- if err := p.db.GetByID(ctx, mediaAttachmentID, a); err != nil {
+ attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
+ if err != nil {
if err == db.ErrNoEntries {
// attachment already gone
return nil
@@ -24,21 +23,21 @@ func (p *processor) Delete(ctx context.Context, mediaAttachmentID string) gtserr
errs := []string{}
// delete the thumbnail from storage
- if a.Thumbnail.Path != "" {
- if err := p.storage.RemoveFileAt(a.Thumbnail.Path); err != nil {
- errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", a.Thumbnail.Path, err))
+ if attachment.Thumbnail.Path != "" {
+ if err := p.storage.RemoveFileAt(attachment.Thumbnail.Path); err != nil {
+ errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", attachment.Thumbnail.Path, err))
}
}
// delete the file from storage
- if a.File.Path != "" {
- if err := p.storage.RemoveFileAt(a.File.Path); err != nil {
- errs = append(errs, fmt.Sprintf("remove file at path %s: %s", a.File.Path, err))
+ if attachment.File.Path != "" {
+ if err := p.storage.RemoveFileAt(attachment.File.Path); err != nil {
+ errs = append(errs, fmt.Sprintf("remove file at path %s: %s", attachment.File.Path, err))
}
}
// delete the attachment
- if err := p.db.DeleteByID(ctx, mediaAttachmentID, a); err != nil {
+ if err := p.db.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil {
if err != db.ErrNoEntries {
errs = append(errs, fmt.Sprintf("remove attachment: %s", err))
}
diff --git a/internal/processing/media/getfile.go b/internal/processing/media/getfile.go
index 6666c6423..c9c9b556d 100644
--- a/internal/processing/media/getfile.go
+++ b/internal/processing/media/getfile.go
@@ -48,8 +48,8 @@ func (p *processor) GetFile(ctx context.Context, account *gtsmodel.Account, form
wantedMediaID := spl[0]
// get the account that owns the media and make sure it's not suspended
- acct := >smodel.Account{}
- if err := p.db.GetByID(ctx, form.AccountID, acct); err != nil {
+ acct, err := p.db.GetAccountByID(ctx, form.AccountID)
+ if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("account with id %s could not be selected from the db: %s", form.AccountID, err))
}
if !acct.SuspendedAt.IsZero() {
@@ -91,8 +91,8 @@ func (p *processor) GetFile(ctx context.Context, account *gtsmodel.Account, form
return nil, gtserror.NewErrorNotFound(fmt.Errorf("media size %s not recognized for emoji", mediaSize))
}
case media.Attachment, media.Header, media.Avatar:
- a := >smodel.MediaAttachment{}
- if err := p.db.GetByID(ctx, wantedMediaID, a); err != nil {
+ a, err := p.db.GetAttachmentByID(ctx, wantedMediaID)
+ if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("attachment %s could not be taken from the db: %s", wantedMediaID, err))
}
if a.AccountID != form.AccountID {
diff --git a/internal/processing/media/getmedia.go b/internal/processing/media/getmedia.go
index 48b7b351a..91608e90d 100644
--- a/internal/processing/media/getmedia.go
+++ b/internal/processing/media/getmedia.go
@@ -30,8 +30,8 @@ import (
)
func (p *processor) GetMedia(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
- attachment := >smodel.MediaAttachment{}
- if err := p.db.GetByID(ctx, mediaAttachmentID, attachment); err != nil {
+ attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
+ if err != nil {
if err == db.ErrNoEntries {
// attachment doesn't exist
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))
diff --git a/internal/processing/media/update.go b/internal/processing/media/update.go
index 5402fd075..6f15f2ace 100644
--- a/internal/processing/media/update.go
+++ b/internal/processing/media/update.go
@@ -31,8 +31,8 @@ import (
)
func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) {
- attachment := >smodel.MediaAttachment{}
- if err := p.db.GetByID(ctx, mediaAttachmentID, attachment); err != nil {
+ attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
+ if err != nil {
if err == db.ErrNoEntries {
// attachment doesn't exist
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))
diff --git a/internal/processing/streaming/authorize.go b/internal/processing/streaming/authorize.go
index 543145e2a..f938a0c0c 100644
--- a/internal/processing/streaming/authorize.go
+++ b/internal/processing/streaming/authorize.go
@@ -24,8 +24,8 @@ func (p *processor) AuthorizeStreamingRequest(ctx context.Context, accessToken s
return nil, fmt.Errorf("AuthorizeStreamingRequest: no user found for validated uid %s", uid)
}
- acct := >smodel.Account{}
- if err := p.db.GetByID(ctx, user.AccountID, acct); err != nil || acct == nil {
+ acct, err := p.db.GetAccountByID(ctx, user.AccountID)
+ if err != nil || acct == nil {
return nil, fmt.Errorf("AuthorizeStreamingRequest: no account retrieved for user with id %s", uid)
}
diff --git a/internal/router/session.go b/internal/router/session.go
index f336521d2..4359a8a60 100644
--- a/internal/router/session.go
+++ b/internal/router/session.go
@@ -20,7 +20,6 @@ package router
import (
"context"
- "crypto/rand"
"errors"
"fmt"
"net/http"
@@ -30,8 +29,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/id"
)
// SessionOptions returns the standard set of options to use for each session.
@@ -46,34 +43,23 @@ func SessionOptions(cfg *config.Config) sessions.Options {
}
}
-func useSession(ctx context.Context, cfg *config.Config, dbService db.DB, engine *gin.Engine) error {
+func useSession(ctx context.Context, cfg *config.Config, sessionDB db.Session, engine *gin.Engine) error {
// check if we have a saved router session already
- routerSessions := []*gtsmodel.RouterSession{}
- if err := dbService.GetAll(ctx, &routerSessions); err != nil {
+ rs, err := sessionDB.GetSession(ctx)
+ if err != nil {
if err != db.ErrNoEntries {
// proper error occurred
return err
}
- }
-
- var rs *gtsmodel.RouterSession
- if len(routerSessions) == 1 {
- // we have a router session stored
- rs = routerSessions[0]
- } else if len(routerSessions) == 0 {
- // we have no router sessions so we need to create a new one
- var err error
- rs, err = routerSession(ctx, dbService)
+ // no session saved so create a new one
+ rs, err = sessionDB.CreateSession(ctx)
if err != nil {
- return fmt.Errorf("error creating new router session: %s", err)
+ return err
}
- } else {
- // we should only have one router session stored ever
- return errors.New("we had more than one router session in the db")
}
if rs == nil {
- return errors.New("error getting or creating router session: router session was nil")
+ return errors.New("router session was nil")
}
store := memstore.NewStore(rs.Auth, rs.Crypt)
@@ -82,34 +68,3 @@ func useSession(ctx context.Context, cfg *config.Config, dbService db.DB, engine
engine.Use(sessions.Sessions(sessionName, store))
return nil
}
-
-// routerSession generates a new router session with random auth and crypt bytes,
-// puts it in the database for persistence, and returns it for use.
-func routerSession(ctx context.Context, dbService db.DB) (*gtsmodel.RouterSession, error) {
- auth := make([]byte, 32)
- crypt := make([]byte, 32)
-
- if _, err := rand.Read(auth); err != nil {
- return nil, err
- }
- if _, err := rand.Read(crypt); err != nil {
- return nil, err
- }
-
- rid, err := id.NewULID()
- if err != nil {
- return nil, err
- }
-
- rs := >smodel.RouterSession{
- ID: rid,
- Auth: auth,
- Crypt: crypt,
- }
-
- if err := dbService.Put(ctx, rs); err != nil {
- return nil, err
- }
-
- return rs, nil
-}
diff --git a/internal/text/common.go b/internal/text/common.go
index ecf7a7a98..a8d585a09 100644
--- a/internal/text/common.go
+++ b/internal/text/common.go
@@ -98,8 +98,8 @@ func (f *formatter) ReplaceMentions(ctx context.Context, in string, mentions []*
// got it from the mention
targetAccount = menchie.OriginAccount
} else {
- a := >smodel.Account{}
- if err := f.db.GetByID(ctx, menchie.TargetAccountID, a); err == nil {
+ a, err := f.db.GetAccountByID(ctx, menchie.TargetAccountID)
+ if err == nil {
// got it from the db
targetAccount = a
} else {
diff --git a/internal/timeline/get.go b/internal/timeline/get.go
index c22935550..a00613dc0 100644
--- a/internal/timeline/get.go
+++ b/internal/timeline/get.go
@@ -54,7 +54,8 @@ func (t *timeline) Get(ctx context.Context, amount int, maxID string, sinceID st
if prepareNext {
// already cache the next query to speed up scrolling
go func() {
- if err := t.prepareNextQuery(ctx, amount, nextMaxID, "", ""); err != nil {
+ // use context.Background() because we don't want the query to abort when the request finishes
+ if err := t.prepareNextQuery(context.Background(), amount, nextMaxID, "", ""); err != nil {
l.Errorf("error preparing next query: %s", err)
}
}()
@@ -73,7 +74,8 @@ func (t *timeline) Get(ctx context.Context, amount int, maxID string, sinceID st
if prepareNext {
// already cache the next query to speed up scrolling
go func() {
- if err := t.prepareNextQuery(ctx, amount, nextMaxID, "", ""); err != nil {
+ // use context.Background() because we don't want the query to abort when the request finishes
+ if err := t.prepareNextQuery(context.Background(), amount, nextMaxID, "", ""); err != nil {
l.Errorf("error preparing next query: %s", err)
}
}()
diff --git a/internal/typeutils/internaltoas.go b/internal/typeutils/internaltoas.go
index abba3ea35..14ed094c5 100644
--- a/internal/typeutils/internaltoas.go
+++ b/internal/typeutils/internaltoas.go
@@ -214,9 +214,12 @@ func (c *converter) AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab
// icon
// Used as profile avatar.
if a.AvatarMediaAttachmentID != "" {
- avatar := >smodel.MediaAttachment{}
- if err := c.db.GetByID(ctx, a.AvatarMediaAttachmentID, avatar); err != nil {
- return nil, err
+ if a.AvatarMediaAttachment == nil {
+ avatar := >smodel.MediaAttachment{}
+ if err := c.db.GetByID(ctx, a.AvatarMediaAttachmentID, avatar); err != nil {
+ return nil, err
+ }
+ a.AvatarMediaAttachment = avatar
}
iconProperty := streams.NewActivityStreamsIconProperty()
@@ -224,11 +227,11 @@ func (c *converter) AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab
iconImage := streams.NewActivityStreamsImage()
mediaType := streams.NewActivityStreamsMediaTypeProperty()
- mediaType.Set(avatar.File.ContentType)
+ mediaType.Set(a.AvatarMediaAttachment.File.ContentType)
iconImage.SetActivityStreamsMediaType(mediaType)
avatarURLProperty := streams.NewActivityStreamsUrlProperty()
- avatarURL, err := url.Parse(avatar.URL)
+ avatarURL, err := url.Parse(a.AvatarMediaAttachment.URL)
if err != nil {
return nil, err
}
@@ -242,9 +245,12 @@ func (c *converter) AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab
// image
// Used as profile header.
if a.HeaderMediaAttachmentID != "" {
- header := >smodel.MediaAttachment{}
- if err := c.db.GetByID(ctx, a.HeaderMediaAttachmentID, header); err != nil {
- return nil, err
+ if a.HeaderMediaAttachment == nil {
+ header := >smodel.MediaAttachment{}
+ if err := c.db.GetByID(ctx, a.HeaderMediaAttachmentID, header); err != nil {
+ return nil, err
+ }
+ a.HeaderMediaAttachment = header
}
headerProperty := streams.NewActivityStreamsImageProperty()
@@ -252,11 +258,11 @@ func (c *converter) AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab
headerImage := streams.NewActivityStreamsImage()
mediaType := streams.NewActivityStreamsMediaTypeProperty()
- mediaType.Set(header.File.ContentType)
+ mediaType.Set(a.HeaderMediaAttachment.File.ContentType)
headerImage.SetActivityStreamsMediaType(mediaType)
headerURLProperty := streams.NewActivityStreamsUrlProperty()
- headerURL, err := url.Parse(header.URL)
+ headerURL, err := url.Parse(a.HeaderMediaAttachment.URL)
if err != nil {
return nil, err
}
diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go
index 52e394698..89da9eb01 100644
--- a/internal/typeutils/internaltofrontend.go
+++ b/internal/typeutils/internaltofrontend.go
@@ -63,6 +63,10 @@ func (c *converter) AccountToMastoSensitive(ctx context.Context, a *gtsmodel.Acc
}
func (c *converter) AccountToMastoPublic(ctx context.Context, a *gtsmodel.Account) (*model.Account, error) {
+ if a == nil {
+ return nil, fmt.Errorf("given account was nil")
+ }
+
// first check if we have this account in our frontEnd cache
if accountI, err := c.frontendCache.Fetch(a.ID); err == nil {
if account, ok := accountI.(*model.Account); ok {
@@ -266,27 +270,30 @@ func (c *converter) AttachmentToMasto(ctx context.Context, a *gtsmodel.MediaAtta
}
func (c *converter) MentionToMasto(ctx context.Context, m *gtsmodel.Mention) (model.Mention, error) {
- target := >smodel.Account{}
- if err := c.db.GetByID(ctx, m.TargetAccountID, target); err != nil {
- return model.Mention{}, err
+ if m.TargetAccount == nil {
+ targetAccount, err := c.db.GetAccountByID(ctx, m.TargetAccountID)
+ if err != nil {
+ return model.Mention{}, err
+ }
+ m.TargetAccount = targetAccount
}
var local bool
- if target.Domain == "" {
+ if m.TargetAccount.Domain == "" {
local = true
}
var acct string
if local {
- acct = target.Username
+ acct = m.TargetAccount.Username
} else {
- acct = fmt.Sprintf("%s@%s", target.Username, target.Domain)
+ acct = fmt.Sprintf("%s@%s", m.TargetAccount.Username, m.TargetAccount.Domain)
}
return model.Mention{
- ID: target.ID,
- Username: target.Username,
- URL: target.URL,
+ ID: m.TargetAccount.ID,
+ Username: m.TargetAccount.Username,
+ URL: m.TargetAccount.URL,
Acct: acct,
}, nil
}
@@ -302,11 +309,9 @@ func (c *converter) EmojiToMasto(ctx context.Context, e *gtsmodel.Emoji) (model.
}
func (c *converter) TagToMasto(ctx context.Context, t *gtsmodel.Tag) (model.Tag, error) {
- tagURL := fmt.Sprintf("%s://%s/tags/%s", c.config.Protocol, c.config.Host, t.Name)
-
return model.Tag{
Name: t.Name,
- URL: tagURL, // we don't serve URLs with collections of tagged statuses (FOR NOW) so this is purely for mastodon compatibility ¯\_(ツ)_/¯
+ URL: t.URL,
}, nil
}
@@ -331,8 +336,8 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque
// the boosted status might have been set on this struct already so check first before doing db calls
if s.BoostOf == nil {
// it's not set so fetch it from the db
- bs := >smodel.Status{}
- if err := c.db.GetByID(ctx, s.BoostOfID, bs); err != nil {
+ bs, err := c.db.GetStatusByID(ctx, s.BoostOfID)
+ if err != nil {
return nil, fmt.Errorf("error getting boosted status with id %s: %s", s.BoostOfID, err)
}
s.BoostOf = bs
@@ -341,8 +346,8 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque
// the boosted account might have been set on this struct already or passed as a param so check first before doing db calls
if s.BoostOfAccount == nil {
// it's not set so fetch it from the db
- ba := >smodel.Account{}
- if err := c.db.GetByID(ctx, s.BoostOf.AccountID, ba); err != nil {
+ ba, err := c.db.GetAccountByID(ctx, s.BoostOf.AccountID)
+ if err != nil {
return nil, fmt.Errorf("error getting boosted account %s from status with id %s: %s", s.BoostOf.AccountID, s.BoostOfID, err)
}
s.BoostOfAccount = ba
@@ -368,8 +373,8 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque
}
if s.Account == nil {
- a := >smodel.Account{}
- if err := c.db.GetByID(ctx, s.AccountID, a); err != nil {
+ a, err := c.db.GetAccountByID(ctx, s.AccountID)
+ if err != nil {
return nil, fmt.Errorf("error getting status author: %s", err)
}
s.Account = a
@@ -394,14 +399,14 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque
// the status doesn't have gts attachments on it, but it does have attachment IDs
// in this case, we need to pull the gts attachments from the db to convert them into masto ones
} else {
- for _, a := range s.AttachmentIDs {
- gtsAttachment := >smodel.MediaAttachment{}
- if err := c.db.GetByID(ctx, a, gtsAttachment); err != nil {
- return nil, fmt.Errorf("error getting attachment with id %s: %s", a, err)
+ for _, aID := range s.AttachmentIDs {
+ gtsAttachment, err := c.db.GetAttachmentByID(ctx, aID)
+ if err != nil {
+ return nil, fmt.Errorf("error getting attachment with id %s: %s", aID, err)
}
mastoAttachment, err := c.AttachmentToMasto(ctx, gtsAttachment)
if err != nil {
- return nil, fmt.Errorf("error converting attachment with id %s: %s", a, err)
+ return nil, fmt.Errorf("error converting attachment with id %s: %s", aID, err)
}
mastoAttachments = append(mastoAttachments, mastoAttachment)
}
@@ -421,10 +426,10 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque
// the status doesn't have gts mentions on it, but it does have mention IDs
// in this case, we need to pull the gts mentions from the db to convert them into masto ones
} else {
- for _, m := range s.MentionIDs {
- gtsMention := >smodel.Mention{}
- if err := c.db.GetByID(ctx, m, gtsMention); err != nil {
- return nil, fmt.Errorf("error getting mention with id %s: %s", m, err)
+ for _, mID := range s.MentionIDs {
+ gtsMention, err := c.db.GetMention(ctx, mID)
+ if err != nil {
+ return nil, fmt.Errorf("error getting mention with id %s: %s", mID, err)
}
mastoMention, err := c.MentionToMasto(ctx, gtsMention)
if err != nil {
@@ -603,13 +608,16 @@ func (c *converter) InstanceToMasto(ctx context.Context, i *gtsmodel.Instance) (
// contact account is optional but let's try to get it
if i.ContactAccountID != "" {
- ia := >smodel.Account{}
- if err := c.db.GetByID(ctx, i.ContactAccountID, ia); err == nil {
- ma, err := c.AccountToMastoPublic(ctx, ia)
+ if i.ContactAccount == nil {
+ contactAccount, err := c.db.GetAccountByID(ctx, i.ContactAccountID)
if err == nil {
- mi.ContactAccount = ma
+ i.ContactAccount = contactAccount
}
}
+ ma, err := c.AccountToMastoPublic(ctx, i.ContactAccount)
+ if err == nil {
+ mi.ContactAccount = ma
+ }
}
return mi, nil
diff --git a/testrig/db.go b/testrig/db.go
index 3901bea77..b0e2afd04 100644
--- a/testrig/db.go
+++ b/testrig/db.go
@@ -24,7 +24,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/db/pg"
+ "github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
@@ -68,7 +68,7 @@ func NewTestDB() db.DB {
l := logrus.New()
l.SetLevel(logrus.TraceLevel)
- testDB, err := pg.NewPostgresService(context.Background(), config, l)
+ testDB, err := bundb.NewBunDBService(context.Background(), config, l)
if err != nil {
logrus.Panic(err)
}
diff --git a/vendor/github.com/uptrace/bun/README.md b/vendor/github.com/uptrace/bun/README.md
index bf57440da..e7cc77a60 100644
--- a/vendor/github.com/uptrace/bun/README.md
+++ b/vendor/github.com/uptrace/bun/README.md
@@ -11,10 +11,6 @@
[](https://bun.uptrace.dev/)
[](https://discord.gg/rWtp5Aj)
-Status: API freeze (release candidate). Note that all sub-packages (mainly extra/\* packages) are
-not part of the API freeze and are developed independently. You can think of them as 3-rd party
-packages.
-
Main features are:
- Works with [PostgreSQL](https://bun.uptrace.dev/guide/drivers.html#postgresql),
diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod
index c47d595d5..0cad1ce5b 100644
--- a/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod
+++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod
@@ -4,4 +4,4 @@ go 1.16
replace github.com/uptrace/bun => ../..
-require github.com/uptrace/bun v1.0.0-rc.1
+require github.com/uptrace/bun v0.4.3
diff --git a/vendor/github.com/uptrace/bun/go.sum b/vendor/github.com/uptrace/bun/go.sum
index 1f5ad5409..3bf0a4a3f 100644
--- a/vendor/github.com/uptrace/bun/go.sum
+++ b/vendor/github.com/uptrace/bun/go.sum
@@ -20,4 +20,4 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
-gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
\ No newline at end of file
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/vendor/github.com/uptrace/bun/schema/formatter.go b/vendor/github.com/uptrace/bun/schema/formatter.go
index 45a246307..7b26fbaca 100644
--- a/vendor/github.com/uptrace/bun/schema/formatter.go
+++ b/vendor/github.com/uptrace/bun/schema/formatter.go
@@ -89,10 +89,10 @@ func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) []
func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte {
var namedArgs NamedArgAppender
if len(args) == 1 {
- if v, ok := args[0].(NamedArgAppender); ok {
- namedArgs = v
- } else if v, ok := newStructArgs(f, args[0]); ok {
- namedArgs = v
+ var ok bool
+ namedArgs, ok = args[0].(NamedArgAppender)
+ if !ok {
+ namedArgs, _ = newStructArgs(f, args[0])
}
}
diff --git a/vendor/github.com/uptrace/bun/version.go b/vendor/github.com/uptrace/bun/version.go
index 460909509..1baf9a39c 100644
--- a/vendor/github.com/uptrace/bun/version.go
+++ b/vendor/github.com/uptrace/bun/version.go
@@ -2,5 +2,5 @@ package bun
// Version is the current release version.
func Version() string {
- return "1.0.0-rc.1"
+ return "0.4.3"
}
diff --git a/vendor/modules.txt b/vendor/modules.txt
index 34ff7e957..a2edd4b90 100644
--- a/vendor/modules.txt
+++ b/vendor/modules.txt
@@ -389,7 +389,7 @@ github.com/tdewolff/parse/v2/strconv
github.com/tmthrgd/go-hex
# github.com/ugorji/go/codec v1.2.6
github.com/ugorji/go/codec
-# github.com/uptrace/bun v1.0.0-rc.1
+# github.com/uptrace/bun v0.4.3
## explicit
github.com/uptrace/bun
github.com/uptrace/bun/dialect
@@ -400,7 +400,7 @@ github.com/uptrace/bun/internal
github.com/uptrace/bun/internal/parser
github.com/uptrace/bun/internal/tagparser
github.com/uptrace/bun/schema
-# github.com/uptrace/bun/dialect/pgdialect v1.0.0-rc.1
+# github.com/uptrace/bun/dialect/pgdialect v0.4.3
## explicit
github.com/uptrace/bun/dialect/pgdialect
# github.com/urfave/cli/v2 v2.3.0