mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-12-05 13:08:07 -06:00
[feature] add support for polls + receiving federated status edits (#2330)
This commit is contained in:
parent
7204ccedc3
commit
e9e5dc5a40
84 changed files with 3992 additions and 570 deletions
|
|
@ -42,7 +42,7 @@ type AccountTestSuite struct {
|
|||
func (suite *AccountTestSuite) TestGetAccountStatuses() {
|
||||
statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, false, false, "", "", false, false)
|
||||
suite.NoError(err)
|
||||
suite.Len(statuses, 5)
|
||||
suite.Len(statuses, 6)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() {
|
||||
|
|
@ -65,7 +65,7 @@ func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() {
|
|||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
suite.Len(statuses, 1)
|
||||
suite.Len(statuses, 2)
|
||||
|
||||
// try to get the last page (should be empty)
|
||||
statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false)
|
||||
|
|
@ -76,7 +76,7 @@ func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() {
|
|||
func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() {
|
||||
statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false)
|
||||
suite.NoError(err)
|
||||
suite.Len(statuses, 5)
|
||||
suite.Len(statuses, 6)
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogsPublicOnly() {
|
||||
|
|
@ -306,7 +306,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() {
|
|||
func (suite *AccountTestSuite) TestGetAccountLastPosted() {
|
||||
lastPosted, err := suite.db.GetAccountLastPosted(context.Background(), suite.testAccounts["local_account_1"].ID, false)
|
||||
suite.NoError(err)
|
||||
suite.EqualValues(1653046675, lastPosted.Unix())
|
||||
suite.EqualValues(1653046870, lastPosted.Unix())
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestGetAccountLastPostedWebOnly() {
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ func (suite *BasicTestSuite) TestGetAllStatuses() {
|
|||
s := []*gtsmodel.Status{}
|
||||
err := suite.db.GetAll(context.Background(), &s)
|
||||
suite.NoError(err)
|
||||
suite.Len(s, 17)
|
||||
suite.Len(s, 20)
|
||||
}
|
||||
|
||||
func (suite *BasicTestSuite) TestGetAllNotNull() {
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ type DBService struct {
|
|||
db.Media
|
||||
db.Mention
|
||||
db.Notification
|
||||
db.Poll
|
||||
db.Relationship
|
||||
db.Report
|
||||
db.Rule
|
||||
|
|
@ -203,6 +204,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
|
|||
db: db,
|
||||
state: state,
|
||||
},
|
||||
Poll: &pollDB{
|
||||
db: db,
|
||||
state: state,
|
||||
},
|
||||
Relationship: &relationshipDB{
|
||||
db: db,
|
||||
state: state,
|
||||
|
|
|
|||
|
|
@ -54,6 +54,8 @@ type BunDBStandardTestSuite struct {
|
|||
testMarkers map[string]*gtsmodel.Marker
|
||||
testRules map[string]*gtsmodel.Rule
|
||||
testThreads map[string]*gtsmodel.Thread
|
||||
testPolls map[string]*gtsmodel.Poll
|
||||
testPollVotes map[string]*gtsmodel.PollVote
|
||||
}
|
||||
|
||||
func (suite *BunDBStandardTestSuite) SetupSuite() {
|
||||
|
|
@ -77,6 +79,8 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {
|
|||
suite.testMarkers = testrig.NewTestMarkers()
|
||||
suite.testRules = testrig.NewTestRules()
|
||||
suite.testThreads = testrig.NewTestThreads()
|
||||
suite.testPolls = testrig.NewTestPolls()
|
||||
suite.testPollVotes = testrig.NewTestPollVotes()
|
||||
}
|
||||
|
||||
func (suite *BunDBStandardTestSuite) SetupTest() {
|
||||
|
|
|
|||
|
|
@ -47,13 +47,13 @@ func (suite *InstanceTestSuite) TestCountInstanceUsersRemote() {
|
|||
func (suite *InstanceTestSuite) TestCountInstanceStatuses() {
|
||||
count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost())
|
||||
suite.NoError(err)
|
||||
suite.Equal(16, count)
|
||||
suite.Equal(18, count)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() {
|
||||
count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io")
|
||||
suite.NoError(err)
|
||||
suite.Equal(1, count)
|
||||
suite.Equal(2, count)
|
||||
}
|
||||
|
||||
func (suite *InstanceTestSuite) TestCountInstanceDomains() {
|
||||
|
|
|
|||
|
|
@ -20,10 +20,10 @@ package bundb
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
|
|
@ -54,31 +54,9 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Set the mention originating status.
|
||||
mention.Status, err = m.state.DB.GetStatusByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
mention.StatusID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating mention status: %w", err)
|
||||
}
|
||||
|
||||
// Set the mention origin account model.
|
||||
mention.OriginAccount, err = m.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
mention.OriginAccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating mention origin account: %w", err)
|
||||
}
|
||||
|
||||
// Set the mention target account model.
|
||||
mention.TargetAccount, err = m.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
mention.TargetAccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating mention target account: %w", err)
|
||||
// Further populate the mention fields where applicable.
|
||||
if err := m.PopulateMention(ctx, mention); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mention, nil
|
||||
|
|
@ -102,6 +80,45 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.
|
|||
return mentions, nil
|
||||
}
|
||||
|
||||
func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Mention) (err error) {
|
||||
var errs gtserror.MultiError
|
||||
|
||||
if mention.Status == nil {
|
||||
// Set the mention originating status.
|
||||
mention.Status, err = m.state.DB.GetStatusByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
mention.StatusID,
|
||||
)
|
||||
if err != nil {
|
||||
return gtserror.Newf("error populating mention status: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if mention.OriginAccount == nil {
|
||||
// Set the mention origin account model.
|
||||
mention.OriginAccount, err = m.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
mention.OriginAccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return gtserror.Newf("error populating mention origin account: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if mention.TargetAccount == nil {
|
||||
// Set the mention target account model.
|
||||
mention.TargetAccount, err = m.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
mention.TargetAccountID,
|
||||
)
|
||||
if err != nil {
|
||||
return gtserror.Newf("error populating mention target account: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return errs.Combine()
|
||||
}
|
||||
|
||||
func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error {
|
||||
return m.state.Caches.GTS.Mention().Store(mention, func() error {
|
||||
_, err := m.db.NewInsert().Model(mention).Exec(ctx)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func init() {
|
||||
up := func(ctx context.Context, db *bun.DB) error {
|
||||
// Create `polls` + `poll_votes` tables.
|
||||
for _, model := range []any{
|
||||
>smodel.Poll{},
|
||||
>smodel.PollVote{},
|
||||
} {
|
||||
_, err := db.NewCreateTable().
|
||||
IfNotExists().
|
||||
Model(model).
|
||||
Exec(ctx)
|
||||
if err != nil && !(strings.Contains(err.Error(), "already exists") || strings.Contains(err.Error(), "duplicate column name") || strings.Contains(err.Error(), "SQLSTATE 42701")) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Add the new status `poll_id` column.
|
||||
_, err := db.NewAddColumn().
|
||||
Model(>smodel.Status{}).
|
||||
ColumnExpr("? CHAR(26)", bun.Ident("poll_id")).
|
||||
Exec(ctx)
|
||||
if err != nil && !(strings.Contains(err.Error(), "already exists") || strings.Contains(err.Error(), "duplicate column name") || strings.Contains(err.Error(), "SQLSTATE 42701")) {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
down := func(ctx context.Context, db *bun.DB) error {
|
||||
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := Migrations.Register(up, down); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
536
internal/db/bundb/poll.go
Normal file
536
internal/db/bundb/poll.go
Normal file
|
|
@ -0,0 +1,536 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package bundb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type pollDB struct {
|
||||
db *DB
|
||||
state *state.State
|
||||
}
|
||||
|
||||
func (p *pollDB) GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, error) {
|
||||
return p.getPoll(
|
||||
ctx,
|
||||
"ID",
|
||||
func(poll *gtsmodel.Poll) error {
|
||||
return p.db.NewSelect().
|
||||
Model(poll).
|
||||
Where("? = ?", bun.Ident("poll.id"), id).
|
||||
Scan(ctx)
|
||||
},
|
||||
id,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *pollDB) GetPollByStatusID(ctx context.Context, statusID string) (*gtsmodel.Poll, error) {
|
||||
return p.getPoll(
|
||||
ctx,
|
||||
"StatusID",
|
||||
func(poll *gtsmodel.Poll) error {
|
||||
return p.db.NewSelect().
|
||||
Model(poll).
|
||||
Where("? = ?", bun.Ident("poll.status_id"), statusID).
|
||||
Scan(ctx)
|
||||
},
|
||||
statusID,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) {
|
||||
// Fetch poll from database cache with loader callback
|
||||
poll, err := p.state.Caches.GTS.Poll().Load(lookup, func() (*gtsmodel.Poll, error) {
|
||||
var poll gtsmodel.Poll
|
||||
|
||||
// Not cached! Perform database query.
|
||||
if err := dbQuery(&poll); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure vote slice
|
||||
// is non nil and set.
|
||||
poll.CheckVotes()
|
||||
|
||||
return &poll, nil
|
||||
}, keyParts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gtscontext.Barebones(ctx) {
|
||||
// no need to fully populate.
|
||||
return poll, nil
|
||||
}
|
||||
|
||||
// Further populate the poll fields where applicable.
|
||||
if err := p.PopulatePoll(ctx, poll); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return poll, nil
|
||||
}
|
||||
|
||||
func (p *pollDB) GetOpenPolls(ctx context.Context) ([]*gtsmodel.Poll, error) {
|
||||
var pollIDs []string
|
||||
|
||||
// Select all polls with unset `closed_at` time.
|
||||
if err := p.db.NewSelect().
|
||||
Table("polls").
|
||||
Column("polls.id").
|
||||
Join("JOIN ? ON ? = ?", bun.Ident("statuses"), bun.Ident("polls.id"), bun.Ident("statuses.poll_id")).
|
||||
Where("? = true", bun.Ident("statuses.local")).
|
||||
Where("? IS NULL", bun.Ident("polls.closed_at")).
|
||||
Scan(ctx, &pollIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Preallocate a slice to contain the poll models.
|
||||
polls := make([]*gtsmodel.Poll, 0, len(pollIDs))
|
||||
|
||||
for _, id := range pollIDs {
|
||||
// Attempt to fetch poll from DB.
|
||||
poll, err := p.GetPollByID(ctx, id)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting poll %s: %v", id, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Append poll to return slice.
|
||||
polls = append(polls, poll)
|
||||
}
|
||||
|
||||
return polls, nil
|
||||
}
|
||||
|
||||
func (p *pollDB) PopulatePoll(ctx context.Context, poll *gtsmodel.Poll) error {
|
||||
var (
|
||||
err error
|
||||
errs gtserror.MultiError
|
||||
)
|
||||
|
||||
if poll.Status == nil {
|
||||
// Vote account is not set, fetch from database.
|
||||
poll.Status, err = p.state.DB.GetStatusByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
poll.StatusID,
|
||||
)
|
||||
if err != nil {
|
||||
errs.Appendf("error populating poll status: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return errs.Combine()
|
||||
}
|
||||
|
||||
func (p *pollDB) PutPoll(ctx context.Context, poll *gtsmodel.Poll) error {
|
||||
// Ensure vote slice
|
||||
// is non nil and set.
|
||||
poll.CheckVotes()
|
||||
|
||||
return p.state.Caches.GTS.Poll().Store(poll, func() error {
|
||||
_, err := p.db.NewInsert().Model(poll).Exec(ctx)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...string) error {
|
||||
// Ensure vote slice
|
||||
// is non nil and set.
|
||||
poll.CheckVotes()
|
||||
|
||||
return p.state.Caches.GTS.Poll().Store(poll, func() error {
|
||||
return p.db.RunInTx(ctx, func(tx Tx) error {
|
||||
// Update the status' "updated_at" field.
|
||||
if _, err := tx.NewUpdate().
|
||||
Table("statuses").
|
||||
Where("? = ?", bun.Ident("id"), poll.StatusID).
|
||||
SetColumn("updated_at", "?", time.Now()).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, update poll
|
||||
// columns in database.
|
||||
_, err := tx.NewUpdate().
|
||||
Model(poll).
|
||||
Column(cols...).
|
||||
Where("? = ?", bun.Ident("id"), poll.ID).
|
||||
Exec(ctx)
|
||||
return err
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (p *pollDB) DeletePollByID(ctx context.Context, id string) error {
|
||||
// Delete poll by ID from database.
|
||||
if _, err := p.db.NewDelete().
|
||||
Table("polls").
|
||||
Where("? = ?", bun.Ident("id"), id).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate poll by ID from cache.
|
||||
p.state.Caches.GTS.Poll().Invalidate("ID", id)
|
||||
p.state.Caches.GTS.PollVoteIDs().Invalidate(id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pollDB) GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.PollVote, error) {
|
||||
return p.getPollVote(
|
||||
ctx,
|
||||
"ID",
|
||||
func(vote *gtsmodel.PollVote) error {
|
||||
return p.db.NewSelect().
|
||||
Model(vote).
|
||||
Where("? = ?", bun.Ident("poll_vote.id"), id).
|
||||
Scan(ctx)
|
||||
},
|
||||
id,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) {
|
||||
return p.getPollVote(
|
||||
ctx,
|
||||
"PollID.AccountID",
|
||||
func(vote *gtsmodel.PollVote) error {
|
||||
return p.db.NewSelect().
|
||||
Model(vote).
|
||||
Where("? = ?", bun.Ident("poll_vote.account_id"), accountID).
|
||||
Where("? = ?", bun.Ident("poll_vote.poll_id"), pollID).
|
||||
Scan(ctx)
|
||||
},
|
||||
pollID,
|
||||
accountID,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.PollVote) error, keyParts ...any) (*gtsmodel.PollVote, error) {
|
||||
// Fetch vote from database cache with loader callback
|
||||
vote, err := p.state.Caches.GTS.PollVote().Load(lookup, func() (*gtsmodel.PollVote, error) {
|
||||
var vote gtsmodel.PollVote
|
||||
|
||||
// Not cached! Perform database query.
|
||||
if err := dbQuery(&vote); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &vote, nil
|
||||
}, keyParts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gtscontext.Barebones(ctx) {
|
||||
// no need to fully populate.
|
||||
return vote, nil
|
||||
}
|
||||
|
||||
// Further populate the vote fields where applicable.
|
||||
if err := p.PopulatePollVote(ctx, vote); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return vote, nil
|
||||
}
|
||||
|
||||
func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) {
|
||||
voteIDs, err := p.state.Caches.GTS.PollVoteIDs().Load(pollID, func() ([]string, error) {
|
||||
var voteIDs []string
|
||||
|
||||
// Vote IDs not in cache, perform DB query!
|
||||
q := newSelectPollVotes(p.db, pollID)
|
||||
if _, err := q.Exec(ctx, &voteIDs); // nocollapse
|
||||
err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return voteIDs, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Preallocate slice of expected length.
|
||||
votes := make([]*gtsmodel.PollVote, 0, len(voteIDs))
|
||||
|
||||
for _, id := range voteIDs {
|
||||
// Fetch poll vote model for this ID.
|
||||
vote, err := p.GetPollVoteByID(ctx, id)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting poll vote %s: %v", id, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Append to return slice.
|
||||
votes = append(votes, vote)
|
||||
}
|
||||
|
||||
return votes, nil
|
||||
}
|
||||
|
||||
func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote) error {
|
||||
var (
|
||||
err error
|
||||
errs gtserror.MultiError
|
||||
)
|
||||
|
||||
if vote.Account == nil {
|
||||
// Vote account is not set, fetch from database.
|
||||
vote.Account, err = p.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
vote.AccountID,
|
||||
)
|
||||
if err != nil {
|
||||
errs.Appendf("error populating vote account: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if vote.Poll == nil {
|
||||
// Vote poll is not set, fetch from database.
|
||||
vote.Poll, err = p.GetPollByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
vote.PollID,
|
||||
)
|
||||
if err != nil {
|
||||
errs.Appendf("error populating vote poll: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return errs.Combine()
|
||||
}
|
||||
|
||||
func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error {
|
||||
return p.state.Caches.GTS.PollVote().Store(vote, func() error {
|
||||
return p.db.RunInTx(ctx, func(tx Tx) error {
|
||||
// Try insert vote into database.
|
||||
if _, err := tx.NewInsert().
|
||||
Model(vote).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var poll gtsmodel.Poll
|
||||
|
||||
// Select poll counts from DB.
|
||||
if err := tx.NewSelect().
|
||||
Model(&poll).
|
||||
Where("? = ?", bun.Ident("id"), vote.PollID).
|
||||
Scan(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Increment poll votes for choices.
|
||||
poll.IncrementVotes(vote.Choices)
|
||||
|
||||
// Finally, update the poll entry.
|
||||
_, err := tx.NewUpdate().
|
||||
Model(&poll).
|
||||
Column("votes", "voters").
|
||||
Where("? = ?", bun.Ident("id"), vote.PollID).
|
||||
Exec(ctx)
|
||||
return err
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
|
||||
err := p.db.RunInTx(ctx, func(tx Tx) error {
|
||||
// Delete all vote in poll,
|
||||
// returning all vote choices.
|
||||
switch _, err := tx.NewDelete().
|
||||
Table("poll_votes").
|
||||
Where("? = ?", bun.Ident("poll_id"), pollID).
|
||||
Exec(ctx); {
|
||||
|
||||
case err == nil:
|
||||
// no issue.
|
||||
|
||||
case errors.Is(err, db.ErrNoEntries):
|
||||
// no votes found,
|
||||
// return here.
|
||||
return nil
|
||||
|
||||
default:
|
||||
// irrecoverable.
|
||||
return err
|
||||
}
|
||||
|
||||
var poll gtsmodel.Poll
|
||||
|
||||
// Select poll counts from DB.
|
||||
switch err := tx.NewSelect().
|
||||
Model(&poll).
|
||||
Where("? = ?", bun.Ident("id"), pollID).
|
||||
Scan(ctx); {
|
||||
|
||||
case err == nil:
|
||||
// no issue.
|
||||
|
||||
case errors.Is(err, db.ErrNoEntries):
|
||||
// no votes found,
|
||||
// return here.
|
||||
return nil
|
||||
|
||||
default:
|
||||
// irrecoverable.
|
||||
return err
|
||||
}
|
||||
|
||||
// Zero all counts.
|
||||
poll.ResetVotes()
|
||||
|
||||
// Finally, update the poll entry.
|
||||
_, err := tx.NewUpdate().
|
||||
Model(&poll).
|
||||
Column("votes", "voters").
|
||||
Where("? = ?", bun.Ident("id"), pollID).
|
||||
Exec(ctx)
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate poll vote and poll entry from caches.
|
||||
p.state.Caches.GTS.Poll().Invalidate("ID", pollID)
|
||||
p.state.Caches.GTS.PollVote().Invalidate("PollID", pollID)
|
||||
p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error {
|
||||
err := p.db.RunInTx(ctx, func(tx Tx) error {
|
||||
var choices []int
|
||||
|
||||
// Delete vote in poll by account,
|
||||
// returning the ID + choices of the vote.
|
||||
switch err := tx.NewDelete().
|
||||
Table("poll_votes").
|
||||
Where("? = ?", bun.Ident("poll_id"), pollID).
|
||||
Where("? = ?", bun.Ident("account_id"), accountID).
|
||||
Returning("choices").
|
||||
Scan(ctx, &choices); {
|
||||
|
||||
case err == nil:
|
||||
// no issue.
|
||||
|
||||
case errors.Is(err, db.ErrNoEntries):
|
||||
// no votes found,
|
||||
// return here.
|
||||
return nil
|
||||
|
||||
default:
|
||||
// irrecoverable.
|
||||
return err
|
||||
}
|
||||
|
||||
var poll gtsmodel.Poll
|
||||
|
||||
// Select poll counts from DB.
|
||||
switch err := tx.NewSelect().
|
||||
Model(&poll).
|
||||
Where("? = ?", bun.Ident("id"), pollID).
|
||||
Scan(ctx); {
|
||||
|
||||
case err == nil:
|
||||
// no issue.
|
||||
|
||||
case errors.Is(err, db.ErrNoEntries):
|
||||
// no votes found,
|
||||
// return here.
|
||||
return nil
|
||||
|
||||
default:
|
||||
// irrecoverable.
|
||||
return err
|
||||
}
|
||||
|
||||
// Decrement votes for choices.
|
||||
poll.IncrementVotes(choices)
|
||||
|
||||
// Finally, update the poll entry.
|
||||
_, err := tx.NewUpdate().
|
||||
Model(&poll).
|
||||
Column("votes", "voters").
|
||||
Where("? = ?", bun.Ident("id"), pollID).
|
||||
Exec(ctx)
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate poll vote and poll entry from caches.
|
||||
p.state.Caches.GTS.Poll().Invalidate("ID", pollID)
|
||||
p.state.Caches.GTS.PollVote().Invalidate("PollID.AccountID", pollID, accountID)
|
||||
p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pollDB) DeletePollVotesByAccountID(ctx context.Context, accountID string) error {
|
||||
var pollIDs []string
|
||||
|
||||
// Select all polls this account
|
||||
// has registered a poll vote in.
|
||||
if err := p.db.NewSelect().
|
||||
Table("poll_votes").
|
||||
Column("poll_id").
|
||||
Where("? = ?", bun.Ident("account_id"), accountID).
|
||||
Scan(ctx, &pollIDs); err != nil &&
|
||||
!errors.Is(err, db.ErrNoEntries) {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, id := range pollIDs {
|
||||
// Delete all votes by this account in each of the polls,
|
||||
// this way ensures that all necessary caches are invalidated.
|
||||
if err := p.DeletePollVoteBy(ctx, id, accountID); err != nil {
|
||||
log.Errorf(ctx, "error deleting vote by %s in %s: %v", accountID, id, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// newSelectPollVotes returns a new select query for all rows in the poll_votes table with poll_id = pollID.
|
||||
func newSelectPollVotes(db *DB, pollID string) *bun.SelectQuery {
|
||||
return db.NewSelect().
|
||||
TableExpr("?", bun.Ident("poll_votes")).
|
||||
ColumnExpr("?", bun.Ident("id")).
|
||||
Where("? = ?", bun.Ident("poll_id"), pollID).
|
||||
OrderExpr("? DESC", bun.Ident("id"))
|
||||
}
|
||||
318
internal/db/bundb/poll_test.go
Normal file
318
internal/db/bundb/poll_test.go
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package bundb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||
)
|
||||
|
||||
type PollTestSuite struct {
|
||||
BunDBStandardTestSuite
|
||||
}
|
||||
|
||||
func (suite *PollTestSuite) TestGetPollBy() {
|
||||
t := suite.T()
|
||||
|
||||
// Create a new context for this test.
|
||||
ctx, cncl := context.WithCancel(context.Background())
|
||||
defer cncl()
|
||||
|
||||
// Sentinel error to mark avoiding a test case.
|
||||
sentinelErr := errors.New("sentinel")
|
||||
|
||||
// isEqual checks if 2 poll models are equal.
|
||||
isEqual := func(p1, p2 gtsmodel.Poll) bool {
|
||||
// Clear populated sub-models.
|
||||
p1.Status = nil
|
||||
p2.Status = nil
|
||||
|
||||
// Localize all of the time fields.
|
||||
p1.ExpiresAt = p1.ExpiresAt.Local()
|
||||
p2.ExpiresAt = p2.ExpiresAt.Local()
|
||||
p1.ClosedAt = p1.ClosedAt.Local()
|
||||
p2.ClosedAt = p2.ClosedAt.Local()
|
||||
|
||||
// Perform the comparison.
|
||||
return suite.Equal(p1, p2)
|
||||
}
|
||||
|
||||
for _, poll := range suite.testPolls {
|
||||
for lookup, dbfunc := range map[string]func() (*gtsmodel.Poll, error){
|
||||
"id": func() (*gtsmodel.Poll, error) {
|
||||
return suite.db.GetPollByID(ctx, poll.ID)
|
||||
},
|
||||
|
||||
"status_id": func() (*gtsmodel.Poll, error) {
|
||||
return suite.db.GetPollByStatusID(ctx, poll.StatusID)
|
||||
},
|
||||
} {
|
||||
|
||||
// Clear database caches.
|
||||
suite.state.Caches.Init()
|
||||
|
||||
t.Logf("checking database lookup %q", lookup)
|
||||
|
||||
// Perform database function.
|
||||
checkPoll, err := dbfunc()
|
||||
if err != nil {
|
||||
if err == sentinelErr {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check received account data.
|
||||
if !isEqual(*checkPoll, *poll) {
|
||||
t.Errorf("poll does not contain expected data: %+v", checkPoll)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check that poll source status populated.
|
||||
if poll.StatusID != (*checkPoll).Status.ID {
|
||||
t.Errorf("poll source status not correctly populated for: %+v", poll)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *PollTestSuite) TestGetPollVoteBy() {
|
||||
t := suite.T()
|
||||
|
||||
// Create a new context for this test.
|
||||
ctx, cncl := context.WithCancel(context.Background())
|
||||
defer cncl()
|
||||
|
||||
// Sentinel error to mark avoiding a test case.
|
||||
sentinelErr := errors.New("sentinel")
|
||||
|
||||
// isEqual checks if 2 poll vote models are equal.
|
||||
isEqual := func(v1, v2 gtsmodel.PollVote) bool {
|
||||
// Clear populated sub-models.
|
||||
v1.Poll = nil
|
||||
v2.Poll = nil
|
||||
v1.Account = nil
|
||||
v2.Account = nil
|
||||
|
||||
// Localize all of the time fields.
|
||||
v1.CreatedAt = v1.CreatedAt.Local()
|
||||
v2.CreatedAt = v2.CreatedAt.Local()
|
||||
|
||||
// Perform the comparison.
|
||||
return suite.Equal(v1, v2)
|
||||
}
|
||||
|
||||
for _, vote := range suite.testPollVotes {
|
||||
for lookup, dbfunc := range map[string]func() (*gtsmodel.PollVote, error){
|
||||
"id": func() (*gtsmodel.PollVote, error) {
|
||||
return suite.db.GetPollVoteByID(ctx, vote.ID)
|
||||
},
|
||||
|
||||
"poll_id_account_id": func() (*gtsmodel.PollVote, error) {
|
||||
return suite.db.GetPollVoteBy(ctx, vote.PollID, vote.AccountID)
|
||||
},
|
||||
} {
|
||||
|
||||
// Clear database caches.
|
||||
suite.state.Caches.Init()
|
||||
|
||||
t.Logf("checking database lookup %q", lookup)
|
||||
|
||||
// Perform database function.
|
||||
checkVote, err := dbfunc()
|
||||
if err != nil {
|
||||
if err == sentinelErr {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check received account data.
|
||||
if !isEqual(*checkVote, *vote) {
|
||||
t.Errorf("poll vote does not contain expected data: %+v", checkVote)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check that vote source poll populated.
|
||||
if checkVote.PollID != (*checkVote).Poll.ID {
|
||||
t.Errorf("vote source poll not correctly populated for: %+v", vote)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check that vote author account populated.
|
||||
if checkVote.AccountID != (*checkVote).Account.ID {
|
||||
t.Errorf("vote author account not correctly populated for: %+v", vote)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *PollTestSuite) TestUpdatePoll() {
|
||||
// Create a new context for this test.
|
||||
ctx, cncl := context.WithCancel(context.Background())
|
||||
defer cncl()
|
||||
|
||||
for _, poll := range suite.testPolls {
|
||||
// Take copy of poll.
|
||||
poll := util.Ptr(*poll)
|
||||
|
||||
// Update the poll closed field.
|
||||
poll.ClosedAt = time.Now()
|
||||
|
||||
// Update poll model in the database.
|
||||
err := suite.db.UpdatePoll(ctx, poll)
|
||||
suite.NoError(err)
|
||||
|
||||
// Refetch poll from database to get latest.
|
||||
latest, err := suite.db.GetPollByID(ctx, poll.ID)
|
||||
suite.NoError(err)
|
||||
|
||||
// The latest poll should have updated closedAt.
|
||||
suite.Equal(poll.ClosedAt, latest.ClosedAt)
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *PollTestSuite) TestPutPoll() {
|
||||
// Create a new context for this test.
|
||||
ctx, cncl := context.WithCancel(context.Background())
|
||||
defer cncl()
|
||||
|
||||
for _, poll := range suite.testPolls {
|
||||
// Delete this poll from the database.
|
||||
err := suite.db.DeletePollByID(ctx, poll.ID)
|
||||
suite.NoError(err)
|
||||
|
||||
// Ensure that afterwards we can
|
||||
// enter it again into database.
|
||||
err = suite.db.PutPoll(ctx, poll)
|
||||
|
||||
// Ensure that afterwards we can fetch poll.
|
||||
_, err = suite.db.GetPollByID(ctx, poll.ID)
|
||||
suite.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *PollTestSuite) TestPutPollVote() {
|
||||
// Create a new context for this test.
|
||||
ctx, cncl := context.WithCancel(context.Background())
|
||||
defer cncl()
|
||||
|
||||
// randomChoices generates random vote choices in poll.
|
||||
randomChoices := func(poll *gtsmodel.Poll) []int {
|
||||
var max int
|
||||
if *poll.Multiple {
|
||||
max = len(poll.Options)
|
||||
} else {
|
||||
max = 1
|
||||
}
|
||||
count := 1 + rand.Intn(max)
|
||||
choices := make([]int, count)
|
||||
for i := range choices {
|
||||
choices[i] = rand.Intn(len(poll.Options))
|
||||
}
|
||||
return choices
|
||||
}
|
||||
|
||||
for _, poll := range suite.testPolls {
|
||||
// Create a new vote to insert for poll.
|
||||
vote := >smodel.PollVote{
|
||||
ID: id.NewULID(),
|
||||
Choices: randomChoices(poll),
|
||||
PollID: poll.ID,
|
||||
AccountID: id.NewULID(), // random account, doesn't matter
|
||||
}
|
||||
|
||||
// Insert this new vote into database.
|
||||
err := suite.db.PutPollVote(ctx, vote)
|
||||
suite.NoError(err)
|
||||
|
||||
// Fetch latest version of poll from database.
|
||||
latest, err := suite.db.GetPollByID(ctx, poll.ID)
|
||||
suite.NoError(err)
|
||||
|
||||
// Decr latest version choices by new vote's.
|
||||
for _, choice := range vote.Choices {
|
||||
latest.Votes[choice]--
|
||||
}
|
||||
(*latest.Voters)--
|
||||
|
||||
// Old poll and latest model after decr
|
||||
// should have equal vote + voter counts.
|
||||
suite.Equal(poll.Voters, latest.Voters)
|
||||
suite.Equal(poll.Votes, latest.Votes)
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *PollTestSuite) TestDeletePoll() {
|
||||
// Create a new context for this test.
|
||||
ctx, cncl := context.WithCancel(context.Background())
|
||||
defer cncl()
|
||||
|
||||
for _, poll := range suite.testPolls {
|
||||
// Delete this poll from the database.
|
||||
err := suite.db.DeletePollByID(ctx, poll.ID)
|
||||
suite.NoError(err)
|
||||
|
||||
// Ensure that afterwards we cannot fetch poll.
|
||||
_, err = suite.db.GetPollByID(ctx, poll.ID)
|
||||
suite.ErrorIs(err, db.ErrNoEntries)
|
||||
|
||||
// Or again by the status it's attached to.
|
||||
_, err = suite.db.GetPollByStatusID(ctx, poll.StatusID)
|
||||
suite.ErrorIs(err, db.ErrNoEntries)
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *PollTestSuite) TestDeletePollVotes() {
|
||||
// Create a new context for this test.
|
||||
ctx, cncl := context.WithCancel(context.Background())
|
||||
defer cncl()
|
||||
|
||||
for _, poll := range suite.testPolls {
|
||||
// Delete votes associated with poll from database.
|
||||
err := suite.db.DeletePollVotes(ctx, poll.ID)
|
||||
suite.NoError(err)
|
||||
|
||||
// Fetch latest version of poll from database.
|
||||
poll, err = suite.db.GetPollByID(ctx, poll.ID)
|
||||
suite.NoError(err)
|
||||
|
||||
// Check that poll counts are all zero.
|
||||
suite.Equal(*poll.Voters, 0)
|
||||
suite.Equal(poll.Votes, make([]int, len(poll.Options)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPollTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(PollTestSuite))
|
||||
}
|
||||
|
|
@ -199,7 +199,8 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri
|
|||
|
||||
// Follow IDs not in cache, perform DB query!
|
||||
q := newSelectFollows(r.db, accountID)
|
||||
if _, err := q.Exec(ctx, &followIDs); err != nil {
|
||||
if _, err := q.Exec(ctx, &followIDs); // nocollapse
|
||||
err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -213,7 +214,8 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID
|
|||
|
||||
// Follow IDs not in cache, perform DB query!
|
||||
q := newSelectLocalFollows(r.db, accountID)
|
||||
if _, err := q.Exec(ctx, &followIDs); err != nil {
|
||||
if _, err := q.Exec(ctx, &followIDs); // nocollapse
|
||||
err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -227,7 +229,8 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st
|
|||
|
||||
// Follow IDs not in cache, perform DB query!
|
||||
q := newSelectFollowers(r.db, accountID)
|
||||
if _, err := q.Exec(ctx, &followIDs); err != nil {
|
||||
if _, err := q.Exec(ctx, &followIDs); // nocollapse
|
||||
err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -241,7 +244,8 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account
|
|||
|
||||
// Follow IDs not in cache, perform DB query!
|
||||
q := newSelectLocalFollowers(r.db, accountID)
|
||||
if _, err := q.Exec(ctx, &followIDs); err != nil {
|
||||
if _, err := q.Exec(ctx, &followIDs); // nocollapse
|
||||
err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -255,7 +259,8 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account
|
|||
|
||||
// Follow request IDs not in cache, perform DB query!
|
||||
q := newSelectFollowRequests(r.db, accountID)
|
||||
if _, err := q.Exec(ctx, &followReqIDs); err != nil {
|
||||
if _, err := q.Exec(ctx, &followReqIDs); // nocollapse
|
||||
err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -269,7 +274,8 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco
|
|||
|
||||
// Follow request IDs not in cache, perform DB query!
|
||||
q := newSelectFollowRequesting(r.db, accountID)
|
||||
if _, err := q.Exec(ctx, &followReqIDs); err != nil {
|
||||
if _, err := q.Exec(ctx, &followReqIDs); // nocollapse
|
||||
err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -283,7 +289,8 @@ func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID strin
|
|||
|
||||
// Block IDs not in cache, perform DB query!
|
||||
q := newSelectBlocks(r.db, accountID)
|
||||
if _, err := q.Exec(ctx, &blockIDs); err != nil {
|
||||
if _, err := q.Exec(ctx, &blockIDs); // nocollapse
|
||||
err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -154,17 +154,6 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
|
|||
}
|
||||
}
|
||||
|
||||
if status.InReplyToID != "" && status.InReplyTo == nil {
|
||||
// Status parent is not set, fetch from database.
|
||||
status.InReplyTo, err = s.GetStatusByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
status.InReplyToID,
|
||||
)
|
||||
if err != nil {
|
||||
errs.Appendf("error populating status parent: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if status.InReplyToID != "" {
|
||||
if status.InReplyTo == nil {
|
||||
// Status parent is not set, fetch from database.
|
||||
|
|
@ -213,6 +202,17 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
|
|||
}
|
||||
}
|
||||
|
||||
if status.PollID != "" && status.Poll == nil {
|
||||
// Status poll is not set, fetch from database.
|
||||
status.Poll, err = s.state.DB.GetPollByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
status.PollID,
|
||||
)
|
||||
if err != nil {
|
||||
errs.Appendf("error populating status poll: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !status.AttachmentsPopulated() {
|
||||
// Status attachments are out-of-date with IDs, repopulate.
|
||||
status.Attachments, err = s.state.DB.GetAttachmentsByIDs(
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"codeberg.org/gruf/go-kv"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
|
||||
|
|
@ -73,20 +74,18 @@ func getFutureStatus() *gtsmodel.Status {
|
|||
|
||||
func (suite *TimelineTestSuite) publicCount() int {
|
||||
var publicCount int
|
||||
|
||||
for _, status := range suite.testStatuses {
|
||||
if status.Visibility == gtsmodel.VisibilityPublic &&
|
||||
status.BoostOfID == "" {
|
||||
publicCount++
|
||||
}
|
||||
}
|
||||
|
||||
return publicCount
|
||||
}
|
||||
|
||||
func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID string, minID string, expectedLength int) {
|
||||
if l := len(statuses); l != expectedLength {
|
||||
suite.FailNow("", "expected %d statuses in slice, got %d", expectedLength, l)
|
||||
suite.FailNowf("", "expected %d statuses in slice, got %d", expectedLength, l)
|
||||
} else if l == 0 {
|
||||
// Can't test empty slice.
|
||||
return
|
||||
|
|
@ -98,15 +97,15 @@ func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID
|
|||
id := status.ID
|
||||
|
||||
if id >= maxID {
|
||||
suite.FailNow("", "%s greater than maxID %s", id, maxID)
|
||||
suite.FailNowf("", "%s greater than maxID %s", id, maxID)
|
||||
}
|
||||
|
||||
if id <= minID {
|
||||
suite.FailNow("", "%s smaller than minID %s", id, minID)
|
||||
suite.FailNowf("", "%s smaller than minID %s", id, minID)
|
||||
}
|
||||
|
||||
if id > highest {
|
||||
suite.FailNow("", "statuses in slice were not ordered highest -> lowest ID")
|
||||
suite.FailNowf("", "statuses in slice were not ordered highest -> lowest ID")
|
||||
}
|
||||
|
||||
highest = id
|
||||
|
|
@ -121,6 +120,10 @@ func (suite *TimelineTestSuite) TestGetPublicTimeline() {
|
|||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
suite.T().Log(kv.Field{
|
||||
K: "statuses", V: s,
|
||||
})
|
||||
|
||||
suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount())
|
||||
}
|
||||
|
||||
|
|
@ -154,7 +157,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimeline() {
|
|||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
suite.checkStatuses(s, id.Highest, id.Lowest, 16)
|
||||
suite.checkStatuses(s, id.Highest, id.Lowest, 18)
|
||||
}
|
||||
|
||||
func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() {
|
||||
|
|
@ -186,7 +189,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() {
|
|||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
|
||||
suite.checkStatuses(s, id.Highest, id.Lowest, 6)
|
||||
}
|
||||
|
||||
func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() {
|
||||
|
|
@ -208,7 +211,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() {
|
|||
}
|
||||
|
||||
suite.NotContains(s, futureStatus)
|
||||
suite.checkStatuses(s, id.Highest, id.Lowest, 16)
|
||||
suite.checkStatuses(s, id.Highest, id.Lowest, 18)
|
||||
}
|
||||
|
||||
func (suite *TimelineTestSuite) TestGetHomeTimelineBackToFront() {
|
||||
|
|
@ -239,8 +242,8 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineFromHighest() {
|
|||
}
|
||||
|
||||
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
|
||||
suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID)
|
||||
suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", s[len(s)-1].ID)
|
||||
suite.Equal("01HEN2RZ8BG29Y5Z9VJC73HZW7", s[0].ID)
|
||||
suite.Equal("01FN3VJGFH10KR7S2PB0GFJZYG", s[len(s)-1].ID)
|
||||
}
|
||||
|
||||
func (suite *TimelineTestSuite) TestGetListTimelineNoParams() {
|
||||
|
|
@ -254,7 +257,7 @@ func (suite *TimelineTestSuite) TestGetListTimelineNoParams() {
|
|||
suite.FailNow(err.Error())
|
||||
}
|
||||
|
||||
suite.checkStatuses(s, id.Highest, id.Lowest, 11)
|
||||
suite.checkStatuses(s, id.Highest, id.Lowest, 12)
|
||||
}
|
||||
|
||||
func (suite *TimelineTestSuite) TestGetListTimelineMaxID() {
|
||||
|
|
@ -269,8 +272,8 @@ func (suite *TimelineTestSuite) TestGetListTimelineMaxID() {
|
|||
}
|
||||
|
||||
suite.checkStatuses(s, id.Highest, id.Lowest, 5)
|
||||
suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID)
|
||||
suite.Equal("01FCQSQ667XHJ9AV9T27SJJSX5", s[len(s)-1].ID)
|
||||
suite.Equal("01HEN2PRXT0TF4YDRA64FZZRN7", s[0].ID)
|
||||
suite.Equal("01FF25D5Q0DH7CHD57CTRS6WK0", s[len(s)-1].ID)
|
||||
}
|
||||
|
||||
func (suite *TimelineTestSuite) TestGetListTimelineMinID() {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue