mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 01:52:26 -05:00 
			
		
		
		
	
		
			
	
	
		
			193 lines
		
	
	
	
		
			5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			193 lines
		
	
	
	
		
			5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
|  | // GoToSocial | ||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||
|  | // | ||
|  | // This program is free software: you can redistribute it and/or modify | ||
|  | // it under the terms of the GNU Affero General Public License as published by | ||
|  | // the Free Software Foundation, either version 3 of the License, or | ||
|  | // (at your option) any later version. | ||
|  | // | ||
|  | // This program is distributed in the hope that it will be useful, | ||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||
|  | // GNU Affero General Public License for more details. | ||
|  | // | ||
|  | // You should have received a copy of the GNU Affero General Public License | ||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||
|  | 
 | ||
|  | package middleware_test | ||
|  | 
 | ||
|  | import ( | ||
|  | 	"net/http" | ||
|  | 	"net/http/httptest" | ||
|  | 	"strconv" | ||
|  | 	"testing" | ||
|  | 	"time" | ||
|  | 
 | ||
|  | 	"github.com/gin-gonic/gin" | ||
|  | 	"github.com/stretchr/testify/suite" | ||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/middleware" | ||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||
|  | ) | ||
|  | 
 | ||
|  | type RateLimitTestSuite struct { | ||
|  | 	suite.Suite | ||
|  | } | ||
|  | 
 | ||
|  | func (suite *RateLimitTestSuite) TestRateLimit() { | ||
|  | 	// Suppress warnings about debug mode. | ||
|  | 	gin.SetMode(gin.ReleaseMode) | ||
|  | 
 | ||
|  | 	const ( | ||
|  | 		trustedPlatform = "X-Test-IP" | ||
|  | 		rlLimit         = "X-RateLimit-Limit" | ||
|  | 		rlRemaining     = "X-RateLimit-Remaining" | ||
|  | 		rlReset         = "X-RateLimit-Reset" | ||
|  | 	) | ||
|  | 
 | ||
|  | 	type rlTest struct { | ||
|  | 		limit        int | ||
|  | 		exceptions   []string | ||
|  | 		clientIP     string | ||
|  | 		shouldPanic  bool | ||
|  | 		shouldExcept bool | ||
|  | 	} | ||
|  | 
 | ||
|  | 	for _, test := range []rlTest{ | ||
|  | 		{ | ||
|  | 			limit:        10, | ||
|  | 			exceptions:   []string{}, | ||
|  | 			clientIP:     "192.0.2.0", | ||
|  | 			shouldPanic:  false, | ||
|  | 			shouldExcept: false, | ||
|  | 		}, | ||
|  | 		{ | ||
|  | 			limit:        10, | ||
|  | 			exceptions:   []string{}, | ||
|  | 			clientIP:     "192.0.2.0", | ||
|  | 			shouldPanic:  false, | ||
|  | 			shouldExcept: false, | ||
|  | 		}, | ||
|  | 		{ | ||
|  | 			limit:        10, | ||
|  | 			exceptions:   []string{"192.0.2.0/24"}, | ||
|  | 			clientIP:     "192.0.2.0", | ||
|  | 			shouldPanic:  false, | ||
|  | 			shouldExcept: true, | ||
|  | 		}, | ||
|  | 		{ | ||
|  | 			limit:        10, | ||
|  | 			exceptions:   []string{"192.0.2.0/32"}, | ||
|  | 			clientIP:     "192.0.2.1", | ||
|  | 			shouldPanic:  false, | ||
|  | 			shouldExcept: false, | ||
|  | 		}, | ||
|  | 		{ | ||
|  | 			limit:        10, | ||
|  | 			exceptions:   []string{"Ceci n'est pas une CIDR"}, | ||
|  | 			clientIP:     "192.0.2.0", | ||
|  | 			shouldPanic:  true, | ||
|  | 			shouldExcept: false, | ||
|  | 		}, | ||
|  | 	} { | ||
|  | 		if test.shouldPanic { | ||
|  | 			// Try to trigger panic. | ||
|  | 			suite.Panics(func() { | ||
|  | 				_ = middleware.RateLimit( | ||
|  | 					test.limit, | ||
|  | 					test.exceptions, | ||
|  | 				) | ||
|  | 			}) | ||
|  | 			continue | ||
|  | 		} | ||
|  | 
 | ||
|  | 		rlMiddleware := middleware.RateLimit( | ||
|  | 			test.limit, | ||
|  | 			test.exceptions, | ||
|  | 		) | ||
|  | 
 | ||
|  | 		// Approximate time when this limiter will reset. | ||
|  | 		resetAt := time.Now().Add(5 * time.Minute) | ||
|  | 
 | ||
|  | 		// Make requests up to + | ||
|  | 		// just over the limit. | ||
|  | 		limitedAt := test.limit + 1 | ||
|  | 		for requestsCount := 1; requestsCount < limitedAt; requestsCount++ { | ||
|  | 			var ( | ||
|  | 				recorder = httptest.NewRecorder() | ||
|  | 				ctx, e   = gin.CreateTestContext(recorder) | ||
|  | 			) | ||
|  | 
 | ||
|  | 			// Instruct engine to derive | ||
|  | 			// clientIP from test header. | ||
|  | 			e.TrustedPlatform = trustedPlatform | ||
|  | 			ctx.Request = httptest.NewRequest(http.MethodGet, "/example", nil) | ||
|  | 			ctx.Request.Header.Add(trustedPlatform, test.clientIP) | ||
|  | 
 | ||
|  | 			// Call the rate limiter. | ||
|  | 			rlMiddleware(ctx) | ||
|  | 
 | ||
|  | 			// Fetch RL headers if present. | ||
|  | 			var ( | ||
|  | 				limitStr     = recorder.Header().Get(rlLimit) | ||
|  | 				remainingStr = recorder.Header().Get(rlRemaining) | ||
|  | 				resetStr     = recorder.Header().Get(rlReset) | ||
|  | 			) | ||
|  | 
 | ||
|  | 			if test.shouldExcept { | ||
|  | 				// Request should be allowed through, | ||
|  | 				// no rate-limit headers should be written. | ||
|  | 				suite.Equal(http.StatusOK, recorder.Code) | ||
|  | 				suite.Empty(limitStr) | ||
|  | 				suite.Empty(remainingStr) | ||
|  | 				suite.Empty(resetStr) | ||
|  | 				continue | ||
|  | 			} | ||
|  | 
 | ||
|  | 			suite.Equal(strconv.Itoa(test.limit), limitStr) | ||
|  | 			suite.Equal(strconv.Itoa(test.limit-requestsCount), remainingStr) | ||
|  | 
 | ||
|  | 			// Ensure reset is ISO8601, and resets at | ||
|  | 			// approximate reset time (+/- 10 seconds). | ||
|  | 			reset, err := util.ParseISO8601(resetStr) | ||
|  | 			if err != nil { | ||
|  | 				suite.FailNow("", "couldn't parse %s as ISO8601: %q", resetStr, err.Error()) | ||
|  | 			} | ||
|  | 			suite.WithinDuration(resetAt, reset, 10*time.Second) | ||
|  | 
 | ||
|  | 			if requestsCount < limitedAt { | ||
|  | 				// Request should be allowed through. | ||
|  | 				suite.Equal(http.StatusOK, recorder.Code) | ||
|  | 				continue | ||
|  | 			} | ||
|  | 
 | ||
|  | 			// Request should be denied. | ||
|  | 			suite.Equal(http.StatusTooManyRequests, recorder.Code) | ||
|  | 
 | ||
|  | 			// Make a final request with an unrelated IP to | ||
|  | 			// ensure it's only the one IP being limited. | ||
|  | 			var ( | ||
|  | 				unrelatedRecorder        = httptest.NewRecorder() | ||
|  | 				unrelatedCtx, unrelatedE = gin.CreateTestContext(unrelatedRecorder) | ||
|  | 			) | ||
|  | 
 | ||
|  | 			// Instruct engine to derive | ||
|  | 			// clientIP from test header. | ||
|  | 			unrelatedE.TrustedPlatform = trustedPlatform | ||
|  | 			unrelatedCtx.Request = httptest.NewRequest(http.MethodGet, "/example", nil) | ||
|  | 			unrelatedCtx.Request.Header.Add(trustedPlatform, "192.0.2.255") | ||
|  | 
 | ||
|  | 			// Call the rate limiter. | ||
|  | 			rlMiddleware(unrelatedCtx) | ||
|  | 
 | ||
|  | 			// Request should be allowed through. | ||
|  | 			suite.Equal(http.StatusOK, unrelatedRecorder.Code) | ||
|  | 
 | ||
|  | 		} | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | func TestRateLimitTestSuite(t *testing.T) { | ||
|  | 	suite.Run(t, new(RateLimitTestSuite)) | ||
|  | } |