mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-11-24 05:43:32 -06:00
Pg to bun (#148)
* start moving to bun * changing more stuff * more * and yet more * tests passing * seems stable now * more big changes * small fix * little fixes
This commit is contained in:
parent
071eca20ce
commit
2dc9fc1626
713 changed files with 98694 additions and 22704 deletions
|
|
@ -28,7 +28,7 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/pg"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
|
@ -104,7 +104,7 @@ func (suite *AuthTestSuite) SetupTest() {
|
|||
|
||||
log := logrus.New()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
db, err := pg.NewPostgresService(context.Background(), suite.config, log)
|
||||
db, err := bundb.NewBunDBService(context.Background(), suite.config, log)
|
||||
if err != nil {
|
||||
logrus.Panicf("error creating database connection: %s", err)
|
||||
}
|
||||
|
|
@ -120,23 +120,23 @@ func (suite *AuthTestSuite) SetupTest() {
|
|||
}
|
||||
|
||||
for _, m := range models {
|
||||
if err := suite.db.CreateTable(m); err != nil {
|
||||
if err := suite.db.CreateTable(context.Background(), m); err != nil {
|
||||
logrus.Panicf("db connection error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
suite.oauthServer = oauth.New(suite.db, log)
|
||||
|
||||
if err := suite.db.Put(suite.testAccount); err != nil {
|
||||
if err := suite.db.Put(context.Background(), suite.testAccount); err != nil {
|
||||
logrus.Panicf("could not insert test account into db: %s", err)
|
||||
}
|
||||
if err := suite.db.Put(suite.testUser); err != nil {
|
||||
if err := suite.db.Put(context.Background(), suite.testUser); err != nil {
|
||||
logrus.Panicf("could not insert test user into db: %s", err)
|
||||
}
|
||||
if err := suite.db.Put(suite.testClient); err != nil {
|
||||
if err := suite.db.Put(context.Background(), suite.testClient); err != nil {
|
||||
logrus.Panicf("could not insert test client into db: %s", err)
|
||||
}
|
||||
if err := suite.db.Put(suite.testApplication); err != nil {
|
||||
if err := suite.db.Put(context.Background(), suite.testApplication); err != nil {
|
||||
logrus.Panicf("could not insert test application into db: %s", err)
|
||||
}
|
||||
|
||||
|
|
@ -152,7 +152,7 @@ func (suite *AuthTestSuite) TearDownTest() {
|
|||
>smodel.Application{},
|
||||
}
|
||||
for _, m := range models {
|
||||
if err := suite.db.DropTable(m); err != nil {
|
||||
if err := suite.db.DropTable(context.Background(), m); err != nil {
|
||||
logrus.Panicf("error dropping table: %s", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,30 +70,23 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
|
|||
c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"})
|
||||
return
|
||||
}
|
||||
app := >smodel.Application{
|
||||
ClientID: clientID,
|
||||
}
|
||||
if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
|
||||
app := >smodel.Application{}
|
||||
if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
|
||||
m.clearSession(s)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
|
||||
return
|
||||
}
|
||||
|
||||
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3
|
||||
user := >smodel.User{
|
||||
ID: userID,
|
||||
}
|
||||
if err := m.db.GetByID(user.ID, user); err != nil {
|
||||
user := >smodel.User{}
|
||||
if err := m.db.GetByID(c.Request.Context(), user.ID, user); err != nil {
|
||||
m.clearSession(s)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
acct := >smodel.Account{
|
||||
ID: user.AccountID,
|
||||
}
|
||||
|
||||
if err := m.db.GetByID(acct.ID, acct); err != nil {
|
||||
acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
|
||||
if err != nil {
|
||||
m.clearSession(s)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
|
@ -80,13 +81,13 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
|
|||
app := >smodel.Application{
|
||||
ClientID: clientID,
|
||||
}
|
||||
if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
|
||||
if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
|
||||
m.clearSession(s)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := m.parseUserFromClaims(claims, net.IP(c.ClientIP()), app.ID)
|
||||
user, err := m.parseUserFromClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID)
|
||||
if err != nil {
|
||||
m.clearSession(s)
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
|
|
@ -103,14 +104,14 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
|
|||
c.Redirect(http.StatusFound, OauthAuthorizePath)
|
||||
}
|
||||
|
||||
func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) {
|
||||
func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) {
|
||||
if claims.Email == "" {
|
||||
return nil, errors.New("no email returned in claims")
|
||||
}
|
||||
|
||||
// see if we already have a user for this email address
|
||||
user := >smodel.User{}
|
||||
err := m.db.GetWhere([]db.Where{{Key: "email", Value: claims.Email}}, user)
|
||||
err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user)
|
||||
if err == nil {
|
||||
// we do! so we can just return it
|
||||
return user, nil
|
||||
|
|
@ -122,7 +123,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
|
|||
}
|
||||
|
||||
// maybe we have an unconfirmed user
|
||||
err = m.db.GetWhere([]db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user)
|
||||
err = m.db.GetWhere(ctx, []db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user)
|
||||
if err == nil {
|
||||
// user is unconfirmed so return an error
|
||||
return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email)
|
||||
|
|
@ -137,9 +138,13 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
|
|||
// however, because we trust the OIDC provider, we should now create a user + account with the provided claims
|
||||
|
||||
// check if the email address is available for use; if it's not there's nothing we can so
|
||||
if err := m.db.IsEmailAvailable(claims.Email); err != nil {
|
||||
emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("email %s not available: %s", claims.Email, err)
|
||||
}
|
||||
if !emailAvailable {
|
||||
return nil, fmt.Errorf("email %s in use", claims.Email)
|
||||
}
|
||||
|
||||
// now we need a username
|
||||
var username string
|
||||
|
|
@ -180,12 +185,11 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
|
|||
// note that for the first iteration, iString is still "" when the check is made, so our first choice
|
||||
// is still the raw username with no integer stuck on the end
|
||||
for i := 1; !found; i = i + 1 {
|
||||
if err := m.db.IsUsernameAvailable(username + iString); err != nil {
|
||||
if strings.Contains(err.Error(), "db error") {
|
||||
// if there's an actual db error we should return
|
||||
return nil, fmt.Errorf("error checking username availability: %s", err)
|
||||
}
|
||||
} else {
|
||||
usernameAvailable, err := m.db.IsUsernameAvailable(ctx, username+iString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if usernameAvailable {
|
||||
// no error so we've found a username that works
|
||||
found = true
|
||||
username = username + iString
|
||||
|
|
@ -209,7 +213,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
|
|||
password := uuid.NewString() + uuid.NewString()
|
||||
|
||||
// create the user! this will also create an account and store it in the database so we don't need to do that here
|
||||
user, err = m.db.NewSignup(username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin)
|
||||
user, err = m.db.NewSignup(ctx, username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating user: %s", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,15 +49,15 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) {
|
|||
|
||||
// fetch user's and account for this user id
|
||||
user := >smodel.User{}
|
||||
if err := m.db.GetByID(uid, user); err != nil || user == nil {
|
||||
if err := m.db.GetByID(c.Request.Context(), uid, user); err != nil || user == nil {
|
||||
l.Warnf("no user found for validated uid %s", uid)
|
||||
return
|
||||
}
|
||||
c.Set(oauth.SessionAuthorizedUser, user)
|
||||
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user)
|
||||
|
||||
acct := >smodel.Account{}
|
||||
if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil {
|
||||
acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID)
|
||||
if err != nil || acct == nil {
|
||||
l.Warnf("no account found for validated user %s", uid)
|
||||
return
|
||||
}
|
||||
|
|
@ -69,7 +69,7 @@ func (m *Module) OauthTokenMiddleware(c *gin.Context) {
|
|||
if cid := ti.GetClientID(); cid != "" {
|
||||
l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope())
|
||||
app := >smodel.Application{}
|
||||
if err := m.db.GetWhere([]db.Where{{Key: "client_id", Value: cid}}, app); err != nil {
|
||||
if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: "client_id", Value: cid}}, app); err != nil {
|
||||
l.Tracef("no app found for client %s", cid)
|
||||
}
|
||||
c.Set(oauth.SessionAuthorizedApplication, app)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
|
|
@ -74,7 +75,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) {
|
|||
}
|
||||
l.Tracef("parsed form: %+v", form)
|
||||
|
||||
userid, err := m.ValidatePassword(form.Email, form.Password)
|
||||
userid, err := m.ValidatePassword(c.Request.Context(), form.Email, form.Password)
|
||||
if err != nil {
|
||||
c.String(http.StatusForbidden, err.Error())
|
||||
m.clearSession(s)
|
||||
|
|
@ -96,7 +97,7 @@ func (m *Module) SignInPOSTHandler(c *gin.Context) {
|
|||
// The goal is to authenticate the password against the one for that email
|
||||
// address stored in the database. If OK, we return the userid (a ulid) for that user,
|
||||
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
|
||||
func (m *Module) ValidatePassword(email string, password string) (userid string, err error) {
|
||||
func (m *Module) ValidatePassword(ctx context.Context, email string, password string) (userid string, err error) {
|
||||
l := m.log.WithField("func", "ValidatePassword")
|
||||
|
||||
// make sure an email/password was provided and bail if not
|
||||
|
|
@ -108,7 +109,7 @@ func (m *Module) ValidatePassword(email string, password string) (userid string,
|
|||
// first we select the user from the database based on email address, bail if no user found for that email
|
||||
gtsUser := >smodel.User{}
|
||||
|
||||
if err := m.db.GetWhere([]db.Where{{Key: "email", Value: email}}, gtsUser); err != nil {
|
||||
if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, gtsUser); err != nil {
|
||||
l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
|
||||
return incorrectPassword()
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue