mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-30 20:32:26 -05:00 
			
		
		
		
	[performance] improved request batching (removes need for queueing) (#1687)
* revamp http client to not limit requests, instead use sender worker Signed-off-by: kim <grufwub@gmail.com> * remove separate sender worker pool, spawn 2*GOMAXPROCS batch senders each time, no need for transport cache sweeping Signed-off-by: kim <grufwub@gmail.com> * improve batch senders to keep popping recipients until remote URL found Signed-off-by: kim <grufwub@gmail.com> * fix recipient looping issue Signed-off-by: kim <grufwub@gmail.com> * fix missing mutex unlock Signed-off-by: kim <grufwub@gmail.com> * move request id ctx key to gtscontext, finish filling out more code comments, add basic support for not logging client IP Signed-off-by: kim <grufwub@gmail.com> * slight code reformatting Signed-off-by: kim <grufwub@gmail.com> * a whitespace Signed-off-by: kim <grufwub@gmail.com> * remove unused code Signed-off-by: kim <grufwub@gmail.com> * add missing license headers Signed-off-by: kim <grufwub@gmail.com> * fix request backoff calculation Signed-off-by: kim <grufwub@gmail.com> --------- Signed-off-by: kim <grufwub@gmail.com>
This commit is contained in:
		
					parent
					
						
							
								6b4f6dc755
							
						
					
				
			
			
				commit
				
					
						6a29c5ffd4
					
				
			
		
					 24 changed files with 431 additions and 493 deletions
				
			
		
							
								
								
									
										2
									
								
								go.mod
									
										
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
										
									
									
									
								
							|  | @ -19,7 +19,6 @@ require ( | |||
| 	github.com/abema/go-mp4 v0.10.1 | ||||
| 	github.com/buckket/go-blurhash v1.1.0 | ||||
| 	github.com/coreos/go-oidc/v3 v3.5.0 | ||||
| 	github.com/cornelk/hashmap v1.0.8 | ||||
| 	github.com/disintegration/imaging v1.6.2 | ||||
| 	github.com/gin-contrib/cors v1.4.0 | ||||
| 	github.com/gin-contrib/gzip v0.0.6 | ||||
|  | @ -82,6 +81,7 @@ require ( | |||
| 	github.com/cilium/ebpf v0.9.1 // indirect | ||||
| 	github.com/containerd/cgroups/v3 v3.0.1 // indirect | ||||
| 	github.com/coreos/go-systemd/v22 v22.3.2 // indirect | ||||
| 	github.com/cornelk/hashmap v1.0.8 // indirect | ||||
| 	github.com/davecgh/go-spew v1.1.1 // indirect | ||||
| 	github.com/docker/go-units v0.4.0 // indirect | ||||
| 	github.com/dsoprea/go-exif/v3 v3.0.0-20210625224831-a6301f85c82b // indirect | ||||
|  |  | |||
|  | @ -24,9 +24,9 @@ import ( | |||
| 	"codeberg.org/gruf/go-kv" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/middleware" | ||||
| ) | ||||
| 
 | ||||
| // TODO: add more templated html pages here for different error types | ||||
|  | @ -51,7 +51,7 @@ func NotFoundHandler(c *gin.Context, instanceGet func(ctx context.Context) (*api | |||
| 
 | ||||
| 		c.HTML(http.StatusNotFound, "404.tmpl", gin.H{ | ||||
| 			"instance":  instance, | ||||
| 			"requestID": middleware.RequestID(ctx), | ||||
| 			"requestID": gtscontext.RequestID(ctx), | ||||
| 		}) | ||||
| 	default: | ||||
| 		c.JSON(http.StatusNotFound, gin.H{ | ||||
|  | @ -76,7 +76,7 @@ func genericErrorHandler(c *gin.Context, instanceGet func(ctx context.Context) ( | |||
| 			"instance":  instance, | ||||
| 			"code":      errWithCode.Code(), | ||||
| 			"error":     errWithCode.Safe(), | ||||
| 			"requestID": middleware.RequestID(ctx), | ||||
| 			"requestID": gtscontext.RequestID(ctx), | ||||
| 		}) | ||||
| 	default: | ||||
| 		c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) | ||||
|  |  | |||
|  | @ -34,10 +34,10 @@ import ( | |||
| 	"github.com/superseriousbusiness/activity/streams/vocab" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/ap" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" | ||||
| ) | ||||
| 
 | ||||
| /* | ||||
|  | @ -216,7 +216,7 @@ func (f *federator) AuthenticateFederatedRequest(ctx context.Context, requestedU | |||
| 		} | ||||
| 
 | ||||
| 		log.Tracef(ctx, "proceeding with dereference for uncached public key %s", requestingPublicKeyID) | ||||
| 		trans, err := f.transportController.NewTransportForUsername(transport.WithFastfail(ctx), requestedUsername) | ||||
| 		trans, err := f.transportController.NewTransportForUsername(gtscontext.SetFastFail(ctx), requestedUsername) | ||||
| 		if err != nil { | ||||
| 			errWithCode := gtserror.NewErrorInternalError(fmt.Errorf("error creating transport for %s: %s", requestedUsername, err)) | ||||
| 			log.Debug(ctx, errWithCode) | ||||
|  |  | |||
|  | @ -29,10 +29,10 @@ import ( | |||
| 	"github.com/superseriousbusiness/activity/streams/vocab" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/ap" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/uris" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||||
| ) | ||||
|  | @ -191,9 +191,8 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr | |||
| 			return ctx, false, err | ||||
| 		} | ||||
| 
 | ||||
| 		// We don't yet have an entry for | ||||
| 		// the instance, go dereference it. | ||||
| 		instance, err := f.GetRemoteInstance(transport.WithFastfail(ctx), username, &url.URL{ | ||||
| 		// we don't have an entry for this instance yet so dereference it | ||||
| 		instance, err := f.GetRemoteInstance(gtscontext.SetFastFail(ctx), username, &url.URL{ | ||||
| 			Scheme: publicKeyOwnerURI.Scheme, | ||||
| 			Host:   publicKeyOwnerURI.Host, | ||||
| 		}) | ||||
|  | @ -212,7 +211,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr | |||
| 	// dereference the remote account (or just get it | ||||
| 	// from the db if we already have it). | ||||
| 	requestingAccount, err := f.GetAccountByURI( | ||||
| 		transport.WithFastfail(ctx), username, publicKeyOwnerURI, false, | ||||
| 		gtscontext.SetFastFail(ctx), username, publicKeyOwnerURI, false, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		if gtserror.StatusCode(err) == http.StatusGone { | ||||
|  |  | |||
|  | @ -17,7 +17,9 @@ | |||
| 
 | ||||
| package gtscontext | ||||
| 
 | ||||
| import "context" | ||||
| import ( | ||||
| 	"context" | ||||
| ) | ||||
| 
 | ||||
| // package private context key type. | ||||
| type ctxkey uint | ||||
|  | @ -26,8 +28,54 @@ const ( | |||
| 	// context keys. | ||||
| 	_ ctxkey = iota | ||||
| 	barebonesKey | ||||
| 	fastFailKey | ||||
| 	pubKeyIDKey | ||||
| 	requestIDKey | ||||
| ) | ||||
| 
 | ||||
| // RequestID returns the request ID associated with context. This value will usually | ||||
| // be set by the request ID middleware handler, either pulling an existing supplied | ||||
| // value from request headers, or generating a unique new entry. This is useful for | ||||
| // tying together log entries associated with an original incoming request. | ||||
| func RequestID(ctx context.Context) string { | ||||
| 	id, _ := ctx.Value(requestIDKey).(string) | ||||
| 	return id | ||||
| } | ||||
| 
 | ||||
| // SetRequestID stores the given request ID value and returns the wrapped | ||||
| // context. See RequestID() for further information on the request ID value. | ||||
| func SetRequestID(ctx context.Context, id string) context.Context { | ||||
| 	return context.WithValue(ctx, requestIDKey, id) | ||||
| } | ||||
| 
 | ||||
| // PublicKeyID returns the public key ID (URI) associated with context. This | ||||
| // value is useful for logging situations in which a given public key URI is | ||||
| // relevant, e.g. for outgoing requests being signed by the given key. | ||||
| func PublicKeyID(ctx context.Context) string { | ||||
| 	id, _ := ctx.Value(pubKeyIDKey).(string) | ||||
| 	return id | ||||
| } | ||||
| 
 | ||||
| // SetPublicKeyID stores the given public key ID value and returns the wrapped | ||||
| // context. See PublicKeyID() for further information on the public key ID value. | ||||
| func SetPublicKeyID(ctx context.Context, id string) context.Context { | ||||
| 	return context.WithValue(ctx, pubKeyIDKey, id) | ||||
| } | ||||
| 
 | ||||
| // IsFastFail returns whether the "fastfail" context key has been set. This | ||||
| // can be used to indicate to an http client, for example, that the result | ||||
| // of an outgoing request is time sensitive and so not to bother with retries. | ||||
| func IsFastfail(ctx context.Context) bool { | ||||
| 	_, ok := ctx.Value(fastFailKey).(struct{}) | ||||
| 	return ok | ||||
| } | ||||
| 
 | ||||
| // SetFastFail sets the "fastfail" context flag and returns this wrapped context. | ||||
| // See IsFastFail() for further information on the "fastfail" context flag. | ||||
| func SetFastFail(ctx context.Context) context.Context { | ||||
| 	return context.WithValue(ctx, fastFailKey, struct{}{}) | ||||
| } | ||||
| 
 | ||||
| // Barebones returns whether the "barebones" context key has been set. This | ||||
| // can be used to indicate to the database, for example, that only a barebones | ||||
| // model need be returned, Allowing it to skip populating sub models. | ||||
|  | @ -37,7 +85,7 @@ func Barebones(ctx context.Context) bool { | |||
| } | ||||
| 
 | ||||
| // SetBarebones sets the "barebones" context flag and returns this wrapped context. | ||||
| // See Barebones() for further information on the "barebones" context flag.. | ||||
| // See Barebones() for further information on the "barebones" context flag. | ||||
| func SetBarebones(ctx context.Context) context.Context { | ||||
| 	return context.WithValue(ctx, barebonesKey, struct{}{}) | ||||
| } | ||||
|  |  | |||
							
								
								
									
										44
									
								
								internal/gtscontext/log_hooks.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								internal/gtscontext/log_hooks.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,44 @@ | |||
| // GoToSocial | ||||
| // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||
| // SPDX-License-Identifier: AGPL-3.0-or-later | ||||
| // | ||||
| // This program is free software: you can redistribute it and/or modify | ||||
| // it under the terms of the GNU Affero General Public License as published by | ||||
| // the Free Software Foundation, either version 3 of the License, or | ||||
| // (at your option) any later version. | ||||
| // | ||||
| // This program is distributed in the hope that it will be useful, | ||||
| // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
| // GNU Affero General Public License for more details. | ||||
| // | ||||
| // You should have received a copy of the GNU Affero General Public License | ||||
| // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| 
 | ||||
| package gtscontext | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 
 | ||||
| 	"codeberg.org/gruf/go-kv" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||
| ) | ||||
| 
 | ||||
| func init() { | ||||
| 	// Add our required logging hooks on application initialization. | ||||
| 	// | ||||
| 	// Request ID middleware hook. | ||||
| 	log.Hook(func(ctx context.Context, kvs []kv.Field) []kv.Field { | ||||
| 		if id := RequestID(ctx); id != "" { | ||||
| 			return append(kvs, kv.Field{K: "requestID", V: id}) | ||||
| 		} | ||||
| 		return kvs | ||||
| 	}) | ||||
| 	// Client IP middleware hook. | ||||
| 	log.Hook(func(ctx context.Context, kvs []kv.Field) []kv.Field { | ||||
| 		if id := PublicKeyID(ctx); id != "" { | ||||
| 			return append(kvs, kv.Field{K: "pubKeyID", V: id}) | ||||
| 		} | ||||
| 		return kvs | ||||
| 	}) | ||||
| } | ||||
|  | @ -18,31 +18,39 @@ | |||
| package httpclient | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/x509" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/netip" | ||||
| 	"runtime" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"codeberg.org/gruf/go-bytesize" | ||||
| 	"codeberg.org/gruf/go-byteutil" | ||||
| 	"codeberg.org/gruf/go-cache/v3" | ||||
| 	errorsv2 "codeberg.org/gruf/go-errors/v2" | ||||
| 	"codeberg.org/gruf/go-kv" | ||||
| 	"github.com/cornelk/hashmap" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||
| ) | ||||
| 
 | ||||
| // ErrInvalidRequest is returned if a given HTTP request is invalid and cannot be performed. | ||||
| var ErrInvalidRequest = errors.New("invalid http request") | ||||
| var ( | ||||
| 	// ErrInvalidNetwork is returned if the request would not be performed over TCP | ||||
| 	ErrInvalidNetwork = errors.New("invalid network type") | ||||
| 
 | ||||
| // ErrInvalidNetwork is returned if the request would not be performed over TCP | ||||
| var ErrInvalidNetwork = errors.New("invalid network type") | ||||
| 	// ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net. | ||||
| 	ErrReservedAddr = errors.New("dial within blocked / reserved IP range") | ||||
| 
 | ||||
| // ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net. | ||||
| var ErrReservedAddr = errors.New("dial within blocked / reserved IP range") | ||||
| 
 | ||||
| // ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB). | ||||
| var ErrBodyTooLarge = errors.New("body size too large") | ||||
| 	// ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB). | ||||
| 	ErrBodyTooLarge = errors.New("body size too large") | ||||
| ) | ||||
| 
 | ||||
| // Config provides configuration details for setting up a new | ||||
| // instance of httpclient.Client{}. Within are a subset of the | ||||
|  | @ -83,13 +91,10 @@ type Config struct { | |||
| //     cases to protect against forged / unknown content-lengths | ||||
| //   - protection from server side request forgery (SSRF) by only dialing | ||||
| //     out to known public IP prefixes, configurable with allows/blocks | ||||
| //   - limit number of concurrent requests, else blocking until a slot | ||||
| //     is available (context channels still respected) | ||||
| type Client struct { | ||||
| 	client   http.Client | ||||
| 	queue  *hashmap.Map[string, chan struct{}] | ||||
| 	bmax   int64 // max response body size | ||||
| 	cmax   int   // max open conns per host | ||||
| 	badHosts cache.Cache[string, struct{}] | ||||
| 	bodyMax  int64 | ||||
| } | ||||
| 
 | ||||
| // New returns a new instance of Client initialized using configuration. | ||||
|  | @ -109,28 +114,26 @@ func New(cfg Config) *Client { | |||
| 	} | ||||
| 
 | ||||
| 	if cfg.MaxIdleConns <= 0 { | ||||
| 		// By default base this value on MaxOpenConns | ||||
| 		// By default base this value on MaxOpenConns. | ||||
| 		cfg.MaxIdleConns = cfg.MaxOpenConnsPerHost * 10 | ||||
| 	} | ||||
| 
 | ||||
| 	if cfg.MaxBodySize <= 0 { | ||||
| 		// By default set this to a reasonable 40MB | ||||
| 		// By default set this to a reasonable 40MB. | ||||
| 		cfg.MaxBodySize = int64(40 * bytesize.MiB) | ||||
| 	} | ||||
| 
 | ||||
| 	// Protect dialer with IP range sanitizer | ||||
| 	// Protect dialer with IP range sanitizer. | ||||
| 	d.Control = (&sanitizer{ | ||||
| 		allow: cfg.AllowRanges, | ||||
| 		block: cfg.BlockRanges, | ||||
| 	}).Sanitize | ||||
| 
 | ||||
| 	// Prepare client fields | ||||
| 	// Prepare client fields. | ||||
| 	c.client.Timeout = cfg.Timeout | ||||
| 	c.cmax = cfg.MaxOpenConnsPerHost | ||||
| 	c.bmax = cfg.MaxBodySize | ||||
| 	c.queue = hashmap.New[string, chan struct{}]() | ||||
| 	c.bodyMax = cfg.MaxBodySize | ||||
| 
 | ||||
| 	// Set underlying HTTP client roundtripper | ||||
| 	// Set underlying HTTP client roundtripper. | ||||
| 	c.client.Transport = &http.Transport{ | ||||
| 		Proxy:                 http.ProxyFromEnvironment, | ||||
| 		ForceAttemptHTTP2:     true, | ||||
|  | @ -144,90 +147,185 @@ func New(cfg Config) *Client { | |||
| 		DisableCompression:    cfg.DisableCompression, | ||||
| 	} | ||||
| 
 | ||||
| 	// Initiate outgoing bad hosts lookup cache. | ||||
| 	c.badHosts = cache.New[string, struct{}](0, 1000, 0) | ||||
| 	c.badHosts.SetTTL(15*time.Minute, false) | ||||
| 	if !c.badHosts.Start(time.Minute) { | ||||
| 		log.Panic(nil, "failed to start transport controller cache") | ||||
| 	} | ||||
| 
 | ||||
| 	return &c | ||||
| } | ||||
| 
 | ||||
| // Do will perform given request when an available slot in the queue is available, | ||||
| // and block until this time. For returned values, this follows the same semantics | ||||
| // as the standard http.Client{}.Do() implementation except that response body will | ||||
| // be wrapped by an io.LimitReader() to limit response body sizes. | ||||
| func (c *Client) Do(req *http.Request) (*http.Response, error) { | ||||
| 	// Ensure this is a valid request | ||||
| 	if err := ValidateRequest(req); err != nil { | ||||
| // Do ... | ||||
| func (c *Client) Do(r *http.Request) (*http.Response, error) { | ||||
| 	return c.DoSigned(r, func(r *http.Request) error { | ||||
| 		return nil // no request signing | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // DoSigned ... | ||||
| func (c *Client) DoSigned(r *http.Request, sign SignFunc) (*http.Response, error) { | ||||
| 	const ( | ||||
| 		// max no. attempts. | ||||
| 		maxRetries = 5 | ||||
| 
 | ||||
| 		// starting backoff duration. | ||||
| 		baseBackoff = 2 * time.Second | ||||
| 	) | ||||
| 
 | ||||
| 	// Get request hostname. | ||||
| 	host := r.URL.Hostname() | ||||
| 
 | ||||
| 	// Check whether request should fast fail. | ||||
| 	fastFail := gtscontext.IsFastfail(r.Context()) | ||||
| 	if !fastFail { | ||||
| 		// Check if recently reached max retries for this host | ||||
| 		// so we don't bother with a retry-backoff loop. The only | ||||
| 		// errors that are retried upon are server failure and | ||||
| 		// domain resolution type errors, so this cached result | ||||
| 		// indicates this server is likely having issues. | ||||
| 		fastFail = c.badHosts.Has(host) | ||||
| 	} | ||||
| 
 | ||||
| 	// Start a log entry for this request | ||||
| 	l := log.WithContext(r.Context()). | ||||
| 		WithFields(kv.Fields{ | ||||
| 			{"method", r.Method}, | ||||
| 			{"url", r.URL.String()}, | ||||
| 		}...) | ||||
| 
 | ||||
| 	for i := 0; i < maxRetries; i++ { | ||||
| 		var backoff time.Duration | ||||
| 
 | ||||
| 		// Reset signing header fields | ||||
| 		now := time.Now().UTC() | ||||
| 		r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT") | ||||
| 		r.Header.Del("Signature") | ||||
| 		r.Header.Del("Digest") | ||||
| 
 | ||||
| 		// Rewind body reader and content-length if set. | ||||
| 		if rc, ok := r.Body.(*byteutil.ReadNopCloser); ok { | ||||
| 			r.ContentLength = int64(rc.Len()) | ||||
| 			rc.Rewind() | ||||
| 		} | ||||
| 
 | ||||
| 		// Sign the outgoing request. | ||||
| 		if err := sign(r); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 	// Get host's wait queue | ||||
| 	wait := c.wait(req.Host) | ||||
| 		l.Infof("performing request") | ||||
| 
 | ||||
| 	var ok bool | ||||
| 		// Perform the request. | ||||
| 		rsp, err := c.do(r) | ||||
| 		if err == nil { //nolint:gocritic | ||||
| 
 | ||||
| 			// TooManyRequest means we need to slow | ||||
| 			// down and retry our request. Codes over | ||||
| 			// 500 generally indicate temp. outages. | ||||
| 			if code := rsp.StatusCode; code < 500 && | ||||
| 				code != http.StatusTooManyRequests { | ||||
| 				return rsp, nil | ||||
| 			} | ||||
| 
 | ||||
| 			// Generate error from status code for logging | ||||
| 			err = errors.New(`http response "` + rsp.Status + `"`) | ||||
| 
 | ||||
| 			// Search for a provided "Retry-After" header value. | ||||
| 			if after := rsp.Header.Get("Retry-After"); after != "" { | ||||
| 
 | ||||
| 				if u, _ := strconv.ParseUint(after, 10, 32); u != 0 { | ||||
| 					// An integer number of backoff seconds was provided. | ||||
| 					backoff = time.Duration(u) * time.Second | ||||
| 				} else if at, _ := http.ParseTime(after); !at.Before(now) { | ||||
| 					// An HTTP formatted future date-time was provided. | ||||
| 					backoff = at.Sub(now) | ||||
| 				} | ||||
| 
 | ||||
| 				// Don't let their provided backoff exceed our max. | ||||
| 				if max := baseBackoff * maxRetries; backoff > max { | ||||
| 					backoff = max | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 		} else if errorsv2.Is(err, | ||||
| 			context.DeadlineExceeded, | ||||
| 			context.Canceled, | ||||
| 			ErrBodyTooLarge, | ||||
| 			ErrReservedAddr, | ||||
| 		) { | ||||
| 			// Return on non-retryable errors | ||||
| 			return nil, err | ||||
| 		} else if strings.Contains(err.Error(), "stopped after 10 redirects") { | ||||
| 			// Don't bother if net/http returned after too many redirects | ||||
| 			return nil, err | ||||
| 		} else if errors.As(err, &x509.UnknownAuthorityError{}) { | ||||
| 			// Unknown authority errors we do NOT recover from | ||||
| 			return nil, err | ||||
| 		} else if dnserr := (*net.DNSError)(nil); // nocollapse | ||||
| 		errors.As(err, &dnserr) && dnserr.IsNotFound { | ||||
| 			// DNS lookup failure, this domain does not exist | ||||
| 			return nil, gtserror.SetNotFound(err) | ||||
| 		} | ||||
| 
 | ||||
| 		if fastFail { | ||||
| 			// on fast-fail, don't bother backoff/retry | ||||
| 			return nil, fmt.Errorf("%w (fast fail)", err) | ||||
| 		} | ||||
| 
 | ||||
| 		if backoff == 0 { | ||||
| 			// No retry-after found, set our predefined | ||||
| 			// backoff according to a multiplier of 2^n. | ||||
| 			backoff = baseBackoff * 1 << (i + 1) | ||||
| 		} | ||||
| 
 | ||||
| 		l.Errorf("backing off for %s after http request error: %v", backoff, err) | ||||
| 
 | ||||
| 		select { | ||||
| 	// Quickly try grab a spot | ||||
| 	case wait <- struct{}{}: | ||||
| 		// it's our turn! | ||||
| 		ok = true | ||||
| 		// Request ctx cancelled | ||||
| 		case <-r.Context().Done(): | ||||
| 			return nil, r.Context().Err() | ||||
| 
 | ||||
| 		// NOTE: | ||||
| 		// Ideally here we would set the slot release to happen either | ||||
| 		// on error return, or via callback from the response body closer. | ||||
| 		// However when implementing this, there appear deadlocks between | ||||
| 		// the channel queue here and the media manager worker pool. So | ||||
| 		// currently we only place a limit on connections dialing out, but | ||||
| 		// there may still be more connections open than len(c.queue) given | ||||
| 		// that connections may not be closed until response body is closed. | ||||
| 		// The current implementation will reduce the viability of denial of | ||||
| 		// service attacks, but if there are future issues heed this advice :] | ||||
| 		defer func() { <-wait }() | ||||
| 	default: | ||||
| 	} | ||||
| 
 | ||||
| 	if !ok { | ||||
| 		// No spot acquired, log warning | ||||
| 		log.WithContext(req.Context()). | ||||
| 			WithFields(kv.Fields{ | ||||
| 				{K: "queue", V: len(wait)}, | ||||
| 				{K: "method", V: req.Method}, | ||||
| 				{K: "host", V: req.Host}, | ||||
| 				{K: "uri", V: req.URL.RequestURI()}, | ||||
| 			}...).Warn("full request queue") | ||||
| 
 | ||||
| 		select { | ||||
| 		case <-req.Context().Done(): | ||||
| 			// the request was canceled before we | ||||
| 			// got to our turn: no need to release | ||||
| 			return nil, req.Context().Err() | ||||
| 		case wait <- struct{}{}: | ||||
| 			defer func() { <-wait }() | ||||
| 		// Backoff for some time | ||||
| 		case <-time.After(backoff): | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Perform the HTTP request | ||||
| 	// Add "bad" entry for this host. | ||||
| 	c.badHosts.Set(host, struct{}{}) | ||||
| 
 | ||||
| 	return nil, errors.New("transport reached max retries") | ||||
| } | ||||
| 
 | ||||
| // do ... | ||||
| func (c *Client) do(req *http.Request) (*http.Response, error) { | ||||
| 	// Perform the HTTP request. | ||||
| 	rsp, err := c.client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// Check response body not too large | ||||
| 	if rsp.ContentLength > c.bmax { | ||||
| 	// Check response body not too large. | ||||
| 	if rsp.ContentLength > c.bodyMax { | ||||
| 		return nil, ErrBodyTooLarge | ||||
| 	} | ||||
| 
 | ||||
| 	// Seperate the body implementers | ||||
| 	// Seperate the body implementers. | ||||
| 	rbody := (io.Reader)(rsp.Body) | ||||
| 	cbody := (io.Closer)(rsp.Body) | ||||
| 
 | ||||
| 	var limit int64 | ||||
| 
 | ||||
| 	if limit = rsp.ContentLength; limit < 0 { | ||||
| 		// If unknown, use max as reader limit | ||||
| 		limit = c.bmax | ||||
| 		// If unknown, use max as reader limit. | ||||
| 		limit = c.bodyMax | ||||
| 	} | ||||
| 
 | ||||
| 	// Don't trust them, limit body reads | ||||
| 	// Don't trust them, limit body reads. | ||||
| 	rbody = io.LimitReader(rbody, limit) | ||||
| 
 | ||||
| 	// Wrap body with limit | ||||
| 	// Wrap body with limit. | ||||
| 	rsp.Body = &struct { | ||||
| 		io.Reader | ||||
| 		io.Closer | ||||
|  | @ -235,17 +333,3 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { | |||
| 
 | ||||
| 	return rsp, nil | ||||
| } | ||||
| 
 | ||||
| // wait acquires the 'wait' queue for the given host string, or allocates new. | ||||
| func (c *Client) wait(host string) chan struct{} { | ||||
| 	// Look for an existing queue | ||||
| 	queue, ok := c.queue.Get(host) | ||||
| 	if ok { | ||||
| 		return queue | ||||
| 	} | ||||
| 
 | ||||
| 	// Allocate a new host queue (or return a sneaky existing one). | ||||
| 	queue, _ = c.queue.GetOrInsert(host, make(chan struct{}, c.cmax)) | ||||
| 
 | ||||
| 	return queue | ||||
| } | ||||
|  |  | |||
|  | @ -48,14 +48,6 @@ var bodies = []string{ | |||
| 	"body with\r\nnewlines", | ||||
| } | ||||
| 
 | ||||
| // Note: | ||||
| // There is no test for the .MaxOpenConns implementation | ||||
| // in the httpclient.Client{}, due to the difficult to test | ||||
| // this. The block is only held for the actual dial out to | ||||
| // the connection, so the usual test of blocking and holding | ||||
| // open this queue slot to check we can't open another isn't | ||||
| // an easy test here. | ||||
| 
 | ||||
| func TestHTTPClientSmallBody(t *testing.T) { | ||||
| 	for _, body := range bodies { | ||||
| 		_TestHTTPClientWithBody(t, []byte(body), int(^uint16(0))) | ||||
|  |  | |||
|  | @ -1,62 +0,0 @@ | |||
| // GoToSocial | ||||
| // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||
| // SPDX-License-Identifier: AGPL-3.0-or-later | ||||
| // | ||||
| // This program is free software: you can redistribute it and/or modify | ||||
| // it under the terms of the GNU Affero General Public License as published by | ||||
| // the Free Software Foundation, either version 3 of the License, or | ||||
| // (at your option) any later version. | ||||
| // | ||||
| // This program is distributed in the hope that it will be useful, | ||||
| // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
| // GNU Affero General Public License for more details. | ||||
| // | ||||
| // You should have received a copy of the GNU Affero General Public License | ||||
| // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| 
 | ||||
| package httpclient | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"golang.org/x/net/http/httpguts" | ||||
| ) | ||||
| 
 | ||||
| // ValidateRequest performs the same request validation logic found in the default | ||||
| // net/http.Transport{}.roundTrip() function, but pulls it out into this separate | ||||
| // function allowing validation errors to be wrapped under a single error type. | ||||
| func ValidateRequest(r *http.Request) error { | ||||
| 	switch { | ||||
| 	case r.URL == nil: | ||||
| 		return fmt.Errorf("%w: nil url", ErrInvalidRequest) | ||||
| 	case r.Header == nil: | ||||
| 		return fmt.Errorf("%w: nil header", ErrInvalidRequest) | ||||
| 	case r.URL.Host == "": | ||||
| 		return fmt.Errorf("%w: empty url host", ErrInvalidRequest) | ||||
| 	case r.URL.Scheme != "http" && r.URL.Scheme != "https": | ||||
| 		return fmt.Errorf("%w: unsupported protocol %q", ErrInvalidRequest, r.URL.Scheme) | ||||
| 	case strings.IndexFunc(r.Method, func(r rune) bool { return !httpguts.IsTokenRune(r) }) != -1: | ||||
| 		return fmt.Errorf("%w: invalid method %q", ErrInvalidRequest, r.Method) | ||||
| 	} | ||||
| 
 | ||||
| 	for key, values := range r.Header { | ||||
| 		// Check field key name is valid | ||||
| 		if !httpguts.ValidHeaderFieldName(key) { | ||||
| 			return fmt.Errorf("%w: invalid header field name %q", ErrInvalidRequest, key) | ||||
| 		} | ||||
| 
 | ||||
| 		// Check each field value is valid | ||||
| 		for i := 0; i < len(values); i++ { | ||||
| 			if !httpguts.ValidHeaderFieldValue(values[i]) { | ||||
| 				return fmt.Errorf("%w: invalid header field value %q", ErrInvalidRequest, values[i]) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// ps. kim wrote this | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
|  | @ -15,19 +15,14 @@ | |||
| // You should have received a copy of the GNU Affero General Public License | ||||
| // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| 
 | ||||
| package transport_test | ||||
| package httpclient | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
| import "net/http" | ||||
| 
 | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" | ||||
| ) | ||||
| // SignFunc is a function signature that provides request signing. | ||||
| type SignFunc func(r *http.Request) error | ||||
| 
 | ||||
| func TestFastFailContext(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	ctx = transport.WithFastfail(ctx) | ||||
| 	if !transport.IsFastfail(ctx) { | ||||
| 		t.Fatal("failed to set fast-fail context key") | ||||
| 	} | ||||
| type SigningClient interface { | ||||
| 	Do(r *http.Request) (*http.Response, error) | ||||
| 	DoSigned(r *http.Request, sign SignFunc) (*http.Response, error) | ||||
| } | ||||
|  | @ -34,7 +34,7 @@ import ( | |||
| func Logger() gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
| 		// Initialize the logging fields | ||||
| 		fields := make(kv.Fields, 6, 7) | ||||
| 		fields := make(kv.Fields, 5, 7) | ||||
| 
 | ||||
| 		// Determine pre-handler time | ||||
| 		before := time.Now() | ||||
|  | @ -68,11 +68,18 @@ func Logger() gin.HandlerFunc { | |||
| 
 | ||||
| 			// Set request logging fields | ||||
| 			fields[0] = kv.Field{"latency", time.Since(before)} | ||||
| 			fields[1] = kv.Field{"clientIP", c.ClientIP()} | ||||
| 			fields[2] = kv.Field{"userAgent", c.Request.UserAgent()} | ||||
| 			fields[3] = kv.Field{"method", c.Request.Method} | ||||
| 			fields[4] = kv.Field{"statusCode", code} | ||||
| 			fields[5] = kv.Field{"path", path} | ||||
| 			fields[1] = kv.Field{"userAgent", c.Request.UserAgent()} | ||||
| 			fields[2] = kv.Field{"method", c.Request.Method} | ||||
| 			fields[3] = kv.Field{"statusCode", code} | ||||
| 			fields[4] = kv.Field{"path", path} | ||||
| 			if includeClientIP := true; includeClientIP { | ||||
| 				// TODO: make this configurable. | ||||
| 				// | ||||
| 				// Include clientIP if enabled. | ||||
| 				fields = append(fields, kv.Field{ | ||||
| 					"clientIP", c.ClientIP(), | ||||
| 				}) | ||||
| 			} | ||||
| 
 | ||||
| 			// Create log entry with fields | ||||
| 			l := log.WithContext(c.Request.Context()). | ||||
|  |  | |||
|  | @ -19,7 +19,6 @@ package middleware | |||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"context" | ||||
| 	"crypto/rand" | ||||
| 	"encoding/base32" | ||||
| 	"encoding/binary" | ||||
|  | @ -27,17 +26,11 @@ import ( | |||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"codeberg.org/gruf/go-kv" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||
| ) | ||||
| 
 | ||||
| type ctxType string | ||||
| 
 | ||||
| var ( | ||||
| 	// ridCtxKey is the key underwhich we store request IDs in a context. | ||||
| 	ridCtxKey ctxType = "id" | ||||
| 
 | ||||
| 	// crand provides buffered reads of random input. | ||||
| 	crand = bufio.NewReader(rand.Reader) | ||||
| 	mrand sync.Mutex | ||||
|  | @ -69,22 +62,8 @@ func generateID() string { | |||
| 	return base32enc.EncodeToString(b) | ||||
| } | ||||
| 
 | ||||
| // RequestID fetches the stored request ID from context. | ||||
| func RequestID(ctx context.Context) string { | ||||
| 	id, _ := ctx.Value(ridCtxKey).(string) | ||||
| 	return id | ||||
| } | ||||
| 
 | ||||
| // AddRequestID returns a gin middleware which adds a unique ID to each request (both response header and context). | ||||
| func AddRequestID(header string) gin.HandlerFunc { | ||||
| 	log.Hook(func(ctx context.Context, kvs []kv.Field) []kv.Field { | ||||
| 		if id, _ := ctx.Value(ridCtxKey).(string); id != "" { | ||||
| 			// Add stored request ID to log entry fields. | ||||
| 			return append(kvs, kv.Field{K: "requestID", V: id}) | ||||
| 		} | ||||
| 		return kvs | ||||
| 	}) | ||||
| 
 | ||||
| 	return func(c *gin.Context) { | ||||
| 		// Look for existing ID. | ||||
| 		id := c.GetHeader(header) | ||||
|  | @ -100,8 +79,8 @@ func AddRequestID(header string) gin.HandlerFunc { | |||
| 			c.Request.Header.Set(header, id) | ||||
| 		} | ||||
| 
 | ||||
| 		// Store request ID in new request ctx and set new gin request obj. | ||||
| 		ctx := context.WithValue(c.Request.Context(), ridCtxKey, id) | ||||
| 		// Store request ID in new request context and set on gin ctx. | ||||
| 		ctx := gtscontext.SetRequestID(c.Request.Context(), id) | ||||
| 		c.Request = c.Request.WithContext(ctx) | ||||
| 
 | ||||
| 		// Set the request ID in the rsp header. | ||||
|  |  | |||
|  | @ -25,9 +25,9 @@ import ( | |||
| 
 | ||||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" | ||||
| ) | ||||
| 
 | ||||
| // Get processes the given request for account information. | ||||
|  | @ -96,7 +96,7 @@ func (p *Processor) getFor(ctx context.Context, requestingAccount *gtsmodel.Acco | |||
| 		} | ||||
| 
 | ||||
| 		a, err := p.federator.GetAccountByURI( | ||||
| 			transport.WithFastfail(ctx), requestingAccount.Username, targetAccountURI, true, | ||||
| 			gtscontext.SetFastFail(ctx), requestingAccount.Username, targetAccountURI, true, | ||||
| 		) | ||||
| 		if err == nil { | ||||
| 			targetAccount = a | ||||
|  |  | |||
|  | @ -22,9 +22,9 @@ import ( | |||
| 	"fmt" | ||||
| 	"net/url" | ||||
| 
 | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" | ||||
| ) | ||||
| 
 | ||||
| func (p *Processor) authenticate(ctx context.Context, requestedUsername string) (requestedAccount, requestingAccount *gtsmodel.Account, errWithCode gtserror.WithCode) { | ||||
|  | @ -40,7 +40,7 @@ func (p *Processor) authenticate(ctx context.Context, requestedUsername string) | |||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if requestingAccount, err = p.federator.GetAccountByURI(transport.WithFastfail(ctx), requestedUsername, requestingAccountURI, false); err != nil { | ||||
| 	if requestingAccount, err = p.federator.GetAccountByURI(gtscontext.SetFastFail(ctx), requestedUsername, requestingAccountURI, false); err != nil { | ||||
| 		errWithCode = gtserror.NewErrorUnauthorized(err) | ||||
| 		return | ||||
| 	} | ||||
|  |  | |||
|  | @ -24,8 +24,8 @@ import ( | |||
| 
 | ||||
| 	"github.com/superseriousbusiness/activity/streams" | ||||
| 	"github.com/superseriousbusiness/activity/streams/vocab" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/uris" | ||||
| ) | ||||
| 
 | ||||
|  | @ -56,7 +56,7 @@ func (p *Processor) UserGet(ctx context.Context, requestedUsername string, reque | |||
| 		// if we're not already handshaking/dereferencing a remote account, dereference it now | ||||
| 		if !p.federator.Handshaking(requestedUsername, requestingAccountURI) { | ||||
| 			requestingAccount, err := p.federator.GetAccountByURI( | ||||
| 				transport.WithFastfail(ctx), requestedUsername, requestingAccountURI, false, | ||||
| 				gtscontext.SetFastFail(ctx), requestedUsername, requestingAccountURI, false, | ||||
| 			) | ||||
| 			if err != nil { | ||||
| 				return nil, gtserror.NewErrorUnauthorized(err) | ||||
|  |  | |||
|  | @ -25,10 +25,10 @@ import ( | |||
| 	"strings" | ||||
| 
 | ||||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/media" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/uris" | ||||
| ) | ||||
| 
 | ||||
|  | @ -157,7 +157,7 @@ func (p *Processor) getAttachmentContent(ctx context.Context, requestingAccount | |||
| 			if err != nil { | ||||
| 				return nil, 0, err | ||||
| 			} | ||||
| 			return t.DereferenceMedia(transport.WithFastfail(innerCtx), remoteMediaIRI) | ||||
| 			return t.DereferenceMedia(gtscontext.SetFastFail(innerCtx), remoteMediaIRI) | ||||
| 		} | ||||
| 
 | ||||
| 		// Start recaching this media with the prepared data function. | ||||
|  |  | |||
|  | @ -30,11 +30,11 @@ import ( | |||
| 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||||
| ) | ||||
| 
 | ||||
|  | @ -226,14 +226,14 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a | |||
| } | ||||
| 
 | ||||
| func (p *Processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL) (*gtsmodel.Status, error) { | ||||
| 	status, statusable, err := p.federator.GetStatus(transport.WithFastfail(ctx), authed.Account.Username, uri, true, true) | ||||
| 	status, statusable, err := p.federator.GetStatus(gtscontext.SetFastFail(ctx), authed.Account.Username, uri, true, true) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if !*status.Local && statusable != nil { | ||||
| 		// Attempt to dereference the status thread while we are here | ||||
| 		p.federator.DereferenceThread(transport.WithFastfail(ctx), authed.Account.Username, uri, status, statusable) | ||||
| 		p.federator.DereferenceThread(gtscontext.SetFastFail(ctx), authed.Account.Username, uri, status, statusable) | ||||
| 	} | ||||
| 
 | ||||
| 	return status, nil | ||||
|  | @ -268,7 +268,7 @@ func (p *Processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, | |||
| 	} | ||||
| 
 | ||||
| 	return p.federator.GetAccountByURI( | ||||
| 		transport.WithFastfail(ctx), | ||||
| 		gtscontext.SetFastFail(ctx), | ||||
| 		authed.Account.Username, | ||||
| 		uri, false, | ||||
| 	) | ||||
|  | @ -295,7 +295,7 @@ func (p *Processor) searchAccountByUsernameDomain(ctx context.Context, authed *o | |||
| 	} | ||||
| 
 | ||||
| 	return p.federator.GetAccountByUsernameDomain( | ||||
| 		transport.WithFastfail(ctx), | ||||
| 		gtscontext.SetFastFail(ctx), | ||||
| 		authed.Account.Username, | ||||
| 		username, domain, false, | ||||
| 	) | ||||
|  |  | |||
|  | @ -24,9 +24,9 @@ import ( | |||
| 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/federation" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/id" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||||
| ) | ||||
| 
 | ||||
|  | @ -58,7 +58,7 @@ func GetParseMentionFunc(dbConn db.DB, federator federation.Federator) gtsmodel. | |||
| 			} | ||||
| 
 | ||||
| 			remoteAccount, err := federator.GetAccountByUsernameDomain( | ||||
| 				transport.WithFastfail(ctx), | ||||
| 				gtscontext.SetFastFail(ctx), | ||||
| 				requestingUsername, | ||||
| 				username, | ||||
| 				domain, | ||||
|  |  | |||
|  | @ -1,42 +0,0 @@ | |||
| // GoToSocial | ||||
| // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||
| // SPDX-License-Identifier: AGPL-3.0-or-later | ||||
| // | ||||
| // This program is free software: you can redistribute it and/or modify | ||||
| // it under the terms of the GNU Affero General Public License as published by | ||||
| // the Free Software Foundation, either version 3 of the License, or | ||||
| // (at your option) any later version. | ||||
| // | ||||
| // This program is distributed in the hope that it will be useful, | ||||
| // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
| // GNU Affero General Public License for more details. | ||||
| // | ||||
| // You should have received a copy of the GNU Affero General Public License | ||||
| // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| 
 | ||||
| package transport | ||||
| 
 | ||||
| import "context" | ||||
| 
 | ||||
| // ctxkey is our own unique context key type to prevent setting outside package. | ||||
| type ctxkey string | ||||
| 
 | ||||
| // fastfailkey is our unique context key to indicate fast-fail is enabled. | ||||
| var fastfailkey = ctxkey("ff") | ||||
| 
 | ||||
| // WithFastfail returns a Context which indicates that any http requests made | ||||
| // with it should return after the first failed attempt, instead of retrying. | ||||
| // | ||||
| // This can be used to fail quickly when you're making an outgoing http request | ||||
| // inside the context of an incoming http request, and you want to be able to | ||||
| // provide a snappy response to the user, instead of retrying + backing off. | ||||
| func WithFastfail(parent context.Context) context.Context { | ||||
| 	return context.WithValue(parent, fastfailkey, struct{}{}) | ||||
| } | ||||
| 
 | ||||
| // IsFastfail returns true if the given context was created by WithFastfail. | ||||
| func IsFastfail(ctx context.Context) bool { | ||||
| 	_, ok := ctx.Value(fastfailkey).(struct{}) | ||||
| 	return ok | ||||
| } | ||||
|  | @ -24,7 +24,7 @@ import ( | |||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/url" | ||||
| 	"time" | ||||
| 	"runtime" | ||||
| 
 | ||||
| 	"codeberg.org/gruf/go-byteutil" | ||||
| 	"codeberg.org/gruf/go-cache/v3" | ||||
|  | @ -32,7 +32,7 @@ import ( | |||
| 	"github.com/superseriousbusiness/activity/streams" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/httpclient" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||
| ) | ||||
| 
 | ||||
|  | @ -49,14 +49,14 @@ type controller struct { | |||
| 	state     *state.State | ||||
| 	fedDB     federatingdb.DB | ||||
| 	clock     pub.Clock | ||||
| 	client    pub.HttpClient | ||||
| 	client    httpclient.SigningClient | ||||
| 	trspCache cache.Cache[string, *transport] | ||||
| 	badHosts  cache.Cache[string, struct{}] | ||||
| 	userAgent string | ||||
| 	senders   int // no. concurrent batch delivery routines. | ||||
| } | ||||
| 
 | ||||
| // NewController returns an implementation of the Controller interface for creating new transports | ||||
| func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller { | ||||
| func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.Clock, client httpclient.SigningClient) Controller { | ||||
| 	applicationName := config.GetApplicationName() | ||||
| 	host := config.GetHost() | ||||
| 	proto := config.GetProtocol() | ||||
|  | @ -68,20 +68,8 @@ func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.C | |||
| 		clock:     clock, | ||||
| 		client:    client, | ||||
| 		trspCache: cache.New[string, *transport](0, 100, 0), | ||||
| 		badHosts:  cache.New[string, struct{}](0, 1000, 0), | ||||
| 		userAgent: fmt.Sprintf("%s (+%s://%s) gotosocial/%s", applicationName, proto, host, version), | ||||
| 	} | ||||
| 
 | ||||
| 	// Transport cache has TTL=1hr freq=1min | ||||
| 	c.trspCache.SetTTL(time.Hour, false) | ||||
| 	if !c.trspCache.Start(time.Minute) { | ||||
| 		log.Panic(nil, "failed to start transport controller cache") | ||||
| 	} | ||||
| 
 | ||||
| 	// Bad hosts cache has TTL=15min freq=1min | ||||
| 	c.badHosts.SetTTL(15*time.Minute, false) | ||||
| 	if !c.badHosts.Start(time.Minute) { | ||||
| 		log.Panic(nil, "failed to start transport controller cache") | ||||
| 		senders:   runtime.GOMAXPROCS(0), // on batch delivery, only ever send GOMAXPROCS at a time. | ||||
| 	} | ||||
| 
 | ||||
| 	return c | ||||
|  |  | |||
|  | @ -22,7 +22,6 @@ import ( | |||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 
 | ||||
| 	"codeberg.org/gruf/go-byteutil" | ||||
|  | @ -32,54 +31,90 @@ import ( | |||
| ) | ||||
| 
 | ||||
| func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*url.URL) error { | ||||
| 	// concurrently deliver to recipients; for each delivery, buffer the error if it fails | ||||
| 	wg := sync.WaitGroup{} | ||||
| 	errCh := make(chan error, len(recipients)) | ||||
| 	for _, recipient := range recipients { | ||||
| 		wg.Add(1) | ||||
| 		go func(r *url.URL) { | ||||
| 			defer wg.Done() | ||||
| 			if err := t.Deliver(ctx, b, r); err != nil { | ||||
| 				errCh <- err | ||||
| 			} | ||||
| 		}(recipient) | ||||
| 	} | ||||
| 	var ( | ||||
| 		// errs accumulates errors received during | ||||
| 		// attempted delivery by deliverer routines. | ||||
| 		errs gtserror.MultiError | ||||
| 
 | ||||
| 	// wait until all deliveries have succeeded or failed | ||||
| 	wg.Wait() | ||||
| 		// wait blocks until all sender | ||||
| 		// routines have returned. | ||||
| 		wait sync.WaitGroup | ||||
| 
 | ||||
| 		// mutex protects 'recipients' and | ||||
| 		// 'errs' for concurrent access. | ||||
| 		mutex sync.Mutex | ||||
| 
 | ||||
| 		// Get current instance host info. | ||||
| 		domain = config.GetAccountDomain() | ||||
| 		host   = config.GetHost() | ||||
| 	) | ||||
| 
 | ||||
| 	// Block on expect no. senders. | ||||
| 	wait.Add(t.controller.senders) | ||||
| 
 | ||||
| 	for i := 0; i < t.controller.senders; i++ { | ||||
| 		go func() { | ||||
| 			// Mark returned. | ||||
| 			defer wait.Done() | ||||
| 
 | ||||
| 	// receive any buffered errors | ||||
| 	errs := make([]string, 0, len(errCh)) | ||||
| outer: | ||||
| 			for { | ||||
| 		select { | ||||
| 		case e := <-errCh: | ||||
| 			errs = append(errs, e.Error()) | ||||
| 		default: | ||||
| 			break outer | ||||
| 		} | ||||
| 				// Acquire lock. | ||||
| 				mutex.Lock() | ||||
| 
 | ||||
| 				if len(recipients) == 0 { | ||||
| 					// Reached end. | ||||
| 					mutex.Unlock() | ||||
| 					return | ||||
| 				} | ||||
| 
 | ||||
| 	if len(errs) > 0 { | ||||
| 		return fmt.Errorf("BatchDeliver: at least one failure: %s", strings.Join(errs, "; ")) | ||||
| 				// Pop next recipient. | ||||
| 				i := len(recipients) - 1 | ||||
| 				to := recipients[i] | ||||
| 				recipients = recipients[:i] | ||||
| 
 | ||||
| 				// Done with lock. | ||||
| 				mutex.Unlock() | ||||
| 
 | ||||
| 				// Skip delivery to recipient if it is "us". | ||||
| 				if to.Host == host || to.Host == domain { | ||||
| 					continue | ||||
| 				} | ||||
| 
 | ||||
| 	return nil | ||||
| 				// Attempt to deliver data to recipient. | ||||
| 				if err := t.deliver(ctx, b, to); err != nil { | ||||
| 					mutex.Lock() // safely append err to accumulator. | ||||
| 					errs.Appendf("error delivering to %s: %v", to, err) | ||||
| 					mutex.Unlock() | ||||
| 				} | ||||
| 			} | ||||
| 		}() | ||||
| 	} | ||||
| 
 | ||||
| 	// Wait for finish. | ||||
| 	wait.Wait() | ||||
| 
 | ||||
| 	// Return combined err. | ||||
| 	return errs.Combine() | ||||
| } | ||||
| 
 | ||||
| func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error { | ||||
| 	// if the 'to' host is our own, just skip this delivery since we by definition already have the message! | ||||
| 	// if 'to' host is our own, skip as we don't need to deliver to ourselves... | ||||
| 	if to.Host == config.GetHost() || to.Host == config.GetAccountDomain() { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	urlStr := to.String() | ||||
| 	// Deliver data to recipient. | ||||
| 	return t.deliver(ctx, b, to) | ||||
| } | ||||
| 
 | ||||
| func (t *transport) deliver(ctx context.Context, b []byte, to *url.URL) error { | ||||
| 	url := to.String() | ||||
| 
 | ||||
| 	// Use rewindable bytes reader for body. | ||||
| 	var body byteutil.ReadNopCloser | ||||
| 	body.Reset(b) | ||||
| 
 | ||||
| 	req, err := http.NewRequestWithContext(ctx, "POST", urlStr, &body) | ||||
| 	req, err := http.NewRequestWithContext(ctx, "POST", url, &body) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | @ -88,16 +123,16 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error { | |||
| 	req.Header.Add("Accept-Charset", "utf-8") | ||||
| 	req.Header.Set("Host", to.Host) | ||||
| 
 | ||||
| 	resp, err := t.POST(req, b) | ||||
| 	rsp, err := t.POST(req, b) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 	defer rsp.Body.Close() | ||||
| 
 | ||||
| 	if code := resp.StatusCode; code != http.StatusOK && | ||||
| 	if code := rsp.StatusCode; code != http.StatusOK && | ||||
| 		code != http.StatusCreated && code != http.StatusAccepted { | ||||
| 		err := fmt.Errorf("POST request to %s failed: %s", urlStr, resp.Status) | ||||
| 		return gtserror.WithStatusCode(err, resp.StatusCode) | ||||
| 		err := fmt.Errorf("POST request to %s failed: %s", url, rsp.Status) | ||||
| 		return gtserror.WithStatusCode(err, rsp.StatusCode) | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
|  |  | |||
|  | @ -20,26 +20,17 @@ package transport | |||
| import ( | ||||
| 	"context" | ||||
| 	"crypto" | ||||
| 	"crypto/x509" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"codeberg.org/gruf/go-byteutil" | ||||
| 	errorsv2 "codeberg.org/gruf/go-errors/v2" | ||||
| 	"codeberg.org/gruf/go-kv" | ||||
| 	"github.com/go-fed/httpsig" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/httpclient" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||
| ) | ||||
| 
 | ||||
| // Transport implements the pub.Transport interface with some additional functionality for fetching remote media. | ||||
|  | @ -78,7 +69,7 @@ type Transport interface { | |||
| 	Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) | ||||
| } | ||||
| 
 | ||||
| // transport implements the Transport interface | ||||
| // transport implements the Transport interface. | ||||
| type transport struct { | ||||
| 	controller *controller | ||||
| 	pubKeyID   string | ||||
|  | @ -95,9 +86,11 @@ func (t *transport) GET(r *http.Request) (*http.Response, error) { | |||
| 	if r.Method != http.MethodGet { | ||||
| 		return nil, errors.New("must be GET request") | ||||
| 	} | ||||
| 	return t.do(r, func(r *http.Request) error { | ||||
| 		return t.signGET(r) | ||||
| 	}) | ||||
| 	ctx := r.Context() // extract, set pubkey ID. | ||||
| 	ctx = gtscontext.SetPublicKeyID(ctx, t.pubKeyID) | ||||
| 	r = r.WithContext(ctx) // replace request ctx. | ||||
| 	r.Header.Set("User-Agent", t.controller.userAgent) | ||||
| 	return t.controller.client.DoSigned(r, t.signGET()) | ||||
| } | ||||
| 
 | ||||
| // POST will perform given http request using transport client, retrying on certain preset errors. | ||||
|  | @ -105,161 +98,31 @@ func (t *transport) POST(r *http.Request, body []byte) (*http.Response, error) { | |||
| 	if r.Method != http.MethodPost { | ||||
| 		return nil, errors.New("must be POST request") | ||||
| 	} | ||||
| 	return t.do(r, func(r *http.Request) error { | ||||
| 		return t.signPOST(r, body) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (t *transport) do(r *http.Request, signer func(*http.Request) error) (*http.Response, error) { | ||||
| 	const ( | ||||
| 		// max no. attempts | ||||
| 		maxRetries = 5 | ||||
| 
 | ||||
| 		// starting backoff duration. | ||||
| 		baseBackoff = 2 * time.Second | ||||
| 	) | ||||
| 
 | ||||
| 	// Get request hostname | ||||
| 	host := r.URL.Hostname() | ||||
| 
 | ||||
| 	// Check whether request should fast fail, we check this | ||||
| 	// before loop as each context.Value() requires mutex lock. | ||||
| 	fastFail := IsFastfail(r.Context()) | ||||
| 	if !fastFail { | ||||
| 		// Check if recently reached max retries for this host | ||||
| 		// so we don't bother with a retry-backoff loop. The only | ||||
| 		// errors that are retried upon are server failure and | ||||
| 		// domain resolution type errors, so this cached result | ||||
| 		// indicates this server is likely having issues. | ||||
| 		fastFail = t.controller.badHosts.Has(host) | ||||
| 	} | ||||
| 
 | ||||
| 	// Start a log entry for this request | ||||
| 	l := log.WithContext(r.Context()). | ||||
| 		WithFields(kv.Fields{ | ||||
| 			{"pubKeyID", t.pubKeyID}, | ||||
| 			{"method", r.Method}, | ||||
| 			{"url", r.URL.String()}, | ||||
| 		}...) | ||||
| 
 | ||||
| 	ctx := r.Context() // extract, set pubkey ID. | ||||
| 	ctx = gtscontext.SetPublicKeyID(ctx, t.pubKeyID) | ||||
| 	r = r.WithContext(ctx) // replace request ctx. | ||||
| 	r.Header.Set("User-Agent", t.controller.userAgent) | ||||
| 
 | ||||
| 	for i := 0; i < maxRetries; i++ { | ||||
| 		var backoff time.Duration | ||||
| 
 | ||||
| 		// Reset signing header fields | ||||
| 		now := t.controller.clock.Now().UTC() | ||||
| 		r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT") | ||||
| 		r.Header.Del("Signature") | ||||
| 		r.Header.Del("Digest") | ||||
| 
 | ||||
| 		// Rewind body reader and content-length if set. | ||||
| 		if rc, ok := r.Body.(*byteutil.ReadNopCloser); ok { | ||||
| 			r.ContentLength = int64(rc.Len()) | ||||
| 			rc.Rewind() | ||||
| 		} | ||||
| 
 | ||||
| 		// Perform request signing | ||||
| 		if err := signer(r); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		l.Infof("performing request") | ||||
| 
 | ||||
| 		// Attempt to perform request | ||||
| 		rsp, err := t.controller.client.Do(r) | ||||
| 		if err == nil { //nolint:gocritic | ||||
| 			// TooManyRequest means we need to slow | ||||
| 			// down and retry our request. Codes over | ||||
| 			// 500 generally indicate temp. outages. | ||||
| 			if code := rsp.StatusCode; code < 500 && | ||||
| 				code != http.StatusTooManyRequests { | ||||
| 				return rsp, nil | ||||
| 			} | ||||
| 
 | ||||
| 			// Generate error from status code for logging | ||||
| 			err = errors.New(`http response "` + rsp.Status + `"`) | ||||
| 
 | ||||
| 			// Search for a provided "Retry-After" header value. | ||||
| 			if after := rsp.Header.Get("Retry-After"); after != "" { | ||||
| 
 | ||||
| 				if u, _ := strconv.ParseUint(after, 10, 32); u != 0 { | ||||
| 					// An integer number of backoff seconds was provided. | ||||
| 					backoff = time.Duration(u) * time.Second | ||||
| 				} else if at, _ := http.ParseTime(after); !at.Before(now) { | ||||
| 					// An HTTP formatted future date-time was provided. | ||||
| 					backoff = at.Sub(now) | ||||
| 				} | ||||
| 
 | ||||
| 				// Don't let their provided backoff exceed our max. | ||||
| 				if max := baseBackoff * maxRetries; backoff > max { | ||||
| 					backoff = max | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 		} else if errorsv2.Is(err, | ||||
| 			context.DeadlineExceeded, | ||||
| 			context.Canceled, | ||||
| 			httpclient.ErrInvalidRequest, | ||||
| 			httpclient.ErrBodyTooLarge, | ||||
| 			httpclient.ErrReservedAddr, | ||||
| 		) { | ||||
| 			// Return on non-retryable errors | ||||
| 			return nil, err | ||||
| 		} else if strings.Contains(err.Error(), "stopped after 10 redirects") { | ||||
| 			// Don't bother if net/http returned after too many redirects | ||||
| 			return nil, err | ||||
| 		} else if errors.As(err, &x509.UnknownAuthorityError{}) { | ||||
| 			// Unknown authority errors we do NOT recover from | ||||
| 			return nil, err | ||||
| 		} else if dnserr := (*net.DNSError)(nil); // nocollapse | ||||
| 		errors.As(err, &dnserr) && dnserr.IsNotFound { | ||||
| 			// DNS lookup failure, this domain does not exist | ||||
| 			return nil, gtserror.SetNotFound(err) | ||||
| 		} | ||||
| 
 | ||||
| 		if fastFail { | ||||
| 			// on fast-fail, don't bother backoff/retry | ||||
| 			return nil, fmt.Errorf("%w (fast fail)", err) | ||||
| 		} | ||||
| 
 | ||||
| 		if backoff == 0 { | ||||
| 			// No retry-after found, set our predefined backoff. | ||||
| 			backoff = time.Duration(i) * baseBackoff | ||||
| 		} | ||||
| 
 | ||||
| 		l.Errorf("backing off for %s after http request error: %v", backoff, err) | ||||
| 
 | ||||
| 		select { | ||||
| 		// Request ctx cancelled | ||||
| 		case <-r.Context().Done(): | ||||
| 			return nil, r.Context().Err() | ||||
| 
 | ||||
| 		// Backoff for some time | ||||
| 		case <-time.After(backoff): | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Add "bad" entry for this host. | ||||
| 	t.controller.badHosts.Set(host, struct{}{}) | ||||
| 
 | ||||
| 	return nil, errors.New("transport reached max retries") | ||||
| 	return t.controller.client.DoSigned(r, t.signPOST(body)) | ||||
| } | ||||
| 
 | ||||
| // signGET will safely sign an HTTP GET request. | ||||
| func (t *transport) signGET(r *http.Request) (err error) { | ||||
| func (t *transport) signGET() httpclient.SignFunc { | ||||
| 	return func(r *http.Request) (err error) { | ||||
| 		t.safesign(func() { | ||||
| 			err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil) | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // signPOST will safely sign an HTTP POST request for given body. | ||||
| func (t *transport) signPOST(r *http.Request, body []byte) (err error) { | ||||
| func (t *transport) signPOST(body []byte) httpclient.SignFunc { | ||||
| 	return func(r *http.Request) (err error) { | ||||
| 		t.safesign(func() { | ||||
| 			err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body) | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // safesign will perform sign function within mutex protection, | ||||
|  |  | |||
|  | @ -31,8 +31,12 @@ type Workers struct { | |||
| 	// Main task scheduler instance. | ||||
| 	Scheduler sched.Scheduler | ||||
| 
 | ||||
| 	// ClientAPI / federator worker pools. | ||||
| 	// ClientAPI provides a worker pool that handles both | ||||
| 	// incoming client actions, and our own side-effects. | ||||
| 	ClientAPI runners.WorkerPool | ||||
| 
 | ||||
| 	// Federator provides a worker pool that handles both | ||||
| 	// incoming federated actions, and our own side-effects. | ||||
| 	Federator runners.WorkerPool | ||||
| 
 | ||||
| 	// Enqueue functions for clientAPI / federator worker pools, | ||||
|  |  | |||
|  | @ -26,12 +26,12 @@ import ( | |||
| 	"strings" | ||||
| 	"sync" | ||||
| 
 | ||||
| 	"github.com/superseriousbusiness/activity/pub" | ||||
| 	"github.com/superseriousbusiness/activity/streams" | ||||
| 	"github.com/superseriousbusiness/activity/streams/vocab" | ||||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/federation" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/httpclient" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" | ||||
|  | @ -51,7 +51,7 @@ const ( | |||
| // Unlike the other test interfaces provided in this package, you'll probably want to call this function | ||||
| // PER TEST rather than per suite, so that the do function can be set on a test by test (or even more granular) | ||||
| // basis. | ||||
| func NewTestTransportController(state *state.State, client pub.HttpClient) transport.Controller { | ||||
| func NewTestTransportController(state *state.State, client httpclient.SigningClient) transport.Controller { | ||||
| 	return transport.NewController(state, NewTestFederatingDB(state), &federation.Clock{}, client) | ||||
| } | ||||
| 
 | ||||
|  | @ -225,6 +225,10 @@ func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) { | |||
| 	return m.do(req) | ||||
| } | ||||
| 
 | ||||
| func (m *MockHTTPClient) DoSigned(req *http.Request, sign httpclient.SignFunc) (*http.Response, error) { | ||||
| 	return m.do(req) | ||||
| } | ||||
| 
 | ||||
| func HostMetaResponse(req *http.Request) (responseCode int, responseBytes []byte, responseContentType string, responseContentLength int) { | ||||
| 	var hm *apimodel.HostMeta | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue