diff --git a/go.sum b/go.sum index 56a215cd3..42d24ad8c 100644 --- a/go.sum +++ b/go.sum @@ -446,10 +446,10 @@ github.com/ugorji/go v1.2.6/go.mod h1:anCg0y61KIhDlPZmnH+so+RQbysYVyDko0IMgJv0Nn github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/ugorji/go/codec v1.2.6 h1:7kbGefxLoDBuYXOms4yD7223OpNMMPNPZxXk5TvFcyQ= github.com/ugorji/go/codec v1.2.6/go.mod h1:V6TCNZ4PHqoHGFZuSG1W8nrCzzdgA2DozYxWFFpvxTw= -github.com/uptrace/bun v1.0.0-rc.1 h1:yMjz/JdlgYRema5mk59URzjAV7HpOV58Fj3KID34K9U= -github.com/uptrace/bun v1.0.0-rc.1/go.mod h1:aL6D9vPw8DXaTQTwGrEPtUderBYXx7ShUmPfnxnqscw= -github.com/uptrace/bun/dialect/pgdialect v1.0.0-rc.1 h1:4rxcO4+x8r2xyCnduH9fhhkjhbtzLHFl3vmx2goVgio= -github.com/uptrace/bun/dialect/pgdialect v1.0.0-rc.1/go.mod h1:x48KjNeAGTSq4WNIOvtfzAOFjPG/V/Gx+SdwO5ksimU= +github.com/uptrace/bun v0.4.3 h1:x6bjDqwjxwM/9Q1eauhkznuvTrz/rLiCK2p4tT63sAE= +github.com/uptrace/bun v0.4.3/go.mod h1:aL6D9vPw8DXaTQTwGrEPtUderBYXx7ShUmPfnxnqscw= +github.com/uptrace/bun/dialect/pgdialect v0.4.3 h1:lM2IUKpU99110chKkupw3oTfXiOKpB0hTJIe6frqQDo= +github.com/uptrace/bun/dialect/pgdialect v0.4.3/go.mod h1:BaNvWejl32oKUhwpFkw/eNcWldzIlVY4nfw/sNul0s8= github.com/urfave/cli/v2 v2.3.0 h1:qph92Y649prgesehzOrQjdWyxFOp/QVM+6imKHad91M= github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= diff --git a/internal/api/client/account/accountupdate_test.go b/internal/api/client/account/accountupdate_test.go index 349429625..8fc31171b 100644 --- a/internal/api/client/account/accountupdate_test.go +++ b/internal/api/client/account/accountupdate_test.go @@ -79,7 +79,7 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandler() recorder := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(recorder) ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) - ctx.Set(oauth.SessionAuthorizedToken, oauth.TokenToOauthToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", account.UpdateCredentialsPath), bytes.NewReader(requestBody.Bytes())) // the endpoint we're hitting diff --git a/internal/api/client/auth/auth_test.go b/internal/api/client/auth/auth_test.go index d1aedb6d8..3d5170f31 100644 --- a/internal/api/client/auth/auth_test.go +++ b/internal/api/client/auth/auth_test.go @@ -28,7 +28,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/db/pg" + "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/oauth" "golang.org/x/crypto/bcrypt" @@ -104,7 +104,7 @@ func (suite *AuthTestSuite) SetupTest() { log := logrus.New() log.SetLevel(logrus.TraceLevel) - db, err := pg.NewPostgresService(context.Background(), suite.config, log) + db, err := bundb.NewBunDBService(context.Background(), suite.config, log) if err != nil { logrus.Panicf("error creating database connection: %s", err) } diff --git a/internal/api/client/auth/authorize.go b/internal/api/client/auth/authorize.go index 34e53f5ae..d67b39ed5 100644 --- a/internal/api/client/auth/authorize.go +++ b/internal/api/client/auth/authorize.go @@ -80,20 +80,15 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { } // we can also use the userid of the user to fetch their username from the db to greet them nicely <3 - user := >smodel.User{ - ID: userID, - } + user := >smodel.User{} if err := m.db.GetByID(c.Request.Context(), user.ID, user); err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - acct := >smodel.Account{ - ID: user.AccountID, - } - - if err := m.db.GetByID(c.Request.Context(), acct.ID, acct); err != nil { + acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) + if err != nil { m.clearSession(s) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return diff --git a/internal/api/client/auth/middleware.go b/internal/api/client/auth/middleware.go index c1995ca92..3599c7048 100644 --- a/internal/api/client/auth/middleware.go +++ b/internal/api/client/auth/middleware.go @@ -56,8 +56,8 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) { c.Set(oauth.SessionAuthorizedUser, user) l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user) - acct := >smodel.Account{} - if err := m.db.GetByID(c.Request.Context(), user.AccountID, acct); err != nil || acct == nil { + acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) + if err != nil || acct == nil { l.Warnf("no account found for validated user %s", uid) return } diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go index 5c48a4381..8433786e4 100644 --- a/internal/api/client/media/mediacreate_test.go +++ b/internal/api/client/media/mediacreate_test.go @@ -121,7 +121,7 @@ func (suite *MediaCreateTestSuite) TestStatusCreatePOSTImageHandlerSuccessful() // set up the context for the request t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) recorder := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(recorder) ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) diff --git a/internal/api/client/status/statusboost_test.go b/internal/api/client/status/statusboost_test.go index fbe267fac..4157bde38 100644 --- a/internal/api/client/status/statusboost_test.go +++ b/internal/api/client/status/statusboost_test.go @@ -67,7 +67,7 @@ func (suite *StatusBoostTestSuite) TearDownTest() { func (suite *StatusBoostTestSuite) TestPostBoost() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["admin_account_status_1"] @@ -133,7 +133,7 @@ func (suite *StatusBoostTestSuite) TestPostBoost() { func (suite *StatusBoostTestSuite) TestPostUnboostable() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["local_account_2_status_4"] @@ -171,7 +171,7 @@ func (suite *StatusBoostTestSuite) TestPostUnboostable() { func (suite *StatusBoostTestSuite) TestPostNotVisible() { t := suite.testTokens["local_account_2"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["local_account_1_status_3"] // this is a mutual only status and these accounts aren't mutuals diff --git a/internal/api/client/status/statuscreate_test.go b/internal/api/client/status/statuscreate_test.go index 097f268f2..060d96dad 100644 --- a/internal/api/client/status/statuscreate_test.go +++ b/internal/api/client/status/statuscreate_test.go @@ -83,7 +83,7 @@ https://docs.gotosocial.org/en/latest/user_guide/posts/#links func (suite *StatusCreateTestSuite) TestPostNewStatus() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -137,7 +137,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatus() { func (suite *StatusCreateTestSuite) TestPostAnotherNewStatus() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -172,7 +172,7 @@ func (suite *StatusCreateTestSuite) TestPostAnotherNewStatus() { func (suite *StatusCreateTestSuite) TestPostNewStatusWithEmoji() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -213,7 +213,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatusWithEmoji() { // Try to reply to a status that doesn't exist func (suite *StatusCreateTestSuite) TestReplyToNonexistentStatus() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -244,7 +244,7 @@ func (suite *StatusCreateTestSuite) TestReplyToNonexistentStatus() { // Post a reply to the status of a local user that allows replies. func (suite *StatusCreateTestSuite) TestReplyToLocalStatus() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // setup recorder := httptest.NewRecorder() @@ -284,7 +284,7 @@ func (suite *StatusCreateTestSuite) TestReplyToLocalStatus() { // Take a media file which is currently not associated with a status, and attach it to a new status. func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) attachment := suite.testAttachments["local_account_1_unattached_1"] @@ -323,8 +323,7 @@ func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() { assert.Len(suite.T(), statusResponse.MediaAttachments, 1) // get the updated media attachment from the database - gtsAttachment := >smodel.MediaAttachment{} - err = suite.db.GetByID(context.Background(), statusResponse.MediaAttachments[0].ID, gtsAttachment) + gtsAttachment, err := suite.db.GetAttachmentByID(context.Background(), statusResponse.MediaAttachments[0].ID) assert.NoError(suite.T(), err) // convert it to a masto attachment diff --git a/internal/api/client/status/statusfave_test.go b/internal/api/client/status/statusfave_test.go index 0f44b5e90..2f7a2c596 100644 --- a/internal/api/client/status/statusfave_test.go +++ b/internal/api/client/status/statusfave_test.go @@ -71,7 +71,7 @@ func (suite *StatusFaveTestSuite) TearDownTest() { func (suite *StatusFaveTestSuite) TestPostFave() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["admin_account_status_2"] @@ -119,7 +119,7 @@ func (suite *StatusFaveTestSuite) TestPostFave() { func (suite *StatusFaveTestSuite) TestPostUnfaveable() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["local_account_2_status_3"] // this one is unlikeable and unreplyable diff --git a/internal/api/client/status/statusfavedby_test.go b/internal/api/client/status/statusfavedby_test.go index 22a549b30..7475f1e69 100644 --- a/internal/api/client/status/statusfavedby_test.go +++ b/internal/api/client/status/statusfavedby_test.go @@ -69,7 +69,7 @@ func (suite *StatusFavedByTestSuite) TearDownTest() { func (suite *StatusFavedByTestSuite) TestGetFavedBy() { t := suite.testTokens["local_account_2"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) targetStatus := suite.testStatuses["admin_account_status_1"] // this status is faved by local_account_1 diff --git a/internal/api/client/status/statusunfave_test.go b/internal/api/client/status/statusunfave_test.go index a5f267f4c..9e7ea8f82 100644 --- a/internal/api/client/status/statusunfave_test.go +++ b/internal/api/client/status/statusunfave_test.go @@ -71,7 +71,7 @@ func (suite *StatusUnfaveTestSuite) TearDownTest() { func (suite *StatusUnfaveTestSuite) TestPostUnfave() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // this is the status we wanna unfave: in the testrig it's already faved by this account targetStatus := suite.testStatuses["admin_account_status_1"] @@ -120,7 +120,7 @@ func (suite *StatusUnfaveTestSuite) TestPostUnfave() { func (suite *StatusUnfaveTestSuite) TestPostAlreadyNotFaved() { t := suite.testTokens["local_account_1"] - oauthToken := oauth.TokenToOauthToken(t) + oauthToken := oauth.DBTokenToToken(t) // this is the status we wanna unfave: in the testrig it's not faved by this account targetStatus := suite.testStatuses["admin_account_status_2"] diff --git a/internal/cache/cache.go b/internal/cache/cache.go index eb3744cfe..ce4aad04d 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -37,7 +37,7 @@ type cache struct { // New returns a new in-memory cache. func New() Cache { c := ttlcache.NewCache() - c.SetTTL(30 * time.Second) + c.SetTTL(5 * time.Minute) cache := &cache{ c: c, } diff --git a/internal/cliactions/admin/account/account.go b/internal/cliactions/admin/account/account.go index 82058fe25..46998ec6a 100644 --- a/internal/cliactions/admin/account/account.go +++ b/internal/cliactions/admin/account/account.go @@ -28,7 +28,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/cliactions" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/db/pg" + "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/util" "golang.org/x/crypto/bcrypt" @@ -36,7 +36,7 @@ import ( // Create creates a new account in the database using the provided flags. var Create cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - dbConn, err := pg.NewPostgresService(ctx, c, log) + dbConn, err := bundb.NewBunDBService(ctx, c, log) if err != nil { return fmt.Errorf("error creating dbservice: %s", err) } @@ -75,7 +75,7 @@ var Create cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo // Confirm sets a user to Approved, sets Email to the current UnconfirmedEmail value, and sets ConfirmedAt to now. var Confirm cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - dbConn, err := pg.NewPostgresService(ctx, c, log) + dbConn, err := bundb.NewBunDBService(ctx, c, log) if err != nil { return fmt.Errorf("error creating dbservice: %s", err) } @@ -110,7 +110,7 @@ var Confirm cliactions.GTSAction = func(ctx context.Context, c *config.Config, l // Promote sets a user to admin. var Promote cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - dbConn, err := pg.NewPostgresService(ctx, c, log) + dbConn, err := bundb.NewBunDBService(ctx, c, log) if err != nil { return fmt.Errorf("error creating dbservice: %s", err) } @@ -142,7 +142,7 @@ var Promote cliactions.GTSAction = func(ctx context.Context, c *config.Config, l // Demote sets admin on a user to false. var Demote cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - dbConn, err := pg.NewPostgresService(ctx, c, log) + dbConn, err := bundb.NewBunDBService(ctx, c, log) if err != nil { return fmt.Errorf("error creating dbservice: %s", err) } @@ -174,7 +174,7 @@ var Demote cliactions.GTSAction = func(ctx context.Context, c *config.Config, lo // Disable sets Disabled to true on a user. var Disable cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - dbConn, err := pg.NewPostgresService(ctx, c, log) + dbConn, err := bundb.NewBunDBService(ctx, c, log) if err != nil { return fmt.Errorf("error creating dbservice: %s", err) } @@ -212,7 +212,7 @@ var Suspend cliactions.GTSAction = func(ctx context.Context, c *config.Config, l // Password sets the password of target account. var Password cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - dbConn, err := pg.NewPostgresService(ctx, c, log) + dbConn, err := bundb.NewBunDBService(ctx, c, log) if err != nil { return fmt.Errorf("error creating dbservice: %s", err) } diff --git a/internal/cliactions/server/server.go b/internal/cliactions/server/server.go index be44c00dc..877a9d397 100644 --- a/internal/cliactions/server/server.go +++ b/internal/cliactions/server/server.go @@ -35,7 +35,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/blob" "github.com/superseriousbusiness/gotosocial/internal/cliactions" "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db/pg" + "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" "github.com/superseriousbusiness/gotosocial/internal/gotosocial" @@ -79,7 +79,7 @@ var models []interface{} = []interface{}{ // Start creates and starts a gotosocial server var Start cliactions.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { - dbService, err := pg.NewPostgresService(ctx, c, log) + dbService, err := bundb.NewBunDBService(ctx, c, log) if err != nil { return fmt.Errorf("error creating dbservice: %s", err) } diff --git a/internal/db/account.go b/internal/db/account.go index 61d97bf8c..058a89859 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -36,6 +36,9 @@ type Account interface { // GetAccountByURL returns one account with the given URL, or an error if something goes wrong. GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error) + // UpdateAccount updates one account by ID. + UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) + // GetLocalAccountByUsername returns an account on this instance by its username. GetLocalAccountByUsername(ctx context.Context, username string) (*gtsmodel.Account, Error) diff --git a/internal/db/pg/account.go b/internal/db/bundb/account.go similarity index 94% rename from internal/db/pg/account.go rename to internal/db/bundb/account.go index 73b71cf11..7ebb79a15 100644 --- a/internal/db/pg/account.go +++ b/internal/db/bundb/account.go @@ -16,12 +16,13 @@ along with this program. If not, see . */ -package pg +package bundb import ( "context" "errors" "fmt" + "strings" "time" "github.com/sirupsen/logrus" @@ -35,7 +36,6 @@ type accountDB struct { config *config.Config conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery { @@ -79,6 +79,25 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel. return account, err } +func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, db.Error) { + if strings.TrimSpace(account.ID) == "" { + return nil, errors.New("account had no ID") + } + + account.UpdatedAt = time.Now() + + q := a.conn. + NewUpdate(). + Model(account). + WherePK() + + _, err := q.Exec(ctx) + + err = processErrorResponse(err) + + return account, err +} + func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { account := new(gtsmodel.Account) diff --git a/internal/db/pg/account_test.go b/internal/db/bundb/account_test.go similarity index 80% rename from internal/db/pg/account_test.go rename to internal/db/bundb/account_test.go index df4d244bf..7174b781d 100644 --- a/internal/db/pg/account_test.go +++ b/internal/db/bundb/account_test.go @@ -16,18 +16,19 @@ along with this program. If not, see . */ -package pg_test +package bundb_test import ( "context" "testing" + "time" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/testrig" ) type AccountTestSuite struct { - PGStandardTestSuite + BunDBStandardTestSuite } func (suite *AccountTestSuite) SetupSuite() { @@ -66,6 +67,20 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() { suite.NotEmpty(account.HeaderMediaAttachment.URL) } +func (suite *AccountTestSuite) TestUpdateAccount() { + testAccount := suite.testAccounts["local_account_1"] + + testAccount.DisplayName = "new display name!" + + _, err := suite.db.UpdateAccount(context.Background(), testAccount) + suite.NoError(err) + + updated, err := suite.db.GetAccountByID(context.Background(), testAccount.ID) + suite.NoError(err) + suite.Equal("new display name!", updated.DisplayName) + suite.WithinDuration(time.Now(), updated.UpdatedAt, 5*time.Second) +} + func TestAccountTestSuite(t *testing.T) { suite.Run(t, new(AccountTestSuite)) } diff --git a/internal/db/pg/admin.go b/internal/db/bundb/admin.go similarity index 99% rename from internal/db/pg/admin.go rename to internal/db/bundb/admin.go index e9fd01b11..67a1e8a0d 100644 --- a/internal/db/pg/admin.go +++ b/internal/db/bundb/admin.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package pg +package bundb import ( "context" @@ -43,7 +43,6 @@ type adminDB struct { config *config.Config conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { diff --git a/internal/db/pg/basic.go b/internal/db/bundb/basic.go similarity index 97% rename from internal/db/pg/basic.go rename to internal/db/bundb/basic.go index dd91acb34..983b6b810 100644 --- a/internal/db/pg/basic.go +++ b/internal/db/bundb/basic.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package pg +package bundb import ( "context" @@ -33,7 +33,6 @@ type basicDB struct { config *config.Config conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } func (b *basicDB) Put(ctx context.Context, i interface{}) db.Error { @@ -116,7 +115,7 @@ func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.E q := b.conn. NewUpdate(). Model(i). - Where("id = ?", id) + WherePK() _, err := q.Exec(ctx) @@ -127,7 +126,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu q := b.conn.NewUpdate(). Model(i). Set("? = ?", bun.Safe(key), value). - Where("id = ?", id) + WherePK() _, err := q.Exec(ctx) @@ -174,7 +173,6 @@ func (b *basicDB) Stop(ctx context.Context) db.Error { b.log.Info("closing db connection") if err := b.conn.Close(); err != nil { // only cancel if there's a problem closing the db - b.cancel() return err } return nil diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go new file mode 100644 index 000000000..9189618c9 --- /dev/null +++ b/internal/db/bundb/basic_test.go @@ -0,0 +1,68 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type BasicTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *BasicTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testAttachments = testrig.NewTestAttachments() + suite.testStatuses = testrig.NewTestStatuses() + suite.testTags = testrig.NewTestTags() + suite.testMentions = testrig.NewTestMentions() +} + +func (suite *BasicTestSuite) SetupTest() { + suite.config = testrig.NewTestConfig() + suite.db = testrig.NewTestDB() + suite.log = testrig.NewTestLog() + + testrig.StandardDBSetup(suite.db, suite.testAccounts) +} + +func (suite *BasicTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) +} + +func (suite *BasicTestSuite) TestGetAccountByID() { + testAccount := suite.testAccounts["local_account_1"] + + a := >smodel.Account{} + err := suite.db.GetByID(context.Background(), testAccount.ID, a) + suite.NoError(err) +} + +func TestBasicTestSuite(t *testing.T) { + suite.Run(t, new(BasicTestSuite)) +} diff --git a/internal/db/pg/pg.go b/internal/db/bundb/bundb.go similarity index 81% rename from internal/db/pg/pg.go rename to internal/db/bundb/bundb.go index 5d1a1d01a..49ed09cbd 100644 --- a/internal/db/pg/pg.go +++ b/internal/db/bundb/bundb.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package pg +package bundb import ( "context" @@ -41,13 +41,18 @@ import ( "github.com/uptrace/bun/dialect/pgdialect" ) +const ( + dbTypePostgres = "postgres" + dbTypeSqlite = "sqlite" +) + var registerTables []interface{} = []interface{}{ >smodel.StatusToEmoji{}, >smodel.StatusToTag{}, } -// postgresService satisfies the DB interface -type postgresService struct { +// bunDBService satisfies the DB interface +type bunDBService struct { db.Account db.Admin db.Basic @@ -57,37 +62,49 @@ type postgresService struct { db.Mention db.Notification db.Relationship + db.Session db.Status db.Timeline config *config.Config conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } -// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. -// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection. -func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) { +// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface. +// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection. +func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) { + var sqldb *sql.DB + var conn *bun.DB - opts, err := derivePGOptions(c) - if err != nil { - return nil, fmt.Errorf("could not create postgres options: %s", err) + // depending on the database type we're trying to create, we need to use a different driver... + switch strings.ToLower(c.DBConfig.Type) { + case dbTypePostgres: + // POSTGRES + opts, err := deriveBunDBPGOptions(c) + if err != nil { + return nil, fmt.Errorf("could not create bundb postgres options: %s", err) + } + sqldb = stdlib.OpenDB(*opts) + conn = bun.NewDB(sqldb, pgdialect.New()) + case dbTypeSqlite: + // SQLITE + // TODO: https://bun.uptrace.dev/guide/drivers.html#sqlite + default: + return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type)) } - sqldb := stdlib.OpenDB(*opts) - conn := bun.NewDB(sqldb, pgdialect.New()) // actually *begin* the connection so that we can tell if the db is there and listening if err := conn.Ping(); err != nil { return nil, fmt.Errorf("db connection error: %s", err) } - log.Info("connected to postgres") + log.Info("connected to database") for _, t := range registerTables { // https://bun.uptrace.dev/orm/many-to-many-relation/ conn.RegisterModel(t) } - ps := &postgresService{ + ps := &bunDBService{ Account: &accountDB{ config: c, conn: conn, @@ -133,6 +150,11 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge conn: conn, log: log, }, + Session: &sessionDB{ + config: c, + conn: conn, + log: log, + }, Status: &statusDB{ config: c, conn: conn, @@ -148,7 +170,7 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge log: log, } - // we can confidently return this useable postgres service now + // we can confidently return this useable service now return ps, nil } @@ -156,9 +178,9 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge HANDY STUFF */ -// derivePGOptions takes an application config and returns either a ready-to-use set of options +// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options // with sensible defaults, or an error if it's not satisfied by the provided config. -func derivePGOptions(c *config.Config) (*pgx.ConnConfig, error) { +func deriveBunDBPGOptions(c *config.Config) (*pgx.ConnConfig, error) { if strings.ToUpper(c.DBConfig.Type) != db.DBTypePostgres { return nil, fmt.Errorf("expected db type of %s but got %s", db.DBTypePostgres, c.DBConfig.Type) } @@ -236,15 +258,16 @@ func derivePGOptions(c *config.Config) (*pgx.ConnConfig, error) { tlsConfig.RootCAs = certPool } - opts, _ := pgx.ParseConfig("") - opts.Host = c.DBConfig.Address - opts.Port = uint16(c.DBConfig.Port) - opts.User = c.DBConfig.User - opts.Password = c.DBConfig.Password - opts.TLSConfig = tlsConfig - opts.PreferSimpleProtocol = true + cfg, _ := pgx.ParseConfig("") + cfg.Host = c.DBConfig.Address + cfg.Port = uint16(c.DBConfig.Port) + cfg.User = c.DBConfig.User + cfg.Password = c.DBConfig.Password + cfg.TLSConfig = tlsConfig + cfg.Database = c.DBConfig.Database + cfg.PreferSimpleProtocol = true - return opts, nil + return cfg, nil } /* @@ -253,7 +276,7 @@ func derivePGOptions(c *config.Config) (*pgx.ConnConfig, error) { // TODO: move these to the type converter, it's bananas that they're here and not there -func (ps *postgresService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) { +func (ps *bunDBService) MentionStringsToMentions(ctx context.Context, targetAccounts []string, originAccountID string, statusID string) ([]*gtsmodel.Mention, error) { ogAccount := >smodel.Account{} if err := ps.conn.NewSelect().Model(ogAccount).Where("id = ?", originAccountID).Scan(ctx); err != nil { return nil, err @@ -331,7 +354,7 @@ func (ps *postgresService) MentionStringsToMentions(ctx context.Context, targetA return menchies, nil } -func (ps *postgresService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) { +func (ps *bunDBService) TagStringsToTags(ctx context.Context, tags []string, originAccountID string, statusID string) ([]*gtsmodel.Tag, error) { newTags := []*gtsmodel.Tag{} for _, t := range tags { tag := >smodel.Tag{} @@ -367,7 +390,7 @@ func (ps *postgresService) TagStringsToTags(ctx context.Context, tags []string, return newTags, nil } -func (ps *postgresService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) { +func (ps *bunDBService) EmojiStringsToEmojis(ctx context.Context, emojis []string, originAccountID string, statusID string) ([]*gtsmodel.Emoji, error) { newEmojis := []*gtsmodel.Emoji{} for _, e := range emojis { emoji := >smodel.Emoji{} diff --git a/internal/db/pg/pg_test.go b/internal/db/bundb/bundb_test.go similarity index 96% rename from internal/db/pg/pg_test.go rename to internal/db/bundb/bundb_test.go index c1e10abdf..b789375af 100644 --- a/internal/db/pg/pg_test.go +++ b/internal/db/bundb/bundb_test.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package pg_test +package bundb_test import ( "github.com/sirupsen/logrus" @@ -27,7 +27,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -type PGStandardTestSuite struct { +type BunDBStandardTestSuite struct { // standard suite interfaces suite.Suite config *config.Config diff --git a/internal/db/pg/domain.go b/internal/db/bundb/domain.go similarity index 98% rename from internal/db/pg/domain.go rename to internal/db/bundb/domain.go index 4a9db09d6..6aa2b8ffe 100644 --- a/internal/db/pg/domain.go +++ b/internal/db/bundb/domain.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package pg +package bundb import ( "context" @@ -34,7 +34,6 @@ type domainDB struct { config *config.Config conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { diff --git a/internal/db/pg/instance.go b/internal/db/bundb/instance.go similarity index 98% rename from internal/db/pg/instance.go rename to internal/db/bundb/instance.go index 946c4c441..f9364346e 100644 --- a/internal/db/pg/instance.go +++ b/internal/db/bundb/instance.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package pg +package bundb import ( "context" @@ -32,7 +32,6 @@ type instanceDB struct { config *config.Config conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { diff --git a/internal/db/pg/media.go b/internal/db/bundb/media.go similarity index 97% rename from internal/db/pg/media.go rename to internal/db/bundb/media.go index dea26b8de..04e55ca62 100644 --- a/internal/db/pg/media.go +++ b/internal/db/bundb/media.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package pg +package bundb import ( "context" @@ -32,7 +32,6 @@ type mediaDB struct { config *config.Config conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } func (m *mediaDB) newMediaQ(i interface{}) *bun.SelectQuery { diff --git a/internal/db/pg/mention.go b/internal/db/bundb/mention.go similarity index 98% rename from internal/db/pg/mention.go rename to internal/db/bundb/mention.go index 5f61b93ec..a444f9b5f 100644 --- a/internal/db/pg/mention.go +++ b/internal/db/bundb/mention.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package pg +package bundb import ( "context" @@ -33,7 +33,6 @@ type mentionDB struct { config *config.Config conn *bun.DB log *logrus.Logger - cancel context.CancelFunc cache cache.Cache } diff --git a/internal/db/pg/notification.go b/internal/db/bundb/notification.go similarity index 98% rename from internal/db/pg/notification.go rename to internal/db/bundb/notification.go index 497bfb056..1c30837ec 100644 --- a/internal/db/pg/notification.go +++ b/internal/db/bundb/notification.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package pg +package bundb import ( "context" @@ -33,7 +33,6 @@ type notificationDB struct { config *config.Config conn *bun.DB log *logrus.Logger - cancel context.CancelFunc cache cache.Cache } diff --git a/internal/db/pg/relationship.go b/internal/db/bundb/relationship.go similarity index 99% rename from internal/db/pg/relationship.go rename to internal/db/bundb/relationship.go index f78179476..ccc604baf 100644 --- a/internal/db/pg/relationship.go +++ b/internal/db/bundb/relationship.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package pg +package bundb import ( "context" @@ -34,7 +34,6 @@ type relationshipDB struct { config *config.Config conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *bun.SelectQuery { diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go new file mode 100644 index 000000000..87e20673d --- /dev/null +++ b/internal/db/bundb/session.go @@ -0,0 +1,85 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package bundb + +import ( + "context" + "crypto/rand" + + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/uptrace/bun" +) + +type sessionDB struct { + config *config.Config + conn *bun.DB + log *logrus.Logger +} + +func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { + rs := new(gtsmodel.RouterSession) + + q := s.conn. + NewSelect(). + Model(rs). + Limit(1) + + _, err := q.Exec(ctx) + + err = processErrorResponse(err) + + return rs, err +} + +func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { + auth := make([]byte, 32) + crypt := make([]byte, 32) + + if _, err := rand.Read(auth); err != nil { + return nil, err + } + if _, err := rand.Read(crypt); err != nil { + return nil, err + } + + rid, err := id.NewULID() + if err != nil { + return nil, err + } + + rs := >smodel.RouterSession{ + ID: rid, + Auth: auth, + Crypt: crypt, + } + + q := s.conn. + NewInsert(). + Model(rs) + + _, err = q.Exec(ctx) + + err = processErrorResponse(err) + + return rs, err +} diff --git a/internal/db/pg/status.go b/internal/db/bundb/status.go similarity index 99% rename from internal/db/pg/status.go rename to internal/db/bundb/status.go index e4609a116..da8d8ca41 100644 --- a/internal/db/pg/status.go +++ b/internal/db/bundb/status.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package pg +package bundb import ( "container/list" @@ -36,7 +36,6 @@ type statusDB struct { config *config.Config conn *bun.DB log *logrus.Logger - cancel context.CancelFunc cache cache.Cache } diff --git a/internal/db/pg/status_test.go b/internal/db/bundb/status_test.go similarity index 98% rename from internal/db/pg/status_test.go rename to internal/db/bundb/status_test.go index e3b9f1867..513000577 100644 --- a/internal/db/pg/status_test.go +++ b/internal/db/bundb/status_test.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package pg_test +package bundb_test import ( "context" @@ -29,7 +29,7 @@ import ( ) type StatusTestSuite struct { - PGStandardTestSuite + BunDBStandardTestSuite } func (suite *StatusTestSuite) SetupSuite() { diff --git a/internal/db/pg/timeline.go b/internal/db/bundb/timeline.go similarity index 98% rename from internal/db/pg/timeline.go rename to internal/db/bundb/timeline.go index 0059f8319..b62ad4c50 100644 --- a/internal/db/pg/timeline.go +++ b/internal/db/bundb/timeline.go @@ -16,7 +16,7 @@ along with this program. If not, see . */ -package pg +package bundb import ( "context" @@ -34,7 +34,6 @@ type timelineDB struct { config *config.Config conn *bun.DB log *logrus.Logger - cancel context.CancelFunc } func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { @@ -78,7 +77,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI // OR statuses posted by accountID itself (since a user should be able to see their own statuses). // // This is equivalent to something like WHERE ... AND (... OR ...) - // See: https://pg.uptrace.dev/queries/#select + // See: https://bun.uptrace.dev/guide/queries.html#select whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { return q. WhereOr("f.account_id = ?", accountID). diff --git a/internal/db/pg/util.go b/internal/db/bundb/util.go similarity index 91% rename from internal/db/pg/util.go rename to internal/db/bundb/util.go index 90e784c3b..115d18de2 100644 --- a/internal/db/pg/util.go +++ b/internal/db/bundb/util.go @@ -16,10 +16,11 @@ along with this program. If not, see . */ -package pg +package bundb import ( "context" + "strings" "database/sql" @@ -35,6 +36,9 @@ func processErrorResponse(err error) db.Error { case sql.ErrNoRows: return db.ErrNoEntries default: + if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { + return db.ErrAlreadyExists + } return err } } diff --git a/internal/db/db.go b/internal/db/db.go index 71ac887bb..ec94fcfe7 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -40,6 +40,7 @@ type DB interface { Mention Notification Relationship + Session Status Timeline diff --git a/internal/db/session.go b/internal/db/session.go new file mode 100644 index 000000000..ae13dccce --- /dev/null +++ b/internal/db/session.go @@ -0,0 +1,31 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package db + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Session handles getting/creation of router sessions. +type Session interface { + GetSession(ctx context.Context) (*gtsmodel.RouterSession, Error) + CreateSession(ctx context.Context) (*gtsmodel.RouterSession, Error) +} diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go index dbecd49f6..2eee0645d 100644 --- a/internal/federation/dereferencing/account.go +++ b/internal/federation/dereferencing/account.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" "net/url" + "strings" "github.com/go-fed/activity/streams" "github.com/go-fed/activity/streams/vocab" @@ -34,18 +35,33 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/transport" ) +func instanceAccount(account *gtsmodel.Account) bool { + return strings.EqualFold(account.Username, account.Domain) || + account.FollowersURI == "" || + account.FollowingURI == "" || + (account.Username == "internal.fetch" && strings.Contains(account.Note, "internal service actor")) +} + // EnrichRemoteAccount takes an account that's already been inserted into the database in a minimal form, // and populates it with additional fields, media, etc. // // EnrichRemoteAccount is mostly useful for calling after an account has been initially created by // the federatingDB's Create function, or during the federated authorization flow. func (d *deref) EnrichRemoteAccount(ctx context.Context, username string, account *gtsmodel.Account) (*gtsmodel.Account, error) { + + // if we're dealing with an instance account, we don't need to update anything + if instanceAccount(account) { + return account, nil + } + if err := d.PopulateAccountFields(ctx, account, username, false); err != nil { return nil, err } - if err := d.db.UpdateByID(ctx, account.ID, account); err != nil { - return nil, fmt.Errorf("EnrichRemoteAccount: error updating account: %s", err) + var err error + account, err = d.db.UpdateAccount(ctx, account) + if err != nil { + d.log.Errorf("EnrichRemoteAccount: error updating account: %s", err) } return account, nil @@ -108,8 +124,9 @@ func (d *deref) GetRemoteAccount(ctx context.Context, username string, remoteAcc return nil, new, fmt.Errorf("FullyDereferenceAccount: error populating further account fields: %s", err) } - if err := d.db.UpdateByID(ctx, gtsAccount.ID, gtsAccount); err != nil { - return nil, new, fmt.Errorf("FullyDereferenceAccount: error updating existing account: %s", err) + gtsAccount, err = d.db.UpdateAccount(ctx, gtsAccount) + if err != nil { + return nil, false, fmt.Errorf("EnrichRemoteAccount: error updating account: %s", err) } } diff --git a/internal/federation/dereferencing/thread.go b/internal/federation/dereferencing/thread.go index 328a1c4ee..f9dd9aa09 100644 --- a/internal/federation/dereferencing/thread.go +++ b/internal/federation/dereferencing/thread.go @@ -25,7 +25,6 @@ import ( "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/ap" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/util" ) @@ -87,8 +86,8 @@ func (d *deref) iterateAncestors(ctx context.Context, username string, statusIRI return err } - status := >smodel.Status{} - if err := d.db.GetByID(ctx, id, status); err != nil { + status, err := d.db.GetStatusByID(ctx, id) + if err != nil { return err } diff --git a/internal/federation/federatingdb/delete.go b/internal/federation/federatingdb/delete.go index 94dfdf71f..11b818168 100644 --- a/internal/federation/federatingdb/delete.go +++ b/internal/federation/federatingdb/delete.go @@ -87,7 +87,7 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error { a, err := f.db.GetAccountByURI(ctx, id.String()) if err == nil { // it's an account - l.Debugf("uri is for an account with id: %s", s.ID) + l.Debugf("uri is for an account with id: %s", a.ID) if err := f.db.DeleteByID(ctx, a.ID, >smodel.Account{}); err != nil { return fmt.Errorf("DELETE: err deleting account: %s", err) } diff --git a/internal/federation/federatingdb/followers.go b/internal/federation/federatingdb/followers.go index e0923453f..c7f636a12 100644 --- a/internal/federation/federatingdb/followers.go +++ b/internal/federation/federatingdb/followers.go @@ -51,13 +51,17 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow followers = streams.NewActivityStreamsCollection() items := streams.NewActivityStreamsItemsProperty() for _, follow := range acctFollowers { - gtsFollower := >smodel.Account{} - if err := f.db.GetByID(ctx, follow.AccountID, gtsFollower); err != nil { - return nil, fmt.Errorf("FOLLOWERS: db error getting account id %s: %s", follow.AccountID, err) + if follow.Account == nil { + followAccount, err := f.db.GetAccountByID(ctx, follow.AccountID) + if err != nil { + return nil, fmt.Errorf("FOLLOWERS: db error getting account id %s: %s", follow.AccountID, err) + } + follow.Account = followAccount } - uri, err := url.Parse(gtsFollower.URI) + + uri, err := url.Parse(follow.Account.URI) if err != nil { - return nil, fmt.Errorf("FOLLOWERS: error parsing %s as url: %s", gtsFollower.URI, err) + return nil, fmt.Errorf("FOLLOWERS: error parsing %s as url: %s", follow.Account.URI, err) } items.AppendIRI(uri) } diff --git a/internal/federation/federatingdb/following.go b/internal/federation/federatingdb/following.go index 963ba63e4..9d5c0693c 100644 --- a/internal/federation/federatingdb/following.go +++ b/internal/federation/federatingdb/following.go @@ -64,13 +64,17 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow following = streams.NewActivityStreamsCollection() items := streams.NewActivityStreamsItemsProperty() for _, follow := range acctFollowing { - gtsFollowing := >smodel.Account{} - if err := f.db.GetByID(ctx, follow.AccountID, gtsFollowing); err != nil { - return nil, fmt.Errorf("FOLLOWING: db error getting account id %s: %s", follow.AccountID, err) + if follow.Account == nil { + followAccount, err := f.db.GetAccountByID(ctx, follow.AccountID) + if err != nil { + return nil, fmt.Errorf("FOLLOWING: db error getting account id %s: %s", follow.AccountID, err) + } + follow.Account = followAccount } - uri, err := url.Parse(gtsFollowing.URI) + + uri, err := url.Parse(follow.Account.URI) if err != nil { - return nil, fmt.Errorf("FOLLOWING: error parsing %s as url: %s", gtsFollowing.URI, err) + return nil, fmt.Errorf("FOLLOWING: error parsing %s as url: %s", follow.Account.URI, err) } items.AppendIRI(uri) } diff --git a/internal/federation/federatingdb/update.go b/internal/federation/federatingdb/update.go index 324509ebd..e9dfe5315 100644 --- a/internal/federation/federatingdb/update.go +++ b/internal/federation/federatingdb/update.go @@ -152,7 +152,8 @@ func (f *federatingDB) Update(ctx context.Context, asType vocab.Type) error { } updatedAcct.ID = requestingAcct.ID // set this here so the db will update properly instead of trying to PUT this and getting constraint issues - if err := f.db.UpdateByID(ctx, requestingAcct.ID, updatedAcct); err != nil { + updatedAcct, err = f.db.UpdateAccount(ctx, updatedAcct) + if err != nil { return fmt.Errorf("UPDATE: database error inserting updated account: %s", err) } diff --git a/internal/gtsmodel/account.go b/internal/gtsmodel/account.go index 8eeab1812..98d2dcfc9 100644 --- a/internal/gtsmodel/account.go +++ b/internal/gtsmodel/account.go @@ -18,8 +18,8 @@ // Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database. // These types should never be serialized and/or sent out via public APIs, as they contain sensitive information. -// The annotation used on these structs is for handling them via the go-pg ORM (hence why they're in this db subdir). -// See here for more info on go-pg model annotations: https://pg.uptrace.dev/models/ +// The annotation used on these structs is for handling them via the bun-db ORM. +// See here for more info on bun model annotations: https://bun.uptrace.dev/guide/models.html package gtsmodel import ( @@ -34,9 +34,9 @@ type Account struct { */ // id of this account in the local database - ID string `bun:"type:CHAR(26),pk,notnull,unique"` + ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // Username of the account, should just be a string of [a-z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org`` - Username string `bun:",notnull,unique:userdomain"` // username and domain should be unique *with* each other + Username string `bun:",notnull,unique:userdomain,nullzero"` // username and domain should be unique *with* each other // Domain of the account, will be null if this is a local account, otherwise something like ``example.org`` or ``mastodon.social``. Should be unique with username. Domain string `bun:",unique:userdomain,nullzero"` // username and domain should be unique *with* each other @@ -93,21 +93,21 @@ type Account struct { */ // What is the activitypub URI for this account discovered by webfinger? - URI string `bun:",unique"` + URI string `bun:",unique,nullzero"` // At which URL can we see the user account in a web browser? - URL string `bun:",unique"` + URL string `bun:",unique,nullzero"` // Last time this account was located using the webfinger API. LastWebfingeredAt time.Time `bun:",nullzero"` // Address of this account's activitypub inbox, for sending activity to - InboxURI string `bun:",unique"` + InboxURI string `bun:",unique,nullzero"` // Address of this account's activitypub outbox - OutboxURI string `bun:",unique"` + OutboxURI string `bun:",unique,nullzero"` // URI for getting the following list of this account - FollowingURI string `bun:",unique"` + FollowingURI string `bun:",unique,nullzero"` // URI for getting the followers list of this account - FollowersURI string `bun:",unique"` + FollowersURI string `bun:",unique,nullzero"` // URL for getting the featured collection list of this account - FeaturedCollectionURI string `bun:",unique"` + FeaturedCollectionURI string `bun:",unique,nullzero"` // What type of activitypub actor is this account? ActorType string // This account is associated with x account id @@ -137,7 +137,7 @@ type Account struct { // Should we hide this account's collections? HideCollections bool // id of the database entry that caused this account to become suspended -- can be an account ID or a domain block ID - SuspensionOrigin string `bun:"type:CHAR(26)"` + SuspensionOrigin string `bun:"type:CHAR(26),nullzero"` } // Field represents a key value field on an account, for things like pronouns, website, etc. diff --git a/internal/gtsmodel/emoji.go b/internal/gtsmodel/emoji.go index 549951ddd..3b02c14e7 100644 --- a/internal/gtsmodel/emoji.go +++ b/internal/gtsmodel/emoji.go @@ -73,5 +73,5 @@ type Emoji struct { // Is this emoji visible in the admin emoji picker? VisibleInPicker bool `bun:",notnull,default:true"` // In which emoji category is this emoji visible? - CategoryID string `bun:"type:CHAR(26),nullzero"` + CategoryID string `bun:"type:CHAR(26),nullzero"` } diff --git a/internal/gtsmodel/followrequest.go b/internal/gtsmodel/followrequest.go index 4a98525be..5a6cb5e02 100644 --- a/internal/gtsmodel/followrequest.go +++ b/internal/gtsmodel/followrequest.go @@ -29,15 +29,15 @@ type FollowRequest struct { // When was this follow request last updated? UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` // Who does this follow request originate from? - AccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"` - Account Account `bun:"rel:belongs-to"` + AccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"` + Account *Account `bun:"rel:belongs-to"` // Who is the target of this follow request? - TargetAccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"` - TargetAccount Account `bun:"rel:belongs-to"` + TargetAccountID string `bun:"type:CHAR(26),unique:frsrctarget,notnull"` + TargetAccount *Account `bun:"rel:belongs-to"` // Does this follow also want to see reblogs and not just posts? ShowReblogs bool `bun:"default:true"` // What is the activitypub URI of this follow request? - URI string `bun:",unique"` + URI string `bun:",unique,nullzero"` // does the following account want to be notified when the followed account posts? Notify bool } diff --git a/internal/gtsmodel/status.go b/internal/gtsmodel/status.go index b4772e96b..c9766a7f2 100644 --- a/internal/gtsmodel/status.go +++ b/internal/gtsmodel/status.go @@ -27,9 +27,9 @@ type Status struct { // id of the status in the database ID string `bun:"type:CHAR(26),pk,notnull"` // uri at which this status is reachable - URI string `bun:",unique"` + URI string `bun:",unique,nullzero"` // web url for viewing this status - URL string `bun:",unique"` + URL string `bun:",unique,nullzero"` // the html-formatted content of this status Content string // Database IDs of any media attachments associated with this status @@ -45,9 +45,9 @@ type Status struct { EmojiIDs []string `bun:"emojis,array"` Emojis []*Emoji `bun:"attached_emojis,m2m:status_to_emojis"` // https://bun.uptrace.dev/guide/relations.html#many-to-many-relation // when was this status created? - CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` + CreatedAt time.Time `bun:",notnull,default:current_timestamp"` // when was this status updated? - UpdatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` + UpdatedAt time.Time `bun:",notnull,default:current_timestamp"` // is this status from a local account? Local bool // which account posted this status? @@ -93,17 +93,17 @@ type Status struct { // StatusToTag is an intermediate struct to facilitate the many2many relationship between a status and one or more tags. type StatusToTag struct { - StatusID string `bun:"type:CHAR(26),unique:statustag"` + StatusID string `bun:"type:CHAR(26),unique:statustag,nullzero"` Status *Status `bun:"rel:belongs-to"` - TagID string `bun:"type:CHAR(26),unique:statustag"` + TagID string `bun:"type:CHAR(26),unique:statustag,nullzero"` Tag *Tag `bun:"rel:belongs-to"` } // StatusToEmoji is an intermediate struct to facilitate the many2many relationship between a status and one or more emojis. type StatusToEmoji struct { - StatusID string `bun:"type:CHAR(26),unique:statusemoji"` + StatusID string `bun:"type:CHAR(26),unique:statusemoji,nullzero"` Status *Status `bun:"rel:belongs-to"` - EmojiID string `bun:"type:CHAR(26),unique:statusemoji"` + EmojiID string `bun:"type:CHAR(26),unique:statusemoji,nullzero"` Emoji *Emoji `bun:"rel:belongs-to"` } diff --git a/internal/gtsmodel/user.go b/internal/gtsmodel/user.go index 026fac9fc..f439be439 100644 --- a/internal/gtsmodel/user.go +++ b/internal/gtsmodel/user.go @@ -33,9 +33,9 @@ type User struct { // id of this user in the local database; the end-user will never need to know this, it's strictly internal ID string `bun:"type:CHAR(26),pk,notnull,unique"` // confirmed email address for this user, this should be unique -- only one email address registered per instance, multiple users per email are not supported - Email string `bun:"default:null,unique"` + Email string `bun:"default:null,unique,nullzero"` // The id of the local gtsmodel.Account entry for this user, if it exists (unconfirmed users don't have an account yet) - AccountID string `bun:"type:CHAR(26),unique"` + AccountID string `bun:"type:CHAR(26),unique,nullzero"` Account *Account `bun:"rel:belongs-to"` // The encrypted password of this user, generated using https://pkg.go.dev/golang.org/x/crypto/bcrypt#GenerateFromPassword. A salt is included so we're safe against 🌈 tables EncryptedPassword string `bun:",notnull"` diff --git a/internal/oauth/clientstore.go b/internal/oauth/clientstore.go index ad8dbd91e..a642f6cfa 100644 --- a/internal/oauth/clientstore.go +++ b/internal/oauth/clientstore.go @@ -39,9 +39,7 @@ func NewClientStore(db db.Basic) oauth2.ClientStore { } func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { - poc := &Client{ - ID: clientID, - } + poc := &Client{} if err := cs.db.GetByID(ctx, clientID, poc); err != nil { return nil, err } @@ -67,7 +65,7 @@ func (cs *clientStore) Delete(ctx context.Context, id string) error { // Client is a handy little wrapper for typical oauth client details type Client struct { - ID string `pg:"type:CHAR(26),pk,notnull"` + ID string `bun:"type:CHAR(26),pk,notnull"` Secret string Domain string UserID string diff --git a/internal/oauth/tokenstore.go b/internal/oauth/tokenstore.go index 33b4f7d00..264678ff5 100644 --- a/internal/oauth/tokenstore.go +++ b/internal/oauth/tokenstore.go @@ -43,13 +43,13 @@ type tokenStore struct { // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through // the tokens in the DB once per minute and deletes any that have expired. func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.TokenStore { - pts := &tokenStore{ + ts := &tokenStore{ db: db, log: log, } // set the token store to clean out expired tokens once per minute, or return if we're done - go func(ctx context.Context, pts *tokenStore, log *logrus.Logger) { + go func(ctx context.Context, ts *tokenStore, log *logrus.Logger) { cleanloop: for { select { @@ -58,32 +58,32 @@ func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2. break cleanloop case <-time.After(1 * time.Minute): log.Trace("sweeping out old oauth entries broom broom") - if err := pts.sweep(ctx); err != nil { + if err := ts.sweep(ctx); err != nil { log.Errorf("error while sweeping oauth entries: %s", err) } } } - }(ctx, pts, log) - return pts + }(ctx, ts, log) + return ts } // sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so. -func (pts *tokenStore) sweep(ctx context.Context) error { +func (ts *tokenStore) sweep(ctx context.Context) error { // select *all* tokens from the db // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. tokens := new([]*Token) - if err := pts.db.GetAll(ctx, tokens); err != nil { + if err := ts.db.GetAll(ctx, tokens); err != nil { return err } // iterate through and remove expired tokens now := time.Now() - for _, pgt := range *tokens { + for _, dbt := range *tokens { // The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: // we only want to check if a token expired before now if the expiry time is *not zero*; // ie., if it's been explicity set. - if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) { - if err := pts.db.DeleteByID(ctx, pgt.ID, pgt); err != nil { + if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) { + if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil { return err } } @@ -94,92 +94,92 @@ func (pts *tokenStore) sweep(ctx context.Context) error { // Create creates and store the new token information. // For the original implementation, see https://github.com/superseriousbusiness/oauth2/blob/master/store/token.go#L34 -func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { +func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { t, ok := info.(*models.Token) if !ok { return errors.New("info param was not a models.Token") } - pgt := TokenToPGToken(t) - if pgt.ID == "" { - pgtID, err := id.NewRandomULID() + dbt := TokenToDBToken(t) + if dbt.ID == "" { + dbtID, err := id.NewRandomULID() if err != nil { return err } - pgt.ID = pgtID + dbt.ID = dbtID } - if err := pts.db.Put(ctx, pgt); err != nil { + if err := ts.db.Put(ctx, dbt); err != nil { return fmt.Errorf("error in tokenstore create: %s", err) } return nil } // RemoveByCode deletes a token from the DB based on the Code field -func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error { - return pts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, &Token{}) +func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error { + return ts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, &Token{}) } // RemoveByAccess deletes a token from the DB based on the Access field -func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { - return pts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, &Token{}) +func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { + return ts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, &Token{}) } // RemoveByRefresh deletes a token from the DB based on the Refresh field -func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { - return pts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, &Token{}) +func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { + return ts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, &Token{}) } // GetByCode selects a token from the DB based on the Code field -func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { +func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { if code == "" { return nil, nil } - pgt := &Token{ + dbt := &Token{ Code: code, } - if err := pts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, pgt); err != nil { + if err := ts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, dbt); err != nil { return nil, err } - return TokenToOauthToken(pgt), nil + return DBTokenToToken(dbt), nil } // GetByAccess selects a token from the DB based on the Access field -func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { +func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { if access == "" { return nil, nil } - pgt := &Token{ + dbt := &Token{ Access: access, } - if err := pts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, pgt); err != nil { + if err := ts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, dbt); err != nil { return nil, err } - return TokenToOauthToken(pgt), nil + return DBTokenToToken(dbt), nil } // GetByRefresh selects a token from the DB based on the Refresh field -func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { +func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { if refresh == "" { return nil, nil } - pgt := &Token{ + dbt := &Token{ Refresh: refresh, } - if err := pts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, pgt); err != nil { + if err := ts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, dbt); err != nil { return nil, err } - return TokenToOauthToken(pgt), nil + return DBTokenToToken(dbt), nil } /* - The following models are basically helpers for the postgres token store implementation, they should only be used internally. + The following models are basically helpers for the token store implementation, they should only be used internally. */ // Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt. // // Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined, -// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and +// and tokens with expired TTLs are automatically removed. Since some databases don't have that feature, it's easier to set an expiry time and // then periodically sweep out tokens when that time has passed. // // Note that this struct does *not* satisfy the token interface shown here: https://github.com/superseriousbusiness/oauth2/blob/master/model.go#L22 @@ -187,26 +187,26 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2 // As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken // and pgTokenToOauthToken can be used for that. type Token struct { - ID string `pg:"type:CHAR(26),pk,notnull"` + ID string `bun:"type:CHAR(26),pk,notnull"` ClientID string UserID string RedirectURI string Scope string - Code string `pg:"default:'',pk"` + Code string `bun:"default:'',pk"` CodeChallenge string CodeChallengeMethod string - CodeCreateAt time.Time `pg:"type:timestamp"` - CodeExpiresAt time.Time `pg:"type:timestamp"` - Access string `pg:"default:'',pk"` - AccessCreateAt time.Time `pg:"type:timestamp"` - AccessExpiresAt time.Time `pg:"type:timestamp"` - Refresh string `pg:"default:'',pk"` - RefreshCreateAt time.Time `pg:"type:timestamp"` - RefreshExpiresAt time.Time `pg:"type:timestamp"` + CodeCreateAt time.Time `bun:",nullzero"` + CodeExpiresAt time.Time `bun:",nullzero"` + Access string `bun:"default:'',pk"` + AccessCreateAt time.Time `bun:",nullzero"` + AccessExpiresAt time.Time `bun:",nullzero"` + Refresh string `bun:"default:'',pk"` + RefreshCreateAt time.Time `bun:",nullzero"` + RefreshExpiresAt time.Time `bun:",nullzero"` } -// TokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres -func TokenToPGToken(tkn *models.Token) *Token { +// TokenToDBToken is a lil util function that takes a gotosocial token and gives back a token for inserting into a database. +func TokenToDBToken(tkn *models.Token) *Token { now := time.Now() // For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's @@ -247,40 +247,40 @@ func TokenToPGToken(tkn *models.Token) *Token { } } -// TokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token -func TokenToOauthToken(pgt *Token) *models.Token { +// DBTokenToToken is a lil util function that takes a database token and gives back a gotosocial token +func DBTokenToToken(dbt *Token) *models.Token { now := time.Now() var codeExpiresIn time.Duration - if !pgt.CodeExpiresAt.IsZero() { - codeExpiresIn = pgt.CodeExpiresAt.Sub(now) + if !dbt.CodeExpiresAt.IsZero() { + codeExpiresIn = dbt.CodeExpiresAt.Sub(now) } var accessExpiresIn time.Duration - if !pgt.AccessExpiresAt.IsZero() { - accessExpiresIn = pgt.AccessExpiresAt.Sub(now) + if !dbt.AccessExpiresAt.IsZero() { + accessExpiresIn = dbt.AccessExpiresAt.Sub(now) } var refreshExpiresIn time.Duration - if !pgt.RefreshExpiresAt.IsZero() { - refreshExpiresIn = pgt.RefreshExpiresAt.Sub(now) + if !dbt.RefreshExpiresAt.IsZero() { + refreshExpiresIn = dbt.RefreshExpiresAt.Sub(now) } return &models.Token{ - ClientID: pgt.ClientID, - UserID: pgt.UserID, - RedirectURI: pgt.RedirectURI, - Scope: pgt.Scope, - Code: pgt.Code, - CodeChallenge: pgt.CodeChallenge, - CodeChallengeMethod: pgt.CodeChallengeMethod, - CodeCreateAt: pgt.CodeCreateAt, + ClientID: dbt.ClientID, + UserID: dbt.UserID, + RedirectURI: dbt.RedirectURI, + Scope: dbt.Scope, + Code: dbt.Code, + CodeChallenge: dbt.CodeChallenge, + CodeChallengeMethod: dbt.CodeChallengeMethod, + CodeCreateAt: dbt.CodeCreateAt, CodeExpiresIn: codeExpiresIn, - Access: pgt.Access, - AccessCreateAt: pgt.AccessCreateAt, + Access: dbt.Access, + AccessCreateAt: dbt.AccessCreateAt, AccessExpiresIn: accessExpiresIn, - Refresh: pgt.Refresh, - RefreshCreateAt: pgt.RefreshCreateAt, + Refresh: dbt.Refresh, + RefreshCreateAt: dbt.RefreshCreateAt, RefreshExpiresIn: refreshExpiresIn, } } diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index a0758c846..d97af4d2e 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -177,17 +177,21 @@ selectStatusesLoop: } for _, b := range boosts { - oa := >smodel.Account{} - if err := p.db.GetByID(ctx, b.AccountID, oa); err == nil { - - l.Debug("putting boost undo in the client api channel") - p.fromClientAPI <- gtsmodel.FromClientAPI{ - APObjectType: gtsmodel.ActivityStreamsAnnounce, - APActivityType: gtsmodel.ActivityStreamsUndo, - GTSModel: s, - OriginAccount: oa, - TargetAccount: account, + if b.Account == nil { + bAccount, err := p.db.GetAccountByID(ctx, b.AccountID) + if err != nil { + continue } + b.Account = bAccount + } + + l.Debug("putting boost undo in the client api channel") + p.fromClientAPI <- gtsmodel.FromClientAPI{ + APObjectType: gtsmodel.ActivityStreamsAnnounce, + APActivityType: gtsmodel.ActivityStreamsUndo, + GTSModel: s, + OriginAccount: b.Account, + TargetAccount: account, } if err := p.db.DeleteByID(ctx, b.ID, b); err != nil { @@ -267,7 +271,8 @@ selectStatusesLoop: account.SuspendedAt = time.Now() account.SuspensionOrigin = origin - if err := p.db.UpdateByID(ctx, account.ID, account); err != nil { + account, err := p.db.UpdateAccount(ctx, account) + if err != nil { return err } diff --git a/internal/processing/account/get.go b/internal/processing/account/get.go index 39ae1c376..5f039127c 100644 --- a/internal/processing/account/get.go +++ b/internal/processing/account/get.go @@ -29,8 +29,8 @@ import ( ) func (p *processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) { - targetAccount := >smodel.Account{} - if err := p.db.GetByID(ctx, targetAccountID, targetAccount); err != nil { + targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) + if err != nil { if err == db.ErrNoEntries { return nil, errors.New("account not found") } @@ -38,7 +38,6 @@ func (p *processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account } var blocked bool - var err error if requestingAccount != nil { blocked, err = p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true) if err != nil { diff --git a/internal/processing/account/removefollow.go b/internal/processing/account/removefollow.go index c271de79f..6186c550f 100644 --- a/internal/processing/account/removefollow.go +++ b/internal/processing/account/removefollow.go @@ -39,8 +39,8 @@ func (p *processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode } // make sure the target account actually exists in our db - targetAcct := >smodel.Account{} - if err := p.db.GetByID(ctx, targetAccountID, targetAcct); err != nil { + targetAcct, err := p.db.GetAccountByID(ctx, targetAccountID) + if err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err)) } diff --git a/internal/processing/account/update.go b/internal/processing/account/update.go index 339542760..99ccbf5a0 100644 --- a/internal/processing/account/update.go +++ b/internal/processing/account/update.go @@ -117,8 +117,8 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, form } // fetch the account with all updated values set - updatedAccount := >smodel.Account{} - if err := p.db.GetByID(ctx, account.ID, updatedAccount); err != nil { + updatedAccount, err := p.db.GetAccountByID(ctx, account.ID) + if err != nil { return nil, fmt.Errorf("could not fetch updated account %s: %s", account.ID, err) } diff --git a/internal/processing/followrequest.go b/internal/processing/followrequest.go index da31d14fc..3dd6432e2 100644 --- a/internal/processing/followrequest.go +++ b/internal/processing/followrequest.go @@ -38,11 +38,15 @@ func (p *processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([] accts := []apimodel.Account{} for _, fr := range frs { - acct := >smodel.Account{} - if err := p.db.GetByID(ctx, fr.AccountID, acct); err != nil { - return nil, gtserror.NewErrorInternalError(err) + if fr.Account == nil { + frAcct, err := p.db.GetAccountByID(ctx, fr.AccountID) + if err != nil { + return nil, gtserror.NewErrorInternalError(err) + } + fr.Account = frAcct } - mastoAcct, err := p.tc.AccountToMastoPublic(ctx, acct) + + mastoAcct, err := p.tc.AccountToMastoPublic(ctx, fr.Account) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -57,22 +61,28 @@ func (p *processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a return nil, gtserror.NewErrorNotFound(err) } - originAccount := >smodel.Account{} - if err := p.db.GetByID(ctx, follow.AccountID, originAccount); err != nil { - return nil, gtserror.NewErrorInternalError(err) + if follow.Account == nil { + followAccount, err := p.db.GetAccountByID(ctx, follow.AccountID) + if err != nil { + return nil, gtserror.NewErrorInternalError(err) + } + follow.Account = followAccount } - targetAccount := >smodel.Account{} - if err := p.db.GetByID(ctx, follow.TargetAccountID, targetAccount); err != nil { - return nil, gtserror.NewErrorInternalError(err) + if follow.TargetAccount == nil { + followTargetAccount, err := p.db.GetAccountByID(ctx, follow.TargetAccountID) + if err != nil { + return nil, gtserror.NewErrorInternalError(err) + } + follow.TargetAccount = followTargetAccount } p.fromClientAPI <- gtsmodel.FromClientAPI{ APObjectType: gtsmodel.ActivityStreamsFollow, APActivityType: gtsmodel.ActivityStreamsAccept, GTSModel: follow, - OriginAccount: originAccount, - TargetAccount: targetAccount, + OriginAccount: follow.Account, + TargetAccount: follow.TargetAccount, } gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID) diff --git a/internal/processing/fromclientapi.go b/internal/processing/fromclientapi.go index 2e0f2c0c4..a6ea0068b 100644 --- a/internal/processing/fromclientapi.go +++ b/internal/processing/fromclientapi.go @@ -238,11 +238,11 @@ func (p *processor) processFromClientAPI(ctx context.Context, clientMsg gtsmodel func (p *processor) federateStatus(ctx context.Context, status *gtsmodel.Status) error { if status.Account == nil { - a := >smodel.Account{} - if err := p.db.GetByID(ctx, status.AccountID, a); err != nil { + statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID) + if err != nil { return fmt.Errorf("federateStatus: error fetching status author account: %s", err) } - status.Account = a + status.Account = statusAccount } // do nothing if this isn't our status @@ -266,11 +266,11 @@ func (p *processor) federateStatus(ctx context.Context, status *gtsmodel.Status) func (p *processor) federateStatusDelete(ctx context.Context, status *gtsmodel.Status) error { if status.Account == nil { - a := >smodel.Account{} - if err := p.db.GetByID(ctx, status.AccountID, a); err != nil { - return fmt.Errorf("federateStatus: error fetching status author account: %s", err) + statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID) + if err != nil { + return fmt.Errorf("federateStatusDelete: error fetching status author account: %s", err) } - status.Account = a + status.Account = statusAccount } // do nothing if this isn't our status @@ -558,19 +558,19 @@ func (p *processor) federateAccountUpdate(ctx context.Context, updatedAccount *g func (p *processor) federateBlock(ctx context.Context, block *gtsmodel.Block) error { if block.Account == nil { - a := >smodel.Account{} - if err := p.db.GetByID(ctx, block.AccountID, a); err != nil { + blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID) + if err != nil { return fmt.Errorf("federateBlock: error getting block account from database: %s", err) } - block.Account = a + block.Account = blockAccount } if block.TargetAccount == nil { - a := >smodel.Account{} - if err := p.db.GetByID(ctx, block.TargetAccountID, a); err != nil { + blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID) + if err != nil { return fmt.Errorf("federateBlock: error getting block target account from database: %s", err) } - block.TargetAccount = a + block.TargetAccount = blockTargetAccount } // if both accounts are local there's nothing to do here @@ -594,19 +594,19 @@ func (p *processor) federateBlock(ctx context.Context, block *gtsmodel.Block) er func (p *processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) error { if block.Account == nil { - a := >smodel.Account{} - if err := p.db.GetByID(ctx, block.AccountID, a); err != nil { + blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID) + if err != nil { return fmt.Errorf("federateUnblock: error getting block account from database: %s", err) } - block.Account = a + block.Account = blockAccount } if block.TargetAccount == nil { - a := >smodel.Account{} - if err := p.db.GetByID(ctx, block.TargetAccountID, a); err != nil { + blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID) + if err != nil { return fmt.Errorf("federateUnblock: error getting block target account from database: %s", err) } - block.TargetAccount = a + block.TargetAccount = blockTargetAccount } // if both accounts are local there's nothing to do here diff --git a/internal/processing/media/delete.go b/internal/processing/media/delete.go index 878425800..281ddba03 100644 --- a/internal/processing/media/delete.go +++ b/internal/processing/media/delete.go @@ -7,12 +7,11 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) func (p *processor) Delete(ctx context.Context, mediaAttachmentID string) gtserror.WithCode { - a := >smodel.MediaAttachment{} - if err := p.db.GetByID(ctx, mediaAttachmentID, a); err != nil { + attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) + if err != nil { if err == db.ErrNoEntries { // attachment already gone return nil @@ -24,21 +23,21 @@ func (p *processor) Delete(ctx context.Context, mediaAttachmentID string) gtserr errs := []string{} // delete the thumbnail from storage - if a.Thumbnail.Path != "" { - if err := p.storage.RemoveFileAt(a.Thumbnail.Path); err != nil { - errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", a.Thumbnail.Path, err)) + if attachment.Thumbnail.Path != "" { + if err := p.storage.RemoveFileAt(attachment.Thumbnail.Path); err != nil { + errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", attachment.Thumbnail.Path, err)) } } // delete the file from storage - if a.File.Path != "" { - if err := p.storage.RemoveFileAt(a.File.Path); err != nil { - errs = append(errs, fmt.Sprintf("remove file at path %s: %s", a.File.Path, err)) + if attachment.File.Path != "" { + if err := p.storage.RemoveFileAt(attachment.File.Path); err != nil { + errs = append(errs, fmt.Sprintf("remove file at path %s: %s", attachment.File.Path, err)) } } // delete the attachment - if err := p.db.DeleteByID(ctx, mediaAttachmentID, a); err != nil { + if err := p.db.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil { if err != db.ErrNoEntries { errs = append(errs, fmt.Sprintf("remove attachment: %s", err)) } diff --git a/internal/processing/media/getfile.go b/internal/processing/media/getfile.go index 6666c6423..c9c9b556d 100644 --- a/internal/processing/media/getfile.go +++ b/internal/processing/media/getfile.go @@ -48,8 +48,8 @@ func (p *processor) GetFile(ctx context.Context, account *gtsmodel.Account, form wantedMediaID := spl[0] // get the account that owns the media and make sure it's not suspended - acct := >smodel.Account{} - if err := p.db.GetByID(ctx, form.AccountID, acct); err != nil { + acct, err := p.db.GetAccountByID(ctx, form.AccountID) + if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("account with id %s could not be selected from the db: %s", form.AccountID, err)) } if !acct.SuspendedAt.IsZero() { @@ -91,8 +91,8 @@ func (p *processor) GetFile(ctx context.Context, account *gtsmodel.Account, form return nil, gtserror.NewErrorNotFound(fmt.Errorf("media size %s not recognized for emoji", mediaSize)) } case media.Attachment, media.Header, media.Avatar: - a := >smodel.MediaAttachment{} - if err := p.db.GetByID(ctx, wantedMediaID, a); err != nil { + a, err := p.db.GetAttachmentByID(ctx, wantedMediaID) + if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("attachment %s could not be taken from the db: %s", wantedMediaID, err)) } if a.AccountID != form.AccountID { diff --git a/internal/processing/media/getmedia.go b/internal/processing/media/getmedia.go index 48b7b351a..91608e90d 100644 --- a/internal/processing/media/getmedia.go +++ b/internal/processing/media/getmedia.go @@ -30,8 +30,8 @@ import ( ) func (p *processor) GetMedia(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) { - attachment := >smodel.MediaAttachment{} - if err := p.db.GetByID(ctx, mediaAttachmentID, attachment); err != nil { + attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) + if err != nil { if err == db.ErrNoEntries { // attachment doesn't exist return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db")) diff --git a/internal/processing/media/update.go b/internal/processing/media/update.go index 5402fd075..6f15f2ace 100644 --- a/internal/processing/media/update.go +++ b/internal/processing/media/update.go @@ -31,8 +31,8 @@ import ( ) func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) { - attachment := >smodel.MediaAttachment{} - if err := p.db.GetByID(ctx, mediaAttachmentID, attachment); err != nil { + attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) + if err != nil { if err == db.ErrNoEntries { // attachment doesn't exist return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db")) diff --git a/internal/processing/streaming/authorize.go b/internal/processing/streaming/authorize.go index 543145e2a..f938a0c0c 100644 --- a/internal/processing/streaming/authorize.go +++ b/internal/processing/streaming/authorize.go @@ -24,8 +24,8 @@ func (p *processor) AuthorizeStreamingRequest(ctx context.Context, accessToken s return nil, fmt.Errorf("AuthorizeStreamingRequest: no user found for validated uid %s", uid) } - acct := >smodel.Account{} - if err := p.db.GetByID(ctx, user.AccountID, acct); err != nil || acct == nil { + acct, err := p.db.GetAccountByID(ctx, user.AccountID) + if err != nil || acct == nil { return nil, fmt.Errorf("AuthorizeStreamingRequest: no account retrieved for user with id %s", uid) } diff --git a/internal/router/session.go b/internal/router/session.go index f336521d2..4359a8a60 100644 --- a/internal/router/session.go +++ b/internal/router/session.go @@ -20,7 +20,6 @@ package router import ( "context" - "crypto/rand" "errors" "fmt" "net/http" @@ -30,8 +29,6 @@ import ( "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/id" ) // SessionOptions returns the standard set of options to use for each session. @@ -46,34 +43,23 @@ func SessionOptions(cfg *config.Config) sessions.Options { } } -func useSession(ctx context.Context, cfg *config.Config, dbService db.DB, engine *gin.Engine) error { +func useSession(ctx context.Context, cfg *config.Config, sessionDB db.Session, engine *gin.Engine) error { // check if we have a saved router session already - routerSessions := []*gtsmodel.RouterSession{} - if err := dbService.GetAll(ctx, &routerSessions); err != nil { + rs, err := sessionDB.GetSession(ctx) + if err != nil { if err != db.ErrNoEntries { // proper error occurred return err } - } - - var rs *gtsmodel.RouterSession - if len(routerSessions) == 1 { - // we have a router session stored - rs = routerSessions[0] - } else if len(routerSessions) == 0 { - // we have no router sessions so we need to create a new one - var err error - rs, err = routerSession(ctx, dbService) + // no session saved so create a new one + rs, err = sessionDB.CreateSession(ctx) if err != nil { - return fmt.Errorf("error creating new router session: %s", err) + return err } - } else { - // we should only have one router session stored ever - return errors.New("we had more than one router session in the db") } if rs == nil { - return errors.New("error getting or creating router session: router session was nil") + return errors.New("router session was nil") } store := memstore.NewStore(rs.Auth, rs.Crypt) @@ -82,34 +68,3 @@ func useSession(ctx context.Context, cfg *config.Config, dbService db.DB, engine engine.Use(sessions.Sessions(sessionName, store)) return nil } - -// routerSession generates a new router session with random auth and crypt bytes, -// puts it in the database for persistence, and returns it for use. -func routerSession(ctx context.Context, dbService db.DB) (*gtsmodel.RouterSession, error) { - auth := make([]byte, 32) - crypt := make([]byte, 32) - - if _, err := rand.Read(auth); err != nil { - return nil, err - } - if _, err := rand.Read(crypt); err != nil { - return nil, err - } - - rid, err := id.NewULID() - if err != nil { - return nil, err - } - - rs := >smodel.RouterSession{ - ID: rid, - Auth: auth, - Crypt: crypt, - } - - if err := dbService.Put(ctx, rs); err != nil { - return nil, err - } - - return rs, nil -} diff --git a/internal/text/common.go b/internal/text/common.go index ecf7a7a98..a8d585a09 100644 --- a/internal/text/common.go +++ b/internal/text/common.go @@ -98,8 +98,8 @@ func (f *formatter) ReplaceMentions(ctx context.Context, in string, mentions []* // got it from the mention targetAccount = menchie.OriginAccount } else { - a := >smodel.Account{} - if err := f.db.GetByID(ctx, menchie.TargetAccountID, a); err == nil { + a, err := f.db.GetAccountByID(ctx, menchie.TargetAccountID) + if err == nil { // got it from the db targetAccount = a } else { diff --git a/internal/timeline/get.go b/internal/timeline/get.go index c22935550..a00613dc0 100644 --- a/internal/timeline/get.go +++ b/internal/timeline/get.go @@ -54,7 +54,8 @@ func (t *timeline) Get(ctx context.Context, amount int, maxID string, sinceID st if prepareNext { // already cache the next query to speed up scrolling go func() { - if err := t.prepareNextQuery(ctx, amount, nextMaxID, "", ""); err != nil { + // use context.Background() because we don't want the query to abort when the request finishes + if err := t.prepareNextQuery(context.Background(), amount, nextMaxID, "", ""); err != nil { l.Errorf("error preparing next query: %s", err) } }() @@ -73,7 +74,8 @@ func (t *timeline) Get(ctx context.Context, amount int, maxID string, sinceID st if prepareNext { // already cache the next query to speed up scrolling go func() { - if err := t.prepareNextQuery(ctx, amount, nextMaxID, "", ""); err != nil { + // use context.Background() because we don't want the query to abort when the request finishes + if err := t.prepareNextQuery(context.Background(), amount, nextMaxID, "", ""); err != nil { l.Errorf("error preparing next query: %s", err) } }() diff --git a/internal/typeutils/internaltoas.go b/internal/typeutils/internaltoas.go index abba3ea35..14ed094c5 100644 --- a/internal/typeutils/internaltoas.go +++ b/internal/typeutils/internaltoas.go @@ -214,9 +214,12 @@ func (c *converter) AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab // icon // Used as profile avatar. if a.AvatarMediaAttachmentID != "" { - avatar := >smodel.MediaAttachment{} - if err := c.db.GetByID(ctx, a.AvatarMediaAttachmentID, avatar); err != nil { - return nil, err + if a.AvatarMediaAttachment == nil { + avatar := >smodel.MediaAttachment{} + if err := c.db.GetByID(ctx, a.AvatarMediaAttachmentID, avatar); err != nil { + return nil, err + } + a.AvatarMediaAttachment = avatar } iconProperty := streams.NewActivityStreamsIconProperty() @@ -224,11 +227,11 @@ func (c *converter) AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab iconImage := streams.NewActivityStreamsImage() mediaType := streams.NewActivityStreamsMediaTypeProperty() - mediaType.Set(avatar.File.ContentType) + mediaType.Set(a.AvatarMediaAttachment.File.ContentType) iconImage.SetActivityStreamsMediaType(mediaType) avatarURLProperty := streams.NewActivityStreamsUrlProperty() - avatarURL, err := url.Parse(avatar.URL) + avatarURL, err := url.Parse(a.AvatarMediaAttachment.URL) if err != nil { return nil, err } @@ -242,9 +245,12 @@ func (c *converter) AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab // image // Used as profile header. if a.HeaderMediaAttachmentID != "" { - header := >smodel.MediaAttachment{} - if err := c.db.GetByID(ctx, a.HeaderMediaAttachmentID, header); err != nil { - return nil, err + if a.HeaderMediaAttachment == nil { + header := >smodel.MediaAttachment{} + if err := c.db.GetByID(ctx, a.HeaderMediaAttachmentID, header); err != nil { + return nil, err + } + a.HeaderMediaAttachment = header } headerProperty := streams.NewActivityStreamsImageProperty() @@ -252,11 +258,11 @@ func (c *converter) AccountToAS(ctx context.Context, a *gtsmodel.Account) (vocab headerImage := streams.NewActivityStreamsImage() mediaType := streams.NewActivityStreamsMediaTypeProperty() - mediaType.Set(header.File.ContentType) + mediaType.Set(a.HeaderMediaAttachment.File.ContentType) headerImage.SetActivityStreamsMediaType(mediaType) headerURLProperty := streams.NewActivityStreamsUrlProperty() - headerURL, err := url.Parse(header.URL) + headerURL, err := url.Parse(a.HeaderMediaAttachment.URL) if err != nil { return nil, err } diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go index 52e394698..89da9eb01 100644 --- a/internal/typeutils/internaltofrontend.go +++ b/internal/typeutils/internaltofrontend.go @@ -63,6 +63,10 @@ func (c *converter) AccountToMastoSensitive(ctx context.Context, a *gtsmodel.Acc } func (c *converter) AccountToMastoPublic(ctx context.Context, a *gtsmodel.Account) (*model.Account, error) { + if a == nil { + return nil, fmt.Errorf("given account was nil") + } + // first check if we have this account in our frontEnd cache if accountI, err := c.frontendCache.Fetch(a.ID); err == nil { if account, ok := accountI.(*model.Account); ok { @@ -266,27 +270,30 @@ func (c *converter) AttachmentToMasto(ctx context.Context, a *gtsmodel.MediaAtta } func (c *converter) MentionToMasto(ctx context.Context, m *gtsmodel.Mention) (model.Mention, error) { - target := >smodel.Account{} - if err := c.db.GetByID(ctx, m.TargetAccountID, target); err != nil { - return model.Mention{}, err + if m.TargetAccount == nil { + targetAccount, err := c.db.GetAccountByID(ctx, m.TargetAccountID) + if err != nil { + return model.Mention{}, err + } + m.TargetAccount = targetAccount } var local bool - if target.Domain == "" { + if m.TargetAccount.Domain == "" { local = true } var acct string if local { - acct = target.Username + acct = m.TargetAccount.Username } else { - acct = fmt.Sprintf("%s@%s", target.Username, target.Domain) + acct = fmt.Sprintf("%s@%s", m.TargetAccount.Username, m.TargetAccount.Domain) } return model.Mention{ - ID: target.ID, - Username: target.Username, - URL: target.URL, + ID: m.TargetAccount.ID, + Username: m.TargetAccount.Username, + URL: m.TargetAccount.URL, Acct: acct, }, nil } @@ -302,11 +309,9 @@ func (c *converter) EmojiToMasto(ctx context.Context, e *gtsmodel.Emoji) (model. } func (c *converter) TagToMasto(ctx context.Context, t *gtsmodel.Tag) (model.Tag, error) { - tagURL := fmt.Sprintf("%s://%s/tags/%s", c.config.Protocol, c.config.Host, t.Name) - return model.Tag{ Name: t.Name, - URL: tagURL, // we don't serve URLs with collections of tagged statuses (FOR NOW) so this is purely for mastodon compatibility ¯\_(ツ)_/¯ + URL: t.URL, }, nil } @@ -331,8 +336,8 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque // the boosted status might have been set on this struct already so check first before doing db calls if s.BoostOf == nil { // it's not set so fetch it from the db - bs := >smodel.Status{} - if err := c.db.GetByID(ctx, s.BoostOfID, bs); err != nil { + bs, err := c.db.GetStatusByID(ctx, s.BoostOfID) + if err != nil { return nil, fmt.Errorf("error getting boosted status with id %s: %s", s.BoostOfID, err) } s.BoostOf = bs @@ -341,8 +346,8 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque // the boosted account might have been set on this struct already or passed as a param so check first before doing db calls if s.BoostOfAccount == nil { // it's not set so fetch it from the db - ba := >smodel.Account{} - if err := c.db.GetByID(ctx, s.BoostOf.AccountID, ba); err != nil { + ba, err := c.db.GetAccountByID(ctx, s.BoostOf.AccountID) + if err != nil { return nil, fmt.Errorf("error getting boosted account %s from status with id %s: %s", s.BoostOf.AccountID, s.BoostOfID, err) } s.BoostOfAccount = ba @@ -368,8 +373,8 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque } if s.Account == nil { - a := >smodel.Account{} - if err := c.db.GetByID(ctx, s.AccountID, a); err != nil { + a, err := c.db.GetAccountByID(ctx, s.AccountID) + if err != nil { return nil, fmt.Errorf("error getting status author: %s", err) } s.Account = a @@ -394,14 +399,14 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque // the status doesn't have gts attachments on it, but it does have attachment IDs // in this case, we need to pull the gts attachments from the db to convert them into masto ones } else { - for _, a := range s.AttachmentIDs { - gtsAttachment := >smodel.MediaAttachment{} - if err := c.db.GetByID(ctx, a, gtsAttachment); err != nil { - return nil, fmt.Errorf("error getting attachment with id %s: %s", a, err) + for _, aID := range s.AttachmentIDs { + gtsAttachment, err := c.db.GetAttachmentByID(ctx, aID) + if err != nil { + return nil, fmt.Errorf("error getting attachment with id %s: %s", aID, err) } mastoAttachment, err := c.AttachmentToMasto(ctx, gtsAttachment) if err != nil { - return nil, fmt.Errorf("error converting attachment with id %s: %s", a, err) + return nil, fmt.Errorf("error converting attachment with id %s: %s", aID, err) } mastoAttachments = append(mastoAttachments, mastoAttachment) } @@ -421,10 +426,10 @@ func (c *converter) StatusToMasto(ctx context.Context, s *gtsmodel.Status, reque // the status doesn't have gts mentions on it, but it does have mention IDs // in this case, we need to pull the gts mentions from the db to convert them into masto ones } else { - for _, m := range s.MentionIDs { - gtsMention := >smodel.Mention{} - if err := c.db.GetByID(ctx, m, gtsMention); err != nil { - return nil, fmt.Errorf("error getting mention with id %s: %s", m, err) + for _, mID := range s.MentionIDs { + gtsMention, err := c.db.GetMention(ctx, mID) + if err != nil { + return nil, fmt.Errorf("error getting mention with id %s: %s", mID, err) } mastoMention, err := c.MentionToMasto(ctx, gtsMention) if err != nil { @@ -603,13 +608,16 @@ func (c *converter) InstanceToMasto(ctx context.Context, i *gtsmodel.Instance) ( // contact account is optional but let's try to get it if i.ContactAccountID != "" { - ia := >smodel.Account{} - if err := c.db.GetByID(ctx, i.ContactAccountID, ia); err == nil { - ma, err := c.AccountToMastoPublic(ctx, ia) + if i.ContactAccount == nil { + contactAccount, err := c.db.GetAccountByID(ctx, i.ContactAccountID) if err == nil { - mi.ContactAccount = ma + i.ContactAccount = contactAccount } } + ma, err := c.AccountToMastoPublic(ctx, i.ContactAccount) + if err == nil { + mi.ContactAccount = ma + } } return mi, nil diff --git a/testrig/db.go b/testrig/db.go index 3901bea77..b0e2afd04 100644 --- a/testrig/db.go +++ b/testrig/db.go @@ -24,7 +24,7 @@ import ( "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/db/pg" + "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/oauth" ) @@ -68,7 +68,7 @@ func NewTestDB() db.DB { l := logrus.New() l.SetLevel(logrus.TraceLevel) - testDB, err := pg.NewPostgresService(context.Background(), config, l) + testDB, err := bundb.NewBunDBService(context.Background(), config, l) if err != nil { logrus.Panic(err) } diff --git a/vendor/github.com/uptrace/bun/README.md b/vendor/github.com/uptrace/bun/README.md index bf57440da..e7cc77a60 100644 --- a/vendor/github.com/uptrace/bun/README.md +++ b/vendor/github.com/uptrace/bun/README.md @@ -11,10 +11,6 @@ [![Documentation](https://img.shields.io/badge/bun-documentation-informational)](https://bun.uptrace.dev/) [![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj) -Status: API freeze (release candidate). Note that all sub-packages (mainly extra/\* packages) are -not part of the API freeze and are developed independently. You can think of them as 3-rd party -packages. - Main features are: - Works with [PostgreSQL](https://bun.uptrace.dev/guide/drivers.html#postgresql), diff --git a/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod index c47d595d5..0cad1ce5b 100644 --- a/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod +++ b/vendor/github.com/uptrace/bun/dialect/pgdialect/go.mod @@ -4,4 +4,4 @@ go 1.16 replace github.com/uptrace/bun => ../.. -require github.com/uptrace/bun v1.0.0-rc.1 +require github.com/uptrace/bun v0.4.3 diff --git a/vendor/github.com/uptrace/bun/go.sum b/vendor/github.com/uptrace/bun/go.sum index 1f5ad5409..3bf0a4a3f 100644 --- a/vendor/github.com/uptrace/bun/go.sum +++ b/vendor/github.com/uptrace/bun/go.sum @@ -20,4 +20,4 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= \ No newline at end of file +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vendor/github.com/uptrace/bun/schema/formatter.go b/vendor/github.com/uptrace/bun/schema/formatter.go index 45a246307..7b26fbaca 100644 --- a/vendor/github.com/uptrace/bun/schema/formatter.go +++ b/vendor/github.com/uptrace/bun/schema/formatter.go @@ -89,10 +89,10 @@ func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) [] func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte { var namedArgs NamedArgAppender if len(args) == 1 { - if v, ok := args[0].(NamedArgAppender); ok { - namedArgs = v - } else if v, ok := newStructArgs(f, args[0]); ok { - namedArgs = v + var ok bool + namedArgs, ok = args[0].(NamedArgAppender) + if !ok { + namedArgs, _ = newStructArgs(f, args[0]) } } diff --git a/vendor/github.com/uptrace/bun/version.go b/vendor/github.com/uptrace/bun/version.go index 460909509..1baf9a39c 100644 --- a/vendor/github.com/uptrace/bun/version.go +++ b/vendor/github.com/uptrace/bun/version.go @@ -2,5 +2,5 @@ package bun // Version is the current release version. func Version() string { - return "1.0.0-rc.1" + return "0.4.3" } diff --git a/vendor/modules.txt b/vendor/modules.txt index 34ff7e957..a2edd4b90 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -389,7 +389,7 @@ github.com/tdewolff/parse/v2/strconv github.com/tmthrgd/go-hex # github.com/ugorji/go/codec v1.2.6 github.com/ugorji/go/codec -# github.com/uptrace/bun v1.0.0-rc.1 +# github.com/uptrace/bun v0.4.3 ## explicit github.com/uptrace/bun github.com/uptrace/bun/dialect @@ -400,7 +400,7 @@ github.com/uptrace/bun/internal github.com/uptrace/bun/internal/parser github.com/uptrace/bun/internal/tagparser github.com/uptrace/bun/schema -# github.com/uptrace/bun/dialect/pgdialect v1.0.0-rc.1 +# github.com/uptrace/bun/dialect/pgdialect v0.4.3 ## explicit github.com/uptrace/bun/dialect/pgdialect # github.com/urfave/cli/v2 v2.3.0