[feature] add support for polls + receiving federated status edits (#2330)

This commit is contained in:
kim 2023-11-08 14:32:17 +00:00 committed by GitHub
commit e9e5dc5a40
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
84 changed files with 3992 additions and 570 deletions

View file

@ -42,7 +42,7 @@ type AccountTestSuite struct {
func (suite *AccountTestSuite) TestGetAccountStatuses() {
statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, false, false, "", "", false, false)
suite.NoError(err)
suite.Len(statuses, 5)
suite.Len(statuses, 6)
}
func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() {
@ -65,7 +65,7 @@ func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() {
if err != nil {
suite.FailNow(err.Error())
}
suite.Len(statuses, 1)
suite.Len(statuses, 2)
// try to get the last page (should be empty)
statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false)
@ -76,7 +76,7 @@ func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() {
func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() {
statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false)
suite.NoError(err)
suite.Len(statuses, 5)
suite.Len(statuses, 6)
}
func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogsPublicOnly() {
@ -306,7 +306,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
func (suite *AccountTestSuite) TestGetAccountLastPosted() {
lastPosted, err := suite.db.GetAccountLastPosted(context.Background(), suite.testAccounts["local_account_1"].ID, false)
suite.NoError(err)
suite.EqualValues(1653046675, lastPosted.Unix())
suite.EqualValues(1653046870, lastPosted.Unix())
}
func (suite *AccountTestSuite) TestGetAccountLastPostedWebOnly() {

View file

@ -121,7 +121,7 @@ func (suite *BasicTestSuite) TestGetAllStatuses() {
s := []*gtsmodel.Status{}
err := suite.db.GetAll(context.Background(), &s)
suite.NoError(err)
suite.Len(s, 17)
suite.Len(s, 20)
}
func (suite *BasicTestSuite) TestGetAllNotNull() {

View file

@ -71,6 +71,7 @@ type DBService struct {
db.Media
db.Mention
db.Notification
db.Poll
db.Relationship
db.Report
db.Rule
@ -203,6 +204,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db,
state: state,
},
Poll: &pollDB{
db: db,
state: state,
},
Relationship: &relationshipDB{
db: db,
state: state,

View file

@ -54,6 +54,8 @@ type BunDBStandardTestSuite struct {
testMarkers map[string]*gtsmodel.Marker
testRules map[string]*gtsmodel.Rule
testThreads map[string]*gtsmodel.Thread
testPolls map[string]*gtsmodel.Poll
testPollVotes map[string]*gtsmodel.PollVote
}
func (suite *BunDBStandardTestSuite) SetupSuite() {
@ -77,6 +79,8 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {
suite.testMarkers = testrig.NewTestMarkers()
suite.testRules = testrig.NewTestRules()
suite.testThreads = testrig.NewTestThreads()
suite.testPolls = testrig.NewTestPolls()
suite.testPollVotes = testrig.NewTestPollVotes()
}
func (suite *BunDBStandardTestSuite) SetupTest() {

View file

@ -47,13 +47,13 @@ func (suite *InstanceTestSuite) TestCountInstanceUsersRemote() {
func (suite *InstanceTestSuite) TestCountInstanceStatuses() {
count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost())
suite.NoError(err)
suite.Equal(16, count)
suite.Equal(18, count)
}
func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() {
count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io")
suite.NoError(err)
suite.Equal(1, count)
suite.Equal(2, count)
}
func (suite *InstanceTestSuite) TestCountInstanceDomains() {

View file

@ -20,10 +20,10 @@ package bundb
import (
"context"
"errors"
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@ -54,31 +54,9 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
return nil, err
}
// Set the mention originating status.
mention.Status, err = m.state.DB.GetStatusByID(
gtscontext.SetBarebones(ctx),
mention.StatusID,
)
if err != nil {
return nil, fmt.Errorf("error populating mention status: %w", err)
}
// Set the mention origin account model.
mention.OriginAccount, err = m.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
mention.OriginAccountID,
)
if err != nil {
return nil, fmt.Errorf("error populating mention origin account: %w", err)
}
// Set the mention target account model.
mention.TargetAccount, err = m.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
mention.TargetAccountID,
)
if err != nil {
return nil, fmt.Errorf("error populating mention target account: %w", err)
// Further populate the mention fields where applicable.
if err := m.PopulateMention(ctx, mention); err != nil {
return nil, err
}
return mention, nil
@ -102,6 +80,45 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.
return mentions, nil
}
func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Mention) (err error) {
var errs gtserror.MultiError
if mention.Status == nil {
// Set the mention originating status.
mention.Status, err = m.state.DB.GetStatusByID(
gtscontext.SetBarebones(ctx),
mention.StatusID,
)
if err != nil {
return gtserror.Newf("error populating mention status: %w", err)
}
}
if mention.OriginAccount == nil {
// Set the mention origin account model.
mention.OriginAccount, err = m.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
mention.OriginAccountID,
)
if err != nil {
return gtserror.Newf("error populating mention origin account: %w", err)
}
}
if mention.TargetAccount == nil {
// Set the mention target account model.
mention.TargetAccount, err = m.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
mention.TargetAccountID,
)
if err != nil {
return gtserror.Newf("error populating mention target account: %w", err)
}
}
return errs.Combine()
}
func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error {
return m.state.Caches.GTS.Mention().Store(mention, func() error {
_, err := m.db.NewInsert().Model(mention).Exec(ctx)

View file

@ -0,0 +1,65 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package migrations
import (
"context"
"strings"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
func init() {
up := func(ctx context.Context, db *bun.DB) error {
// Create `polls` + `poll_votes` tables.
for _, model := range []any{
&gtsmodel.Poll{},
&gtsmodel.PollVote{},
} {
_, err := db.NewCreateTable().
IfNotExists().
Model(model).
Exec(ctx)
if err != nil && !(strings.Contains(err.Error(), "already exists") || strings.Contains(err.Error(), "duplicate column name") || strings.Contains(err.Error(), "SQLSTATE 42701")) {
return err
}
}
// Add the new status `poll_id` column.
_, err := db.NewAddColumn().
Model(&gtsmodel.Status{}).
ColumnExpr("? CHAR(26)", bun.Ident("poll_id")).
Exec(ctx)
if err != nil && !(strings.Contains(err.Error(), "already exists") || strings.Contains(err.Error(), "duplicate column name") || strings.Contains(err.Error(), "SQLSTATE 42701")) {
return err
}
return nil
}
down := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return nil
})
}
if err := Migrations.Register(up, down); err != nil {
panic(err)
}
}

536
internal/db/bundb/poll.go Normal file
View file

@ -0,0 +1,536 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"errors"
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
)
type pollDB struct {
db *DB
state *state.State
}
func (p *pollDB) GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, error) {
return p.getPoll(
ctx,
"ID",
func(poll *gtsmodel.Poll) error {
return p.db.NewSelect().
Model(poll).
Where("? = ?", bun.Ident("poll.id"), id).
Scan(ctx)
},
id,
)
}
func (p *pollDB) GetPollByStatusID(ctx context.Context, statusID string) (*gtsmodel.Poll, error) {
return p.getPoll(
ctx,
"StatusID",
func(poll *gtsmodel.Poll) error {
return p.db.NewSelect().
Model(poll).
Where("? = ?", bun.Ident("poll.status_id"), statusID).
Scan(ctx)
},
statusID,
)
}
func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) {
// Fetch poll from database cache with loader callback
poll, err := p.state.Caches.GTS.Poll().Load(lookup, func() (*gtsmodel.Poll, error) {
var poll gtsmodel.Poll
// Not cached! Perform database query.
if err := dbQuery(&poll); err != nil {
return nil, err
}
// Ensure vote slice
// is non nil and set.
poll.CheckVotes()
return &poll, nil
}, keyParts...)
if err != nil {
return nil, err
}
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return poll, nil
}
// Further populate the poll fields where applicable.
if err := p.PopulatePoll(ctx, poll); err != nil {
return nil, err
}
return poll, nil
}
func (p *pollDB) GetOpenPolls(ctx context.Context) ([]*gtsmodel.Poll, error) {
var pollIDs []string
// Select all polls with unset `closed_at` time.
if err := p.db.NewSelect().
Table("polls").
Column("polls.id").
Join("JOIN ? ON ? = ?", bun.Ident("statuses"), bun.Ident("polls.id"), bun.Ident("statuses.poll_id")).
Where("? = true", bun.Ident("statuses.local")).
Where("? IS NULL", bun.Ident("polls.closed_at")).
Scan(ctx, &pollIDs); err != nil {
return nil, err
}
// Preallocate a slice to contain the poll models.
polls := make([]*gtsmodel.Poll, 0, len(pollIDs))
for _, id := range pollIDs {
// Attempt to fetch poll from DB.
poll, err := p.GetPollByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting poll %s: %v", id, err)
continue
}
// Append poll to return slice.
polls = append(polls, poll)
}
return polls, nil
}
func (p *pollDB) PopulatePoll(ctx context.Context, poll *gtsmodel.Poll) error {
var (
err error
errs gtserror.MultiError
)
if poll.Status == nil {
// Vote account is not set, fetch from database.
poll.Status, err = p.state.DB.GetStatusByID(
gtscontext.SetBarebones(ctx),
poll.StatusID,
)
if err != nil {
errs.Appendf("error populating poll status: %w", err)
}
}
return errs.Combine()
}
func (p *pollDB) PutPoll(ctx context.Context, poll *gtsmodel.Poll) error {
// Ensure vote slice
// is non nil and set.
poll.CheckVotes()
return p.state.Caches.GTS.Poll().Store(poll, func() error {
_, err := p.db.NewInsert().Model(poll).Exec(ctx)
return err
})
}
func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...string) error {
// Ensure vote slice
// is non nil and set.
poll.CheckVotes()
return p.state.Caches.GTS.Poll().Store(poll, func() error {
return p.db.RunInTx(ctx, func(tx Tx) error {
// Update the status' "updated_at" field.
if _, err := tx.NewUpdate().
Table("statuses").
Where("? = ?", bun.Ident("id"), poll.StatusID).
SetColumn("updated_at", "?", time.Now()).
Exec(ctx); err != nil {
return err
}
// Finally, update poll
// columns in database.
_, err := tx.NewUpdate().
Model(poll).
Column(cols...).
Where("? = ?", bun.Ident("id"), poll.ID).
Exec(ctx)
return err
})
})
}
func (p *pollDB) DeletePollByID(ctx context.Context, id string) error {
// Delete poll by ID from database.
if _, err := p.db.NewDelete().
Table("polls").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx); err != nil {
return err
}
// Invalidate poll by ID from cache.
p.state.Caches.GTS.Poll().Invalidate("ID", id)
p.state.Caches.GTS.PollVoteIDs().Invalidate(id)
return nil
}
func (p *pollDB) GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.PollVote, error) {
return p.getPollVote(
ctx,
"ID",
func(vote *gtsmodel.PollVote) error {
return p.db.NewSelect().
Model(vote).
Where("? = ?", bun.Ident("poll_vote.id"), id).
Scan(ctx)
},
id,
)
}
func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) {
return p.getPollVote(
ctx,
"PollID.AccountID",
func(vote *gtsmodel.PollVote) error {
return p.db.NewSelect().
Model(vote).
Where("? = ?", bun.Ident("poll_vote.account_id"), accountID).
Where("? = ?", bun.Ident("poll_vote.poll_id"), pollID).
Scan(ctx)
},
pollID,
accountID,
)
}
func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.PollVote) error, keyParts ...any) (*gtsmodel.PollVote, error) {
// Fetch vote from database cache with loader callback
vote, err := p.state.Caches.GTS.PollVote().Load(lookup, func() (*gtsmodel.PollVote, error) {
var vote gtsmodel.PollVote
// Not cached! Perform database query.
if err := dbQuery(&vote); err != nil {
return nil, err
}
return &vote, nil
}, keyParts...)
if err != nil {
return nil, err
}
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return vote, nil
}
// Further populate the vote fields where applicable.
if err := p.PopulatePollVote(ctx, vote); err != nil {
return nil, err
}
return vote, nil
}
func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) {
voteIDs, err := p.state.Caches.GTS.PollVoteIDs().Load(pollID, func() ([]string, error) {
var voteIDs []string
// Vote IDs not in cache, perform DB query!
q := newSelectPollVotes(p.db, pollID)
if _, err := q.Exec(ctx, &voteIDs); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
return voteIDs, nil
})
if err != nil {
return nil, err
}
// Preallocate slice of expected length.
votes := make([]*gtsmodel.PollVote, 0, len(voteIDs))
for _, id := range voteIDs {
// Fetch poll vote model for this ID.
vote, err := p.GetPollVoteByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting poll vote %s: %v", id, err)
continue
}
// Append to return slice.
votes = append(votes, vote)
}
return votes, nil
}
func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote) error {
var (
err error
errs gtserror.MultiError
)
if vote.Account == nil {
// Vote account is not set, fetch from database.
vote.Account, err = p.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
vote.AccountID,
)
if err != nil {
errs.Appendf("error populating vote account: %w", err)
}
}
if vote.Poll == nil {
// Vote poll is not set, fetch from database.
vote.Poll, err = p.GetPollByID(
gtscontext.SetBarebones(ctx),
vote.PollID,
)
if err != nil {
errs.Appendf("error populating vote poll: %w", err)
}
}
return errs.Combine()
}
func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error {
return p.state.Caches.GTS.PollVote().Store(vote, func() error {
return p.db.RunInTx(ctx, func(tx Tx) error {
// Try insert vote into database.
if _, err := tx.NewInsert().
Model(vote).
Exec(ctx); err != nil {
return err
}
var poll gtsmodel.Poll
// Select poll counts from DB.
if err := tx.NewSelect().
Model(&poll).
Where("? = ?", bun.Ident("id"), vote.PollID).
Scan(ctx); err != nil {
return err
}
// Increment poll votes for choices.
poll.IncrementVotes(vote.Choices)
// Finally, update the poll entry.
_, err := tx.NewUpdate().
Model(&poll).
Column("votes", "voters").
Where("? = ?", bun.Ident("id"), vote.PollID).
Exec(ctx)
return err
})
})
}
func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
err := p.db.RunInTx(ctx, func(tx Tx) error {
// Delete all vote in poll,
// returning all vote choices.
switch _, err := tx.NewDelete().
Table("poll_votes").
Where("? = ?", bun.Ident("poll_id"), pollID).
Exec(ctx); {
case err == nil:
// no issue.
case errors.Is(err, db.ErrNoEntries):
// no votes found,
// return here.
return nil
default:
// irrecoverable.
return err
}
var poll gtsmodel.Poll
// Select poll counts from DB.
switch err := tx.NewSelect().
Model(&poll).
Where("? = ?", bun.Ident("id"), pollID).
Scan(ctx); {
case err == nil:
// no issue.
case errors.Is(err, db.ErrNoEntries):
// no votes found,
// return here.
return nil
default:
// irrecoverable.
return err
}
// Zero all counts.
poll.ResetVotes()
// Finally, update the poll entry.
_, err := tx.NewUpdate().
Model(&poll).
Column("votes", "voters").
Where("? = ?", bun.Ident("id"), pollID).
Exec(ctx)
return err
})
if err != nil {
return err
}
// Invalidate poll vote and poll entry from caches.
p.state.Caches.GTS.Poll().Invalidate("ID", pollID)
p.state.Caches.GTS.PollVote().Invalidate("PollID", pollID)
p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID)
return nil
}
func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error {
err := p.db.RunInTx(ctx, func(tx Tx) error {
var choices []int
// Delete vote in poll by account,
// returning the ID + choices of the vote.
switch err := tx.NewDelete().
Table("poll_votes").
Where("? = ?", bun.Ident("poll_id"), pollID).
Where("? = ?", bun.Ident("account_id"), accountID).
Returning("choices").
Scan(ctx, &choices); {
case err == nil:
// no issue.
case errors.Is(err, db.ErrNoEntries):
// no votes found,
// return here.
return nil
default:
// irrecoverable.
return err
}
var poll gtsmodel.Poll
// Select poll counts from DB.
switch err := tx.NewSelect().
Model(&poll).
Where("? = ?", bun.Ident("id"), pollID).
Scan(ctx); {
case err == nil:
// no issue.
case errors.Is(err, db.ErrNoEntries):
// no votes found,
// return here.
return nil
default:
// irrecoverable.
return err
}
// Decrement votes for choices.
poll.IncrementVotes(choices)
// Finally, update the poll entry.
_, err := tx.NewUpdate().
Model(&poll).
Column("votes", "voters").
Where("? = ?", bun.Ident("id"), pollID).
Exec(ctx)
return err
})
if err != nil {
return err
}
// Invalidate poll vote and poll entry from caches.
p.state.Caches.GTS.Poll().Invalidate("ID", pollID)
p.state.Caches.GTS.PollVote().Invalidate("PollID.AccountID", pollID, accountID)
p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID)
return nil
}
func (p *pollDB) DeletePollVotesByAccountID(ctx context.Context, accountID string) error {
var pollIDs []string
// Select all polls this account
// has registered a poll vote in.
if err := p.db.NewSelect().
Table("poll_votes").
Column("poll_id").
Where("? = ?", bun.Ident("account_id"), accountID).
Scan(ctx, &pollIDs); err != nil &&
!errors.Is(err, db.ErrNoEntries) {
return err
}
for _, id := range pollIDs {
// Delete all votes by this account in each of the polls,
// this way ensures that all necessary caches are invalidated.
if err := p.DeletePollVoteBy(ctx, id, accountID); err != nil {
log.Errorf(ctx, "error deleting vote by %s in %s: %v", accountID, id, err)
}
}
return nil
}
// newSelectPollVotes returns a new select query for all rows in the poll_votes table with poll_id = pollID.
func newSelectPollVotes(db *DB, pollID string) *bun.SelectQuery {
return db.NewSelect().
TableExpr("?", bun.Ident("poll_votes")).
ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("poll_id"), pollID).
OrderExpr("? DESC", bun.Ident("id"))
}

View file

@ -0,0 +1,318 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb_test
import (
"context"
"errors"
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
type PollTestSuite struct {
BunDBStandardTestSuite
}
func (suite *PollTestSuite) TestGetPollBy() {
t := suite.T()
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Sentinel error to mark avoiding a test case.
sentinelErr := errors.New("sentinel")
// isEqual checks if 2 poll models are equal.
isEqual := func(p1, p2 gtsmodel.Poll) bool {
// Clear populated sub-models.
p1.Status = nil
p2.Status = nil
// Localize all of the time fields.
p1.ExpiresAt = p1.ExpiresAt.Local()
p2.ExpiresAt = p2.ExpiresAt.Local()
p1.ClosedAt = p1.ClosedAt.Local()
p2.ClosedAt = p2.ClosedAt.Local()
// Perform the comparison.
return suite.Equal(p1, p2)
}
for _, poll := range suite.testPolls {
for lookup, dbfunc := range map[string]func() (*gtsmodel.Poll, error){
"id": func() (*gtsmodel.Poll, error) {
return suite.db.GetPollByID(ctx, poll.ID)
},
"status_id": func() (*gtsmodel.Poll, error) {
return suite.db.GetPollByStatusID(ctx, poll.StatusID)
},
} {
// Clear database caches.
suite.state.Caches.Init()
t.Logf("checking database lookup %q", lookup)
// Perform database function.
checkPoll, err := dbfunc()
if err != nil {
if err == sentinelErr {
continue
}
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
continue
}
// Check received account data.
if !isEqual(*checkPoll, *poll) {
t.Errorf("poll does not contain expected data: %+v", checkPoll)
continue
}
// Check that poll source status populated.
if poll.StatusID != (*checkPoll).Status.ID {
t.Errorf("poll source status not correctly populated for: %+v", poll)
continue
}
}
}
}
func (suite *PollTestSuite) TestGetPollVoteBy() {
t := suite.T()
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Sentinel error to mark avoiding a test case.
sentinelErr := errors.New("sentinel")
// isEqual checks if 2 poll vote models are equal.
isEqual := func(v1, v2 gtsmodel.PollVote) bool {
// Clear populated sub-models.
v1.Poll = nil
v2.Poll = nil
v1.Account = nil
v2.Account = nil
// Localize all of the time fields.
v1.CreatedAt = v1.CreatedAt.Local()
v2.CreatedAt = v2.CreatedAt.Local()
// Perform the comparison.
return suite.Equal(v1, v2)
}
for _, vote := range suite.testPollVotes {
for lookup, dbfunc := range map[string]func() (*gtsmodel.PollVote, error){
"id": func() (*gtsmodel.PollVote, error) {
return suite.db.GetPollVoteByID(ctx, vote.ID)
},
"poll_id_account_id": func() (*gtsmodel.PollVote, error) {
return suite.db.GetPollVoteBy(ctx, vote.PollID, vote.AccountID)
},
} {
// Clear database caches.
suite.state.Caches.Init()
t.Logf("checking database lookup %q", lookup)
// Perform database function.
checkVote, err := dbfunc()
if err != nil {
if err == sentinelErr {
continue
}
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
continue
}
// Check received account data.
if !isEqual(*checkVote, *vote) {
t.Errorf("poll vote does not contain expected data: %+v", checkVote)
continue
}
// Check that vote source poll populated.
if checkVote.PollID != (*checkVote).Poll.ID {
t.Errorf("vote source poll not correctly populated for: %+v", vote)
continue
}
// Check that vote author account populated.
if checkVote.AccountID != (*checkVote).Account.ID {
t.Errorf("vote author account not correctly populated for: %+v", vote)
continue
}
}
}
}
func (suite *PollTestSuite) TestUpdatePoll() {
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
for _, poll := range suite.testPolls {
// Take copy of poll.
poll := util.Ptr(*poll)
// Update the poll closed field.
poll.ClosedAt = time.Now()
// Update poll model in the database.
err := suite.db.UpdatePoll(ctx, poll)
suite.NoError(err)
// Refetch poll from database to get latest.
latest, err := suite.db.GetPollByID(ctx, poll.ID)
suite.NoError(err)
// The latest poll should have updated closedAt.
suite.Equal(poll.ClosedAt, latest.ClosedAt)
}
}
func (suite *PollTestSuite) TestPutPoll() {
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
for _, poll := range suite.testPolls {
// Delete this poll from the database.
err := suite.db.DeletePollByID(ctx, poll.ID)
suite.NoError(err)
// Ensure that afterwards we can
// enter it again into database.
err = suite.db.PutPoll(ctx, poll)
// Ensure that afterwards we can fetch poll.
_, err = suite.db.GetPollByID(ctx, poll.ID)
suite.NoError(err)
}
}
func (suite *PollTestSuite) TestPutPollVote() {
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// randomChoices generates random vote choices in poll.
randomChoices := func(poll *gtsmodel.Poll) []int {
var max int
if *poll.Multiple {
max = len(poll.Options)
} else {
max = 1
}
count := 1 + rand.Intn(max)
choices := make([]int, count)
for i := range choices {
choices[i] = rand.Intn(len(poll.Options))
}
return choices
}
for _, poll := range suite.testPolls {
// Create a new vote to insert for poll.
vote := &gtsmodel.PollVote{
ID: id.NewULID(),
Choices: randomChoices(poll),
PollID: poll.ID,
AccountID: id.NewULID(), // random account, doesn't matter
}
// Insert this new vote into database.
err := suite.db.PutPollVote(ctx, vote)
suite.NoError(err)
// Fetch latest version of poll from database.
latest, err := suite.db.GetPollByID(ctx, poll.ID)
suite.NoError(err)
// Decr latest version choices by new vote's.
for _, choice := range vote.Choices {
latest.Votes[choice]--
}
(*latest.Voters)--
// Old poll and latest model after decr
// should have equal vote + voter counts.
suite.Equal(poll.Voters, latest.Voters)
suite.Equal(poll.Votes, latest.Votes)
}
}
func (suite *PollTestSuite) TestDeletePoll() {
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
for _, poll := range suite.testPolls {
// Delete this poll from the database.
err := suite.db.DeletePollByID(ctx, poll.ID)
suite.NoError(err)
// Ensure that afterwards we cannot fetch poll.
_, err = suite.db.GetPollByID(ctx, poll.ID)
suite.ErrorIs(err, db.ErrNoEntries)
// Or again by the status it's attached to.
_, err = suite.db.GetPollByStatusID(ctx, poll.StatusID)
suite.ErrorIs(err, db.ErrNoEntries)
}
}
func (suite *PollTestSuite) TestDeletePollVotes() {
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
for _, poll := range suite.testPolls {
// Delete votes associated with poll from database.
err := suite.db.DeletePollVotes(ctx, poll.ID)
suite.NoError(err)
// Fetch latest version of poll from database.
poll, err = suite.db.GetPollByID(ctx, poll.ID)
suite.NoError(err)
// Check that poll counts are all zero.
suite.Equal(*poll.Voters, 0)
suite.Equal(poll.Votes, make([]int, len(poll.Options)))
}
}
func TestPollTestSuite(t *testing.T) {
suite.Run(t, new(PollTestSuite))
}

View file

@ -199,7 +199,8 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri
// Follow IDs not in cache, perform DB query!
q := newSelectFollows(r.db, accountID)
if _, err := q.Exec(ctx, &followIDs); err != nil {
if _, err := q.Exec(ctx, &followIDs); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
@ -213,7 +214,8 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID
// Follow IDs not in cache, perform DB query!
q := newSelectLocalFollows(r.db, accountID)
if _, err := q.Exec(ctx, &followIDs); err != nil {
if _, err := q.Exec(ctx, &followIDs); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
@ -227,7 +229,8 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st
// Follow IDs not in cache, perform DB query!
q := newSelectFollowers(r.db, accountID)
if _, err := q.Exec(ctx, &followIDs); err != nil {
if _, err := q.Exec(ctx, &followIDs); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
@ -241,7 +244,8 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account
// Follow IDs not in cache, perform DB query!
q := newSelectLocalFollowers(r.db, accountID)
if _, err := q.Exec(ctx, &followIDs); err != nil {
if _, err := q.Exec(ctx, &followIDs); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
@ -255,7 +259,8 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account
// Follow request IDs not in cache, perform DB query!
q := newSelectFollowRequests(r.db, accountID)
if _, err := q.Exec(ctx, &followReqIDs); err != nil {
if _, err := q.Exec(ctx, &followReqIDs); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
@ -269,7 +274,8 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco
// Follow request IDs not in cache, perform DB query!
q := newSelectFollowRequesting(r.db, accountID)
if _, err := q.Exec(ctx, &followReqIDs); err != nil {
if _, err := q.Exec(ctx, &followReqIDs); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
@ -283,7 +289,8 @@ func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID strin
// Block IDs not in cache, perform DB query!
q := newSelectBlocks(r.db, accountID)
if _, err := q.Exec(ctx, &blockIDs); err != nil {
if _, err := q.Exec(ctx, &blockIDs); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}

View file

@ -154,17 +154,6 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
}
}
if status.InReplyToID != "" && status.InReplyTo == nil {
// Status parent is not set, fetch from database.
status.InReplyTo, err = s.GetStatusByID(
gtscontext.SetBarebones(ctx),
status.InReplyToID,
)
if err != nil {
errs.Appendf("error populating status parent: %w", err)
}
}
if status.InReplyToID != "" {
if status.InReplyTo == nil {
// Status parent is not set, fetch from database.
@ -213,6 +202,17 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
}
}
if status.PollID != "" && status.Poll == nil {
// Status poll is not set, fetch from database.
status.Poll, err = s.state.DB.GetPollByID(
gtscontext.SetBarebones(ctx),
status.PollID,
)
if err != nil {
errs.Appendf("error populating status poll: %w", err)
}
}
if !status.AttachmentsPopulated() {
// Status attachments are out-of-date with IDs, repopulate.
status.Attachments, err = s.state.DB.GetAttachmentsByIDs(

View file

@ -22,6 +22,7 @@ import (
"testing"
"time"
"codeberg.org/gruf/go-kv"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
@ -73,20 +74,18 @@ func getFutureStatus() *gtsmodel.Status {
func (suite *TimelineTestSuite) publicCount() int {
var publicCount int
for _, status := range suite.testStatuses {
if status.Visibility == gtsmodel.VisibilityPublic &&
status.BoostOfID == "" {
publicCount++
}
}
return publicCount
}
func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID string, minID string, expectedLength int) {
if l := len(statuses); l != expectedLength {
suite.FailNow("", "expected %d statuses in slice, got %d", expectedLength, l)
suite.FailNowf("", "expected %d statuses in slice, got %d", expectedLength, l)
} else if l == 0 {
// Can't test empty slice.
return
@ -98,15 +97,15 @@ func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID
id := status.ID
if id >= maxID {
suite.FailNow("", "%s greater than maxID %s", id, maxID)
suite.FailNowf("", "%s greater than maxID %s", id, maxID)
}
if id <= minID {
suite.FailNow("", "%s smaller than minID %s", id, minID)
suite.FailNowf("", "%s smaller than minID %s", id, minID)
}
if id > highest {
suite.FailNow("", "statuses in slice were not ordered highest -> lowest ID")
suite.FailNowf("", "statuses in slice were not ordered highest -> lowest ID")
}
highest = id
@ -121,6 +120,10 @@ func (suite *TimelineTestSuite) TestGetPublicTimeline() {
suite.FailNow(err.Error())
}
suite.T().Log(kv.Field{
K: "statuses", V: s,
})
suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount())
}
@ -154,7 +157,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimeline() {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 16)
suite.checkStatuses(s, id.Highest, id.Lowest, 18)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() {
@ -186,7 +189,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
suite.checkStatuses(s, id.Highest, id.Lowest, 6)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() {
@ -208,7 +211,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() {
}
suite.NotContains(s, futureStatus)
suite.checkStatuses(s, id.Highest, id.Lowest, 16)
suite.checkStatuses(s, id.Highest, id.Lowest, 18)
}
func (suite *TimelineTestSuite) TestGetHomeTimelineBackToFront() {
@ -239,8 +242,8 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineFromHighest() {
}
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID)
suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", s[len(s)-1].ID)
suite.Equal("01HEN2RZ8BG29Y5Z9VJC73HZW7", s[0].ID)
suite.Equal("01FN3VJGFH10KR7S2PB0GFJZYG", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetListTimelineNoParams() {
@ -254,7 +257,7 @@ func (suite *TimelineTestSuite) TestGetListTimelineNoParams() {
suite.FailNow(err.Error())
}
suite.checkStatuses(s, id.Highest, id.Lowest, 11)
suite.checkStatuses(s, id.Highest, id.Lowest, 12)
}
func (suite *TimelineTestSuite) TestGetListTimelineMaxID() {
@ -269,8 +272,8 @@ func (suite *TimelineTestSuite) TestGetListTimelineMaxID() {
}
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID)
suite.Equal("01FCQSQ667XHJ9AV9T27SJJSX5", s[len(s)-1].ID)
suite.Equal("01HEN2PRXT0TF4YDRA64FZZRN7", s[0].ID)
suite.Equal("01FF25D5Q0DH7CHD57CTRS6WK0", s[len(s)-1].ID)
}
func (suite *TimelineTestSuite) TestGetListTimelineMinID() {