mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-10-29 18:52:24 -05:00
Grand test fixup (#138)
* start fixing up tests * fix up tests + automate with drone * fiddle with linting * messing about with drone.yml * some more fiddling * hmmm * add cache * add vendor directory * verbose * ci updates * update some little things * update sig
This commit is contained in:
parent
329a5e8144
commit
98263a7de6
2677 changed files with 1090869 additions and 219 deletions
500
vendor/github.com/superseriousbusiness/oauth2/v4/manage/manager.go
generated
vendored
Normal file
500
vendor/github.com/superseriousbusiness/oauth2/v4/manage/manager.go
generated
vendored
Normal file
|
|
@ -0,0 +1,500 @@
|
|||
package manage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/superseriousbusiness/oauth2/v4"
|
||||
"github.com/superseriousbusiness/oauth2/v4/errors"
|
||||
"github.com/superseriousbusiness/oauth2/v4/generates"
|
||||
"github.com/superseriousbusiness/oauth2/v4/models"
|
||||
)
|
||||
|
||||
// NewDefaultManager create to default authorization management instance
|
||||
func NewDefaultManager() *Manager {
|
||||
m := NewManager()
|
||||
// default implementation
|
||||
m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate())
|
||||
m.MapAccessGenerate(generates.NewAccessGenerate())
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// NewManager create to authorization management instance
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
gtcfg: make(map[oauth2.GrantType]*Config),
|
||||
validateURI: DefaultValidateURI,
|
||||
}
|
||||
}
|
||||
|
||||
// Manager provide authorization management
|
||||
type Manager struct {
|
||||
codeExp time.Duration
|
||||
gtcfg map[oauth2.GrantType]*Config
|
||||
rcfg *RefreshingConfig
|
||||
validateURI ValidateURIHandler
|
||||
authorizeGenerate oauth2.AuthorizeGenerate
|
||||
accessGenerate oauth2.AccessGenerate
|
||||
tokenStore oauth2.TokenStore
|
||||
clientStore oauth2.ClientStore
|
||||
}
|
||||
|
||||
// get grant type config
|
||||
func (m *Manager) grantConfig(gt oauth2.GrantType) *Config {
|
||||
if c, ok := m.gtcfg[gt]; ok && c != nil {
|
||||
return c
|
||||
}
|
||||
switch gt {
|
||||
case oauth2.AuthorizationCode:
|
||||
return DefaultAuthorizeCodeTokenCfg
|
||||
case oauth2.Implicit:
|
||||
return DefaultImplicitTokenCfg
|
||||
case oauth2.PasswordCredentials:
|
||||
return DefaultPasswordTokenCfg
|
||||
case oauth2.ClientCredentials:
|
||||
return DefaultClientTokenCfg
|
||||
}
|
||||
return &Config{}
|
||||
}
|
||||
|
||||
// SetAuthorizeCodeExp set the authorization code expiration time
|
||||
func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) {
|
||||
m.codeExp = exp
|
||||
}
|
||||
|
||||
// SetAuthorizeCodeTokenCfg set the authorization code grant token config
|
||||
func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) {
|
||||
m.gtcfg[oauth2.AuthorizationCode] = cfg
|
||||
}
|
||||
|
||||
// SetImplicitTokenCfg set the implicit grant token config
|
||||
func (m *Manager) SetImplicitTokenCfg(cfg *Config) {
|
||||
m.gtcfg[oauth2.Implicit] = cfg
|
||||
}
|
||||
|
||||
// SetPasswordTokenCfg set the password grant token config
|
||||
func (m *Manager) SetPasswordTokenCfg(cfg *Config) {
|
||||
m.gtcfg[oauth2.PasswordCredentials] = cfg
|
||||
}
|
||||
|
||||
// SetClientTokenCfg set the client grant token config
|
||||
func (m *Manager) SetClientTokenCfg(cfg *Config) {
|
||||
m.gtcfg[oauth2.ClientCredentials] = cfg
|
||||
}
|
||||
|
||||
// SetRefreshTokenCfg set the refreshing token config
|
||||
func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) {
|
||||
m.rcfg = cfg
|
||||
}
|
||||
|
||||
// SetValidateURIHandler set the validates that RedirectURI is contained in baseURI
|
||||
func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) {
|
||||
m.validateURI = handler
|
||||
}
|
||||
|
||||
// MapAuthorizeGenerate mapping the authorize code generate interface
|
||||
func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) {
|
||||
m.authorizeGenerate = gen
|
||||
}
|
||||
|
||||
// MapAccessGenerate mapping the access token generate interface
|
||||
func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) {
|
||||
m.accessGenerate = gen
|
||||
}
|
||||
|
||||
// MapClientStorage mapping the client store interface
|
||||
func (m *Manager) MapClientStorage(stor oauth2.ClientStore) {
|
||||
m.clientStore = stor
|
||||
}
|
||||
|
||||
// MustClientStorage mandatory mapping the client store interface
|
||||
func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) {
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
m.clientStore = stor
|
||||
}
|
||||
|
||||
// MapTokenStorage mapping the token store interface
|
||||
func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) {
|
||||
m.tokenStore = stor
|
||||
}
|
||||
|
||||
// MustTokenStorage mandatory mapping the token store interface
|
||||
func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
m.tokenStore = stor
|
||||
}
|
||||
|
||||
// GetClient get the client information
|
||||
func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) {
|
||||
cli, err = m.clientStore.GetByID(ctx, clientID)
|
||||
if err != nil {
|
||||
return
|
||||
} else if cli == nil {
|
||||
err = errors.ErrInvalidClient
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GenerateAuthToken generate the authorization token(code)
|
||||
func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
|
||||
cli, err := m.GetClient(ctx, tgr.ClientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if tgr.RedirectURI != "" {
|
||||
if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ti := models.NewToken()
|
||||
ti.SetClientID(tgr.ClientID)
|
||||
ti.SetUserID(tgr.UserID)
|
||||
ti.SetRedirectURI(tgr.RedirectURI)
|
||||
ti.SetScope(tgr.Scope)
|
||||
|
||||
createAt := time.Now()
|
||||
td := &oauth2.GenerateBasic{
|
||||
Client: cli,
|
||||
UserID: tgr.UserID,
|
||||
CreateAt: createAt,
|
||||
TokenInfo: ti,
|
||||
Request: tgr.Request,
|
||||
}
|
||||
switch rt {
|
||||
case oauth2.Code:
|
||||
codeExp := m.codeExp
|
||||
if codeExp == 0 {
|
||||
codeExp = DefaultCodeExp
|
||||
}
|
||||
ti.SetCodeCreateAt(createAt)
|
||||
ti.SetCodeExpiresIn(codeExp)
|
||||
if exp := tgr.AccessTokenExp; exp > 0 {
|
||||
ti.SetAccessExpiresIn(exp)
|
||||
}
|
||||
if tgr.CodeChallenge != "" {
|
||||
ti.SetCodeChallenge(tgr.CodeChallenge)
|
||||
ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod)
|
||||
}
|
||||
|
||||
tv, err := m.authorizeGenerate.Token(ctx, td)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ti.SetCode(tv)
|
||||
case oauth2.Token:
|
||||
// set access token expires
|
||||
icfg := m.grantConfig(oauth2.Implicit)
|
||||
aexp := icfg.AccessTokenExp
|
||||
if exp := tgr.AccessTokenExp; exp > 0 {
|
||||
aexp = exp
|
||||
}
|
||||
ti.SetAccessCreateAt(createAt)
|
||||
ti.SetAccessExpiresIn(aexp)
|
||||
|
||||
if icfg.IsGenerateRefresh {
|
||||
ti.SetRefreshCreateAt(createAt)
|
||||
ti.SetRefreshExpiresIn(icfg.RefreshTokenExp)
|
||||
}
|
||||
|
||||
tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ti.SetAccess(tv)
|
||||
|
||||
if rv != "" {
|
||||
ti.SetRefresh(rv)
|
||||
}
|
||||
}
|
||||
|
||||
err = m.tokenStore.Create(ctx, ti)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ti, nil
|
||||
}
|
||||
|
||||
// get authorization code data
|
||||
func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
|
||||
ti, err := m.tokenStore.GetByCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) {
|
||||
err = errors.ErrInvalidAuthorizeCode
|
||||
return nil, errors.ErrInvalidAuthorizeCode
|
||||
}
|
||||
return ti, nil
|
||||
}
|
||||
|
||||
// delete authorization code data
|
||||
func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error {
|
||||
return m.tokenStore.RemoveByCode(ctx, code)
|
||||
}
|
||||
|
||||
// get and delete authorization code data
|
||||
func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
|
||||
code := tgr.Code
|
||||
ti, err := m.getAuthorizationCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if ti.GetClientID() != tgr.ClientID {
|
||||
return nil, errors.ErrInvalidAuthorizeCode
|
||||
} else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI {
|
||||
return nil, errors.ErrInvalidAuthorizeCode
|
||||
}
|
||||
|
||||
err = m.delAuthorizationCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ti, nil
|
||||
}
|
||||
|
||||
func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error {
|
||||
cc := ti.GetCodeChallenge()
|
||||
// early return
|
||||
if cc == "" && ver == "" {
|
||||
return nil
|
||||
}
|
||||
if cc == "" {
|
||||
return errors.ErrMissingCodeVerifier
|
||||
}
|
||||
if ver == "" {
|
||||
return errors.ErrMissingCodeVerifier
|
||||
}
|
||||
ccm := ti.GetCodeChallengeMethod()
|
||||
if ccm.String() == "" {
|
||||
ccm = oauth2.CodeChallengePlain
|
||||
}
|
||||
if !ccm.Validate(cc, ver) {
|
||||
return errors.ErrInvalidCodeChallenge
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateAccessToken generate the access token
|
||||
func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
|
||||
cli, err := m.GetClient(ctx, tgr.ClientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok {
|
||||
if !cliPass.VerifyPassword(tgr.ClientSecret) {
|
||||
return nil, errors.ErrInvalidClient
|
||||
}
|
||||
} else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() {
|
||||
return nil, errors.ErrInvalidClient
|
||||
}
|
||||
if tgr.RedirectURI != "" {
|
||||
if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if gt == oauth2.AuthorizationCode {
|
||||
ti, err := m.getAndDelAuthorizationCode(ctx, tgr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tgr.UserID = ti.GetUserID()
|
||||
tgr.Scope = ti.GetScope()
|
||||
if exp := ti.GetAccessExpiresIn(); exp > 0 {
|
||||
tgr.AccessTokenExp = exp
|
||||
}
|
||||
}
|
||||
|
||||
ti := models.NewToken()
|
||||
ti.SetClientID(tgr.ClientID)
|
||||
ti.SetUserID(tgr.UserID)
|
||||
ti.SetRedirectURI(tgr.RedirectURI)
|
||||
ti.SetScope(tgr.Scope)
|
||||
|
||||
createAt := time.Now()
|
||||
ti.SetAccessCreateAt(createAt)
|
||||
|
||||
// set access token expires
|
||||
gcfg := m.grantConfig(gt)
|
||||
aexp := gcfg.AccessTokenExp
|
||||
if exp := tgr.AccessTokenExp; exp > 0 {
|
||||
aexp = exp
|
||||
}
|
||||
ti.SetAccessExpiresIn(aexp)
|
||||
if gcfg.IsGenerateRefresh {
|
||||
ti.SetRefreshCreateAt(createAt)
|
||||
ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp)
|
||||
}
|
||||
|
||||
td := &oauth2.GenerateBasic{
|
||||
Client: cli,
|
||||
UserID: tgr.UserID,
|
||||
CreateAt: createAt,
|
||||
TokenInfo: ti,
|
||||
Request: tgr.Request,
|
||||
}
|
||||
|
||||
av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ti.SetAccess(av)
|
||||
|
||||
if rv != "" {
|
||||
ti.SetRefresh(rv)
|
||||
}
|
||||
|
||||
err = m.tokenStore.Create(ctx, ti)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ti, nil
|
||||
}
|
||||
|
||||
// RefreshAccessToken refreshing an access token
|
||||
func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
|
||||
cli, err := m.GetClient(ctx, tgr.ClientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if tgr.ClientSecret != cli.GetSecret() {
|
||||
return nil, errors.ErrInvalidClient
|
||||
}
|
||||
|
||||
ti, err := m.LoadRefreshToken(ctx, tgr.Refresh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if ti.GetClientID() != tgr.ClientID {
|
||||
return nil, errors.ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh()
|
||||
|
||||
td := &oauth2.GenerateBasic{
|
||||
Client: cli,
|
||||
UserID: ti.GetUserID(),
|
||||
CreateAt: time.Now(),
|
||||
TokenInfo: ti,
|
||||
Request: tgr.Request,
|
||||
}
|
||||
|
||||
rcfg := DefaultRefreshTokenCfg
|
||||
if v := m.rcfg; v != nil {
|
||||
rcfg = v
|
||||
}
|
||||
|
||||
ti.SetAccessCreateAt(td.CreateAt)
|
||||
if v := rcfg.AccessTokenExp; v > 0 {
|
||||
ti.SetAccessExpiresIn(v)
|
||||
}
|
||||
|
||||
if v := rcfg.RefreshTokenExp; v > 0 {
|
||||
ti.SetRefreshExpiresIn(v)
|
||||
}
|
||||
|
||||
if rcfg.IsResetRefreshTime {
|
||||
ti.SetRefreshCreateAt(td.CreateAt)
|
||||
}
|
||||
|
||||
if scope := tgr.Scope; scope != "" {
|
||||
ti.SetScope(scope)
|
||||
}
|
||||
|
||||
tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ti.SetAccess(tv)
|
||||
if rv != "" {
|
||||
ti.SetRefresh(rv)
|
||||
}
|
||||
|
||||
if err := m.tokenStore.Create(ctx, ti); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rcfg.IsRemoveAccess {
|
||||
// remove the old access token
|
||||
if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if rcfg.IsRemoveRefreshing && rv != "" {
|
||||
// remove the old refresh token
|
||||
if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if rv == "" {
|
||||
ti.SetRefresh("")
|
||||
ti.SetRefreshCreateAt(time.Now())
|
||||
ti.SetRefreshExpiresIn(0)
|
||||
}
|
||||
|
||||
return ti, nil
|
||||
}
|
||||
|
||||
// RemoveAccessToken use the access token to delete the token information
|
||||
func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error {
|
||||
if access == "" {
|
||||
return errors.ErrInvalidAccessToken
|
||||
}
|
||||
return m.tokenStore.RemoveByAccess(ctx, access)
|
||||
}
|
||||
|
||||
// RemoveRefreshToken use the refresh token to delete the token information
|
||||
func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error {
|
||||
if refresh == "" {
|
||||
return errors.ErrInvalidAccessToken
|
||||
}
|
||||
return m.tokenStore.RemoveByRefresh(ctx, refresh)
|
||||
}
|
||||
|
||||
// LoadAccessToken according to the access token for corresponding token information
|
||||
func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) {
|
||||
if access == "" {
|
||||
return nil, errors.ErrInvalidAccessToken
|
||||
}
|
||||
|
||||
ct := time.Now()
|
||||
ti, err := m.tokenStore.GetByAccess(ctx, access)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if ti == nil || ti.GetAccess() != access {
|
||||
return nil, errors.ErrInvalidAccessToken
|
||||
} else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 &&
|
||||
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
|
||||
return nil, errors.ErrExpiredRefreshToken
|
||||
} else if ti.GetAccessExpiresIn() != 0 &&
|
||||
ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
|
||||
return nil, errors.ErrExpiredAccessToken
|
||||
}
|
||||
return ti, nil
|
||||
}
|
||||
|
||||
// LoadRefreshToken according to the refresh token for corresponding token information
|
||||
func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
|
||||
if refresh == "" {
|
||||
return nil, errors.ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
ti, err := m.tokenStore.GetByRefresh(ctx, refresh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if ti == nil || ti.GetRefresh() != refresh {
|
||||
return nil, errors.ErrInvalidRefreshToken
|
||||
} else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire
|
||||
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) {
|
||||
return nil, errors.ErrExpiredRefreshToken
|
||||
}
|
||||
return ti, nil
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue