mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-11-13 05:17:30 -06:00
[feature] add rate limit middleware (#741)
* feat: add rate limit middleware * chore: update vendor dir * chore: update readme with new dependency * chore: add rate limit infos to swagger.md file * refactor: add ipv6 mask limiter option Add IPv6 CIDR /64 mask * refactor: increase rate limit to 1000 Address https://github.com/superseriousbusiness/gotosocial/pull/741#discussion_r945584800 Co-authored-by: tobi <31960611+tsmethurst@users.noreply.github.com>
This commit is contained in:
parent
daec9ab10e
commit
bee8458a2d
43 changed files with 4692 additions and 443 deletions
65
vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/middleware.go
generated
vendored
Normal file
65
vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/middleware.go
generated
vendored
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
package gin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ulule/limiter/v3"
|
||||
)
|
||||
|
||||
// Middleware is the middleware for gin.
|
||||
type Middleware struct {
|
||||
Limiter *limiter.Limiter
|
||||
OnError ErrorHandler
|
||||
OnLimitReached LimitReachedHandler
|
||||
KeyGetter KeyGetter
|
||||
ExcludedKey func(string) bool
|
||||
}
|
||||
|
||||
// NewMiddleware return a new instance of a gin middleware.
|
||||
func NewMiddleware(limiter *limiter.Limiter, options ...Option) gin.HandlerFunc {
|
||||
middleware := &Middleware{
|
||||
Limiter: limiter,
|
||||
OnError: DefaultErrorHandler,
|
||||
OnLimitReached: DefaultLimitReachedHandler,
|
||||
KeyGetter: DefaultKeyGetter,
|
||||
ExcludedKey: nil,
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
option.apply(middleware)
|
||||
}
|
||||
|
||||
return func(ctx *gin.Context) {
|
||||
middleware.Handle(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle gin request.
|
||||
func (middleware *Middleware) Handle(c *gin.Context) {
|
||||
key := middleware.KeyGetter(c)
|
||||
if middleware.ExcludedKey != nil && middleware.ExcludedKey(key) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
context, err := middleware.Limiter.Get(c, key)
|
||||
if err != nil {
|
||||
middleware.OnError(c, err)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("X-RateLimit-Limit", strconv.FormatInt(context.Limit, 10))
|
||||
c.Header("X-RateLimit-Remaining", strconv.FormatInt(context.Remaining, 10))
|
||||
c.Header("X-RateLimit-Reset", strconv.FormatInt(context.Reset, 10))
|
||||
|
||||
if context.Reached {
|
||||
middleware.OnLimitReached(c)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
71
vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/options.go
generated
vendored
Normal file
71
vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/options.go
generated
vendored
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
package gin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Option is used to define Middleware configuration.
|
||||
type Option interface {
|
||||
apply(*Middleware)
|
||||
}
|
||||
|
||||
type option func(*Middleware)
|
||||
|
||||
func (o option) apply(middleware *Middleware) {
|
||||
o(middleware)
|
||||
}
|
||||
|
||||
// ErrorHandler is an handler used to inform when an error has occurred.
|
||||
type ErrorHandler func(c *gin.Context, err error)
|
||||
|
||||
// WithErrorHandler will configure the Middleware to use the given ErrorHandler.
|
||||
func WithErrorHandler(handler ErrorHandler) Option {
|
||||
return option(func(middleware *Middleware) {
|
||||
middleware.OnError = handler
|
||||
})
|
||||
}
|
||||
|
||||
// DefaultErrorHandler is the default ErrorHandler used by a new Middleware.
|
||||
func DefaultErrorHandler(c *gin.Context, err error) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// LimitReachedHandler is an handler used to inform when the limit has exceeded.
|
||||
type LimitReachedHandler func(c *gin.Context)
|
||||
|
||||
// WithLimitReachedHandler will configure the Middleware to use the given LimitReachedHandler.
|
||||
func WithLimitReachedHandler(handler LimitReachedHandler) Option {
|
||||
return option(func(middleware *Middleware) {
|
||||
middleware.OnLimitReached = handler
|
||||
})
|
||||
}
|
||||
|
||||
// DefaultLimitReachedHandler is the default LimitReachedHandler used by a new Middleware.
|
||||
func DefaultLimitReachedHandler(c *gin.Context) {
|
||||
c.String(http.StatusTooManyRequests, "Limit exceeded")
|
||||
}
|
||||
|
||||
// KeyGetter will define the rate limiter key given the gin Context.
|
||||
type KeyGetter func(c *gin.Context) string
|
||||
|
||||
// WithKeyGetter will configure the Middleware to use the given KeyGetter.
|
||||
func WithKeyGetter(handler KeyGetter) Option {
|
||||
return option(func(middleware *Middleware) {
|
||||
middleware.KeyGetter = handler
|
||||
})
|
||||
}
|
||||
|
||||
// DefaultKeyGetter is the default KeyGetter used by a new Middleware.
|
||||
// It returns the Client IP address.
|
||||
func DefaultKeyGetter(c *gin.Context) string {
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
// WithExcludedKey will configure the Middleware to ignore key(s) using the given function.
|
||||
func WithExcludedKey(handler func(string) bool) Option {
|
||||
return option(func(middleware *Middleware) {
|
||||
middleware.ExcludedKey = handler
|
||||
})
|
||||
}
|
||||
28
vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go
generated
vendored
Normal file
28
vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go
generated
vendored
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/ulule/limiter/v3"
|
||||
)
|
||||
|
||||
// GetContextFromState generate a new limiter.Context from given state.
|
||||
func GetContextFromState(now time.Time, rate limiter.Rate, expiration time.Time, count int64) limiter.Context {
|
||||
limit := rate.Limit
|
||||
remaining := int64(0)
|
||||
reached := true
|
||||
|
||||
if count <= limit {
|
||||
remaining = limit - count
|
||||
reached = false
|
||||
}
|
||||
|
||||
reset := expiration.Unix()
|
||||
|
||||
return limiter.Context{
|
||||
Limit: limit,
|
||||
Remaining: remaining,
|
||||
Reset: reset,
|
||||
Reached: reached,
|
||||
}
|
||||
}
|
||||
240
vendor/github.com/ulule/limiter/v3/drivers/store/memory/cache.go
generated
vendored
Normal file
240
vendor/github.com/ulule/limiter/v3/drivers/store/memory/cache.go
generated
vendored
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
package memory
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Forked from https://github.com/patrickmn/go-cache
|
||||
|
||||
// CacheWrapper is used to ensure that the underlying cleaner goroutine used to clean expired keys will not prevent
|
||||
// Cache from being garbage collected.
|
||||
type CacheWrapper struct {
|
||||
*Cache
|
||||
}
|
||||
|
||||
// A cleaner will periodically delete expired keys from cache.
|
||||
type cleaner struct {
|
||||
interval time.Duration
|
||||
stop chan bool
|
||||
}
|
||||
|
||||
// Run will periodically delete expired keys from given cache until GC notify that it should stop.
|
||||
func (cleaner *cleaner) Run(cache *Cache) {
|
||||
ticker := time.NewTicker(cleaner.interval)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cache.Clean()
|
||||
case <-cleaner.stop:
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// stopCleaner is a callback from GC used to stop cleaner goroutine.
|
||||
func stopCleaner(wrapper *CacheWrapper) {
|
||||
wrapper.cleaner.stop <- true
|
||||
wrapper.cleaner = nil
|
||||
}
|
||||
|
||||
// startCleaner will start a cleaner goroutine for given cache.
|
||||
func startCleaner(cache *Cache, interval time.Duration) {
|
||||
cleaner := &cleaner{
|
||||
interval: interval,
|
||||
stop: make(chan bool),
|
||||
}
|
||||
|
||||
cache.cleaner = cleaner
|
||||
go cleaner.Run(cache)
|
||||
}
|
||||
|
||||
// Counter is a simple counter with an expiration.
|
||||
type Counter struct {
|
||||
mutex sync.RWMutex
|
||||
value int64
|
||||
expiration int64
|
||||
}
|
||||
|
||||
// Value returns the counter current value.
|
||||
func (counter *Counter) Value() int64 {
|
||||
counter.mutex.RLock()
|
||||
defer counter.mutex.RUnlock()
|
||||
return counter.value
|
||||
}
|
||||
|
||||
// Expiration returns the counter expiration.
|
||||
func (counter *Counter) Expiration() int64 {
|
||||
counter.mutex.RLock()
|
||||
defer counter.mutex.RUnlock()
|
||||
return counter.expiration
|
||||
}
|
||||
|
||||
// Expired returns true if the counter has expired.
|
||||
func (counter *Counter) Expired() bool {
|
||||
counter.mutex.RLock()
|
||||
defer counter.mutex.RUnlock()
|
||||
|
||||
return counter.expiration == 0 || time.Now().UnixNano() > counter.expiration
|
||||
}
|
||||
|
||||
// Load returns the value and the expiration of this counter.
|
||||
// If the counter is expired, it will use the given expiration.
|
||||
func (counter *Counter) Load(expiration int64) (int64, int64) {
|
||||
counter.mutex.RLock()
|
||||
defer counter.mutex.RUnlock()
|
||||
|
||||
if counter.expiration == 0 || time.Now().UnixNano() > counter.expiration {
|
||||
return 0, expiration
|
||||
}
|
||||
|
||||
return counter.value, counter.expiration
|
||||
}
|
||||
|
||||
// Increment increments given value on this counter.
|
||||
// If the counter is expired, it will use the given expiration.
|
||||
// It returns its current value and expiration.
|
||||
func (counter *Counter) Increment(value int64, expiration int64) (int64, int64) {
|
||||
counter.mutex.Lock()
|
||||
defer counter.mutex.Unlock()
|
||||
|
||||
if counter.expiration == 0 || time.Now().UnixNano() > counter.expiration {
|
||||
counter.value = value
|
||||
counter.expiration = expiration
|
||||
return counter.value, counter.expiration
|
||||
}
|
||||
|
||||
counter.value += value
|
||||
return counter.value, counter.expiration
|
||||
}
|
||||
|
||||
// Cache contains a collection of counters.
|
||||
type Cache struct {
|
||||
counters sync.Map
|
||||
cleaner *cleaner
|
||||
}
|
||||
|
||||
// NewCache returns a new cache.
|
||||
func NewCache(cleanInterval time.Duration) *CacheWrapper {
|
||||
|
||||
cache := &Cache{}
|
||||
wrapper := &CacheWrapper{Cache: cache}
|
||||
|
||||
if cleanInterval > 0 {
|
||||
startCleaner(cache, cleanInterval)
|
||||
runtime.SetFinalizer(wrapper, stopCleaner)
|
||||
}
|
||||
|
||||
return wrapper
|
||||
}
|
||||
|
||||
// LoadOrStore returns the existing counter for the key if present.
|
||||
// Otherwise, it stores and returns the given counter.
|
||||
// The loaded result is true if the counter was loaded, false if stored.
|
||||
func (cache *Cache) LoadOrStore(key string, counter *Counter) (*Counter, bool) {
|
||||
val, loaded := cache.counters.LoadOrStore(key, counter)
|
||||
if val == nil {
|
||||
return counter, false
|
||||
}
|
||||
|
||||
actual := val.(*Counter)
|
||||
return actual, loaded
|
||||
}
|
||||
|
||||
// Load returns the counter stored in the map for a key, or nil if no counter is present.
|
||||
// The ok result indicates whether counter was found in the map.
|
||||
func (cache *Cache) Load(key string) (*Counter, bool) {
|
||||
val, ok := cache.counters.Load(key)
|
||||
if val == nil || !ok {
|
||||
return nil, false
|
||||
}
|
||||
actual := val.(*Counter)
|
||||
return actual, true
|
||||
}
|
||||
|
||||
// Store sets the counter for a key.
|
||||
func (cache *Cache) Store(key string, counter *Counter) {
|
||||
cache.counters.Store(key, counter)
|
||||
}
|
||||
|
||||
// Delete deletes the value for a key.
|
||||
func (cache *Cache) Delete(key string) {
|
||||
cache.counters.Delete(key)
|
||||
}
|
||||
|
||||
// Range calls handler sequentially for each key and value present in the cache.
|
||||
// If handler returns false, range stops the iteration.
|
||||
func (cache *Cache) Range(handler func(key string, counter *Counter)) {
|
||||
cache.counters.Range(func(k interface{}, v interface{}) bool {
|
||||
if v == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
key := k.(string)
|
||||
counter := v.(*Counter)
|
||||
|
||||
handler(key, counter)
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Increment increments given value on key.
|
||||
// If key is undefined or expired, it will create it.
|
||||
func (cache *Cache) Increment(key string, value int64, duration time.Duration) (int64, time.Time) {
|
||||
expiration := time.Now().Add(duration).UnixNano()
|
||||
|
||||
// If counter is in cache, try to load it first.
|
||||
counter, loaded := cache.Load(key)
|
||||
if loaded {
|
||||
value, expiration = counter.Increment(value, expiration)
|
||||
return value, time.Unix(0, expiration)
|
||||
}
|
||||
|
||||
// If it's not in cache, try to atomically create it.
|
||||
// We do that in two step to reduce memory allocation.
|
||||
counter, loaded = cache.LoadOrStore(key, &Counter{
|
||||
mutex: sync.RWMutex{},
|
||||
value: value,
|
||||
expiration: expiration,
|
||||
})
|
||||
if loaded {
|
||||
value, expiration = counter.Increment(value, expiration)
|
||||
return value, time.Unix(0, expiration)
|
||||
}
|
||||
|
||||
// Otherwise, it has been created, return given value.
|
||||
return value, time.Unix(0, expiration)
|
||||
}
|
||||
|
||||
// Get returns key's value and expiration.
|
||||
func (cache *Cache) Get(key string, duration time.Duration) (int64, time.Time) {
|
||||
expiration := time.Now().Add(duration).UnixNano()
|
||||
|
||||
counter, ok := cache.Load(key)
|
||||
if !ok {
|
||||
return 0, time.Unix(0, expiration)
|
||||
}
|
||||
|
||||
value, expiration := counter.Load(expiration)
|
||||
return value, time.Unix(0, expiration)
|
||||
}
|
||||
|
||||
// Clean will deleted any expired keys.
|
||||
func (cache *Cache) Clean() {
|
||||
cache.Range(func(key string, counter *Counter) {
|
||||
if counter.Expired() {
|
||||
cache.Delete(key)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Reset changes the key's value and resets the expiration.
|
||||
func (cache *Cache) Reset(key string, duration time.Duration) (int64, time.Time) {
|
||||
cache.Delete(key)
|
||||
|
||||
expiration := time.Now().Add(duration).UnixNano()
|
||||
return 0, time.Unix(0, expiration)
|
||||
}
|
||||
82
vendor/github.com/ulule/limiter/v3/drivers/store/memory/store.go
generated
vendored
Normal file
82
vendor/github.com/ulule/limiter/v3/drivers/store/memory/store.go
generated
vendored
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/ulule/limiter/v3"
|
||||
"github.com/ulule/limiter/v3/drivers/store/common"
|
||||
"github.com/ulule/limiter/v3/internal/bytebuffer"
|
||||
)
|
||||
|
||||
// Store is the in-memory store.
|
||||
type Store struct {
|
||||
// Prefix used for the key.
|
||||
Prefix string
|
||||
// cache used to store values in-memory.
|
||||
cache *CacheWrapper
|
||||
}
|
||||
|
||||
// NewStore creates a new instance of memory store with defaults.
|
||||
func NewStore() limiter.Store {
|
||||
return NewStoreWithOptions(limiter.StoreOptions{
|
||||
Prefix: limiter.DefaultPrefix,
|
||||
CleanUpInterval: limiter.DefaultCleanUpInterval,
|
||||
})
|
||||
}
|
||||
|
||||
// NewStoreWithOptions creates a new instance of memory store with options.
|
||||
func NewStoreWithOptions(options limiter.StoreOptions) limiter.Store {
|
||||
return &Store{
|
||||
Prefix: options.Prefix,
|
||||
cache: NewCache(options.CleanUpInterval),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the limit for given identifier.
|
||||
func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
|
||||
buffer := bytebuffer.New()
|
||||
defer buffer.Close()
|
||||
buffer.Concat(store.Prefix, ":", key)
|
||||
|
||||
count, expiration := store.cache.Increment(buffer.String(), 1, rate.Period)
|
||||
|
||||
lctx := common.GetContextFromState(time.Now(), rate, expiration, count)
|
||||
return lctx, nil
|
||||
}
|
||||
|
||||
// Increment increments the limit by given count & returns the new limit value for given identifier.
|
||||
func (store *Store) Increment(ctx context.Context, key string, count int64, rate limiter.Rate) (limiter.Context, error) {
|
||||
buffer := bytebuffer.New()
|
||||
defer buffer.Close()
|
||||
buffer.Concat(store.Prefix, ":", key)
|
||||
|
||||
newCount, expiration := store.cache.Increment(buffer.String(), count, rate.Period)
|
||||
|
||||
lctx := common.GetContextFromState(time.Now(), rate, expiration, newCount)
|
||||
return lctx, nil
|
||||
}
|
||||
|
||||
// Peek returns the limit for given identifier, without modification on current values.
|
||||
func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
|
||||
buffer := bytebuffer.New()
|
||||
defer buffer.Close()
|
||||
buffer.Concat(store.Prefix, ":", key)
|
||||
|
||||
count, expiration := store.cache.Get(buffer.String(), rate.Period)
|
||||
|
||||
lctx := common.GetContextFromState(time.Now(), rate, expiration, count)
|
||||
return lctx, nil
|
||||
}
|
||||
|
||||
// Reset returns the limit for given identifier.
|
||||
func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
|
||||
buffer := bytebuffer.New()
|
||||
defer buffer.Close()
|
||||
buffer.Concat(store.Prefix, ":", key)
|
||||
|
||||
count, expiration := store.cache.Reset(buffer.String(), rate.Period)
|
||||
|
||||
lctx := common.GetContextFromState(time.Now(), rate, expiration, count)
|
||||
return lctx, nil
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue