[feature] request blocking by http headers (#2409)

This commit is contained in:
kim 2023-12-18 14:18:25 +00:00 committed by GitHub
commit 8ebb7775a3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 2561 additions and 81 deletions

View file

@ -35,6 +35,10 @@ const (
DomainAllowsPath = BasePath + "/domain_allows"
DomainAllowsPathWithID = DomainAllowsPath + "/:" + IDKey
DomainKeysExpirePath = BasePath + "/domain_keys_expire"
HeaderAllowsPath = BasePath + "/header_allows"
HeaderAllowsPathWithID = HeaderAllowsPath + "/:" + IDKey
HeaderBlocksPath = BasePath + "/header_blocks"
HeaderBlocksPathWithID = HeaderAllowsPath + "/:" + IDKey
AccountsPath = BasePath + "/accounts"
AccountsPathWithID = AccountsPath + "/:" + IDKey
AccountsActionPath = AccountsPathWithID + "/action"
@ -95,6 +99,16 @@ func (m *Module) Route(attachHandler func(method string, path string, f ...gin.H
attachHandler(http.MethodGet, DomainAllowsPathWithID, m.DomainAllowGETHandler)
attachHandler(http.MethodDelete, DomainAllowsPathWithID, m.DomainAllowDELETEHandler)
// header filtering administration routes
attachHandler(http.MethodGet, HeaderAllowsPathWithID, m.HeaderFilterAllowGET)
attachHandler(http.MethodGet, HeaderBlocksPathWithID, m.HeaderFilterBlockGET)
attachHandler(http.MethodGet, HeaderAllowsPath, m.HeaderFilterAllowsGET)
attachHandler(http.MethodGet, HeaderBlocksPath, m.HeaderFilterBlocksGET)
attachHandler(http.MethodPost, HeaderAllowsPath, m.HeaderFilterAllowPOST)
attachHandler(http.MethodPost, HeaderBlocksPath, m.HeaderFilterBlockPOST)
attachHandler(http.MethodDelete, HeaderAllowsPathWithID, m.HeaderFilterAllowDELETE)
attachHandler(http.MethodDelete, HeaderBlocksPathWithID, m.HeaderFilterBlockDELETE)
// domain maintenance stuff
attachHandler(http.MethodPost, DomainKeysExpirePath, m.DomainKeysExpirePOSTHandler)

View file

@ -0,0 +1,173 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package admin
import (
"context"
"errors"
"net/http"
"github.com/gin-gonic/gin"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
// getHeaderFilter is a gin handler function that returns details of an HTTP header filter with provided ID, using given get function.
func (m *Module) getHeaderFilter(c *gin.Context, get func(context.Context, string) (*apimodel.HeaderFilter, gtserror.WithCode)) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
errWithCode := gtserror.NewErrorUnauthorized(err, err.Error())
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if !*authed.User.Admin {
const text = "user not an admin"
errWithCode := gtserror.NewErrorForbidden(errors.New(text), text)
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
errWithCode := gtserror.NewErrorNotAcceptable(err, err.Error())
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
filterID, errWithCode := apiutil.ParseID(c.Param("ID"))
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
filter, errWithCode := get(c.Request.Context(), filterID)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
apiutil.JSON(c, http.StatusOK, filter)
}
// getHeaderFilters is a gin handler function that returns details of all HTTP header filters using given get function.
func (m *Module) getHeaderFilters(c *gin.Context, get func(context.Context) ([]*apimodel.HeaderFilter, gtserror.WithCode)) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
errWithCode := gtserror.NewErrorUnauthorized(err, err.Error())
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if !*authed.User.Admin {
const text = "user not an admin"
errWithCode := gtserror.NewErrorForbidden(errors.New(text), text)
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
errWithCode := gtserror.NewErrorNotAcceptable(err, err.Error())
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
filters, errWithCode := get(c.Request.Context())
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
apiutil.JSON(c, http.StatusOK, filters)
}
// createHeaderFilter is a gin handler function that creates a HTTP header filter entry using provided form data, passing to given create function.
func (m *Module) createHeaderFilter(c *gin.Context, create func(context.Context, *gtsmodel.Account, *apimodel.HeaderFilterRequest) (*apimodel.HeaderFilter, gtserror.WithCode)) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
errWithCode := gtserror.NewErrorUnauthorized(err, err.Error())
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if !*authed.User.Admin {
const text = "user not an admin"
errWithCode := gtserror.NewErrorForbidden(errors.New(text), text)
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
errWithCode := gtserror.NewErrorNotAcceptable(err, err.Error())
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
var form apimodel.HeaderFilterRequest
if err := c.ShouldBind(&form); err != nil {
errWithCode := gtserror.NewErrorBadRequest(err, err.Error())
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
filter, errWithCode := create(
c.Request.Context(),
authed.Account,
&form,
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
apiutil.JSON(c, http.StatusOK, filter)
}
// deleteHeaderFilter is a gin handler function that deletes an HTTP header filter with provided ID, using given delete function.
func (m *Module) deleteHeaderFilter(c *gin.Context, delete func(context.Context, string) gtserror.WithCode) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
errWithCode := gtserror.NewErrorUnauthorized(err, err.Error())
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
if !*authed.User.Admin {
const text = "user not an admin"
errWithCode := gtserror.NewErrorForbidden(errors.New(text), text)
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
filterID, errWithCode := apiutil.ParseID(c.Param("ID"))
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
errWithCode = delete(c.Request.Context(), filterID)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
c.Status(http.StatusAccepted)
}

View file

@ -0,0 +1,102 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package admin
import (
"github.com/gin-gonic/gin"
)
// HeaderFilterAllowPOST swagger:operation POST /api/v1/admin/header_allows headerFilterAllowCreate
//
// Create new "allow" HTTP request header filter.
//
// The parameters can also be given in the body of the request, as JSON, if the content-type is set to 'application/json'.
// The parameters can also be given in the body of the request, as XML, if the content-type is set to 'application/xml'.
//
// ---
// tags:
// - admin
//
// consumes:
// - application/json
// - application/xml
// - application/x-www-form-urlencoded
//
// produces:
// - application/json
//
// security:
// - OAuth2 Bearer:
// - admin
//
// responses:
// '200':
// description: The newly created "allow" header filter.
// schema:
// "$ref": "#/definitions/headerFilter"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '403':
// description: forbidden
// '500':
// description: internal server error
func (m *Module) HeaderFilterAllowPOST(c *gin.Context) {
m.createHeaderFilter(c, m.processor.Admin().CreateAllowHeaderFilter)
}
// HeaderFilterBlockPOST swagger:operation POST /api/v1/admin/header_blocks headerFilterBlockCreate
//
// Create new "block" HTTP request header filter.
//
// The parameters can also be given in the body of the request, as JSON, if the content-type is set to 'application/json'.
// The parameters can also be given in the body of the request, as XML, if the content-type is set to 'application/xml'.
//
// ---
// tags:
// - admin
//
// consumes:
// - application/json
// - application/xml
// - application/x-www-form-urlencoded
//
// produces:
// - application/json
//
// security:
// - OAuth2 Bearer:
// - admin
//
// responses:
// '200':
// description: The newly created "block" header filter.
// schema:
// "$ref": "#/definitions/headerFilter"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '403':
// description: forbidden
// '500':
// description: internal server error
func (m *Module) HeaderFilterBlockPOST(c *gin.Context) {
m.createHeaderFilter(c, m.processor.Admin().CreateBlockHeaderFilter)
}

View file

@ -0,0 +1,96 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package admin
import (
"github.com/gin-gonic/gin"
)
// HeaderFilterAllowDELETE swagger:operation DELETE /api/v1/admin/header_allows/{id} headerFilterAllowDelete
//
// Delete the "allow" header filter with the given ID.
//
// ---
// tags:
// - admin
//
// parameters:
// -
// name: id
// type: string
// description: Target header filter ID.
// in: path
// required: true
//
// security:
// - OAuth2 Bearer:
// - admin
//
// responses:
// '202':
// description: Accepted
// '400':
// description: bad request
// '401':
// description: unauthorized
// '403':
// description: forbidden
// '404':
// description: not found
// '500':
// description: internal server error
func (m *Module) HeaderFilterAllowDELETE(c *gin.Context) {
m.deleteHeaderFilter(c, m.processor.Admin().DeleteAllowHeaderFilter)
}
// HeaderFilterBlockDELETE swagger:operation DELETE /api/v1/admin/header_blocks/{id} headerFilterBlockDelete
//
// Delete the "block" header filter with the given ID.
//
// ---
// tags:
// - admin
//
// parameters:
// -
// name: id
// type: string
// description: Target header filter ID.
// in: path
// required: true
//
// security:
// - OAuth2 Bearer:
// - admin
//
// responses:
// '202':
// description: Accepted
// '400':
// description: bad request
// '401':
// description: unauthorized
// '403':
// description: forbidden
// '404':
// description: not found
// '500':
// description: internal server error
func (m *Module) HeaderFilterBlockDELETE(c *gin.Context) {
m.deleteHeaderFilter(c, m.processor.Admin().DeleteAllowHeaderFilter)
}

View file

@ -0,0 +1,164 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package admin
import "github.com/gin-gonic/gin"
// HeaderFilterAllowGET swagger:operation GET /api/v1/admin/header_allows/{id} headerFilterAllowGet
//
// Get "allow" header filter with the given ID.
//
// ---
// tags:
// - admin
//
// parameters:
// -
// name: id
// type: string
// description: Target header filter ID.
// in: path
// required: true
//
// security:
// - OAuth2 Bearer:
// - admin
//
// responses:
// '200':
// description: The requested "allow" header filter.
// schema:
// "$ref": "#/definitions/headerFilter"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '403':
// description: forbidden
// '404':
// description: not found
// '500':
// description: internal server error
func (m *Module) HeaderFilterAllowGET(c *gin.Context) {
m.getHeaderFilter(c, m.processor.Admin().GetAllowHeaderFilter)
}
// HeaderFilterBlockGET swagger:operation GET /api/v1/admin/header_blocks/{id} headerFilterBlockGet
//
// Get "block" header filter with the given ID.
//
// ---
// tags:
// - admin
//
// parameters:
// -
// name: id
// type: string
// description: Target header filter ID.
// in: path
// required: true
//
// security:
// - OAuth2 Bearer:
// - admin
//
// responses:
// '200':
// description: The requested "block" header filter.
// schema:
// "$ref": "#/definitions/headerFilter"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '403':
// description: forbidden
// '404':
// description: not found
// '500':
// description: internal server error
func (m *Module) HeaderFilterBlockGET(c *gin.Context) {
m.getHeaderFilter(c, m.processor.Admin().GetBlockHeaderFilter)
}
// HeaderFilterAllowsGET swagger:operation GET /api/v1/admin/header_allows headerFilterAllowsGet
//
// Get all "allow" header filters currently in place.
//
// ---
// tags:
// - admin
//
// security:
// - OAuth2 Bearer:
// - admin
//
// responses:
// '200':
// description: All "allow" header filters currently in place.
// schema:
// type: array
// items:
// "$ref": "#/definitions/headerFilter"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '403':
// description: forbidden
// '404':
// description: not found
// '500':
// description: internal server error
func (m *Module) HeaderFilterAllowsGET(c *gin.Context) {
m.getHeaderFilters(c, m.processor.Admin().GetAllowHeaderFilters)
}
// HeaderFilterBlocksGET swagger:operation GET /api/v1/admin/header_blocks headerFilterBlocksGet
//
// Get all "allow" header filters currently in place.
//
// ---
// tags:
// - admin
//
// security:
// - OAuth2 Bearer:
// - admin
//
// responses:
// '200':
// description: All "block" header filters currently in place.
// schema:
// type: array
// items:
// "$ref": "#/definitions/headerFilter"
// '400':
// description: bad request
// '401':
// description: unauthorized
// '403':
// description: forbidden
// '404':
// description: not found
// '500':
// description: internal server error
func (m *Module) HeaderFilterBlocksGET(c *gin.Context) {
m.getHeaderFilters(c, m.processor.Admin().GetBlockHeaderFilters)
}

View file

@ -0,0 +1,55 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package model
// HeaderFilter represents a regex value filter applied to one particular HTTP header (allow / block).
type HeaderFilter struct {
// The ID of the header filter.
// example: 01FBW21XJA09XYX51KV5JVBW0F
// readonly: true
ID string `json:"id"`
// The HTTP header to match against.
// example: User-Agent
Header string `json:"header"`
// The header value matching regular expression.
// example: .*Firefox.*
Regex string `json:"regex"`
// The ID of the admin account that created this header filter.
// example: 01FBW2758ZB6PBR200YPDDJK4C
// readonly: true
CreatedBy string `json:"created_by"`
// Time at which the header filter was created (ISO 8601 Datetime).
// example: 2021-07-30T09:20:25+00:00
// readonly: true
CreatedAt string `json:"created_at"`
}
// HeaderFilterRequest is the form submitted as a POST to create a new header filter entry (allow / block).
//
// swagger:model headerFilterCreateRequest
type HeaderFilterRequest struct {
// The HTTP header to match against (e.g. User-Agent).
Header string `form:"header" json:"header" xml:"header"`
// The header value matching regular expression.
Regex string `form:"regex" json:"regex" xml:"regex"`
}

View file

@ -39,14 +39,17 @@ var (
StatusAcceptedJSON = mustJSON(map[string]string{
"status": http.StatusText(http.StatusAccepted),
})
StatusForbiddenJSON = mustJSON(map[string]string{
"status": http.StatusText(http.StatusForbidden),
})
StatusInternalServerErrorJSON = mustJSON(map[string]string{
"status": http.StatusText(http.StatusInternalServerError),
})
ErrorCapacityExceeded = mustJSON(map[string]string{
"error": "server capacity exceeded!",
"error": "server capacity exceeded",
})
ErrorRateLimitReached = mustJSON(map[string]string{
"error": "rate limit reached!",
ErrorRateLimited = mustJSON(map[string]string{
"error": "rate limit reached",
})
EmptyJSONObject = mustJSON("{}")
EmptyJSONArray = mustJSON("[]")

View file

@ -18,21 +18,26 @@
package cache
import (
"github.com/superseriousbusiness/gotosocial/internal/cache/headerfilter"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
type Caches struct {
// GTS provides access to the collection of gtsmodel object caches.
// (used by the database).
// GTS provides access to the collection of
// gtsmodel object caches. (used by the database).
GTS GTSCaches
// AP provides access to the collection of ActivityPub object caches.
// (planned to be used by the typeconverter).
AP APCaches
// AllowHeaderFilters provides access to
// the allow []headerfilter.Filter cache.
AllowHeaderFilters headerfilter.Cache
// Visibility provides access to the item visibility cache.
// (used by the visibility filter).
// BlockHeaderFilters provides access to
// the block []headerfilter.Filter cache.
BlockHeaderFilters headerfilter.Cache
// Visibility provides access to the item visibility
// cache. (used by the visibility filter).
Visibility VisibilityCache
// prevent pass-by-value.
@ -45,7 +50,6 @@ func (c *Caches) Init() {
log.Infof(nil, "init: %p", c)
c.GTS.Init()
c.AP.Init()
c.Visibility.Init()
// Setup cache invalidate hooks.
@ -58,7 +62,6 @@ func (c *Caches) Start() {
log.Infof(nil, "start: %p", c)
c.GTS.Start()
c.AP.Start()
c.Visibility.Start()
}
@ -67,7 +70,6 @@ func (c *Caches) Stop() {
log.Infof(nil, "stop: %p", c)
c.GTS.Stop()
c.AP.Stop()
c.Visibility.Stop()
}

View file

@ -21,7 +21,6 @@ import (
"fmt"
"strings"
"sync/atomic"
"unsafe"
"golang.org/x/exp/slices"
)
@ -37,17 +36,17 @@ import (
// The .Clear() function can be used to invalidate the cache,
// e.g. when an entry is added / deleted from the database.
type Cache struct {
// atomically updated ptr value to the
// current domain cache radix trie.
rootptr unsafe.Pointer
rootptr atomic.Pointer[root]
}
// Matches checks whether domain matches an entry in the cache.
// If the cache is not currently loaded, then the provided load
// function is used to hydrate it.
func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, error) {
// Load the current root pointer value.
ptr := atomic.LoadPointer(&c.rootptr)
// Load the current
// root pointer value.
ptr := c.rootptr.Load()
if ptr == nil {
// Cache is not hydrated.
@ -60,35 +59,32 @@ func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, err
// Allocate new radix trie
// node to store matches.
root := new(root)
ptr = new(root)
// Add each domain to the trie.
for _, domain := range domains {
root.Add(domain)
ptr.Add(domain)
}
// Sort the trie.
root.Sort()
ptr.Sort()
// Store the new node ptr.
ptr = unsafe.Pointer(root)
atomic.StorePointer(&c.rootptr, ptr)
// Store new node ptr.
c.rootptr.Store(ptr)
}
// Look for a match in the trie node.
return (*root)(ptr).Match(domain), nil
// Look for match in trie node.
return ptr.Match(domain), nil
}
// Clear will drop the currently loaded domain list,
// triggering a reload on next call to .Matches().
func (c *Cache) Clear() {
atomic.StorePointer(&c.rootptr, nil)
}
func (c *Cache) Clear() { c.rootptr.Store(nil) }
// String returns a string representation of stored domains in cache.
func (c *Cache) String() string {
if ptr := atomic.LoadPointer(&c.rootptr); ptr != nil {
return (*root)(ptr).String()
if ptr := c.rootptr.Load(); ptr != nil {
return ptr.String()
}
return "<empty>"
}

105
internal/cache/headerfilter/filter.go vendored Normal file
View file

@ -0,0 +1,105 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package headerfilter
import (
"fmt"
"net/http"
"sync/atomic"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/headerfilter"
)
// Cache provides a means of caching headerfilter.Filters in
// memory to reduce load on an underlying storage mechanism.
type Cache struct {
// current cached header filters slice.
ptr atomic.Pointer[headerfilter.Filters]
}
// RegularMatch performs .RegularMatch() on cached headerfilter.Filters, loading using callback if necessary.
func (c *Cache) RegularMatch(h http.Header, load func() ([]*gtsmodel.HeaderFilter, error)) (string, string, error) {
// Load ptr value.
ptr := c.ptr.Load()
if ptr == nil {
// Cache is not hydrated.
// Load filters from callback.
filters, err := loadFilters(load)
if err != nil {
return "", "", err
}
// Store the new
// header filters.
ptr = &filters
c.ptr.Store(ptr)
}
// Deref and perform match.
return ptr.RegularMatch(h)
}
// InverseMatch performs .InverseMatch() on cached headerfilter.Filters, loading using callback if necessary.
func (c *Cache) InverseMatch(h http.Header, load func() ([]*gtsmodel.HeaderFilter, error)) (string, string, error) {
// Load ptr value.
ptr := c.ptr.Load()
if ptr == nil {
// Cache is not hydrated.
// Load filters from callback.
filters, err := loadFilters(load)
if err != nil {
return "", "", err
}
// Store the new
// header filters.
ptr = &filters
c.ptr.Store(ptr)
}
// Deref and perform match.
return ptr.InverseMatch(h)
}
// Clear will drop the currently loaded filters,
// triggering a reload on next call to ._Match().
func (c *Cache) Clear() { c.ptr.Store(nil) }
// loadFilters will load filters from given load callback, creating and parsing raw filters.
func loadFilters(load func() ([]*gtsmodel.HeaderFilter, error)) (headerfilter.Filters, error) {
// Load filters from callback.
hdrFilters, err := load()
if err != nil {
return nil, fmt.Errorf("error reloading cache: %w", err)
}
// Allocate new header filter slice to store expressions.
filters := make(headerfilter.Filters, 0, len(hdrFilters))
// Add all raw expression to filter slice.
for _, filter := range hdrFilters {
if err := filters.Append(filter.Header, filter.Regex); err != nil {
return nil, fmt.Errorf("error appending exprs: %w", err)
}
}
return filters, nil
}

View file

@ -163,6 +163,7 @@ type Configuration struct {
AdvancedThrottlingRetryAfter time.Duration `name:"advanced-throttling-retry-after" usage:"Retry-After duration response to send for throttled requests."`
AdvancedSenderMultiplier int `name:"advanced-sender-multiplier" usage:"Multiplier to use per cpu for batching outgoing fedi messages. 0 or less turns batching off (not recommended)."`
AdvancedCSPExtraURIs []string `name:"advanced-csp-extra-uris" usage:"Additional URIs to allow when building content-security-policy for media + images."`
AdvancedHeaderFilterMode string `name:"advanced-header-filter-mode" usage:"Set incoming request header filtering mode."`
// HTTPClient configuration vars.
HTTPClient HTTPClientConfiguration `name:"http-client"`

View file

@ -17,10 +17,16 @@
package config
// Instance federation mode determines how this
// instance federates with others (if at all).
const (
// Instance federation mode determines how this
// instance federates with others (if at all).
InstanceFederationModeBlocklist = "blocklist"
InstanceFederationModeAllowlist = "allowlist"
InstanceFederationModeDefault = InstanceFederationModeBlocklist
// Request header filter mode determines how
// this instance will perform request filtering.
RequestHeaderFilterModeAllow = "allow"
RequestHeaderFilterModeBlock = "block"
RequestHeaderFilterModeDisabled = ""
)

View file

@ -135,6 +135,7 @@ var Defaults = Configuration{
AdvancedThrottlingRetryAfter: time.Second * 30,
AdvancedSenderMultiplier: 2, // 2 senders per CPU
AdvancedCSPExtraURIs: []string{},
AdvancedHeaderFilterMode: RequestHeaderFilterModeDisabled,
Cache: CacheConfiguration{
// Rough memory target that the total

View file

@ -158,6 +158,7 @@ func (s *ConfigState) AddServerFlags(cmd *cobra.Command) {
cmd.Flags().Duration(AdvancedThrottlingRetryAfterFlag(), cfg.AdvancedThrottlingRetryAfter, fieldtag("AdvancedThrottlingRetryAfter", "usage"))
cmd.Flags().Int(AdvancedSenderMultiplierFlag(), cfg.AdvancedSenderMultiplier, fieldtag("AdvancedSenderMultiplier", "usage"))
cmd.Flags().StringSlice(AdvancedCSPExtraURIsFlag(), cfg.AdvancedCSPExtraURIs, fieldtag("AdvancedCSPExtraURIs", "usage"))
cmd.Flags().String(AdvancedHeaderFilterModeFlag(), cfg.AdvancedHeaderFilterMode, fieldtag("AdvancedHeaderFilterMode", "usage"))
cmd.Flags().String(RequestIDHeaderFlag(), cfg.RequestIDHeader, fieldtag("RequestIDHeader", "usage"))
})

View file

@ -2600,6 +2600,31 @@ func GetAdvancedCSPExtraURIs() []string { return global.GetAdvancedCSPExtraURIs(
// SetAdvancedCSPExtraURIs safely sets the value for global configuration 'AdvancedCSPExtraURIs' field
func SetAdvancedCSPExtraURIs(v []string) { global.SetAdvancedCSPExtraURIs(v) }
// GetAdvancedHeaderFilterMode safely fetches the Configuration value for state's 'AdvancedHeaderFilterMode' field
func (st *ConfigState) GetAdvancedHeaderFilterMode() (v string) {
st.mutex.RLock()
v = st.config.AdvancedHeaderFilterMode
st.mutex.RUnlock()
return
}
// SetAdvancedHeaderFilterMode safely sets the Configuration value for state's 'AdvancedHeaderFilterMode' field
func (st *ConfigState) SetAdvancedHeaderFilterMode(v string) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.AdvancedHeaderFilterMode = v
st.reloadToViper()
}
// AdvancedHeaderFilterModeFlag returns the flag name for the 'AdvancedHeaderFilterMode' field
func AdvancedHeaderFilterModeFlag() string { return "advanced-header-filter-mode" }
// GetAdvancedHeaderFilterMode safely fetches the value for global configuration 'AdvancedHeaderFilterMode' field
func GetAdvancedHeaderFilterMode() string { return global.GetAdvancedHeaderFilterMode() }
// SetAdvancedHeaderFilterMode safely sets the value for global configuration 'AdvancedHeaderFilterMode' field
func SetAdvancedHeaderFilterMode(v string) { global.SetAdvancedHeaderFilterMode(v) }
// GetHTTPClientAllowIPs safely fetches the Configuration value for state's 'HTTPClient.AllowIPs' field
func (st *ConfigState) GetHTTPClientAllowIPs() (v []string) {
st.mutex.RLock()

View file

@ -25,10 +25,6 @@ type Basic interface {
// For implementations that don't use tables, this can just return nil.
CreateTable(ctx context.Context, i interface{}) error
// CreateAllTables creates *all* tables necessary for the running of GoToSocial.
// Because it uses the 'if not exists' parameter it is safe to run against a GtS that's already been initialized.
CreateAllTables(ctx context.Context) error
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
DropTable(ctx context.Context, i interface{}) error

View file

@ -22,7 +22,6 @@ import (
"errors"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
)
@ -120,39 +119,6 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) error {
return err
}
func (b *basicDB) CreateAllTables(ctx context.Context) error {
models := []interface{}{
&gtsmodel.Account{},
&gtsmodel.Application{},
&gtsmodel.Block{},
&gtsmodel.DomainBlock{},
&gtsmodel.EmailDomainBlock{},
&gtsmodel.Follow{},
&gtsmodel.FollowRequest{},
&gtsmodel.MediaAttachment{},
&gtsmodel.Mention{},
&gtsmodel.Status{},
&gtsmodel.StatusToEmoji{},
&gtsmodel.StatusFave{},
&gtsmodel.StatusBookmark{},
&gtsmodel.ThreadMute{},
&gtsmodel.Tag{},
&gtsmodel.User{},
&gtsmodel.Emoji{},
&gtsmodel.Instance{},
&gtsmodel.Notification{},
&gtsmodel.RouterSession{},
&gtsmodel.Token{},
&gtsmodel.Client{},
}
for _, i := range models {
if err := b.CreateTable(ctx, i); err != nil {
return err
}
}
return nil
}
func (b *basicDB) DropTable(ctx context.Context, i interface{}) error {
_, err := b.db.NewDropTable().Model(i).IfExists().Exec(ctx)
return err

View file

@ -67,6 +67,7 @@ type DBService struct {
db.Basic
db.Domain
db.Emoji
db.HeaderFilter
db.Instance
db.List
db.Marker
@ -193,6 +194,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db,
state: state,
},
HeaderFilter: &headerFilterDB{
db: db,
state: state,
},
Instance: &instanceDB{
db: db,
state: state,

View file

@ -0,0 +1,207 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"net/http"
"time"
"unsafe"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
)
type headerFilterDB struct {
db *DB
state *state.State
}
func (h *headerFilterDB) AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) {
return h.state.Caches.AllowHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
return h.GetAllowHeaderFilters(ctx)
})
}
func (h *headerFilterDB) AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) {
return h.state.Caches.AllowHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
return h.GetAllowHeaderFilters(ctx)
})
}
func (h *headerFilterDB) BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) {
return h.state.Caches.BlockHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
return h.GetBlockHeaderFilters(ctx)
})
}
func (h *headerFilterDB) BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) {
return h.state.Caches.BlockHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
return h.GetBlockHeaderFilters(ctx)
})
}
func (h *headerFilterDB) GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) {
filter := new(gtsmodel.HeaderFilterAllow)
if err := h.db.NewSelect().
Model(filter).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx); err != nil {
return nil, err
}
return fromAllowFilter(filter), nil
}
func (h *headerFilterDB) GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) {
filter := new(gtsmodel.HeaderFilterBlock)
if err := h.db.NewSelect().
Model(filter).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx); err != nil {
return nil, err
}
return fromBlockFilter(filter), nil
}
func (h *headerFilterDB) GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) {
var filters []*gtsmodel.HeaderFilterAllow
err := h.db.NewSelect().
Model(&filters).
Scan(ctx, &filters)
return fromAllowFilters(filters), err
}
func (h *headerFilterDB) GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) {
var filters []*gtsmodel.HeaderFilterBlock
err := h.db.NewSelect().
Model(&filters).
Scan(ctx, &filters)
return fromBlockFilters(filters), err
}
func (h *headerFilterDB) PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error {
if _, err := h.db.NewInsert().
Model(toAllowFilter(filter)).
Exec(ctx); err != nil {
return err
}
h.state.Caches.AllowHeaderFilters.Clear()
return nil
}
func (h *headerFilterDB) PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error {
if _, err := h.db.NewInsert().
Model(toBlockFilter(filter)).
Exec(ctx); err != nil {
return err
}
h.state.Caches.BlockHeaderFilters.Clear()
return nil
}
func (h *headerFilterDB) UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error {
filter.UpdatedAt = time.Now()
if len(cols) > 0 {
// If we're updating by column,
// ensure "updated_at" is included.
cols = append(cols, "updated_at")
}
if _, err := h.db.NewUpdate().
Model(toAllowFilter(filter)).
Column(cols...).
Where("? = ?", bun.Ident("id"), filter.ID).
Exec(ctx); err != nil {
return err
}
h.state.Caches.AllowHeaderFilters.Clear()
return nil
}
func (h *headerFilterDB) UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error {
filter.UpdatedAt = time.Now()
if len(cols) > 0 {
// If we're updating by column,
// ensure "updated_at" is included.
cols = append(cols, "updated_at")
}
if _, err := h.db.NewUpdate().
Model(toBlockFilter(filter)).
Column(cols...).
Where("? = ?", bun.Ident("id"), filter.ID).
Exec(ctx); err != nil {
return err
}
h.state.Caches.BlockHeaderFilters.Clear()
return nil
}
func (h *headerFilterDB) DeleteAllowHeaderFilter(ctx context.Context, id string) error {
if _, err := h.db.NewDelete().
Table("header_filter_allows").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return err
}
h.state.Caches.AllowHeaderFilters.Clear()
return nil
}
func (h *headerFilterDB) DeleteBlockHeaderFilter(ctx context.Context, id string) error {
if _, err := h.db.NewDelete().
Table("header_filter_blocks").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return err
}
h.state.Caches.BlockHeaderFilters.Clear()
return nil
}
// NOTE:
// all of the below unsafe cast functions
// are only possible because HeaderFilterAllow{},
// HeaderFilterBlock{}, HeaderFilter{} while
// different types in source, have exactly the
// same size and layout in memory. the unsafe
// cast simply changes the type associated with
// that block of memory.
func toAllowFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterAllow {
return (*gtsmodel.HeaderFilterAllow)(unsafe.Pointer(filter))
}
func toBlockFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterBlock {
return (*gtsmodel.HeaderFilterBlock)(unsafe.Pointer(filter))
}
func fromAllowFilter(filter *gtsmodel.HeaderFilterAllow) *gtsmodel.HeaderFilter {
return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter))
}
func fromBlockFilter(filter *gtsmodel.HeaderFilterBlock) *gtsmodel.HeaderFilter {
return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter))
}
func fromAllowFilters(filters []*gtsmodel.HeaderFilterAllow) []*gtsmodel.HeaderFilter {
return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters))
}
func fromBlockFilters(filters []*gtsmodel.HeaderFilterBlock) []*gtsmodel.HeaderFilter {
return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters))
}

View file

@ -0,0 +1,125 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb_test
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type HeaderFilterTestSuite struct {
BunDBStandardTestSuite
}
func (suite *HeaderFilterTestSuite) TestAllowHeaderFilterGetPutUpdateDelete() {
suite.testHeaderFilterGetPutUpdateDelete(
suite.db.GetAllowHeaderFilter,
suite.db.GetAllowHeaderFilters,
suite.db.PutAllowHeaderFilter,
suite.db.UpdateAllowHeaderFilter,
suite.db.DeleteAllowHeaderFilter,
)
}
func (suite *HeaderFilterTestSuite) TestBlockHeaderFilterGetPutUpdateDelete() {
suite.testHeaderFilterGetPutUpdateDelete(
suite.db.GetBlockHeaderFilter,
suite.db.GetBlockHeaderFilters,
suite.db.PutBlockHeaderFilter,
suite.db.UpdateBlockHeaderFilter,
suite.db.DeleteBlockHeaderFilter,
)
}
func (suite *HeaderFilterTestSuite) testHeaderFilterGetPutUpdateDelete(
get func(context.Context, string) (*gtsmodel.HeaderFilter, error),
getAll func(context.Context) ([]*gtsmodel.HeaderFilter, error),
put func(context.Context, *gtsmodel.HeaderFilter) error,
update func(context.Context, *gtsmodel.HeaderFilter, ...string) error,
delete func(context.Context, string) error,
) {
t := suite.T()
// Create new example header filter.
filter := gtsmodel.HeaderFilter{
ID: "some unique id",
Header: "Http-Header-Key",
Regex: ".*",
AuthorID: "some unique author id",
}
// Create new cancellable test context.
ctx := context.Background()
ctx, cncl := context.WithCancel(ctx)
defer cncl()
// Insert the example header filter into db.
if err := put(ctx, &filter); err != nil {
t.Fatalf("error inserting header filter: %v", err)
}
// Now fetch newly created filter.
check, err := get(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching header filter: %v", err)
}
// Check all expected fields match.
suite.Equal(filter.ID, check.ID)
suite.Equal(filter.Header, check.Header)
suite.Equal(filter.Regex, check.Regex)
suite.Equal(filter.AuthorID, check.AuthorID)
// Fetch all header filters.
all, err := getAll(ctx)
if err != nil {
t.Fatalf("error fetching header filters: %v", err)
}
// Ensure contains example.
suite.Equal(len(all), 1)
suite.Equal(all[0].ID, filter.ID)
// Update the header filter regex value.
check.Regex = "new regex value"
if err := update(ctx, check); err != nil {
t.Fatalf("error updating header filter: %v", err)
}
// Ensure 'updated_at' was updated on check model.
suite.True(check.UpdatedAt.After(filter.UpdatedAt))
// Now delete the header filter from db.
if err := delete(ctx, filter.ID); err != nil {
t.Fatalf("error deleting header filter: %v", err)
}
// Ensure we can't refetch it.
_, err = get(ctx, filter.ID)
if err != db.ErrNoEntries {
t.Fatalf("deleted header filter returned unexpected error: %v", err)
}
}
func TestHeaderFilterTestSuite(t *testing.T) {
suite.Run(t, new(HeaderFilterTestSuite))
}

View file

@ -0,0 +1,54 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package migrations
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
func init() {
up := func(ctx context.Context, db *bun.DB) error {
for _, model := range []any{
&gtsmodel.HeaderFilterAllow{},
&gtsmodel.HeaderFilterBlock{},
} {
_, err := db.NewCreateTable().
IfNotExists().
Model(model).
Exec(ctx)
if err != nil {
return err
}
}
return nil
}
down := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return nil
})
}
if err := Migrations.Register(up, down); err != nil {
panic(err)
}
}

View file

@ -30,6 +30,7 @@ type DB interface {
Basic
Domain
Emoji
HeaderFilter
Instance
List
Marker

View file

@ -0,0 +1,73 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package db
import (
"context"
"net/http"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type HeaderFilter interface {
// AllowHeaderRegularMatch performs an headerfilter.Filter.RegularMatch() on cached allow header filters.
// (Note: the actual matching code can be found under ./internal/headerfilter/ ).
AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error)
// AllowHeaderInverseMatch performs an headerfilter.Filter.InverseMatch() on cached allow header filters.
// (Note: the actual matching code can be found under ./internal/headerfilter/ ).
AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error)
// BlockHeaderRegularMatch performs an headerfilter.Filter.RegularMatch() on cached block header filters.
// (Note: the actual matching code can be found under ./internal/headerfilter/ ).
BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error)
// BlockHeaderInverseMatch performs an headerfilter.Filter.InverseMatch() on cached block header filters.
// (Note: the actual matching code can be found under ./internal/headerfilter/ ).
BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error)
// GetAllowHeaderFilter fetches the allow header filter with ID from the database.
GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error)
// GetBlockHeaderFilter fetches the block header filter with ID from the database.
GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error)
// GetAllowHeaderFilters fetches all allow header filters from the database.
GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error)
// GetBlockHeaderFilters fetches all block header filters from the database.
GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error)
// PutAllowHeaderFilter inserts the given allow header filter into the database.
PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error
// PutBlockHeaderFilter inserts the given block header filter into the database.
PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error
// UpdateAllowHeaderFilter updates the given allow header filter in the database, only updating given columns if provided.
UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error
// UpdateBlockHeaderFilter updates the given block header filter in the database, only updating given columns if provided.
UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error
// DeleteAllowHeaderFilter deletes the allow header filter with ID from the database.
DeleteAllowHeaderFilter(ctx context.Context, id string) error
// DeleteBlockHeaderFilter deletes the block header filter with ID from the database.
DeleteBlockHeaderFilter(ctx context.Context, id string) error
}

View file

@ -0,0 +1,54 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package gtsmodel
import (
"time"
"unsafe"
)
func init() {
// Note that since all of the below calculations are
// constant, these should be optimized out of builds.
const filterSz = unsafe.Sizeof(HeaderFilter{})
if unsafe.Sizeof(HeaderFilterAllow{}) != filterSz {
panic("HeaderFilterAllow{} needs to have the same in-memory size / layout as HeaderFilter{}")
}
if unsafe.Sizeof(HeaderFilterBlock{}) != filterSz {
panic("HeaderFilterBlock{} needs to have the same in-memory size / layout as HeaderFilter{}")
}
}
// HeaderFilterAllow represents an allow HTTP header filter in the database.
type HeaderFilterAllow struct{ HeaderFilter }
// HeaderFilterBlock represents a block HTTP header filter in the database.
type HeaderFilterBlock struct{ HeaderFilter }
// HeaderFilter represents an HTTP request filter in
// the database, with a header to match against, value
// matching regex, and details about its creation.
type HeaderFilter struct {
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // ID of this item in the database
Header string `bun:",nullzero,notnull,unique:header_regex"` // Request header this filter pertains to
Regex string `bun:",nullzero,notnull,unique:header_regex"` // Request header value matching regular expression
AuthorID string `bun:"type:CHAR(26),nullzero,notnull"` // Account ID of the creator of this filter
Author *Account `bun:"-"` // Account corresponding to AuthorID
CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created
UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated
}

View file

@ -0,0 +1,136 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package headerfilter
import (
"errors"
"fmt"
"net/http"
"net/textproto"
"regexp"
)
// Maximum header value size before we return
// an instant negative match. They shouldn't
// go beyond this size in most cases anywho.
const MaxHeaderValue = 1024
// ErrLargeHeaderValue is returned on attempting to match on a value > MaxHeaderValue.
var ErrLargeHeaderValue = errors.New("header value too large")
// Filters represents a set of http.Header regular
// expression filters built-in statistic tracking.
type Filters []headerfilter
type headerfilter struct {
// key is the header key to match against
// in canonical textproto mime header format.
key string
// exprs contains regular expressions to
// match values against for this header key.
exprs []*regexp.Regexp
}
// Append will add new header filter expression under given header key.
func (fs *Filters) Append(key string, expr string) error {
var filter *headerfilter
// Ensure in canonical mime header format.
key = textproto.CanonicalMIMEHeaderKey(key)
// Look for existing filter
// with key in filter slice.
for i := range *fs {
if (*fs)[i].key == key {
filter = &((*fs)[i])
break
}
}
if filter == nil {
// No existing filter found, create new.
// Append new header filter to slice.
(*fs) = append((*fs), headerfilter{})
// Then take ptr to this new filter
// at the last index in the slice.
filter = &((*fs)[len((*fs))-1])
// Setup new key.
filter.key = key
}
// Compile regular expression.
reg, err := regexp.Compile(expr)
if err != nil {
return fmt.Errorf("error compiling regexp %q: %w", expr, err)
}
// Append regular expression to filter.
filter.exprs = append(filter.exprs, reg)
return nil
}
// RegularMatch returns whether any values in http header
// matches any of the receiving filter regular expressions.
// This returns the matched header key, and matching regexp.
func (fs Filters) RegularMatch(h http.Header) (string, string, error) {
for _, filter := range fs {
for _, value := range h[filter.key] {
// Don't perform match on large values
// to mitigate denial of service attacks.
if len(value) > MaxHeaderValue {
return "", "", ErrLargeHeaderValue
}
// Compare against regular exprs.
for _, expr := range filter.exprs {
if expr.MatchString(value) {
return filter.key, expr.String(), nil
}
}
}
}
return "", "", nil
}
// InverseMatch returns whether any values in http header do
// NOT match any of the receiving filter regular expressions.
// This returns the matched header key, and matching regexp.
func (fs Filters) InverseMatch(h http.Header) (string, string, error) {
for _, filter := range fs {
for _, value := range h[filter.key] {
// Don't perform match on large values
// to mitigate denial of service attacks.
if len(value) > MaxHeaderValue {
return "", "", ErrLargeHeaderValue
}
// Compare against regular exprs.
for _, expr := range filter.exprs {
if !expr.MatchString(value) {
return filter.key, expr.String(), nil
}
}
}
}
return "", "", nil
}

View file

@ -0,0 +1,251 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package middleware
import (
"sync"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/headerfilter"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
)
var (
allowMatches = matchstats{m: make(map[string]uint64)}
blockMatches = matchstats{m: make(map[string]uint64)}
)
// matchstats is a simple statistics
// counter for header filter matches.
// TODO: replace with otel.
type matchstats struct {
m map[string]uint64
l sync.Mutex
}
func (m *matchstats) Add(hdr, regex string) {
m.l.Lock()
key := hdr + ":" + regex
m.m[key]++
m.l.Unlock()
}
// HeaderFilter returns a gin middleware handler that provides HTTP
// request blocking (filtering) based on database allow / block filters.
func HeaderFilter(state *state.State) gin.HandlerFunc {
switch mode := config.GetAdvancedHeaderFilterMode(); mode {
case config.RequestHeaderFilterModeDisabled:
return func(ctx *gin.Context) {}
case config.RequestHeaderFilterModeAllow:
return headerFilterAllowMode(state)
case config.RequestHeaderFilterModeBlock:
return headerFilterBlockMode(state)
default:
panic("unrecognized filter mode: " + mode)
}
}
func headerFilterAllowMode(state *state.State) func(c *gin.Context) {
_ = *state //nolint
// Allowlist mode: explicit block takes
// precedence over explicit allow.
//
// Headers that have neither block
// or allow entries are blocked.
return func(c *gin.Context) {
// Check if header is explicitly blocked.
block, err := isHeaderBlocked(state, c)
if err != nil {
respondInternalServerError(c, err)
return
}
if block {
respondBlocked(c)
return
}
// Check if header is missing explicit allow.
notAllow, err := isHeaderNotAllowed(state, c)
if err != nil {
respondInternalServerError(c, err)
return
}
if notAllow {
respondBlocked(c)
return
}
// Allowed!
c.Next()
}
}
func headerFilterBlockMode(state *state.State) func(c *gin.Context) {
_ = *state //nolint
// Blocklist/default mode: explicit allow
// takes precedence over explicit block.
//
// Headers that have neither block
// or allow entries are allowed.
return func(c *gin.Context) {
// Check if header is explicitly allowed.
allow, err := isHeaderAllowed(state, c)
if err != nil {
respondInternalServerError(c, err)
return
}
if !allow {
// Check if header is explicitly blocked.
block, err := isHeaderBlocked(state, c)
if err != nil {
respondInternalServerError(c, err)
return
}
if block {
respondBlocked(c)
return
}
}
// Allowed!
c.Next()
}
}
func isHeaderBlocked(state *state.State, c *gin.Context) (bool, error) {
var (
ctx = c.Request.Context()
hdr = c.Request.Header
)
// Perform an explicit is-blocked check on request header.
key, expr, err := state.DB.BlockHeaderRegularMatch(ctx, hdr)
switch err {
case nil:
break
case headerfilter.ErrLargeHeaderValue:
log.Warn(ctx, "large header value")
key = "*" // block large headers
default:
err := gtserror.Newf("error checking header: %w", err)
return false, err
}
if key != "" {
if expr != "" {
// Increment block matches stat.
// TODO: replace expvar with build
// taggable metrics types in State{}.
blockMatches.Add(key, expr)
}
// A header was matched against!
// i.e. this request is blocked.
return true, nil
}
return false, nil
}
func isHeaderAllowed(state *state.State, c *gin.Context) (bool, error) {
var (
ctx = c.Request.Context()
hdr = c.Request.Header
)
// Perform an explicit is-allowed check on request header.
key, expr, err := state.DB.AllowHeaderRegularMatch(ctx, hdr)
switch err {
case nil:
break
case headerfilter.ErrLargeHeaderValue:
log.Warn(ctx, "large header value")
key = "" // block large headers
default:
err := gtserror.Newf("error checking header: %w", err)
return false, err
}
if key != "" {
if expr != "" {
// Increment allow matches stat.
// TODO: replace expvar with build
// taggable metrics types in State{}.
allowMatches.Add(key, expr)
}
// A header was matched against!
// i.e. this request is allowed.
return true, nil
}
return false, nil
}
func isHeaderNotAllowed(state *state.State, c *gin.Context) (bool, error) {
var (
ctx = c.Request.Context()
hdr = c.Request.Header
)
// Perform an explicit is-NOT-allowed check on request header.
key, expr, err := state.DB.AllowHeaderInverseMatch(ctx, hdr)
switch err {
case nil:
break
case headerfilter.ErrLargeHeaderValue:
log.Warn(ctx, "large header value")
key = "*" // block large headers
default:
err := gtserror.Newf("error checking header: %w", err)
return false, err
}
if key != "" {
if expr != "" {
// Increment allow matches stat.
// TODO: replace expvar with build
// taggable metrics types in State{}.
allowMatches.Add(key, expr)
}
// A header was matched against!
// i.e. request is NOT allowed.
return true, nil
}
return false, nil
}

View file

@ -0,0 +1,299 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package middleware_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/headerfilter"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/middleware"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/testrig"
)
func TestHeaderFilter(t *testing.T) {
testrig.InitTestLog()
testrig.InitTestConfig()
for _, test := range []struct {
mode string
allow []filter
block []filter
input http.Header
expect bool
}{
{
// Allow mode with expected 200 OK.
mode: config.RequestHeaderFilterModeAllow,
allow: []filter{
{"User-Agent", ".*Firefox.*"},
},
block: []filter{},
input: http.Header{
"User-Agent": []string{"Firefox v169.42; Extra Tracking Info"},
},
expect: true,
},
{
// Allow mode with expected 403 Forbidden.
mode: config.RequestHeaderFilterModeAllow,
allow: []filter{
{"User-Agent", ".*Firefox.*"},
},
block: []filter{},
input: http.Header{
"User-Agent": []string{"Chromium v169.42; Extra Tracking Info"},
},
expect: false,
},
{
// Allow mode with too long header value expecting 403 Forbidden.
mode: config.RequestHeaderFilterModeAllow,
allow: []filter{
{"User-Agent", ".*"},
},
block: []filter{},
input: http.Header{
"User-Agent": []string{func() string {
var buf strings.Builder
for i := 0; i < headerfilter.MaxHeaderValue+1; i++ {
buf.WriteByte(' ')
}
return buf.String()
}()},
},
expect: false,
},
{
// Allow mode with explicit block expecting 403 Forbidden.
mode: config.RequestHeaderFilterModeAllow,
allow: []filter{
{"User-Agent", ".*Firefox.*"},
},
block: []filter{
{"User-Agent", ".*Firefox v169\\.42.*"},
},
input: http.Header{
"User-Agent": []string{"Firefox v169.42; Extra Tracking Info"},
},
expect: false,
},
{
// Block mode with an expected 403 Forbidden.
mode: config.RequestHeaderFilterModeBlock,
allow: []filter{},
block: []filter{
{"User-Agent", ".*Firefox.*"},
},
input: http.Header{
"User-Agent": []string{"Firefox v169.42; Extra Tracking Info"},
},
expect: false,
},
{
// Block mode with an expected 200 OK.
mode: config.RequestHeaderFilterModeBlock,
allow: []filter{},
block: []filter{
{"User-Agent", ".*Firefox.*"},
},
input: http.Header{
"User-Agent": []string{"Chromium v169.42; Extra Tracking Info"},
},
expect: true,
},
{
// Block mode with too long header value expecting 403 Forbidden.
mode: config.RequestHeaderFilterModeBlock,
allow: []filter{},
block: []filter{
{"User-Agent", "none"},
},
input: http.Header{
"User-Agent": []string{func() string {
var buf strings.Builder
for i := 0; i < headerfilter.MaxHeaderValue+1; i++ {
buf.WriteByte(' ')
}
return buf.String()
}()},
},
expect: false,
},
{
// Block mode with explicit allow expecting 200 OK.
mode: config.RequestHeaderFilterModeBlock,
allow: []filter{
{"User-Agent", ".*Firefox.*"},
},
block: []filter{
{"User-Agent", ".*Firefox v169\\.42.*"},
},
input: http.Header{
"User-Agent": []string{"Firefox v169.42; Extra Tracking Info"},
},
expect: true,
},
{
// Disabled mode with an expected 200 OK.
mode: config.RequestHeaderFilterModeDisabled,
allow: []filter{
{"Key1", "only-this"},
{"Key2", "only-this"},
{"Key3", "only-this"},
},
block: []filter{
{"Key1", "Value"},
{"Key2", "Value"},
{"Key3", "Value"},
},
input: http.Header{
"Key1": []string{"Value"},
"Key2": []string{"Value"},
"Key3": []string{"Value"},
},
expect: true,
},
} {
// Generate a unique name for this test case.
name := fmt.Sprintf("%s allow=%v block=%v => expect=%v",
test.mode,
test.allow,
test.block,
test.expect,
)
// Update header filter mode to test case.
config.SetAdvancedHeaderFilterMode(test.mode)
// Run this particular test case.
ok := t.Run(name, func(t *testing.T) {
testHeaderFilter(t,
test.allow,
test.block,
test.input,
test.expect,
)
})
if !ok {
return
}
}
}
func testHeaderFilter(t *testing.T, allow, block []filter, input http.Header, expect bool) {
var err error
// Create test context with cancel.
ctx := context.Background()
ctx, cncl := context.WithCancel(ctx)
defer cncl()
// Initialize caches.
var state state.State
state.Caches.Init()
// Create new database instance with test config.
state.DB, err = bundb.NewBunDBService(ctx, &state)
if err != nil {
t.Fatalf("error opening database: %v", err)
}
// Insert all allow filters into DB.
for _, filter := range allow {
filter := &gtsmodel.HeaderFilter{
ID: id.NewULID(),
Header: filter.header,
Regex: filter.regex,
AuthorID: "admin-id",
Author: nil,
}
if err := state.DB.PutAllowHeaderFilter(ctx, filter); err != nil {
t.Fatalf("error inserting allow filter into database: %v", err)
}
}
// Insert all block filters into DB.
for _, filter := range block {
filter := &gtsmodel.HeaderFilter{
ID: id.NewULID(),
Header: filter.header,
Regex: filter.regex,
AuthorID: "admin-id",
Author: nil,
}
if err := state.DB.PutBlockHeaderFilter(ctx, filter); err != nil {
t.Fatalf("error inserting block filter into database: %v", err)
}
}
// Gin test http engine
// (used for ctx init).
e := gin.New()
// Create new filter middleware to test against.
middleware := middleware.HeaderFilter(&state)
e.Use(middleware)
// Set the empty gin handler (always returns okay).
e.Handle("GET", "/", func(ctx *gin.Context) { ctx.Status(200) })
// Prepare a gin test context.
r := httptest.NewRequest("GET", "/", nil)
rw := httptest.NewRecorder()
// Set input headers.
r.Header = input
// Pass req through
// engine handler.
e.ServeHTTP(rw, r)
// Get http result.
res := rw.Result()
switch {
case expect && res.StatusCode != http.StatusOK:
t.Errorf("unexpected response (should allow): %s", res.Status)
case !expect && res.StatusCode != http.StatusForbidden:
t.Errorf("unexpected response (should block): %s", res.Status)
}
}
type filter struct {
header string
regex string
}
func (hf *filter) String() string {
return fmt.Sprintf("%s=%q", hf.header, hf.regex)
}

View file

@ -146,7 +146,7 @@ func RateLimit(limit int, exceptions []string) gin.HandlerFunc {
apiutil.Data(c,
http.StatusTooManyRequests,
apiutil.AppJSON,
apiutil.ErrorRateLimitReached,
apiutil.ErrorRateLimited,
)
c.Abort()
return

View file

@ -18,21 +18,22 @@
package middleware
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
)
// UserAgent returns a gin middleware which aborts requests with
// empty user agent strings, returning code 418 - I'm a teapot.
func UserAgent() gin.HandlerFunc {
// todo: make this configurable
var rsp = []byte(`{"error": "I'm a teapot: no user-agent sent with request"}`)
return func(c *gin.Context) {
if ua := c.Request.UserAgent(); ua == "" {
code := http.StatusTeapot
err := errors.New(http.StatusText(code) + ": no user-agent sent with request")
c.AbortWithStatusJSON(code, gin.H{"error": err.Error()})
apiutil.Data(c,
http.StatusTeapot, apiutil.AppJSON, rsp)
c.Abort()
}
}
}

View file

@ -0,0 +1,51 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
)
// respondBlocked responds to the given gin context with
// status forbidden, and a generic forbidden JSON response,
// finally aborting the gin handler chain.
func respondBlocked(c *gin.Context) {
apiutil.Data(c,
http.StatusForbidden,
apiutil.AppJSON,
apiutil.StatusForbiddenJSON,
)
c.Abort()
}
// respondInternalServerError responds to the given gin context
// with status internal server error, a generic internal server
// error JSON response, sets the given error on the gin context
// for later logging, finally aborting the gin handler chain.
func respondInternalServerError(c *gin.Context, err error) {
apiutil.Data(c,
http.StatusInternalServerError,
apiutil.AppJSON,
apiutil.StatusInternalServerErrorJSON,
)
_ = c.Error(err)
c.Abort()
}

View file

@ -0,0 +1,215 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package admin
import (
"context"
"errors"
"net/textproto"
"regexp"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/headerfilter"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
// GetAllowHeaderFilter fetches allow HTTP header filter with provided ID from the database.
func (p *Processor) GetAllowHeaderFilter(ctx context.Context, id string) (*apimodel.HeaderFilter, gtserror.WithCode) {
return p.getHeaderFilter(ctx, id, p.state.DB.GetAllowHeaderFilter)
}
// GetBlockHeaderFilter fetches block HTTP header filter with provided ID from the database.
func (p *Processor) GetBlockHeaderFilter(ctx context.Context, id string) (*apimodel.HeaderFilter, gtserror.WithCode) {
return p.getHeaderFilter(ctx, id, p.state.DB.GetBlockHeaderFilter)
}
// GetAllowHeaderFilters fetches all allow HTTP header filters stored in the database.
func (p *Processor) GetAllowHeaderFilters(ctx context.Context) ([]*apimodel.HeaderFilter, gtserror.WithCode) {
return p.getHeaderFilters(ctx, p.state.DB.GetAllowHeaderFilters)
}
// GetBlockHeaderFilters fetches all block HTTP header filters stored in the database.
func (p *Processor) GetBlockHeaderFilters(ctx context.Context) ([]*apimodel.HeaderFilter, gtserror.WithCode) {
return p.getHeaderFilters(ctx, p.state.DB.GetBlockHeaderFilters)
}
// CreateAllowHeaderFilter inserts the incoming allow HTTP header filter into the database, marking as authored by provided admin account.
func (p *Processor) CreateAllowHeaderFilter(ctx context.Context, admin *gtsmodel.Account, request *apimodel.HeaderFilterRequest) (*apimodel.HeaderFilter, gtserror.WithCode) {
return p.createHeaderFilter(ctx, admin, request, p.state.DB.PutAllowHeaderFilter)
}
// CreateBlockHeaderFilter inserts the incoming block HTTP header filter into the database, marking as authored by provided admin account.
func (p *Processor) CreateBlockHeaderFilter(ctx context.Context, admin *gtsmodel.Account, request *apimodel.HeaderFilterRequest) (*apimodel.HeaderFilter, gtserror.WithCode) {
return p.createHeaderFilter(ctx, admin, request, p.state.DB.PutBlockHeaderFilter)
}
// DeleteAllowHeaderFilter deletes the allowing HTTP header filter with provided ID from the database.
func (p *Processor) DeleteAllowHeaderFilter(ctx context.Context, id string) gtserror.WithCode {
return p.deleteHeaderFilter(ctx, id, p.state.DB.DeleteAllowHeaderFilter)
}
// DeleteBlockHeaderFilter deletes the blocking HTTP header filter with provided ID from the database.
func (p *Processor) DeleteBlockHeaderFilter(ctx context.Context, id string) gtserror.WithCode {
return p.deleteHeaderFilter(ctx, id, p.state.DB.DeleteBlockHeaderFilter)
}
// getHeaderFilter fetches an HTTP header filter with
// provided ID, using given get function, converting the
// resulting filter to returnable frontend API model.
func (p *Processor) getHeaderFilter(
ctx context.Context,
id string,
get func(context.Context, string) (*gtsmodel.HeaderFilter, error),
) (
*apimodel.HeaderFilter,
gtserror.WithCode,
) {
// Select filter by ID from db.
filter, err := get(ctx, id)
switch {
// Successfully found.
case err == nil:
return toAPIHeaderFilter(filter), nil
// Filter does not exist with ID.
case errors.Is(err, db.ErrNoEntries):
const text = "filter not found"
return nil, gtserror.NewErrorNotFound(errors.New(text), text)
// Any other error type.
default:
err := gtserror.Newf("error selecting from database: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
}
// getHeaderFilters fetches all HTTP header filters
// using given get function, converting the resulting
// filters to returnable frontend API models.
func (p *Processor) getHeaderFilters(
ctx context.Context,
get func(context.Context) ([]*gtsmodel.HeaderFilter, error),
) (
[]*apimodel.HeaderFilter,
gtserror.WithCode,
) {
// Select all filters from DB.
filters, err := get(ctx)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
// Only handle errors other than not-found types.
err := gtserror.Newf("error selecting from database: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
// Convert passed header filters to apimodel filters.
apiFilters := make([]*apimodel.HeaderFilter, len(filters))
for i := range filters {
apiFilters[i] = toAPIHeaderFilter(filters[i])
}
return apiFilters, nil
}
// createHeaderFilter inserts the given HTTP header
// filter into database, marking as authored by the
// provided admin, using the given insert function.
func (p *Processor) createHeaderFilter(
ctx context.Context,
admin *gtsmodel.Account,
request *apimodel.HeaderFilterRequest,
insert func(context.Context, *gtsmodel.HeaderFilter) error,
) (
*apimodel.HeaderFilter,
gtserror.WithCode,
) {
// Convert header key to canonical mime header format.
request.Header = textproto.CanonicalMIMEHeaderKey(request.Header)
// Validate incoming header filter.
if errWithCode := validateHeaderFilter(
request.Header,
request.Regex,
); errWithCode != nil {
return nil, errWithCode
}
// Create new database model with ID.
var filter gtsmodel.HeaderFilter
filter.ID = id.NewULID()
filter.Header = request.Header
filter.Regex = request.Regex
filter.AuthorID = admin.ID
filter.Author = admin
// Insert new header filter into the database.
if err := insert(ctx, &filter); err != nil {
err := gtserror.Newf("error inserting into database: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
// Finally return API model response.
return toAPIHeaderFilter(&filter), nil
}
// deleteHeaderFilter deletes the HTTP header filter
// with provided ID, using the given delete function.
func (p *Processor) deleteHeaderFilter(
ctx context.Context,
id string,
delete func(context.Context, string) error,
) gtserror.WithCode {
if err := delete(ctx, id); err != nil && !errors.Is(err, db.ErrNoEntries) {
err := gtserror.Newf("error deleting from database: %w", err)
return gtserror.NewErrorInternalError(err)
}
return nil
}
// toAPIFilter performs a simple conversion of database model HeaderFilter to API model.
func toAPIHeaderFilter(filter *gtsmodel.HeaderFilter) *apimodel.HeaderFilter {
return &apimodel.HeaderFilter{
ID: filter.ID,
Header: filter.Header,
Regex: filter.Regex,
CreatedBy: filter.AuthorID,
CreatedAt: util.FormatISO8601(filter.CreatedAt),
}
}
// validateHeaderFilter validates incoming filter's header key, and regular expression.
func validateHeaderFilter(header, regex string) gtserror.WithCode {
// Check header validity (within our own bound checks).
if header == "" || len(header) > headerfilter.MaxHeaderValue {
const text = "invalid request header key (empty or too long)"
return gtserror.NewErrorBadRequest(errors.New(text), text)
}
// Ensure this is compilable regex.
_, err := regexp.Compile(regex)
if err != nil {
return gtserror.NewErrorBadRequest(err, err.Error())
}
return nil
}