[feature] request blocking by http headers (#2409)

This commit is contained in:
kim 2023-12-18 14:18:25 +00:00 committed by GitHub
commit 8ebb7775a3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 2561 additions and 81 deletions

View file

@ -25,10 +25,6 @@ type Basic interface {
// For implementations that don't use tables, this can just return nil.
CreateTable(ctx context.Context, i interface{}) error
// CreateAllTables creates *all* tables necessary for the running of GoToSocial.
// Because it uses the 'if not exists' parameter it is safe to run against a GtS that's already been initialized.
CreateAllTables(ctx context.Context) error
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
DropTable(ctx context.Context, i interface{}) error

View file

@ -22,7 +22,6 @@ import (
"errors"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/uptrace/bun"
)
@ -120,39 +119,6 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) error {
return err
}
func (b *basicDB) CreateAllTables(ctx context.Context) error {
models := []interface{}{
&gtsmodel.Account{},
&gtsmodel.Application{},
&gtsmodel.Block{},
&gtsmodel.DomainBlock{},
&gtsmodel.EmailDomainBlock{},
&gtsmodel.Follow{},
&gtsmodel.FollowRequest{},
&gtsmodel.MediaAttachment{},
&gtsmodel.Mention{},
&gtsmodel.Status{},
&gtsmodel.StatusToEmoji{},
&gtsmodel.StatusFave{},
&gtsmodel.StatusBookmark{},
&gtsmodel.ThreadMute{},
&gtsmodel.Tag{},
&gtsmodel.User{},
&gtsmodel.Emoji{},
&gtsmodel.Instance{},
&gtsmodel.Notification{},
&gtsmodel.RouterSession{},
&gtsmodel.Token{},
&gtsmodel.Client{},
}
for _, i := range models {
if err := b.CreateTable(ctx, i); err != nil {
return err
}
}
return nil
}
func (b *basicDB) DropTable(ctx context.Context, i interface{}) error {
_, err := b.db.NewDropTable().Model(i).IfExists().Exec(ctx)
return err

View file

@ -67,6 +67,7 @@ type DBService struct {
db.Basic
db.Domain
db.Emoji
db.HeaderFilter
db.Instance
db.List
db.Marker
@ -193,6 +194,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db,
state: state,
},
HeaderFilter: &headerFilterDB{
db: db,
state: state,
},
Instance: &instanceDB{
db: db,
state: state,

View file

@ -0,0 +1,207 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"net/http"
"time"
"unsafe"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
)
type headerFilterDB struct {
db *DB
state *state.State
}
func (h *headerFilterDB) AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) {
return h.state.Caches.AllowHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
return h.GetAllowHeaderFilters(ctx)
})
}
func (h *headerFilterDB) AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) {
return h.state.Caches.AllowHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
return h.GetAllowHeaderFilters(ctx)
})
}
func (h *headerFilterDB) BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error) {
return h.state.Caches.BlockHeaderFilters.RegularMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
return h.GetBlockHeaderFilters(ctx)
})
}
func (h *headerFilterDB) BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error) {
return h.state.Caches.BlockHeaderFilters.InverseMatch(hdr, func() ([]*gtsmodel.HeaderFilter, error) {
return h.GetBlockHeaderFilters(ctx)
})
}
func (h *headerFilterDB) GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) {
filter := new(gtsmodel.HeaderFilterAllow)
if err := h.db.NewSelect().
Model(filter).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx); err != nil {
return nil, err
}
return fromAllowFilter(filter), nil
}
func (h *headerFilterDB) GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error) {
filter := new(gtsmodel.HeaderFilterBlock)
if err := h.db.NewSelect().
Model(filter).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx); err != nil {
return nil, err
}
return fromBlockFilter(filter), nil
}
func (h *headerFilterDB) GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) {
var filters []*gtsmodel.HeaderFilterAllow
err := h.db.NewSelect().
Model(&filters).
Scan(ctx, &filters)
return fromAllowFilters(filters), err
}
func (h *headerFilterDB) GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error) {
var filters []*gtsmodel.HeaderFilterBlock
err := h.db.NewSelect().
Model(&filters).
Scan(ctx, &filters)
return fromBlockFilters(filters), err
}
func (h *headerFilterDB) PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error {
if _, err := h.db.NewInsert().
Model(toAllowFilter(filter)).
Exec(ctx); err != nil {
return err
}
h.state.Caches.AllowHeaderFilters.Clear()
return nil
}
func (h *headerFilterDB) PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error {
if _, err := h.db.NewInsert().
Model(toBlockFilter(filter)).
Exec(ctx); err != nil {
return err
}
h.state.Caches.BlockHeaderFilters.Clear()
return nil
}
func (h *headerFilterDB) UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error {
filter.UpdatedAt = time.Now()
if len(cols) > 0 {
// If we're updating by column,
// ensure "updated_at" is included.
cols = append(cols, "updated_at")
}
if _, err := h.db.NewUpdate().
Model(toAllowFilter(filter)).
Column(cols...).
Where("? = ?", bun.Ident("id"), filter.ID).
Exec(ctx); err != nil {
return err
}
h.state.Caches.AllowHeaderFilters.Clear()
return nil
}
func (h *headerFilterDB) UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error {
filter.UpdatedAt = time.Now()
if len(cols) > 0 {
// If we're updating by column,
// ensure "updated_at" is included.
cols = append(cols, "updated_at")
}
if _, err := h.db.NewUpdate().
Model(toBlockFilter(filter)).
Column(cols...).
Where("? = ?", bun.Ident("id"), filter.ID).
Exec(ctx); err != nil {
return err
}
h.state.Caches.BlockHeaderFilters.Clear()
return nil
}
func (h *headerFilterDB) DeleteAllowHeaderFilter(ctx context.Context, id string) error {
if _, err := h.db.NewDelete().
Table("header_filter_allows").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return err
}
h.state.Caches.AllowHeaderFilters.Clear()
return nil
}
func (h *headerFilterDB) DeleteBlockHeaderFilter(ctx context.Context, id string) error {
if _, err := h.db.NewDelete().
Table("header_filter_blocks").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return err
}
h.state.Caches.BlockHeaderFilters.Clear()
return nil
}
// NOTE:
// all of the below unsafe cast functions
// are only possible because HeaderFilterAllow{},
// HeaderFilterBlock{}, HeaderFilter{} while
// different types in source, have exactly the
// same size and layout in memory. the unsafe
// cast simply changes the type associated with
// that block of memory.
func toAllowFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterAllow {
return (*gtsmodel.HeaderFilterAllow)(unsafe.Pointer(filter))
}
func toBlockFilter(filter *gtsmodel.HeaderFilter) *gtsmodel.HeaderFilterBlock {
return (*gtsmodel.HeaderFilterBlock)(unsafe.Pointer(filter))
}
func fromAllowFilter(filter *gtsmodel.HeaderFilterAllow) *gtsmodel.HeaderFilter {
return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter))
}
func fromBlockFilter(filter *gtsmodel.HeaderFilterBlock) *gtsmodel.HeaderFilter {
return (*gtsmodel.HeaderFilter)(unsafe.Pointer(filter))
}
func fromAllowFilters(filters []*gtsmodel.HeaderFilterAllow) []*gtsmodel.HeaderFilter {
return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters))
}
func fromBlockFilters(filters []*gtsmodel.HeaderFilterBlock) []*gtsmodel.HeaderFilter {
return *(*[]*gtsmodel.HeaderFilter)(unsafe.Pointer(&filters))
}

View file

@ -0,0 +1,125 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb_test
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type HeaderFilterTestSuite struct {
BunDBStandardTestSuite
}
func (suite *HeaderFilterTestSuite) TestAllowHeaderFilterGetPutUpdateDelete() {
suite.testHeaderFilterGetPutUpdateDelete(
suite.db.GetAllowHeaderFilter,
suite.db.GetAllowHeaderFilters,
suite.db.PutAllowHeaderFilter,
suite.db.UpdateAllowHeaderFilter,
suite.db.DeleteAllowHeaderFilter,
)
}
func (suite *HeaderFilterTestSuite) TestBlockHeaderFilterGetPutUpdateDelete() {
suite.testHeaderFilterGetPutUpdateDelete(
suite.db.GetBlockHeaderFilter,
suite.db.GetBlockHeaderFilters,
suite.db.PutBlockHeaderFilter,
suite.db.UpdateBlockHeaderFilter,
suite.db.DeleteBlockHeaderFilter,
)
}
func (suite *HeaderFilterTestSuite) testHeaderFilterGetPutUpdateDelete(
get func(context.Context, string) (*gtsmodel.HeaderFilter, error),
getAll func(context.Context) ([]*gtsmodel.HeaderFilter, error),
put func(context.Context, *gtsmodel.HeaderFilter) error,
update func(context.Context, *gtsmodel.HeaderFilter, ...string) error,
delete func(context.Context, string) error,
) {
t := suite.T()
// Create new example header filter.
filter := gtsmodel.HeaderFilter{
ID: "some unique id",
Header: "Http-Header-Key",
Regex: ".*",
AuthorID: "some unique author id",
}
// Create new cancellable test context.
ctx := context.Background()
ctx, cncl := context.WithCancel(ctx)
defer cncl()
// Insert the example header filter into db.
if err := put(ctx, &filter); err != nil {
t.Fatalf("error inserting header filter: %v", err)
}
// Now fetch newly created filter.
check, err := get(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching header filter: %v", err)
}
// Check all expected fields match.
suite.Equal(filter.ID, check.ID)
suite.Equal(filter.Header, check.Header)
suite.Equal(filter.Regex, check.Regex)
suite.Equal(filter.AuthorID, check.AuthorID)
// Fetch all header filters.
all, err := getAll(ctx)
if err != nil {
t.Fatalf("error fetching header filters: %v", err)
}
// Ensure contains example.
suite.Equal(len(all), 1)
suite.Equal(all[0].ID, filter.ID)
// Update the header filter regex value.
check.Regex = "new regex value"
if err := update(ctx, check); err != nil {
t.Fatalf("error updating header filter: %v", err)
}
// Ensure 'updated_at' was updated on check model.
suite.True(check.UpdatedAt.After(filter.UpdatedAt))
// Now delete the header filter from db.
if err := delete(ctx, filter.ID); err != nil {
t.Fatalf("error deleting header filter: %v", err)
}
// Ensure we can't refetch it.
_, err = get(ctx, filter.ID)
if err != db.ErrNoEntries {
t.Fatalf("deleted header filter returned unexpected error: %v", err)
}
}
func TestHeaderFilterTestSuite(t *testing.T) {
suite.Run(t, new(HeaderFilterTestSuite))
}

View file

@ -0,0 +1,54 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package migrations
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
func init() {
up := func(ctx context.Context, db *bun.DB) error {
for _, model := range []any{
&gtsmodel.HeaderFilterAllow{},
&gtsmodel.HeaderFilterBlock{},
} {
_, err := db.NewCreateTable().
IfNotExists().
Model(model).
Exec(ctx)
if err != nil {
return err
}
}
return nil
}
down := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return nil
})
}
if err := Migrations.Register(up, down); err != nil {
panic(err)
}
}

View file

@ -30,6 +30,7 @@ type DB interface {
Basic
Domain
Emoji
HeaderFilter
Instance
List
Marker

View file

@ -0,0 +1,73 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package db
import (
"context"
"net/http"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type HeaderFilter interface {
// AllowHeaderRegularMatch performs an headerfilter.Filter.RegularMatch() on cached allow header filters.
// (Note: the actual matching code can be found under ./internal/headerfilter/ ).
AllowHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error)
// AllowHeaderInverseMatch performs an headerfilter.Filter.InverseMatch() on cached allow header filters.
// (Note: the actual matching code can be found under ./internal/headerfilter/ ).
AllowHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error)
// BlockHeaderRegularMatch performs an headerfilter.Filter.RegularMatch() on cached block header filters.
// (Note: the actual matching code can be found under ./internal/headerfilter/ ).
BlockHeaderRegularMatch(ctx context.Context, hdr http.Header) (string, string, error)
// BlockHeaderInverseMatch performs an headerfilter.Filter.InverseMatch() on cached block header filters.
// (Note: the actual matching code can be found under ./internal/headerfilter/ ).
BlockHeaderInverseMatch(ctx context.Context, hdr http.Header) (string, string, error)
// GetAllowHeaderFilter fetches the allow header filter with ID from the database.
GetAllowHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error)
// GetBlockHeaderFilter fetches the block header filter with ID from the database.
GetBlockHeaderFilter(ctx context.Context, id string) (*gtsmodel.HeaderFilter, error)
// GetAllowHeaderFilters fetches all allow header filters from the database.
GetAllowHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error)
// GetBlockHeaderFilters fetches all block header filters from the database.
GetBlockHeaderFilters(ctx context.Context) ([]*gtsmodel.HeaderFilter, error)
// PutAllowHeaderFilter inserts the given allow header filter into the database.
PutAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error
// PutBlockHeaderFilter inserts the given block header filter into the database.
PutBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter) error
// UpdateAllowHeaderFilter updates the given allow header filter in the database, only updating given columns if provided.
UpdateAllowHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error
// UpdateBlockHeaderFilter updates the given block header filter in the database, only updating given columns if provided.
UpdateBlockHeaderFilter(ctx context.Context, filter *gtsmodel.HeaderFilter, cols ...string) error
// DeleteAllowHeaderFilter deletes the allow header filter with ID from the database.
DeleteAllowHeaderFilter(ctx context.Context, id string) error
// DeleteBlockHeaderFilter deletes the block header filter with ID from the database.
DeleteBlockHeaderFilter(ctx context.Context, id string) error
}