Pg to bun (#148)

* start moving to bun

* changing more stuff

* more

* and yet more

* tests passing

* seems stable now

* more big changes

* small fix

* little fixes
This commit is contained in:
tobi 2021-08-25 15:34:33 +02:00 committed by GitHub
commit 2dc9fc1626
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
713 changed files with 98694 additions and 22704 deletions

View file

@ -101,7 +101,7 @@ func (m *Module) AccountCreatePOSTHandler(c *gin.Context) {
form.IP = signUpIP
ti, err := m.processor.AccountCreate(authed, form)
ti, err := m.processor.AccountCreate(c.Request.Context(), authed, form)
if err != nil {
l.Errorf("internal server error while creating new account: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})

View file

@ -70,7 +70,7 @@ func (m *Module) AccountGETHandler(c *gin.Context) {
return
}
acctInfo, err := m.processor.AccountGet(authed, targetAcctID)
acctInfo, err := m.processor.AccountGet(c.Request.Context(), authed, targetAcctID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
return

View file

@ -122,7 +122,7 @@ func (m *Module) AccountUpdateCredentialsPATCHHandler(c *gin.Context) {
return
}
acctSensitive, err := m.processor.AccountUpdate(authed, form)
acctSensitive, err := m.processor.AccountUpdate(c.Request.Context(), authed, form)
if err != nil {
l.Debugf("could not update account: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})

View file

@ -79,7 +79,7 @@ func (suite *AccountUpdateTestSuite) TestAccountUpdateCredentialsPATCHHandler()
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"])
ctx.Set(oauth.SessionAuthorizedToken, oauth.TokenToOauthToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", account.UpdateCredentialsPath), bytes.NewReader(requestBody.Bytes())) // the endpoint we're hitting

View file

@ -59,7 +59,7 @@ func (m *Module) AccountVerifyGETHandler(c *gin.Context) {
return
}
acctSensitive, err := m.processor.AccountGet(authed, authed.Account.ID)
acctSensitive, err := m.processor.AccountGet(c.Request.Context(), authed, authed.Account.ID)
if err != nil {
l.Debugf("error getting account from processor: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})

View file

@ -72,7 +72,7 @@ func (m *Module) AccountBlockPOSTHandler(c *gin.Context) {
return
}
relationship, errWithCode := m.processor.AccountBlockCreate(authed, targetAcctID)
relationship, errWithCode := m.processor.AccountBlockCreate(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return

View file

@ -99,7 +99,7 @@ func (m *Module) AccountFollowPOSTHandler(c *gin.Context) {
}
form.ID = targetAcctID
relationship, errWithCode := m.processor.AccountFollowCreate(authed, form)
relationship, errWithCode := m.processor.AccountFollowCreate(c.Request.Context(), authed, form)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return

View file

@ -74,7 +74,7 @@ func (m *Module) AccountFollowersGETHandler(c *gin.Context) {
return
}
followers, errWithCode := m.processor.AccountFollowersGet(authed, targetAcctID)
followers, errWithCode := m.processor.AccountFollowersGet(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return

View file

@ -74,7 +74,7 @@ func (m *Module) AccountFollowingGETHandler(c *gin.Context) {
return
}
following, errWithCode := m.processor.AccountFollowingGet(authed, targetAcctID)
following, errWithCode := m.processor.AccountFollowingGet(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return

View file

@ -71,7 +71,7 @@ func (m *Module) AccountRelationshipsGETHandler(c *gin.Context) {
relationships := []model.Relationship{}
for _, targetAccountID := range targetAccountIDs {
r, errWithCode := m.processor.AccountRelationshipGet(authed, targetAccountID)
r, errWithCode := m.processor.AccountRelationshipGet(c.Request.Context(), authed, targetAccountID)
if err != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return

View file

@ -166,7 +166,7 @@ func (m *Module) AccountStatusesGETHandler(c *gin.Context) {
mediaOnly = i
}
statuses, errWithCode := m.processor.AccountStatusesGet(authed, targetAcctID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
statuses, errWithCode := m.processor.AccountStatusesGet(c.Request.Context(), authed, targetAcctID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
if errWithCode != nil {
l.Debugf("error from processor account statuses get: %s", errWithCode)
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})

View file

@ -72,7 +72,7 @@ func (m *Module) AccountUnblockPOSTHandler(c *gin.Context) {
return
}
relationship, errWithCode := m.processor.AccountBlockRemove(authed, targetAcctID)
relationship, errWithCode := m.processor.AccountBlockRemove(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return

View file

@ -75,7 +75,7 @@ func (m *Module) AccountUnfollowPOSTHandler(c *gin.Context) {
return
}
relationship, errWithCode := m.processor.AccountFollowRemove(authed, targetAcctID)
relationship, errWithCode := m.processor.AccountFollowRemove(c.Request.Context(), authed, targetAcctID)
if errWithCode != nil {
l.Debug(errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})

View file

@ -141,7 +141,7 @@ func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) {
if imp {
// we're importing multiple blocks
domainBlocks, err := m.processor.AdminDomainBlocksImport(authed, form)
domainBlocks, err := m.processor.AdminDomainBlocksImport(c.Request.Context(), authed, form)
if err != nil {
l.Debugf("error importing domain blocks: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@ -150,7 +150,7 @@ func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) {
c.JSON(http.StatusOK, domainBlocks)
} else {
// we're just creating one block
domainBlock, err := m.processor.AdminDomainBlockCreate(authed, form)
domainBlock, err := m.processor.AdminDomainBlockCreate(c.Request.Context(), authed, form)
if err != nil {
l.Debugf("error creating domain block: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})

View file

@ -68,7 +68,7 @@ func (m *Module) DomainBlockDELETEHandler(c *gin.Context) {
return
}
domainBlock, errWithCode := m.processor.AdminDomainBlockDelete(authed, domainBlockID)
domainBlock, errWithCode := m.processor.AdminDomainBlockDelete(c.Request.Context(), authed, domainBlockID)
if errWithCode != nil {
l.Debugf("error deleting domain block: %s", errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})

View file

@ -81,7 +81,7 @@ func (m *Module) DomainBlockGETHandler(c *gin.Context) {
export = i
}
domainBlock, err := m.processor.AdminDomainBlockGet(authed, domainBlockID, export)
domainBlock, err := m.processor.AdminDomainBlockGet(c.Request.Context(), authed, domainBlockID, export)
if err != nil {
l.Debugf("error getting domain block: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})

View file

@ -81,7 +81,7 @@ func (m *Module) DomainBlocksGETHandler(c *gin.Context) {
export = i
}
domainBlocks, err := m.processor.AdminDomainBlocksGet(authed, export)
domainBlocks, err := m.processor.AdminDomainBlocksGet(c.Request.Context(), authed, export)
if err != nil {
l.Debugf("error getting domain blocks: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})

View file

@ -111,7 +111,7 @@ func (m *Module) emojiCreatePOSTHandler(c *gin.Context) {
return
}
mastoEmoji, err := m.processor.AdminEmojiCreate(authed, form)
mastoEmoji, err := m.processor.AdminEmojiCreate(c.Request.Context(), authed, form)
if err != nil {
l.Debugf("error creating emoji: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})

View file

@ -101,7 +101,7 @@ func (m *Module) AppsPOSTHandler(c *gin.Context) {
return
}
mastoApp, err := m.processor.AppCreate(authed, form)
mastoApp, err := m.processor.AppCreate(c.Request.Context(), authed, form)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return

View file

@ -28,7 +28,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/pg"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"golang.org/x/crypto/bcrypt"
@ -104,7 +104,7 @@ func (suite *AuthTestSuite) SetupTest() {
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
db, err := pg.NewPostgresService(context.Background(), suite.config, log)
db, err := bundb.NewBunDBService(context.Background(), suite.config, log)
if err != nil {
logrus.Panicf("error creating database connection: %s", err)
}
@ -120,23 +120,23 @@ func (suite *AuthTestSuite) SetupTest() {
}
for _, m := range models {
if err := suite.db.CreateTable(m); err != nil {
if err := suite.db.CreateTable(context.Background(), m); err != nil {
logrus.Panicf("db connection error: %s", err)
}
}
suite.oauthServer = oauth.New(suite.db, log)
if err := suite.db.Put(suite.testAccount); err != nil {
if err := suite.db.Put(context.Background(), suite.testAccount); err != nil {
logrus.Panicf("could not insert test account into db: %s", err)
}
if err := suite.db.Put(suite.testUser); err != nil {
if err := suite.db.Put(context.Background(), suite.testUser); err != nil {
logrus.Panicf("could not insert test user into db: %s", err)
}
if err := suite.db.Put(suite.testClient); err != nil {
if err := suite.db.Put(context.Background(), suite.testClient); err != nil {
logrus.Panicf("could not insert test client into db: %s", err)
}
if err := suite.db.Put(suite.testApplication); err != nil {
if err := suite.db.Put(context.Background(), suite.testApplication); err != nil {
logrus.Panicf("could not insert test application into db: %s", err)
}
@ -152,7 +152,7 @@ func (suite *AuthTestSuite) TearDownTest() {
&gtsmodel.Application{},
}
for _, m := range models {
if err := suite.db.DropTable(m); err != nil {
if err := suite.db.DropTable(context.Background(), m); err != nil {
logrus.Panicf("error dropping table: %s", err)
}
}

View file

@ -70,30 +70,23 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"})
return
}
app := &gtsmodel.Application{
ClientID: clientID,
}
if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
app := &gtsmodel.Application{}
if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
return
}
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3
user := &gtsmodel.User{
ID: userID,
}
if err := m.db.GetByID(user.ID, user); err != nil {
user := &gtsmodel.User{}
if err := m.db.GetByID(c.Request.Context(), user.ID, user); err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
acct := &gtsmodel.Account{
ID: user.AccountID,
}
if err := m.db.GetByID(acct.ID, acct); err != nil {
acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
if err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return

View file

@ -19,6 +19,7 @@
package auth
import (
"context"
"errors"
"fmt"
"net"
@ -80,13 +81,13 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
app := &gtsmodel.Application{
ClientID: clientID,
}
if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
m.clearSession(s)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
return
}
user, err := m.parseUserFromClaims(claims, net.IP(c.ClientIP()), app.ID)
user, err := m.parseUserFromClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID)
if err != nil {
m.clearSession(s)
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
@ -103,14 +104,14 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
c.Redirect(http.StatusFound, OauthAuthorizePath)
}
func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) {
func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) {
if claims.Email == "" {
return nil, errors.New("no email returned in claims")
}
// see if we already have a user for this email address
user := &gtsmodel.User{}
err := m.db.GetWhere([]db.Where{{Key: "email", Value: claims.Email}}, user)
err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user)
if err == nil {
// we do! so we can just return it
return user, nil
@ -122,7 +123,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
}
// maybe we have an unconfirmed user
err = m.db.GetWhere([]db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user)
err = m.db.GetWhere(ctx, []db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user)
if err == nil {
// user is unconfirmed so return an error
return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email)
@ -137,9 +138,13 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
// however, because we trust the OIDC provider, we should now create a user + account with the provided claims
// check if the email address is available for use; if it's not there's nothing we can so
if err := m.db.IsEmailAvailable(claims.Email); err != nil {
emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email)
if err != nil {
return nil, fmt.Errorf("email %s not available: %s", claims.Email, err)
}
if !emailAvailable {
return nil, fmt.Errorf("email %s in use", claims.Email)
}
// now we need a username
var username string
@ -180,12 +185,11 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
// note that for the first iteration, iString is still "" when the check is made, so our first choice
// is still the raw username with no integer stuck on the end
for i := 1; !found; i = i + 1 {
if err := m.db.IsUsernameAvailable(username + iString); err != nil {
if strings.Contains(err.Error(), "db error") {
// if there's an actual db error we should return
return nil, fmt.Errorf("error checking username availability: %s", err)
}
} else {
usernameAvailable, err := m.db.IsUsernameAvailable(ctx, username+iString)
if err != nil {
return nil, err
}
if usernameAvailable {
// no error so we've found a username that works
found = true
username = username + iString
@ -209,7 +213,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
password := uuid.NewString() + uuid.NewString()
// create the user! this will also create an account and store it in the database so we don't need to do that here
user, err = m.db.NewSignup(username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin)
user, err = m.db.NewSignup(ctx, username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin)
if err != nil {
return nil, fmt.Errorf("error creating user: %s", err)
}

View file

@ -49,15 +49,15 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) {
// fetch user's and account for this user id
user := &gtsmodel.User{}
if err := m.db.GetByID(uid, user); err != nil || user == nil {
if err := m.db.GetByID(c.Request.Context(), uid, user); err != nil || user == nil {
l.Warnf("no user found for validated uid %s", uid)
return
}
c.Set(oauth.SessionAuthorizedUser, user)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user)
acct := &gtsmodel.Account{}
if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil {
acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
if err != nil || acct == nil {
l.Warnf("no account found for validated user %s", uid)
return
}
@ -69,7 +69,7 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) {
if cid := ti.GetClientID(); cid != "" {
l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope())
app := &gtsmodel.Application{}
if err := m.db.GetWhere([]db.Where{{Key: "client_id", Value: cid}}, app); err != nil {
if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: "client_id", Value: cid}}, app); err != nil {
l.Tracef("no app found for client %s", cid)
}
c.Set(oauth.SessionAuthorizedApplication, app)

View file

@ -19,6 +19,7 @@
package auth
import (
"context"
"errors"
"net/http"
@ -74,7 +75,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) {
}
l.Tracef("parsed form: %+v", form)
userid, err := m.ValidatePassword(form.Email, form.Password)
userid, err := m.ValidatePassword(c.Request.Context(), form.Email, form.Password)
if err != nil {
c.String(http.StatusForbidden, err.Error())
m.clearSession(s)
@ -96,7 +97,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) {
// The goal is to authenticate the password against the one for that email
// address stored in the database. If OK, we return the userid (a ulid) for that user,
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
func (m *Module) ValidatePassword(email string, password string) (userid string, err error) {
func (m *Module) ValidatePassword(ctx context.Context, email string, password string) (userid string, err error) {
l := m.log.WithField("func", "ValidatePassword")
// make sure an email/password was provided and bail if not
@ -108,7 +109,7 @@ func (m *Module) ValidatePassword(email string, password string) (userid string,
// first we select the user from the database based on email address, bail if no user found for that email
gtsUser := &gtsmodel.User{}
if err := m.db.GetWhere([]db.Where{{Key: "email", Value: email}}, gtsUser); err != nil {
if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, gtsUser); err != nil {
l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
return incorrectPassword()
}

View file

@ -117,7 +117,7 @@ func (m *Module) BlocksGETHandler(c *gin.Context) {
limit = int(i)
}
resp, errWithCode := m.processor.BlocksGet(authed, maxID, sinceID, limit)
resp, errWithCode := m.processor.BlocksGet(c.Request.Context(), authed, maxID, sinceID, limit)
if errWithCode != nil {
l.Debugf("error from processor BlocksGet: %s", errWithCode)
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})

View file

@ -43,7 +43,7 @@ func (m *Module) FavouritesGETHandler(c *gin.Context) {
limit = int(i)
}
resp, errWithCode := m.processor.FavedTimelineGet(authed, maxID, minID, limit)
resp, errWithCode := m.processor.FavedTimelineGet(c.Request.Context(), authed, maxID, minID, limit)
if errWithCode != nil {
l.Debugf("error from processor FavedTimelineGet: %s", errWithCode)
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})

View file

@ -25,8 +25,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
)
@ -66,17 +64,3 @@ func (m *FileServer) Route(s router.Router) error {
s.AttachHandler(http.MethodGet, fmt.Sprintf("%s/:%s/:%s/:%s/:%s", m.storageBase, AccountIDKey, MediaTypeKey, MediaSizeKey, FileNameKey), m.ServeFile)
return nil
}
// CreateTables populates necessary tables in the given DB
func (m *FileServer) CreateTables(db db.DB) error {
models := []interface{}{
&gtsmodel.MediaAttachment{},
}
for _, m := range models {
if err := db.CreateTable(m); err != nil {
return fmt.Errorf("error creating table: %s", err)
}
}
return nil
}

View file

@ -78,7 +78,7 @@ func (m *FileServer) ServeFile(c *gin.Context) {
return
}
content, err := m.processor.FileGet(authed, &model.GetContentRequestForm{
content, err := m.processor.FileGet(c.Request.Context(), authed, &model.GetContentRequestForm{
AccountID: accountID,
MediaType: mediaType,
MediaSize: mediaSize,

View file

@ -48,7 +48,7 @@ func (m *Module) FollowRequestAcceptPOSTHandler(c *gin.Context) {
return
}
r, errWithCode := m.processor.FollowRequestAccept(authed, originAccountID)
r, errWithCode := m.processor.FollowRequestAccept(c.Request.Context(), authed, originAccountID)
if errWithCode != nil {
l.Debug(errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})

View file

@ -41,7 +41,7 @@ func (m *Module) FollowRequestGETHandler(c *gin.Context) {
return
}
accts, errWithCode := m.processor.FollowRequestsGet(authed)
accts, errWithCode := m.processor.FollowRequestsGet(c.Request.Context(), authed)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return

View file

@ -31,7 +31,7 @@ import (
func (m *Module) InstanceInformationGETHandler(c *gin.Context) {
l := m.log.WithField("func", "InstanceInformationGETHandler")
instance, err := m.processor.InstanceGet(m.config.Host)
instance, err := m.processor.InstanceGet(c.Request.Context(), m.config.Host)
if err != nil {
l.Debugf("error getting instance from processor: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"})

View file

@ -116,7 +116,7 @@ func (m *Module) InstanceUpdatePATCHHandler(c *gin.Context) {
return
}
i, errWithCode := m.processor.InstancePatch(form)
i, errWithCode := m.processor.InstancePatch(c.Request.Context(), form)
if errWithCode != nil {
l.Debugf("error with instance patch request: %s", errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})

View file

@ -19,14 +19,11 @@
package media
import (
"fmt"
"net/http"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/processing"
"github.com/superseriousbusiness/gotosocial/internal/router"
)
@ -63,17 +60,3 @@ func (m *Module) Route(s router.Router) error {
s.AttachHandler(http.MethodPut, BasePathWithID, m.MediaPUTHandler)
return nil
}
// CreateTables populates necessary tables in the given DB
func (m *Module) CreateTables(db db.DB) error {
models := []interface{}{
&gtsmodel.MediaAttachment{},
}
for _, m := range models {
if err := db.CreateTable(m); err != nil {
return fmt.Errorf("error creating table: %s", err)
}
}
return nil
}

View file

@ -108,7 +108,7 @@ func (m *Module) MediaCreatePOSTHandler(c *gin.Context) {
}
l.Debug("calling processor media create func")
mastoAttachment, err := m.processor.MediaCreate(authed, form)
mastoAttachment, err := m.processor.MediaCreate(c.Request.Context(), authed, form)
if err != nil {
l.Debugf("error creating attachment: %s", err)
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})

View file

@ -121,7 +121,7 @@ func (suite *MediaCreateTestSuite) TestStatusCreatePOSTImageHandlerSuccessful()
// set up the context for the request
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])

View file

@ -75,7 +75,7 @@ func (m *Module) MediaGETHandler(c *gin.Context) {
return
}
attachment, errWithCode := m.processor.MediaGet(authed, attachmentID)
attachment, errWithCode := m.processor.MediaGet(c.Request.Context(), authed, attachmentID)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return

View file

@ -122,7 +122,7 @@ func (m *Module) MediaPUTHandler(c *gin.Context) {
return
}
attachment, errWithCode := m.processor.MediaUpdate(authed, attachmentID, &form)
attachment, errWithCode := m.processor.MediaUpdate(c.Request.Context(), authed, attachmentID, &form)
if errWithCode != nil {
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})
return

View file

@ -68,7 +68,7 @@ func (m *Module) NotificationsGETHandler(c *gin.Context) {
sinceID = sinceIDString
}
notifs, errWithCode := m.processor.NotificationsGet(authed, limit, maxID, sinceID)
notifs, errWithCode := m.processor.NotificationsGet(c.Request.Context(), authed, limit, maxID, sinceID)
if errWithCode != nil {
l.Debugf("error processing notifications get: %s", errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})

View file

@ -164,7 +164,7 @@ func (m *Module) SearchGETHandler(c *gin.Context) {
Following: following,
}
results, errWithCode := m.processor.SearchGet(authed, searchQuery)
results, errWithCode := m.processor.SearchGet(c.Request.Context(), authed, searchQuery)
if errWithCode != nil {
l.Debugf("error searching: %s", errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})

View file

@ -87,7 +87,7 @@ func (m *Module) StatusBoostPOSTHandler(c *gin.Context) {
return
}
mastoStatus, errWithCode := m.processor.StatusBoost(authed, targetStatusID)
mastoStatus, errWithCode := m.processor.StatusBoost(c.Request.Context(), authed, targetStatusID)
if errWithCode != nil {
l.Debugf("error processing status boost: %s", errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})

View file

@ -67,7 +67,7 @@ func (suite *StatusBoostTestSuite) TearDownTest() {
func (suite *StatusBoostTestSuite) TestPostBoost() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["admin_account_status_1"]
@ -133,7 +133,7 @@ func (suite *StatusBoostTestSuite) TestPostBoost() {
func (suite *StatusBoostTestSuite) TestPostUnboostable() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["local_account_2_status_4"]
@ -171,7 +171,7 @@ func (suite *StatusBoostTestSuite) TestPostUnboostable() {
func (suite *StatusBoostTestSuite) TestPostNotVisible() {
t := suite.testTokens["local_account_2"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["local_account_1_status_3"] // this is a mutual only status and these accounts aren't mutuals

View file

@ -84,7 +84,7 @@ func (m *Module) StatusBoostedByGETHandler(c *gin.Context) {
return
}
mastoAccounts, err := m.processor.StatusBoostedBy(authed, targetStatusID)
mastoAccounts, err := m.processor.StatusBoostedBy(c.Request.Context(), authed, targetStatusID)
if err != nil {
l.Debugf("error processing status boosted by request: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})

View file

@ -86,7 +86,7 @@ func (m *Module) StatusContextGETHandler(c *gin.Context) {
return
}
statusContext, errWithCode := m.processor.StatusGetContext(authed, targetStatusID)
statusContext, errWithCode := m.processor.StatusGetContext(c.Request.Context(), authed, targetStatusID)
if errWithCode != nil {
l.Debugf("error getting status context: %s", errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})

View file

@ -101,7 +101,7 @@ func (m *Module) StatusCreatePOSTHandler(c *gin.Context) {
return
}
mastoStatus, err := m.processor.StatusCreate(authed, form)
mastoStatus, err := m.processor.StatusCreate(c.Request.Context(), authed, form)
if err != nil {
l.Debugf("error processing status create: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})

View file

@ -19,6 +19,7 @@
package status_test
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
@ -82,7 +83,7 @@ https://docs.gotosocial.org/en/latest/user_guide/posts/#links
func (suite *StatusCreateTestSuite) TestPostNewStatus() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@ -128,7 +129,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatus() {
}, statusReply.Tags[0])
gtsTag := &gtsmodel.Tag{}
err = suite.db.GetWhere([]db.Where{{Key: "name", Value: "helloworld"}}, gtsTag)
err = suite.db.GetWhere(context.Background(), []db.Where{{Key: "name", Value: "helloworld"}}, gtsTag)
assert.NoError(suite.T(), err)
assert.Equal(suite.T(), statusReply.Account.ID, gtsTag.FirstSeenFromAccountID)
}
@ -136,7 +137,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatus() {
func (suite *StatusCreateTestSuite) TestPostAnotherNewStatus() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@ -171,7 +172,7 @@ func (suite *StatusCreateTestSuite) TestPostAnotherNewStatus() {
func (suite *StatusCreateTestSuite) TestPostNewStatusWithEmoji() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@ -212,7 +213,7 @@ func (suite *StatusCreateTestSuite) TestPostNewStatusWithEmoji() {
// Try to reply to a status that doesn't exist
func (suite *StatusCreateTestSuite) TestReplyToNonexistentStatus() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@ -243,7 +244,7 @@ func (suite *StatusCreateTestSuite) TestReplyToNonexistentStatus() {
// Post a reply to the status of a local user that allows replies.
func (suite *StatusCreateTestSuite) TestReplyToLocalStatus() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
// setup
recorder := httptest.NewRecorder()
@ -283,7 +284,7 @@ func (suite *StatusCreateTestSuite) TestReplyToLocalStatus() {
// Take a media file which is currently not associated with a status, and attach it to a new status.
func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
attachment := suite.testAttachments["local_account_1_unattached_1"]
@ -322,12 +323,11 @@ func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() {
assert.Len(suite.T(), statusResponse.MediaAttachments, 1)
// get the updated media attachment from the database
gtsAttachment := &gtsmodel.MediaAttachment{}
err = suite.db.GetByID(statusResponse.MediaAttachments[0].ID, gtsAttachment)
gtsAttachment, err := suite.db.GetAttachmentByID(context.Background(), statusResponse.MediaAttachments[0].ID)
assert.NoError(suite.T(), err)
// convert it to a masto attachment
gtsAttachmentAsMasto, err := suite.tc.AttachmentToMasto(gtsAttachment)
gtsAttachmentAsMasto, err := suite.tc.AttachmentToMasto(context.Background(), gtsAttachment)
assert.NoError(suite.T(), err)
// compare it with what we have now

View file

@ -86,7 +86,7 @@ func (m *Module) StatusDELETEHandler(c *gin.Context) {
return
}
mastoStatus, err := m.processor.StatusDelete(authed, targetStatusID)
mastoStatus, err := m.processor.StatusDelete(c.Request.Context(), authed, targetStatusID)
if err != nil {
l.Debugf("error processing status delete: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})

View file

@ -83,7 +83,7 @@ func (m *Module) StatusFavePOSTHandler(c *gin.Context) {
return
}
mastoStatus, err := m.processor.StatusFave(authed, targetStatusID)
mastoStatus, err := m.processor.StatusFave(c.Request.Context(), authed, targetStatusID)
if err != nil {
l.Debugf("error processing status fave: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})

View file

@ -71,7 +71,7 @@ func (suite *StatusFaveTestSuite) TearDownTest() {
func (suite *StatusFaveTestSuite) TestPostFave() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["admin_account_status_2"]
@ -119,7 +119,7 @@ func (suite *StatusFaveTestSuite) TestPostFave() {
func (suite *StatusFaveTestSuite) TestPostUnfaveable() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["local_account_2_status_3"] // this one is unlikeable and unreplyable

View file

@ -84,7 +84,7 @@ func (m *Module) StatusFavedByGETHandler(c *gin.Context) {
return
}
mastoAccounts, err := m.processor.StatusFavedBy(authed, targetStatusID)
mastoAccounts, err := m.processor.StatusFavedBy(c.Request.Context(), authed, targetStatusID)
if err != nil {
l.Debugf("error processing status faved by request: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})

View file

@ -69,7 +69,7 @@ func (suite *StatusFavedByTestSuite) TearDownTest() {
func (suite *StatusFavedByTestSuite) TestGetFavedBy() {
t := suite.testTokens["local_account_2"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
targetStatus := suite.testStatuses["admin_account_status_1"] // this status is faved by local_account_1

View file

@ -83,7 +83,7 @@ func (m *Module) StatusGETHandler(c *gin.Context) {
return
}
mastoStatus, err := m.processor.StatusGet(authed, targetStatusID)
mastoStatus, err := m.processor.StatusGet(c.Request.Context(), authed, targetStatusID)
if err != nil {
l.Debugf("error processing status get: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})

View file

@ -84,7 +84,7 @@ func (m *Module) StatusUnboostPOSTHandler(c *gin.Context) {
return
}
mastoStatus, errWithCode := m.processor.StatusUnboost(authed, targetStatusID)
mastoStatus, errWithCode := m.processor.StatusUnboost(c.Request.Context(), authed, targetStatusID)
if errWithCode != nil {
l.Debugf("error processing status unboost: %s", errWithCode.Error())
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})

View file

@ -83,7 +83,7 @@ func (m *Module) StatusUnfavePOSTHandler(c *gin.Context) {
return
}
mastoStatus, err := m.processor.StatusUnfave(authed, targetStatusID)
mastoStatus, err := m.processor.StatusUnfave(c.Request.Context(), authed, targetStatusID)
if err != nil {
l.Debugf("error processing status unfave: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "bad request"})

View file

@ -71,7 +71,7 @@ func (suite *StatusUnfaveTestSuite) TearDownTest() {
func (suite *StatusUnfaveTestSuite) TestPostUnfave() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
// this is the status we wanna unfave: in the testrig it's already faved by this account
targetStatus := suite.testStatuses["admin_account_status_1"]
@ -120,7 +120,7 @@ func (suite *StatusUnfaveTestSuite) TestPostUnfave() {
func (suite *StatusUnfaveTestSuite) TestPostAlreadyNotFaved() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.TokenToOauthToken(t)
oauthToken := oauth.DBTokenToToken(t)
// this is the status we wanna unfave: in the testrig it's not faved by this account
targetStatus := suite.testStatuses["admin_account_status_2"]

View file

@ -122,7 +122,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
}
// make sure a valid token has been provided and obtain the associated account
account, err := m.processor.AuthorizeStreamingRequest(accessToken)
account, err := m.processor.AuthorizeStreamingRequest(c.Request.Context(), accessToken)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "could not authorize with given token"})
return
@ -147,7 +147,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
defer conn.Close() // whatever happens, when we leave this function we want to close the websocket connection
// inform the processor that we have a new connection and want a stream for it
stream, errWithCode := m.processor.OpenStreamForAccount(account, streamType)
stream, errWithCode := m.processor.OpenStreamForAccount(c.Request.Context(), account, streamType)
if errWithCode != nil {
c.JSON(errWithCode.Code(), errWithCode.Safe())
return

View file

@ -153,7 +153,7 @@ func (m *Module) HomeTimelineGETHandler(c *gin.Context) {
local = i
}
resp, errWithCode := m.processor.HomeTimelineGet(authed, maxID, sinceID, minID, limit, local)
resp, errWithCode := m.processor.HomeTimelineGet(c.Request.Context(), authed, maxID, sinceID, minID, limit, local)
if errWithCode != nil {
l.Debugf("error from processor HomeTimelineGet: %s", errWithCode)
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})

View file

@ -153,7 +153,7 @@ func (m *Module) PublicTimelineGETHandler(c *gin.Context) {
local = i
}
resp, errWithCode := m.processor.PublicTimelineGet(authed, maxID, sinceID, minID, limit, local)
resp, errWithCode := m.processor.PublicTimelineGet(c.Request.Context(), authed, maxID, sinceID, minID, limit, local)
if errWithCode != nil {
l.Debugf("error from processor PublicTimelineGet: %s", errWithCode)
c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()})