mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-11-04 02:42:24 -06:00 
			
		
		
		
	
		
			
	
	
		
			199 lines
		
	
	
	
		
			5.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			199 lines
		
	
	
	
		
			5.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| 
								 | 
							
								package oauth2
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import (
							 | 
						||
| 
								 | 
							
									"context"
							 | 
						||
| 
								 | 
							
									"encoding/json"
							 | 
						||
| 
								 | 
							
									"errors"
							 | 
						||
| 
								 | 
							
									"fmt"
							 | 
						||
| 
								 | 
							
									"io"
							 | 
						||
| 
								 | 
							
									"net/http"
							 | 
						||
| 
								 | 
							
									"net/url"
							 | 
						||
| 
								 | 
							
									"strings"
							 | 
						||
| 
								 | 
							
									"time"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									"golang.org/x/oauth2/internal"
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
							 | 
						||
| 
								 | 
							
								const (
							 | 
						||
| 
								 | 
							
									errAuthorizationPending = "authorization_pending"
							 | 
						||
| 
								 | 
							
									errSlowDown             = "slow_down"
							 | 
						||
| 
								 | 
							
									errAccessDenied         = "access_denied"
							 | 
						||
| 
								 | 
							
									errExpiredToken         = "expired_token"
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
							 | 
						||
| 
								 | 
							
								// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
							 | 
						||
| 
								 | 
							
								type DeviceAuthResponse struct {
							 | 
						||
| 
								 | 
							
									// DeviceCode
							 | 
						||
| 
								 | 
							
									DeviceCode string `json:"device_code"`
							 | 
						||
| 
								 | 
							
									// UserCode is the code the user should enter at the verification uri
							 | 
						||
| 
								 | 
							
									UserCode string `json:"user_code"`
							 | 
						||
| 
								 | 
							
									// VerificationURI is where user should enter the user code
							 | 
						||
| 
								 | 
							
									VerificationURI string `json:"verification_uri"`
							 | 
						||
| 
								 | 
							
									// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
							 | 
						||
| 
								 | 
							
									VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
							 | 
						||
| 
								 | 
							
									// Expiry is when the device code and user code expire
							 | 
						||
| 
								 | 
							
									Expiry time.Time `json:"expires_in,omitempty"`
							 | 
						||
| 
								 | 
							
									// Interval is the duration in seconds that Poll should wait between requests
							 | 
						||
| 
								 | 
							
									Interval int64 `json:"interval,omitempty"`
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
							 | 
						||
| 
								 | 
							
									type Alias DeviceAuthResponse
							 | 
						||
| 
								 | 
							
									var expiresIn int64
							 | 
						||
| 
								 | 
							
									if !d.Expiry.IsZero() {
							 | 
						||
| 
								 | 
							
										expiresIn = int64(time.Until(d.Expiry).Seconds())
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									return json.Marshal(&struct {
							 | 
						||
| 
								 | 
							
										ExpiresIn int64 `json:"expires_in,omitempty"`
							 | 
						||
| 
								 | 
							
										*Alias
							 | 
						||
| 
								 | 
							
									}{
							 | 
						||
| 
								 | 
							
										ExpiresIn: expiresIn,
							 | 
						||
| 
								 | 
							
										Alias:     (*Alias)(&d),
							 | 
						||
| 
								 | 
							
									})
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
							 | 
						||
| 
								 | 
							
									type Alias DeviceAuthResponse
							 | 
						||
| 
								 | 
							
									aux := &struct {
							 | 
						||
| 
								 | 
							
										ExpiresIn int64 `json:"expires_in"`
							 | 
						||
| 
								 | 
							
										// workaround misspelling of verification_uri
							 | 
						||
| 
								 | 
							
										VerificationURL string `json:"verification_url"`
							 | 
						||
| 
								 | 
							
										*Alias
							 | 
						||
| 
								 | 
							
									}{
							 | 
						||
| 
								 | 
							
										Alias: (*Alias)(c),
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if err := json.Unmarshal(data, &aux); err != nil {
							 | 
						||
| 
								 | 
							
										return err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if aux.ExpiresIn != 0 {
							 | 
						||
| 
								 | 
							
										c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if c.VerificationURI == "" {
							 | 
						||
| 
								 | 
							
										c.VerificationURI = aux.VerificationURL
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									return nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// DeviceAuth returns a device auth struct which contains a device code
							 | 
						||
| 
								 | 
							
								// and authorization information provided for users to enter on another device.
							 | 
						||
| 
								 | 
							
								func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
							 | 
						||
| 
								 | 
							
									// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
							 | 
						||
| 
								 | 
							
									v := url.Values{
							 | 
						||
| 
								 | 
							
										"client_id": {c.ClientID},
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if len(c.Scopes) > 0 {
							 | 
						||
| 
								 | 
							
										v.Set("scope", strings.Join(c.Scopes, " "))
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									for _, opt := range opts {
							 | 
						||
| 
								 | 
							
										opt.setValue(v)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									return retrieveDeviceAuth(ctx, c, v)
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
							 | 
						||
| 
								 | 
							
									if c.Endpoint.DeviceAuthURL == "" {
							 | 
						||
| 
								 | 
							
										return nil, errors.New("endpoint missing DeviceAuthURL")
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return nil, err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
							 | 
						||
| 
								 | 
							
									req.Header.Set("Accept", "application/json")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									t := time.Now()
							 | 
						||
| 
								 | 
							
									r, err := internal.ContextClient(ctx).Do(req)
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return nil, err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if code := r.StatusCode; code < 200 || code > 299 {
							 | 
						||
| 
								 | 
							
										return nil, &RetrieveError{
							 | 
						||
| 
								 | 
							
											Response: r,
							 | 
						||
| 
								 | 
							
											Body:     body,
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									da := &DeviceAuthResponse{}
							 | 
						||
| 
								 | 
							
									err = json.Unmarshal(body, &da)
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return nil, fmt.Errorf("unmarshal %s", err)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									if !da.Expiry.IsZero() {
							 | 
						||
| 
								 | 
							
										// Make a small adjustment to account for time taken by the request
							 | 
						||
| 
								 | 
							
										da.Expiry = da.Expiry.Add(-time.Since(t))
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return da, nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// DeviceAccessToken polls the server to exchange a device code for a token.
							 | 
						||
| 
								 | 
							
								func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
							 | 
						||
| 
								 | 
							
									if !da.Expiry.IsZero() {
							 | 
						||
| 
								 | 
							
										var cancel context.CancelFunc
							 | 
						||
| 
								 | 
							
										ctx, cancel = context.WithDeadline(ctx, da.Expiry)
							 | 
						||
| 
								 | 
							
										defer cancel()
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
							 | 
						||
| 
								 | 
							
									v := url.Values{
							 | 
						||
| 
								 | 
							
										"client_id":   {c.ClientID},
							 | 
						||
| 
								 | 
							
										"grant_type":  {"urn:ietf:params:oauth:grant-type:device_code"},
							 | 
						||
| 
								 | 
							
										"device_code": {da.DeviceCode},
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if len(c.Scopes) > 0 {
							 | 
						||
| 
								 | 
							
										v.Set("scope", strings.Join(c.Scopes, " "))
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									for _, opt := range opts {
							 | 
						||
| 
								 | 
							
										opt.setValue(v)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// "If no value is provided, clients MUST use 5 as the default."
							 | 
						||
| 
								 | 
							
									// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
							 | 
						||
| 
								 | 
							
									interval := da.Interval
							 | 
						||
| 
								 | 
							
									if interval == 0 {
							 | 
						||
| 
								 | 
							
										interval = 5
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									ticker := time.NewTicker(time.Duration(interval) * time.Second)
							 | 
						||
| 
								 | 
							
									defer ticker.Stop()
							 | 
						||
| 
								 | 
							
									for {
							 | 
						||
| 
								 | 
							
										select {
							 | 
						||
| 
								 | 
							
										case <-ctx.Done():
							 | 
						||
| 
								 | 
							
											return nil, ctx.Err()
							 | 
						||
| 
								 | 
							
										case <-ticker.C:
							 | 
						||
| 
								 | 
							
											tok, err := retrieveToken(ctx, c, v)
							 | 
						||
| 
								 | 
							
											if err == nil {
							 | 
						||
| 
								 | 
							
												return tok, nil
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
											e, ok := err.(*RetrieveError)
							 | 
						||
| 
								 | 
							
											if !ok {
							 | 
						||
| 
								 | 
							
												return nil, err
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											switch e.ErrorCode {
							 | 
						||
| 
								 | 
							
											case errSlowDown:
							 | 
						||
| 
								 | 
							
												// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
							 | 
						||
| 
								 | 
							
												// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
							 | 
						||
| 
								 | 
							
												interval += 5
							 | 
						||
| 
								 | 
							
												ticker.Reset(time.Duration(interval) * time.Second)
							 | 
						||
| 
								 | 
							
											case errAuthorizationPending:
							 | 
						||
| 
								 | 
							
												// Do nothing.
							 | 
						||
| 
								 | 
							
											case errAccessDenied, errExpiredToken:
							 | 
						||
| 
								 | 
							
												fallthrough
							 | 
						||
| 
								 | 
							
											default:
							 | 
						||
| 
								 | 
							
												return tok, err
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								}
							 |