diff --git a/internal/api/auth/authorize.go b/internal/api/auth/authorize.go index 151245e6c..3676fd417 100644 --- a/internal/api/auth/authorize.go +++ b/internal/api/auth/authorize.go @@ -58,7 +58,7 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { return } - user := m.userFromSession(c, s) + user := m.mustUserFromSession(c, s) if user == nil { // Error already // written. @@ -81,21 +81,21 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) { return } - redirectURI := m.stringFromSession(c, s, sessionRedirectURI) + redirectURI := m.mustStringFromSession(c, s, sessionRedirectURI) if redirectURI == "" { // Error already // written. return } - scope := m.stringFromSession(c, s, sessionScope) + scope := m.mustStringFromSession(c, s, sessionScope) if scope == "" { // Error already // written. return } - app := m.appFromSession(c, s) + app := m.mustAppFromSession(c, s) if app == nil { // Error already // written. @@ -136,35 +136,35 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) { // can be validated by the oauth2 library. s := sessions.Default(c) - responseType := m.stringFromSession(c, s, sessionResponseType) + responseType := m.mustStringFromSession(c, s, sessionResponseType) if responseType == "" { // Error already // written. return } - clientID := m.stringFromSession(c, s, sessionClientID) + clientID := m.mustStringFromSession(c, s, sessionClientID) if clientID == "" { // Error already // written. return } - redirectURI := m.stringFromSession(c, s, sessionRedirectURI) + redirectURI := m.mustStringFromSession(c, s, sessionRedirectURI) if redirectURI == "" { // Error already // written. return } - scope := m.stringFromSession(c, s, sessionScope) + scope := m.mustStringFromSession(c, s, sessionScope) if scope == "" { // Error already // written. return } - user := m.userFromSession(c, s) + user := m.mustUserFromSession(c, s) if user == nil { // Error already // written. diff --git a/internal/api/auth/oob.go b/internal/api/auth/oob.go index 8a5355ef2..c723a1cb5 100644 --- a/internal/api/auth/oob.go +++ b/internal/api/auth/oob.go @@ -37,14 +37,14 @@ func (m *Module) OOBTokenGETHandler(c *gin.Context) { return } - user := m.userFromSession(c, s) + user := m.mustUserFromSession(c, s) if user == nil { // Error already // written. return } - scope := m.stringFromSession(c, s, sessionScope) + scope := m.mustStringFromSession(c, s, sessionScope) if scope == "" { // Error already // written. diff --git a/internal/api/auth/signin.go b/internal/api/auth/signin.go index 04e8ac654..2820255db 100644 --- a/internal/api/auth/signin.go +++ b/internal/api/auth/signin.go @@ -61,7 +61,7 @@ func (m *Module) SignInGETHandler(c *gin.Context) { // // We need the internal state to know where // to redirect to. - internalState := m.stringFromSession( + internalState := m.mustStringFromSession( c, sessions.Default(c), sessionInternalState, @@ -195,7 +195,7 @@ func incorrectPassword(err error) (*gtsmodel.User, gtserror.WithCode) { func (m *Module) TwoFactorCodeGETHandler(c *gin.Context) { s := sessions.Default(c) - user := m.userFromSession(c, s) + user := m.mustUserFromSession(c, s) if user == nil { // Error already // written. @@ -226,7 +226,7 @@ func (m *Module) TwoFactorCodeGETHandler(c *gin.Context) { func (m *Module) TwoFactorCodePOSTHandler(c *gin.Context) { s := sessions.Default(c) - user := m.userFromSession(c, s) + user := m.mustUserFromSession(c, s) if user == nil { // Error already // written. diff --git a/internal/api/auth/util.go b/internal/api/auth/util.go index 39ba98c43..f1aed0bc3 100644 --- a/internal/api/auth/util.go +++ b/internal/api/auth/util.go @@ -39,7 +39,13 @@ func (m *Module) mustSaveSession(s sessions.Session) { } } -func (m *Module) userFromSession( +// mustUserFromSession returns a *gtsmodel.User by checking the +// session for a user id and fetching the user from the database. +// +// On failure, the function clears session state, writes an internal +// error to the response writer, and returns nil. Callers should always +// return immediately if receiving nil back from this function! +func (m *Module) mustUserFromSession( c *gin.Context, s sessions.Session, ) *gtsmodel.User { @@ -74,7 +80,13 @@ func (m *Module) userFromSession( return user } -func (m *Module) appFromSession( +// mustAppFromSession returns a *gtsmodel.Application by checking the +// session for an application keyid and fetching the app from the database. +// +// On failure, the function clears session state, writes an internal +// error to the response writer, and returns nil. Callers should always +// return immediately if receiving nil back from this function! +func (m *Module) mustAppFromSession( c *gin.Context, s sessions.Session, ) *gtsmodel.Application { @@ -95,7 +107,14 @@ func (m *Module) appFromSession( return app } -func (m *Module) stringFromSession( +// mustStringFromSession returns the string value +// corresponding to the given session key, if any is set. +// +// On failure (nothing set), the function clears session +// state, writes an internal error to the response writer, +// and returns nil. Callers should always return immediately +// if receiving nil back from this function! +func (m *Module) mustStringFromSession( c *gin.Context, s sessions.Session, key string,