mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 10:52:28 -05:00 
			
		
		
		
	
		
			
	
	
		
			199 lines
		
	
	
	
		
			5.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			199 lines
		
	
	
	
		
			5.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
|  | package oauth2 | ||
|  | 
 | ||
|  | import ( | ||
|  | 	"context" | ||
|  | 	"encoding/json" | ||
|  | 	"errors" | ||
|  | 	"fmt" | ||
|  | 	"io" | ||
|  | 	"net/http" | ||
|  | 	"net/url" | ||
|  | 	"strings" | ||
|  | 	"time" | ||
|  | 
 | ||
|  | 	"golang.org/x/oauth2/internal" | ||
|  | ) | ||
|  | 
 | ||
|  | // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5 | ||
|  | const ( | ||
|  | 	errAuthorizationPending = "authorization_pending" | ||
|  | 	errSlowDown             = "slow_down" | ||
|  | 	errAccessDenied         = "access_denied" | ||
|  | 	errExpiredToken         = "expired_token" | ||
|  | ) | ||
|  | 
 | ||
|  | // DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response | ||
|  | // https://datatracker.ietf.org/doc/html/rfc8628#section-3.2 | ||
|  | type DeviceAuthResponse struct { | ||
|  | 	// DeviceCode | ||
|  | 	DeviceCode string `json:"device_code"` | ||
|  | 	// UserCode is the code the user should enter at the verification uri | ||
|  | 	UserCode string `json:"user_code"` | ||
|  | 	// VerificationURI is where user should enter the user code | ||
|  | 	VerificationURI string `json:"verification_uri"` | ||
|  | 	// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code. | ||
|  | 	VerificationURIComplete string `json:"verification_uri_complete,omitempty"` | ||
|  | 	// Expiry is when the device code and user code expire | ||
|  | 	Expiry time.Time `json:"expires_in,omitempty"` | ||
|  | 	// Interval is the duration in seconds that Poll should wait between requests | ||
|  | 	Interval int64 `json:"interval,omitempty"` | ||
|  | } | ||
|  | 
 | ||
|  | func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) { | ||
|  | 	type Alias DeviceAuthResponse | ||
|  | 	var expiresIn int64 | ||
|  | 	if !d.Expiry.IsZero() { | ||
|  | 		expiresIn = int64(time.Until(d.Expiry).Seconds()) | ||
|  | 	} | ||
|  | 	return json.Marshal(&struct { | ||
|  | 		ExpiresIn int64 `json:"expires_in,omitempty"` | ||
|  | 		*Alias | ||
|  | 	}{ | ||
|  | 		ExpiresIn: expiresIn, | ||
|  | 		Alias:     (*Alias)(&d), | ||
|  | 	}) | ||
|  | 
 | ||
|  | } | ||
|  | 
 | ||
|  | func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error { | ||
|  | 	type Alias DeviceAuthResponse | ||
|  | 	aux := &struct { | ||
|  | 		ExpiresIn int64 `json:"expires_in"` | ||
|  | 		// workaround misspelling of verification_uri | ||
|  | 		VerificationURL string `json:"verification_url"` | ||
|  | 		*Alias | ||
|  | 	}{ | ||
|  | 		Alias: (*Alias)(c), | ||
|  | 	} | ||
|  | 	if err := json.Unmarshal(data, &aux); err != nil { | ||
|  | 		return err | ||
|  | 	} | ||
|  | 	if aux.ExpiresIn != 0 { | ||
|  | 		c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn)) | ||
|  | 	} | ||
|  | 	if c.VerificationURI == "" { | ||
|  | 		c.VerificationURI = aux.VerificationURL | ||
|  | 	} | ||
|  | 	return nil | ||
|  | } | ||
|  | 
 | ||
|  | // DeviceAuth returns a device auth struct which contains a device code | ||
|  | // and authorization information provided for users to enter on another device. | ||
|  | func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) { | ||
|  | 	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1 | ||
|  | 	v := url.Values{ | ||
|  | 		"client_id": {c.ClientID}, | ||
|  | 	} | ||
|  | 	if len(c.Scopes) > 0 { | ||
|  | 		v.Set("scope", strings.Join(c.Scopes, " ")) | ||
|  | 	} | ||
|  | 	for _, opt := range opts { | ||
|  | 		opt.setValue(v) | ||
|  | 	} | ||
|  | 	return retrieveDeviceAuth(ctx, c, v) | ||
|  | } | ||
|  | 
 | ||
|  | func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) { | ||
|  | 	if c.Endpoint.DeviceAuthURL == "" { | ||
|  | 		return nil, errors.New("endpoint missing DeviceAuthURL") | ||
|  | 	} | ||
|  | 
 | ||
|  | 	req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode())) | ||
|  | 	if err != nil { | ||
|  | 		return nil, err | ||
|  | 	} | ||
|  | 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||
|  | 	req.Header.Set("Accept", "application/json") | ||
|  | 
 | ||
|  | 	t := time.Now() | ||
|  | 	r, err := internal.ContextClient(ctx).Do(req) | ||
|  | 	if err != nil { | ||
|  | 		return nil, err | ||
|  | 	} | ||
|  | 
 | ||
|  | 	body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) | ||
|  | 	if err != nil { | ||
|  | 		return nil, fmt.Errorf("oauth2: cannot auth device: %v", err) | ||
|  | 	} | ||
|  | 	if code := r.StatusCode; code < 200 || code > 299 { | ||
|  | 		return nil, &RetrieveError{ | ||
|  | 			Response: r, | ||
|  | 			Body:     body, | ||
|  | 		} | ||
|  | 	} | ||
|  | 
 | ||
|  | 	da := &DeviceAuthResponse{} | ||
|  | 	err = json.Unmarshal(body, &da) | ||
|  | 	if err != nil { | ||
|  | 		return nil, fmt.Errorf("unmarshal %s", err) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	if !da.Expiry.IsZero() { | ||
|  | 		// Make a small adjustment to account for time taken by the request | ||
|  | 		da.Expiry = da.Expiry.Add(-time.Since(t)) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	return da, nil | ||
|  | } | ||
|  | 
 | ||
|  | // DeviceAccessToken polls the server to exchange a device code for a token. | ||
|  | func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) { | ||
|  | 	if !da.Expiry.IsZero() { | ||
|  | 		var cancel context.CancelFunc | ||
|  | 		ctx, cancel = context.WithDeadline(ctx, da.Expiry) | ||
|  | 		defer cancel() | ||
|  | 	} | ||
|  | 
 | ||
|  | 	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4 | ||
|  | 	v := url.Values{ | ||
|  | 		"client_id":   {c.ClientID}, | ||
|  | 		"grant_type":  {"urn:ietf:params:oauth:grant-type:device_code"}, | ||
|  | 		"device_code": {da.DeviceCode}, | ||
|  | 	} | ||
|  | 	if len(c.Scopes) > 0 { | ||
|  | 		v.Set("scope", strings.Join(c.Scopes, " ")) | ||
|  | 	} | ||
|  | 	for _, opt := range opts { | ||
|  | 		opt.setValue(v) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	// "If no value is provided, clients MUST use 5 as the default." | ||
|  | 	// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2 | ||
|  | 	interval := da.Interval | ||
|  | 	if interval == 0 { | ||
|  | 		interval = 5 | ||
|  | 	} | ||
|  | 
 | ||
|  | 	ticker := time.NewTicker(time.Duration(interval) * time.Second) | ||
|  | 	defer ticker.Stop() | ||
|  | 	for { | ||
|  | 		select { | ||
|  | 		case <-ctx.Done(): | ||
|  | 			return nil, ctx.Err() | ||
|  | 		case <-ticker.C: | ||
|  | 			tok, err := retrieveToken(ctx, c, v) | ||
|  | 			if err == nil { | ||
|  | 				return tok, nil | ||
|  | 			} | ||
|  | 
 | ||
|  | 			e, ok := err.(*RetrieveError) | ||
|  | 			if !ok { | ||
|  | 				return nil, err | ||
|  | 			} | ||
|  | 			switch e.ErrorCode { | ||
|  | 			case errSlowDown: | ||
|  | 				// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5 | ||
|  | 				// "the interval MUST be increased by 5 seconds for this and all subsequent requests" | ||
|  | 				interval += 5 | ||
|  | 				ticker.Reset(time.Duration(interval) * time.Second) | ||
|  | 			case errAuthorizationPending: | ||
|  | 				// Do nothing. | ||
|  | 			case errAccessDenied, errExpiredToken: | ||
|  | 				fallthrough | ||
|  | 			default: | ||
|  | 				return tok, err | ||
|  | 			} | ||
|  | 		} | ||
|  | 	} | ||
|  | } |