mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-11-03 23:02:24 -06:00 
			
		
		
		
	
		
			
				
	
	
		
			198 lines
		
	
	
	
		
			5.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			198 lines
		
	
	
	
		
			5.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package oauth2
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"encoding/json"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"golang.org/x/oauth2/internal"
 | 
						|
)
 | 
						|
 | 
						|
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
 | 
						|
const (
 | 
						|
	errAuthorizationPending = "authorization_pending"
 | 
						|
	errSlowDown             = "slow_down"
 | 
						|
	errAccessDenied         = "access_denied"
 | 
						|
	errExpiredToken         = "expired_token"
 | 
						|
)
 | 
						|
 | 
						|
// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
 | 
						|
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
 | 
						|
type DeviceAuthResponse struct {
 | 
						|
	// DeviceCode
 | 
						|
	DeviceCode string `json:"device_code"`
 | 
						|
	// UserCode is the code the user should enter at the verification uri
 | 
						|
	UserCode string `json:"user_code"`
 | 
						|
	// VerificationURI is where user should enter the user code
 | 
						|
	VerificationURI string `json:"verification_uri"`
 | 
						|
	// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
 | 
						|
	VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
 | 
						|
	// Expiry is when the device code and user code expire
 | 
						|
	Expiry time.Time `json:"expires_in,omitempty"`
 | 
						|
	// Interval is the duration in seconds that Poll should wait between requests
 | 
						|
	Interval int64 `json:"interval,omitempty"`
 | 
						|
}
 | 
						|
 | 
						|
func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
 | 
						|
	type Alias DeviceAuthResponse
 | 
						|
	var expiresIn int64
 | 
						|
	if !d.Expiry.IsZero() {
 | 
						|
		expiresIn = int64(time.Until(d.Expiry).Seconds())
 | 
						|
	}
 | 
						|
	return json.Marshal(&struct {
 | 
						|
		ExpiresIn int64 `json:"expires_in,omitempty"`
 | 
						|
		*Alias
 | 
						|
	}{
 | 
						|
		ExpiresIn: expiresIn,
 | 
						|
		Alias:     (*Alias)(&d),
 | 
						|
	})
 | 
						|
 | 
						|
}
 | 
						|
 | 
						|
func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
 | 
						|
	type Alias DeviceAuthResponse
 | 
						|
	aux := &struct {
 | 
						|
		ExpiresIn int64 `json:"expires_in"`
 | 
						|
		// workaround misspelling of verification_uri
 | 
						|
		VerificationURL string `json:"verification_url"`
 | 
						|
		*Alias
 | 
						|
	}{
 | 
						|
		Alias: (*Alias)(c),
 | 
						|
	}
 | 
						|
	if err := json.Unmarshal(data, &aux); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	if aux.ExpiresIn != 0 {
 | 
						|
		c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
 | 
						|
	}
 | 
						|
	if c.VerificationURI == "" {
 | 
						|
		c.VerificationURI = aux.VerificationURL
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// DeviceAuth returns a device auth struct which contains a device code
 | 
						|
// and authorization information provided for users to enter on another device.
 | 
						|
func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
 | 
						|
	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
 | 
						|
	v := url.Values{
 | 
						|
		"client_id": {c.ClientID},
 | 
						|
	}
 | 
						|
	if len(c.Scopes) > 0 {
 | 
						|
		v.Set("scope", strings.Join(c.Scopes, " "))
 | 
						|
	}
 | 
						|
	for _, opt := range opts {
 | 
						|
		opt.setValue(v)
 | 
						|
	}
 | 
						|
	return retrieveDeviceAuth(ctx, c, v)
 | 
						|
}
 | 
						|
 | 
						|
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
 | 
						|
	if c.Endpoint.DeviceAuthURL == "" {
 | 
						|
		return nil, errors.New("endpoint missing DeviceAuthURL")
 | 
						|
	}
 | 
						|
 | 
						|
	req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
						|
	req.Header.Set("Accept", "application/json")
 | 
						|
 | 
						|
	t := time.Now()
 | 
						|
	r, err := internal.ContextClient(ctx).Do(req)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
 | 
						|
	}
 | 
						|
	if code := r.StatusCode; code < 200 || code > 299 {
 | 
						|
		return nil, &RetrieveError{
 | 
						|
			Response: r,
 | 
						|
			Body:     body,
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	da := &DeviceAuthResponse{}
 | 
						|
	err = json.Unmarshal(body, &da)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("unmarshal %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if !da.Expiry.IsZero() {
 | 
						|
		// Make a small adjustment to account for time taken by the request
 | 
						|
		da.Expiry = da.Expiry.Add(-time.Since(t))
 | 
						|
	}
 | 
						|
 | 
						|
	return da, nil
 | 
						|
}
 | 
						|
 | 
						|
// DeviceAccessToken polls the server to exchange a device code for a token.
 | 
						|
func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
 | 
						|
	if !da.Expiry.IsZero() {
 | 
						|
		var cancel context.CancelFunc
 | 
						|
		ctx, cancel = context.WithDeadline(ctx, da.Expiry)
 | 
						|
		defer cancel()
 | 
						|
	}
 | 
						|
 | 
						|
	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
 | 
						|
	v := url.Values{
 | 
						|
		"client_id":   {c.ClientID},
 | 
						|
		"grant_type":  {"urn:ietf:params:oauth:grant-type:device_code"},
 | 
						|
		"device_code": {da.DeviceCode},
 | 
						|
	}
 | 
						|
	if len(c.Scopes) > 0 {
 | 
						|
		v.Set("scope", strings.Join(c.Scopes, " "))
 | 
						|
	}
 | 
						|
	for _, opt := range opts {
 | 
						|
		opt.setValue(v)
 | 
						|
	}
 | 
						|
 | 
						|
	// "If no value is provided, clients MUST use 5 as the default."
 | 
						|
	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
 | 
						|
	interval := da.Interval
 | 
						|
	if interval == 0 {
 | 
						|
		interval = 5
 | 
						|
	}
 | 
						|
 | 
						|
	ticker := time.NewTicker(time.Duration(interval) * time.Second)
 | 
						|
	defer ticker.Stop()
 | 
						|
	for {
 | 
						|
		select {
 | 
						|
		case <-ctx.Done():
 | 
						|
			return nil, ctx.Err()
 | 
						|
		case <-ticker.C:
 | 
						|
			tok, err := retrieveToken(ctx, c, v)
 | 
						|
			if err == nil {
 | 
						|
				return tok, nil
 | 
						|
			}
 | 
						|
 | 
						|
			e, ok := err.(*RetrieveError)
 | 
						|
			if !ok {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
			switch e.ErrorCode {
 | 
						|
			case errSlowDown:
 | 
						|
				// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
 | 
						|
				// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
 | 
						|
				interval += 5
 | 
						|
				ticker.Reset(time.Duration(interval) * time.Second)
 | 
						|
			case errAuthorizationPending:
 | 
						|
				// Do nothing.
 | 
						|
			case errAccessDenied, errExpiredToken:
 | 
						|
				fallthrough
 | 
						|
			default:
 | 
						|
				return tok, err
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 |