| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | package server | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							| 
									
										
										
										
											2025-06-05 11:29:36 +02:00
										 |  |  | 	"cmp" | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | 	"context" | 
					
						
							|  |  |  | 	"encoding/json" | 
					
						
							|  |  |  | 	"fmt" | 
					
						
							|  |  |  | 	"net/http" | 
					
						
							|  |  |  | 	"net/url" | 
					
						
							|  |  |  | 	"strings" | 
					
						
							|  |  |  | 	"time" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-04-25 15:15:36 +02:00
										 |  |  | 	"code.superseriousbusiness.org/oauth2/v4" | 
					
						
							|  |  |  | 	"code.superseriousbusiness.org/oauth2/v4/errors" | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // NewDefaultServer create a default authorization server | 
					
						
							|  |  |  | func NewDefaultServer(manager oauth2.Manager) *Server { | 
					
						
							|  |  |  | 	return NewServer(NewConfig(), manager) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // NewServer create authorization server | 
					
						
							|  |  |  | func NewServer(cfg *Config, manager oauth2.Manager) *Server { | 
					
						
							|  |  |  | 	srv := &Server{ | 
					
						
							|  |  |  | 		Config:  cfg, | 
					
						
							|  |  |  | 		Manager: manager, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// default handler | 
					
						
							|  |  |  | 	srv.ClientInfoHandler = ClientBasicHandler | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { | 
					
						
							|  |  |  | 		return "", errors.ErrAccessDenied | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { | 
					
						
							|  |  |  | 		return "", errors.ErrAccessDenied | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return srv | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Server Provide authorization server | 
					
						
							|  |  |  | type Server struct { | 
					
						
							|  |  |  | 	Config                       *Config | 
					
						
							|  |  |  | 	Manager                      oauth2.Manager | 
					
						
							|  |  |  | 	ClientInfoHandler            ClientInfoHandler | 
					
						
							|  |  |  | 	ClientAuthorizedHandler      ClientAuthorizedHandler | 
					
						
							|  |  |  | 	ClientScopeHandler           ClientScopeHandler | 
					
						
							|  |  |  | 	UserAuthorizationHandler     UserAuthorizationHandler | 
					
						
							|  |  |  | 	PasswordAuthorizationHandler PasswordAuthorizationHandler | 
					
						
							|  |  |  | 	RefreshingValidationHandler  RefreshingValidationHandler | 
					
						
							|  |  |  | 	RefreshingScopeHandler       RefreshingScopeHandler | 
					
						
							|  |  |  | 	ResponseErrorHandler         ResponseErrorHandler | 
					
						
							|  |  |  | 	InternalErrorHandler         InternalErrorHandler | 
					
						
							|  |  |  | 	ExtensionFieldsHandler       ExtensionFieldsHandler | 
					
						
							|  |  |  | 	AccessTokenExpHandler        AccessTokenExpHandler | 
					
						
							|  |  |  | 	AuthorizeScopeHandler        AuthorizeScopeHandler | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { | 
					
						
							|  |  |  | 	if req == nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	data, _, _ := s.GetErrorData(err) | 
					
						
							|  |  |  | 	return s.redirect(w, req, data) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { | 
					
						
							|  |  |  | 	uri, err := s.GetRedirectURI(req, data) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	w.Header().Set("Location", uri) | 
					
						
							|  |  |  | 	w.WriteHeader(302) | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (s *Server) tokenError(w http.ResponseWriter, err error) error { | 
					
						
							|  |  |  | 	data, statusCode, header := s.GetErrorData(err) | 
					
						
							|  |  |  | 	return s.token(w, data, header, statusCode) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { | 
					
						
							|  |  |  | 	w.Header().Set("Content-Type", "application/json;charset=UTF-8") | 
					
						
							|  |  |  | 	w.Header().Set("Cache-Control", "no-store") | 
					
						
							|  |  |  | 	w.Header().Set("Pragma", "no-cache") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for key := range header { | 
					
						
							|  |  |  | 		w.Header().Set(key, header.Get(key)) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	status := http.StatusOK | 
					
						
							|  |  |  | 	if len(statusCode) > 0 && statusCode[0] > 0 { | 
					
						
							|  |  |  | 		status = statusCode[0] | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	w.WriteHeader(status) | 
					
						
							|  |  |  | 	return json.NewEncoder(w).Encode(data) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // GetRedirectURI get redirect uri | 
					
						
							|  |  |  | func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { | 
					
						
							|  |  |  | 	u, err := url.Parse(req.RedirectURI) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return "", err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	q := u.Query() | 
					
						
							|  |  |  | 	if req.State != "" { | 
					
						
							|  |  |  | 		q.Set("state", req.State) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for k, v := range data { | 
					
						
							|  |  |  | 		q.Set(k, fmt.Sprint(v)) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	switch req.ResponseType { | 
					
						
							|  |  |  | 	case oauth2.Code: | 
					
						
							|  |  |  | 		u.RawQuery = q.Encode() | 
					
						
							|  |  |  | 	case oauth2.Token: | 
					
						
							|  |  |  | 		u.RawQuery = "" | 
					
						
							|  |  |  | 		fragment, err := url.QueryUnescape(q.Encode()) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return "", err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		u.Fragment = fragment | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return u.String(), nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // CheckResponseType check allows response type | 
					
						
							|  |  |  | func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { | 
					
						
							|  |  |  | 	for _, art := range s.Config.AllowedResponseTypes { | 
					
						
							|  |  |  | 		if art == rt { | 
					
						
							|  |  |  | 			return true | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return false | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // CheckCodeChallengeMethod checks for allowed code challenge method | 
					
						
							|  |  |  | func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool { | 
					
						
							|  |  |  | 	for _, c := range s.Config.AllowedCodeChallengeMethods { | 
					
						
							|  |  |  | 		if c == ccm { | 
					
						
							|  |  |  | 			return true | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return false | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // ValidationAuthorizeRequest the authorization request validation | 
					
						
							|  |  |  | func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { | 
					
						
							|  |  |  | 	redirectURI := r.FormValue("redirect_uri") | 
					
						
							|  |  |  | 	clientID := r.FormValue("client_id") | 
					
						
							|  |  |  | 	if !(r.Method == "GET" || r.Method == "POST") || | 
					
						
							|  |  |  | 		clientID == "" { | 
					
						
							|  |  |  | 		return nil, errors.ErrInvalidRequest | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	resType := oauth2.ResponseType(r.FormValue("response_type")) | 
					
						
							|  |  |  | 	if resType.String() == "" { | 
					
						
							|  |  |  | 		return nil, errors.ErrUnsupportedResponseType | 
					
						
							|  |  |  | 	} else if allowed := s.CheckResponseType(resType); !allowed { | 
					
						
							|  |  |  | 		return nil, errors.ErrUnauthorizedClient | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	cc := r.FormValue("code_challenge") | 
					
						
							|  |  |  | 	if cc == "" && s.Config.ForcePKCE { | 
					
						
							|  |  |  | 		return nil, errors.ErrCodeChallengeRquired | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if cc != "" && (len(cc) < 43 || len(cc) > 128) { | 
					
						
							|  |  |  | 		return nil, errors.ErrInvalidCodeChallengeLen | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method")) | 
					
						
							|  |  |  | 	// set default | 
					
						
							|  |  |  | 	if ccm == "" { | 
					
						
							| 
									
										
										
										
											2025-06-05 11:29:36 +02:00
										 |  |  | 		ccm = cmp.Or( | 
					
						
							|  |  |  | 			s.Config.DefaultCodeChallengeMethod, | 
					
						
							|  |  |  | 			oauth2.CodeChallengePlain, | 
					
						
							|  |  |  | 		) | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) { | 
					
						
							|  |  |  | 		return nil, errors.ErrUnsupportedCodeChallengeMethod | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	req := &AuthorizeRequest{ | 
					
						
							|  |  |  | 		RedirectURI:         redirectURI, | 
					
						
							|  |  |  | 		ResponseType:        resType, | 
					
						
							|  |  |  | 		ClientID:            clientID, | 
					
						
							|  |  |  | 		State:               r.FormValue("state"), | 
					
						
							|  |  |  | 		Scope:               r.FormValue("scope"), | 
					
						
							|  |  |  | 		Request:             r, | 
					
						
							|  |  |  | 		CodeChallenge:       cc, | 
					
						
							|  |  |  | 		CodeChallengeMethod: ccm, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return req, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // GetAuthorizeToken get authorization token(code) | 
					
						
							|  |  |  | func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { | 
					
						
							|  |  |  | 	// check the client allows the grant type | 
					
						
							|  |  |  | 	if fn := s.ClientAuthorizedHandler; fn != nil { | 
					
						
							|  |  |  | 		gt := oauth2.AuthorizationCode | 
					
						
							|  |  |  | 		if req.ResponseType == oauth2.Token { | 
					
						
							|  |  |  | 			gt = oauth2.Implicit | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		allowed, err := fn(req.ClientID, gt) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return nil, err | 
					
						
							|  |  |  | 		} else if !allowed { | 
					
						
							|  |  |  | 			return nil, errors.ErrUnauthorizedClient | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-08 20:20:06 +01:00
										 |  |  | 	tgr := &oauth2.TokenGenerateRequest{ | 
					
						
							|  |  |  | 		ClientID:       req.ClientID, | 
					
						
							|  |  |  | 		UserID:         req.UserID, | 
					
						
							|  |  |  | 		RedirectURI:    req.RedirectURI, | 
					
						
							|  |  |  | 		Scope:          req.Scope, | 
					
						
							|  |  |  | 		AccessTokenExp: req.AccessTokenExp, | 
					
						
							|  |  |  | 		Request:        req.Request, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | 	// check the client allows the authorized scope | 
					
						
							|  |  |  | 	if fn := s.ClientScopeHandler; fn != nil { | 
					
						
							| 
									
										
										
										
											2021-09-08 20:20:06 +01:00
										 |  |  | 		allowed, err := fn(tgr) | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return nil, err | 
					
						
							|  |  |  | 		} else if !allowed { | 
					
						
							|  |  |  | 			return nil, errors.ErrInvalidScope | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-08 20:20:06 +01:00
										 |  |  | 	tgr.CodeChallenge = req.CodeChallenge | 
					
						
							|  |  |  | 	tgr.CodeChallengeMethod = req.CodeChallengeMethod | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | 	return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // GetAuthorizeData get authorization response data | 
					
						
							|  |  |  | func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { | 
					
						
							|  |  |  | 	if rt == oauth2.Code { | 
					
						
							|  |  |  | 		return map[string]interface{}{ | 
					
						
							|  |  |  | 			"code": ti.GetCode(), | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return s.GetTokenData(ti) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // HandleAuthorizeRequest the authorization request handling | 
					
						
							|  |  |  | func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { | 
					
						
							|  |  |  | 	ctx := r.Context() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	req, err := s.ValidationAuthorizeRequest(r) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return s.redirectError(w, req, err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// user authorization | 
					
						
							|  |  |  | 	userID, err := s.UserAuthorizationHandler(w, r) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return s.redirectError(w, req, err) | 
					
						
							|  |  |  | 	} else if userID == "" { | 
					
						
							|  |  |  | 		return nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	req.UserID = userID | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// specify the scope of authorization | 
					
						
							|  |  |  | 	if fn := s.AuthorizeScopeHandler; fn != nil { | 
					
						
							|  |  |  | 		scope, err := fn(w, r) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} else if scope != "" { | 
					
						
							|  |  |  | 			req.Scope = scope | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// specify the expiration time of access token | 
					
						
							|  |  |  | 	if fn := s.AccessTokenExpHandler; fn != nil { | 
					
						
							|  |  |  | 		exp, err := fn(w, r) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		req.AccessTokenExp = exp | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	ti, err := s.GetAuthorizeToken(ctx, req) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return s.redirectError(w, req, err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// If the redirect URI is empty, the default domain provided by the client is used. | 
					
						
							|  |  |  | 	if req.RedirectURI == "" { | 
					
						
							|  |  |  | 		client, err := s.Manager.GetClient(ctx, req.ClientID) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		req.RedirectURI = client.GetDomain() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // ValidationTokenRequest the token request validation | 
					
						
							|  |  |  | func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { | 
					
						
							|  |  |  | 	if v := r.Method; !(v == "POST" || | 
					
						
							|  |  |  | 		(s.Config.AllowGetAccessRequest && v == "GET")) { | 
					
						
							|  |  |  | 		return "", nil, errors.ErrInvalidRequest | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	gt := oauth2.GrantType(r.FormValue("grant_type")) | 
					
						
							|  |  |  | 	if gt.String() == "" { | 
					
						
							|  |  |  | 		return "", nil, errors.ErrUnsupportedGrantType | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-27 16:16:58 +01:00
										 |  |  | 	if !s.CheckGrantType(gt) { | 
					
						
							|  |  |  | 		return "", nil, errors.ErrUnsupportedGrantType | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | 	clientID, clientSecret, err := s.ClientInfoHandler(r) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return "", nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	tgr := &oauth2.TokenGenerateRequest{ | 
					
						
							|  |  |  | 		ClientID:     clientID, | 
					
						
							|  |  |  | 		ClientSecret: clientSecret, | 
					
						
							|  |  |  | 		Request:      r, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	switch gt { | 
					
						
							|  |  |  | 	case oauth2.AuthorizationCode: | 
					
						
							|  |  |  | 		tgr.RedirectURI = r.FormValue("redirect_uri") | 
					
						
							|  |  |  | 		tgr.Code = r.FormValue("code") | 
					
						
							|  |  |  | 		if tgr.RedirectURI == "" || | 
					
						
							|  |  |  | 			tgr.Code == "" { | 
					
						
							|  |  |  | 			return "", nil, errors.ErrInvalidRequest | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2021-09-08 20:20:06 +01:00
										 |  |  | 		tgr.CodeVerifier = r.FormValue("code_verifier") | 
					
						
							|  |  |  | 		if s.Config.ForcePKCE && tgr.CodeVerifier == "" { | 
					
						
							|  |  |  | 			return "", nil, errors.ErrInvalidRequest | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | 	case oauth2.PasswordCredentials: | 
					
						
							|  |  |  | 		tgr.Scope = r.FormValue("scope") | 
					
						
							|  |  |  | 		username, password := r.FormValue("username"), r.FormValue("password") | 
					
						
							|  |  |  | 		if username == "" || password == "" { | 
					
						
							|  |  |  | 			return "", nil, errors.ErrInvalidRequest | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		userID, err := s.PasswordAuthorizationHandler(username, password) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return "", nil, err | 
					
						
							|  |  |  | 		} else if userID == "" { | 
					
						
							|  |  |  | 			return "", nil, errors.ErrInvalidGrant | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		tgr.UserID = userID | 
					
						
							|  |  |  | 	case oauth2.ClientCredentials: | 
					
						
							|  |  |  | 		tgr.Scope = r.FormValue("scope") | 
					
						
							| 
									
										
										
										
											2021-09-08 20:51:42 +01:00
										 |  |  | 		tgr.RedirectURI = r.FormValue("redirect_uri") | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | 	case oauth2.Refreshing: | 
					
						
							|  |  |  | 		tgr.Refresh = r.FormValue("refresh_token") | 
					
						
							|  |  |  | 		tgr.Scope = r.FormValue("scope") | 
					
						
							|  |  |  | 		if tgr.Refresh == "" { | 
					
						
							|  |  |  | 			return "", nil, errors.ErrInvalidRequest | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return gt, tgr, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // CheckGrantType check allows grant type | 
					
						
							|  |  |  | func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { | 
					
						
							|  |  |  | 	for _, agt := range s.Config.AllowedGrantTypes { | 
					
						
							|  |  |  | 		if agt == gt { | 
					
						
							|  |  |  | 			return true | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return false | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // GetAccessToken access token | 
					
						
							| 
									
										
										
										
											2021-09-08 20:20:06 +01:00
										 |  |  | func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, | 
					
						
							|  |  |  | 	error) { | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | 	if allowed := s.CheckGrantType(gt); !allowed { | 
					
						
							|  |  |  | 		return nil, errors.ErrUnauthorizedClient | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if fn := s.ClientAuthorizedHandler; fn != nil { | 
					
						
							|  |  |  | 		allowed, err := fn(tgr.ClientID, gt) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return nil, err | 
					
						
							|  |  |  | 		} else if !allowed { | 
					
						
							|  |  |  | 			return nil, errors.ErrUnauthorizedClient | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	switch gt { | 
					
						
							|  |  |  | 	case oauth2.AuthorizationCode: | 
					
						
							|  |  |  | 		ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			switch err { | 
					
						
							| 
									
										
										
										
											2021-09-08 20:20:06 +01:00
										 |  |  | 			case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, errors.ErrMissingCodeChallenge: | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | 				return nil, errors.ErrInvalidGrant | 
					
						
							|  |  |  | 			case errors.ErrInvalidClient: | 
					
						
							|  |  |  | 				return nil, errors.ErrInvalidClient | 
					
						
							|  |  |  | 			default: | 
					
						
							|  |  |  | 				return nil, err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		return ti, nil | 
					
						
							|  |  |  | 	case oauth2.PasswordCredentials, oauth2.ClientCredentials: | 
					
						
							|  |  |  | 		if fn := s.ClientScopeHandler; fn != nil { | 
					
						
							| 
									
										
										
										
											2021-09-08 20:20:06 +01:00
										 |  |  | 			allowed, err := fn(tgr) | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return nil, err | 
					
						
							|  |  |  | 			} else if !allowed { | 
					
						
							|  |  |  | 				return nil, errors.ErrInvalidScope | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		return s.Manager.GenerateAccessToken(ctx, gt, tgr) | 
					
						
							|  |  |  | 	case oauth2.Refreshing: | 
					
						
							|  |  |  | 		// check scope | 
					
						
							| 
									
										
										
										
											2021-09-08 20:20:06 +01:00
										 |  |  | 		if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | 			rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { | 
					
						
							|  |  |  | 					return nil, errors.ErrInvalidGrant | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 				return nil, err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-08 20:20:06 +01:00
										 |  |  | 			allowed, err := scopeFn(tgr, rti.GetScope()) | 
					
						
							| 
									
										
										
										
											2021-08-12 21:03:24 +02:00
										 |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return nil, err | 
					
						
							|  |  |  | 			} else if !allowed { | 
					
						
							|  |  |  | 				return nil, errors.ErrInvalidScope | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if validationFn := s.RefreshingValidationHandler; validationFn != nil { | 
					
						
							|  |  |  | 			rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { | 
					
						
							|  |  |  | 					return nil, errors.ErrInvalidGrant | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 				return nil, err | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			allowed, err := validationFn(rti) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return nil, err | 
					
						
							|  |  |  | 			} else if !allowed { | 
					
						
							|  |  |  | 				return nil, errors.ErrInvalidScope | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		ti, err := s.Manager.RefreshAccessToken(ctx, tgr) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { | 
					
						
							|  |  |  | 				return nil, errors.ErrInvalidGrant | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			return nil, err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		return ti, nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil, errors.ErrUnsupportedGrantType | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // GetTokenData token data | 
					
						
							|  |  |  | func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { | 
					
						
							|  |  |  | 	data := map[string]interface{}{ | 
					
						
							|  |  |  | 		"access_token": ti.GetAccess(), | 
					
						
							|  |  |  | 		"token_type":   s.Config.TokenType, | 
					
						
							|  |  |  | 		"expires_in":   int64(ti.GetAccessExpiresIn() / time.Second), | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if scope := ti.GetScope(); scope != "" { | 
					
						
							|  |  |  | 		data["scope"] = scope | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if refresh := ti.GetRefresh(); refresh != "" { | 
					
						
							|  |  |  | 		data["refresh_token"] = refresh | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if fn := s.ExtensionFieldsHandler; fn != nil { | 
					
						
							|  |  |  | 		ext := fn(ti) | 
					
						
							|  |  |  | 		for k, v := range ext { | 
					
						
							|  |  |  | 			if _, ok := data[k]; ok { | 
					
						
							|  |  |  | 				continue | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			data[k] = v | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return data | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // HandleTokenRequest token request handling | 
					
						
							|  |  |  | func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { | 
					
						
							|  |  |  | 	ctx := r.Context() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	gt, tgr, err := s.ValidationTokenRequest(r) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return s.tokenError(w, err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	ti, err := s.GetAccessToken(ctx, gt, tgr) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return s.tokenError(w, err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return s.token(w, s.GetTokenData(ti), nil) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // GetErrorData get error response data | 
					
						
							|  |  |  | func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { | 
					
						
							|  |  |  | 	var re errors.Response | 
					
						
							|  |  |  | 	if v, ok := errors.Descriptions[err]; ok { | 
					
						
							|  |  |  | 		re.Error = err | 
					
						
							|  |  |  | 		re.Description = v | 
					
						
							|  |  |  | 		re.StatusCode = errors.StatusCodes[err] | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		if fn := s.InternalErrorHandler; fn != nil { | 
					
						
							|  |  |  | 			if v := fn(err); v != nil { | 
					
						
							|  |  |  | 				re = *v | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if re.Error == nil { | 
					
						
							|  |  |  | 			re.Error = errors.ErrServerError | 
					
						
							|  |  |  | 			re.Description = errors.Descriptions[errors.ErrServerError] | 
					
						
							|  |  |  | 			re.StatusCode = errors.StatusCodes[errors.ErrServerError] | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if fn := s.ResponseErrorHandler; fn != nil { | 
					
						
							|  |  |  | 		fn(&re) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	data := make(map[string]interface{}) | 
					
						
							|  |  |  | 	if err := re.Error; err != nil { | 
					
						
							|  |  |  | 		data["error"] = err.Error() | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if v := re.ErrorCode; v != 0 { | 
					
						
							|  |  |  | 		data["error_code"] = v | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if v := re.Description; v != "" { | 
					
						
							|  |  |  | 		data["error_description"] = v | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if v := re.URI; v != "" { | 
					
						
							|  |  |  | 		data["error_uri"] = v | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	statusCode := http.StatusInternalServerError | 
					
						
							|  |  |  | 	if v := re.StatusCode; v > 0 { | 
					
						
							|  |  |  | 		statusCode = v | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return data, statusCode, re.Header | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // BearerAuth parse bearer token | 
					
						
							|  |  |  | func (s *Server) BearerAuth(r *http.Request) (string, bool) { | 
					
						
							|  |  |  | 	auth := r.Header.Get("Authorization") | 
					
						
							|  |  |  | 	prefix := "Bearer " | 
					
						
							|  |  |  | 	token := "" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if auth != "" && strings.HasPrefix(auth, prefix) { | 
					
						
							|  |  |  | 		token = auth[len(prefix):] | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		token = r.FormValue("access_token") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return token, token != "" | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // ValidationBearerToken validation the bearer tokens | 
					
						
							|  |  |  | // https://tools.ietf.org/html/rfc6750 | 
					
						
							|  |  |  | func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { | 
					
						
							|  |  |  | 	ctx := r.Context() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	accessToken, ok := s.BearerAuth(r) | 
					
						
							|  |  |  | 	if !ok { | 
					
						
							|  |  |  | 		return nil, errors.ErrInvalidAccessToken | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return s.Manager.LoadAccessToken(ctx, accessToken) | 
					
						
							|  |  |  | } |