mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-12-10 10:08:08 -06:00
[feature] request blocking by http headers (#2409)
This commit is contained in:
parent
07bd848028
commit
8ebb7775a3
36 changed files with 2561 additions and 81 deletions
|
|
@ -35,6 +35,10 @@ const (
|
|||
DomainAllowsPath = BasePath + "/domain_allows"
|
||||
DomainAllowsPathWithID = DomainAllowsPath + "/:" + IDKey
|
||||
DomainKeysExpirePath = BasePath + "/domain_keys_expire"
|
||||
HeaderAllowsPath = BasePath + "/header_allows"
|
||||
HeaderAllowsPathWithID = HeaderAllowsPath + "/:" + IDKey
|
||||
HeaderBlocksPath = BasePath + "/header_blocks"
|
||||
HeaderBlocksPathWithID = HeaderAllowsPath + "/:" + IDKey
|
||||
AccountsPath = BasePath + "/accounts"
|
||||
AccountsPathWithID = AccountsPath + "/:" + IDKey
|
||||
AccountsActionPath = AccountsPathWithID + "/action"
|
||||
|
|
@ -95,6 +99,16 @@ func (m *Module) Route(attachHandler func(method string, path string, f ...gin.H
|
|||
attachHandler(http.MethodGet, DomainAllowsPathWithID, m.DomainAllowGETHandler)
|
||||
attachHandler(http.MethodDelete, DomainAllowsPathWithID, m.DomainAllowDELETEHandler)
|
||||
|
||||
// header filtering administration routes
|
||||
attachHandler(http.MethodGet, HeaderAllowsPathWithID, m.HeaderFilterAllowGET)
|
||||
attachHandler(http.MethodGet, HeaderBlocksPathWithID, m.HeaderFilterBlockGET)
|
||||
attachHandler(http.MethodGet, HeaderAllowsPath, m.HeaderFilterAllowsGET)
|
||||
attachHandler(http.MethodGet, HeaderBlocksPath, m.HeaderFilterBlocksGET)
|
||||
attachHandler(http.MethodPost, HeaderAllowsPath, m.HeaderFilterAllowPOST)
|
||||
attachHandler(http.MethodPost, HeaderBlocksPath, m.HeaderFilterBlockPOST)
|
||||
attachHandler(http.MethodDelete, HeaderAllowsPathWithID, m.HeaderFilterAllowDELETE)
|
||||
attachHandler(http.MethodDelete, HeaderBlocksPathWithID, m.HeaderFilterBlockDELETE)
|
||||
|
||||
// domain maintenance stuff
|
||||
attachHandler(http.MethodPost, DomainKeysExpirePath, m.DomainKeysExpirePOSTHandler)
|
||||
|
||||
|
|
|
|||
173
internal/api/client/admin/headerfilter.go
Normal file
173
internal/api/client/admin/headerfilter.go
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
)
|
||||
|
||||
// getHeaderFilter is a gin handler function that returns details of an HTTP header filter with provided ID, using given get function.
|
||||
func (m *Module) getHeaderFilter(c *gin.Context, get func(context.Context, string) (*apimodel.HeaderFilter, gtserror.WithCode)) {
|
||||
authed, err := oauth.Authed(c, true, true, true, true)
|
||||
if err != nil {
|
||||
errWithCode := gtserror.NewErrorUnauthorized(err, err.Error())
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
if !*authed.User.Admin {
|
||||
const text = "user not an admin"
|
||||
errWithCode := gtserror.NewErrorForbidden(errors.New(text), text)
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
|
||||
errWithCode := gtserror.NewErrorNotAcceptable(err, err.Error())
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
filterID, errWithCode := apiutil.ParseID(c.Param("ID"))
|
||||
if errWithCode != nil {
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
filter, errWithCode := get(c.Request.Context(), filterID)
|
||||
if errWithCode != nil {
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(c, http.StatusOK, filter)
|
||||
}
|
||||
|
||||
// getHeaderFilters is a gin handler function that returns details of all HTTP header filters using given get function.
|
||||
func (m *Module) getHeaderFilters(c *gin.Context, get func(context.Context) ([]*apimodel.HeaderFilter, gtserror.WithCode)) {
|
||||
authed, err := oauth.Authed(c, true, true, true, true)
|
||||
if err != nil {
|
||||
errWithCode := gtserror.NewErrorUnauthorized(err, err.Error())
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
if !*authed.User.Admin {
|
||||
const text = "user not an admin"
|
||||
errWithCode := gtserror.NewErrorForbidden(errors.New(text), text)
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
|
||||
errWithCode := gtserror.NewErrorNotAcceptable(err, err.Error())
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
filters, errWithCode := get(c.Request.Context())
|
||||
if errWithCode != nil {
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(c, http.StatusOK, filters)
|
||||
}
|
||||
|
||||
// createHeaderFilter is a gin handler function that creates a HTTP header filter entry using provided form data, passing to given create function.
|
||||
func (m *Module) createHeaderFilter(c *gin.Context, create func(context.Context, *gtsmodel.Account, *apimodel.HeaderFilterRequest) (*apimodel.HeaderFilter, gtserror.WithCode)) {
|
||||
authed, err := oauth.Authed(c, true, true, true, true)
|
||||
if err != nil {
|
||||
errWithCode := gtserror.NewErrorUnauthorized(err, err.Error())
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
if !*authed.User.Admin {
|
||||
const text = "user not an admin"
|
||||
errWithCode := gtserror.NewErrorForbidden(errors.New(text), text)
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil {
|
||||
errWithCode := gtserror.NewErrorNotAcceptable(err, err.Error())
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
var form apimodel.HeaderFilterRequest
|
||||
|
||||
if err := c.ShouldBind(&form); err != nil {
|
||||
errWithCode := gtserror.NewErrorBadRequest(err, err.Error())
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
filter, errWithCode := create(
|
||||
c.Request.Context(),
|
||||
authed.Account,
|
||||
&form,
|
||||
)
|
||||
if errWithCode != nil {
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
apiutil.JSON(c, http.StatusOK, filter)
|
||||
}
|
||||
|
||||
// deleteHeaderFilter is a gin handler function that deletes an HTTP header filter with provided ID, using given delete function.
|
||||
func (m *Module) deleteHeaderFilter(c *gin.Context, delete func(context.Context, string) gtserror.WithCode) {
|
||||
authed, err := oauth.Authed(c, true, true, true, true)
|
||||
if err != nil {
|
||||
errWithCode := gtserror.NewErrorUnauthorized(err, err.Error())
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
if !*authed.User.Admin {
|
||||
const text = "user not an admin"
|
||||
errWithCode := gtserror.NewErrorForbidden(errors.New(text), text)
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
filterID, errWithCode := apiutil.ParseID(c.Param("ID"))
|
||||
if errWithCode != nil {
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
errWithCode = delete(c.Request.Context(), filterID)
|
||||
if errWithCode != nil {
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusAccepted)
|
||||
}
|
||||
102
internal/api/client/admin/headerfilter_create.go
Normal file
102
internal/api/client/admin/headerfilter_create.go
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// HeaderFilterAllowPOST swagger:operation POST /api/v1/admin/header_allows headerFilterAllowCreate
|
||||
//
|
||||
// Create new "allow" HTTP request header filter.
|
||||
//
|
||||
// The parameters can also be given in the body of the request, as JSON, if the content-type is set to 'application/json'.
|
||||
// The parameters can also be given in the body of the request, as XML, if the content-type is set to 'application/xml'.
|
||||
//
|
||||
// ---
|
||||
// tags:
|
||||
// - admin
|
||||
//
|
||||
// consumes:
|
||||
// - application/json
|
||||
// - application/xml
|
||||
// - application/x-www-form-urlencoded
|
||||
//
|
||||
// produces:
|
||||
// - application/json
|
||||
//
|
||||
// security:
|
||||
// - OAuth2 Bearer:
|
||||
// - admin
|
||||
//
|
||||
// responses:
|
||||
// '200':
|
||||
// description: The newly created "allow" header filter.
|
||||
// schema:
|
||||
// "$ref": "#/definitions/headerFilter"
|
||||
// '400':
|
||||
// description: bad request
|
||||
// '401':
|
||||
// description: unauthorized
|
||||
// '403':
|
||||
// description: forbidden
|
||||
// '500':
|
||||
// description: internal server error
|
||||
func (m *Module) HeaderFilterAllowPOST(c *gin.Context) {
|
||||
m.createHeaderFilter(c, m.processor.Admin().CreateAllowHeaderFilter)
|
||||
}
|
||||
|
||||
// HeaderFilterBlockPOST swagger:operation POST /api/v1/admin/header_blocks headerFilterBlockCreate
|
||||
//
|
||||
// Create new "block" HTTP request header filter.
|
||||
//
|
||||
// The parameters can also be given in the body of the request, as JSON, if the content-type is set to 'application/json'.
|
||||
// The parameters can also be given in the body of the request, as XML, if the content-type is set to 'application/xml'.
|
||||
//
|
||||
// ---
|
||||
// tags:
|
||||
// - admin
|
||||
//
|
||||
// consumes:
|
||||
// - application/json
|
||||
// - application/xml
|
||||
// - application/x-www-form-urlencoded
|
||||
//
|
||||
// produces:
|
||||
// - application/json
|
||||
//
|
||||
// security:
|
||||
// - OAuth2 Bearer:
|
||||
// - admin
|
||||
//
|
||||
// responses:
|
||||
// '200':
|
||||
// description: The newly created "block" header filter.
|
||||
// schema:
|
||||
// "$ref": "#/definitions/headerFilter"
|
||||
// '400':
|
||||
// description: bad request
|
||||
// '401':
|
||||
// description: unauthorized
|
||||
// '403':
|
||||
// description: forbidden
|
||||
// '500':
|
||||
// description: internal server error
|
||||
func (m *Module) HeaderFilterBlockPOST(c *gin.Context) {
|
||||
m.createHeaderFilter(c, m.processor.Admin().CreateBlockHeaderFilter)
|
||||
}
|
||||
96
internal/api/client/admin/headerfilter_delete.go
Normal file
96
internal/api/client/admin/headerfilter_delete.go
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// HeaderFilterAllowDELETE swagger:operation DELETE /api/v1/admin/header_allows/{id} headerFilterAllowDelete
|
||||
//
|
||||
// Delete the "allow" header filter with the given ID.
|
||||
//
|
||||
// ---
|
||||
// tags:
|
||||
// - admin
|
||||
//
|
||||
// parameters:
|
||||
// -
|
||||
// name: id
|
||||
// type: string
|
||||
// description: Target header filter ID.
|
||||
// in: path
|
||||
// required: true
|
||||
//
|
||||
// security:
|
||||
// - OAuth2 Bearer:
|
||||
// - admin
|
||||
//
|
||||
// responses:
|
||||
// '202':
|
||||
// description: Accepted
|
||||
// '400':
|
||||
// description: bad request
|
||||
// '401':
|
||||
// description: unauthorized
|
||||
// '403':
|
||||
// description: forbidden
|
||||
// '404':
|
||||
// description: not found
|
||||
// '500':
|
||||
// description: internal server error
|
||||
func (m *Module) HeaderFilterAllowDELETE(c *gin.Context) {
|
||||
m.deleteHeaderFilter(c, m.processor.Admin().DeleteAllowHeaderFilter)
|
||||
}
|
||||
|
||||
// HeaderFilterBlockDELETE swagger:operation DELETE /api/v1/admin/header_blocks/{id} headerFilterBlockDelete
|
||||
//
|
||||
// Delete the "block" header filter with the given ID.
|
||||
//
|
||||
// ---
|
||||
// tags:
|
||||
// - admin
|
||||
//
|
||||
// parameters:
|
||||
// -
|
||||
// name: id
|
||||
// type: string
|
||||
// description: Target header filter ID.
|
||||
// in: path
|
||||
// required: true
|
||||
//
|
||||
// security:
|
||||
// - OAuth2 Bearer:
|
||||
// - admin
|
||||
//
|
||||
// responses:
|
||||
// '202':
|
||||
// description: Accepted
|
||||
// '400':
|
||||
// description: bad request
|
||||
// '401':
|
||||
// description: unauthorized
|
||||
// '403':
|
||||
// description: forbidden
|
||||
// '404':
|
||||
// description: not found
|
||||
// '500':
|
||||
// description: internal server error
|
||||
func (m *Module) HeaderFilterBlockDELETE(c *gin.Context) {
|
||||
m.deleteHeaderFilter(c, m.processor.Admin().DeleteAllowHeaderFilter)
|
||||
}
|
||||
164
internal/api/client/admin/headerfilter_get.go
Normal file
164
internal/api/client/admin/headerfilter_get.go
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package admin
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
// HeaderFilterAllowGET swagger:operation GET /api/v1/admin/header_allows/{id} headerFilterAllowGet
|
||||
//
|
||||
// Get "allow" header filter with the given ID.
|
||||
//
|
||||
// ---
|
||||
// tags:
|
||||
// - admin
|
||||
//
|
||||
// parameters:
|
||||
// -
|
||||
// name: id
|
||||
// type: string
|
||||
// description: Target header filter ID.
|
||||
// in: path
|
||||
// required: true
|
||||
//
|
||||
// security:
|
||||
// - OAuth2 Bearer:
|
||||
// - admin
|
||||
//
|
||||
// responses:
|
||||
// '200':
|
||||
// description: The requested "allow" header filter.
|
||||
// schema:
|
||||
// "$ref": "#/definitions/headerFilter"
|
||||
// '400':
|
||||
// description: bad request
|
||||
// '401':
|
||||
// description: unauthorized
|
||||
// '403':
|
||||
// description: forbidden
|
||||
// '404':
|
||||
// description: not found
|
||||
// '500':
|
||||
// description: internal server error
|
||||
func (m *Module) HeaderFilterAllowGET(c *gin.Context) {
|
||||
m.getHeaderFilter(c, m.processor.Admin().GetAllowHeaderFilter)
|
||||
}
|
||||
|
||||
// HeaderFilterBlockGET swagger:operation GET /api/v1/admin/header_blocks/{id} headerFilterBlockGet
|
||||
//
|
||||
// Get "block" header filter with the given ID.
|
||||
//
|
||||
// ---
|
||||
// tags:
|
||||
// - admin
|
||||
//
|
||||
// parameters:
|
||||
// -
|
||||
// name: id
|
||||
// type: string
|
||||
// description: Target header filter ID.
|
||||
// in: path
|
||||
// required: true
|
||||
//
|
||||
// security:
|
||||
// - OAuth2 Bearer:
|
||||
// - admin
|
||||
//
|
||||
// responses:
|
||||
// '200':
|
||||
// description: The requested "block" header filter.
|
||||
// schema:
|
||||
// "$ref": "#/definitions/headerFilter"
|
||||
// '400':
|
||||
// description: bad request
|
||||
// '401':
|
||||
// description: unauthorized
|
||||
// '403':
|
||||
// description: forbidden
|
||||
// '404':
|
||||
// description: not found
|
||||
// '500':
|
||||
// description: internal server error
|
||||
func (m *Module) HeaderFilterBlockGET(c *gin.Context) {
|
||||
m.getHeaderFilter(c, m.processor.Admin().GetBlockHeaderFilter)
|
||||
}
|
||||
|
||||
// HeaderFilterAllowsGET swagger:operation GET /api/v1/admin/header_allows headerFilterAllowsGet
|
||||
//
|
||||
// Get all "allow" header filters currently in place.
|
||||
//
|
||||
// ---
|
||||
// tags:
|
||||
// - admin
|
||||
//
|
||||
// security:
|
||||
// - OAuth2 Bearer:
|
||||
// - admin
|
||||
//
|
||||
// responses:
|
||||
// '200':
|
||||
// description: All "allow" header filters currently in place.
|
||||
// schema:
|
||||
// type: array
|
||||
// items:
|
||||
// "$ref": "#/definitions/headerFilter"
|
||||
// '400':
|
||||
// description: bad request
|
||||
// '401':
|
||||
// description: unauthorized
|
||||
// '403':
|
||||
// description: forbidden
|
||||
// '404':
|
||||
// description: not found
|
||||
// '500':
|
||||
// description: internal server error
|
||||
func (m *Module) HeaderFilterAllowsGET(c *gin.Context) {
|
||||
m.getHeaderFilters(c, m.processor.Admin().GetAllowHeaderFilters)
|
||||
}
|
||||
|
||||
// HeaderFilterBlocksGET swagger:operation GET /api/v1/admin/header_blocks headerFilterBlocksGet
|
||||
//
|
||||
// Get all "allow" header filters currently in place.
|
||||
//
|
||||
// ---
|
||||
// tags:
|
||||
// - admin
|
||||
//
|
||||
// security:
|
||||
// - OAuth2 Bearer:
|
||||
// - admin
|
||||
//
|
||||
// responses:
|
||||
// '200':
|
||||
// description: All "block" header filters currently in place.
|
||||
// schema:
|
||||
// type: array
|
||||
// items:
|
||||
// "$ref": "#/definitions/headerFilter"
|
||||
// '400':
|
||||
// description: bad request
|
||||
// '401':
|
||||
// description: unauthorized
|
||||
// '403':
|
||||
// description: forbidden
|
||||
// '404':
|
||||
// description: not found
|
||||
// '500':
|
||||
// description: internal server error
|
||||
func (m *Module) HeaderFilterBlocksGET(c *gin.Context) {
|
||||
m.getHeaderFilters(c, m.processor.Admin().GetBlockHeaderFilters)
|
||||
}
|
||||
55
internal/api/model/headerfilter.go
Normal file
55
internal/api/model/headerfilter.go
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package model
|
||||
|
||||
// HeaderFilter represents a regex value filter applied to one particular HTTP header (allow / block).
|
||||
type HeaderFilter struct {
|
||||
// The ID of the header filter.
|
||||
// example: 01FBW21XJA09XYX51KV5JVBW0F
|
||||
// readonly: true
|
||||
ID string `json:"id"`
|
||||
|
||||
// The HTTP header to match against.
|
||||
// example: User-Agent
|
||||
Header string `json:"header"`
|
||||
|
||||
// The header value matching regular expression.
|
||||
// example: .*Firefox.*
|
||||
Regex string `json:"regex"`
|
||||
|
||||
// The ID of the admin account that created this header filter.
|
||||
// example: 01FBW2758ZB6PBR200YPDDJK4C
|
||||
// readonly: true
|
||||
CreatedBy string `json:"created_by"`
|
||||
|
||||
// Time at which the header filter was created (ISO 8601 Datetime).
|
||||
// example: 2021-07-30T09:20:25+00:00
|
||||
// readonly: true
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// HeaderFilterRequest is the form submitted as a POST to create a new header filter entry (allow / block).
|
||||
//
|
||||
// swagger:model headerFilterCreateRequest
|
||||
type HeaderFilterRequest struct {
|
||||
// The HTTP header to match against (e.g. User-Agent).
|
||||
Header string `form:"header" json:"header" xml:"header"`
|
||||
|
||||
// The header value matching regular expression.
|
||||
Regex string `form:"regex" json:"regex" xml:"regex"`
|
||||
}
|
||||
|
|
@ -39,14 +39,17 @@ var (
|
|||
StatusAcceptedJSON = mustJSON(map[string]string{
|
||||
"status": http.StatusText(http.StatusAccepted),
|
||||
})
|
||||
StatusForbiddenJSON = mustJSON(map[string]string{
|
||||
"status": http.StatusText(http.StatusForbidden),
|
||||
})
|
||||
StatusInternalServerErrorJSON = mustJSON(map[string]string{
|
||||
"status": http.StatusText(http.StatusInternalServerError),
|
||||
})
|
||||
ErrorCapacityExceeded = mustJSON(map[string]string{
|
||||
"error": "server capacity exceeded!",
|
||||
"error": "server capacity exceeded",
|
||||
})
|
||||
ErrorRateLimitReached = mustJSON(map[string]string{
|
||||
"error": "rate limit reached!",
|
||||
ErrorRateLimited = mustJSON(map[string]string{
|
||||
"error": "rate limit reached",
|
||||
})
|
||||
EmptyJSONObject = mustJSON("{}")
|
||||
EmptyJSONArray = mustJSON("[]")
|
||||
|
|
|
|||
22
internal/cache/cache.go
vendored
22
internal/cache/cache.go
vendored
|
|
@ -18,21 +18,26 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache/headerfilter"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
)
|
||||
|
||||
type Caches struct {
|
||||
// GTS provides access to the collection of gtsmodel object caches.
|
||||
// (used by the database).
|
||||
// GTS provides access to the collection of
|
||||
// gtsmodel object caches. (used by the database).
|
||||
GTS GTSCaches
|
||||
|
||||
// AP provides access to the collection of ActivityPub object caches.
|
||||
// (planned to be used by the typeconverter).
|
||||
AP APCaches
|
||||
// AllowHeaderFilters provides access to
|
||||
// the allow []headerfilter.Filter cache.
|
||||
AllowHeaderFilters headerfilter.Cache
|
||||
|
||||
// Visibility provides access to the item visibility cache.
|
||||
// (used by the visibility filter).
|
||||
// BlockHeaderFilters provides access to
|
||||
// the block []headerfilter.Filter cache.
|
||||
BlockHeaderFilters headerfilter.Cache
|
||||
|
||||
// Visibility provides access to the item visibility
|
||||
// cache. (used by the visibility filter).
|
||||
Visibility VisibilityCache
|
||||
|
||||
// prevent pass-by-value.
|
||||
|
|
@ -45,7 +50,6 @@ func (c *Caches) Init() {
|
|||
log.Infof(nil, "init: %p", c)
|
||||
|
||||
c.GTS.Init()
|
||||
c.AP.Init()
|
||||
c.Visibility.Init()
|
||||
|
||||
// Setup cache invalidate hooks.
|
||||
|
|
@ -58,7 +62,6 @@ func (c *Caches) Start() {
|
|||
log.Infof(nil, "start: %p", c)
|
||||
|
||||
c.GTS.Start()
|
||||
c.AP.Start()
|
||||
c.Visibility.Start()
|
||||
}
|
||||
|
||||
|
|
@ -67,7 +70,6 @@ func (c *Caches) Stop() {
|
|||
log.Infof(nil, "stop: %p", c)
|
||||
|
||||
c.GTS.Stop()
|
||||
c.AP.Stop()
|
||||
c.Visibility.Stop()
|
||||
}
|
||||
|
||||
|
|
|
|||
32
internal/cache/domain/domain.go
vendored
32
internal/cache/domain/domain.go
vendored
|
|
@ -21,7 +21,6 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
|
@ -37,17 +36,17 @@ import (
|
|||
// The .Clear() function can be used to invalidate the cache,
|
||||
// e.g. when an entry is added / deleted from the database.
|
||||
type Cache struct {
|
||||
// atomically updated ptr value to the
|
||||
// current domain cache radix trie.
|
||||
rootptr unsafe.Pointer
|
||||
rootptr atomic.Pointer[root]
|
||||
}
|
||||
|
||||
// Matches checks whether domain matches an entry in the cache.
|
||||
// If the cache is not currently loaded, then the provided load
|
||||
// function is used to hydrate it.
|
||||
func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, error) {
|
||||
// Load the current root pointer value.
|
||||
ptr := atomic.LoadPointer(&c.rootptr)
|
||||
// Load the current
|
||||
// root pointer value.
|
||||
ptr := c.rootptr.Load()
|
||||
|
||||
if ptr == nil {
|
||||
// Cache is not hydrated.
|
||||
|
|
@ -60,35 +59,32 @@ func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, err
|
|||
|
||||
// Allocate new radix trie
|
||||
// node to store matches.
|
||||
root := new(root)
|
||||
ptr = new(root)
|
||||
|
||||
// Add each domain to the trie.
|
||||
for _, domain := range domains {
|
||||
root.Add(domain)
|
||||
ptr.Add(domain)
|
||||
}
|
||||
|
||||
// Sort the trie.
|
||||
root.Sort()
|
||||
ptr.Sort()
|
||||
|
||||
// Store the new node ptr.
|
||||
ptr = unsafe.Pointer(root)
|
||||
atomic.StorePointer(&c.rootptr, ptr)
|
||||
// Store new node ptr.
|
||||
c.rootptr.Store(ptr)
|
||||
}
|
||||
|
||||
// Look for a match in the trie node.
|
||||
return (*root)(ptr).Match(domain), nil
|
||||
// Look for match in trie node.
|
||||
return ptr.Match(domain), nil
|
||||
}
|
||||
|
||||
// Clear will drop the currently loaded domain list,
|
||||
// triggering a reload on next call to .Matches().
|
||||
func (c *Cache) Clear() {
|
||||
atomic.StorePointer(&c.rootptr, nil)
|
||||
}
|
||||
func (c *Cache) Clear() { c.rootptr.Store(nil) }
|
||||
|
||||
// String returns a string representation of stored domains in cache.
|
||||
func (c *Cache) String() string {
|
||||
if ptr := atomic.LoadPointer(&c.rootptr); ptr != nil {
|
||||
return (*root)(ptr).String()
|
||||
if ptr := c.rootptr.Load(); ptr != nil {
|
||||
return ptr.String()
|
||||
}
|
||||
return "<empty>"
|
||||
}
|
||||
|
|
|
|||
105
internal/cache/headerfilter/filter.go
vendored
Normal file
105
internal/cache/headerfilter/filter.go
vendored
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package headerfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/headerfilter"
|
||||
)
|
||||
|
||||
// Cache provides a means of caching headerfilter.Filters in
|
||||
// memory to reduce load on an underlying storage mechanism.
|
||||
type Cache struct {
|
||||
// current cached header filters slice.
|
||||
ptr atomic.Pointer[headerfilter.Filters]
|
||||
}
|
||||
|
||||
// RegularMatch performs .RegularMatch() on cached headerfilter.Filters, loading using callback if necessary.
|
||||
func (c *Cache) RegularMatch(h http.Header, load func() ([]*gtsmodel.HeaderFilter, error)) (string, string, error) {
|
||||
// Load ptr value.
|
||||
ptr := c.ptr.Load()
|
||||
|
||||
if ptr == nil {
|
||||
// Cache is not hydrated.
|
||||
// Load filters from callback.
|
||||
filters, err := loadFilters(load)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Store the new
|
||||
// header filters.
|
||||
ptr = &filters
|
||||
c.ptr.Store(ptr)
|
||||
}
|
||||
|
||||
// Deref and perform match.
|
||||
return ptr.RegularMatch(h)
|
||||
}
|
||||
|
||||
// InverseMatch performs .InverseMatch() on cached headerfilter.Filters, loading using callback if necessary.
|
||||
func (c *Cache) InverseMatch(h http.Header, load func() ([]*gtsmodel.HeaderFilter, error)) (string, string, error) {
|
||||
// Load ptr value.
|
||||
ptr := c.ptr.Load()
|
||||
|
||||
if ptr == nil {
|
||||
// Cache is not hydrated.
|
||||
// Load filters from callback.
|
||||
filters, err := loadFilters(load)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Store the new
|
||||
// header filters.
|
||||
ptr = &filters
|
||||
c.ptr.Store(ptr)
|
||||
}
|
||||
|
||||
// Deref and perform match.
|
||||
return ptr.InverseMatch(h)
|
||||
}
|
||||
|
||||
// Clear will drop the currently loaded filters,
|
||||
// triggering a reload on next call to ._Match().
|
||||
func (c *Cache) Clear() { c.ptr.Store(nil) }
|
||||
|
||||
// loadFilters will load filters from given load callback, creating and parsing raw filters.
|
||||
func loadFilters(load func() ([]*gtsmodel.HeaderFilter, error)) (headerfilter.Filters, error) {
|
||||
// Load filters from callback.
|
||||
hdrFilters, err := load()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reloading cache: %w", err)
|
||||
}
|
||||
|
||||
// Allocate new header filter slice to store expressions.
|
||||
filters := make(headerfilter.Filters, 0, len(hdrFilters))
|
||||
|
||||
// Add all raw expression to filter slice.
|
||||
for _, filter := range hdrFilters {
|
||||
if err := filters.Append(filter.Header, filter.Regex); err != nil {
|
||||
return nil, fmt.Errorf("error appending exprs: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return filters, nil
|
||||
}
|
||||
|
|
@ -163,6 +163,7 @@ type Configuration struct {
|
|||
AdvancedThrottlingRetryAfter time.Duration `name:"advanced-throttling-retry-after" usage:"Retry-After duration response to send for throttled requests."`
|
||||
AdvancedSenderMultiplier int `name:"advanced-sender-multiplier" usage:"Multiplier to use per cpu for batching outgoing fedi messages. 0 or less turns batching off (not recommended)."`
|
||||
AdvancedCSPExtraURIs []string `name:"advanced-csp-extra-uris" usage:"Additional URIs to allow when building content-security-policy for media + images."`
|
||||
AdvancedHeaderFilterMode string `name:"advanced-header-filter-mode" usage:"Set incoming request header filtering mode."`
|
||||
|
||||
// HTTPClient configuration vars.
|
||||
HTTPClient HTTPClientConfiguration `name:"http-client"`
|
||||
|
|
|
|||
|
|
@ -17,10 +17,16 @@
|
|||
|
||||
package config
|
||||
|
||||
// Instance federation mode determines how this
|
||||
// instance federates with others (if at all).
|
||||
const (
|
||||
// Instance federation mode determines how this
|
||||
// instance federates with others (if at all).
|
||||
InstanceFederationModeBlocklist = "blocklist"
|
||||
InstanceFederationModeAllowlist = "allowlist"
|
||||
InstanceFederationModeDefault = InstanceFederationModeBlocklist
|
||||
|
||||
// Request header filter mode determines how
|
||||
// this instance will perform request filtering.
|
||||
RequestHeaderFilterModeAllow = "allow"
|
||||
RequestHeaderFilterModeBlock = "block"
|
||||
RequestHeaderFilterModeDisabled = ""
|
||||
)
|
||||
|
|
|
|||
|
|
@ -135,6 +135,7 @@ var Defaults = Configuration{
|
|||
AdvancedThrottlingRetryAfter: time.Second * 30,
|
||||
AdvancedSenderMultiplier: 2, // 2 senders per CPU
|
||||
AdvancedCSPExtraURIs: []string{},
|
||||
AdvancedHeaderFilterMode: RequestHeaderFilterModeDisabled,
|
||||
|
||||
Cache: CacheConfiguration{
|
||||
// Rough memory target that the total
|
||||
|
|
|
|||
|
|
@ -158,6 +158,7 @@ func (s *ConfigState) AddServerFlags(cmd *cobra.Command) {
|
|||
cmd.Flags().Duration(AdvancedThrottlingRetryAfterFlag(), cfg.AdvancedThrottlingRetryAfter, fieldtag("AdvancedThrottlingRetryAfter", "usage"))
|
||||
cmd.Flags().Int(AdvancedSenderMultiplierFlag(), cfg.AdvancedSenderMultiplier, fieldtag("AdvancedSenderMultiplier", "usage"))
|
||||
cmd.Flags().StringSlice(AdvancedCSPExtraURIsFlag(), cfg.AdvancedCSPExtraURIs, fieldtag("AdvancedCSPExtraURIs", "usage"))
|
||||
cmd.Flags().String(AdvancedHeaderFilterModeFlag(), cfg.AdvancedHeaderFilterMode, fieldtag("AdvancedHeaderFilterMode", "usage"))
|
||||
|
||||
cmd.Flags().String(RequestIDHeaderFlag(), cfg.RequestIDHeader, fieldtag("RequestIDHeader", "usage"))
|
||||
})
|
||||
|
|
|
|||
|
|
@ -2600,6 +2600,31 @@ func GetAdvancedCSPExtraURIs() []string { return global.GetAdvancedCSPExtraURIs(
|
|||
// SetAdvancedCSPExtraURIs safely sets the value for global configuration 'AdvancedCSPExtraURIs' field
|
||||
func SetAdvancedCSPExtraURIs(v []string) { global.SetAdvancedCSPExtraURIs(v) }
|
||||
|
||||
// GetAdvancedHeaderFilterMode safely fetches the Configuration value for state's 'AdvancedHeaderFilterMode' field
|
||||
func (st *ConfigState) GetAdvancedHeaderFilterMode() (v string) {
|
||||
st.mutex.RLock()
|
||||
v = st.config.AdvancedHeaderFilterMode
|
||||
st.mutex.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// SetAdvancedHeaderFilterMode safely sets the Configuration value for state's 'AdvancedHeaderFilterMode' field
|
||||
func (st *ConfigState) SetAdvancedHeaderFilterMode(v string) {
|
||||
st.mutex.Lock()
|
||||
defer st.mutex.Unlock()
|
||||
st.config.AdvancedHeaderFilterMode = v
|
||||
st.reloadToViper()
|
||||
}
|
||||
|
||||
// AdvancedHeaderFilterModeFlag returns the flag name for the 'AdvancedHeaderFilterMode' field
|
||||
func AdvancedHeaderFilterModeFlag() string { return "advanced-header-filter-mode" }
|
||||
|
||||
// GetAdvancedHeaderFilterMode safely fetches the value for global configuration 'AdvancedHeaderFilterMode' field
|
||||
func GetAdvancedHeaderFilterMode() string { return global.GetAdvancedHeaderFilterMode() }
|
||||
|
||||
// SetAdvancedHeaderFilterMode safely sets the value for global configuration 'AdvancedHeaderFilterMode' field
|
||||
func SetAdvancedHeaderFilterMode(v string) { global.SetAdvancedHeaderFilterMode(v) }
|
||||
|
||||
// GetHTTPClientAllowIPs safely fetches the Configuration value for state's 'HTTPClient.AllowIPs' field
|
||||
func (st *ConfigState) GetHTTPClientAllowIPs() (v []string) {
|
||||
st.mutex.RLock()
|
||||
|
|
|
|||
|
|
@ -25,10 +25,6 @@ type Basic interface {
|
|||
// For implementations that don't use tables, this can just return nil.
|
||||
CreateTable(ctx context.Context, i interface{}) error
|
||||
|
||||
// CreateAllTables creates *all* tables necessary for the running of GoToSocial.
|
||||
// Because it uses the 'if not exists' parameter it is safe to run against a GtS that's already been initialized.
|
||||
CreateAllTables(ctx context.Context) error
|
||||
|
||||
// DropTable drops the table for the given interface.
|
||||
// For implementations that don't use tables, this can just return nil.
|
||||
DropTable(ctx context.Context, i interface{}) error
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ import (
|
|||
"errors"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
|
@ -120,39 +119,6 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (b *basicDB) CreateAllTables(ctx context.Context) error {
|
||||
models := []interface{}{
|
||||
>smodel.Account{},
|
||||
>smodel.Application{},
|
||||
>smodel.Block{},
|
||||
>smodel.DomainBlock{},
|
||||
>smodel.EmailDomainBlock{},
|
||||
>smodel.Follow{},
|
||||
>smodel.FollowRequest{},
|
||||
>smodel.MediaAttachment{},
|
||||
>smodel.Mention{},
|
||||
>smodel.Status{},
|
||||
>smodel.StatusToEmoji{},
|
||||
>smodel.StatusFave{},
|
||||
>smodel.StatusBookmark{},
|
||||
>smodel.ThreadMute{},
|
||||
>smodel.Tag{},
|
||||
>smodel.User{},
|
||||
>smodel.Emoji{},
|
||||
>smodel.Instance{},
|
||||
>smodel.Notification{},
|
||||
>smodel.RouterSession{},
|
||||
>smodel.Token{},
|
||||
>smodel.Client{},
|
||||
}
|
||||
for _, i := range models {
|
||||
if err := b.CreateTable(ctx, i); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *basicDB) DropTable(ctx context.Context, i interface{}) error {
|
||||
_, err := b.db.NewDropTable().Model(i).IfExists().Exec(ctx)
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ type DBService struct {
|
|||
db.Basic
|
||||
db.Domain
|
||||
db.Emoji
|
||||
db.HeaderFilter
|
||||
db.Instance
|
||||
db.List
|
||||
db.Marker
|
||||
|
|
@ -193,6 +194,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
|
|||
db: db,
|
||||
state: state,
|
||||
},
|
||||
HeaderFilter: &headerFilterDB{
|
||||
db: db,
|
||||
state: state,
|
||||
},
|
||||
Instance: &instanceDB{
|
||||
db: db,
|
||||
state: state,
|
||||
|
|
|
|||
207
internal/db/bundb/headerfilter.go
Normal file
207
internal/db/bundb/headerfilter.go
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package bundb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type headerFilterDB struct {
|
||||
db *DB
|
||||
state *state.State
|
||||
}
|
||||
|
||||
func (h *headerFilterDB) AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) {
|
||||
return h.state.Caches.AllowHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
|
||||
return h.GetAllowHeaderFilters(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *headerFilterDB) AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) {
|
||||
return h.state.Caches.AllowHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
|
||||
return h.GetAllowHeaderFilters(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *headerFilterDB) BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) {
|
||||
return h.state.Caches.BlockHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
|
||||
return h.GetBlockHeaderFilters(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *headerFilterDB) BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) {
|
||||
return h.state.Caches.BlockHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
|
||||
return h.GetBlockHeaderFilters(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *headerFilterDB) GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) {
|
||||
filter := new(gtsmodel.HeaderFilterAllow)
|
||||
if err := h.db.NewSelect().
|
||||
Model(filter).
|
||||
Where("? = ?", bun.Ident("id"), id).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fromAllowFilter(filter), nil
|
||||
}
|
||||
|
||||
func (h *headerFilterDB) GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) {
|
||||
filter := new(gtsmodel.HeaderFilterBlock)
|
||||
if err := h.db.NewSelect().
|
||||
Model(filter).
|
||||
Where("? = ?", bun.Ident("id"), id).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fromBlockFilter(filter), nil
|
||||
}
|
||||
|
||||
func (h *headerFilterDB) GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) {
|
||||
var filters []*gtsmodel.HeaderFilterAllow
|
||||
err := h.db.NewSelect().
|
||||
Model(&filters).
|
||||
Scan(ctx, &filters)
|
||||
return fromAllowFilters(filters), err
|
||||
}
|
||||
|
||||
func (h *headerFilterDB) GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) {
|
||||
var filters []*gtsmodel.HeaderFilterBlock
|
||||
err := h.db.NewSelect().
|
||||
Model(&filters).
|
||||
Scan(ctx, &filters)
|
||||
return fromBlockFilters(filters), err
|
||||
}
|
||||
|
||||
func (h *headerFilterDB) PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error {
|
||||
if _, err := h.db.NewInsert().
|
||||
Model(toAllowFilter(filter)).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
h.state.Caches.AllowHeaderFilters.Clear()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *headerFilterDB) PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error {
|
||||
if _, err := h.db.NewInsert().
|
||||
Model(toBlockFilter(filter)).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
h.state.Caches.BlockHeaderFilters.Clear()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *headerFilterDB) UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error {
|
||||
filter.UpdatedAt = time.Now()
|
||||
if len(cols) > 0 {
|
||||
// If we're updating by column,
|
||||
// ensure "updated_at" is included.
|
||||
cols = append(cols, "updated_at")
|
||||
}
|
||||
if _, err := h.db.NewUpdate().
|
||||
Model(toAllowFilter(filter)).
|
||||
Column(cols...).
|
||||
Where("? = ?", bun.Ident("id"), filter.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
h.state.Caches.AllowHeaderFilters.Clear()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *headerFilterDB) UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error {
|
||||
filter.UpdatedAt = time.Now()
|
||||
if len(cols) > 0 {
|
||||
// If we're updating by column,
|
||||
// ensure "updated_at" is included.
|
||||
cols = append(cols, "updated_at")
|
||||
}
|
||||
if _, err := h.db.NewUpdate().
|
||||
Model(toBlockFilter(filter)).
|
||||
Column(cols...).
|
||||
Where("? = ?", bun.Ident("id"), filter.ID).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
h.state.Caches.BlockHeaderFilters.Clear()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *headerFilterDB) DeleteAllowHeaderFilter(ctx context.Context, id string) error {
|
||||
if _, err := h.db.NewDelete().
|
||||
Table("header_filter_allows").
|
||||
Where("? = ?", bun.Ident("id"), id).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
h.state.Caches.AllowHeaderFilters.Clear()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *headerFilterDB) DeleteBlockHeaderFilter(ctx context.Context, id string) error {
|
||||
if _, err := h.db.NewDelete().
|
||||
Table("header_filter_blocks").
|
||||
Where("? = ?", bun.Ident("id"), id).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
h.state.Caches.BlockHeaderFilters.Clear()
|
||||
return nil
|
||||
}
|
||||
|
||||
// NOTE:
|
||||
// all of the below unsafe cast functions
|
||||
// are only possible because HeaderFilterAllow{},
|
||||
// HeaderFilterBlock{}, HeaderFilter{} while
|
||||
// different types in source, have exactly the
|
||||
// same size and layout in memory. the unsafe
|
||||
// cast simply changes the type associated with
|
||||
// that block of memory.
|
||||
|
||||
func toAllowFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterAllow {
|
||||
return (*gtsmodel.HeaderFilterAllow)(unsafe.Pointer(filter))
|
||||
}
|
||||
|
||||
func toBlockFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterBlock {
|
||||
return (*gtsmodel.HeaderFilterBlock)(unsafe.Pointer(filter))
|
||||
}
|
||||
|
||||
func fromAllowFilter(filter *gtsmodel.HeaderFilterAllow) *gtsmodel.HeaderFilter {
|
||||
return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter))
|
||||
}
|
||||
|
||||
func fromBlockFilter(filter *gtsmodel.HeaderFilterBlock) *gtsmodel.HeaderFilter {
|
||||
return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter))
|
||||
}
|
||||
|
||||
func fromAllowFilters(filters []*gtsmodel.HeaderFilterAllow) []*gtsmodel.HeaderFilter {
|
||||
return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters))
|
||||
}
|
||||
|
||||
func fromBlockFilters(filters []*gtsmodel.HeaderFilterBlock) []*gtsmodel.HeaderFilter {
|
||||
return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters))
|
||||
}
|
||||
125
internal/db/bundb/headerfilter_test.go
Normal file
125
internal/db/bundb/headerfilter_test.go
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package bundb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
||||
type HeaderFilterTestSuite struct {
|
||||
BunDBStandardTestSuite
|
||||
}
|
||||
|
||||
func (suite *HeaderFilterTestSuite) TestAllowHeaderFilterGetPutUpdateDelete() {
|
||||
suite.testHeaderFilterGetPutUpdateDelete(
|
||||
suite.db.GetAllowHeaderFilter,
|
||||
suite.db.GetAllowHeaderFilters,
|
||||
suite.db.PutAllowHeaderFilter,
|
||||
suite.db.UpdateAllowHeaderFilter,
|
||||
suite.db.DeleteAllowHeaderFilter,
|
||||
)
|
||||
}
|
||||
|
||||
func (suite *HeaderFilterTestSuite) TestBlockHeaderFilterGetPutUpdateDelete() {
|
||||
suite.testHeaderFilterGetPutUpdateDelete(
|
||||
suite.db.GetBlockHeaderFilter,
|
||||
suite.db.GetBlockHeaderFilters,
|
||||
suite.db.PutBlockHeaderFilter,
|
||||
suite.db.UpdateBlockHeaderFilter,
|
||||
suite.db.DeleteBlockHeaderFilter,
|
||||
)
|
||||
}
|
||||
|
||||
func (suite *HeaderFilterTestSuite) testHeaderFilterGetPutUpdateDelete(
|
||||
get func(context.Context, string) (*gtsmodel.HeaderFilter, error),
|
||||
getAll func(context.Context) ([]*gtsmodel.HeaderFilter, error),
|
||||
put func(context.Context, *gtsmodel.HeaderFilter) error,
|
||||
update func(context.Context, *gtsmodel.HeaderFilter, ...string) error,
|
||||
delete func(context.Context, string) error,
|
||||
) {
|
||||
t := suite.T()
|
||||
|
||||
// Create new example header filter.
|
||||
filter := gtsmodel.HeaderFilter{
|
||||
ID: "some unique id",
|
||||
Header: "Http-Header-Key",
|
||||
Regex: ".*",
|
||||
AuthorID: "some unique author id",
|
||||
}
|
||||
|
||||
// Create new cancellable test context.
|
||||
ctx := context.Background()
|
||||
ctx, cncl := context.WithCancel(ctx)
|
||||
defer cncl()
|
||||
|
||||
// Insert the example header filter into db.
|
||||
if err := put(ctx, &filter); err != nil {
|
||||
t.Fatalf("error inserting header filter: %v", err)
|
||||
}
|
||||
|
||||
// Now fetch newly created filter.
|
||||
check, err := get(ctx, filter.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("error fetching header filter: %v", err)
|
||||
}
|
||||
|
||||
// Check all expected fields match.
|
||||
suite.Equal(filter.ID, check.ID)
|
||||
suite.Equal(filter.Header, check.Header)
|
||||
suite.Equal(filter.Regex, check.Regex)
|
||||
suite.Equal(filter.AuthorID, check.AuthorID)
|
||||
|
||||
// Fetch all header filters.
|
||||
all, err := getAll(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("error fetching header filters: %v", err)
|
||||
}
|
||||
|
||||
// Ensure contains example.
|
||||
suite.Equal(len(all), 1)
|
||||
suite.Equal(all[0].ID, filter.ID)
|
||||
|
||||
// Update the header filter regex value.
|
||||
check.Regex = "new regex value"
|
||||
if err := update(ctx, check); err != nil {
|
||||
t.Fatalf("error updating header filter: %v", err)
|
||||
}
|
||||
|
||||
// Ensure 'updated_at' was updated on check model.
|
||||
suite.True(check.UpdatedAt.After(filter.UpdatedAt))
|
||||
|
||||
// Now delete the header filter from db.
|
||||
if err := delete(ctx, filter.ID); err != nil {
|
||||
t.Fatalf("error deleting header filter: %v", err)
|
||||
}
|
||||
|
||||
// Ensure we can't refetch it.
|
||||
_, err = get(ctx, filter.ID)
|
||||
if err != db.ErrNoEntries {
|
||||
t.Fatalf("deleted header filter returned unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeaderFilterTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(HeaderFilterTestSuite))
|
||||
}
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func init() {
|
||||
up := func(ctx context.Context, db *bun.DB) error {
|
||||
for _, model := range []any{
|
||||
>smodel.HeaderFilterAllow{},
|
||||
>smodel.HeaderFilterBlock{},
|
||||
} {
|
||||
_, err := db.NewCreateTable().
|
||||
IfNotExists().
|
||||
Model(model).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
down := func(ctx context.Context, db *bun.DB) error {
|
||||
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := Migrations.Register(up, down); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
|
@ -30,6 +30,7 @@ type DB interface {
|
|||
Basic
|
||||
Domain
|
||||
Emoji
|
||||
HeaderFilter
|
||||
Instance
|
||||
List
|
||||
Marker
|
||||
|
|
|
|||
73
internal/db/headerfilter.go
Normal file
73
internal/db/headerfilter.go
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
||||
type HeaderFilter interface {
|
||||
// AllowHeaderRegularMatch performs an headerfilter.Filter.RegularMatch() on cached allow header filters.
|
||||
// (Note: the actual matching code can be found under ./internal/headerfilter/ ).
|
||||
AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error)
|
||||
|
||||
// AllowHeaderInverseMatch performs an headerfilter.Filter.InverseMatch() on cached allow header filters.
|
||||
// (Note: the actual matching code can be found under ./internal/headerfilter/ ).
|
||||
AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error)
|
||||
|
||||
// BlockHeaderRegularMatch performs an headerfilter.Filter.RegularMatch() on cached block header filters.
|
||||
// (Note: the actual matching code can be found under ./internal/headerfilter/ ).
|
||||
BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error)
|
||||
|
||||
// BlockHeaderInverseMatch performs an headerfilter.Filter.InverseMatch() on cached block header filters.
|
||||
// (Note: the actual matching code can be found under ./internal/headerfilter/ ).
|
||||
BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error)
|
||||
|
||||
// GetAllowHeaderFilter fetches the allow header filter with ID from the database.
|
||||
GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error)
|
||||
|
||||
// GetBlockHeaderFilter fetches the block header filter with ID from the database.
|
||||
GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error)
|
||||
|
||||
// GetAllowHeaderFilters fetches all allow header filters from the database.
|
||||
GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error)
|
||||
|
||||
// GetBlockHeaderFilters fetches all block header filters from the database.
|
||||
GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error)
|
||||
|
||||
// PutAllowHeaderFilter inserts the given allow header filter into the database.
|
||||
PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error
|
||||
|
||||
// PutBlockHeaderFilter inserts the given block header filter into the database.
|
||||
PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error
|
||||
|
||||
// UpdateAllowHeaderFilter updates the given allow header filter in the database, only updating given columns if provided.
|
||||
UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error
|
||||
|
||||
// UpdateBlockHeaderFilter updates the given block header filter in the database, only updating given columns if provided.
|
||||
UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error
|
||||
|
||||
// DeleteAllowHeaderFilter deletes the allow header filter with ID from the database.
|
||||
DeleteAllowHeaderFilter(ctx context.Context, id string) error
|
||||
|
||||
// DeleteBlockHeaderFilter deletes the block header filter with ID from the database.
|
||||
DeleteBlockHeaderFilter(ctx context.Context, id string) error
|
||||
}
|
||||
54
internal/gtsmodel/headerfilter.go
Normal file
54
internal/gtsmodel/headerfilter.go
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package gtsmodel
|
||||
|
||||
import (
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Note that since all of the below calculations are
|
||||
// constant, these should be optimized out of builds.
|
||||
const filterSz = unsafe.Sizeof(HeaderFilter{})
|
||||
if unsafe.Sizeof(HeaderFilterAllow{}) != filterSz {
|
||||
panic("HeaderFilterAllow{} needs to have the same in-memory size / layout as HeaderFilter{}")
|
||||
}
|
||||
if unsafe.Sizeof(HeaderFilterBlock{}) != filterSz {
|
||||
panic("HeaderFilterBlock{} needs to have the same in-memory size / layout as HeaderFilter{}")
|
||||
}
|
||||
}
|
||||
|
||||
// HeaderFilterAllow represents an allow HTTP header filter in the database.
|
||||
type HeaderFilterAllow struct{ HeaderFilter }
|
||||
|
||||
// HeaderFilterBlock represents a block HTTP header filter in the database.
|
||||
type HeaderFilterBlock struct{ HeaderFilter }
|
||||
|
||||
// HeaderFilter represents an HTTP request filter in
|
||||
// the database, with a header to match against, value
|
||||
// matching regex, and details about its creation.
|
||||
type HeaderFilter struct {
|
||||
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // ID of this item in the database
|
||||
Header string `bun:",nullzero,notnull,unique:header_regex"` // Request header this filter pertains to
|
||||
Regex string `bun:",nullzero,notnull,unique:header_regex"` // Request header value matching regular expression
|
||||
AuthorID string `bun:"type:CHAR(26),nullzero,notnull"` // Account ID of the creator of this filter
|
||||
Author *Account `bun:"-"` // Account corresponding to AuthorID
|
||||
CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created
|
||||
UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated
|
||||
}
|
||||
136
internal/headerfilter/filter.go
Normal file
136
internal/headerfilter/filter.go
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package headerfilter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// Maximum header value size before we return
|
||||
// an instant negative match. They shouldn't
|
||||
// go beyond this size in most cases anywho.
|
||||
const MaxHeaderValue = 1024
|
||||
|
||||
// ErrLargeHeaderValue is returned on attempting to match on a value > MaxHeaderValue.
|
||||
var ErrLargeHeaderValue = errors.New("header value too large")
|
||||
|
||||
// Filters represents a set of http.Header regular
|
||||
// expression filters built-in statistic tracking.
|
||||
type Filters []headerfilter
|
||||
|
||||
type headerfilter struct {
|
||||
// key is the header key to match against
|
||||
// in canonical textproto mime header format.
|
||||
key string
|
||||
|
||||
// exprs contains regular expressions to
|
||||
// match values against for this header key.
|
||||
exprs []*regexp.Regexp
|
||||
}
|
||||
|
||||
// Append will add new header filter expression under given header key.
|
||||
func (fs *Filters) Append(key string, expr string) error {
|
||||
var filter *headerfilter
|
||||
|
||||
// Ensure in canonical mime header format.
|
||||
key = textproto.CanonicalMIMEHeaderKey(key)
|
||||
|
||||
// Look for existing filter
|
||||
// with key in filter slice.
|
||||
for i := range *fs {
|
||||
if (*fs)[i].key == key {
|
||||
filter = &((*fs)[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if filter == nil {
|
||||
// No existing filter found, create new.
|
||||
|
||||
// Append new header filter to slice.
|
||||
(*fs) = append((*fs), headerfilter{})
|
||||
|
||||
// Then take ptr to this new filter
|
||||
// at the last index in the slice.
|
||||
filter = &((*fs)[len((*fs))-1])
|
||||
|
||||
// Setup new key.
|
||||
filter.key = key
|
||||
}
|
||||
|
||||
// Compile regular expression.
|
||||
reg, err := regexp.Compile(expr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error compiling regexp %q: %w", expr, err)
|
||||
}
|
||||
|
||||
// Append regular expression to filter.
|
||||
filter.exprs = append(filter.exprs, reg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegularMatch returns whether any values in http header
|
||||
// matches any of the receiving filter regular expressions.
|
||||
// This returns the matched header key, and matching regexp.
|
||||
func (fs Filters) RegularMatch(h http.Header) (string, string, error) {
|
||||
for _, filter := range fs {
|
||||
for _, value := range h[filter.key] {
|
||||
// Don't perform match on large values
|
||||
// to mitigate denial of service attacks.
|
||||
if len(value) > MaxHeaderValue {
|
||||
return "", "", ErrLargeHeaderValue
|
||||
}
|
||||
|
||||
// Compare against regular exprs.
|
||||
for _, expr := range filter.exprs {
|
||||
if expr.MatchString(value) {
|
||||
return filter.key, expr.String(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
// InverseMatch returns whether any values in http header do
|
||||
// NOT match any of the receiving filter regular expressions.
|
||||
// This returns the matched header key, and matching regexp.
|
||||
func (fs Filters) InverseMatch(h http.Header) (string, string, error) {
|
||||
for _, filter := range fs {
|
||||
for _, value := range h[filter.key] {
|
||||
// Don't perform match on large values
|
||||
// to mitigate denial of service attacks.
|
||||
if len(value) > MaxHeaderValue {
|
||||
return "", "", ErrLargeHeaderValue
|
||||
}
|
||||
|
||||
// Compare against regular exprs.
|
||||
for _, expr := range filter.exprs {
|
||||
if !expr.MatchString(value) {
|
||||
return filter.key, expr.String(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", "", nil
|
||||
}
|
||||
251
internal/middleware/headerfilter.go
Normal file
251
internal/middleware/headerfilter.go
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/headerfilter"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
)
|
||||
|
||||
var (
|
||||
allowMatches = matchstats{m: make(map[string]uint64)}
|
||||
blockMatches = matchstats{m: make(map[string]uint64)}
|
||||
)
|
||||
|
||||
// matchstats is a simple statistics
|
||||
// counter for header filter matches.
|
||||
// TODO: replace with otel.
|
||||
type matchstats struct {
|
||||
m map[string]uint64
|
||||
l sync.Mutex
|
||||
}
|
||||
|
||||
func (m *matchstats) Add(hdr, regex string) {
|
||||
m.l.Lock()
|
||||
key := hdr + ":" + regex
|
||||
m.m[key]++
|
||||
m.l.Unlock()
|
||||
}
|
||||
|
||||
// HeaderFilter returns a gin middleware handler that provides HTTP
|
||||
// request blocking (filtering) based on database allow / block filters.
|
||||
func HeaderFilter(state *state.State) gin.HandlerFunc {
|
||||
switch mode := config.GetAdvancedHeaderFilterMode(); mode {
|
||||
case config.RequestHeaderFilterModeDisabled:
|
||||
return func(ctx *gin.Context) {}
|
||||
|
||||
case config.RequestHeaderFilterModeAllow:
|
||||
return headerFilterAllowMode(state)
|
||||
|
||||
case config.RequestHeaderFilterModeBlock:
|
||||
return headerFilterBlockMode(state)
|
||||
|
||||
default:
|
||||
panic("unrecognized filter mode: " + mode)
|
||||
}
|
||||
}
|
||||
|
||||
func headerFilterAllowMode(state *state.State) func(c *gin.Context) {
|
||||
_ = *state //nolint
|
||||
// Allowlist mode: explicit block takes
|
||||
// precedence over explicit allow.
|
||||
//
|
||||
// Headers that have neither block
|
||||
// or allow entries are blocked.
|
||||
return func(c *gin.Context) {
|
||||
|
||||
// Check if header is explicitly blocked.
|
||||
block, err := isHeaderBlocked(state, c)
|
||||
if err != nil {
|
||||
respondInternalServerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if block {
|
||||
respondBlocked(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if header is missing explicit allow.
|
||||
notAllow, err := isHeaderNotAllowed(state, c)
|
||||
if err != nil {
|
||||
respondInternalServerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if notAllow {
|
||||
respondBlocked(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Allowed!
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func headerFilterBlockMode(state *state.State) func(c *gin.Context) {
|
||||
_ = *state //nolint
|
||||
// Blocklist/default mode: explicit allow
|
||||
// takes precedence over explicit block.
|
||||
//
|
||||
// Headers that have neither block
|
||||
// or allow entries are allowed.
|
||||
return func(c *gin.Context) {
|
||||
|
||||
// Check if header is explicitly allowed.
|
||||
allow, err := isHeaderAllowed(state, c)
|
||||
if err != nil {
|
||||
respondInternalServerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !allow {
|
||||
// Check if header is explicitly blocked.
|
||||
block, err := isHeaderBlocked(state, c)
|
||||
if err != nil {
|
||||
respondInternalServerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if block {
|
||||
respondBlocked(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Allowed!
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func isHeaderBlocked(state *state.State, c *gin.Context) (bool, error) {
|
||||
var (
|
||||
ctx = c.Request.Context()
|
||||
hdr = c.Request.Header
|
||||
)
|
||||
|
||||
// Perform an explicit is-blocked check on request header.
|
||||
key, expr, err := state.DB.BlockHeaderRegularMatch(ctx, hdr)
|
||||
switch err {
|
||||
case nil:
|
||||
break
|
||||
|
||||
case headerfilter.ErrLargeHeaderValue:
|
||||
log.Warn(ctx, "large header value")
|
||||
key = "*" // block large headers
|
||||
|
||||
default:
|
||||
err := gtserror.Newf("error checking header: %w", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
if key != "" {
|
||||
if expr != "" {
|
||||
// Increment block matches stat.
|
||||
// TODO: replace expvar with build
|
||||
// taggable metrics types in State{}.
|
||||
blockMatches.Add(key, expr)
|
||||
}
|
||||
|
||||
// A header was matched against!
|
||||
// i.e. this request is blocked.
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isHeaderAllowed(state *state.State, c *gin.Context) (bool, error) {
|
||||
var (
|
||||
ctx = c.Request.Context()
|
||||
hdr = c.Request.Header
|
||||
)
|
||||
|
||||
// Perform an explicit is-allowed check on request header.
|
||||
key, expr, err := state.DB.AllowHeaderRegularMatch(ctx, hdr)
|
||||
switch err {
|
||||
case nil:
|
||||
break
|
||||
|
||||
case headerfilter.ErrLargeHeaderValue:
|
||||
log.Warn(ctx, "large header value")
|
||||
key = "" // block large headers
|
||||
|
||||
default:
|
||||
err := gtserror.Newf("error checking header: %w", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
if key != "" {
|
||||
if expr != "" {
|
||||
// Increment allow matches stat.
|
||||
// TODO: replace expvar with build
|
||||
// taggable metrics types in State{}.
|
||||
allowMatches.Add(key, expr)
|
||||
}
|
||||
|
||||
// A header was matched against!
|
||||
// i.e. this request is allowed.
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isHeaderNotAllowed(state *state.State, c *gin.Context) (bool, error) {
|
||||
var (
|
||||
ctx = c.Request.Context()
|
||||
hdr = c.Request.Header
|
||||
)
|
||||
|
||||
// Perform an explicit is-NOT-allowed check on request header.
|
||||
key, expr, err := state.DB.AllowHeaderInverseMatch(ctx, hdr)
|
||||
switch err {
|
||||
case nil:
|
||||
break
|
||||
|
||||
case headerfilter.ErrLargeHeaderValue:
|
||||
log.Warn(ctx, "large header value")
|
||||
key = "*" // block large headers
|
||||
|
||||
default:
|
||||
err := gtserror.Newf("error checking header: %w", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
if key != "" {
|
||||
if expr != "" {
|
||||
// Increment allow matches stat.
|
||||
// TODO: replace expvar with build
|
||||
// taggable metrics types in State{}.
|
||||
allowMatches.Add(key, expr)
|
||||
}
|
||||
|
||||
// A header was matched against!
|
||||
// i.e. request is NOT allowed.
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
299
internal/middleware/headerfilter_test.go
Normal file
299
internal/middleware/headerfilter_test.go
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/headerfilter"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/middleware"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
func TestHeaderFilter(t *testing.T) {
|
||||
testrig.InitTestLog()
|
||||
testrig.InitTestConfig()
|
||||
|
||||
for _, test := range []struct {
|
||||
mode string
|
||||
allow []filter
|
||||
block []filter
|
||||
input http.Header
|
||||
expect bool
|
||||
}{
|
||||
{
|
||||
// Allow mode with expected 200 OK.
|
||||
mode: config.RequestHeaderFilterModeAllow,
|
||||
allow: []filter{
|
||||
{"User-Agent", ".*Firefox.*"},
|
||||
},
|
||||
block: []filter{},
|
||||
input: http.Header{
|
||||
"User-Agent": []string{"Firefox v169.42; Extra Tracking Info"},
|
||||
},
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
// Allow mode with expected 403 Forbidden.
|
||||
mode: config.RequestHeaderFilterModeAllow,
|
||||
allow: []filter{
|
||||
{"User-Agent", ".*Firefox.*"},
|
||||
},
|
||||
block: []filter{},
|
||||
input: http.Header{
|
||||
"User-Agent": []string{"Chromium v169.42; Extra Tracking Info"},
|
||||
},
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
// Allow mode with too long header value expecting 403 Forbidden.
|
||||
mode: config.RequestHeaderFilterModeAllow,
|
||||
allow: []filter{
|
||||
{"User-Agent", ".*"},
|
||||
},
|
||||
block: []filter{},
|
||||
input: http.Header{
|
||||
"User-Agent": []string{func() string {
|
||||
var buf strings.Builder
|
||||
for i := 0; i < headerfilter.MaxHeaderValue+1; i++ {
|
||||
buf.WriteByte(' ')
|
||||
}
|
||||
return buf.String()
|
||||
}()},
|
||||
},
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
// Allow mode with explicit block expecting 403 Forbidden.
|
||||
mode: config.RequestHeaderFilterModeAllow,
|
||||
allow: []filter{
|
||||
{"User-Agent", ".*Firefox.*"},
|
||||
},
|
||||
block: []filter{
|
||||
{"User-Agent", ".*Firefox v169\\.42.*"},
|
||||
},
|
||||
input: http.Header{
|
||||
"User-Agent": []string{"Firefox v169.42; Extra Tracking Info"},
|
||||
},
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
// Block mode with an expected 403 Forbidden.
|
||||
mode: config.RequestHeaderFilterModeBlock,
|
||||
allow: []filter{},
|
||||
block: []filter{
|
||||
{"User-Agent", ".*Firefox.*"},
|
||||
},
|
||||
input: http.Header{
|
||||
"User-Agent": []string{"Firefox v169.42; Extra Tracking Info"},
|
||||
},
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
// Block mode with an expected 200 OK.
|
||||
mode: config.RequestHeaderFilterModeBlock,
|
||||
allow: []filter{},
|
||||
block: []filter{
|
||||
{"User-Agent", ".*Firefox.*"},
|
||||
},
|
||||
input: http.Header{
|
||||
"User-Agent": []string{"Chromium v169.42; Extra Tracking Info"},
|
||||
},
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
// Block mode with too long header value expecting 403 Forbidden.
|
||||
mode: config.RequestHeaderFilterModeBlock,
|
||||
allow: []filter{},
|
||||
block: []filter{
|
||||
{"User-Agent", "none"},
|
||||
},
|
||||
input: http.Header{
|
||||
"User-Agent": []string{func() string {
|
||||
var buf strings.Builder
|
||||
for i := 0; i < headerfilter.MaxHeaderValue+1; i++ {
|
||||
buf.WriteByte(' ')
|
||||
}
|
||||
return buf.String()
|
||||
}()},
|
||||
},
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
// Block mode with explicit allow expecting 200 OK.
|
||||
mode: config.RequestHeaderFilterModeBlock,
|
||||
allow: []filter{
|
||||
{"User-Agent", ".*Firefox.*"},
|
||||
},
|
||||
block: []filter{
|
||||
{"User-Agent", ".*Firefox v169\\.42.*"},
|
||||
},
|
||||
input: http.Header{
|
||||
"User-Agent": []string{"Firefox v169.42; Extra Tracking Info"},
|
||||
},
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
// Disabled mode with an expected 200 OK.
|
||||
mode: config.RequestHeaderFilterModeDisabled,
|
||||
allow: []filter{
|
||||
{"Key1", "only-this"},
|
||||
{"Key2", "only-this"},
|
||||
{"Key3", "only-this"},
|
||||
},
|
||||
block: []filter{
|
||||
{"Key1", "Value"},
|
||||
{"Key2", "Value"},
|
||||
{"Key3", "Value"},
|
||||
},
|
||||
input: http.Header{
|
||||
"Key1": []string{"Value"},
|
||||
"Key2": []string{"Value"},
|
||||
"Key3": []string{"Value"},
|
||||
},
|
||||
expect: true,
|
||||
},
|
||||
} {
|
||||
// Generate a unique name for this test case.
|
||||
name := fmt.Sprintf("%s allow=%v block=%v => expect=%v",
|
||||
test.mode,
|
||||
test.allow,
|
||||
test.block,
|
||||
test.expect,
|
||||
)
|
||||
|
||||
// Update header filter mode to test case.
|
||||
config.SetAdvancedHeaderFilterMode(test.mode)
|
||||
|
||||
// Run this particular test case.
|
||||
ok := t.Run(name, func(t *testing.T) {
|
||||
testHeaderFilter(t,
|
||||
test.allow,
|
||||
test.block,
|
||||
test.input,
|
||||
test.expect,
|
||||
)
|
||||
})
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testHeaderFilter(t *testing.T, allow, block []filter, input http.Header, expect bool) {
|
||||
var err error
|
||||
|
||||
// Create test context with cancel.
|
||||
ctx := context.Background()
|
||||
ctx, cncl := context.WithCancel(ctx)
|
||||
defer cncl()
|
||||
|
||||
// Initialize caches.
|
||||
var state state.State
|
||||
state.Caches.Init()
|
||||
|
||||
// Create new database instance with test config.
|
||||
state.DB, err = bundb.NewBunDBService(ctx, &state)
|
||||
if err != nil {
|
||||
t.Fatalf("error opening database: %v", err)
|
||||
}
|
||||
|
||||
// Insert all allow filters into DB.
|
||||
for _, filter := range allow {
|
||||
filter := >smodel.HeaderFilter{
|
||||
ID: id.NewULID(),
|
||||
Header: filter.header,
|
||||
Regex: filter.regex,
|
||||
AuthorID: "admin-id",
|
||||
Author: nil,
|
||||
}
|
||||
|
||||
if err := state.DB.PutAllowHeaderFilter(ctx, filter); err != nil {
|
||||
t.Fatalf("error inserting allow filter into database: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Insert all block filters into DB.
|
||||
for _, filter := range block {
|
||||
filter := >smodel.HeaderFilter{
|
||||
ID: id.NewULID(),
|
||||
Header: filter.header,
|
||||
Regex: filter.regex,
|
||||
AuthorID: "admin-id",
|
||||
Author: nil,
|
||||
}
|
||||
|
||||
if err := state.DB.PutBlockHeaderFilter(ctx, filter); err != nil {
|
||||
t.Fatalf("error inserting block filter into database: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Gin test http engine
|
||||
// (used for ctx init).
|
||||
e := gin.New()
|
||||
|
||||
// Create new filter middleware to test against.
|
||||
middleware := middleware.HeaderFilter(&state)
|
||||
e.Use(middleware)
|
||||
|
||||
// Set the empty gin handler (always returns okay).
|
||||
e.Handle("GET", "/", func(ctx *gin.Context) { ctx.Status(200) })
|
||||
|
||||
// Prepare a gin test context.
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// Set input headers.
|
||||
r.Header = input
|
||||
|
||||
// Pass req through
|
||||
// engine handler.
|
||||
e.ServeHTTP(rw, r)
|
||||
|
||||
// Get http result.
|
||||
res := rw.Result()
|
||||
|
||||
switch {
|
||||
case expect && res.StatusCode != http.StatusOK:
|
||||
t.Errorf("unexpected response (should allow): %s", res.Status)
|
||||
|
||||
case !expect && res.StatusCode != http.StatusForbidden:
|
||||
t.Errorf("unexpected response (should block): %s", res.Status)
|
||||
}
|
||||
}
|
||||
|
||||
type filter struct {
|
||||
header string
|
||||
regex string
|
||||
}
|
||||
|
||||
func (hf *filter) String() string {
|
||||
return fmt.Sprintf("%s=%q", hf.header, hf.regex)
|
||||
}
|
||||
|
|
@ -146,7 +146,7 @@ func RateLimit(limit int, exceptions []string) gin.HandlerFunc {
|
|||
apiutil.Data(c,
|
||||
http.StatusTooManyRequests,
|
||||
apiutil.AppJSON,
|
||||
apiutil.ErrorRateLimitReached,
|
||||
apiutil.ErrorRateLimited,
|
||||
)
|
||||
c.Abort()
|
||||
return
|
||||
|
|
|
|||
|
|
@ -18,21 +18,22 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
|
||||
)
|
||||
|
||||
// UserAgent returns a gin middleware which aborts requests with
|
||||
// empty user agent strings, returning code 418 - I'm a teapot.
|
||||
func UserAgent() gin.HandlerFunc {
|
||||
// todo: make this configurable
|
||||
var rsp = []byte(`{"error": "I'm a teapot: no user-agent sent with request"}`)
|
||||
return func(c *gin.Context) {
|
||||
if ua := c.Request.UserAgent(); ua == "" {
|
||||
code := http.StatusTeapot
|
||||
err := errors.New(http.StatusText(code) + ": no user-agent sent with request")
|
||||
c.AbortWithStatusJSON(code, gin.H{"error": err.Error()})
|
||||
apiutil.Data(c,
|
||||
http.StatusTeapot, apiutil.AppJSON, rsp)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
51
internal/middleware/util.go
Normal file
51
internal/middleware/util.go
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
|
||||
)
|
||||
|
||||
// respondBlocked responds to the given gin context with
|
||||
// status forbidden, and a generic forbidden JSON response,
|
||||
// finally aborting the gin handler chain.
|
||||
func respondBlocked(c *gin.Context) {
|
||||
apiutil.Data(c,
|
||||
http.StatusForbidden,
|
||||
apiutil.AppJSON,
|
||||
apiutil.StatusForbiddenJSON,
|
||||
)
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// respondInternalServerError responds to the given gin context
|
||||
// with status internal server error, a generic internal server
|
||||
// error JSON response, sets the given error on the gin context
|
||||
// for later logging, finally aborting the gin handler chain.
|
||||
func respondInternalServerError(c *gin.Context, err error) {
|
||||
apiutil.Data(c,
|
||||
http.StatusInternalServerError,
|
||||
apiutil.AppJSON,
|
||||
apiutil.StatusInternalServerErrorJSON,
|
||||
)
|
||||
_ = c.Error(err)
|
||||
c.Abort()
|
||||
}
|
||||
215
internal/processing/admin/headerfilter.go
Normal file
215
internal/processing/admin/headerfilter.go
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/textproto"
|
||||
"regexp"
|
||||
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/headerfilter"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||
)
|
||||
|
||||
// GetAllowHeaderFilter fetches allow HTTP header filter with provided ID from the database.
|
||||
func (p *Processor) GetAllowHeaderFilter(ctx context.Context, id string) (*apimodel.HeaderFilter, gtserror.WithCode) {
|
||||
return p.getHeaderFilter(ctx, id, p.state.DB.GetAllowHeaderFilter)
|
||||
}
|
||||
|
||||
// GetBlockHeaderFilter fetches block HTTP header filter with provided ID from the database.
|
||||
func (p *Processor) GetBlockHeaderFilter(ctx context.Context, id string) (*apimodel.HeaderFilter, gtserror.WithCode) {
|
||||
return p.getHeaderFilter(ctx, id, p.state.DB.GetBlockHeaderFilter)
|
||||
}
|
||||
|
||||
// GetAllowHeaderFilters fetches all allow HTTP header filters stored in the database.
|
||||
func (p *Processor) GetAllowHeaderFilters(ctx context.Context) ([]*apimodel.HeaderFilter, gtserror.WithCode) {
|
||||
return p.getHeaderFilters(ctx, p.state.DB.GetAllowHeaderFilters)
|
||||
}
|
||||
|
||||
// GetBlockHeaderFilters fetches all block HTTP header filters stored in the database.
|
||||
func (p *Processor) GetBlockHeaderFilters(ctx context.Context) ([]*apimodel.HeaderFilter, gtserror.WithCode) {
|
||||
return p.getHeaderFilters(ctx, p.state.DB.GetBlockHeaderFilters)
|
||||
}
|
||||
|
||||
// CreateAllowHeaderFilter inserts the incoming allow HTTP header filter into the database, marking as authored by provided admin account.
|
||||
func (p *Processor) CreateAllowHeaderFilter(ctx context.Context, admin *gtsmodel.Account, request *apimodel.HeaderFilterRequest) (*apimodel.HeaderFilter, gtserror.WithCode) {
|
||||
return p.createHeaderFilter(ctx, admin, request, p.state.DB.PutAllowHeaderFilter)
|
||||
}
|
||||
|
||||
// CreateBlockHeaderFilter inserts the incoming block HTTP header filter into the database, marking as authored by provided admin account.
|
||||
func (p *Processor) CreateBlockHeaderFilter(ctx context.Context, admin *gtsmodel.Account, request *apimodel.HeaderFilterRequest) (*apimodel.HeaderFilter, gtserror.WithCode) {
|
||||
return p.createHeaderFilter(ctx, admin, request, p.state.DB.PutBlockHeaderFilter)
|
||||
}
|
||||
|
||||
// DeleteAllowHeaderFilter deletes the allowing HTTP header filter with provided ID from the database.
|
||||
func (p *Processor) DeleteAllowHeaderFilter(ctx context.Context, id string) gtserror.WithCode {
|
||||
return p.deleteHeaderFilter(ctx, id, p.state.DB.DeleteAllowHeaderFilter)
|
||||
}
|
||||
|
||||
// DeleteBlockHeaderFilter deletes the blocking HTTP header filter with provided ID from the database.
|
||||
func (p *Processor) DeleteBlockHeaderFilter(ctx context.Context, id string) gtserror.WithCode {
|
||||
return p.deleteHeaderFilter(ctx, id, p.state.DB.DeleteBlockHeaderFilter)
|
||||
}
|
||||
|
||||
// getHeaderFilter fetches an HTTP header filter with
|
||||
// provided ID, using given get function, converting the
|
||||
// resulting filter to returnable frontend API model.
|
||||
func (p *Processor) getHeaderFilter(
|
||||
ctx context.Context,
|
||||
id string,
|
||||
get func(context.Context, string) (*gtsmodel.HeaderFilter, error),
|
||||
) (
|
||||
*apimodel.HeaderFilter,
|
||||
gtserror.WithCode,
|
||||
) {
|
||||
// Select filter by ID from db.
|
||||
filter, err := get(ctx, id)
|
||||
|
||||
switch {
|
||||
// Successfully found.
|
||||
case err == nil:
|
||||
return toAPIHeaderFilter(filter), nil
|
||||
|
||||
// Filter does not exist with ID.
|
||||
case errors.Is(err, db.ErrNoEntries):
|
||||
const text = "filter not found"
|
||||
return nil, gtserror.NewErrorNotFound(errors.New(text), text)
|
||||
|
||||
// Any other error type.
|
||||
default:
|
||||
err := gtserror.Newf("error selecting from database: %w", err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// getHeaderFilters fetches all HTTP header filters
|
||||
// using given get function, converting the resulting
|
||||
// filters to returnable frontend API models.
|
||||
func (p *Processor) getHeaderFilters(
|
||||
ctx context.Context,
|
||||
get func(context.Context) ([]*gtsmodel.HeaderFilter, error),
|
||||
) (
|
||||
[]*apimodel.HeaderFilter,
|
||||
gtserror.WithCode,
|
||||
) {
|
||||
// Select all filters from DB.
|
||||
filters, err := get(ctx)
|
||||
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
// Only handle errors other than not-found types.
|
||||
err := gtserror.Newf("error selecting from database: %w", err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
// Convert passed header filters to apimodel filters.
|
||||
apiFilters := make([]*apimodel.HeaderFilter, len(filters))
|
||||
for i := range filters {
|
||||
apiFilters[i] = toAPIHeaderFilter(filters[i])
|
||||
}
|
||||
|
||||
return apiFilters, nil
|
||||
}
|
||||
|
||||
// createHeaderFilter inserts the given HTTP header
|
||||
// filter into database, marking as authored by the
|
||||
// provided admin, using the given insert function.
|
||||
func (p *Processor) createHeaderFilter(
|
||||
ctx context.Context,
|
||||
admin *gtsmodel.Account,
|
||||
request *apimodel.HeaderFilterRequest,
|
||||
insert func(context.Context, *gtsmodel.HeaderFilter) error,
|
||||
) (
|
||||
*apimodel.HeaderFilter,
|
||||
gtserror.WithCode,
|
||||
) {
|
||||
// Convert header key to canonical mime header format.
|
||||
request.Header = textproto.CanonicalMIMEHeaderKey(request.Header)
|
||||
|
||||
// Validate incoming header filter.
|
||||
if errWithCode := validateHeaderFilter(
|
||||
request.Header,
|
||||
request.Regex,
|
||||
); errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// Create new database model with ID.
|
||||
var filter gtsmodel.HeaderFilter
|
||||
filter.ID = id.NewULID()
|
||||
filter.Header = request.Header
|
||||
filter.Regex = request.Regex
|
||||
filter.AuthorID = admin.ID
|
||||
filter.Author = admin
|
||||
|
||||
// Insert new header filter into the database.
|
||||
if err := insert(ctx, &filter); err != nil {
|
||||
err := gtserror.Newf("error inserting into database: %w", err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
// Finally return API model response.
|
||||
return toAPIHeaderFilter(&filter), nil
|
||||
}
|
||||
|
||||
// deleteHeaderFilter deletes the HTTP header filter
|
||||
// with provided ID, using the given delete function.
|
||||
func (p *Processor) deleteHeaderFilter(
|
||||
ctx context.Context,
|
||||
id string,
|
||||
delete func(context.Context, string) error,
|
||||
) gtserror.WithCode {
|
||||
if err := delete(ctx, id); err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
err := gtserror.Newf("error deleting from database: %w", err)
|
||||
return gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// toAPIFilter performs a simple conversion of database model HeaderFilter to API model.
|
||||
func toAPIHeaderFilter(filter *gtsmodel.HeaderFilter) *apimodel.HeaderFilter {
|
||||
return &apimodel.HeaderFilter{
|
||||
ID: filter.ID,
|
||||
Header: filter.Header,
|
||||
Regex: filter.Regex,
|
||||
CreatedBy: filter.AuthorID,
|
||||
CreatedAt: util.FormatISO8601(filter.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// validateHeaderFilter validates incoming filter's header key, and regular expression.
|
||||
func validateHeaderFilter(header, regex string) gtserror.WithCode {
|
||||
// Check header validity (within our own bound checks).
|
||||
if header == "" || len(header) > headerfilter.MaxHeaderValue {
|
||||
const text = "invalid request header key (empty or too long)"
|
||||
return gtserror.NewErrorBadRequest(errors.New(text), text)
|
||||
}
|
||||
|
||||
// Ensure this is compilable regex.
|
||||
_, err := regexp.Compile(regex)
|
||||
if err != nil {
|
||||
return gtserror.NewErrorBadRequest(err, err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue