| 
									
										
										
										
											2024-04-11 10:45:35 +01:00
										 |  |  | package delivery_test | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"fmt" | 
					
						
							|  |  |  | 	"io" | 
					
						
							|  |  |  | 	"math/rand" | 
					
						
							|  |  |  | 	"net" | 
					
						
							|  |  |  | 	"net/http" | 
					
						
							|  |  |  | 	"strconv" | 
					
						
							|  |  |  | 	"strings" | 
					
						
							|  |  |  | 	"testing" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"codeberg.org/gruf/go-byteutil" | 
					
						
							|  |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/config" | 
					
						
							|  |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/httpclient" | 
					
						
							|  |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/queue" | 
					
						
							|  |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/transport/delivery" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestDeliveryWorkerPool(t *testing.T) { | 
					
						
							|  |  |  | 	for _, i := range []int{1, 2, 4, 8, 16, 32} { | 
					
						
							|  |  |  | 		t.Run("size="+strconv.Itoa(i), func(t *testing.T) { | 
					
						
							|  |  |  | 			testDeliveryWorkerPool(t, i, generateInput(100*i)) | 
					
						
							|  |  |  | 		}) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func testDeliveryWorkerPool(t *testing.T, sz int, input []*testrequest) { | 
					
						
							|  |  |  | 	wp := new(delivery.WorkerPool) | 
					
						
							|  |  |  | 	wp.Init(httpclient.New(httpclient.Config{ | 
					
						
							|  |  |  | 		AllowRanges: config.MustParseIPPrefixes([]string{ | 
					
						
							|  |  |  | 			"127.0.0.0/8", | 
					
						
							|  |  |  | 		}), | 
					
						
							|  |  |  | 	})) | 
					
						
							| 
									
										
										
										
											2024-04-26 13:50:46 +01:00
										 |  |  | 	wp.Start(sz) | 
					
						
							| 
									
										
										
										
											2024-04-11 10:45:35 +01:00
										 |  |  | 	defer wp.Stop() | 
					
						
							|  |  |  | 	test(t, &wp.Queue, input) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func test( | 
					
						
							|  |  |  | 	t *testing.T, | 
					
						
							|  |  |  | 	queue *queue.StructQueue[*delivery.Delivery], | 
					
						
							|  |  |  | 	input []*testrequest, | 
					
						
							|  |  |  | ) { | 
					
						
							|  |  |  | 	expect := make(chan *testrequest) | 
					
						
							|  |  |  | 	errors := make(chan error) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Prepare an HTTP test handler that ensures expected delivery is received. | 
					
						
							|  |  |  | 	handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { | 
					
						
							|  |  |  | 		errors <- (<-expect).Equal(r) | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Start new HTTP test server listener. | 
					
						
							|  |  |  | 	l, err := net.Listen("tcp", "127.0.0.1:0") | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Fatal(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	defer l.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Start the HTTP server. | 
					
						
							|  |  |  | 	// | 
					
						
							|  |  |  | 	// specifically not using httptest.Server{} here as httptest | 
					
						
							|  |  |  | 	// links that server with its own http.Client{}, whereas we're | 
					
						
							|  |  |  | 	// using an httpclient.Client{} (well, delivery routine is). | 
					
						
							|  |  |  | 	srv := new(http.Server) | 
					
						
							|  |  |  | 	srv.Addr = "http://" + l.Addr().String() | 
					
						
							|  |  |  | 	srv.Handler = handler | 
					
						
							|  |  |  | 	go srv.Serve(l) | 
					
						
							|  |  |  | 	defer srv.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Range over test input. | 
					
						
							|  |  |  | 	for _, test := range input { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Generate req for input. | 
					
						
							|  |  |  | 		req := test.Generate(srv.Addr) | 
					
						
							|  |  |  | 		r := httpclient.WrapRequest(req) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Wrap the request in delivery. | 
					
						
							|  |  |  | 		dlv := new(delivery.Delivery) | 
					
						
							|  |  |  | 		dlv.Request = r | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Enqueue delivery! | 
					
						
							|  |  |  | 		queue.Push(dlv) | 
					
						
							|  |  |  | 		expect <- test | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Wait for errors from handler. | 
					
						
							|  |  |  | 		if err := <-errors; err != nil { | 
					
						
							|  |  |  | 			t.Error(err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type testrequest struct { | 
					
						
							|  |  |  | 	method string | 
					
						
							|  |  |  | 	uri    string | 
					
						
							|  |  |  | 	body   []byte | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // generateInput generates 'n' many testrequest cases. | 
					
						
							|  |  |  | func generateInput(n int) []*testrequest { | 
					
						
							|  |  |  | 	tests := make([]*testrequest, n) | 
					
						
							|  |  |  | 	for i := range tests { | 
					
						
							|  |  |  | 		tests[i] = new(testrequest) | 
					
						
							|  |  |  | 		tests[i].method = randomMethod() | 
					
						
							|  |  |  | 		tests[i].uri = randomURI() | 
					
						
							|  |  |  | 		tests[i].body = randomBody(tests[i].method) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return tests | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | var methods = []string{ | 
					
						
							|  |  |  | 	http.MethodConnect, | 
					
						
							|  |  |  | 	http.MethodDelete, | 
					
						
							|  |  |  | 	http.MethodGet, | 
					
						
							|  |  |  | 	http.MethodHead, | 
					
						
							|  |  |  | 	http.MethodOptions, | 
					
						
							|  |  |  | 	http.MethodPatch, | 
					
						
							|  |  |  | 	http.MethodPost, | 
					
						
							|  |  |  | 	http.MethodPut, | 
					
						
							|  |  |  | 	http.MethodTrace, | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // randomMethod generates a random http method. | 
					
						
							|  |  |  | func randomMethod() string { | 
					
						
							|  |  |  | 	return methods[rand.Intn(len(methods))] | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // randomURI generates a random http uri. | 
					
						
							|  |  |  | func randomURI() string { | 
					
						
							|  |  |  | 	n := rand.Intn(5) | 
					
						
							|  |  |  | 	p := make([]string, n) | 
					
						
							|  |  |  | 	for i := range p { | 
					
						
							|  |  |  | 		p[i] = strconv.Itoa(rand.Int()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return "/" + strings.Join(p, "/") | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // randomBody generates a random http body DEPENDING on method. | 
					
						
							|  |  |  | func randomBody(method string) []byte { | 
					
						
							|  |  |  | 	if requiresBody(method) { | 
					
						
							|  |  |  | 		return []byte(method + " " + randomURI()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // requiresBody returns whether method requires body. | 
					
						
							|  |  |  | func requiresBody(method string) bool { | 
					
						
							|  |  |  | 	switch method { | 
					
						
							|  |  |  | 	case http.MethodPatch, | 
					
						
							|  |  |  | 		http.MethodPost, | 
					
						
							|  |  |  | 		http.MethodPut: | 
					
						
							|  |  |  | 		return true | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		return false | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Generate will generate a real http.Request{} from test data. | 
					
						
							|  |  |  | func (t *testrequest) Generate(addr string) *http.Request { | 
					
						
							|  |  |  | 	var body io.ReadCloser | 
					
						
							|  |  |  | 	if t.body != nil { | 
					
						
							|  |  |  | 		var b byteutil.ReadNopCloser | 
					
						
							|  |  |  | 		b.Reset(t.body) | 
					
						
							|  |  |  | 		body = &b | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	req, err := http.NewRequest(t.method, addr+t.uri, body) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		panic(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return req | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Equal checks if request matches receiving test request. | 
					
						
							|  |  |  | func (t *testrequest) Equal(r *http.Request) error { | 
					
						
							|  |  |  | 	// Ensure methods match. | 
					
						
							|  |  |  | 	if t.method != r.Method { | 
					
						
							|  |  |  | 		return fmt.Errorf("differing request methods: t=%q r=%q", t.method, r.Method) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Ensure request URIs match. | 
					
						
							|  |  |  | 	if t.uri != r.URL.RequestURI() { | 
					
						
							|  |  |  | 		return fmt.Errorf("differing request urls: t=%q r=%q", t.uri, r.URL.RequestURI()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Ensure body cases match. | 
					
						
							|  |  |  | 	if requiresBody(t.method) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Read request into memory. | 
					
						
							|  |  |  | 		b, err := io.ReadAll(r.Body) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return fmt.Errorf("error reading request body: %v", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Compare the request bodies. | 
					
						
							|  |  |  | 		st := strings.TrimSpace(string(t.body)) | 
					
						
							|  |  |  | 		sr := strings.TrimSpace(string(b)) | 
					
						
							|  |  |  | 		if st != sr { | 
					
						
							|  |  |  | 			return fmt.Errorf("differing request bodies: t=%q r=%q", st, sr) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } |