mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-11-04 05:22:25 -06:00 
			
		
		
		
	
		
			
				
	
	
		
			212 lines
		
	
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			212 lines
		
	
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright 2014 The Go Authors. All rights reserved.
 | 
						|
// Use of this source code is governed by a BSD-style
 | 
						|
// license that can be found in the LICENSE file.
 | 
						|
 | 
						|
package oauth2
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"golang.org/x/oauth2/internal"
 | 
						|
)
 | 
						|
 | 
						|
// defaultExpiryDelta determines how earlier a token should be considered
 | 
						|
// expired than its actual expiration time. It is used to avoid late
 | 
						|
// expirations due to client-server time mismatches.
 | 
						|
const defaultExpiryDelta = 10 * time.Second
 | 
						|
 | 
						|
// Token represents the credentials used to authorize
 | 
						|
// the requests to access protected resources on the OAuth 2.0
 | 
						|
// provider's backend.
 | 
						|
//
 | 
						|
// Most users of this package should not access fields of Token
 | 
						|
// directly. They're exported mostly for use by related packages
 | 
						|
// implementing derivative OAuth2 flows.
 | 
						|
type Token struct {
 | 
						|
	// AccessToken is the token that authorizes and authenticates
 | 
						|
	// the requests.
 | 
						|
	AccessToken string `json:"access_token"`
 | 
						|
 | 
						|
	// TokenType is the type of token.
 | 
						|
	// The Type method returns either this or "Bearer", the default.
 | 
						|
	TokenType string `json:"token_type,omitempty"`
 | 
						|
 | 
						|
	// RefreshToken is a token that's used by the application
 | 
						|
	// (as opposed to the user) to refresh the access token
 | 
						|
	// if it expires.
 | 
						|
	RefreshToken string `json:"refresh_token,omitempty"`
 | 
						|
 | 
						|
	// Expiry is the optional expiration time of the access token.
 | 
						|
	//
 | 
						|
	// If zero, TokenSource implementations will reuse the same
 | 
						|
	// token forever and RefreshToken or equivalent
 | 
						|
	// mechanisms for that TokenSource will not be used.
 | 
						|
	Expiry time.Time `json:"expiry,omitempty"`
 | 
						|
 | 
						|
	// ExpiresIn is the OAuth2 wire format "expires_in" field,
 | 
						|
	// which specifies how many seconds later the token expires,
 | 
						|
	// relative to an unknown time base approximately around "now".
 | 
						|
	// It is the application's responsibility to populate
 | 
						|
	// `Expiry` from `ExpiresIn` when required.
 | 
						|
	ExpiresIn int64 `json:"expires_in,omitempty"`
 | 
						|
 | 
						|
	// raw optionally contains extra metadata from the server
 | 
						|
	// when updating a token.
 | 
						|
	raw interface{}
 | 
						|
 | 
						|
	// expiryDelta is used to calculate when a token is considered
 | 
						|
	// expired, by subtracting from Expiry. If zero, defaultExpiryDelta
 | 
						|
	// is used.
 | 
						|
	expiryDelta time.Duration
 | 
						|
}
 | 
						|
 | 
						|
// Type returns t.TokenType if non-empty, else "Bearer".
 | 
						|
func (t *Token) Type() string {
 | 
						|
	if strings.EqualFold(t.TokenType, "bearer") {
 | 
						|
		return "Bearer"
 | 
						|
	}
 | 
						|
	if strings.EqualFold(t.TokenType, "mac") {
 | 
						|
		return "MAC"
 | 
						|
	}
 | 
						|
	if strings.EqualFold(t.TokenType, "basic") {
 | 
						|
		return "Basic"
 | 
						|
	}
 | 
						|
	if t.TokenType != "" {
 | 
						|
		return t.TokenType
 | 
						|
	}
 | 
						|
	return "Bearer"
 | 
						|
}
 | 
						|
 | 
						|
// SetAuthHeader sets the Authorization header to r using the access
 | 
						|
// token in t.
 | 
						|
//
 | 
						|
// This method is unnecessary when using Transport or an HTTP Client
 | 
						|
// returned by this package.
 | 
						|
func (t *Token) SetAuthHeader(r *http.Request) {
 | 
						|
	r.Header.Set("Authorization", t.Type()+" "+t.AccessToken)
 | 
						|
}
 | 
						|
 | 
						|
// WithExtra returns a new Token that's a clone of t, but using the
 | 
						|
// provided raw extra map. This is only intended for use by packages
 | 
						|
// implementing derivative OAuth2 flows.
 | 
						|
func (t *Token) WithExtra(extra interface{}) *Token {
 | 
						|
	t2 := new(Token)
 | 
						|
	*t2 = *t
 | 
						|
	t2.raw = extra
 | 
						|
	return t2
 | 
						|
}
 | 
						|
 | 
						|
// Extra returns an extra field.
 | 
						|
// Extra fields are key-value pairs returned by the server as a
 | 
						|
// part of the token retrieval response.
 | 
						|
func (t *Token) Extra(key string) interface{} {
 | 
						|
	if raw, ok := t.raw.(map[string]interface{}); ok {
 | 
						|
		return raw[key]
 | 
						|
	}
 | 
						|
 | 
						|
	vals, ok := t.raw.(url.Values)
 | 
						|
	if !ok {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	v := vals.Get(key)
 | 
						|
	switch s := strings.TrimSpace(v); strings.Count(s, ".") {
 | 
						|
	case 0: // Contains no "."; try to parse as int
 | 
						|
		if i, err := strconv.ParseInt(s, 10, 64); err == nil {
 | 
						|
			return i
 | 
						|
		}
 | 
						|
	case 1: // Contains a single "."; try to parse as float
 | 
						|
		if f, err := strconv.ParseFloat(s, 64); err == nil {
 | 
						|
			return f
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return v
 | 
						|
}
 | 
						|
 | 
						|
// timeNow is time.Now but pulled out as a variable for tests.
 | 
						|
var timeNow = time.Now
 | 
						|
 | 
						|
// expired reports whether the token is expired.
 | 
						|
// t must be non-nil.
 | 
						|
func (t *Token) expired() bool {
 | 
						|
	if t.Expiry.IsZero() {
 | 
						|
		return false
 | 
						|
	}
 | 
						|
 | 
						|
	expiryDelta := defaultExpiryDelta
 | 
						|
	if t.expiryDelta != 0 {
 | 
						|
		expiryDelta = t.expiryDelta
 | 
						|
	}
 | 
						|
	return t.Expiry.Round(0).Add(-expiryDelta).Before(timeNow())
 | 
						|
}
 | 
						|
 | 
						|
// Valid reports whether t is non-nil, has an AccessToken, and is not expired.
 | 
						|
func (t *Token) Valid() bool {
 | 
						|
	return t != nil && t.AccessToken != "" && !t.expired()
 | 
						|
}
 | 
						|
 | 
						|
// tokenFromInternal maps an *internal.Token struct into
 | 
						|
// a *Token struct.
 | 
						|
func tokenFromInternal(t *internal.Token) *Token {
 | 
						|
	if t == nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	return &Token{
 | 
						|
		AccessToken:  t.AccessToken,
 | 
						|
		TokenType:    t.TokenType,
 | 
						|
		RefreshToken: t.RefreshToken,
 | 
						|
		Expiry:       t.Expiry,
 | 
						|
		raw:          t.Raw,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// retrieveToken takes a *Config and uses that to retrieve an *internal.Token.
 | 
						|
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
 | 
						|
// with an error..
 | 
						|
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
 | 
						|
	tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
 | 
						|
	if err != nil {
 | 
						|
		if rErr, ok := err.(*internal.RetrieveError); ok {
 | 
						|
			return nil, (*RetrieveError)(rErr)
 | 
						|
		}
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	return tokenFromInternal(tk), nil
 | 
						|
}
 | 
						|
 | 
						|
// RetrieveError is the error returned when the token endpoint returns a
 | 
						|
// non-2XX HTTP status code or populates RFC 6749's 'error' parameter.
 | 
						|
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
 | 
						|
type RetrieveError struct {
 | 
						|
	Response *http.Response
 | 
						|
	// Body is the body that was consumed by reading Response.Body.
 | 
						|
	// It may be truncated.
 | 
						|
	Body []byte
 | 
						|
	// ErrorCode is RFC 6749's 'error' parameter.
 | 
						|
	ErrorCode string
 | 
						|
	// ErrorDescription is RFC 6749's 'error_description' parameter.
 | 
						|
	ErrorDescription string
 | 
						|
	// ErrorURI is RFC 6749's 'error_uri' parameter.
 | 
						|
	ErrorURI string
 | 
						|
}
 | 
						|
 | 
						|
func (r *RetrieveError) Error() string {
 | 
						|
	if r.ErrorCode != "" {
 | 
						|
		s := fmt.Sprintf("oauth2: %q", r.ErrorCode)
 | 
						|
		if r.ErrorDescription != "" {
 | 
						|
			s += fmt.Sprintf(" %q", r.ErrorDescription)
 | 
						|
		}
 | 
						|
		if r.ErrorURI != "" {
 | 
						|
			s += fmt.Sprintf(" %q", r.ErrorURI)
 | 
						|
		}
 | 
						|
		return s
 | 
						|
	}
 | 
						|
	return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
 | 
						|
}
 |