[feature] Filters v1 (#2594)

* Implement client-side v1 filters

* Exclude linter false positives

* Update test/envparsing.sh

* Fix minor Swagger, style, and Bun usage issues

* Regenerate Swagger

* De-generify filter keywords

* Remove updating filter statuses

This is an operation that the Mastodon v2 filter API doesn't actually have, because filter statuses, unlike keywords, don't have options: the only info they contain is the status ID to be filtered.

* Add a test for filter statuses specifically

* De-generify filter statuses

* Inline FilterEntry

* Use vertical style for Bun operations consistently

* Add comment on Filter DB interface

* Remove GoLand linter control comments

Our existing linters should catch these, or they don't matter very much

* Reduce memory ratio for filters
This commit is contained in:
Vyr Cossont 2024-03-06 02:15:58 -08:00 committed by GitHub
commit 61a2b91f45
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
50 changed files with 4672 additions and 52 deletions

View file

@ -62,6 +62,7 @@ type DBService struct {
db.Emoji
db.HeaderFilter
db.Instance
db.Filter
db.List
db.Marker
db.Media
@ -200,6 +201,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db,
state: state,
},
Filter: &filterDB{
db: db,
state: state,
},
List: &listDB{
db: db,
state: state,

339
internal/db/bundb/filter.go Normal file
View file

@ -0,0 +1,339 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
type filterDB struct {
db *bun.DB
state *state.State
}
func (f *filterDB) GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filter, error) {
filter, err := f.state.Caches.GTS.Filter.LoadOne(
"ID",
func() (*gtsmodel.Filter, error) {
var filter gtsmodel.Filter
err := f.db.
NewSelect().
Model(&filter).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
return &filter, err
},
id,
)
if err != nil {
// already processed
return nil, err
}
if !gtscontext.Barebones(ctx) {
if err := f.populateFilter(ctx, filter); err != nil {
return nil, err
}
}
return filter, nil
}
func (f *filterDB) GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error) {
// Fetch IDs of all filters owned by this account.
var filterIDs []string
if err := f.db.
NewSelect().
Model((*gtsmodel.Filter)(nil)).
Column("id").
Where("? = ?", bun.Ident("account_id"), accountID).
Scan(ctx, &filterIDs); err != nil {
return nil, err
}
if len(filterIDs) == 0 {
return nil, nil
}
// Get each filter by ID from the cache or DB.
uncachedFilterIDs := make([]string, 0, len(filterIDs))
filters, err := f.state.Caches.GTS.Filter.Load(
"ID",
func(load func(keyParts ...any) bool) {
for _, id := range filterIDs {
if !load(id) {
uncachedFilterIDs = append(uncachedFilterIDs, id)
}
}
},
func() ([]*gtsmodel.Filter, error) {
uncachedFilters := make([]*gtsmodel.Filter, 0, len(uncachedFilterIDs))
if err := f.db.
NewSelect().
Model(&uncachedFilters).
Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterIDs)).
Scan(ctx); err != nil {
return nil, err
}
return uncachedFilters, nil
},
)
if err != nil {
return nil, err
}
// Put the filter structs in the same order as the filter IDs.
util.OrderBy(filters, filterIDs, func(filter *gtsmodel.Filter) string { return filter.ID })
if gtscontext.Barebones(ctx) {
return filters, nil
}
// Populate the filters. Remove any that we can't populate from the return slice.
errs := gtserror.NewMultiError(len(filters))
filters = slices.DeleteFunc(filters, func(filter *gtsmodel.Filter) bool {
if err := f.populateFilter(ctx, filter); err != nil {
errs.Appendf("error populating filter %s: %w", filter.ID, err)
return true
}
return false
})
return filters, errs.Combine()
}
func (f *filterDB) populateFilter(ctx context.Context, filter *gtsmodel.Filter) error {
var err error
errs := gtserror.NewMultiError(2)
if filter.Keywords == nil {
// Filter keywords are not set, fetch from the database.
filter.Keywords, err = f.state.DB.GetFilterKeywordsForFilterID(
gtscontext.SetBarebones(ctx),
filter.ID,
)
if err != nil {
errs.Appendf("error populating filter keywords: %w", err)
}
for i := range filter.Keywords {
filter.Keywords[i].Filter = filter
}
}
if filter.Statuses == nil {
// Filter statuses are not set, fetch from the database.
filter.Statuses, err = f.state.DB.GetFilterStatusesForFilterID(
gtscontext.SetBarebones(ctx),
filter.ID,
)
if err != nil {
errs.Appendf("error populating filter statuses: %w", err)
}
for i := range filter.Statuses {
filter.Statuses[i].Filter = filter
}
}
return errs.Combine()
}
func (f *filterDB) PutFilter(ctx context.Context, filter *gtsmodel.Filter) error {
// Update database.
if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
if _, err := tx.
NewInsert().
Model(filter).
Exec(ctx); err != nil {
return err
}
if len(filter.Keywords) > 0 {
if _, err := tx.
NewInsert().
Model(&filter.Keywords).
Exec(ctx); err != nil {
return err
}
}
if len(filter.Statuses) > 0 {
if _, err := tx.
NewInsert().
Model(&filter.Statuses).
Exec(ctx); err != nil {
return err
}
}
return nil
}); err != nil {
return err
}
// Update cache.
f.state.Caches.GTS.Filter.Put(filter)
f.state.Caches.GTS.FilterKeyword.Put(filter.Keywords...)
f.state.Caches.GTS.FilterStatus.Put(filter.Statuses...)
return nil
}
func (f *filterDB) UpdateFilter(
ctx context.Context,
filter *gtsmodel.Filter,
filterColumns []string,
filterKeywordColumns []string,
deleteFilterKeywordIDs []string,
deleteFilterStatusIDs []string,
) error {
updatedAt := time.Now()
filter.UpdatedAt = updatedAt
for _, filterKeyword := range filter.Keywords {
filterKeyword.UpdatedAt = updatedAt
}
for _, filterStatus := range filter.Statuses {
filterStatus.UpdatedAt = updatedAt
}
// If we're updating by column, ensure "updated_at" is included.
if len(filterColumns) > 0 {
filterColumns = append(filterColumns, "updated_at")
}
if len(filterKeywordColumns) > 0 {
filterKeywordColumns = append(filterKeywordColumns, "updated_at")
}
// Update database.
if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
if _, err := tx.
NewUpdate().
Model(filter).
Column(filterColumns...).
Where("? = ?", bun.Ident("id"), filter.ID).
Exec(ctx); err != nil {
return err
}
if len(filter.Keywords) > 0 {
if _, err := NewUpsert(tx).
Model(&filter.Keywords).
Constraint("id").
Column(filterKeywordColumns...).
Exec(ctx); err != nil {
return err
}
}
if len(filter.Statuses) > 0 {
if _, err := tx.
NewInsert().
Ignore().
Model(&filter.Statuses).
Exec(ctx); err != nil {
return err
}
}
if len(deleteFilterKeywordIDs) > 0 {
if _, err := tx.
NewDelete().
Model((*gtsmodel.FilterKeyword)(nil)).
Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterKeywordIDs)).
Exec(ctx); err != nil {
return err
}
}
if len(deleteFilterStatusIDs) > 0 {
if _, err := tx.
NewDelete().
Model((*gtsmodel.FilterStatus)(nil)).
Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterStatusIDs)).
Exec(ctx); err != nil {
return err
}
}
return nil
}); err != nil {
return err
}
// Update cache.
f.state.Caches.GTS.Filter.Put(filter)
f.state.Caches.GTS.FilterKeyword.Put(filter.Keywords...)
f.state.Caches.GTS.FilterStatus.Put(filter.Statuses...)
// TODO: (Vyr) replace with cache multi-invalidate call
for _, id := range deleteFilterKeywordIDs {
f.state.Caches.GTS.FilterKeyword.Invalidate("ID", id)
}
for _, id := range deleteFilterStatusIDs {
f.state.Caches.GTS.FilterStatus.Invalidate("ID", id)
}
return nil
}
func (f *filterDB) DeleteFilterByID(ctx context.Context, id string) error {
if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Delete all keywords attached to filter.
if _, err := tx.
NewDelete().
Model((*gtsmodel.FilterKeyword)(nil)).
Where("? = ?", bun.Ident("filter_id"), id).
Exec(ctx); err != nil {
return err
}
// Delete all statuses attached to filter.
if _, err := tx.
NewDelete().
Model((*gtsmodel.FilterStatus)(nil)).
Where("? = ?", bun.Ident("filter_id"), id).
Exec(ctx); err != nil {
return err
}
// Delete the filter itself.
_, err := tx.
NewDelete().
Model((*gtsmodel.Filter)(nil)).
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
return err
}); err != nil {
return err
}
// Invalidate this filter.
f.state.Caches.GTS.Filter.Invalidate("ID", id)
// Invalidate all keywords and statuses for this filter.
f.state.Caches.GTS.FilterKeyword.Invalidate("FilterID", id)
f.state.Caches.GTS.FilterStatus.Invalidate("FilterID", id)
return nil
}

View file

@ -0,0 +1,252 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb_test
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
type FilterTestSuite struct {
BunDBStandardTestSuite
}
// TestFilterCRUD tests CRUD and read-all operations on filters.
func (suite *FilterTestSuite) TestFilterCRUD() {
t := suite.T()
// Create new example filter with attached keyword.
filter := &gtsmodel.Filter{
ID: "01HNEJNVZZVXJTRB3FX3K2B1YF",
AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47",
Title: "foss jail",
Action: gtsmodel.FilterActionWarn,
ContextHome: util.Ptr(true),
ContextPublic: util.Ptr(true),
}
filterKeyword := &gtsmodel.FilterKeyword{
ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4",
AccountID: filter.AccountID,
FilterID: filter.ID,
Keyword: "GNU/Linux",
}
filter.Keywords = []*gtsmodel.FilterKeyword{filterKeyword}
// Create new cancellable test context.
ctx := context.Background()
ctx, cncl := context.WithCancel(ctx)
defer cncl()
// Insert the example filter into db.
if err := suite.db.PutFilter(ctx, filter); err != nil {
t.Fatalf("error inserting filter: %v", err)
}
// Now fetch newly created filter.
check, err := suite.db.GetFilterByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching filter: %v", err)
}
// Check all expected fields match.
suite.Equal(filter.ID, check.ID)
suite.Equal(filter.AccountID, check.AccountID)
suite.Equal(filter.Title, check.Title)
suite.Equal(filter.Action, check.Action)
suite.Equal(filter.ContextHome, check.ContextHome)
suite.Equal(filter.ContextNotifications, check.ContextNotifications)
suite.Equal(filter.ContextPublic, check.ContextPublic)
suite.Equal(filter.ContextThread, check.ContextThread)
suite.Equal(filter.ContextAccount, check.ContextAccount)
suite.NotZero(check.CreatedAt)
suite.NotZero(check.UpdatedAt)
suite.Equal(len(filter.Keywords), len(check.Keywords))
suite.Equal(filter.Keywords[0].ID, check.Keywords[0].ID)
suite.Equal(filter.Keywords[0].AccountID, check.Keywords[0].AccountID)
suite.Equal(filter.Keywords[0].FilterID, check.Keywords[0].FilterID)
suite.Equal(filter.Keywords[0].Keyword, check.Keywords[0].Keyword)
suite.Equal(filter.Keywords[0].FilterID, check.Keywords[0].FilterID)
suite.NotZero(check.Keywords[0].CreatedAt)
suite.NotZero(check.Keywords[0].UpdatedAt)
suite.Equal(len(filter.Statuses), len(check.Statuses))
// Fetch all filters.
all, err := suite.db.GetFiltersForAccountID(ctx, filter.AccountID)
if err != nil {
t.Fatalf("error fetching filters: %v", err)
}
// Ensure the result contains our example filter.
suite.Len(all, 1)
suite.Equal(filter.ID, all[0].ID)
suite.Len(all[0].Keywords, 1)
suite.Equal(filter.Keywords[0].ID, all[0].Keywords[0].ID)
suite.Empty(all[0].Statuses)
// Update the filter context and add another keyword and a status.
check.ContextNotifications = util.Ptr(true)
newKeyword := &gtsmodel.FilterKeyword{
ID: "01HNEMY810E5XKWDDMN5ZRE749",
FilterID: filter.ID,
AccountID: filter.AccountID,
Keyword: "tux",
}
check.Keywords = append(check.Keywords, newKeyword)
newStatus := &gtsmodel.FilterStatus{
ID: "01HNEMYD5XE7C8HH8TNCZ76FN2",
FilterID: filter.ID,
AccountID: filter.AccountID,
StatusID: "01HNEKZW34SQZ8PSDQ0Z10NZES",
}
check.Statuses = append(check.Statuses, newStatus)
if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil {
t.Fatalf("error updating filter: %v", err)
}
// Now fetch newly updated filter.
check, err = suite.db.GetFilterByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching updated filter: %v", err)
}
// Ensure expected fields were modified on check filter.
suite.True(check.UpdatedAt.After(filter.UpdatedAt))
if suite.NotNil(check.ContextHome) {
suite.True(*check.ContextHome)
}
if suite.NotNil(check.ContextNotifications) {
suite.True(*check.ContextNotifications)
}
if suite.NotNil(check.ContextPublic) {
suite.True(*check.ContextPublic)
}
if suite.NotNil(check.ContextThread) {
suite.False(*check.ContextThread)
}
if suite.NotNil(check.ContextAccount) {
suite.False(*check.ContextAccount)
}
// Ensure keyword entries were added.
suite.Len(check.Keywords, 2)
checkFilterKeywordIDs := make([]string, 0, 2)
for _, checkFilterKeyword := range check.Keywords {
checkFilterKeywordIDs = append(checkFilterKeywordIDs, checkFilterKeyword.ID)
}
suite.ElementsMatch([]string{filterKeyword.ID, newKeyword.ID}, checkFilterKeywordIDs)
// Ensure status entry was added.
suite.Len(check.Statuses, 1)
checkFilterStatusIDs := make([]string, 0, 1)
for _, checkFilterStatus := range check.Statuses {
checkFilterStatusIDs = append(checkFilterStatusIDs, checkFilterStatus.ID)
}
suite.ElementsMatch([]string{newStatus.ID}, checkFilterStatusIDs)
// Update one filter keyword and delete another. Don't change the filter or the filter status.
filterKeyword.WholeWord = util.Ptr(true)
check.Keywords = []*gtsmodel.FilterKeyword{filterKeyword}
check.Statuses = nil
if err := suite.db.UpdateFilter(ctx, check, nil, nil, []string{newKeyword.ID}, nil); err != nil {
t.Fatalf("error updating filter: %v", err)
}
check, err = suite.db.GetFilterByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching updated filter: %v", err)
}
// Ensure expected fields were not modified.
suite.Equal(filter.Title, check.Title)
suite.Equal(gtsmodel.FilterActionWarn, check.Action)
if suite.NotNil(check.ContextHome) {
suite.True(*check.ContextHome)
}
if suite.NotNil(check.ContextNotifications) {
suite.True(*check.ContextNotifications)
}
if suite.NotNil(check.ContextPublic) {
suite.True(*check.ContextPublic)
}
if suite.NotNil(check.ContextThread) {
suite.False(*check.ContextThread)
}
if suite.NotNil(check.ContextAccount) {
suite.False(*check.ContextAccount)
}
// Ensure only changed field of keyword was modified, and other keyword was deleted.
suite.Len(check.Keywords, 1)
suite.Equal(filterKeyword.ID, check.Keywords[0].ID)
suite.Equal("GNU/Linux", check.Keywords[0].Keyword)
if suite.NotNil(check.Keywords[0].WholeWord) {
suite.True(*check.Keywords[0].WholeWord)
}
// Ensure status entry was not deleted.
suite.Len(check.Statuses, 1)
suite.Equal(newStatus.ID, check.Statuses[0].ID)
// Add another status entry for the same status ID. It should be ignored without problems.
redundantStatus := &gtsmodel.FilterStatus{
ID: "01HQXJ5Y405XZSQ67C2BSQ6HJ0",
FilterID: filter.ID,
AccountID: filter.AccountID,
StatusID: newStatus.StatusID,
}
check.Statuses = []*gtsmodel.FilterStatus{redundantStatus}
if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil {
t.Fatalf("error updating filter: %v", err)
}
check, err = suite.db.GetFilterByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching updated filter: %v", err)
}
// Ensure status entry was not deleted, updated, or duplicated.
suite.Len(check.Statuses, 1)
suite.Equal(newStatus.ID, check.Statuses[0].ID)
suite.Equal(newStatus.StatusID, check.Statuses[0].StatusID)
// Now delete the filter from the DB.
if err := suite.db.DeleteFilterByID(ctx, filter.ID); err != nil {
t.Fatalf("error deleting filter: %v", err)
}
// Ensure we can't refetch it.
_, err = suite.db.GetFilterByID(ctx, filter.ID)
if !errors.Is(err, db.ErrNoEntries) {
t.Fatalf("fetching deleted filter returned unexpected error: %v", err)
}
}
func TestFilterTestSuite(t *testing.T) {
suite.Run(t, new(FilterTestSuite))
}

View file

@ -0,0 +1,191 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
func (f *filterDB) GetFilterKeywordByID(ctx context.Context, id string) (*gtsmodel.FilterKeyword, error) {
filterKeyword, err := f.state.Caches.GTS.FilterKeyword.LoadOne(
"ID",
func() (*gtsmodel.FilterKeyword, error) {
var filterKeyword gtsmodel.FilterKeyword
err := f.db.
NewSelect().
Model(&filterKeyword).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
return &filterKeyword, err
},
id,
)
if err != nil {
return nil, err
}
if !gtscontext.Barebones(ctx) {
err = f.populateFilterKeyword(ctx, filterKeyword)
if err != nil {
return nil, err
}
}
return filterKeyword, nil
}
func (f *filterDB) populateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error {
if filterKeyword.Filter == nil {
// Filter is not set, fetch from the cache or database.
filter, err := f.state.DB.GetFilterByID(
// Don't populate the filter with all of its keywords and statuses or we'll just end up back here.
gtscontext.SetBarebones(ctx),
filterKeyword.FilterID,
)
if err != nil {
return err
}
filterKeyword.Filter = filter
}
return nil
}
func (f *filterDB) GetFilterKeywordsForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterKeyword, error) {
return f.getFilterKeywords(ctx, "filter_id", filterID)
}
func (f *filterDB) GetFilterKeywordsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterKeyword, error) {
return f.getFilterKeywords(ctx, "account_id", accountID)
}
func (f *filterDB) getFilterKeywords(ctx context.Context, idColumn string, id string) ([]*gtsmodel.FilterKeyword, error) {
var filterKeywordIDs []string
if err := f.db.
NewSelect().
Model((*gtsmodel.FilterKeyword)(nil)).
Column("id").
Where("? = ?", bun.Ident(idColumn), id).
Scan(ctx, &filterKeywordIDs); err != nil {
return nil, err
}
if len(filterKeywordIDs) == 0 {
return nil, nil
}
// Get each filter keyword by ID from the cache or DB.
uncachedFilterKeywordIDs := make([]string, 0, len(filterKeywordIDs))
filterKeywords, err := f.state.Caches.GTS.FilterKeyword.Load(
"ID",
func(load func(keyParts ...any) bool) {
for _, id := range filterKeywordIDs {
if !load(id) {
uncachedFilterKeywordIDs = append(uncachedFilterKeywordIDs, id)
}
}
},
func() ([]*gtsmodel.FilterKeyword, error) {
uncachedFilterKeywords := make([]*gtsmodel.FilterKeyword, 0, len(uncachedFilterKeywordIDs))
if err := f.db.
NewSelect().
Model(&uncachedFilterKeywords).
Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterKeywordIDs)).
Scan(ctx); err != nil {
return nil, err
}
return uncachedFilterKeywords, nil
},
)
if err != nil {
return nil, err
}
// Put the filter keyword structs in the same order as the filter keyword IDs.
util.OrderBy(filterKeywords, filterKeywordIDs, func(filterKeyword *gtsmodel.FilterKeyword) string {
return filterKeyword.ID
})
if gtscontext.Barebones(ctx) {
return filterKeywords, nil
}
// Populate the filter keywords. Remove any that we can't populate from the return slice.
errs := gtserror.NewMultiError(len(filterKeywords))
filterKeywords = slices.DeleteFunc(filterKeywords, func(filterKeyword *gtsmodel.FilterKeyword) bool {
if err := f.populateFilterKeyword(ctx, filterKeyword); err != nil {
errs.Appendf(
"error populating filter keyword %s: %w",
filterKeyword.ID,
err,
)
return true
}
return false
})
return filterKeywords, errs.Combine()
}
func (f *filterDB) PutFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error {
return f.state.Caches.GTS.FilterKeyword.Store(filterKeyword, func() error {
_, err := f.db.
NewInsert().
Model(filterKeyword).
Exec(ctx)
return err
})
}
func (f *filterDB) UpdateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword, columns ...string) error {
filterKeyword.UpdatedAt = time.Now()
if len(columns) > 0 {
columns = append(columns, "updated_at")
}
return f.state.Caches.GTS.FilterKeyword.Store(filterKeyword, func() error {
_, err := f.db.
NewUpdate().
Model(filterKeyword).
Where("? = ?", bun.Ident("id"), filterKeyword.ID).
Column(columns...).
Exec(ctx)
return err
})
}
func (f *filterDB) DeleteFilterKeywordByID(ctx context.Context, id string) error {
if _, err := f.db.
NewDelete().
Model((*gtsmodel.FilterKeyword)(nil)).
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return err
}
f.state.Caches.GTS.FilterKeyword.Invalidate("ID", id)
return nil
}

View file

@ -0,0 +1,143 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb_test
import (
"context"
"errors"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
// TestFilterKeywordCRUD tests CRUD and read-all operations on filter keywords.
func (suite *FilterTestSuite) TestFilterKeywordCRUD() {
t := suite.T()
// Create new filter.
filter := &gtsmodel.Filter{
ID: "01HNEJNVZZVXJTRB3FX3K2B1YF",
AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47",
Title: "foss jail",
Action: gtsmodel.FilterActionWarn,
ContextHome: util.Ptr(true),
ContextPublic: util.Ptr(true),
}
// Create new cancellable test context.
ctx := context.Background()
ctx, cncl := context.WithCancel(ctx)
defer cncl()
// Insert the new filter into the DB.
err := suite.db.PutFilter(ctx, filter)
if err != nil {
t.Fatalf("error inserting filter: %v", err)
}
// There should be no filter keywords yet.
all, err := suite.db.GetFilterKeywordsForAccountID(ctx, filter.AccountID)
if err != nil {
t.Fatalf("error fetching filter keywords: %v", err)
}
suite.Empty(all)
// Add a filter keyword to it.
filterKeyword := &gtsmodel.FilterKeyword{
ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4",
AccountID: filter.AccountID,
FilterID: filter.ID,
Keyword: "GNU/Linux",
}
// Insert the new filter keyword into the DB.
err = suite.db.PutFilterKeyword(ctx, filterKeyword)
if err != nil {
t.Fatalf("error inserting filter keyword: %v", err)
}
// Try to find it again and ensure it has the fields we expect.
check, err := suite.db.GetFilterKeywordByID(ctx, filterKeyword.ID)
if err != nil {
t.Fatalf("error fetching filter keyword: %v", err)
}
suite.Equal(filterKeyword.ID, check.ID)
suite.NotZero(check.CreatedAt)
suite.NotZero(check.UpdatedAt)
suite.Equal(filterKeyword.AccountID, check.AccountID)
suite.Equal(filterKeyword.FilterID, check.FilterID)
suite.Equal(filterKeyword.Keyword, check.Keyword)
suite.Equal(filterKeyword.WholeWord, check.WholeWord)
// Loading filter keywords by account ID should find the one we inserted.
all, err = suite.db.GetFilterKeywordsForAccountID(ctx, filter.AccountID)
if err != nil {
t.Fatalf("error fetching filter keywords: %v", err)
}
suite.Len(all, 1)
suite.Equal(filterKeyword.ID, all[0].ID)
// Loading filter keywords by filter ID should also find the one we inserted.
all, err = suite.db.GetFilterKeywordsForFilterID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching filter keywords: %v", err)
}
suite.Len(all, 1)
suite.Equal(filterKeyword.ID, all[0].ID)
// Modify the filter keyword.
filterKeyword.WholeWord = util.Ptr(true)
err = suite.db.UpdateFilterKeyword(ctx, filterKeyword)
if err != nil {
t.Fatalf("error updating filter keyword: %v", err)
}
// Try to find it again and ensure it has the updated fields we expect.
check, err = suite.db.GetFilterKeywordByID(ctx, filterKeyword.ID)
if err != nil {
t.Fatalf("error fetching filter keyword: %v", err)
}
suite.Equal(filterKeyword.ID, check.ID)
suite.NotZero(check.CreatedAt)
suite.True(check.UpdatedAt.After(check.CreatedAt))
suite.Equal(filterKeyword.AccountID, check.AccountID)
suite.Equal(filterKeyword.FilterID, check.FilterID)
suite.Equal(filterKeyword.Keyword, check.Keyword)
suite.Equal(filterKeyword.WholeWord, check.WholeWord)
// Delete the filter keyword from the DB.
err = suite.db.DeleteFilterKeywordByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error deleting filter keyword: %v", err)
}
// Ensure we can't refetch it.
check, err = suite.db.GetFilterKeywordByID(ctx, filter.ID)
if !errors.Is(err, db.ErrNoEntries) {
t.Fatalf("fetching deleted filter keyword returned unexpected error: %v", err)
}
suite.Nil(check)
// Ensure the filter itself is still there.
checkFilter, err := suite.db.GetFilterByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching filter: %v", err)
}
suite.Equal(filter.ID, checkFilter.ID)
}

View file

@ -0,0 +1,191 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"slices"
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
func (f *filterDB) GetFilterStatusByID(ctx context.Context, id string) (*gtsmodel.FilterStatus, error) {
filterStatus, err := f.state.Caches.GTS.FilterStatus.LoadOne(
"ID",
func() (*gtsmodel.FilterStatus, error) {
var filterStatus gtsmodel.FilterStatus
err := f.db.
NewSelect().
Model(&filterStatus).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
return &filterStatus, err
},
id,
)
if err != nil {
return nil, err
}
if !gtscontext.Barebones(ctx) {
err = f.populateFilterStatus(ctx, filterStatus)
if err != nil {
return nil, err
}
}
return filterStatus, nil
}
func (f *filterDB) populateFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error {
if filterStatus.Filter == nil {
// Filter is not set, fetch from the cache or database.
filter, err := f.state.DB.GetFilterByID(
// Don't populate the filter with all of its keywords and statuses or we'll just end up back here.
gtscontext.SetBarebones(ctx),
filterStatus.FilterID,
)
if err != nil {
return err
}
filterStatus.Filter = filter
}
return nil
}
func (f *filterDB) GetFilterStatusesForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterStatus, error) {
return f.getFilterStatuses(ctx, "filter_id", filterID)
}
func (f *filterDB) GetFilterStatusesForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterStatus, error) {
return f.getFilterStatuses(ctx, "account_id", accountID)
}
func (f *filterDB) getFilterStatuses(ctx context.Context, idColumn string, id string) ([]*gtsmodel.FilterStatus, error) {
var filterStatusIDs []string
if err := f.db.
NewSelect().
Model((*gtsmodel.FilterStatus)(nil)).
Column("id").
Where("? = ?", bun.Ident(idColumn), id).
Scan(ctx, &filterStatusIDs); err != nil {
return nil, err
}
if len(filterStatusIDs) == 0 {
return nil, nil
}
// Get each filter status by ID from the cache or DB.
uncachedFilterStatusIDs := make([]string, 0, len(filterStatusIDs))
filterStatuses, err := f.state.Caches.GTS.FilterStatus.Load(
"ID",
func(load func(keyParts ...any) bool) {
for _, id := range filterStatusIDs {
if !load(id) {
uncachedFilterStatusIDs = append(uncachedFilterStatusIDs, id)
}
}
},
func() ([]*gtsmodel.FilterStatus, error) {
uncachedFilterStatuses := make([]*gtsmodel.FilterStatus, 0, len(uncachedFilterStatusIDs))
if err := f.db.
NewSelect().
Model(&uncachedFilterStatuses).
Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterStatusIDs)).
Scan(ctx); err != nil {
return nil, err
}
return uncachedFilterStatuses, nil
},
)
if err != nil {
return nil, err
}
// Put the filter status structs in the same order as the filter status IDs.
util.OrderBy(filterStatuses, filterStatusIDs, func(filterStatus *gtsmodel.FilterStatus) string {
return filterStatus.ID
})
if gtscontext.Barebones(ctx) {
return filterStatuses, nil
}
// Populate the filter statuses. Remove any that we can't populate from the return slice.
errs := gtserror.NewMultiError(len(filterStatuses))
filterStatuses = slices.DeleteFunc(filterStatuses, func(filterStatus *gtsmodel.FilterStatus) bool {
if err := f.populateFilterStatus(ctx, filterStatus); err != nil {
errs.Appendf(
"error populating filter status %s: %w",
filterStatus.ID,
err,
)
return true
}
return false
})
return filterStatuses, errs.Combine()
}
func (f *filterDB) PutFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error {
return f.state.Caches.GTS.FilterStatus.Store(filterStatus, func() error {
_, err := f.db.
NewInsert().
Model(filterStatus).
Exec(ctx)
return err
})
}
func (f *filterDB) UpdateFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus, columns ...string) error {
filterStatus.UpdatedAt = time.Now()
if len(columns) > 0 {
columns = append(columns, "updated_at")
}
return f.state.Caches.GTS.FilterStatus.Store(filterStatus, func() error {
_, err := f.db.
NewUpdate().
Model(filterStatus).
Where("? = ?", bun.Ident("id"), filterStatus.ID).
Column(columns...).
Exec(ctx)
return err
})
}
func (f *filterDB) DeleteFilterStatusByID(ctx context.Context, id string) error {
if _, err := f.db.
NewDelete().
Model((*gtsmodel.FilterStatus)(nil)).
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return err
}
f.state.Caches.GTS.FilterStatus.Invalidate("ID", id)
return nil
}

View file

@ -0,0 +1,122 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb_test
import (
"context"
"errors"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
// TestFilterStatusCRD tests CRD (no U) and read-all operations on filter statuses.
func (suite *FilterTestSuite) TestFilterStatusCRD() {
t := suite.T()
// Create new filter.
filter := &gtsmodel.Filter{
ID: "01HNEJNVZZVXJTRB3FX3K2B1YF",
AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47",
Title: "foss jail",
Action: gtsmodel.FilterActionWarn,
ContextHome: util.Ptr(true),
ContextPublic: util.Ptr(true),
}
// Create new cancellable test context.
ctx := context.Background()
ctx, cncl := context.WithCancel(ctx)
defer cncl()
// Insert the new filter into the DB.
err := suite.db.PutFilter(ctx, filter)
if err != nil {
t.Fatalf("error inserting filter: %v", err)
}
// There should be no filter statuses yet.
all, err := suite.db.GetFilterStatusesForAccountID(ctx, filter.AccountID)
if err != nil {
t.Fatalf("error fetching filter statuses: %v", err)
}
suite.Empty(all)
// Add a filter status to it.
filterStatus := &gtsmodel.FilterStatus{
ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4",
AccountID: filter.AccountID,
FilterID: filter.ID,
StatusID: "01HQXGMQ3QFXRT4GX9WNQ8KC0X",
}
// Insert the new filter status into the DB.
err = suite.db.PutFilterStatus(ctx, filterStatus)
if err != nil {
t.Fatalf("error inserting filter status: %v", err)
}
// Try to find it again and ensure it has the fields we expect.
check, err := suite.db.GetFilterStatusByID(ctx, filterStatus.ID)
if err != nil {
t.Fatalf("error fetching filter status: %v", err)
}
suite.Equal(filterStatus.ID, check.ID)
suite.NotZero(check.CreatedAt)
suite.NotZero(check.UpdatedAt)
suite.Equal(filterStatus.AccountID, check.AccountID)
suite.Equal(filterStatus.FilterID, check.FilterID)
suite.Equal(filterStatus.StatusID, check.StatusID)
// Loading filter statuses by account ID should find the one we inserted.
all, err = suite.db.GetFilterStatusesForAccountID(ctx, filter.AccountID)
if err != nil {
t.Fatalf("error fetching filter statuses: %v", err)
}
suite.Len(all, 1)
suite.Equal(filterStatus.ID, all[0].ID)
// Loading filter statuses by filter ID should also find the one we inserted.
all, err = suite.db.GetFilterStatusesForFilterID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching filter statuses: %v", err)
}
suite.Len(all, 1)
suite.Equal(filterStatus.ID, all[0].ID)
// Delete the filter status from the DB.
err = suite.db.DeleteFilterStatusByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error deleting filter status: %v", err)
}
// Ensure we can't refetch it.
check, err = suite.db.GetFilterStatusByID(ctx, filter.ID)
if !errors.Is(err, db.ErrNoEntries) {
t.Fatalf("fetching deleted filter status returned unexpected error: %v", err)
}
suite.Nil(check)
// Ensure the filter itself is still there.
checkFilter, err := suite.db.GetFilterByID(ctx, filter.ID)
if err != nil {
t.Fatalf("error fetching filter: %v", err)
}
suite.Equal(filter.ID, checkFilter.ID)
}

View file

@ -0,0 +1,97 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package migrations
import (
"context"
gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
func init() {
up := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Filter table.
if _, err := tx.
NewCreateTable().
Model(&gtsmodel.Filter{}).
IfNotExists().
Exec(ctx); err != nil {
return err
}
// Filter keyword table.
if _, err := tx.
NewCreateTable().
Model(&gtsmodel.FilterKeyword{}).
IfNotExists().
Exec(ctx); err != nil {
return err
}
// Filter status table.
if _, err := tx.
NewCreateTable().
Model(&gtsmodel.FilterStatus{}).
IfNotExists().
Exec(ctx); err != nil {
return err
}
// Add indexes to the filter tables.
for table, indexes := range map[string]map[string][]string{
"filters": {
"filters_account_id_idx": {"account_id"},
},
"filter_keywords": {
"filter_keywords_account_id_idx": {"account_id"},
"filter_keywords_filter_id_idx": {"filter_id"},
},
"filter_statuses": {
"filter_statuses_account_id_idx": {"account_id"},
"filter_statuses_filter_id_idx": {"filter_id"},
},
} {
for index, columns := range indexes {
if _, err := tx.
NewCreateIndex().
Table(table).
Index(index).
Column(columns...).
IfNotExists().
Exec(ctx); err != nil {
return err
}
}
}
return nil
})
}
down := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return nil
})
}
if err := Migrations.Register(up, down); err != nil {
panic(err)
}
}

230
internal/db/bundb/upsert.go Normal file
View file

@ -0,0 +1,230 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"database/sql"
"reflect"
"strings"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
)
// UpsertQuery is a wrapper around an insert query that can update if an insert fails.
// Doesn't implement the full set of Bun query methods, but we can add more if we need them.
// See https://bun.uptrace.dev/guide/query-insert.html#upsert
type UpsertQuery struct {
db bun.IDB
model interface{}
constraints []string
columns []string
}
func NewUpsert(idb bun.IDB) *UpsertQuery {
// note: passing in rawtx as conn iface so no double query-hook
// firing when passed through the bun.Tx.Query___() functions.
return &UpsertQuery{db: idb}
}
// Model sets the model or models to upsert.
func (u *UpsertQuery) Model(model interface{}) *UpsertQuery {
u.model = model
return u
}
// Constraint sets the columns or indexes that are used to check for conflicts.
// This is required.
func (u *UpsertQuery) Constraint(constraints ...string) *UpsertQuery {
u.constraints = constraints
return u
}
// Column sets the columns to update if an insert does't happen.
// If empty, all columns not being used for constraints will be updated.
// Cannot overlap with Constraint.
func (u *UpsertQuery) Column(columns ...string) *UpsertQuery {
u.columns = columns
return u
}
// insertDialect errors if we're using a dialect in which we don't know how to upsert.
func (u *UpsertQuery) insertDialect() error {
dialectName := u.db.Dialect().Name()
switch dialectName {
case dialect.PG, dialect.SQLite:
return nil
default:
// FUTURE: MySQL has its own variation on upserts, but the syntax is different.
return gtserror.Newf("UpsertQuery: upsert not supported by SQL dialect: %s", dialectName)
}
}
// insertConstraints checks that we have constraints and returns them.
func (u *UpsertQuery) insertConstraints() ([]string, error) {
if len(u.constraints) == 0 {
return nil, gtserror.New("UpsertQuery: upserts require at least one constraint column or index, none provided")
}
return u.constraints, nil
}
// insertColumns returns the non-constraint columns we'll be updating.
func (u *UpsertQuery) insertColumns(constraints []string) ([]string, error) {
// Constraints as a set.
constraintSet := make(map[string]struct{}, len(constraints))
for _, constraint := range constraints {
constraintSet[constraint] = struct{}{}
}
var columns []string
var err error
if len(u.columns) == 0 {
columns, err = u.insertColumnsDefault(constraintSet)
} else {
columns, err = u.insertColumnsSpecified(constraintSet)
}
if err != nil {
return nil, err
}
if len(columns) == 0 {
return nil, gtserror.New("UpsertQuery: there are no columns to update when upserting")
}
return columns, nil
}
// hasElem returns whether the type has an element and can call [reflect.Type.Elem] without panicking.
func hasElem(modelType reflect.Type) bool {
switch modelType.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Pointer, reflect.Slice:
return true
default:
return false
}
}
// insertColumnsDefault returns all non-constraint columns from the model schema.
func (u *UpsertQuery) insertColumnsDefault(constraintSet map[string]struct{}) ([]string, error) {
// Get underlying struct type.
modelType := reflect.TypeOf(u.model)
for hasElem(modelType) {
modelType = modelType.Elem()
}
table := u.db.Dialect().Tables().Get(modelType)
if table == nil {
return nil, gtserror.Newf("UpsertQuery: couldn't find the table schema for model: %v", u.model)
}
columns := make([]string, 0, len(u.columns))
for _, field := range table.Fields {
column := field.Name
if _, overlaps := constraintSet[column]; !overlaps {
columns = append(columns, column)
}
}
return columns, nil
}
// insertColumnsSpecified ensures constraints and specified columns to update don't overlap.
func (u *UpsertQuery) insertColumnsSpecified(constraintSet map[string]struct{}) ([]string, error) {
overlapping := make([]string, 0, min(len(u.constraints), len(u.columns)))
for _, column := range u.columns {
if _, overlaps := constraintSet[column]; overlaps {
overlapping = append(overlapping, column)
}
}
if len(overlapping) > 0 {
return nil, gtserror.Newf(
"UpsertQuery: the following columns can't be used for both constraints and columns to update: %s",
strings.Join(overlapping, ", "),
)
}
return u.columns, nil
}
// insert tries to create a Bun insert query from an upsert query.
func (u *UpsertQuery) insertQuery() (*bun.InsertQuery, error) {
var err error
err = u.insertDialect()
if err != nil {
return nil, err
}
constraints, err := u.insertConstraints()
if err != nil {
return nil, err
}
columns, err := u.insertColumns(constraints)
if err != nil {
return nil, err
}
// Build the parts of the query that need us to generate SQL.
constraintIDPlaceholders := make([]string, 0, len(constraints))
constraintIDs := make([]interface{}, 0, len(constraints))
for _, constraint := range constraints {
constraintIDPlaceholders = append(constraintIDPlaceholders, "?")
constraintIDs = append(constraintIDs, bun.Ident(constraint))
}
onSQL := "conflict (" + strings.Join(constraintIDPlaceholders, ", ") + ") do update"
setClauses := make([]string, 0, len(columns))
setIDs := make([]interface{}, 0, 2*len(columns))
for _, column := range columns {
// "excluded" is a special table that contains only the row involved in a conflict.
setClauses = append(setClauses, "? = excluded.?")
setIDs = append(setIDs, bun.Ident(column), bun.Ident(column))
}
setSQL := strings.Join(setClauses, ", ")
insertQuery := u.db.
NewInsert().
Model(u.model).
On(onSQL, constraintIDs...).
Set(setSQL, setIDs...)
return insertQuery, nil
}
// Exec builds a Bun insert query from the upsert query, and executes it.
func (u *UpsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
insertQuery, err := u.insertQuery()
if err != nil {
return nil, err
}
return insertQuery.Exec(ctx, dest...)
}
// Scan builds a Bun insert query from the upsert query, and scans it.
func (u *UpsertQuery) Scan(ctx context.Context, dest ...interface{}) error {
insertQuery, err := u.insertQuery()
if err != nil {
return err
}
return insertQuery.Scan(ctx, dest...)
}

View file

@ -32,6 +32,7 @@ type DB interface {
Emoji
HeaderFilter
Instance
Filter
List
Marker
Media

101
internal/db/filter.go Normal file
View file

@ -0,0 +1,101 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package db
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Filter contains methods for creating, reading, updating, and deleting filters and their keyword and status entries.
type Filter interface {
//<editor-fold desc="Filter methods">
// GetFilterByID gets one filter with the given id.
GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filter, error)
// GetFiltersForAccountID gets all filters owned by the given accountID.
GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error)
// PutFilter puts a new filter in the database, adding any attached keywords or statuses.
// It uses a transaction to ensure no partial updates.
PutFilter(ctx context.Context, filter *gtsmodel.Filter) error
// UpdateFilter updates the given filter,
// upserts any attached keywords and inserts any new statuses (existing statuses cannot be updated),
// and deletes indicated filter keywords and statuses by ID.
// It uses a transaction to ensure no partial updates.
// The column lists are optional; if not specified, all columns will be updated.
UpdateFilter(
ctx context.Context,
filter *gtsmodel.Filter,
filterColumns []string,
filterKeywordColumns []string,
deleteFilterKeywordIDs []string,
deleteFilterStatusIDs []string,
) error
// DeleteFilterByID deletes one filter with the given ID.
// It uses a transaction to ensure no partial updates.
DeleteFilterByID(ctx context.Context, id string) error
//</editor-fold>
//<editor-fold desc="Filter keyword methods">
// GetFilterKeywordByID gets one filter keyword with the given ID.
GetFilterKeywordByID(ctx context.Context, id string) (*gtsmodel.FilterKeyword, error)
// GetFilterKeywordsForFilterID gets filter keywords from the given filterID.
GetFilterKeywordsForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterKeyword, error)
// GetFilterKeywordsForAccountID gets filter keywords from the given accountID.
GetFilterKeywordsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterKeyword, error)
// PutFilterKeyword inserts a single filter keyword into the database.
PutFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error
// UpdateFilterKeyword updates the given filter keyword.
// Columns is optional, if not specified all will be updated.
UpdateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword, columns ...string) error
// DeleteFilterKeywordByID deletes one filter keyword with the given id.
DeleteFilterKeywordByID(ctx context.Context, id string) error
//</editor-fold>
//<editor-fold desc="Filter status methods">
// GetFilterStatusByID gets one filter status with the given ID.
GetFilterStatusByID(ctx context.Context, id string) (*gtsmodel.FilterStatus, error)
// GetFilterStatusesForFilterID gets filter statuses from the given filterID.
GetFilterStatusesForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterStatus, error)
// GetFilterStatusesForAccountID gets filter keywords from the given accountID.
GetFilterStatusesForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterStatus, error)
// PutFilterStatus inserts a single filter status into the database.
PutFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error
// DeleteFilterStatusByID deletes one filter status with the given id.
DeleteFilterStatusByID(ctx context.Context, id string) error
//</editor-fold>
}