mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 05:52:25 -05:00 
			
		
		
		
	[bugfix] Update poll delete/update db queries (#2361)
This commit is contained in:
		
					parent
					
						
							
								8d0c017cf2
							
						
					
				
			
			
				commit
				
					
						0b99f14d64
					
				
			
		
					 3 changed files with 98 additions and 39 deletions
				
			
		|  | @ -44,7 +44,7 @@ func init() { | ||||||
| 				Table("polls"). | 				Table("polls"). | ||||||
| 				Column("expires_at_new"). | 				Column("expires_at_new"). | ||||||
| 				Set("? = ?", bun.Ident("expires_at_new"), bun.Ident("expires_at")). | 				Set("? = ?", bun.Ident("expires_at_new"), bun.Ident("expires_at")). | ||||||
| 				Where("1"). // bun gets angry performing update over all rows | 				Where("TRUE"). // bun gets angry performing update over all rows | ||||||
| 				Exec(ctx); err != nil { | 				Exec(ctx); err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | @ -341,9 +341,12 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error | ||||||
| 
 | 
 | ||||||
| 			var poll gtsmodel.Poll | 			var poll gtsmodel.Poll | ||||||
| 
 | 
 | ||||||
| 			// Select poll counts from DB. | 			// Select current poll counts from DB, | ||||||
|  | 			// taking minimal columns needed to | ||||||
|  | 			// increment/decrement votes. | ||||||
| 			if err := tx.NewSelect(). | 			if err := tx.NewSelect(). | ||||||
| 				Model(&poll). | 				Model(&poll). | ||||||
|  | 				Column("options", "votes", "voters"). | ||||||
| 				Where("? = ?", bun.Ident("id"), vote.PollID). | 				Where("? = ?", bun.Ident("id"), vote.PollID). | ||||||
| 				Scan(ctx); err != nil { | 				Scan(ctx); err != nil { | ||||||
| 				return err | 				return err | ||||||
|  | @ -365,31 +368,35 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error | ||||||
| 
 | 
 | ||||||
| func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { | func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { | ||||||
| 	err := p.db.RunInTx(ctx, func(tx Tx) error { | 	err := p.db.RunInTx(ctx, func(tx Tx) error { | ||||||
| 		// Delete all vote in poll, | 		// Delete all votes in poll. | ||||||
| 		// returning all vote choices. | 		res, err := tx.NewDelete(). | ||||||
| 		switch _, err := tx.NewDelete(). |  | ||||||
| 			Table("poll_votes"). | 			Table("poll_votes"). | ||||||
| 			Where("? = ?", bun.Ident("poll_id"), pollID). | 			Where("? = ?", bun.Ident("poll_id"), pollID). | ||||||
| 			Exec(ctx); { | 			Exec(ctx) | ||||||
| 
 | 		if err != nil { | ||||||
| 		case err == nil: | 			// irrecoverable | ||||||
| 			// no issue. |  | ||||||
| 
 |  | ||||||
| 		case errors.Is(err, db.ErrNoEntries): |  | ||||||
| 			// no votes found, |  | ||||||
| 			// return here. |  | ||||||
| 			return nil |  | ||||||
| 
 |  | ||||||
| 		default: |  | ||||||
| 			// irrecoverable. |  | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		var poll gtsmodel.Poll | 		ra, err := res.RowsAffected() | ||||||
|  | 		if err != nil { | ||||||
|  | 			// irrecoverable | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 		// Select poll counts from DB. | 		if ra == 0 { | ||||||
|  | 			// No poll votes deleted, | ||||||
|  | 			// nothing to update. | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Select current poll counts from DB, | ||||||
|  | 		// taking minimal columns needed to | ||||||
|  | 		// increment/decrement votes. | ||||||
|  | 		var poll gtsmodel.Poll | ||||||
| 		switch err := tx.NewSelect(). | 		switch err := tx.NewSelect(). | ||||||
| 			Model(&poll). | 			Model(&poll). | ||||||
|  | 			Column("options", "votes", "voters"). | ||||||
| 			Where("? = ?", bun.Ident("id"), pollID). | 			Where("? = ?", bun.Ident("id"), pollID). | ||||||
| 			Scan(ctx); { | 			Scan(ctx); { | ||||||
| 
 | 
 | ||||||
|  | @ -410,7 +417,7 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { | ||||||
| 		poll.ResetVotes() | 		poll.ResetVotes() | ||||||
| 
 | 
 | ||||||
| 		// Finally, update the poll entry. | 		// Finally, update the poll entry. | ||||||
| 		_, err := tx.NewUpdate(). | 		_, err = tx.NewUpdate(). | ||||||
| 			Model(&poll). | 			Model(&poll). | ||||||
| 			Column("votes", "voters"). | 			Column("votes", "voters"). | ||||||
| 			Where("? = ?", bun.Ident("id"), pollID). | 			Where("? = ?", bun.Ident("id"), pollID). | ||||||
|  | @ -432,35 +439,37 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { | ||||||
| 
 | 
 | ||||||
| func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error { | func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error { | ||||||
| 	err := p.db.RunInTx(ctx, func(tx Tx) error { | 	err := p.db.RunInTx(ctx, func(tx Tx) error { | ||||||
| 		var choices []int | 		// Slice should only ever be of length | ||||||
|  | 		// 0 or 1; it's a slice of slices only | ||||||
|  | 		// because we can't LIMIT deletes to 1. | ||||||
|  | 		var choicesSl [][]int | ||||||
| 
 | 
 | ||||||
| 		// Delete vote in poll by account, | 		// Delete vote in poll by account, | ||||||
| 		// returning the ID + choices of the vote. | 		// returning the ID + choices of the vote. | ||||||
| 		switch err := tx.NewDelete(). | 		if err := tx.NewDelete(). | ||||||
| 			Table("poll_votes"). | 			Table("poll_votes"). | ||||||
| 			Where("? = ?", bun.Ident("poll_id"), pollID). | 			Where("? = ?", bun.Ident("poll_id"), pollID). | ||||||
| 			Where("? = ?", bun.Ident("account_id"), accountID). | 			Where("? = ?", bun.Ident("account_id"), accountID). | ||||||
| 			Returning("choices"). | 			Returning("?", bun.Ident("choices")). | ||||||
| 			Scan(ctx, &choices); { | 			Scan(ctx, &choicesSl); err != nil { | ||||||
| 
 |  | ||||||
| 		case err == nil: |  | ||||||
| 			// no issue. |  | ||||||
| 
 |  | ||||||
| 		case errors.Is(err, db.ErrNoEntries): |  | ||||||
| 			// no votes found, |  | ||||||
| 			// return here. |  | ||||||
| 			return nil |  | ||||||
| 
 |  | ||||||
| 		default: |  | ||||||
| 			// irrecoverable. | 			// irrecoverable. | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		var poll gtsmodel.Poll | 		if len(choicesSl) != 1 { | ||||||
|  | 			// No poll votes by this | ||||||
|  | 			// acct on this poll. | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 		choices := choicesSl[0] | ||||||
| 
 | 
 | ||||||
| 		// Select poll counts from DB. | 		// Select current poll counts from DB, | ||||||
|  | 		// taking minimal columns needed to | ||||||
|  | 		// increment/decrement votes. | ||||||
|  | 		var poll gtsmodel.Poll | ||||||
| 		switch err := tx.NewSelect(). | 		switch err := tx.NewSelect(). | ||||||
| 			Model(&poll). | 			Model(&poll). | ||||||
|  | 			Column("options", "votes", "voters"). | ||||||
| 			Where("? = ?", bun.Ident("id"), pollID). | 			Where("? = ?", bun.Ident("id"), pollID). | ||||||
| 			Scan(ctx); { | 			Scan(ctx); { | ||||||
| 
 | 
 | ||||||
|  | @ -468,7 +477,7 @@ func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID | ||||||
| 			// no issue. | 			// no issue. | ||||||
| 
 | 
 | ||||||
| 		case errors.Is(err, db.ErrNoEntries): | 		case errors.Is(err, db.ErrNoEntries): | ||||||
| 			// no votes found, | 			// no poll found, | ||||||
| 			// return here. | 			// return here. | ||||||
| 			return nil | 			return nil | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -26,6 +26,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/stretchr/testify/suite" | 	"github.com/stretchr/testify/suite" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/id" | 	"github.com/superseriousbusiness/gotosocial/internal/id" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/util" | 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||||||
|  | @ -304,15 +305,64 @@ func (suite *PollTestSuite) TestDeletePollVotes() { | ||||||
| 		suite.NoError(err) | 		suite.NoError(err) | ||||||
| 
 | 
 | ||||||
| 		// Fetch latest version of poll from database. | 		// Fetch latest version of poll from database. | ||||||
| 		poll, err = suite.db.GetPollByID(ctx, poll.ID) | 		poll, err = suite.db.GetPollByID( | ||||||
|  | 			gtscontext.SetBarebones(ctx), | ||||||
|  | 			poll.ID, | ||||||
|  | 		) | ||||||
| 		suite.NoError(err) | 		suite.NoError(err) | ||||||
| 
 | 
 | ||||||
| 		// Check that poll counts are all zero. | 		// Check that poll counts are all zero. | ||||||
| 		suite.Equal(*poll.Voters, 0) | 		suite.Equal(*poll.Voters, 0) | ||||||
| 		suite.Equal(poll.Votes, make([]int, len(poll.Options))) | 		suite.Equal(make([]int, len(poll.Options)), poll.Votes) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (suite *PollTestSuite) TestDeletePollVotesNoPoll() { | ||||||
|  | 	// Create a new context for this test. | ||||||
|  | 	ctx, cncl := context.WithCancel(context.Background()) | ||||||
|  | 	defer cncl() | ||||||
|  | 
 | ||||||
|  | 	// Try to delete votes of nonexistent poll. | ||||||
|  | 	nonPollID := "01HF6V4XWTSZWJ80JNPPDTD4DB" | ||||||
|  | 
 | ||||||
|  | 	err := suite.db.DeletePollVotes(ctx, nonPollID) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) TestDeletePollVotesBy() { | ||||||
|  | 	ctx, cncl := context.WithCancel(context.Background()) | ||||||
|  | 	defer cncl() | ||||||
|  | 
 | ||||||
|  | 	for _, vote := range suite.testPollVotes { | ||||||
|  | 		// Fetch before version of pollBefore from database. | ||||||
|  | 		pollBefore, err := suite.db.GetPollByID(ctx, vote.PollID) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Delete this poll vote. | ||||||
|  | 		err = suite.db.DeletePollVoteBy(ctx, vote.PollID, vote.AccountID) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Fetch after version of poll from database. | ||||||
|  | 		pollAfter, err := suite.db.GetPollByID(ctx, vote.PollID) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Voters count should be reduced by 1. | ||||||
|  | 		suite.Equal(*pollBefore.Voters-1, *pollAfter.Voters) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) TestDeletePollVotesByNoAccount() { | ||||||
|  | 	ctx, cncl := context.WithCancel(context.Background()) | ||||||
|  | 	defer cncl() | ||||||
|  | 
 | ||||||
|  | 	// Try to delete a poll by nonexisting account. | ||||||
|  | 	pollID := suite.testPolls["local_account_1_status_6_poll"].ID | ||||||
|  | 	nonAccountID := "01HF6T545G1G8ZNMY1S3ZXJ608" | ||||||
|  | 
 | ||||||
|  | 	err := suite.db.DeletePollVoteBy(ctx, pollID, nonAccountID) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestPollTestSuite(t *testing.T) { | func TestPollTestSuite(t *testing.T) { | ||||||
| 	suite.Run(t, new(PollTestSuite)) | 	suite.Run(t, new(PollTestSuite)) | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue