mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 00:12:25 -05:00 
			
		
		
		
	[feature] add authorization to the already-existing authentication (#365)
* add ensureUserIsAuthorizedOrRedirect to /oauth/authorize * adding authorization (email confirm, account approve, etc) to TokenCheck * revert un-needed changes to signin.go * oops what happened here * error css * add account.SuspendedAt check * remove redundant checks from oauth util Authed function * wip tests * tests passing * stop stripping useful information from ErrAlreadyExists * that feeling of scraping the dryer LINT off the screen * oops I didn't mean to get rid of this NewTestRouter function * make tests work with recorder * re-add ConfigureTemplatesWithGin to handle template path err Co-authored-by: tsmethurst <tobi.smethurst@protonmail.com>
This commit is contained in:
		
					parent
					
						
							
								5c9d20cea3
							
						
					
				
			
			
				commit
				
					
						6ed368cbeb
					
				
			
		
					 19 changed files with 424 additions and 47 deletions
				
			
		|  | @ -32,10 +32,23 @@ import ( | ||||||
| const ( | const ( | ||||||
| 	// AuthSignInPath is the API path for users to sign in through | 	// AuthSignInPath is the API path for users to sign in through | ||||||
| 	AuthSignInPath = "/auth/sign_in" | 	AuthSignInPath = "/auth/sign_in" | ||||||
|  | 
 | ||||||
|  | 	// CheckYourEmailPath users land here after registering a new account, instructs them to confirm thier email | ||||||
|  | 	CheckYourEmailPath = "/check_your_email" | ||||||
|  | 
 | ||||||
|  | 	// WaitForApprovalPath users land here after confirming thier email but before an admin approves thier account | ||||||
|  | 	// (if such is required) | ||||||
|  | 	WaitForApprovalPath = "/wait_for_approval" | ||||||
|  | 
 | ||||||
|  | 	// AccountDisabledPath users land here when thier account is suspended by an admin | ||||||
|  | 	AccountDisabledPath = "/account_disabled" | ||||||
|  | 
 | ||||||
| 	// OauthTokenPath is the API path to use for granting token requests to users with valid credentials | 	// OauthTokenPath is the API path to use for granting token requests to users with valid credentials | ||||||
| 	OauthTokenPath = "/oauth/token" | 	OauthTokenPath = "/oauth/token" | ||||||
|  | 
 | ||||||
| 	// OauthAuthorizePath is the API path for authorization requests (eg., authorize this app to act on my behalf as a user) | 	// OauthAuthorizePath is the API path for authorization requests (eg., authorize this app to act on my behalf as a user) | ||||||
| 	OauthAuthorizePath = "/oauth/authorize" | 	OauthAuthorizePath = "/oauth/authorize" | ||||||
|  | 
 | ||||||
| 	// CallbackPath is the API path for receiving callback tokens from external OIDC providers | 	// CallbackPath is the API path for receiving callback tokens from external OIDC providers | ||||||
| 	CallbackPath = oidc.CallbackPath | 	CallbackPath = oidc.CallbackPath | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -18,4 +18,96 @@ | ||||||
| 
 | 
 | ||||||
| package auth_test | package auth_test | ||||||
| 
 | 
 | ||||||
| // TODO | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 
 | ||||||
|  | 	"github.com/gin-contrib/sessions" | ||||||
|  | 	"github.com/gin-contrib/sessions/memstore" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/spf13/viper" | ||||||
|  | 	"github.com/stretchr/testify/suite" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/api/client/auth" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/oidc" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/router" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/testrig" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type AuthStandardTestSuite struct { | ||||||
|  | 	suite.Suite | ||||||
|  | 	db          db.DB | ||||||
|  | 	idp         oidc.IDP | ||||||
|  | 	oauthServer oauth.Server | ||||||
|  | 
 | ||||||
|  | 	// standard suite models | ||||||
|  | 	testTokens       map[string]*gtsmodel.Token | ||||||
|  | 	testClients      map[string]*gtsmodel.Client | ||||||
|  | 	testApplications map[string]*gtsmodel.Application | ||||||
|  | 	testUsers        map[string]*gtsmodel.User | ||||||
|  | 	testAccounts     map[string]*gtsmodel.Account | ||||||
|  | 
 | ||||||
|  | 	// module being tested | ||||||
|  | 	authModule *auth.Module | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	sessionUserID   = "userid" | ||||||
|  | 	sessionClientID = "client_id" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func (suite *AuthStandardTestSuite) SetupSuite() { | ||||||
|  | 	suite.testTokens = testrig.NewTestTokens() | ||||||
|  | 	suite.testClients = testrig.NewTestClients() | ||||||
|  | 	suite.testApplications = testrig.NewTestApplications() | ||||||
|  | 	suite.testUsers = testrig.NewTestUsers() | ||||||
|  | 	suite.testAccounts = testrig.NewTestAccounts() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *AuthStandardTestSuite) SetupTest() { | ||||||
|  | 	testrig.InitTestConfig() | ||||||
|  | 	suite.db = testrig.NewTestDB() | ||||||
|  | 	testrig.InitTestLog() | ||||||
|  | 
 | ||||||
|  | 	suite.oauthServer = testrig.NewTestOauthServer(suite.db) | ||||||
|  | 	var err error | ||||||
|  | 	suite.idp, err = oidc.NewIDP(context.Background()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  | 	suite.authModule = auth.New(suite.db, suite.oauthServer, suite.idp).(*auth.Module) | ||||||
|  | 	testrig.StandardDBSetup(suite.db, nil) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *AuthStandardTestSuite) TearDownTest() { | ||||||
|  | 	testrig.StandardDBTeardown(suite.db) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *AuthStandardTestSuite) newContext(requestMethod string, requestPath string) (*gin.Context, *httptest.ResponseRecorder) { | ||||||
|  | 	// create the recorder and gin test context | ||||||
|  | 	recorder := httptest.NewRecorder() | ||||||
|  | 	ctx, engine := gin.CreateTestContext(recorder) | ||||||
|  | 
 | ||||||
|  | 	// load templates into the engine | ||||||
|  | 	testrig.ConfigureTemplatesWithGin(engine) | ||||||
|  | 
 | ||||||
|  | 	// create the request | ||||||
|  | 	protocol := viper.GetString(config.Keys.Protocol) | ||||||
|  | 	host := viper.GetString(config.Keys.Host) | ||||||
|  | 	baseURI := fmt.Sprintf("%s://%s", protocol, host) | ||||||
|  | 	requestURI := fmt.Sprintf("%s/%s", baseURI, requestPath) | ||||||
|  | 	ctx.Request = httptest.NewRequest(requestMethod, requestURI, nil) // the endpoint we're hitting | ||||||
|  | 	ctx.Request.Header.Set("accept", "text/html") | ||||||
|  | 
 | ||||||
|  | 	// trigger the session middleware on the context | ||||||
|  | 	store := memstore.NewStore(make([]byte, 32), make([]byte, 32)) | ||||||
|  | 	store.Options(router.SessionOptions()) | ||||||
|  | 	sessionMiddleware := sessions.Sessions("gotosocial-localhost", store) | ||||||
|  | 	sessionMiddleware(ctx) | ||||||
|  | 
 | ||||||
|  | 	return ctx, recorder | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -44,7 +44,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { | ||||||
| 	s := sessions.Default(c) | 	s := sessions.Default(c) | ||||||
| 
 | 
 | ||||||
| 	if _, err := api.NegotiateAccept(c, api.HTMLAcceptHeaders...); err != nil { | 	if _, err := api.NegotiateAccept(c, api.HTMLAcceptHeaders...); err != nil { | ||||||
| 		c.JSON(http.StatusNotAcceptable, gin.H{"error": err.Error()}) | 		c.HTML(http.StatusNotAcceptable, "error.tmpl", gin.H{"error": err.Error()}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -57,7 +57,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { | ||||||
| 		if err := c.Bind(form); err != nil { | 		if err := c.Bind(form); err != nil { | ||||||
| 			l.Debugf("invalid auth form: %s", err) | 			l.Debugf("invalid auth form: %s", err) | ||||||
| 			m.clearSession(s) | 			m.clearSession(s) | ||||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | 			c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": err.Error()}) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		l.Debugf("parsed auth form: %+v", form) | 		l.Debugf("parsed auth form: %+v", form) | ||||||
|  | @ -65,7 +65,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { | ||||||
| 		if err := extractAuthForm(s, form); err != nil { | 		if err := extractAuthForm(s, form); err != nil { | ||||||
| 			l.Debugf(fmt.Sprintf("error parsing form at /oauth/authorize: %s", err)) | 			l.Debugf(fmt.Sprintf("error parsing form at /oauth/authorize: %s", err)) | ||||||
| 			m.clearSession(s) | 			m.clearSession(s) | ||||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | 			c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": err.Error()}) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		c.Redirect(http.StatusSeeOther, AuthSignInPath) | 		c.Redirect(http.StatusSeeOther, AuthSignInPath) | ||||||
|  | @ -75,28 +75,33 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { | ||||||
| 	// We can use the client_id on the session to retrieve info about the app associated with the client_id | 	// We can use the client_id on the session to retrieve info about the app associated with the client_id | ||||||
| 	clientID, ok := s.Get(sessionClientID).(string) | 	clientID, ok := s.Get(sessionClientID).(string) | ||||||
| 	if !ok || clientID == "" { | 	if !ok || clientID == "" { | ||||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"}) | 		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": "no client_id found in session"}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	app := >smodel.Application{} | 	app := >smodel.Application{} | ||||||
| 	if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil { | 	if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil { | ||||||
| 		m.clearSession(s) | 		m.clearSession(s) | ||||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) | 		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{ | ||||||
|  | 			"error": fmt.Sprintf("no application found for client id %s", clientID), | ||||||
|  | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// we can also use the userid of the user to fetch their username from the db to greet them nicely <3 | 	// redirect the user if they have not confirmed their email yet, thier account has not been approved yet, | ||||||
|  | 	// or thier account has been disabled. | ||||||
| 	user := >smodel.User{} | 	user := >smodel.User{} | ||||||
| 	if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { | 	if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { | ||||||
| 		m.clearSession(s) | 		m.clearSession(s) | ||||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | 		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 |  | ||||||
| 	acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) | 	acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		m.clearSession(s) | 		m.clearSession(s) | ||||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | 		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if !ensureUserIsAuthorizedOrRedirect(c, user, acct) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -104,13 +109,13 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { | ||||||
| 	redirect, ok := s.Get(sessionRedirectURI).(string) | 	redirect, ok := s.Get(sessionRedirectURI).(string) | ||||||
| 	if !ok || redirect == "" { | 	if !ok || redirect == "" { | ||||||
| 		m.clearSession(s) | 		m.clearSession(s) | ||||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": "no redirect_uri found in session"}) | 		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": "no redirect_uri found in session"}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	scope, ok := s.Get(sessionScope).(string) | 	scope, ok := s.Get(sessionScope).(string) | ||||||
| 	if !ok || scope == "" { | 	if !ok || scope == "" { | ||||||
| 		m.clearSession(s) | 		m.clearSession(s) | ||||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": "no scope found in session"}) | 		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": "no scope found in session"}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -170,10 +175,28 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) { | ||||||
| 		errs = append(errs, "session missing userid") | 		errs = append(errs, "session missing userid") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// redirect the user if they have not confirmed their email yet, thier account has not been approved yet, | ||||||
|  | 	// or thier account has been disabled. | ||||||
|  | 	user := >smodel.User{} | ||||||
|  | 	if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { | ||||||
|  | 		m.clearSession(s) | ||||||
|  | 		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	acct, err := m.db.GetAccountByID(c.Request.Context(), user.AccountID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		m.clearSession(s) | ||||||
|  | 		c.HTML(http.StatusInternalServerError, "error.tmpl", gin.H{"error": err.Error()}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if !ensureUserIsAuthorizedOrRedirect(c, user, acct) { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	m.clearSession(s) | 	m.clearSession(s) | ||||||
| 
 | 
 | ||||||
| 	if len(errs) != 0 { | 	if len(errs) != 0 { | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": strings.Join(errs, ": ")}) | 		c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": strings.Join(errs, ": ")}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -190,7 +213,7 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) { | ||||||
| 
 | 
 | ||||||
| 	// and proceed with authorization using the oauth2 library | 	// and proceed with authorization using the oauth2 library | ||||||
| 	if err := m.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { | 	if err := m.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | 		c.HTML(http.StatusBadRequest, "error.tmpl", gin.H{"error": err.Error()}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -216,3 +239,27 @@ func extractAuthForm(s sessions.Session, form *model.OAuthAuthorize) error { | ||||||
| 	s.Set(sessionState, uuid.NewString()) | 	s.Set(sessionState, uuid.NewString()) | ||||||
| 	return s.Save() | 	return s.Save() | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func ensureUserIsAuthorizedOrRedirect(ctx *gin.Context, user *gtsmodel.User, account *gtsmodel.Account) bool { | ||||||
|  | 	if user.ConfirmedAt.IsZero() { | ||||||
|  | 		ctx.Redirect(http.StatusSeeOther, CheckYourEmailPath) | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !user.Approved { | ||||||
|  | 		ctx.Redirect(http.StatusSeeOther, WaitForApprovalPath) | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if user.Disabled { | ||||||
|  | 		ctx.Redirect(http.StatusSeeOther, AccountDisabledPath) | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !account.SuspendedAt.IsZero() { | ||||||
|  | 		ctx.Redirect(http.StatusSeeOther, AccountDisabledPath) | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  |  | ||||||
							
								
								
									
										113
									
								
								internal/api/client/auth/authorize_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								internal/api/client/auth/authorize_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,113 @@ | ||||||
|  | package auth_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"codeberg.org/gruf/go-errors" | ||||||
|  | 	"github.com/gin-contrib/sessions" | ||||||
|  | 	"github.com/stretchr/testify/suite" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/api/client/auth" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type AuthAuthorizeTestSuite struct { | ||||||
|  | 	AuthStandardTestSuite | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type authorizeHandlerTestCase struct { | ||||||
|  | 	description            string | ||||||
|  | 	mutateUserAccount      func(*gtsmodel.User, *gtsmodel.Account) | ||||||
|  | 	expectedStatusCode     int | ||||||
|  | 	expectedLocationHeader string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() { | ||||||
|  | 
 | ||||||
|  | 	var tests = []authorizeHandlerTestCase{ | ||||||
|  | 		{ | ||||||
|  | 			description: "user has their email unconfirmed", | ||||||
|  | 			mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { | ||||||
|  | 				// nothing to do, weed_lord420 already has their email unconfirmed | ||||||
|  | 			}, | ||||||
|  | 			expectedStatusCode:     http.StatusSeeOther, | ||||||
|  | 			expectedLocationHeader: auth.CheckYourEmailPath, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			description: "user has their email confirmed but is not approved", | ||||||
|  | 			mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { | ||||||
|  | 				user.ConfirmedAt = time.Now() | ||||||
|  | 				user.Email = user.UnconfirmedEmail | ||||||
|  | 			}, | ||||||
|  | 			expectedStatusCode:     http.StatusSeeOther, | ||||||
|  | 			expectedLocationHeader: auth.WaitForApprovalPath, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			description: "user has their email confirmed and is approved, but User entity has been disabled", | ||||||
|  | 			mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { | ||||||
|  | 				user.ConfirmedAt = time.Now() | ||||||
|  | 				user.Email = user.UnconfirmedEmail | ||||||
|  | 				user.Approved = true | ||||||
|  | 				user.Disabled = true | ||||||
|  | 			}, | ||||||
|  | 			expectedStatusCode:     http.StatusSeeOther, | ||||||
|  | 			expectedLocationHeader: auth.AccountDisabledPath, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			description: "user has their email confirmed and is approved, but Account entity has been suspended", | ||||||
|  | 			mutateUserAccount: func(user *gtsmodel.User, account *gtsmodel.Account) { | ||||||
|  | 				user.ConfirmedAt = time.Now() | ||||||
|  | 				user.Email = user.UnconfirmedEmail | ||||||
|  | 				user.Approved = true | ||||||
|  | 				user.Disabled = false | ||||||
|  | 				account.SuspendedAt = time.Now() | ||||||
|  | 			}, | ||||||
|  | 			expectedStatusCode:     http.StatusSeeOther, | ||||||
|  | 			expectedLocationHeader: auth.AccountDisabledPath, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	doTest := func(testCase authorizeHandlerTestCase) { | ||||||
|  | 		ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath) | ||||||
|  | 
 | ||||||
|  | 		user := suite.testUsers["unconfirmed_account"] | ||||||
|  | 		account := suite.testAccounts["unconfirmed_account"] | ||||||
|  | 
 | ||||||
|  | 		testSession := sessions.Default(ctx) | ||||||
|  | 		testSession.Set(sessionUserID, user.ID) | ||||||
|  | 		testSession.Set(sessionClientID, suite.testApplications["application_1"].ClientID) | ||||||
|  | 		if err := testSession.Save(); err != nil { | ||||||
|  | 			panic(errors.WrapMsgf(err, "failed on case: %s", testCase.description)) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		testCase.mutateUserAccount(user, account) | ||||||
|  | 
 | ||||||
|  | 		testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, user.Disabled, account.SuspendedAt) | ||||||
|  | 
 | ||||||
|  | 		user.UpdatedAt = time.Now() | ||||||
|  | 		err := suite.db.UpdateByPrimaryKey(context.Background(), user) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 		_, err = suite.db.UpdateAccount(context.Background(), account) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// call the handler | ||||||
|  | 		suite.authModule.AuthorizeGETHandler(ctx) | ||||||
|  | 
 | ||||||
|  | 		// 1. we should have a redirect | ||||||
|  | 		suite.Equal(testCase.expectedStatusCode, recorder.Code, fmt.Sprintf("failed on case: %s", testCase.description)) | ||||||
|  | 
 | ||||||
|  | 		// 2. we should have a redirect to the check your email path, as this user has not confirmed their email yet. | ||||||
|  | 		suite.Equal(testCase.expectedLocationHeader, recorder.Header().Get("Location"), fmt.Sprintf("failed on case: %s", testCase.description)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, testCase := range tests { | ||||||
|  | 		doTest(testCase) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestAccountUpdateTestSuite(t *testing.T) { | ||||||
|  | 	suite.Run(t, new(AuthAuthorizeTestSuite)) | ||||||
|  | } | ||||||
|  | @ -62,6 +62,22 @@ func (m *Module) TokenCheck(c *gin.Context) { | ||||||
| 			l.Warnf("no user found for userID %s", userID) | 			l.Warnf("no user found for userID %s", userID) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
|  | 		if user.ConfirmedAt.IsZero() { | ||||||
|  | 			l.Warnf("authenticated user %s has never confirmed thier email address", userID) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if !user.Approved { | ||||||
|  | 			l.Warnf("authenticated user %s's account was never approved by an admin", userID) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if user.Disabled { | ||||||
|  | 			l.Warnf("authenticated user %s's account was disabled'", userID) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		c.Set(oauth.SessionAuthorizedUser, user) | 		c.Set(oauth.SessionAuthorizedUser, user) | ||||||
| 
 | 
 | ||||||
| 		// fetch account for this token | 		// fetch account for this token | ||||||
|  | @ -74,6 +90,12 @@ func (m *Module) TokenCheck(c *gin.Context) { | ||||||
| 			l.Warnf("no account found for userID %s", userID) | 			l.Warnf("no account found for userID %s", userID) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
|  | 		if !acct.SuspendedAt.IsZero() { | ||||||
|  | 			l.Warnf("authenticated user %s's account (accountId=%s) has been suspended", userID, user.AccountID) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		c.Set(oauth.SessionAuthorizedAccount, acct) | 		c.Set(oauth.SessionAuthorizedAccount, acct) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -19,7 +19,7 @@ func processPostgresError(err error) db.Error { | ||||||
| 	// (https://www.postgresql.org/docs/10/errcodes-appendix.html) | 	// (https://www.postgresql.org/docs/10/errcodes-appendix.html) | ||||||
| 	switch pgErr.Code { | 	switch pgErr.Code { | ||||||
| 	case "23505" /* unique_violation */ : | 	case "23505" /* unique_violation */ : | ||||||
| 		return db.ErrAlreadyExists | 		return db.NewErrAlreadyExists(pgErr.Message) | ||||||
| 	default: | 	default: | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | @ -36,7 +36,7 @@ func processSQLiteError(err error) db.Error { | ||||||
| 	// Handle supplied error code: | 	// Handle supplied error code: | ||||||
| 	switch sqliteErr.Code() { | 	switch sqliteErr.Code() { | ||||||
| 	case sqlite3.SQLITE_CONSTRAINT_UNIQUE: | 	case sqlite3.SQLITE_CONSTRAINT_UNIQUE: | ||||||
| 		return db.ErrAlreadyExists | 		return db.NewErrAlreadyExists(err.Error()) | ||||||
| 	default: | 	default: | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -28,8 +28,19 @@ var ( | ||||||
| 	ErrNoEntries Error = fmt.Errorf("no entries") | 	ErrNoEntries Error = fmt.Errorf("no entries") | ||||||
| 	// ErrMultipleEntries is returned when a caller expected ONE entry for a query, but multiples were found. | 	// ErrMultipleEntries is returned when a caller expected ONE entry for a query, but multiples were found. | ||||||
| 	ErrMultipleEntries Error = fmt.Errorf("multiple entries") | 	ErrMultipleEntries Error = fmt.Errorf("multiple entries") | ||||||
| 	// ErrAlreadyExists is returned when a caller tries to insert a database entry that already exists in the db. |  | ||||||
| 	ErrAlreadyExists Error = fmt.Errorf("already exists") |  | ||||||
| 	// ErrUnknown denotes an unknown database error. | 	// ErrUnknown denotes an unknown database error. | ||||||
| 	ErrUnknown Error = fmt.Errorf("unknown error") | 	ErrUnknown Error = fmt.Errorf("unknown error") | ||||||
| ) | ) | ||||||
|  | 
 | ||||||
|  | // ErrAlreadyExists is returned when a caller tries to insert a database entry that already exists in the db. | ||||||
|  | type ErrAlreadyExists struct { | ||||||
|  | 	message string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (e *ErrAlreadyExists) Error() string { | ||||||
|  | 	return e.message | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func NewErrAlreadyExists(msg string) error { | ||||||
|  | 	return &ErrAlreadyExists{message: msg} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -20,6 +20,7 @@ package dereferencing | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
|  | @ -60,7 +61,8 @@ func (d *deref) GetRemoteAttachment(ctx context.Context, requestingUsername stri | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := d.db.Put(ctx, a); err != nil { | 	if err := d.db.Put(ctx, a); err != nil { | ||||||
| 		if err != db.ErrAlreadyExists { | 		var alreadyExistsError *db.ErrAlreadyExists | ||||||
|  | 		if !errors.As(err, &alreadyExistsError) { | ||||||
| 			return nil, fmt.Errorf("GetRemoteAttachment: error inserting attachment: %s", err) | 			return nil, fmt.Errorf("GetRemoteAttachment: error inserting attachment: %s", err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -231,7 +231,8 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream | ||||||
| 	status.ID = statusID | 	status.ID = statusID | ||||||
| 
 | 
 | ||||||
| 	if err := f.db.PutStatus(ctx, status); err != nil { | 	if err := f.db.PutStatus(ctx, status); err != nil { | ||||||
| 		if err == db.ErrAlreadyExists { | 		var alreadyExistsError *db.ErrAlreadyExists | ||||||
|  | 		if errors.As(err, &alreadyExistsError) { | ||||||
| 			// the status already exists in the database, which means we've already handled everything else, | 			// the status already exists in the database, which means we've already handled everything else, | ||||||
| 			// so we can just return nil here and be done with it. | 			// so we can just return nil here and be done with it. | ||||||
| 			return nil | 			return nil | ||||||
|  |  | ||||||
|  | @ -78,25 +78,12 @@ func Authed(c *gin.Context, requireToken bool, requireApp bool, requireUser bool | ||||||
| 		return nil, errors.New("application not supplied") | 		return nil, errors.New("application not supplied") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if requireUser { | 	if requireUser && a.User == nil { | ||||||
| 		if a.User == nil { | 		return nil, errors.New("user not supplied or not authorized") | ||||||
| 			return nil, errors.New("user not supplied") |  | ||||||
| 		} |  | ||||||
| 		if a.User.Disabled || !a.User.Approved { |  | ||||||
| 			return nil, errors.New("user disabled or not approved") |  | ||||||
| 		} |  | ||||||
| 		if a.User.Email == "" { |  | ||||||
| 			return nil, errors.New("user has no confirmed email address") |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if requireAccount { | 	if requireAccount && a.Account == nil { | ||||||
| 		if a.Account == nil { | 		return nil, errors.New("account not supplied or not authorized") | ||||||
| 			return nil, errors.New("account not supplied") |  | ||||||
| 		} |  | ||||||
| 		if !a.Account.SuspendedAt.IsZero() { |  | ||||||
| 			return nil, errors.New("account suspended") |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return a, nil | 	return a, nil | ||||||
|  |  | ||||||
|  | @ -223,9 +223,12 @@ func (p *processor) ProcessTags(ctx context.Context, form *apimodel.AdvancedStat | ||||||
| 		return fmt.Errorf("error generating hashtags from status: %s", err) | 		return fmt.Errorf("error generating hashtags from status: %s", err) | ||||||
| 	} | 	} | ||||||
| 	for _, tag := range gtsTags { | 	for _, tag := range gtsTags { | ||||||
| 		if err := p.db.Put(ctx, tag); err != nil && err != db.ErrAlreadyExists { | 		if err := p.db.Put(ctx, tag); err != nil { | ||||||
|  | 			var alreadyExistsError *db.ErrAlreadyExists | ||||||
|  | 			if !errors.As(err, &alreadyExistsError) { | ||||||
| 				return fmt.Errorf("error putting tags in db: %s", err) | 				return fmt.Errorf("error putting tags in db: %s", err) | ||||||
| 			} | 			} | ||||||
|  | 		} | ||||||
| 		tags = append(tags, tag.ID) | 		tags = append(tags, tag.ID) | ||||||
| 	} | 	} | ||||||
| 	// add full populated gts tags to the status for passing them around conveniently | 	// add full populated gts tags to the status for passing them around conveniently | ||||||
|  |  | ||||||
|  | @ -138,7 +138,7 @@ func New(ctx context.Context, db db.DB) (Router, error) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// set template functions | 	// set template functions | ||||||
| 	loadTemplateFunctions(engine) | 	LoadTemplateFunctions(engine) | ||||||
| 
 | 
 | ||||||
| 	// load templates onto the engine | 	// load templates onto the engine | ||||||
| 	if err := loadTemplates(engine); err != nil { | 	if err := loadTemplates(engine); err != nil { | ||||||
|  |  | ||||||
|  | @ -33,8 +33,8 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // sessionOptions returns the standard set of options to use for each session. | // SessionOptions returns the standard set of options to use for each session. | ||||||
| func sessionOptions() sessions.Options { | func SessionOptions() sessions.Options { | ||||||
| 	return sessions.Options{ | 	return sessions.Options{ | ||||||
| 		Path:     "/", | 		Path:     "/", | ||||||
| 		Domain:   viper.GetString(config.Keys.Host), | 		Domain:   viper.GetString(config.Keys.Host), | ||||||
|  | @ -75,7 +75,7 @@ func useSession(ctx context.Context, sessionDB db.Session, engine *gin.Engine) e | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	store := memstore.NewStore(rs.Auth, rs.Crypt) | 	store := memstore.NewStore(rs.Auth, rs.Crypt) | ||||||
| 	store.Options(sessionOptions()) | 	store.Options(SessionOptions()) | ||||||
| 
 | 
 | ||||||
| 	sessionName, err := SessionName() | 	sessionName, err := SessionName() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  |  | ||||||
|  | @ -31,7 +31,7 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/config" | 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // loadTemplates loads html templates for use by the given engine | // LoadTemplates loads html templates for use by the given engine | ||||||
| func loadTemplates(engine *gin.Engine) error { | func loadTemplates(engine *gin.Engine) error { | ||||||
| 	cwd, err := os.Getwd() | 	cwd, err := os.Getwd() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -39,8 +39,13 @@ func loadTemplates(engine *gin.Engine) error { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	templateBaseDir := viper.GetString(config.Keys.WebTemplateBaseDir) | 	templateBaseDir := viper.GetString(config.Keys.WebTemplateBaseDir) | ||||||
| 	tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", templateBaseDir)) |  | ||||||
| 
 | 
 | ||||||
|  | 	_, err = os.Stat(filepath.Join(cwd, templateBaseDir, "index.tmpl")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("%s doesn't seem to contain the templates; index.tmpl is missing: %s", filepath.Join(cwd, templateBaseDir), err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", templateBaseDir)) | ||||||
| 	engine.LoadHTMLGlob(tmPath) | 	engine.LoadHTMLGlob(tmPath) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  | @ -87,7 +92,7 @@ func visibilityIcon(visibility model.Visibility) template.HTML { | ||||||
| 	return template.HTML(fmt.Sprintf(`<i aria-label="Visibility: %v" class="fa fa-%v"></i>`, icon.label, icon.faIcon)) | 	return template.HTML(fmt.Sprintf(`<i aria-label="Visibility: %v" class="fa fa-%v"></i>`, icon.label, icon.faIcon)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func loadTemplateFunctions(engine *gin.Engine) { | func LoadTemplateFunctions(engine *gin.Engine) { | ||||||
| 	engine.SetFuncMap(template.FuncMap{ | 	engine.SetFuncMap(template.FuncMap{ | ||||||
| 		"noescape":       noescape, | 		"noescape":       noescape, | ||||||
| 		"oddOrEven":      oddOrEven, | 		"oddOrEven":      oddOrEven, | ||||||
|  |  | ||||||
|  | @ -20,7 +20,14 @@ package testrig | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"fmt" | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  | 	"runtime" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/spf13/viper" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/router" | 	"github.com/superseriousbusiness/gotosocial/internal/router" | ||||||
| ) | ) | ||||||
|  | @ -33,3 +40,26 @@ func NewTestRouter(db db.DB) router.Router { | ||||||
| 	} | 	} | ||||||
| 	return r | 	return r | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // ConfigureTemplatesWithGin will panic on any errors related to template loading during tests | ||||||
|  | func ConfigureTemplatesWithGin(engine *gin.Engine) { | ||||||
|  | 
 | ||||||
|  | 	router.LoadTemplateFunctions(engine) | ||||||
|  | 
 | ||||||
|  | 	// https://stackoverflow.com/questions/31873396/is-it-possible-to-get-the-current-root-of-package-structure-as-a-string-in-golan | ||||||
|  | 	_, runtimeCallerLocation, _, _ := runtime.Caller(0) | ||||||
|  | 	projectRoot, err := filepath.Abs(filepath.Join(filepath.Dir(runtimeCallerLocation), "../")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	templateBaseDir := viper.GetString(config.Keys.WebTemplateBaseDir) | ||||||
|  | 
 | ||||||
|  | 	_, err = os.Stat(filepath.Join(projectRoot, templateBaseDir, "index.tmpl")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(fmt.Errorf("%s doesn't seem to contain the templates; index.tmpl is missing: %s", filepath.Join(projectRoot, templateBaseDir), err)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	tmPath := filepath.Join(projectRoot, fmt.Sprintf("%s*", templateBaseDir)) | ||||||
|  | 	engine.LoadHTMLGlob(tmPath) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -165,6 +165,25 @@ section.login form button { | ||||||
| 			grid-column: 2; | 			grid-column: 2; | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | section.error { | ||||||
|  |   display: flex; | ||||||
|  |   flex-direction: row; | ||||||
|  |   align-items: center; | ||||||
|  | } | ||||||
|  | section.error span { | ||||||
|  |   font-size: 2em; | ||||||
|  | } | ||||||
|  | section.error pre { | ||||||
|  |   border: 1px solid #ff000080; | ||||||
|  |   margin-left: 1em; | ||||||
|  |   padding: 0 0.7em; | ||||||
|  |   border-radius: 0.5em; | ||||||
|  |   background-color: #ff000010; | ||||||
|  |   font-size: 1.3em; | ||||||
|  |   white-space: pre-wrap; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| input, select, textarea { | input, select, textarea { | ||||||
| 	border: 1px solid #fafaff; | 	border: 1px solid #fafaff; | ||||||
| 	color: #fafaff; | 	color: #fafaff; | ||||||
|  |  | ||||||
|  | @ -165,6 +165,24 @@ section.login { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | section.error { | ||||||
|  | 	display: flex; | ||||||
|  | 	flex-direction: row; | ||||||
|  | 	align-items: center; | ||||||
|  |   span { | ||||||
|  |     font-size: 2em; | ||||||
|  |   } | ||||||
|  |   pre { | ||||||
|  |     border: 1px solid #ff000080; | ||||||
|  |     margin-left: 1em; | ||||||
|  |     padding: 0 0.7em; | ||||||
|  |     border-radius: 0.5em; | ||||||
|  |     background-color: #ff000010; | ||||||
|  |     font-size: 1.3em; | ||||||
|  |     white-space: pre-wrap; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| input, select, textarea { | input, select, textarea { | ||||||
| 	border: 1px solid $fg; | 	border: 1px solid $fg; | ||||||
| 	color: $fg; | 	color: $fg; | ||||||
|  |  | ||||||
|  | @ -2,7 +2,13 @@ | ||||||
|     <main> |     <main> | ||||||
|         <form action="/oauth/authorize" method="POST"> |         <form action="/oauth/authorize" method="POST"> | ||||||
|             <h1>Hi {{.user}}!</h1> |             <h1>Hi {{.user}}!</h1> | ||||||
|             <p>Application <b>{{.appname}}</b> {{if len .appwebsite | eq 0 | not}}({{.appwebsite}}) {{end}}would like to perform actions on your behalf, with scope <em>{{.scope}}</em>.</p> |             <p> | ||||||
|  |               Application <b>{{.appname}}</b>  | ||||||
|  |               {{if len .appwebsite | eq 0 | not}} | ||||||
|  |                 ({{.appwebsite}})  | ||||||
|  |               {{end}} | ||||||
|  |               would like to perform actions on your behalf, with scope <em>{{.scope}}</em>. | ||||||
|  |             </p> | ||||||
|             <p>The application will redirect to {{.redirect}} to continue.</p> |             <p>The application will redirect to {{.redirect}} to continue.</p> | ||||||
|             <p> |             <p> | ||||||
|                 <button |                 <button | ||||||
|  |  | ||||||
							
								
								
									
										8
									
								
								web/template/error.tmpl
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								web/template/error.tmpl
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,8 @@ | ||||||
|  | {{ template "header.tmpl" .}} | ||||||
|  |     <main> | ||||||
|  |         <section class="error"> | ||||||
|  |           <span>❌</span> <pre>{{.error}}</pre> | ||||||
|  |         </section> | ||||||
|  |          | ||||||
|  |     </main> | ||||||
|  | {{ template "footer.tmpl" .}} | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue