mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-30 23:22:26 -05:00 
			
		
		
		
	[performance] improved request batching (removes need for queueing) (#1687)
* revamp http client to not limit requests, instead use sender worker Signed-off-by: kim <grufwub@gmail.com> * remove separate sender worker pool, spawn 2*GOMAXPROCS batch senders each time, no need for transport cache sweeping Signed-off-by: kim <grufwub@gmail.com> * improve batch senders to keep popping recipients until remote URL found Signed-off-by: kim <grufwub@gmail.com> * fix recipient looping issue Signed-off-by: kim <grufwub@gmail.com> * fix missing mutex unlock Signed-off-by: kim <grufwub@gmail.com> * move request id ctx key to gtscontext, finish filling out more code comments, add basic support for not logging client IP Signed-off-by: kim <grufwub@gmail.com> * slight code reformatting Signed-off-by: kim <grufwub@gmail.com> * a whitespace Signed-off-by: kim <grufwub@gmail.com> * remove unused code Signed-off-by: kim <grufwub@gmail.com> * add missing license headers Signed-off-by: kim <grufwub@gmail.com> * fix request backoff calculation Signed-off-by: kim <grufwub@gmail.com> --------- Signed-off-by: kim <grufwub@gmail.com>
This commit is contained in:
		
					parent
					
						
							
								6b4f6dc755
							
						
					
				
			
			
				commit
				
					
						6a29c5ffd4
					
				
			
		
					 24 changed files with 431 additions and 493 deletions
				
			
		
							
								
								
									
										2
									
								
								go.mod
									
										
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
										
									
									
									
								
							|  | @ -19,7 +19,6 @@ require ( | ||||||
| 	github.com/abema/go-mp4 v0.10.1 | 	github.com/abema/go-mp4 v0.10.1 | ||||||
| 	github.com/buckket/go-blurhash v1.1.0 | 	github.com/buckket/go-blurhash v1.1.0 | ||||||
| 	github.com/coreos/go-oidc/v3 v3.5.0 | 	github.com/coreos/go-oidc/v3 v3.5.0 | ||||||
| 	github.com/cornelk/hashmap v1.0.8 |  | ||||||
| 	github.com/disintegration/imaging v1.6.2 | 	github.com/disintegration/imaging v1.6.2 | ||||||
| 	github.com/gin-contrib/cors v1.4.0 | 	github.com/gin-contrib/cors v1.4.0 | ||||||
| 	github.com/gin-contrib/gzip v0.0.6 | 	github.com/gin-contrib/gzip v0.0.6 | ||||||
|  | @ -82,6 +81,7 @@ require ( | ||||||
| 	github.com/cilium/ebpf v0.9.1 // indirect | 	github.com/cilium/ebpf v0.9.1 // indirect | ||||||
| 	github.com/containerd/cgroups/v3 v3.0.1 // indirect | 	github.com/containerd/cgroups/v3 v3.0.1 // indirect | ||||||
| 	github.com/coreos/go-systemd/v22 v22.3.2 // indirect | 	github.com/coreos/go-systemd/v22 v22.3.2 // indirect | ||||||
|  | 	github.com/cornelk/hashmap v1.0.8 // indirect | ||||||
| 	github.com/davecgh/go-spew v1.1.1 // indirect | 	github.com/davecgh/go-spew v1.1.1 // indirect | ||||||
| 	github.com/docker/go-units v0.4.0 // indirect | 	github.com/docker/go-units v0.4.0 // indirect | ||||||
| 	github.com/dsoprea/go-exif/v3 v3.0.0-20210625224831-a6301f85c82b // indirect | 	github.com/dsoprea/go-exif/v3 v3.0.0-20210625224831-a6301f85c82b // indirect | ||||||
|  |  | ||||||
|  | @ -24,9 +24,9 @@ import ( | ||||||
| 	"codeberg.org/gruf/go-kv" | 	"codeberg.org/gruf/go-kv" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/middleware" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // TODO: add more templated html pages here for different error types | // TODO: add more templated html pages here for different error types | ||||||
|  | @ -51,7 +51,7 @@ func NotFoundHandler(c *gin.Context, instanceGet func(ctx context.Context) (*api | ||||||
| 
 | 
 | ||||||
| 		c.HTML(http.StatusNotFound, "404.tmpl", gin.H{ | 		c.HTML(http.StatusNotFound, "404.tmpl", gin.H{ | ||||||
| 			"instance":  instance, | 			"instance":  instance, | ||||||
| 			"requestID": middleware.RequestID(ctx), | 			"requestID": gtscontext.RequestID(ctx), | ||||||
| 		}) | 		}) | ||||||
| 	default: | 	default: | ||||||
| 		c.JSON(http.StatusNotFound, gin.H{ | 		c.JSON(http.StatusNotFound, gin.H{ | ||||||
|  | @ -76,7 +76,7 @@ func genericErrorHandler(c *gin.Context, instanceGet func(ctx context.Context) ( | ||||||
| 			"instance":  instance, | 			"instance":  instance, | ||||||
| 			"code":      errWithCode.Code(), | 			"code":      errWithCode.Code(), | ||||||
| 			"error":     errWithCode.Safe(), | 			"error":     errWithCode.Safe(), | ||||||
| 			"requestID": middleware.RequestID(ctx), | 			"requestID": gtscontext.RequestID(ctx), | ||||||
| 		}) | 		}) | ||||||
| 	default: | 	default: | ||||||
| 		c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) | 		c.JSON(errWithCode.Code(), gin.H{"error": errWithCode.Safe()}) | ||||||
|  |  | ||||||
|  | @ -34,10 +34,10 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/activity/streams/vocab" | 	"github.com/superseriousbusiness/activity/streams/vocab" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/ap" | 	"github.com/superseriousbusiness/gotosocial/internal/ap" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/config" | 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| /* | /* | ||||||
|  | @ -216,7 +216,7 @@ func (f *federator) AuthenticateFederatedRequest(ctx context.Context, requestedU | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		log.Tracef(ctx, "proceeding with dereference for uncached public key %s", requestingPublicKeyID) | 		log.Tracef(ctx, "proceeding with dereference for uncached public key %s", requestingPublicKeyID) | ||||||
| 		trans, err := f.transportController.NewTransportForUsername(transport.WithFastfail(ctx), requestedUsername) | 		trans, err := f.transportController.NewTransportForUsername(gtscontext.SetFastFail(ctx), requestedUsername) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			errWithCode := gtserror.NewErrorInternalError(fmt.Errorf("error creating transport for %s: %s", requestedUsername, err)) | 			errWithCode := gtserror.NewErrorInternalError(fmt.Errorf("error creating transport for %s: %s", requestedUsername, err)) | ||||||
| 			log.Debug(ctx, errWithCode) | 			log.Debug(ctx, errWithCode) | ||||||
|  |  | ||||||
|  | @ -29,10 +29,10 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/activity/streams/vocab" | 	"github.com/superseriousbusiness/activity/streams/vocab" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/ap" | 	"github.com/superseriousbusiness/gotosocial/internal/ap" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/uris" | 	"github.com/superseriousbusiness/gotosocial/internal/uris" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/util" | 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||||||
| ) | ) | ||||||
|  | @ -191,9 +191,8 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr | ||||||
| 			return ctx, false, err | 			return ctx, false, err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// We don't yet have an entry for | 		// we don't have an entry for this instance yet so dereference it | ||||||
| 		// the instance, go dereference it. | 		instance, err := f.GetRemoteInstance(gtscontext.SetFastFail(ctx), username, &url.URL{ | ||||||
| 		instance, err := f.GetRemoteInstance(transport.WithFastfail(ctx), username, &url.URL{ |  | ||||||
| 			Scheme: publicKeyOwnerURI.Scheme, | 			Scheme: publicKeyOwnerURI.Scheme, | ||||||
| 			Host:   publicKeyOwnerURI.Host, | 			Host:   publicKeyOwnerURI.Host, | ||||||
| 		}) | 		}) | ||||||
|  | @ -212,7 +211,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr | ||||||
| 	// dereference the remote account (or just get it | 	// dereference the remote account (or just get it | ||||||
| 	// from the db if we already have it). | 	// from the db if we already have it). | ||||||
| 	requestingAccount, err := f.GetAccountByURI( | 	requestingAccount, err := f.GetAccountByURI( | ||||||
| 		transport.WithFastfail(ctx), username, publicKeyOwnerURI, false, | 		gtscontext.SetFastFail(ctx), username, publicKeyOwnerURI, false, | ||||||
| 	) | 	) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		if gtserror.StatusCode(err) == http.StatusGone { | 		if gtserror.StatusCode(err) == http.StatusGone { | ||||||
|  |  | ||||||
|  | @ -17,7 +17,9 @@ | ||||||
| 
 | 
 | ||||||
| package gtscontext | package gtscontext | ||||||
| 
 | 
 | ||||||
| import "context" | import ( | ||||||
|  | 	"context" | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| // package private context key type. | // package private context key type. | ||||||
| type ctxkey uint | type ctxkey uint | ||||||
|  | @ -26,8 +28,54 @@ const ( | ||||||
| 	// context keys. | 	// context keys. | ||||||
| 	_ ctxkey = iota | 	_ ctxkey = iota | ||||||
| 	barebonesKey | 	barebonesKey | ||||||
|  | 	fastFailKey | ||||||
|  | 	pubKeyIDKey | ||||||
|  | 	requestIDKey | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // RequestID returns the request ID associated with context. This value will usually | ||||||
|  | // be set by the request ID middleware handler, either pulling an existing supplied | ||||||
|  | // value from request headers, or generating a unique new entry. This is useful for | ||||||
|  | // tying together log entries associated with an original incoming request. | ||||||
|  | func RequestID(ctx context.Context) string { | ||||||
|  | 	id, _ := ctx.Value(requestIDKey).(string) | ||||||
|  | 	return id | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SetRequestID stores the given request ID value and returns the wrapped | ||||||
|  | // context. See RequestID() for further information on the request ID value. | ||||||
|  | func SetRequestID(ctx context.Context, id string) context.Context { | ||||||
|  | 	return context.WithValue(ctx, requestIDKey, id) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // PublicKeyID returns the public key ID (URI) associated with context. This | ||||||
|  | // value is useful for logging situations in which a given public key URI is | ||||||
|  | // relevant, e.g. for outgoing requests being signed by the given key. | ||||||
|  | func PublicKeyID(ctx context.Context) string { | ||||||
|  | 	id, _ := ctx.Value(pubKeyIDKey).(string) | ||||||
|  | 	return id | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SetPublicKeyID stores the given public key ID value and returns the wrapped | ||||||
|  | // context. See PublicKeyID() for further information on the public key ID value. | ||||||
|  | func SetPublicKeyID(ctx context.Context, id string) context.Context { | ||||||
|  | 	return context.WithValue(ctx, pubKeyIDKey, id) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // IsFastFail returns whether the "fastfail" context key has been set. This | ||||||
|  | // can be used to indicate to an http client, for example, that the result | ||||||
|  | // of an outgoing request is time sensitive and so not to bother with retries. | ||||||
|  | func IsFastfail(ctx context.Context) bool { | ||||||
|  | 	_, ok := ctx.Value(fastFailKey).(struct{}) | ||||||
|  | 	return ok | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SetFastFail sets the "fastfail" context flag and returns this wrapped context. | ||||||
|  | // See IsFastFail() for further information on the "fastfail" context flag. | ||||||
|  | func SetFastFail(ctx context.Context) context.Context { | ||||||
|  | 	return context.WithValue(ctx, fastFailKey, struct{}{}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Barebones returns whether the "barebones" context key has been set. This | // Barebones returns whether the "barebones" context key has been set. This | ||||||
| // can be used to indicate to the database, for example, that only a barebones | // can be used to indicate to the database, for example, that only a barebones | ||||||
| // model need be returned, Allowing it to skip populating sub models. | // model need be returned, Allowing it to skip populating sub models. | ||||||
|  | @ -37,7 +85,7 @@ func Barebones(ctx context.Context) bool { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // SetBarebones sets the "barebones" context flag and returns this wrapped context. | // SetBarebones sets the "barebones" context flag and returns this wrapped context. | ||||||
| // See Barebones() for further information on the "barebones" context flag.. | // See Barebones() for further information on the "barebones" context flag. | ||||||
| func SetBarebones(ctx context.Context) context.Context { | func SetBarebones(ctx context.Context) context.Context { | ||||||
| 	return context.WithValue(ctx, barebonesKey, struct{}{}) | 	return context.WithValue(ctx, barebonesKey, struct{}{}) | ||||||
| } | } | ||||||
|  |  | ||||||
							
								
								
									
										44
									
								
								internal/gtscontext/log_hooks.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								internal/gtscontext/log_hooks.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,44 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package gtscontext | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 
 | ||||||
|  | 	"codeberg.org/gruf/go-kv" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func init() { | ||||||
|  | 	// Add our required logging hooks on application initialization. | ||||||
|  | 	// | ||||||
|  | 	// Request ID middleware hook. | ||||||
|  | 	log.Hook(func(ctx context.Context, kvs []kv.Field) []kv.Field { | ||||||
|  | 		if id := RequestID(ctx); id != "" { | ||||||
|  | 			return append(kvs, kv.Field{K: "requestID", V: id}) | ||||||
|  | 		} | ||||||
|  | 		return kvs | ||||||
|  | 	}) | ||||||
|  | 	// Client IP middleware hook. | ||||||
|  | 	log.Hook(func(ctx context.Context, kvs []kv.Field) []kv.Field { | ||||||
|  | 		if id := PublicKeyID(ctx); id != "" { | ||||||
|  | 			return append(kvs, kv.Field{K: "pubKeyID", V: id}) | ||||||
|  | 		} | ||||||
|  | 		return kvs | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | @ -18,31 +18,39 @@ | ||||||
| package httpclient | package httpclient | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"crypto/x509" | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| 	"runtime" | 	"runtime" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"codeberg.org/gruf/go-bytesize" | 	"codeberg.org/gruf/go-bytesize" | ||||||
|  | 	"codeberg.org/gruf/go-byteutil" | ||||||
|  | 	"codeberg.org/gruf/go-cache/v3" | ||||||
|  | 	errorsv2 "codeberg.org/gruf/go-errors/v2" | ||||||
| 	"codeberg.org/gruf/go-kv" | 	"codeberg.org/gruf/go-kv" | ||||||
| 	"github.com/cornelk/hashmap" | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // ErrInvalidRequest is returned if a given HTTP request is invalid and cannot be performed. | var ( | ||||||
| var ErrInvalidRequest = errors.New("invalid http request") | 	// ErrInvalidNetwork is returned if the request would not be performed over TCP | ||||||
|  | 	ErrInvalidNetwork = errors.New("invalid network type") | ||||||
| 
 | 
 | ||||||
| // ErrInvalidNetwork is returned if the request would not be performed over TCP | 	// ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net. | ||||||
| var ErrInvalidNetwork = errors.New("invalid network type") | 	ErrReservedAddr = errors.New("dial within blocked / reserved IP range") | ||||||
| 
 | 
 | ||||||
| // ErrReservedAddr is returned if a dialed address resolves to an IP within a blocked or reserved net. | 	// ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB). | ||||||
| var ErrReservedAddr = errors.New("dial within blocked / reserved IP range") | 	ErrBodyTooLarge = errors.New("body size too large") | ||||||
| 
 | ) | ||||||
| // ErrBodyTooLarge is returned when a received response body is above predefined limit (default 40MB). |  | ||||||
| var ErrBodyTooLarge = errors.New("body size too large") |  | ||||||
| 
 | 
 | ||||||
| // Config provides configuration details for setting up a new | // Config provides configuration details for setting up a new | ||||||
| // instance of httpclient.Client{}. Within are a subset of the | // instance of httpclient.Client{}. Within are a subset of the | ||||||
|  | @ -83,13 +91,10 @@ type Config struct { | ||||||
| //     cases to protect against forged / unknown content-lengths | //     cases to protect against forged / unknown content-lengths | ||||||
| //   - protection from server side request forgery (SSRF) by only dialing | //   - protection from server side request forgery (SSRF) by only dialing | ||||||
| //     out to known public IP prefixes, configurable with allows/blocks | //     out to known public IP prefixes, configurable with allows/blocks | ||||||
| //   - limit number of concurrent requests, else blocking until a slot |  | ||||||
| //     is available (context channels still respected) |  | ||||||
| type Client struct { | type Client struct { | ||||||
| 	client http.Client | 	client   http.Client | ||||||
| 	queue  *hashmap.Map[string, chan struct{}] | 	badHosts cache.Cache[string, struct{}] | ||||||
| 	bmax   int64 // max response body size | 	bodyMax  int64 | ||||||
| 	cmax   int   // max open conns per host |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // New returns a new instance of Client initialized using configuration. | // New returns a new instance of Client initialized using configuration. | ||||||
|  | @ -109,28 +114,26 @@ func New(cfg Config) *Client { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if cfg.MaxIdleConns <= 0 { | 	if cfg.MaxIdleConns <= 0 { | ||||||
| 		// By default base this value on MaxOpenConns | 		// By default base this value on MaxOpenConns. | ||||||
| 		cfg.MaxIdleConns = cfg.MaxOpenConnsPerHost * 10 | 		cfg.MaxIdleConns = cfg.MaxOpenConnsPerHost * 10 | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if cfg.MaxBodySize <= 0 { | 	if cfg.MaxBodySize <= 0 { | ||||||
| 		// By default set this to a reasonable 40MB | 		// By default set this to a reasonable 40MB. | ||||||
| 		cfg.MaxBodySize = int64(40 * bytesize.MiB) | 		cfg.MaxBodySize = int64(40 * bytesize.MiB) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Protect dialer with IP range sanitizer | 	// Protect dialer with IP range sanitizer. | ||||||
| 	d.Control = (&sanitizer{ | 	d.Control = (&sanitizer{ | ||||||
| 		allow: cfg.AllowRanges, | 		allow: cfg.AllowRanges, | ||||||
| 		block: cfg.BlockRanges, | 		block: cfg.BlockRanges, | ||||||
| 	}).Sanitize | 	}).Sanitize | ||||||
| 
 | 
 | ||||||
| 	// Prepare client fields | 	// Prepare client fields. | ||||||
| 	c.client.Timeout = cfg.Timeout | 	c.client.Timeout = cfg.Timeout | ||||||
| 	c.cmax = cfg.MaxOpenConnsPerHost | 	c.bodyMax = cfg.MaxBodySize | ||||||
| 	c.bmax = cfg.MaxBodySize |  | ||||||
| 	c.queue = hashmap.New[string, chan struct{}]() |  | ||||||
| 
 | 
 | ||||||
| 	// Set underlying HTTP client roundtripper | 	// Set underlying HTTP client roundtripper. | ||||||
| 	c.client.Transport = &http.Transport{ | 	c.client.Transport = &http.Transport{ | ||||||
| 		Proxy:                 http.ProxyFromEnvironment, | 		Proxy:                 http.ProxyFromEnvironment, | ||||||
| 		ForceAttemptHTTP2:     true, | 		ForceAttemptHTTP2:     true, | ||||||
|  | @ -144,90 +147,185 @@ func New(cfg Config) *Client { | ||||||
| 		DisableCompression:    cfg.DisableCompression, | 		DisableCompression:    cfg.DisableCompression, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Initiate outgoing bad hosts lookup cache. | ||||||
|  | 	c.badHosts = cache.New[string, struct{}](0, 1000, 0) | ||||||
|  | 	c.badHosts.SetTTL(15*time.Minute, false) | ||||||
|  | 	if !c.badHosts.Start(time.Minute) { | ||||||
|  | 		log.Panic(nil, "failed to start transport controller cache") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	return &c | 	return &c | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Do will perform given request when an available slot in the queue is available, | // Do ... | ||||||
| // and block until this time. For returned values, this follows the same semantics | func (c *Client) Do(r *http.Request) (*http.Response, error) { | ||||||
| // as the standard http.Client{}.Do() implementation except that response body will | 	return c.DoSigned(r, func(r *http.Request) error { | ||||||
| // be wrapped by an io.LimitReader() to limit response body sizes. | 		return nil // no request signing | ||||||
| func (c *Client) Do(req *http.Request) (*http.Response, error) { | 	}) | ||||||
| 	// Ensure this is a valid request | } | ||||||
| 	if err := ValidateRequest(req); err != nil { | 
 | ||||||
| 		return nil, err | // DoSigned ... | ||||||
|  | func (c *Client) DoSigned(r *http.Request, sign SignFunc) (*http.Response, error) { | ||||||
|  | 	const ( | ||||||
|  | 		// max no. attempts. | ||||||
|  | 		maxRetries = 5 | ||||||
|  | 
 | ||||||
|  | 		// starting backoff duration. | ||||||
|  | 		baseBackoff = 2 * time.Second | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	// Get request hostname. | ||||||
|  | 	host := r.URL.Hostname() | ||||||
|  | 
 | ||||||
|  | 	// Check whether request should fast fail. | ||||||
|  | 	fastFail := gtscontext.IsFastfail(r.Context()) | ||||||
|  | 	if !fastFail { | ||||||
|  | 		// Check if recently reached max retries for this host | ||||||
|  | 		// so we don't bother with a retry-backoff loop. The only | ||||||
|  | 		// errors that are retried upon are server failure and | ||||||
|  | 		// domain resolution type errors, so this cached result | ||||||
|  | 		// indicates this server is likely having issues. | ||||||
|  | 		fastFail = c.badHosts.Has(host) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Get host's wait queue | 	// Start a log entry for this request | ||||||
| 	wait := c.wait(req.Host) | 	l := log.WithContext(r.Context()). | ||||||
|  | 		WithFields(kv.Fields{ | ||||||
|  | 			{"method", r.Method}, | ||||||
|  | 			{"url", r.URL.String()}, | ||||||
|  | 		}...) | ||||||
| 
 | 
 | ||||||
| 	var ok bool | 	for i := 0; i < maxRetries; i++ { | ||||||
|  | 		var backoff time.Duration | ||||||
| 
 | 
 | ||||||
| 	select { | 		// Reset signing header fields | ||||||
| 	// Quickly try grab a spot | 		now := time.Now().UTC() | ||||||
| 	case wait <- struct{}{}: | 		r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT") | ||||||
| 		// it's our turn! | 		r.Header.Del("Signature") | ||||||
| 		ok = true | 		r.Header.Del("Digest") | ||||||
| 
 | 
 | ||||||
| 		// NOTE: | 		// Rewind body reader and content-length if set. | ||||||
| 		// Ideally here we would set the slot release to happen either | 		if rc, ok := r.Body.(*byteutil.ReadNopCloser); ok { | ||||||
| 		// on error return, or via callback from the response body closer. | 			r.ContentLength = int64(rc.Len()) | ||||||
| 		// However when implementing this, there appear deadlocks between | 			rc.Rewind() | ||||||
| 		// the channel queue here and the media manager worker pool. So | 		} | ||||||
| 		// currently we only place a limit on connections dialing out, but |  | ||||||
| 		// there may still be more connections open than len(c.queue) given |  | ||||||
| 		// that connections may not be closed until response body is closed. |  | ||||||
| 		// The current implementation will reduce the viability of denial of |  | ||||||
| 		// service attacks, but if there are future issues heed this advice :] |  | ||||||
| 		defer func() { <-wait }() |  | ||||||
| 	default: |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	if !ok { | 		// Sign the outgoing request. | ||||||
| 		// No spot acquired, log warning | 		if err := sign(r); err != nil { | ||||||
| 		log.WithContext(req.Context()). | 			return nil, err | ||||||
| 			WithFields(kv.Fields{ | 		} | ||||||
| 				{K: "queue", V: len(wait)}, | 
 | ||||||
| 				{K: "method", V: req.Method}, | 		l.Infof("performing request") | ||||||
| 				{K: "host", V: req.Host}, | 
 | ||||||
| 				{K: "uri", V: req.URL.RequestURI()}, | 		// Perform the request. | ||||||
| 			}...).Warn("full request queue") | 		rsp, err := c.do(r) | ||||||
|  | 		if err == nil { //nolint:gocritic | ||||||
|  | 
 | ||||||
|  | 			// TooManyRequest means we need to slow | ||||||
|  | 			// down and retry our request. Codes over | ||||||
|  | 			// 500 generally indicate temp. outages. | ||||||
|  | 			if code := rsp.StatusCode; code < 500 && | ||||||
|  | 				code != http.StatusTooManyRequests { | ||||||
|  | 				return rsp, nil | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Generate error from status code for logging | ||||||
|  | 			err = errors.New(`http response "` + rsp.Status + `"`) | ||||||
|  | 
 | ||||||
|  | 			// Search for a provided "Retry-After" header value. | ||||||
|  | 			if after := rsp.Header.Get("Retry-After"); after != "" { | ||||||
|  | 
 | ||||||
|  | 				if u, _ := strconv.ParseUint(after, 10, 32); u != 0 { | ||||||
|  | 					// An integer number of backoff seconds was provided. | ||||||
|  | 					backoff = time.Duration(u) * time.Second | ||||||
|  | 				} else if at, _ := http.ParseTime(after); !at.Before(now) { | ||||||
|  | 					// An HTTP formatted future date-time was provided. | ||||||
|  | 					backoff = at.Sub(now) | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				// Don't let their provided backoff exceed our max. | ||||||
|  | 				if max := baseBackoff * maxRetries; backoff > max { | ||||||
|  | 					backoff = max | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 		} else if errorsv2.Is(err, | ||||||
|  | 			context.DeadlineExceeded, | ||||||
|  | 			context.Canceled, | ||||||
|  | 			ErrBodyTooLarge, | ||||||
|  | 			ErrReservedAddr, | ||||||
|  | 		) { | ||||||
|  | 			// Return on non-retryable errors | ||||||
|  | 			return nil, err | ||||||
|  | 		} else if strings.Contains(err.Error(), "stopped after 10 redirects") { | ||||||
|  | 			// Don't bother if net/http returned after too many redirects | ||||||
|  | 			return nil, err | ||||||
|  | 		} else if errors.As(err, &x509.UnknownAuthorityError{}) { | ||||||
|  | 			// Unknown authority errors we do NOT recover from | ||||||
|  | 			return nil, err | ||||||
|  | 		} else if dnserr := (*net.DNSError)(nil); // nocollapse | ||||||
|  | 		errors.As(err, &dnserr) && dnserr.IsNotFound { | ||||||
|  | 			// DNS lookup failure, this domain does not exist | ||||||
|  | 			return nil, gtserror.SetNotFound(err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if fastFail { | ||||||
|  | 			// on fast-fail, don't bother backoff/retry | ||||||
|  | 			return nil, fmt.Errorf("%w (fast fail)", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if backoff == 0 { | ||||||
|  | 			// No retry-after found, set our predefined | ||||||
|  | 			// backoff according to a multiplier of 2^n. | ||||||
|  | 			backoff = baseBackoff * 1 << (i + 1) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		l.Errorf("backing off for %s after http request error: %v", backoff, err) | ||||||
| 
 | 
 | ||||||
| 		select { | 		select { | ||||||
| 		case <-req.Context().Done(): | 		// Request ctx cancelled | ||||||
| 			// the request was canceled before we | 		case <-r.Context().Done(): | ||||||
| 			// got to our turn: no need to release | 			return nil, r.Context().Err() | ||||||
| 			return nil, req.Context().Err() | 
 | ||||||
| 		case wait <- struct{}{}: | 		// Backoff for some time | ||||||
| 			defer func() { <-wait }() | 		case <-time.After(backoff): | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Perform the HTTP request | 	// Add "bad" entry for this host. | ||||||
|  | 	c.badHosts.Set(host, struct{}{}) | ||||||
|  | 
 | ||||||
|  | 	return nil, errors.New("transport reached max retries") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // do ... | ||||||
|  | func (c *Client) do(req *http.Request) (*http.Response, error) { | ||||||
|  | 	// Perform the HTTP request. | ||||||
| 	rsp, err := c.client.Do(req) | 	rsp, err := c.client.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Check response body not too large | 	// Check response body not too large. | ||||||
| 	if rsp.ContentLength > c.bmax { | 	if rsp.ContentLength > c.bodyMax { | ||||||
| 		return nil, ErrBodyTooLarge | 		return nil, ErrBodyTooLarge | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Seperate the body implementers | 	// Seperate the body implementers. | ||||||
| 	rbody := (io.Reader)(rsp.Body) | 	rbody := (io.Reader)(rsp.Body) | ||||||
| 	cbody := (io.Closer)(rsp.Body) | 	cbody := (io.Closer)(rsp.Body) | ||||||
| 
 | 
 | ||||||
| 	var limit int64 | 	var limit int64 | ||||||
| 
 | 
 | ||||||
| 	if limit = rsp.ContentLength; limit < 0 { | 	if limit = rsp.ContentLength; limit < 0 { | ||||||
| 		// If unknown, use max as reader limit | 		// If unknown, use max as reader limit. | ||||||
| 		limit = c.bmax | 		limit = c.bodyMax | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Don't trust them, limit body reads | 	// Don't trust them, limit body reads. | ||||||
| 	rbody = io.LimitReader(rbody, limit) | 	rbody = io.LimitReader(rbody, limit) | ||||||
| 
 | 
 | ||||||
| 	// Wrap body with limit | 	// Wrap body with limit. | ||||||
| 	rsp.Body = &struct { | 	rsp.Body = &struct { | ||||||
| 		io.Reader | 		io.Reader | ||||||
| 		io.Closer | 		io.Closer | ||||||
|  | @ -235,17 +333,3 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { | ||||||
| 
 | 
 | ||||||
| 	return rsp, nil | 	return rsp, nil | ||||||
| } | } | ||||||
| 
 |  | ||||||
| // wait acquires the 'wait' queue for the given host string, or allocates new. |  | ||||||
| func (c *Client) wait(host string) chan struct{} { |  | ||||||
| 	// Look for an existing queue |  | ||||||
| 	queue, ok := c.queue.Get(host) |  | ||||||
| 	if ok { |  | ||||||
| 		return queue |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Allocate a new host queue (or return a sneaky existing one). |  | ||||||
| 	queue, _ = c.queue.GetOrInsert(host, make(chan struct{}, c.cmax)) |  | ||||||
| 
 |  | ||||||
| 	return queue |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -48,14 +48,6 @@ var bodies = []string{ | ||||||
| 	"body with\r\nnewlines", | 	"body with\r\nnewlines", | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Note: |  | ||||||
| // There is no test for the .MaxOpenConns implementation |  | ||||||
| // in the httpclient.Client{}, due to the difficult to test |  | ||||||
| // this. The block is only held for the actual dial out to |  | ||||||
| // the connection, so the usual test of blocking and holding |  | ||||||
| // open this queue slot to check we can't open another isn't |  | ||||||
| // an easy test here. |  | ||||||
| 
 |  | ||||||
| func TestHTTPClientSmallBody(t *testing.T) { | func TestHTTPClientSmallBody(t *testing.T) { | ||||||
| 	for _, body := range bodies { | 	for _, body := range bodies { | ||||||
| 		_TestHTTPClientWithBody(t, []byte(body), int(^uint16(0))) | 		_TestHTTPClientWithBody(t, []byte(body), int(^uint16(0))) | ||||||
|  |  | ||||||
|  | @ -1,62 +0,0 @@ | ||||||
| // GoToSocial |  | ||||||
| // Copyright (C) GoToSocial Authors admin@gotosocial.org |  | ||||||
| // SPDX-License-Identifier: AGPL-3.0-or-later |  | ||||||
| // |  | ||||||
| // This program is free software: you can redistribute it and/or modify |  | ||||||
| // it under the terms of the GNU Affero General Public License as published by |  | ||||||
| // the Free Software Foundation, either version 3 of the License, or |  | ||||||
| // (at your option) any later version. |  | ||||||
| // |  | ||||||
| // This program is distributed in the hope that it will be useful, |  | ||||||
| // but WITHOUT ANY WARRANTY; without even the implied warranty of |  | ||||||
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the |  | ||||||
| // GNU Affero General Public License for more details. |  | ||||||
| // |  | ||||||
| // You should have received a copy of the GNU Affero General Public License |  | ||||||
| // along with this program.  If not, see <http://www.gnu.org/licenses/>. |  | ||||||
| 
 |  | ||||||
| package httpclient |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"net/http" |  | ||||||
| 	"strings" |  | ||||||
| 
 |  | ||||||
| 	"golang.org/x/net/http/httpguts" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // ValidateRequest performs the same request validation logic found in the default |  | ||||||
| // net/http.Transport{}.roundTrip() function, but pulls it out into this separate |  | ||||||
| // function allowing validation errors to be wrapped under a single error type. |  | ||||||
| func ValidateRequest(r *http.Request) error { |  | ||||||
| 	switch { |  | ||||||
| 	case r.URL == nil: |  | ||||||
| 		return fmt.Errorf("%w: nil url", ErrInvalidRequest) |  | ||||||
| 	case r.Header == nil: |  | ||||||
| 		return fmt.Errorf("%w: nil header", ErrInvalidRequest) |  | ||||||
| 	case r.URL.Host == "": |  | ||||||
| 		return fmt.Errorf("%w: empty url host", ErrInvalidRequest) |  | ||||||
| 	case r.URL.Scheme != "http" && r.URL.Scheme != "https": |  | ||||||
| 		return fmt.Errorf("%w: unsupported protocol %q", ErrInvalidRequest, r.URL.Scheme) |  | ||||||
| 	case strings.IndexFunc(r.Method, func(r rune) bool { return !httpguts.IsTokenRune(r) }) != -1: |  | ||||||
| 		return fmt.Errorf("%w: invalid method %q", ErrInvalidRequest, r.Method) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for key, values := range r.Header { |  | ||||||
| 		// Check field key name is valid |  | ||||||
| 		if !httpguts.ValidHeaderFieldName(key) { |  | ||||||
| 			return fmt.Errorf("%w: invalid header field name %q", ErrInvalidRequest, key) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// Check each field value is valid |  | ||||||
| 		for i := 0; i < len(values); i++ { |  | ||||||
| 			if !httpguts.ValidHeaderFieldValue(values[i]) { |  | ||||||
| 				return fmt.Errorf("%w: invalid header field value %q", ErrInvalidRequest, values[i]) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// ps. kim wrote this |  | ||||||
| 
 |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  | @ -15,19 +15,14 @@ | ||||||
| // You should have received a copy of the GNU Affero General Public License | // You should have received a copy of the GNU Affero General Public License | ||||||
| // along with this program.  If not, see <http://www.gnu.org/licenses/>. | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
| 
 | 
 | ||||||
| package transport_test | package httpclient | ||||||
| 
 | 
 | ||||||
| import ( | import "net/http" | ||||||
| 	"context" |  | ||||||
| 	"testing" |  | ||||||
| 
 | 
 | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" | // SignFunc is a function signature that provides request signing. | ||||||
| ) | type SignFunc func(r *http.Request) error | ||||||
| 
 | 
 | ||||||
| func TestFastFailContext(t *testing.T) { | type SigningClient interface { | ||||||
| 	ctx := context.Background() | 	Do(r *http.Request) (*http.Response, error) | ||||||
| 	ctx = transport.WithFastfail(ctx) | 	DoSigned(r *http.Request, sign SignFunc) (*http.Response, error) | ||||||
| 	if !transport.IsFastfail(ctx) { |  | ||||||
| 		t.Fatal("failed to set fast-fail context key") |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  | @ -34,7 +34,7 @@ import ( | ||||||
| func Logger() gin.HandlerFunc { | func Logger() gin.HandlerFunc { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		// Initialize the logging fields | 		// Initialize the logging fields | ||||||
| 		fields := make(kv.Fields, 6, 7) | 		fields := make(kv.Fields, 5, 7) | ||||||
| 
 | 
 | ||||||
| 		// Determine pre-handler time | 		// Determine pre-handler time | ||||||
| 		before := time.Now() | 		before := time.Now() | ||||||
|  | @ -68,11 +68,18 @@ func Logger() gin.HandlerFunc { | ||||||
| 
 | 
 | ||||||
| 			// Set request logging fields | 			// Set request logging fields | ||||||
| 			fields[0] = kv.Field{"latency", time.Since(before)} | 			fields[0] = kv.Field{"latency", time.Since(before)} | ||||||
| 			fields[1] = kv.Field{"clientIP", c.ClientIP()} | 			fields[1] = kv.Field{"userAgent", c.Request.UserAgent()} | ||||||
| 			fields[2] = kv.Field{"userAgent", c.Request.UserAgent()} | 			fields[2] = kv.Field{"method", c.Request.Method} | ||||||
| 			fields[3] = kv.Field{"method", c.Request.Method} | 			fields[3] = kv.Field{"statusCode", code} | ||||||
| 			fields[4] = kv.Field{"statusCode", code} | 			fields[4] = kv.Field{"path", path} | ||||||
| 			fields[5] = kv.Field{"path", path} | 			if includeClientIP := true; includeClientIP { | ||||||
|  | 				// TODO: make this configurable. | ||||||
|  | 				// | ||||||
|  | 				// Include clientIP if enabled. | ||||||
|  | 				fields = append(fields, kv.Field{ | ||||||
|  | 					"clientIP", c.ClientIP(), | ||||||
|  | 				}) | ||||||
|  | 			} | ||||||
| 
 | 
 | ||||||
| 			// Create log entry with fields | 			// Create log entry with fields | ||||||
| 			l := log.WithContext(c.Request.Context()). | 			l := log.WithContext(c.Request.Context()). | ||||||
|  |  | ||||||
|  | @ -19,7 +19,6 @@ package middleware | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bufio" | 	"bufio" | ||||||
| 	"context" |  | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
| 	"encoding/base32" | 	"encoding/base32" | ||||||
| 	"encoding/binary" | 	"encoding/binary" | ||||||
|  | @ -27,17 +26,11 @@ import ( | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"codeberg.org/gruf/go-kv" |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type ctxType string |  | ||||||
| 
 |  | ||||||
| var ( | var ( | ||||||
| 	// ridCtxKey is the key underwhich we store request IDs in a context. |  | ||||||
| 	ridCtxKey ctxType = "id" |  | ||||||
| 
 |  | ||||||
| 	// crand provides buffered reads of random input. | 	// crand provides buffered reads of random input. | ||||||
| 	crand = bufio.NewReader(rand.Reader) | 	crand = bufio.NewReader(rand.Reader) | ||||||
| 	mrand sync.Mutex | 	mrand sync.Mutex | ||||||
|  | @ -69,22 +62,8 @@ func generateID() string { | ||||||
| 	return base32enc.EncodeToString(b) | 	return base32enc.EncodeToString(b) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RequestID fetches the stored request ID from context. |  | ||||||
| func RequestID(ctx context.Context) string { |  | ||||||
| 	id, _ := ctx.Value(ridCtxKey).(string) |  | ||||||
| 	return id |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // AddRequestID returns a gin middleware which adds a unique ID to each request (both response header and context). | // AddRequestID returns a gin middleware which adds a unique ID to each request (both response header and context). | ||||||
| func AddRequestID(header string) gin.HandlerFunc { | func AddRequestID(header string) gin.HandlerFunc { | ||||||
| 	log.Hook(func(ctx context.Context, kvs []kv.Field) []kv.Field { |  | ||||||
| 		if id, _ := ctx.Value(ridCtxKey).(string); id != "" { |  | ||||||
| 			// Add stored request ID to log entry fields. |  | ||||||
| 			return append(kvs, kv.Field{K: "requestID", V: id}) |  | ||||||
| 		} |  | ||||||
| 		return kvs |  | ||||||
| 	}) |  | ||||||
| 
 |  | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
| 		// Look for existing ID. | 		// Look for existing ID. | ||||||
| 		id := c.GetHeader(header) | 		id := c.GetHeader(header) | ||||||
|  | @ -100,8 +79,8 @@ func AddRequestID(header string) gin.HandlerFunc { | ||||||
| 			c.Request.Header.Set(header, id) | 			c.Request.Header.Set(header, id) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// Store request ID in new request ctx and set new gin request obj. | 		// Store request ID in new request context and set on gin ctx. | ||||||
| 		ctx := context.WithValue(c.Request.Context(), ridCtxKey, id) | 		ctx := gtscontext.SetRequestID(c.Request.Context(), id) | ||||||
| 		c.Request = c.Request.WithContext(ctx) | 		c.Request = c.Request.WithContext(ctx) | ||||||
| 
 | 
 | ||||||
| 		// Set the request ID in the rsp header. | 		// Set the request ID in the rsp header. | ||||||
|  |  | ||||||
|  | @ -25,9 +25,9 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Get processes the given request for account information. | // Get processes the given request for account information. | ||||||
|  | @ -96,7 +96,7 @@ func (p *Processor) getFor(ctx context.Context, requestingAccount *gtsmodel.Acco | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		a, err := p.federator.GetAccountByURI( | 		a, err := p.federator.GetAccountByURI( | ||||||
| 			transport.WithFastfail(ctx), requestingAccount.Username, targetAccountURI, true, | 			gtscontext.SetFastFail(ctx), requestingAccount.Username, targetAccountURI, true, | ||||||
| 		) | 		) | ||||||
| 		if err == nil { | 		if err == nil { | ||||||
| 			targetAccount = a | 			targetAccount = a | ||||||
|  |  | ||||||
|  | @ -22,9 +22,9 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func (p *Processor) authenticate(ctx context.Context, requestedUsername string) (requestedAccount, requestingAccount *gtsmodel.Account, errWithCode gtserror.WithCode) { | func (p *Processor) authenticate(ctx context.Context, requestedUsername string) (requestedAccount, requestingAccount *gtsmodel.Account, errWithCode gtserror.WithCode) { | ||||||
|  | @ -40,7 +40,7 @@ func (p *Processor) authenticate(ctx context.Context, requestedUsername string) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if requestingAccount, err = p.federator.GetAccountByURI(transport.WithFastfail(ctx), requestedUsername, requestingAccountURI, false); err != nil { | 	if requestingAccount, err = p.federator.GetAccountByURI(gtscontext.SetFastFail(ctx), requestedUsername, requestingAccountURI, false); err != nil { | ||||||
| 		errWithCode = gtserror.NewErrorUnauthorized(err) | 		errWithCode = gtserror.NewErrorUnauthorized(err) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -24,8 +24,8 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/superseriousbusiness/activity/streams" | 	"github.com/superseriousbusiness/activity/streams" | ||||||
| 	"github.com/superseriousbusiness/activity/streams/vocab" | 	"github.com/superseriousbusiness/activity/streams/vocab" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/uris" | 	"github.com/superseriousbusiness/gotosocial/internal/uris" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -56,7 +56,7 @@ func (p *Processor) UserGet(ctx context.Context, requestedUsername string, reque | ||||||
| 		// if we're not already handshaking/dereferencing a remote account, dereference it now | 		// if we're not already handshaking/dereferencing a remote account, dereference it now | ||||||
| 		if !p.federator.Handshaking(requestedUsername, requestingAccountURI) { | 		if !p.federator.Handshaking(requestedUsername, requestingAccountURI) { | ||||||
| 			requestingAccount, err := p.federator.GetAccountByURI( | 			requestingAccount, err := p.federator.GetAccountByURI( | ||||||
| 				transport.WithFastfail(ctx), requestedUsername, requestingAccountURI, false, | 				gtscontext.SetFastFail(ctx), requestedUsername, requestingAccountURI, false, | ||||||
| 			) | 			) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, gtserror.NewErrorUnauthorized(err) | 				return nil, gtserror.NewErrorUnauthorized(err) | ||||||
|  |  | ||||||
|  | @ -25,10 +25,10 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/media" | 	"github.com/superseriousbusiness/gotosocial/internal/media" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/uris" | 	"github.com/superseriousbusiness/gotosocial/internal/uris" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -157,7 +157,7 @@ func (p *Processor) getAttachmentContent(ctx context.Context, requestingAccount | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, 0, err | 				return nil, 0, err | ||||||
| 			} | 			} | ||||||
| 			return t.DereferenceMedia(transport.WithFastfail(innerCtx), remoteMediaIRI) | 			return t.DereferenceMedia(gtscontext.SetFastFail(innerCtx), remoteMediaIRI) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// Start recaching this media with the prepared data function. | 		// Start recaching this media with the prepared data function. | ||||||
|  |  | ||||||
|  | @ -30,11 +30,11 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/config" | 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing" | 	"github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/util" | 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -226,14 +226,14 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *Processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL) (*gtsmodel.Status, error) { | func (p *Processor) searchStatusByURI(ctx context.Context, authed *oauth.Auth, uri *url.URL) (*gtsmodel.Status, error) { | ||||||
| 	status, statusable, err := p.federator.GetStatus(transport.WithFastfail(ctx), authed.Account.Username, uri, true, true) | 	status, statusable, err := p.federator.GetStatus(gtscontext.SetFastFail(ctx), authed.Account.Username, uri, true, true) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if !*status.Local && statusable != nil { | 	if !*status.Local && statusable != nil { | ||||||
| 		// Attempt to dereference the status thread while we are here | 		// Attempt to dereference the status thread while we are here | ||||||
| 		p.federator.DereferenceThread(transport.WithFastfail(ctx), authed.Account.Username, uri, status, statusable) | 		p.federator.DereferenceThread(gtscontext.SetFastFail(ctx), authed.Account.Username, uri, status, statusable) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return status, nil | 	return status, nil | ||||||
|  | @ -268,7 +268,7 @@ func (p *Processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return p.federator.GetAccountByURI( | 	return p.federator.GetAccountByURI( | ||||||
| 		transport.WithFastfail(ctx), | 		gtscontext.SetFastFail(ctx), | ||||||
| 		authed.Account.Username, | 		authed.Account.Username, | ||||||
| 		uri, false, | 		uri, false, | ||||||
| 	) | 	) | ||||||
|  | @ -295,7 +295,7 @@ func (p *Processor) searchAccountByUsernameDomain(ctx context.Context, authed *o | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return p.federator.GetAccountByUsernameDomain( | 	return p.federator.GetAccountByUsernameDomain( | ||||||
| 		transport.WithFastfail(ctx), | 		gtscontext.SetFastFail(ctx), | ||||||
| 		authed.Account.Username, | 		authed.Account.Username, | ||||||
| 		username, domain, false, | 		username, domain, false, | ||||||
| 	) | 	) | ||||||
|  |  | ||||||
|  | @ -24,9 +24,9 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/config" | 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/federation" | 	"github.com/superseriousbusiness/gotosocial/internal/federation" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/id" | 	"github.com/superseriousbusiness/gotosocial/internal/id" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/util" | 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -58,7 +58,7 @@ func GetParseMentionFunc(dbConn db.DB, federator federation.Federator) gtsmodel. | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			remoteAccount, err := federator.GetAccountByUsernameDomain( | 			remoteAccount, err := federator.GetAccountByUsernameDomain( | ||||||
| 				transport.WithFastfail(ctx), | 				gtscontext.SetFastFail(ctx), | ||||||
| 				requestingUsername, | 				requestingUsername, | ||||||
| 				username, | 				username, | ||||||
| 				domain, | 				domain, | ||||||
|  |  | ||||||
|  | @ -1,42 +0,0 @@ | ||||||
| // GoToSocial |  | ||||||
| // Copyright (C) GoToSocial Authors admin@gotosocial.org |  | ||||||
| // SPDX-License-Identifier: AGPL-3.0-or-later |  | ||||||
| // |  | ||||||
| // This program is free software: you can redistribute it and/or modify |  | ||||||
| // it under the terms of the GNU Affero General Public License as published by |  | ||||||
| // the Free Software Foundation, either version 3 of the License, or |  | ||||||
| // (at your option) any later version. |  | ||||||
| // |  | ||||||
| // This program is distributed in the hope that it will be useful, |  | ||||||
| // but WITHOUT ANY WARRANTY; without even the implied warranty of |  | ||||||
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the |  | ||||||
| // GNU Affero General Public License for more details. |  | ||||||
| // |  | ||||||
| // You should have received a copy of the GNU Affero General Public License |  | ||||||
| // along with this program.  If not, see <http://www.gnu.org/licenses/>. |  | ||||||
| 
 |  | ||||||
| package transport |  | ||||||
| 
 |  | ||||||
| import "context" |  | ||||||
| 
 |  | ||||||
| // ctxkey is our own unique context key type to prevent setting outside package. |  | ||||||
| type ctxkey string |  | ||||||
| 
 |  | ||||||
| // fastfailkey is our unique context key to indicate fast-fail is enabled. |  | ||||||
| var fastfailkey = ctxkey("ff") |  | ||||||
| 
 |  | ||||||
| // WithFastfail returns a Context which indicates that any http requests made |  | ||||||
| // with it should return after the first failed attempt, instead of retrying. |  | ||||||
| // |  | ||||||
| // This can be used to fail quickly when you're making an outgoing http request |  | ||||||
| // inside the context of an incoming http request, and you want to be able to |  | ||||||
| // provide a snappy response to the user, instead of retrying + backing off. |  | ||||||
| func WithFastfail(parent context.Context) context.Context { |  | ||||||
| 	return context.WithValue(parent, fastfailkey, struct{}{}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // IsFastfail returns true if the given context was created by WithFastfail. |  | ||||||
| func IsFastfail(ctx context.Context) bool { |  | ||||||
| 	_, ok := ctx.Value(fastfailkey).(struct{}) |  | ||||||
| 	return ok |  | ||||||
| } |  | ||||||
|  | @ -24,7 +24,7 @@ import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"time" | 	"runtime" | ||||||
| 
 | 
 | ||||||
| 	"codeberg.org/gruf/go-byteutil" | 	"codeberg.org/gruf/go-byteutil" | ||||||
| 	"codeberg.org/gruf/go-cache/v3" | 	"codeberg.org/gruf/go-cache/v3" | ||||||
|  | @ -32,7 +32,7 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/activity/streams" | 	"github.com/superseriousbusiness/activity/streams" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/config" | 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" | 	"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | 	"github.com/superseriousbusiness/gotosocial/internal/httpclient" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/state" | 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -49,14 +49,14 @@ type controller struct { | ||||||
| 	state     *state.State | 	state     *state.State | ||||||
| 	fedDB     federatingdb.DB | 	fedDB     federatingdb.DB | ||||||
| 	clock     pub.Clock | 	clock     pub.Clock | ||||||
| 	client    pub.HttpClient | 	client    httpclient.SigningClient | ||||||
| 	trspCache cache.Cache[string, *transport] | 	trspCache cache.Cache[string, *transport] | ||||||
| 	badHosts  cache.Cache[string, struct{}] |  | ||||||
| 	userAgent string | 	userAgent string | ||||||
|  | 	senders   int // no. concurrent batch delivery routines. | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // NewController returns an implementation of the Controller interface for creating new transports | // NewController returns an implementation of the Controller interface for creating new transports | ||||||
| func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller { | func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.Clock, client httpclient.SigningClient) Controller { | ||||||
| 	applicationName := config.GetApplicationName() | 	applicationName := config.GetApplicationName() | ||||||
| 	host := config.GetHost() | 	host := config.GetHost() | ||||||
| 	proto := config.GetProtocol() | 	proto := config.GetProtocol() | ||||||
|  | @ -68,20 +68,8 @@ func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.C | ||||||
| 		clock:     clock, | 		clock:     clock, | ||||||
| 		client:    client, | 		client:    client, | ||||||
| 		trspCache: cache.New[string, *transport](0, 100, 0), | 		trspCache: cache.New[string, *transport](0, 100, 0), | ||||||
| 		badHosts:  cache.New[string, struct{}](0, 1000, 0), |  | ||||||
| 		userAgent: fmt.Sprintf("%s (+%s://%s) gotosocial/%s", applicationName, proto, host, version), | 		userAgent: fmt.Sprintf("%s (+%s://%s) gotosocial/%s", applicationName, proto, host, version), | ||||||
| 	} | 		senders:   runtime.GOMAXPROCS(0), // on batch delivery, only ever send GOMAXPROCS at a time. | ||||||
| 
 |  | ||||||
| 	// Transport cache has TTL=1hr freq=1min |  | ||||||
| 	c.trspCache.SetTTL(time.Hour, false) |  | ||||||
| 	if !c.trspCache.Start(time.Minute) { |  | ||||||
| 		log.Panic(nil, "failed to start transport controller cache") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Bad hosts cache has TTL=15min freq=1min |  | ||||||
| 	c.badHosts.SetTTL(15*time.Minute, false) |  | ||||||
| 	if !c.badHosts.Start(time.Minute) { |  | ||||||
| 		log.Panic(nil, "failed to start transport controller cache") |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return c | 	return c | ||||||
|  |  | ||||||
|  | @ -22,7 +22,6 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strings" |  | ||||||
| 	"sync" | 	"sync" | ||||||
| 
 | 
 | ||||||
| 	"codeberg.org/gruf/go-byteutil" | 	"codeberg.org/gruf/go-byteutil" | ||||||
|  | @ -32,54 +31,90 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*url.URL) error { | func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*url.URL) error { | ||||||
| 	// concurrently deliver to recipients; for each delivery, buffer the error if it fails | 	var ( | ||||||
| 	wg := sync.WaitGroup{} | 		// errs accumulates errors received during | ||||||
| 	errCh := make(chan error, len(recipients)) | 		// attempted delivery by deliverer routines. | ||||||
| 	for _, recipient := range recipients { | 		errs gtserror.MultiError | ||||||
| 		wg.Add(1) | 
 | ||||||
| 		go func(r *url.URL) { | 		// wait blocks until all sender | ||||||
| 			defer wg.Done() | 		// routines have returned. | ||||||
| 			if err := t.Deliver(ctx, b, r); err != nil { | 		wait sync.WaitGroup | ||||||
| 				errCh <- err | 
 | ||||||
|  | 		// mutex protects 'recipients' and | ||||||
|  | 		// 'errs' for concurrent access. | ||||||
|  | 		mutex sync.Mutex | ||||||
|  | 
 | ||||||
|  | 		// Get current instance host info. | ||||||
|  | 		domain = config.GetAccountDomain() | ||||||
|  | 		host   = config.GetHost() | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	// Block on expect no. senders. | ||||||
|  | 	wait.Add(t.controller.senders) | ||||||
|  | 
 | ||||||
|  | 	for i := 0; i < t.controller.senders; i++ { | ||||||
|  | 		go func() { | ||||||
|  | 			// Mark returned. | ||||||
|  | 			defer wait.Done() | ||||||
|  | 
 | ||||||
|  | 			for { | ||||||
|  | 				// Acquire lock. | ||||||
|  | 				mutex.Lock() | ||||||
|  | 
 | ||||||
|  | 				if len(recipients) == 0 { | ||||||
|  | 					// Reached end. | ||||||
|  | 					mutex.Unlock() | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				// Pop next recipient. | ||||||
|  | 				i := len(recipients) - 1 | ||||||
|  | 				to := recipients[i] | ||||||
|  | 				recipients = recipients[:i] | ||||||
|  | 
 | ||||||
|  | 				// Done with lock. | ||||||
|  | 				mutex.Unlock() | ||||||
|  | 
 | ||||||
|  | 				// Skip delivery to recipient if it is "us". | ||||||
|  | 				if to.Host == host || to.Host == domain { | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				// Attempt to deliver data to recipient. | ||||||
|  | 				if err := t.deliver(ctx, b, to); err != nil { | ||||||
|  | 					mutex.Lock() // safely append err to accumulator. | ||||||
|  | 					errs.Appendf("error delivering to %s: %v", to, err) | ||||||
|  | 					mutex.Unlock() | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 		}(recipient) | 		}() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// wait until all deliveries have succeeded or failed | 	// Wait for finish. | ||||||
| 	wg.Wait() | 	wait.Wait() | ||||||
| 
 | 
 | ||||||
| 	// receive any buffered errors | 	// Return combined err. | ||||||
| 	errs := make([]string, 0, len(errCh)) | 	return errs.Combine() | ||||||
| outer: |  | ||||||
| 	for { |  | ||||||
| 		select { |  | ||||||
| 		case e := <-errCh: |  | ||||||
| 			errs = append(errs, e.Error()) |  | ||||||
| 		default: |  | ||||||
| 			break outer |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if len(errs) > 0 { |  | ||||||
| 		return fmt.Errorf("BatchDeliver: at least one failure: %s", strings.Join(errs, "; ")) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error { | func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error { | ||||||
| 	// if the 'to' host is our own, just skip this delivery since we by definition already have the message! | 	// if 'to' host is our own, skip as we don't need to deliver to ourselves... | ||||||
| 	if to.Host == config.GetHost() || to.Host == config.GetAccountDomain() { | 	if to.Host == config.GetHost() || to.Host == config.GetAccountDomain() { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	urlStr := to.String() | 	// Deliver data to recipient. | ||||||
|  | 	return t.deliver(ctx, b, to) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (t *transport) deliver(ctx context.Context, b []byte, to *url.URL) error { | ||||||
|  | 	url := to.String() | ||||||
| 
 | 
 | ||||||
| 	// Use rewindable bytes reader for body. | 	// Use rewindable bytes reader for body. | ||||||
| 	var body byteutil.ReadNopCloser | 	var body byteutil.ReadNopCloser | ||||||
| 	body.Reset(b) | 	body.Reset(b) | ||||||
| 
 | 
 | ||||||
| 	req, err := http.NewRequestWithContext(ctx, "POST", urlStr, &body) | 	req, err := http.NewRequestWithContext(ctx, "POST", url, &body) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | @ -88,16 +123,16 @@ func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error { | ||||||
| 	req.Header.Add("Accept-Charset", "utf-8") | 	req.Header.Add("Accept-Charset", "utf-8") | ||||||
| 	req.Header.Set("Host", to.Host) | 	req.Header.Set("Host", to.Host) | ||||||
| 
 | 
 | ||||||
| 	resp, err := t.POST(req, b) | 	rsp, err := t.POST(req, b) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	defer resp.Body.Close() | 	defer rsp.Body.Close() | ||||||
| 
 | 
 | ||||||
| 	if code := resp.StatusCode; code != http.StatusOK && | 	if code := rsp.StatusCode; code != http.StatusOK && | ||||||
| 		code != http.StatusCreated && code != http.StatusAccepted { | 		code != http.StatusCreated && code != http.StatusAccepted { | ||||||
| 		err := fmt.Errorf("POST request to %s failed: %s", urlStr, resp.Status) | 		err := fmt.Errorf("POST request to %s failed: %s", url, rsp.Status) | ||||||
| 		return gtserror.WithStatusCode(err, resp.StatusCode) | 		return gtserror.WithStatusCode(err, rsp.StatusCode) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  |  | ||||||
|  | @ -20,26 +20,17 @@ package transport | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"crypto" | 	"crypto" | ||||||
| 	"crypto/x509" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 	"io" | 	"io" | ||||||
| 	"net" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strconv" |  | ||||||
| 	"strings" |  | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"codeberg.org/gruf/go-byteutil" |  | ||||||
| 	errorsv2 "codeberg.org/gruf/go-errors/v2" |  | ||||||
| 	"codeberg.org/gruf/go-kv" |  | ||||||
| 	"github.com/go-fed/httpsig" | 	"github.com/go-fed/httpsig" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/httpclient" | 	"github.com/superseriousbusiness/gotosocial/internal/httpclient" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Transport implements the pub.Transport interface with some additional functionality for fetching remote media. | // Transport implements the pub.Transport interface with some additional functionality for fetching remote media. | ||||||
|  | @ -78,7 +69,7 @@ type Transport interface { | ||||||
| 	Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) | 	Finger(ctx context.Context, targetUsername string, targetDomain string) ([]byte, error) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // transport implements the Transport interface | // transport implements the Transport interface. | ||||||
| type transport struct { | type transport struct { | ||||||
| 	controller *controller | 	controller *controller | ||||||
| 	pubKeyID   string | 	pubKeyID   string | ||||||
|  | @ -95,9 +86,11 @@ func (t *transport) GET(r *http.Request) (*http.Response, error) { | ||||||
| 	if r.Method != http.MethodGet { | 	if r.Method != http.MethodGet { | ||||||
| 		return nil, errors.New("must be GET request") | 		return nil, errors.New("must be GET request") | ||||||
| 	} | 	} | ||||||
| 	return t.do(r, func(r *http.Request) error { | 	ctx := r.Context() // extract, set pubkey ID. | ||||||
| 		return t.signGET(r) | 	ctx = gtscontext.SetPublicKeyID(ctx, t.pubKeyID) | ||||||
| 	}) | 	r = r.WithContext(ctx) // replace request ctx. | ||||||
|  | 	r.Header.Set("User-Agent", t.controller.userAgent) | ||||||
|  | 	return t.controller.client.DoSigned(r, t.signGET()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // POST will perform given http request using transport client, retrying on certain preset errors. | // POST will perform given http request using transport client, retrying on certain preset errors. | ||||||
|  | @ -105,161 +98,31 @@ func (t *transport) POST(r *http.Request, body []byte) (*http.Response, error) { | ||||||
| 	if r.Method != http.MethodPost { | 	if r.Method != http.MethodPost { | ||||||
| 		return nil, errors.New("must be POST request") | 		return nil, errors.New("must be POST request") | ||||||
| 	} | 	} | ||||||
| 	return t.do(r, func(r *http.Request) error { | 	ctx := r.Context() // extract, set pubkey ID. | ||||||
| 		return t.signPOST(r, body) | 	ctx = gtscontext.SetPublicKeyID(ctx, t.pubKeyID) | ||||||
| 	}) | 	r = r.WithContext(ctx) // replace request ctx. | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (t *transport) do(r *http.Request, signer func(*http.Request) error) (*http.Response, error) { |  | ||||||
| 	const ( |  | ||||||
| 		// max no. attempts |  | ||||||
| 		maxRetries = 5 |  | ||||||
| 
 |  | ||||||
| 		// starting backoff duration. |  | ||||||
| 		baseBackoff = 2 * time.Second |  | ||||||
| 	) |  | ||||||
| 
 |  | ||||||
| 	// Get request hostname |  | ||||||
| 	host := r.URL.Hostname() |  | ||||||
| 
 |  | ||||||
| 	// Check whether request should fast fail, we check this |  | ||||||
| 	// before loop as each context.Value() requires mutex lock. |  | ||||||
| 	fastFail := IsFastfail(r.Context()) |  | ||||||
| 	if !fastFail { |  | ||||||
| 		// Check if recently reached max retries for this host |  | ||||||
| 		// so we don't bother with a retry-backoff loop. The only |  | ||||||
| 		// errors that are retried upon are server failure and |  | ||||||
| 		// domain resolution type errors, so this cached result |  | ||||||
| 		// indicates this server is likely having issues. |  | ||||||
| 		fastFail = t.controller.badHosts.Has(host) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Start a log entry for this request |  | ||||||
| 	l := log.WithContext(r.Context()). |  | ||||||
| 		WithFields(kv.Fields{ |  | ||||||
| 			{"pubKeyID", t.pubKeyID}, |  | ||||||
| 			{"method", r.Method}, |  | ||||||
| 			{"url", r.URL.String()}, |  | ||||||
| 		}...) |  | ||||||
| 
 |  | ||||||
| 	r.Header.Set("User-Agent", t.controller.userAgent) | 	r.Header.Set("User-Agent", t.controller.userAgent) | ||||||
| 
 | 	return t.controller.client.DoSigned(r, t.signPOST(body)) | ||||||
| 	for i := 0; i < maxRetries; i++ { |  | ||||||
| 		var backoff time.Duration |  | ||||||
| 
 |  | ||||||
| 		// Reset signing header fields |  | ||||||
| 		now := t.controller.clock.Now().UTC() |  | ||||||
| 		r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT") |  | ||||||
| 		r.Header.Del("Signature") |  | ||||||
| 		r.Header.Del("Digest") |  | ||||||
| 
 |  | ||||||
| 		// Rewind body reader and content-length if set. |  | ||||||
| 		if rc, ok := r.Body.(*byteutil.ReadNopCloser); ok { |  | ||||||
| 			r.ContentLength = int64(rc.Len()) |  | ||||||
| 			rc.Rewind() |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// Perform request signing |  | ||||||
| 		if err := signer(r); err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		l.Infof("performing request") |  | ||||||
| 
 |  | ||||||
| 		// Attempt to perform request |  | ||||||
| 		rsp, err := t.controller.client.Do(r) |  | ||||||
| 		if err == nil { //nolint:gocritic |  | ||||||
| 			// TooManyRequest means we need to slow |  | ||||||
| 			// down and retry our request. Codes over |  | ||||||
| 			// 500 generally indicate temp. outages. |  | ||||||
| 			if code := rsp.StatusCode; code < 500 && |  | ||||||
| 				code != http.StatusTooManyRequests { |  | ||||||
| 				return rsp, nil |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			// Generate error from status code for logging |  | ||||||
| 			err = errors.New(`http response "` + rsp.Status + `"`) |  | ||||||
| 
 |  | ||||||
| 			// Search for a provided "Retry-After" header value. |  | ||||||
| 			if after := rsp.Header.Get("Retry-After"); after != "" { |  | ||||||
| 
 |  | ||||||
| 				if u, _ := strconv.ParseUint(after, 10, 32); u != 0 { |  | ||||||
| 					// An integer number of backoff seconds was provided. |  | ||||||
| 					backoff = time.Duration(u) * time.Second |  | ||||||
| 				} else if at, _ := http.ParseTime(after); !at.Before(now) { |  | ||||||
| 					// An HTTP formatted future date-time was provided. |  | ||||||
| 					backoff = at.Sub(now) |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				// Don't let their provided backoff exceed our max. |  | ||||||
| 				if max := baseBackoff * maxRetries; backoff > max { |  | ||||||
| 					backoff = max |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 		} else if errorsv2.Is(err, |  | ||||||
| 			context.DeadlineExceeded, |  | ||||||
| 			context.Canceled, |  | ||||||
| 			httpclient.ErrInvalidRequest, |  | ||||||
| 			httpclient.ErrBodyTooLarge, |  | ||||||
| 			httpclient.ErrReservedAddr, |  | ||||||
| 		) { |  | ||||||
| 			// Return on non-retryable errors |  | ||||||
| 			return nil, err |  | ||||||
| 		} else if strings.Contains(err.Error(), "stopped after 10 redirects") { |  | ||||||
| 			// Don't bother if net/http returned after too many redirects |  | ||||||
| 			return nil, err |  | ||||||
| 		} else if errors.As(err, &x509.UnknownAuthorityError{}) { |  | ||||||
| 			// Unknown authority errors we do NOT recover from |  | ||||||
| 			return nil, err |  | ||||||
| 		} else if dnserr := (*net.DNSError)(nil); // nocollapse |  | ||||||
| 		errors.As(err, &dnserr) && dnserr.IsNotFound { |  | ||||||
| 			// DNS lookup failure, this domain does not exist |  | ||||||
| 			return nil, gtserror.SetNotFound(err) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if fastFail { |  | ||||||
| 			// on fast-fail, don't bother backoff/retry |  | ||||||
| 			return nil, fmt.Errorf("%w (fast fail)", err) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if backoff == 0 { |  | ||||||
| 			// No retry-after found, set our predefined backoff. |  | ||||||
| 			backoff = time.Duration(i) * baseBackoff |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		l.Errorf("backing off for %s after http request error: %v", backoff, err) |  | ||||||
| 
 |  | ||||||
| 		select { |  | ||||||
| 		// Request ctx cancelled |  | ||||||
| 		case <-r.Context().Done(): |  | ||||||
| 			return nil, r.Context().Err() |  | ||||||
| 
 |  | ||||||
| 		// Backoff for some time |  | ||||||
| 		case <-time.After(backoff): |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Add "bad" entry for this host. |  | ||||||
| 	t.controller.badHosts.Set(host, struct{}{}) |  | ||||||
| 
 |  | ||||||
| 	return nil, errors.New("transport reached max retries") |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // signGET will safely sign an HTTP GET request. | // signGET will safely sign an HTTP GET request. | ||||||
| func (t *transport) signGET(r *http.Request) (err error) { | func (t *transport) signGET() httpclient.SignFunc { | ||||||
| 	t.safesign(func() { | 	return func(r *http.Request) (err error) { | ||||||
| 		err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil) | 		t.safesign(func() { | ||||||
| 	}) | 			err = t.getSigner.SignRequest(t.privkey, t.pubKeyID, r, nil) | ||||||
| 	return | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // signPOST will safely sign an HTTP POST request for given body. | // signPOST will safely sign an HTTP POST request for given body. | ||||||
| func (t *transport) signPOST(r *http.Request, body []byte) (err error) { | func (t *transport) signPOST(body []byte) httpclient.SignFunc { | ||||||
| 	t.safesign(func() { | 	return func(r *http.Request) (err error) { | ||||||
| 		err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body) | 		t.safesign(func() { | ||||||
| 	}) | 			err = t.postSigner.SignRequest(t.privkey, t.pubKeyID, r, body) | ||||||
| 	return | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // safesign will perform sign function within mutex protection, | // safesign will perform sign function within mutex protection, | ||||||
|  |  | ||||||
|  | @ -31,8 +31,12 @@ type Workers struct { | ||||||
| 	// Main task scheduler instance. | 	// Main task scheduler instance. | ||||||
| 	Scheduler sched.Scheduler | 	Scheduler sched.Scheduler | ||||||
| 
 | 
 | ||||||
| 	// ClientAPI / federator worker pools. | 	// ClientAPI provides a worker pool that handles both | ||||||
|  | 	// incoming client actions, and our own side-effects. | ||||||
| 	ClientAPI runners.WorkerPool | 	ClientAPI runners.WorkerPool | ||||||
|  | 
 | ||||||
|  | 	// Federator provides a worker pool that handles both | ||||||
|  | 	// incoming federated actions, and our own side-effects. | ||||||
| 	Federator runners.WorkerPool | 	Federator runners.WorkerPool | ||||||
| 
 | 
 | ||||||
| 	// Enqueue functions for clientAPI / federator worker pools, | 	// Enqueue functions for clientAPI / federator worker pools, | ||||||
|  |  | ||||||
|  | @ -26,12 +26,12 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 
 | 
 | ||||||
| 	"github.com/superseriousbusiness/activity/pub" |  | ||||||
| 	"github.com/superseriousbusiness/activity/streams" | 	"github.com/superseriousbusiness/activity/streams" | ||||||
| 	"github.com/superseriousbusiness/activity/streams/vocab" | 	"github.com/superseriousbusiness/activity/streams/vocab" | ||||||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/federation" | 	"github.com/superseriousbusiness/gotosocial/internal/federation" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/httpclient" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/state" | 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" | 	"github.com/superseriousbusiness/gotosocial/internal/transport" | ||||||
|  | @ -51,7 +51,7 @@ const ( | ||||||
| // Unlike the other test interfaces provided in this package, you'll probably want to call this function | // Unlike the other test interfaces provided in this package, you'll probably want to call this function | ||||||
| // PER TEST rather than per suite, so that the do function can be set on a test by test (or even more granular) | // PER TEST rather than per suite, so that the do function can be set on a test by test (or even more granular) | ||||||
| // basis. | // basis. | ||||||
| func NewTestTransportController(state *state.State, client pub.HttpClient) transport.Controller { | func NewTestTransportController(state *state.State, client httpclient.SigningClient) transport.Controller { | ||||||
| 	return transport.NewController(state, NewTestFederatingDB(state), &federation.Clock{}, client) | 	return transport.NewController(state, NewTestFederatingDB(state), &federation.Clock{}, client) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -225,6 +225,10 @@ func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) { | ||||||
| 	return m.do(req) | 	return m.do(req) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (m *MockHTTPClient) DoSigned(req *http.Request, sign httpclient.SignFunc) (*http.Response, error) { | ||||||
|  | 	return m.do(req) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func HostMetaResponse(req *http.Request) (responseCode int, responseBytes []byte, responseContentType string, responseContentLength int) { | func HostMetaResponse(req *http.Request) (responseCode int, responseBytes []byte, responseContentType string, responseContentLength int) { | ||||||
| 	var hm *apimodel.HostMeta | 	var hm *apimodel.HostMeta | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue