package server import ( "cmp" "context" "encoding/json" "fmt" "net/http" "net/url" "time" "code.superseriousbusiness.org/oauth2/v4" "code.superseriousbusiness.org/oauth2/v4/errors" ) // NewDefaultServer create a default authorization server func NewDefaultServer(manager oauth2.Manager) *Server { return NewServer(NewConfig(), manager) } // NewServer create authorization server func NewServer(cfg *Config, manager oauth2.Manager) *Server { srv := &Server{ Config: cfg, Manager: manager, } // default handlers srv.ClientInfoHandler = ClientBasicHandler srv.RefreshTokenResolveHandler = RefreshTokenFormResolveHandler srv.AccessTokenResolveHandler = AccessTokenDefaultResolveHandler srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { return "", errors.ErrAccessDenied } srv.PasswordAuthorizationHandler = func(ctx context.Context, clientID, username, password string) (string, error) { return "", errors.ErrAccessDenied } return srv } // Server Provide authorization server type Server struct { Config *Config Manager oauth2.Manager ClientInfoHandler ClientInfoHandler ClientAuthorizedHandler ClientAuthorizedHandler ClientScopeHandler ClientScopeHandler UserAuthorizationHandler UserAuthorizationHandler PasswordAuthorizationHandler PasswordAuthorizationHandler RefreshingValidationHandler RefreshingValidationHandler PreRedirectErrorHandler PreRedirectErrorHandler RefreshingScopeHandler RefreshingScopeHandler ResponseErrorHandler ResponseErrorHandler InternalErrorHandler InternalErrorHandler ExtensionFieldsHandler ExtensionFieldsHandler AccessTokenExpHandler AccessTokenExpHandler AuthorizeScopeHandler AuthorizeScopeHandler ResponseTokenHandler ResponseTokenHandler RefreshTokenResolveHandler RefreshTokenResolveHandler AccessTokenResolveHandler AccessTokenResolveHandler } func (s *Server) handleError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { if fn := s.PreRedirectErrorHandler; fn != nil { return fn(w, req, err) } return s.redirectError(w, req, err) } func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { if req == nil { return err } data, _, _ := s.GetErrorData(err) return s.redirect(w, req, data) } func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { uri, err := s.GetRedirectURI(req, data) if err != nil { return err } w.Header().Set("Location", uri) w.WriteHeader(302) return nil } func (s *Server) tokenError(w http.ResponseWriter, err error) error { data, statusCode, header := s.GetErrorData(err) return s.token(w, data, header, statusCode) } func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { if fn := s.ResponseTokenHandler; fn != nil { return fn(w, data, header, statusCode...) } w.Header().Set("Content-Type", "application/json;charset=UTF-8") w.Header().Set("Cache-Control", "no-store") w.Header().Set("Pragma", "no-cache") for key := range header { w.Header().Set(key, header.Get(key)) } status := http.StatusOK if len(statusCode) > 0 && statusCode[0] > 0 { status = statusCode[0] } w.WriteHeader(status) return json.NewEncoder(w).Encode(data) } // GetRedirectURI get redirect uri func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { u, err := url.Parse(req.RedirectURI) if err != nil { return "", err } q := u.Query() if req.State != "" { q.Set("state", req.State) } for k, v := range data { q.Set(k, fmt.Sprint(v)) } switch req.ResponseType { case oauth2.Code: u.RawQuery = q.Encode() case oauth2.Token: u.RawQuery = "" fragment, err := url.QueryUnescape(q.Encode()) if err != nil { return "", err } u.Fragment = fragment } return u.String(), nil } // CheckResponseType check allows response type func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { for _, art := range s.Config.AllowedResponseTypes { if art == rt { return true } } return false } // CheckCodeChallengeMethod checks for allowed code challenge method func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool { for _, c := range s.Config.AllowedCodeChallengeMethods { if c == ccm { return true } } return false } // ValidationAuthorizeRequest the authorization request validation func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { redirectURI := r.FormValue("redirect_uri") clientID := r.FormValue("client_id") if !(r.Method == "GET" || r.Method == "POST") || clientID == "" { return nil, errors.ErrInvalidRequest } resType := oauth2.ResponseType(r.FormValue("response_type")) if resType.String() == "" { return nil, errors.ErrUnsupportedResponseType } else if allowed := s.CheckResponseType(resType); !allowed { return nil, errors.ErrUnauthorizedClient } cc := r.FormValue("code_challenge") if cc == "" && s.Config.ForcePKCE { return nil, errors.ErrCodeChallengeRquired } if cc != "" && (len(cc) < 43 || len(cc) > 128) { return nil, errors.ErrInvalidCodeChallengeLen } ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method")) // set default if ccm == "" { ccm = cmp.Or( s.Config.DefaultCodeChallengeMethod, oauth2.CodeChallengePlain, ) } if ccm != "" && !s.CheckCodeChallengeMethod(ccm) { return nil, errors.ErrUnsupportedCodeChallengeMethod } req := &AuthorizeRequest{ RedirectURI: redirectURI, ResponseType: resType, ClientID: clientID, State: r.FormValue("state"), Scope: r.FormValue("scope"), Request: r, CodeChallenge: cc, CodeChallengeMethod: ccm, } return req, nil } // GetAuthorizeToken get authorization token(code) func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { // check the client allows the grant type if fn := s.ClientAuthorizedHandler; fn != nil { gt := oauth2.AuthorizationCode if req.ResponseType == oauth2.Token { gt = oauth2.Implicit } allowed, err := fn(req.ClientID, gt) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrUnauthorizedClient } } tgr := &oauth2.TokenGenerateRequest{ ClientID: req.ClientID, UserID: req.UserID, RedirectURI: req.RedirectURI, Scope: req.Scope, AccessTokenExp: req.AccessTokenExp, Request: req.Request, } // check the client allows the authorized scope if fn := s.ClientScopeHandler; fn != nil { allowed, err := fn(tgr) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } tgr.CodeChallenge = req.CodeChallenge tgr.CodeChallengeMethod = req.CodeChallengeMethod return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) } // GetAuthorizeData get authorization response data func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { if rt == oauth2.Code { return map[string]interface{}{ "code": ti.GetCode(), } } return s.GetTokenData(ti) } // HandleAuthorizeRequest the authorization request handling func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() req, err := s.ValidationAuthorizeRequest(r) if err != nil { return s.handleError(w, req, err) } // user authorization userID, err := s.UserAuthorizationHandler(w, r) if err != nil { return s.handleError(w, req, err) } else if userID == "" { return nil } req.UserID = userID // specify the scope of authorization if fn := s.AuthorizeScopeHandler; fn != nil { scope, err := fn(w, r) if err != nil { return err } else if scope != "" { req.Scope = scope } } // specify the expiration time of access token if fn := s.AccessTokenExpHandler; fn != nil { exp, err := fn(w, r) if err != nil { return err } req.AccessTokenExp = exp } ti, err := s.GetAuthorizeToken(ctx, req) if err != nil { return s.handleError(w, req, err) } // If the redirect URI is empty, the default domain provided by the client is used. if req.RedirectURI == "" { client, err := s.Manager.GetClient(ctx, req.ClientID) if err != nil { return err } req.RedirectURI = client.GetDomain() } return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) } // ValidationTokenRequest the token request validation func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { if v := r.Method; !(v == "POST" || (s.Config.AllowGetAccessRequest && v == "GET")) { return "", nil, errors.ErrInvalidRequest } gt := oauth2.GrantType(r.FormValue("grant_type")) if gt.String() == "" { return "", nil, errors.ErrUnsupportedGrantType } if !s.CheckGrantType(gt) { return "", nil, errors.ErrUnsupportedGrantType } clientID, clientSecret, err := s.ClientInfoHandler(r) if err != nil { return "", nil, err } tgr := &oauth2.TokenGenerateRequest{ ClientID: clientID, ClientSecret: clientSecret, Request: r, } switch gt { case oauth2.AuthorizationCode: tgr.RedirectURI = r.FormValue("redirect_uri") tgr.Code = r.FormValue("code") if tgr.RedirectURI == "" || tgr.Code == "" { return "", nil, errors.ErrInvalidRequest } tgr.CodeVerifier = r.FormValue("code_verifier") if s.Config.ForcePKCE && tgr.CodeVerifier == "" { return "", nil, errors.ErrInvalidRequest } case oauth2.PasswordCredentials: tgr.Scope = r.FormValue("scope") username, password := r.FormValue("username"), r.FormValue("password") if username == "" || password == "" { return "", nil, errors.ErrInvalidRequest } userID, err := s.PasswordAuthorizationHandler(r.Context(), clientID, username, password) if err != nil { return "", nil, err } else if userID == "" { return "", nil, errors.ErrInvalidGrant } tgr.UserID = userID case oauth2.ClientCredentials: tgr.Scope = r.FormValue("scope") tgr.RedirectURI = r.FormValue("redirect_uri") case oauth2.Refreshing: tgr.Refresh, err = s.RefreshTokenResolveHandler(r) tgr.Scope = r.FormValue("scope") if err != nil { return "", nil, err } } return gt, tgr, nil } // CheckGrantType check allows grant type func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { for _, agt := range s.Config.AllowedGrantTypes { if agt == gt { return true } } return false } // GetAccessToken access token func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { if allowed := s.CheckGrantType(gt); !allowed { return nil, errors.ErrUnauthorizedClient } if fn := s.ClientAuthorizedHandler; fn != nil { allowed, err := fn(tgr.ClientID, gt) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrUnauthorizedClient } } switch gt { case oauth2.AuthorizationCode: ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) if err != nil { switch err { case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, errors.ErrMissingCodeChallenge: return nil, errors.ErrInvalidGrant case errors.ErrInvalidClient: return nil, errors.ErrInvalidClient default: return nil, err } } return ti, nil case oauth2.PasswordCredentials, oauth2.ClientCredentials: if fn := s.ClientScopeHandler; fn != nil { allowed, err := fn(tgr) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } return s.Manager.GenerateAccessToken(ctx, gt, tgr) case oauth2.Refreshing: // check scope if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { return nil, errors.ErrInvalidGrant } return nil, err } allowed, err := scopeFn(tgr, rti.GetScope()) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } if validationFn := s.RefreshingValidationHandler; validationFn != nil { rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { return nil, errors.ErrInvalidGrant } return nil, err } allowed, err := validationFn(rti) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } ti, err := s.Manager.RefreshAccessToken(ctx, tgr) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { return nil, errors.ErrInvalidGrant } return nil, err } return ti, nil } return nil, errors.ErrUnsupportedGrantType } // GetTokenData token data func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { data := map[string]interface{}{ "access_token": ti.GetAccess(), "token_type": s.Config.TokenType, "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), } if scope := ti.GetScope(); scope != "" { data["scope"] = scope } if refresh := ti.GetRefresh(); refresh != "" { data["refresh_token"] = refresh } if fn := s.ExtensionFieldsHandler; fn != nil { ext := fn(ti) for k, v := range ext { if _, ok := data[k]; ok { continue } data[k] = v } } return data } // HandleTokenRequest token request handling func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() gt, tgr, err := s.ValidationTokenRequest(r) if err != nil { return s.tokenError(w, err) } ti, err := s.GetAccessToken(ctx, gt, tgr) if err != nil { return s.tokenError(w, err) } return s.token(w, s.GetTokenData(ti), nil) } // GetErrorData get error response data func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { var re errors.Response if v, ok := errors.Descriptions[err]; ok { re.Error = err re.Description = v re.StatusCode = errors.StatusCodes[err] } else { if fn := s.InternalErrorHandler; fn != nil { if v := fn(err); v != nil { re = *v } } if re.Error == nil { re.Error = errors.ErrServerError re.Description = errors.Descriptions[errors.ErrServerError] re.StatusCode = errors.StatusCodes[errors.ErrServerError] } } if fn := s.ResponseErrorHandler; fn != nil { fn(&re) } data := make(map[string]interface{}) if err := re.Error; err != nil { data["error"] = err.Error() } if v := re.ErrorCode; v != 0 { data["error_code"] = v } if v := re.Description; v != "" { data["error_description"] = v } if v := re.URI; v != "" { data["error_uri"] = v } statusCode := http.StatusInternalServerError if v := re.StatusCode; v > 0 { statusCode = v } return data, statusCode, re.Header } // ValidationBearerToken validation the bearer tokens // https://tools.ietf.org/html/rfc6750 func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { ctx := r.Context() accessToken, ok := s.AccessTokenResolveHandler(r) if !ok { return nil, errors.ErrInvalidAccessToken } return s.Manager.LoadAccessToken(ctx, accessToken) }