mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-30 21:12:24 -05:00 
			
		
		
		
	[feature] add support for polls + receiving federated status edits (#2330)
This commit is contained in:
		
					parent
					
						
							
								7204ccedc3
							
						
					
				
			
			
				commit
				
					
						e9e5dc5a40
					
				
			
		
					 84 changed files with 3992 additions and 570 deletions
				
			
		|  | @ -200,6 +200,11 @@ var Start action.GTSAction = func(ctx context.Context) error { | ||||||
| 	state.Workers.ProcessFromClientAPI = processor.Workers().ProcessFromClientAPI | 	state.Workers.ProcessFromClientAPI = processor.Workers().ProcessFromClientAPI | ||||||
| 	state.Workers.ProcessFromFediAPI = processor.Workers().ProcessFromFediAPI | 	state.Workers.ProcessFromFediAPI = processor.Workers().ProcessFromFediAPI | ||||||
| 
 | 
 | ||||||
|  | 	// Schedule tasks for all existing poll expiries. | ||||||
|  | 	if err := processor.Polls().ScheduleAll(ctx); err != nil { | ||||||
|  | 		return fmt.Errorf("error scheduling poll expiries: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	/* | 	/* | ||||||
| 		HTTP router initialization | 		HTTP router initialization | ||||||
| 	*/ | 	*/ | ||||||
|  |  | ||||||
|  | @ -1416,6 +1416,10 @@ definitions: | ||||||
|                 description: 'Statistics about the instance: number of posts, accounts, etc.' |                 description: 'Statistics about the instance: number of posts, accounts, etc.' | ||||||
|                 type: object |                 type: object | ||||||
|                 x-go-name: Stats |                 x-go-name: Stats | ||||||
|  |             terms: | ||||||
|  |                 description: Terms and conditions for accounts on this instance. | ||||||
|  |                 type: string | ||||||
|  |                 x-go-name: Terms | ||||||
|             thumbnail: |             thumbnail: | ||||||
|                 description: URL of the instance avatar/banner image. |                 description: URL of the instance avatar/banner image. | ||||||
|                 example: https://example.org/files/instance/thumbnail.jpeg |                 example: https://example.org/files/instance/thumbnail.jpeg | ||||||
|  | @ -1533,6 +1537,10 @@ definitions: | ||||||
|                 example: https://github.com/superseriousbusiness/gotosocial |                 example: https://github.com/superseriousbusiness/gotosocial | ||||||
|                 type: string |                 type: string | ||||||
|                 x-go-name: SourceURL |                 x-go-name: SourceURL | ||||||
|  |             terms: | ||||||
|  |                 description: Terms and conditions for accounts on this instance. | ||||||
|  |                 type: string | ||||||
|  |                 x-go-name: Terms | ||||||
|             thumbnail: |             thumbnail: | ||||||
|                 $ref: '#/definitions/instanceV2Thumbnail' |                 $ref: '#/definitions/instanceV2Thumbnail' | ||||||
|             title: |             title: | ||||||
|  | @ -1993,7 +2001,7 @@ definitions: | ||||||
|                 type: boolean |                 type: boolean | ||||||
|                 x-go-name: Expired |                 x-go-name: Expired | ||||||
|             expires_at: |             expires_at: | ||||||
|                 description: When the poll ends. (ISO 8601 Datetime), or null if the poll does not end |                 description: When the poll ends. (ISO 8601 Datetime). | ||||||
|                 type: string |                 type: string | ||||||
|                 x-go-name: ExpiresAt |                 x-go-name: ExpiresAt | ||||||
|             id: |             id: | ||||||
|  | @ -2008,7 +2016,7 @@ definitions: | ||||||
|             options: |             options: | ||||||
|                 description: Possible answers for the poll. |                 description: Possible answers for the poll. | ||||||
|                 items: |                 items: | ||||||
|                     $ref: '#/definitions/pollOptions' |                     $ref: '#/definitions/pollOption' | ||||||
|                 type: array |                 type: array | ||||||
|                 x-go-name: Options |                 x-go-name: Options | ||||||
|             own_votes: |             own_votes: | ||||||
|  | @ -2023,7 +2031,7 @@ definitions: | ||||||
|                 type: boolean |                 type: boolean | ||||||
|                 x-go-name: Voted |                 x-go-name: Voted | ||||||
|             voters_count: |             voters_count: | ||||||
|                 description: How many unique accounts have voted on a multiple-choice poll. Null if multiple is false. |                 description: How many unique accounts have voted on a multiple-choice poll. | ||||||
|                 format: int64 |                 format: int64 | ||||||
|                 type: integer |                 type: integer | ||||||
|                 x-go-name: VotersCount |                 x-go-name: VotersCount | ||||||
|  | @ -2036,22 +2044,20 @@ definitions: | ||||||
|         type: object |         type: object | ||||||
|         x-go-name: Poll |         x-go-name: Poll | ||||||
|         x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model |         x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model | ||||||
|     pollOptions: |     pollOption: | ||||||
|         properties: |         properties: | ||||||
|             title: |             title: | ||||||
|                 description: The text value of the poll option. String. |                 description: The text value of the poll option. String. | ||||||
|                 type: string |                 type: string | ||||||
|                 x-go-name: Title |                 x-go-name: Title | ||||||
|             votes_count: |             votes_count: | ||||||
|                 description: |- |                 description: The number of received votes for this option. | ||||||
|                     The number of received votes for this option. |  | ||||||
|                     Number, or null if results are not published yet. |  | ||||||
|                 format: int64 |                 format: int64 | ||||||
|                 type: integer |                 type: integer | ||||||
|                 x-go-name: VotesCount |                 x-go-name: VotesCount | ||||||
|         title: PollOptions represents the current vote counts for different poll options. |         title: PollOption represents the current vote counts for different poll options. | ||||||
|         type: object |         type: object | ||||||
|         x-go-name: PollOptions |         x-go-name: PollOption | ||||||
|         x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model |         x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model | ||||||
|     report: |     report: | ||||||
|         properties: |         properties: | ||||||
|  | @ -5986,6 +5992,76 @@ paths: | ||||||
|             summary: Clear/delete all notifications for currently authorized user. |             summary: Clear/delete all notifications for currently authorized user. | ||||||
|             tags: |             tags: | ||||||
|                 - notifications |                 - notifications | ||||||
|  |     /api/v1/polls/{id}: | ||||||
|  |         get: | ||||||
|  |             operationId: poll | ||||||
|  |             parameters: | ||||||
|  |                 - description: Target poll ID. | ||||||
|  |                   in: path | ||||||
|  |                   name: id | ||||||
|  |                   required: true | ||||||
|  |                   type: string | ||||||
|  |             produces: | ||||||
|  |                 - application/json | ||||||
|  |             responses: | ||||||
|  |                 "200": | ||||||
|  |                     description: The requested poll. | ||||||
|  |                     schema: | ||||||
|  |                         $ref: '#/definitions/poll' | ||||||
|  |                 "400": | ||||||
|  |                     description: bad request | ||||||
|  |                 "401": | ||||||
|  |                     description: unauthorized | ||||||
|  |                 "403": | ||||||
|  |                     description: forbidden | ||||||
|  |                 "404": | ||||||
|  |                     description: not found | ||||||
|  |                 "406": | ||||||
|  |                     description: not acceptable | ||||||
|  |                 "500": | ||||||
|  |                     description: internal server error | ||||||
|  |             security: | ||||||
|  |                 - OAuth2 Bearer: | ||||||
|  |                     - read:statuses | ||||||
|  |             summary: View poll with given ID. | ||||||
|  |             tags: | ||||||
|  |                 - polls | ||||||
|  |     /api/v1/polls/{id}/vote: | ||||||
|  |         post: | ||||||
|  |             operationId: poll | ||||||
|  |             parameters: | ||||||
|  |                 - description: Target poll ID. | ||||||
|  |                   in: path | ||||||
|  |                   name: id | ||||||
|  |                   required: true | ||||||
|  |                   type: string | ||||||
|  |             produces: | ||||||
|  |                 - application/json | ||||||
|  |             responses: | ||||||
|  |                 "200": | ||||||
|  |                     description: The updated poll with user vote choices. | ||||||
|  |                     schema: | ||||||
|  |                         $ref: '#/definitions/poll' | ||||||
|  |                 "400": | ||||||
|  |                     description: bad request | ||||||
|  |                 "401": | ||||||
|  |                     description: unauthorized | ||||||
|  |                 "403": | ||||||
|  |                     description: forbidden | ||||||
|  |                 "404": | ||||||
|  |                     description: not found | ||||||
|  |                 "406": | ||||||
|  |                     description: not acceptable | ||||||
|  |                 "422": | ||||||
|  |                     description: unprocessable entity | ||||||
|  |                 "500": | ||||||
|  |                     description: internal server error | ||||||
|  |             security: | ||||||
|  |                 - OAuth2 Bearer: | ||||||
|  |                     - write:statuses | ||||||
|  |             summary: Vote with choices in the given poll. | ||||||
|  |             tags: | ||||||
|  |                 - polls | ||||||
|     /api/v1/preferences: |     /api/v1/preferences: | ||||||
|         get: |         get: | ||||||
|             description: |- |             description: |- | ||||||
|  |  | ||||||
|  | @ -22,6 +22,7 @@ import ( | ||||||
| 	"crypto/rsa" | 	"crypto/rsa" | ||||||
| 	"crypto/x509" | 	"crypto/x509" | ||||||
| 	"encoding/pem" | 	"encoding/pem" | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | @ -1112,6 +1113,91 @@ func ExtractSharedInbox(withEndpoints WithEndpoints) *url.URL { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ExtractPoll extracts a placeholder Poll from Pollable interface, with available options and flags populated. | ||||||
|  | func ExtractPoll(poll Pollable) (*gtsmodel.Poll, error) { | ||||||
|  | 	var closed time.Time | ||||||
|  | 
 | ||||||
|  | 	// Extract the options (votes if any) and 'multiple choice' flag. | ||||||
|  | 	options, votes, multi, err := ExtractPollOptions(poll) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check if counts have been hidden from us. | ||||||
|  | 	hideCounts := len(options) != len(votes) | ||||||
|  | 	if hideCounts { | ||||||
|  | 
 | ||||||
|  | 		// Zero out all votes. | ||||||
|  | 		for i := range votes { | ||||||
|  | 			votes[i] = 0 | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Extract the poll end time. | ||||||
|  | 	endTime := GetEndTime(poll) | ||||||
|  | 	if endTime.IsZero() { | ||||||
|  | 		return nil, errors.New("no poll end time specified") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Extract the poll closed time. | ||||||
|  | 	closedSlice := GetClosed(poll) | ||||||
|  | 	if len(closedSlice) == 1 { | ||||||
|  | 		closed = closedSlice[0] | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Extract the number of voters. | ||||||
|  | 	voters := GetVotersCount(poll) | ||||||
|  | 
 | ||||||
|  | 	return >smodel.Poll{ | ||||||
|  | 		Options:    options, | ||||||
|  | 		Multiple:   &multi, | ||||||
|  | 		HideCounts: &hideCounts, | ||||||
|  | 		Votes:      votes, | ||||||
|  | 		Voters:     &voters, | ||||||
|  | 		ExpiresAt:  endTime, | ||||||
|  | 		ClosedAt:   closed, | ||||||
|  | 	}, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ExtractPollOptions extracts poll option name strings, and the 'multiple choice flag' property value from Pollable. | ||||||
|  | func ExtractPollOptions(poll Pollable) (names []string, votes []int, multi bool, err error) { | ||||||
|  | 	var errs gtserror.MultiError | ||||||
|  | 
 | ||||||
|  | 	// Iterate the oneOf property and gather poll single-choice options. | ||||||
|  | 	IterateOneOf(poll, func(iter vocab.ActivityStreamsOneOfPropertyIterator) { | ||||||
|  | 		name, count, err := extractPollOption(iter.GetType()) | ||||||
|  | 		if err != nil { | ||||||
|  | 			errs.Append(err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		names = append(names, name) | ||||||
|  | 		if count != nil { | ||||||
|  | 			votes = append(votes, *count) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	if len(names) > 0 || len(errs) > 0 { | ||||||
|  | 		return names, votes, false, errs.Combine() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Iterate the anyOf property and gather poll multi-choice options. | ||||||
|  | 	IterateAnyOf(poll, func(iter vocab.ActivityStreamsAnyOfPropertyIterator) { | ||||||
|  | 		name, count, err := extractPollOption(iter.GetType()) | ||||||
|  | 		if err != nil { | ||||||
|  | 			errs.Append(err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		names = append(names, name) | ||||||
|  | 		if count != nil { | ||||||
|  | 			votes = append(votes, *count) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	if len(names) > 0 || len(errs) > 0 { | ||||||
|  | 		return names, votes, true, errs.Combine() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil, nil, false, errors.New("poll without options") | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // IterateOneOf will attempt to extract oneOf property from given interface, and passes each iterated item to function. | // IterateOneOf will attempt to extract oneOf property from given interface, and passes each iterated item to function. | ||||||
| func IterateOneOf(withOneOf WithOneOf, foreach func(vocab.ActivityStreamsOneOfPropertyIterator)) { | func IterateOneOf(withOneOf WithOneOf, foreach func(vocab.ActivityStreamsOneOfPropertyIterator)) { | ||||||
| 	if foreach == nil { | 	if foreach == nil { | ||||||
|  | @ -1158,6 +1244,41 @@ func IterateAnyOf(withAnyOf WithAnyOf, foreach func(vocab.ActivityStreamsAnyOfPr | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // extractPollOption extracts a usable poll option name from vocab.Type, or error. | ||||||
|  | func extractPollOption(t vocab.Type) (name string, votes *int, err error) { | ||||||
|  | 	// Check fulfills PollOptionable type | ||||||
|  | 	// (this accounts for nil input type). | ||||||
|  | 	optionable, ok := t.(PollOptionable) | ||||||
|  | 	if !ok { | ||||||
|  | 		return "", nil, fmt.Errorf("incorrect option type: %T", t) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Extract PollOption from interface. | ||||||
|  | 	name = ExtractName(optionable) | ||||||
|  | 	if name == "" { | ||||||
|  | 		return "", nil, errors.New("empty option name") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check PollOptionable for attached 'replies' property. | ||||||
|  | 	repliesProp := optionable.GetActivityStreamsReplies() | ||||||
|  | 	if repliesProp != nil { | ||||||
|  | 
 | ||||||
|  | 		// Get repliesProp as the AS collection type it should be. | ||||||
|  | 		collection := repliesProp.GetActivityStreamsCollection() | ||||||
|  | 		if collection != nil { | ||||||
|  | 
 | ||||||
|  | 			// Extract integer value from the collection 'totalItems' property. | ||||||
|  | 			totalItemsProp := collection.GetActivityStreamsTotalItems() | ||||||
|  | 			if totalItemsProp != nil { | ||||||
|  | 				i := totalItemsProp.Get() | ||||||
|  | 				votes = &i | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return name, votes, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // isPublic checks if at least one entry in the given | // isPublic checks if at least one entry in the given | ||||||
| // uris slice equals the activitystreams public uri. | // uris slice equals the activitystreams public uri. | ||||||
| func isPublic(uris []*url.URL) bool { | func isPublic(uris []*url.URL) bool { | ||||||
|  |  | ||||||
|  | @ -24,6 +24,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/superseriousbusiness/activity/streams" | 	"github.com/superseriousbusiness/activity/streams" | ||||||
| 	"github.com/superseriousbusiness/activity/streams/vocab" | 	"github.com/superseriousbusiness/activity/streams/vocab" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // MustGet performs the given 'Get$Property(with) (T, error)' signature function, panicking on error. | // MustGet performs the given 'Get$Property(with) (T, error)' signature function, panicking on error. | ||||||
|  | @ -36,12 +37,12 @@ import ( | ||||||
| // } | // } | ||||||
| 
 | 
 | ||||||
| // MustSet performs the given 'Set$Property(with, T) error' signature function, panicking on error. | // MustSet performs the given 'Set$Property(with, T) error' signature function, panicking on error. | ||||||
| // func MustSet[W, T any](fn func(W, T) error, with W, value T) { | func MustSet[W, T any](fn func(W, T) error, with W, value T) { | ||||||
| // 	err := fn(with, value) | 	err := fn(with, value) | ||||||
| // 	if err != nil { | 	if err != nil { | ||||||
| // 		panicfAt(3, "error setting property on %T: %w", with, err) | 		panicfAt(3, "error setting property on %T: %w", with, err) | ||||||
| // 	} | 	} | ||||||
| // } | } | ||||||
| 
 | 
 | ||||||
| // AppendSet performs the given 'Append$Property(with, ...T) error' signature function, panicking on error. | // AppendSet performs the given 'Append$Property(with, ...T) error' signature function, panicking on error. | ||||||
| // func MustAppend[W, T any](fn func(W, ...T) error, with W, values ...T) { | // func MustAppend[W, T any](fn func(W, ...T) error, with W, values ...T) { | ||||||
|  | @ -320,6 +321,6 @@ func appendIRIs[T TypeOrIRI](getProp func() Property[T], iri ...*url.URL) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // panicfAt panics with a call to gtserror.NewfAt() with given args (+1 to calldepth). | // panicfAt panics with a call to gtserror.NewfAt() with given args (+1 to calldepth). | ||||||
| // func panicfAt(calldepth int, msg string, args ...any) { | func panicfAt(calldepth int, msg string, args ...any) { | ||||||
| // 	panic(gtserror.NewfAt(calldepth+1, msg, args...)) | 	panic(gtserror.NewfAt(calldepth+1, msg, args...)) | ||||||
| // } | } | ||||||
|  |  | ||||||
|  | @ -25,6 +25,7 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/stretchr/testify/suite" | 	"github.com/stretchr/testify/suite" | ||||||
|  | @ -101,12 +102,6 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() { | ||||||
| 	signedRequest := derefRequests["foss_satan_dereference_zork_outbox_first"] | 	signedRequest := derefRequests["foss_satan_dereference_zork_outbox_first"] | ||||||
| 	targetAccount := suite.testAccounts["local_account_1"] | 	targetAccount := suite.testAccounts["local_account_1"] | ||||||
| 
 | 
 | ||||||
| 	tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) |  | ||||||
| 	federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager) |  | ||||||
| 	emailSender := testrig.NewEmailSender("../../../../web/template/", nil) |  | ||||||
| 	processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) |  | ||||||
| 	userModule := users.New(processor) |  | ||||||
| 
 |  | ||||||
| 	// setup request | 	// setup request | ||||||
| 	recorder := httptest.NewRecorder() | 	recorder := httptest.NewRecorder() | ||||||
| 	ctx, _ := testrig.CreateGinTestContext(recorder, nil) | 	ctx, _ := testrig.CreateGinTestContext(recorder, nil) | ||||||
|  | @ -128,7 +123,7 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// trigger the function being tested | 	// trigger the function being tested | ||||||
| 	userModule.OutboxGETHandler(ctx) | 	suite.userModule.OutboxGETHandler(ctx) | ||||||
| 
 | 
 | ||||||
| 	// check response | 	// check response | ||||||
| 	suite.EqualValues(http.StatusOK, recorder.Code) | 	suite.EqualValues(http.StatusOK, recorder.Code) | ||||||
|  | @ -137,6 +132,7 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() { | ||||||
| 	defer result.Body.Close() | 	defer result.Body.Close() | ||||||
| 	b, err := ioutil.ReadAll(result.Body) | 	b, err := ioutil.ReadAll(result.Body) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
|  | 	b = checkDropPublished(suite.T(), b, "orderedItems") | ||||||
| 	dst := new(bytes.Buffer) | 	dst := new(bytes.Buffer) | ||||||
| 	err = json.Indent(dst, b, "", "  ") | 	err = json.Indent(dst, b, "", "  ") | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
|  | @ -147,9 +143,8 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() { | ||||||
|   "orderedItems": { |   "orderedItems": { | ||||||
|     "actor": "http://localhost:8080/users/the_mighty_zork", |     "actor": "http://localhost:8080/users/the_mighty_zork", | ||||||
|     "cc": "http://localhost:8080/users/the_mighty_zork/followers", |     "cc": "http://localhost:8080/users/the_mighty_zork/followers", | ||||||
|     "id": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/activity", |     "id": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/activity#Create", | ||||||
|     "object": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY", |     "object": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY", | ||||||
|     "published": "2021-10-20T10:40:37Z", |  | ||||||
|     "to": "https://www.w3.org/ns/activitystreams#Public", |     "to": "https://www.w3.org/ns/activitystreams#Public", | ||||||
|     "type": "Create" |     "type": "Create" | ||||||
|   }, |   }, | ||||||
|  | @ -175,12 +170,6 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() { | ||||||
| 	signedRequest := derefRequests["foss_satan_dereference_zork_outbox_next"] | 	signedRequest := derefRequests["foss_satan_dereference_zork_outbox_next"] | ||||||
| 	targetAccount := suite.testAccounts["local_account_1"] | 	targetAccount := suite.testAccounts["local_account_1"] | ||||||
| 
 | 
 | ||||||
| 	tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) |  | ||||||
| 	federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager) |  | ||||||
| 	emailSender := testrig.NewEmailSender("../../../../web/template/", nil) |  | ||||||
| 	processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) |  | ||||||
| 	userModule := users.New(processor) |  | ||||||
| 
 |  | ||||||
| 	// setup request | 	// setup request | ||||||
| 	recorder := httptest.NewRecorder() | 	recorder := httptest.NewRecorder() | ||||||
| 	ctx, _ := testrig.CreateGinTestContext(recorder, nil) | 	ctx, _ := testrig.CreateGinTestContext(recorder, nil) | ||||||
|  | @ -206,7 +195,7 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// trigger the function being tested | 	// trigger the function being tested | ||||||
| 	userModule.OutboxGETHandler(ctx) | 	suite.userModule.OutboxGETHandler(ctx) | ||||||
| 
 | 
 | ||||||
| 	// check response | 	// check response | ||||||
| 	suite.EqualValues(http.StatusOK, recorder.Code) | 	suite.EqualValues(http.StatusOK, recorder.Code) | ||||||
|  | @ -240,3 +229,30 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() { | ||||||
| func TestOutboxGetTestSuite(t *testing.T) { | func TestOutboxGetTestSuite(t *testing.T) { | ||||||
| 	suite.Run(t, new(OutboxGetTestSuite)) | 	suite.Run(t, new(OutboxGetTestSuite)) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // checkDropPublished checks the published field at given key position for formatting, and drops from the JSON. | ||||||
|  | // This is useful because the published property is usually set to the current time string (which is difficult to test). | ||||||
|  | func checkDropPublished(t *testing.T, b []byte, at ...string) []byte { | ||||||
|  | 	m := make(map[string]any) | ||||||
|  | 	if err := json.Unmarshal(b, &m); err != nil { | ||||||
|  | 		t.Fatalf("error unmarshaling json into map: %v", err) | ||||||
|  | 	} | ||||||
|  | 	mm := m | ||||||
|  | 	for _, key := range at { | ||||||
|  | 		switch vt := mm[key].(type) { | ||||||
|  | 		case map[string]any: | ||||||
|  | 			mm = vt | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if s, ok := mm["published"].(string); !ok { | ||||||
|  | 		t.Fatal("missing published data on json") | ||||||
|  | 	} else if _, err := time.Parse(time.RFC3339, s); err != nil { | ||||||
|  | 		t.Fatalf("error parsing published time: %v", err) | ||||||
|  | 	} | ||||||
|  | 	delete(mm, "published") | ||||||
|  | 	b, err := json.Marshal(m) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("error remarshaling json: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return b | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -101,12 +101,6 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() { | ||||||
| 	targetAccount := suite.testAccounts["local_account_1"] | 	targetAccount := suite.testAccounts["local_account_1"] | ||||||
| 	targetStatus := suite.testStatuses["local_account_1_status_1"] | 	targetStatus := suite.testStatuses["local_account_1_status_1"] | ||||||
| 
 | 
 | ||||||
| 	tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) |  | ||||||
| 	federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager) |  | ||||||
| 	emailSender := testrig.NewEmailSender("../../../../web/template/", nil) |  | ||||||
| 	processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) |  | ||||||
| 	userModule := users.New(processor) |  | ||||||
| 
 |  | ||||||
| 	// setup request | 	// setup request | ||||||
| 	recorder := httptest.NewRecorder() | 	recorder := httptest.NewRecorder() | ||||||
| 	ctx, _ := testrig.CreateGinTestContext(recorder, nil) | 	ctx, _ := testrig.CreateGinTestContext(recorder, nil) | ||||||
|  | @ -132,7 +126,7 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// trigger the function being tested | 	// trigger the function being tested | ||||||
| 	userModule.StatusRepliesGETHandler(ctx) | 	suite.userModule.StatusRepliesGETHandler(ctx) | ||||||
| 
 | 
 | ||||||
| 	// check response | 	// check response | ||||||
| 	suite.EqualValues(http.StatusOK, recorder.Code) | 	suite.EqualValues(http.StatusOK, recorder.Code) | ||||||
|  | @ -165,12 +159,6 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() { | ||||||
| 	targetAccount := suite.testAccounts["local_account_1"] | 	targetAccount := suite.testAccounts["local_account_1"] | ||||||
| 	targetStatus := suite.testStatuses["local_account_1_status_1"] | 	targetStatus := suite.testStatuses["local_account_1_status_1"] | ||||||
| 
 | 
 | ||||||
| 	tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) |  | ||||||
| 	federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager) |  | ||||||
| 	emailSender := testrig.NewEmailSender("../../../../web/template/", nil) |  | ||||||
| 	processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) |  | ||||||
| 	userModule := users.New(processor) |  | ||||||
| 
 |  | ||||||
| 	// setup request | 	// setup request | ||||||
| 	recorder := httptest.NewRecorder() | 	recorder := httptest.NewRecorder() | ||||||
| 	ctx, _ := testrig.CreateGinTestContext(recorder, nil) | 	ctx, _ := testrig.CreateGinTestContext(recorder, nil) | ||||||
|  | @ -196,7 +184,7 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// trigger the function being tested | 	// trigger the function being tested | ||||||
| 	userModule.StatusRepliesGETHandler(ctx) | 	suite.userModule.StatusRepliesGETHandler(ctx) | ||||||
| 
 | 
 | ||||||
| 	// check response | 	// check response | ||||||
| 	suite.EqualValues(http.StatusOK, recorder.Code) | 	suite.EqualValues(http.StatusOK, recorder.Code) | ||||||
|  |  | ||||||
|  | @ -91,11 +91,14 @@ func (suite *AuthStandardTestSuite) SetupTest() { | ||||||
| 	suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil) | 	suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil) | ||||||
| 	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) | 	suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) | ||||||
| 	suite.authModule = auth.New(suite.db, suite.processor, suite.idp) | 	suite.authModule = auth.New(suite.db, suite.processor, suite.idp) | ||||||
|  | 
 | ||||||
| 	testrig.StandardDBSetup(suite.db, suite.testAccounts) | 	testrig.StandardDBSetup(suite.db, suite.testAccounts) | ||||||
|  | 	testrig.StartWorkers(&suite.state) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *AuthStandardTestSuite) TearDownTest() { | func (suite *AuthStandardTestSuite) TearDownTest() { | ||||||
| 	testrig.StandardDBTeardown(suite.db) | 	testrig.StandardDBTeardown(suite.db) | ||||||
|  | 	testrig.StopWorkers(&suite.state) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *AuthStandardTestSuite) newContext(requestMethod string, requestPath string, requestBody []byte, bodyContentType string) (*gin.Context, *httptest.ResponseRecorder) { | func (suite *AuthStandardTestSuite) newContext(requestMethod string, requestPath string, requestBody []byte, bodyContentType string) (*gin.Context, *httptest.ResponseRecorder) { | ||||||
|  |  | ||||||
|  | @ -36,6 +36,7 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/api/client/markers" | 	"github.com/superseriousbusiness/gotosocial/internal/api/client/markers" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/api/client/media" | 	"github.com/superseriousbusiness/gotosocial/internal/api/client/media" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/api/client/notifications" | 	"github.com/superseriousbusiness/gotosocial/internal/api/client/notifications" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/api/client/polls" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/api/client/preferences" | 	"github.com/superseriousbusiness/gotosocial/internal/api/client/preferences" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/api/client/reports" | 	"github.com/superseriousbusiness/gotosocial/internal/api/client/reports" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/api/client/search" | 	"github.com/superseriousbusiness/gotosocial/internal/api/client/search" | ||||||
|  | @ -68,6 +69,7 @@ type Client struct { | ||||||
| 	markers        *markers.Module        // api/v1/markers | 	markers        *markers.Module        // api/v1/markers | ||||||
| 	media          *media.Module          // api/v1/media, api/v2/media | 	media          *media.Module          // api/v1/media, api/v2/media | ||||||
| 	notifications  *notifications.Module  // api/v1/notifications | 	notifications  *notifications.Module  // api/v1/notifications | ||||||
|  | 	polls          *polls.Module          // api/v1/polls | ||||||
| 	preferences    *preferences.Module    // api/v1/preferences | 	preferences    *preferences.Module    // api/v1/preferences | ||||||
| 	reports        *reports.Module        // api/v1/reports | 	reports        *reports.Module        // api/v1/reports | ||||||
| 	search         *search.Module         // api/v1/search, api/v2/search | 	search         *search.Module         // api/v1/search, api/v2/search | ||||||
|  | @ -109,6 +111,7 @@ func (c *Client) Route(r router.Router, m ...gin.HandlerFunc) { | ||||||
| 	c.markers.Route(h) | 	c.markers.Route(h) | ||||||
| 	c.media.Route(h) | 	c.media.Route(h) | ||||||
| 	c.notifications.Route(h) | 	c.notifications.Route(h) | ||||||
|  | 	c.polls.Route(h) | ||||||
| 	c.preferences.Route(h) | 	c.preferences.Route(h) | ||||||
| 	c.reports.Route(h) | 	c.reports.Route(h) | ||||||
| 	c.search.Route(h) | 	c.search.Route(h) | ||||||
|  | @ -138,6 +141,7 @@ func NewClient(db db.DB, p *processing.Processor) *Client { | ||||||
| 		markers:        markers.New(p), | 		markers:        markers.New(p), | ||||||
| 		media:          media.New(p), | 		media:          media.New(p), | ||||||
| 		notifications:  notifications.New(p), | 		notifications:  notifications.New(p), | ||||||
|  | 		polls:          polls.New(p), | ||||||
| 		preferences:    preferences.New(p), | 		preferences:    preferences.New(p), | ||||||
| 		reports:        reports.New(p), | 		reports:        reports.New(p), | ||||||
| 		search:         search.New(p), | 		search:         search.New(p), | ||||||
|  |  | ||||||
|  | @ -79,7 +79,7 @@ func (suite *AccountVerifyTestSuite) TestAccountVerifyGet() { | ||||||
| 	suite.Equal("http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", apimodelAccount.HeaderStatic) | 	suite.Equal("http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", apimodelAccount.HeaderStatic) | ||||||
| 	suite.Equal(2, apimodelAccount.FollowersCount) | 	suite.Equal(2, apimodelAccount.FollowersCount) | ||||||
| 	suite.Equal(2, apimodelAccount.FollowingCount) | 	suite.Equal(2, apimodelAccount.FollowingCount) | ||||||
| 	suite.Equal(5, apimodelAccount.StatusesCount) | 	suite.Equal(6, apimodelAccount.StatusesCount) | ||||||
| 	suite.EqualValues(gtsmodel.VisibilityPublic, apimodelAccount.Source.Privacy) | 	suite.EqualValues(gtsmodel.VisibilityPublic, apimodelAccount.Source.Privacy) | ||||||
| 	suite.Equal(testAccount.Language, apimodelAccount.Source.Language) | 	suite.Equal(testAccount.Language, apimodelAccount.Source.Language) | ||||||
| 	suite.Equal(testAccount.NoteRaw, apimodelAccount.Source.Note) | 	suite.Equal(testAccount.NoteRaw, apimodelAccount.Source.Note) | ||||||
|  |  | ||||||
|  | @ -180,8 +180,8 @@ func (suite *ReportsGetTestSuite) TestReportsGetAll() { | ||||||
|         "header_static": "http://localhost:8080/assets/default_header.png", |         "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|         "followers_count": 0, |         "followers_count": 0, | ||||||
|         "following_count": 0, |         "following_count": 0, | ||||||
|         "statuses_count": 1, |         "statuses_count": 2, | ||||||
|         "last_status_at": "2021-09-20T10:40:37.000Z", |         "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|         "emojis": [], |         "emojis": [], | ||||||
|         "fields": [] |         "fields": [] | ||||||
|       } |       } | ||||||
|  | @ -221,8 +221,8 @@ func (suite *ReportsGetTestSuite) TestReportsGetAll() { | ||||||
|         "header_static": "http://localhost:8080/assets/default_header.png", |         "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|         "followers_count": 1, |         "followers_count": 1, | ||||||
|         "following_count": 1, |         "following_count": 1, | ||||||
|         "statuses_count": 7, |         "statuses_count": 8, | ||||||
|         "last_status_at": "2021-10-20T10:40:37.000Z", |         "last_status_at": "2021-07-28T08:40:37.000Z", | ||||||
|         "emojis": [], |         "emojis": [], | ||||||
|         "fields": [ |         "fields": [ | ||||||
|           { |           { | ||||||
|  | @ -382,8 +382,8 @@ func (suite *ReportsGetTestSuite) TestReportsGetAll() { | ||||||
|         "header_static": "http://localhost:8080/assets/default_header.png", |         "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|         "followers_count": 1, |         "followers_count": 1, | ||||||
|         "following_count": 1, |         "following_count": 1, | ||||||
|         "statuses_count": 7, |         "statuses_count": 8, | ||||||
|         "last_status_at": "2021-10-20T10:40:37.000Z", |         "last_status_at": "2021-07-28T08:40:37.000Z", | ||||||
|         "emojis": [], |         "emojis": [], | ||||||
|         "fields": [ |         "fields": [ | ||||||
|           { |           { | ||||||
|  | @ -438,8 +438,8 @@ func (suite *ReportsGetTestSuite) TestReportsGetAll() { | ||||||
|         "header_static": "http://localhost:8080/assets/default_header.png", |         "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|         "followers_count": 0, |         "followers_count": 0, | ||||||
|         "following_count": 0, |         "following_count": 0, | ||||||
|         "statuses_count": 1, |         "statuses_count": 2, | ||||||
|         "last_status_at": "2021-09-20T10:40:37.000Z", |         "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|         "emojis": [], |         "emojis": [], | ||||||
|         "fields": [] |         "fields": [] | ||||||
|       } |       } | ||||||
|  | @ -485,8 +485,8 @@ func (suite *ReportsGetTestSuite) TestReportsGetAll() { | ||||||
|           "header_static": "http://localhost:8080/assets/default_header.png", |           "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|           "followers_count": 0, |           "followers_count": 0, | ||||||
|           "following_count": 0, |           "following_count": 0, | ||||||
|           "statuses_count": 1, |           "statuses_count": 2, | ||||||
|           "last_status_at": "2021-09-20T10:40:37.000Z", |           "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|           "emojis": [], |           "emojis": [], | ||||||
|           "fields": [] |           "fields": [] | ||||||
|         }, |         }, | ||||||
|  | @ -603,8 +603,8 @@ func (suite *ReportsGetTestSuite) TestReportsGetCreatedByAccount() { | ||||||
|         "header_static": "http://localhost:8080/assets/default_header.png", |         "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|         "followers_count": 1, |         "followers_count": 1, | ||||||
|         "following_count": 1, |         "following_count": 1, | ||||||
|         "statuses_count": 7, |         "statuses_count": 8, | ||||||
|         "last_status_at": "2021-10-20T10:40:37.000Z", |         "last_status_at": "2021-07-28T08:40:37.000Z", | ||||||
|         "emojis": [], |         "emojis": [], | ||||||
|         "fields": [ |         "fields": [ | ||||||
|           { |           { | ||||||
|  | @ -659,8 +659,8 @@ func (suite *ReportsGetTestSuite) TestReportsGetCreatedByAccount() { | ||||||
|         "header_static": "http://localhost:8080/assets/default_header.png", |         "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|         "followers_count": 0, |         "followers_count": 0, | ||||||
|         "following_count": 0, |         "following_count": 0, | ||||||
|         "statuses_count": 1, |         "statuses_count": 2, | ||||||
|         "last_status_at": "2021-09-20T10:40:37.000Z", |         "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|         "emojis": [], |         "emojis": [], | ||||||
|         "fields": [] |         "fields": [] | ||||||
|       } |       } | ||||||
|  | @ -706,8 +706,8 @@ func (suite *ReportsGetTestSuite) TestReportsGetCreatedByAccount() { | ||||||
|           "header_static": "http://localhost:8080/assets/default_header.png", |           "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|           "followers_count": 0, |           "followers_count": 0, | ||||||
|           "following_count": 0, |           "following_count": 0, | ||||||
|           "statuses_count": 1, |           "statuses_count": 2, | ||||||
|           "last_status_at": "2021-09-20T10:40:37.000Z", |           "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|           "emojis": [], |           "emojis": [], | ||||||
|           "fields": [] |           "fields": [] | ||||||
|         }, |         }, | ||||||
|  | @ -824,8 +824,8 @@ func (suite *ReportsGetTestSuite) TestReportsGetTargetAccount() { | ||||||
|         "header_static": "http://localhost:8080/assets/default_header.png", |         "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|         "followers_count": 1, |         "followers_count": 1, | ||||||
|         "following_count": 1, |         "following_count": 1, | ||||||
|         "statuses_count": 7, |         "statuses_count": 8, | ||||||
|         "last_status_at": "2021-10-20T10:40:37.000Z", |         "last_status_at": "2021-07-28T08:40:37.000Z", | ||||||
|         "emojis": [], |         "emojis": [], | ||||||
|         "fields": [ |         "fields": [ | ||||||
|           { |           { | ||||||
|  | @ -880,8 +880,8 @@ func (suite *ReportsGetTestSuite) TestReportsGetTargetAccount() { | ||||||
|         "header_static": "http://localhost:8080/assets/default_header.png", |         "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|         "followers_count": 0, |         "followers_count": 0, | ||||||
|         "following_count": 0, |         "following_count": 0, | ||||||
|         "statuses_count": 1, |         "statuses_count": 2, | ||||||
|         "last_status_at": "2021-09-20T10:40:37.000Z", |         "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|         "emojis": [], |         "emojis": [], | ||||||
|         "fields": [] |         "fields": [] | ||||||
|       } |       } | ||||||
|  | @ -927,8 +927,8 @@ func (suite *ReportsGetTestSuite) TestReportsGetTargetAccount() { | ||||||
|           "header_static": "http://localhost:8080/assets/default_header.png", |           "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|           "followers_count": 0, |           "followers_count": 0, | ||||||
|           "following_count": 0, |           "following_count": 0, | ||||||
|           "statuses_count": 1, |           "statuses_count": 2, | ||||||
|           "last_status_at": "2021-09-20T10:40:37.000Z", |           "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|           "emojis": [], |           "emojis": [], | ||||||
|           "fields": [] |           "fields": [] | ||||||
|         }, |         }, | ||||||
|  |  | ||||||
|  | @ -130,7 +130,7 @@ func (suite *InstancePatchTestSuite) TestInstancePatch1() { | ||||||
|   }, |   }, | ||||||
|   "stats": { |   "stats": { | ||||||
|     "domain_count": 2, |     "domain_count": 2, | ||||||
|     "status_count": 16, |     "status_count": 18, | ||||||
|     "user_count": 4 |     "user_count": 4 | ||||||
|   }, |   }, | ||||||
|   "thumbnail": "http://localhost:8080/assets/logo.png", |   "thumbnail": "http://localhost:8080/assets/logo.png", | ||||||
|  | @ -244,7 +244,7 @@ func (suite *InstancePatchTestSuite) TestInstancePatch2() { | ||||||
|   }, |   }, | ||||||
|   "stats": { |   "stats": { | ||||||
|     "domain_count": 2, |     "domain_count": 2, | ||||||
|     "status_count": 16, |     "status_count": 18, | ||||||
|     "user_count": 4 |     "user_count": 4 | ||||||
|   }, |   }, | ||||||
|   "thumbnail": "http://localhost:8080/assets/logo.png", |   "thumbnail": "http://localhost:8080/assets/logo.png", | ||||||
|  | @ -358,7 +358,7 @@ func (suite *InstancePatchTestSuite) TestInstancePatch3() { | ||||||
|   }, |   }, | ||||||
|   "stats": { |   "stats": { | ||||||
|     "domain_count": 2, |     "domain_count": 2, | ||||||
|     "status_count": 16, |     "status_count": 18, | ||||||
|     "user_count": 4 |     "user_count": 4 | ||||||
|   }, |   }, | ||||||
|   "thumbnail": "http://localhost:8080/assets/logo.png", |   "thumbnail": "http://localhost:8080/assets/logo.png", | ||||||
|  | @ -523,7 +523,7 @@ func (suite *InstancePatchTestSuite) TestInstancePatch6() { | ||||||
|   }, |   }, | ||||||
|   "stats": { |   "stats": { | ||||||
|     "domain_count": 2, |     "domain_count": 2, | ||||||
|     "status_count": 16, |     "status_count": 18, | ||||||
|     "user_count": 4 |     "user_count": 4 | ||||||
|   }, |   }, | ||||||
|   "thumbnail": "http://localhost:8080/assets/logo.png", |   "thumbnail": "http://localhost:8080/assets/logo.png", | ||||||
|  | @ -659,7 +659,7 @@ func (suite *InstancePatchTestSuite) TestInstancePatch8() { | ||||||
|   }, |   }, | ||||||
|   "stats": { |   "stats": { | ||||||
|     "domain_count": 2, |     "domain_count": 2, | ||||||
|     "status_count": 16, |     "status_count": 18, | ||||||
|     "user_count": 4 |     "user_count": 4 | ||||||
|   }, |   }, | ||||||
|   "thumbnail": "http://localhost:8080/fileserver/01AY6P665V14JJR0AFVRT7311Y/attachment/original/`+instanceAccount.AvatarMediaAttachment.ID+`.gif",`+` |   "thumbnail": "http://localhost:8080/fileserver/01AY6P665V14JJR0AFVRT7311Y/attachment/original/`+instanceAccount.AvatarMediaAttachment.ID+`.gif",`+` | ||||||
|  | @ -810,7 +810,7 @@ func (suite *InstancePatchTestSuite) TestInstancePatch9() { | ||||||
|   }, |   }, | ||||||
|   "stats": { |   "stats": { | ||||||
|     "domain_count": 2, |     "domain_count": 2, | ||||||
|     "status_count": 16, |     "status_count": 18, | ||||||
|     "user_count": 4 |     "user_count": 4 | ||||||
|   }, |   }, | ||||||
|   "thumbnail": "http://localhost:8080/assets/logo.png", |   "thumbnail": "http://localhost:8080/assets/logo.png", | ||||||
|  |  | ||||||
							
								
								
									
										48
									
								
								internal/api/client/polls/polls.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								internal/api/client/polls/polls.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,48 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package polls | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/api/util" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/processing" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	IDKey           = "id"                                 // IDKey is the key for poll IDs | ||||||
|  | 	BasePath        = "/:" + util.APIVersionKey + "/polls" // BasePath is the base API path for making poll requests through v1 or v2 of the api (for mastodon API compatibility) | ||||||
|  | 	PollWithID      = BasePath + "/:" + IDKey              // | ||||||
|  | 	PollVotesWithID = BasePath + "/:" + IDKey + "/votes"   // | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Module struct { | ||||||
|  | 	processor *processing.Processor | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func New(processor *processing.Processor) *Module { | ||||||
|  | 	return &Module{ | ||||||
|  | 		processor: processor, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) { | ||||||
|  | 	attachHandler(http.MethodGet, PollWithID, m.PollGETHandler) | ||||||
|  | 	attachHandler(http.MethodPost, PollVotesWithID, m.PollVotePOSTHandler) | ||||||
|  | } | ||||||
							
								
								
									
										100
									
								
								internal/api/client/polls/polls_get.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								internal/api/client/polls/polls_get.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,100 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package polls | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // PollGETHandler swagger:operation GET /api/v1/polls/{id} poll | ||||||
|  | // | ||||||
|  | // View poll with given ID. | ||||||
|  | // | ||||||
|  | //	--- | ||||||
|  | //	tags: | ||||||
|  | //	- polls | ||||||
|  | // | ||||||
|  | //	produces: | ||||||
|  | //	- application/json | ||||||
|  | // | ||||||
|  | //	parameters: | ||||||
|  | //	- | ||||||
|  | //		name: id | ||||||
|  | //		type: string | ||||||
|  | //		description: Target poll ID. | ||||||
|  | //		in: path | ||||||
|  | //		required: true | ||||||
|  | // | ||||||
|  | //	security: | ||||||
|  | //	- OAuth2 Bearer: | ||||||
|  | //		- read:statuses | ||||||
|  | // | ||||||
|  | //	responses: | ||||||
|  | //		'200': | ||||||
|  | //			description: "The requested poll." | ||||||
|  | //			schema: | ||||||
|  | //				"$ref": "#/definitions/poll" | ||||||
|  | //		'400': | ||||||
|  | //			description: bad request | ||||||
|  | //		'401': | ||||||
|  | //			description: unauthorized | ||||||
|  | //		'403': | ||||||
|  | //			description: forbidden | ||||||
|  | //		'404': | ||||||
|  | //			description: not found | ||||||
|  | //		'406': | ||||||
|  | //			description: not acceptable | ||||||
|  | //		'500': | ||||||
|  | //			description: internal server error | ||||||
|  | func (m *Module) PollGETHandler(c *gin.Context) { | ||||||
|  | 	authed, err := oauth.Authed(c, true, true, true, true) | ||||||
|  | 	if err != nil { | ||||||
|  | 		errWithCode := gtserror.NewErrorUnauthorized(err, err.Error()) | ||||||
|  | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { | ||||||
|  | 		errWithCode := gtserror.NewErrorNotAcceptable(err, err.Error()) | ||||||
|  | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	pollID, errWithCode := apiutil.ParseID(c.Param(IDKey)) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	poll, errWithCode := m.processor.Polls().PollGet( | ||||||
|  | 		c.Request.Context(), | ||||||
|  | 		authed.Account, | ||||||
|  | 		pollID, | ||||||
|  | 	) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	c.JSON(http.StatusOK, poll) | ||||||
|  | } | ||||||
							
								
								
									
										112
									
								
								internal/api/client/polls/polls_vote.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								internal/api/client/polls/polls_vote.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,112 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package polls | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
|  | 	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // PollVotePOSTHandler swagger:operation POST /api/v1/polls/{id}/vote poll | ||||||
|  | // | ||||||
|  | // Vote with choices in the given poll. | ||||||
|  | // | ||||||
|  | //	--- | ||||||
|  | //	tags: | ||||||
|  | //	- polls | ||||||
|  | // | ||||||
|  | //	produces: | ||||||
|  | //	- application/json | ||||||
|  | // | ||||||
|  | //	parameters: | ||||||
|  | //	- | ||||||
|  | //		name: id | ||||||
|  | //		type: string | ||||||
|  | //		description: Target poll ID. | ||||||
|  | //		in: path | ||||||
|  | //		required: true | ||||||
|  | // | ||||||
|  | //	security: | ||||||
|  | //	- OAuth2 Bearer: | ||||||
|  | //		- write:statuses | ||||||
|  | // | ||||||
|  | //	responses: | ||||||
|  | //		'200': | ||||||
|  | //			description: "The updated poll with user vote choices." | ||||||
|  | //			schema: | ||||||
|  | //				"$ref": "#/definitions/poll" | ||||||
|  | //		'400': | ||||||
|  | //			description: bad request | ||||||
|  | //		'401': | ||||||
|  | //			description: unauthorized | ||||||
|  | //		'403': | ||||||
|  | //			description: forbidden | ||||||
|  | //		'404': | ||||||
|  | //			description: not found | ||||||
|  | //		'406': | ||||||
|  | //			description: not acceptable | ||||||
|  | //		'422': | ||||||
|  | //			description: unprocessable entity | ||||||
|  | //		'500': | ||||||
|  | //			description: internal server error | ||||||
|  | func (m *Module) PollVotePOSTHandler(c *gin.Context) { | ||||||
|  | 	authed, err := oauth.Authed(c, true, true, true, true) | ||||||
|  | 	if err != nil { | ||||||
|  | 		errWithCode := gtserror.NewErrorUnauthorized(err, err.Error()) | ||||||
|  | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { | ||||||
|  | 		errWithCode := gtserror.NewErrorNotAcceptable(err, err.Error()) | ||||||
|  | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	pollID, errWithCode := apiutil.ParseID(c.Param(IDKey)) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var form apimodel.PollVoteRequest | ||||||
|  | 
 | ||||||
|  | 	if err := c.ShouldBind(&form); err != nil { | ||||||
|  | 		errWithCode := gtserror.NewErrorBadRequest(err, err.Error()) | ||||||
|  | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	poll, errWithCode := m.processor.Polls().PollVote( | ||||||
|  | 		c.Request.Context(), | ||||||
|  | 		authed.Account, | ||||||
|  | 		pollID, | ||||||
|  | 		form.Choices, | ||||||
|  | 	) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	c.JSON(http.StatusOK, poll) | ||||||
|  | } | ||||||
|  | @ -129,8 +129,8 @@ func (suite *ReportGetTestSuite) TestGetReport1() { | ||||||
|     "header_static": "http://localhost:8080/assets/default_header.png", |     "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|     "followers_count": 0, |     "followers_count": 0, | ||||||
|     "following_count": 0, |     "following_count": 0, | ||||||
|     "statuses_count": 1, |     "statuses_count": 2, | ||||||
|     "last_status_at": "2021-09-20T10:40:37.000Z", |     "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|     "emojis": [], |     "emojis": [], | ||||||
|     "fields": [] |     "fields": [] | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | @ -154,8 +154,8 @@ func (suite *ReportsGetTestSuite) TestGetReports() { | ||||||
|       "header_static": "http://localhost:8080/assets/default_header.png", |       "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|       "followers_count": 0, |       "followers_count": 0, | ||||||
|       "following_count": 0, |       "following_count": 0, | ||||||
|       "statuses_count": 1, |       "statuses_count": 2, | ||||||
|       "last_status_at": "2021-09-20T10:40:37.000Z", |       "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|       "emojis": [], |       "emojis": [], | ||||||
|       "fields": [] |       "fields": [] | ||||||
|     } |     } | ||||||
|  | @ -244,8 +244,8 @@ func (suite *ReportsGetTestSuite) TestGetReports4() { | ||||||
|       "header_static": "http://localhost:8080/assets/default_header.png", |       "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|       "followers_count": 0, |       "followers_count": 0, | ||||||
|       "following_count": 0, |       "following_count": 0, | ||||||
|       "statuses_count": 1, |       "statuses_count": 2, | ||||||
|       "last_status_at": "2021-09-20T10:40:37.000Z", |       "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|       "emojis": [], |       "emojis": [], | ||||||
|       "fields": [] |       "fields": [] | ||||||
|     } |     } | ||||||
|  | @ -318,8 +318,8 @@ func (suite *ReportsGetTestSuite) TestGetReports6() { | ||||||
|       "header_static": "http://localhost:8080/assets/default_header.png", |       "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|       "followers_count": 0, |       "followers_count": 0, | ||||||
|       "following_count": 0, |       "following_count": 0, | ||||||
|       "statuses_count": 1, |       "statuses_count": 2, | ||||||
|       "last_status_at": "2021-09-20T10:40:37.000Z", |       "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|       "emojis": [], |       "emojis": [], | ||||||
|       "fields": [] |       "fields": [] | ||||||
|     } |     } | ||||||
|  | @ -376,8 +376,8 @@ func (suite *ReportsGetTestSuite) TestGetReports7() { | ||||||
|       "header_static": "http://localhost:8080/assets/default_header.png", |       "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|       "followers_count": 0, |       "followers_count": 0, | ||||||
|       "following_count": 0, |       "following_count": 0, | ||||||
|       "statuses_count": 1, |       "statuses_count": 2, | ||||||
|       "last_status_at": "2021-09-20T10:40:37.000Z", |       "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|       "emojis": [], |       "emojis": [], | ||||||
|       "fields": [] |       "fields": [] | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  | @ -877,7 +877,7 @@ func (suite *SearchGetTestSuite) TestSearchAAny() { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	suite.Len(searchResult.Accounts, 5) | 	suite.Len(searchResult.Accounts, 5) | ||||||
| 	suite.Len(searchResult.Statuses, 4) | 	suite.Len(searchResult.Statuses, 5) | ||||||
| 	suite.Len(searchResult.Hashtags, 0) | 	suite.Len(searchResult.Hashtags, 0) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -918,7 +918,7 @@ func (suite *SearchGetTestSuite) TestSearchAAnyFollowingOnly() { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	suite.Len(searchResult.Accounts, 2) | 	suite.Len(searchResult.Accounts, 2) | ||||||
| 	suite.Len(searchResult.Statuses, 4) | 	suite.Len(searchResult.Statuses, 5) | ||||||
| 	suite.Len(searchResult.Hashtags, 0) | 	suite.Len(searchResult.Hashtags, 0) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -959,7 +959,7 @@ func (suite *SearchGetTestSuite) TestSearchAStatuses() { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	suite.Len(searchResult.Accounts, 0) | 	suite.Len(searchResult.Accounts, 0) | ||||||
| 	suite.Len(searchResult.Statuses, 4) | 	suite.Len(searchResult.Statuses, 5) | ||||||
| 	suite.Len(searchResult.Hashtags, 0) | 	suite.Len(searchResult.Hashtags, 0) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -103,7 +103,12 @@ func (m *Module) StatusCreatePOSTHandler(c *gin.Context) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	apiStatus, errWithCode := m.processor.Status().Create(c.Request.Context(), authed.Account, authed.Application, form) | 	apiStatus, errWithCode := m.processor.Status().Create( | ||||||
|  | 		c.Request.Context(), | ||||||
|  | 		authed.Account, | ||||||
|  | 		authed.Application, | ||||||
|  | 		form, | ||||||
|  | 	) | ||||||
| 	if errWithCode != nil { | 	if errWithCode != nil { | ||||||
| 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
| 		return | 		return | ||||||
|  | @ -144,7 +149,7 @@ func validateNormalizeCreateStatus(form *apimodel.AdvancedStatusCreateForm) erro | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if form.Poll != nil { | 	if form.Poll != nil { | ||||||
| 		if form.Poll.Options == nil { | 		if len(form.Poll.Options) == 0 { | ||||||
| 			return errors.New("poll with no options") | 			return errors.New("poll with no options") | ||||||
| 		} | 		} | ||||||
| 		if len(form.Poll.Options) > maxPollOptions { | 		if len(form.Poll.Options) > maxPollOptions { | ||||||
|  |  | ||||||
|  | @ -130,8 +130,8 @@ func (suite *StatusMuteTestSuite) TestMuteUnmuteStatus() { | ||||||
|     "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", |     "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", | ||||||
|     "followers_count": 2, |     "followers_count": 2, | ||||||
|     "following_count": 2, |     "following_count": 2, | ||||||
|     "statuses_count": 5, |     "statuses_count": 6, | ||||||
|     "last_status_at": "2022-05-20T11:37:55.000Z", |     "last_status_at": "2022-05-20T11:41:10.000Z", | ||||||
|     "emojis": [], |     "emojis": [], | ||||||
|     "fields": [], |     "fields": [], | ||||||
|     "enable_rss": true, |     "enable_rss": true, | ||||||
|  | @ -193,8 +193,8 @@ func (suite *StatusMuteTestSuite) TestMuteUnmuteStatus() { | ||||||
|     "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", |     "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", | ||||||
|     "followers_count": 2, |     "followers_count": 2, | ||||||
|     "following_count": 2, |     "following_count": 2, | ||||||
|     "statuses_count": 5, |     "statuses_count": 6, | ||||||
|     "last_status_at": "2022-05-20T11:37:55.000Z", |     "last_status_at": "2022-05-20T11:41:10.000Z", | ||||||
|     "emojis": [], |     "emojis": [], | ||||||
|     "fields": [], |     "fields": [], | ||||||
|     "enable_rss": true, |     "enable_rss": true, | ||||||
|  |  | ||||||
|  | @ -107,7 +107,7 @@ func (suite *StatusUnpinTestSuite) TestUnpinStatusNotFound() { | ||||||
| 	// Unpin a pinned followers-only status owned by another account. | 	// Unpin a pinned followers-only status owned by another account. | ||||||
| 	targetStatus := suite.testStatuses["local_account_2_status_7"] | 	targetStatus := suite.testStatuses["local_account_2_status_7"] | ||||||
| 
 | 
 | ||||||
| 	if _, err := suite.createUnpin(http.StatusNotFound, `{"error":"Not Found"}`, targetStatus.ID); err != nil { | 	if _, err := suite.createUnpin(http.StatusNotFound, `{"error":"Not Found: target status not found"}`, targetStatus.ID); err != nil { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -24,35 +24,44 @@ type Poll struct { | ||||||
| 	// The ID of the poll in the database. | 	// The ID of the poll in the database. | ||||||
| 	// example: 01FBYKMD1KBMJ0W6JF1YZ3VY5D | 	// example: 01FBYKMD1KBMJ0W6JF1YZ3VY5D | ||||||
| 	ID string `json:"id"` | 	ID string `json:"id"` | ||||||
| 	// When the poll ends. (ISO 8601 Datetime), or null if the poll does not end | 
 | ||||||
| 	ExpiresAt string `json:"expires_at,omitempty"` | 	// When the poll ends. (ISO 8601 Datetime). | ||||||
|  | 	ExpiresAt string `json:"expires_at"` | ||||||
|  | 
 | ||||||
| 	// Is the poll currently expired? | 	// Is the poll currently expired? | ||||||
| 	Expired bool `json:"expired"` | 	Expired bool `json:"expired"` | ||||||
|  | 
 | ||||||
| 	// Does the poll allow multiple-choice answers? | 	// Does the poll allow multiple-choice answers? | ||||||
| 	Multiple bool `json:"multiple"` | 	Multiple bool `json:"multiple"` | ||||||
|  | 
 | ||||||
| 	// How many votes have been received. | 	// How many votes have been received. | ||||||
| 	VotesCount int `json:"votes_count"` | 	VotesCount int `json:"votes_count"` | ||||||
| 	// How many unique accounts have voted on a multiple-choice poll. Null if multiple is false. | 
 | ||||||
| 	VotersCount int `json:"voters_count,omitempty"` | 	// How many unique accounts have voted on a multiple-choice poll. | ||||||
|  | 	VotersCount int `json:"voters_count"` | ||||||
|  | 
 | ||||||
| 	// When called with a user token, has the authorized user voted? | 	// When called with a user token, has the authorized user voted? | ||||||
| 	Voted bool `json:"voted,omitempty"` | 	Voted bool `json:"voted,omitempty"` | ||||||
|  | 
 | ||||||
| 	// When called with a user token, which options has the authorized user chosen? Contains an array of index values for options. | 	// When called with a user token, which options has the authorized user chosen? Contains an array of index values for options. | ||||||
| 	OwnVotes []int `json:"own_votes,omitempty"` | 	OwnVotes []int `json:"own_votes,omitempty"` | ||||||
|  | 
 | ||||||
| 	// Possible answers for the poll. | 	// Possible answers for the poll. | ||||||
| 	Options []PollOptions `json:"options"` | 	Options []PollOption `json:"options"` | ||||||
|  | 
 | ||||||
| 	// Custom emoji to be used for rendering poll options. | 	// Custom emoji to be used for rendering poll options. | ||||||
| 	Emojis []Emoji `json:"emojis"` | 	Emojis []Emoji `json:"emojis"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // PollOptions represents the current vote counts for different poll options. | // PollOption represents the current vote counts for different poll options. | ||||||
| // | // | ||||||
| // swagger:model pollOptions | // swagger:model pollOption | ||||||
| type PollOptions struct { | type PollOption struct { | ||||||
| 	// The text value of the poll option. String. | 	// The text value of the poll option. String. | ||||||
| 	Title string `json:"title"` | 	Title string `json:"title"` | ||||||
|  | 
 | ||||||
| 	// The number of received votes for this option. | 	// The number of received votes for this option. | ||||||
| 	// Number, or null if results are not published yet. | 	VotesCount int `json:"votes_count"` | ||||||
| 	VotesCount int `json:"votes_count,omitempty"` |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // PollRequest models a request to create a poll. | // PollRequest models a request to create a poll. | ||||||
|  | @ -63,11 +72,23 @@ type PollRequest struct { | ||||||
| 	// If provided, media_ids cannot be used, and poll[expires_in] must be provided. | 	// If provided, media_ids cannot be used, and poll[expires_in] must be provided. | ||||||
| 	// name: poll[options] | 	// name: poll[options] | ||||||
| 	Options []string `form:"options" json:"options" xml:"options"` | 	Options []string `form:"options" json:"options" xml:"options"` | ||||||
|  | 
 | ||||||
| 	// Duration the poll should be open, in seconds. | 	// Duration the poll should be open, in seconds. | ||||||
| 	// If provided, media_ids cannot be used, and poll[options] must be provided. | 	// If provided, media_ids cannot be used, and poll[options] must be provided. | ||||||
| 	ExpiresIn int `form:"expires_in" json:"expires_in" xml:"expires_in"` | 	ExpiresIn int `form:"expires_in" json:"expires_in" xml:"expires_in"` | ||||||
|  | 
 | ||||||
| 	// Allow multiple choices on this poll. | 	// Allow multiple choices on this poll. | ||||||
| 	Multiple bool `form:"multiple" json:"multiple" xml:"multiple"` | 	Multiple bool `form:"multiple" json:"multiple" xml:"multiple"` | ||||||
|  | 
 | ||||||
| 	// Hide vote counts until the poll ends. | 	// Hide vote counts until the poll ends. | ||||||
| 	HideTotals bool `form:"hide_totals" json:"hide_totals" xml:"hide_totals"` | 	HideTotals bool `form:"hide_totals" json:"hide_totals" xml:"hide_totals"` | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // PollVoteRequest models a request to vote in a poll. | ||||||
|  | // | ||||||
|  | // swagger:parameters pollVote | ||||||
|  | type PollVoteRequest struct { | ||||||
|  | 	// Choices contains poll vote choice indices. Note that form | ||||||
|  | 	// uses a different key than the JSON, i.e. the '[]' suffix. | ||||||
|  | 	Choices []int `form:"choices[]" json:"choices" xml:"choices"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -40,7 +40,7 @@ import ( | ||||||
| // 404 header and footer. | // 404 header and footer. | ||||||
| // | // | ||||||
| // If an error is returned by InstanceGet, the function will panic. | // If an error is returned by InstanceGet, the function will panic. | ||||||
| func NotFoundHandler(c *gin.Context, instanceGet func(ctx context.Context) (*apimodel.InstanceV1, gtserror.WithCode), accept string) { | func NotFoundHandler(c *gin.Context, instanceGet func(ctx context.Context) (*apimodel.InstanceV1, gtserror.WithCode), accept string, errWithCode gtserror.WithCode) { | ||||||
| 	switch accept { | 	switch accept { | ||||||
| 	case string(TextHTML): | 	case string(TextHTML): | ||||||
| 		ctx := c.Request.Context() | 		ctx := c.Request.Context() | ||||||
|  | @ -54,9 +54,7 @@ func NotFoundHandler(c *gin.Context, instanceGet func(ctx context.Context) (*api | ||||||
| 			"requestID": gtscontext.RequestID(ctx), | 			"requestID": gtscontext.RequestID(ctx), | ||||||
| 		}) | 		}) | ||||||
| 	default: | 	default: | ||||||
| 		c.JSON(http.StatusNotFound, gin.H{ | 		c.JSON(http.StatusNotFound, gin.H{"error": errWithCode.Safe()}) | ||||||
| 			"error": http.StatusText(http.StatusNotFound), |  | ||||||
| 		}) |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -122,7 +120,7 @@ func ErrorHandler(c *gin.Context, errWithCode gtserror.WithCode, instanceGet fun | ||||||
| 
 | 
 | ||||||
| 	if errWithCode.Code() == http.StatusNotFound { | 	if errWithCode.Code() == http.StatusNotFound { | ||||||
| 		// Use our special not found handler with useful status text. | 		// Use our special not found handler with useful status text. | ||||||
| 		NotFoundHandler(c, instanceGet, accept) | 		NotFoundHandler(c, instanceGet, accept, errWithCode) | ||||||
| 	} else { | 	} else { | ||||||
| 		genericErrorHandler(c, instanceGet, accept, errWithCode) | 		genericErrorHandler(c, instanceGet, accept, errWithCode) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -83,8 +83,11 @@ func (suite *WebfingerGetTestSuite) funkifyAccountDomain(host string, accountDom | ||||||
| 	// to new host + account domain. | 	// to new host + account domain. | ||||||
| 	config.SetHost(host) | 	config.SetHost(host) | ||||||
| 	config.SetAccountDomain(accountDomain) | 	config.SetAccountDomain(accountDomain) | ||||||
|  | 	testrig.StopWorkers(&suite.state) | ||||||
|  | 	testrig.StartWorkers(&suite.state) | ||||||
| 	suite.processor = processing.NewProcessor(cleaner.New(&suite.state), suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(&suite.state), &suite.state, suite.emailSender) | 	suite.processor = processing.NewProcessor(cleaner.New(&suite.state), suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(&suite.state), &suite.state, suite.emailSender) | ||||||
| 	suite.webfingerModule = webfinger.New(suite.processor) | 	suite.webfingerModule = webfinger.New(suite.processor) | ||||||
|  | 	testrig.StartWorkers(&suite.state) | ||||||
| 
 | 
 | ||||||
| 	// Generate a new account for the | 	// Generate a new account for the | ||||||
| 	// tester, which uses the new host. | 	// tester, which uses the new host. | ||||||
|  |  | ||||||
							
								
								
									
										22
									
								
								internal/cache/cache.go
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										22
									
								
								internal/cache/cache.go
									
										
									
									
										vendored
									
									
								
							|  | @ -183,6 +183,22 @@ func (c *Caches) setuphooks() { | ||||||
| 		} | 		} | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
|  | 	c.GTS.Poll().SetInvalidateCallback(func(poll *gtsmodel.Poll) { | ||||||
|  | 		// Invalidate all cached votes of this poll. | ||||||
|  | 		c.GTS.PollVote().Invalidate("PollID", poll.ID) | ||||||
|  | 
 | ||||||
|  | 		// Invalidate cache of poll vote IDs. | ||||||
|  | 		c.GTS.PollVoteIDs().Invalidate(poll.ID) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	c.GTS.PollVote().SetInvalidateCallback(func(vote *gtsmodel.PollVote) { | ||||||
|  | 		// Invalidate cached poll (contains no. votes). | ||||||
|  | 		c.GTS.Poll().Invalidate("ID", vote.PollID) | ||||||
|  | 
 | ||||||
|  | 		// Invalidate cache of poll vote IDs. | ||||||
|  | 		c.GTS.PollVoteIDs().Invalidate(vote.PollID) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
| 	c.GTS.Status().SetInvalidateCallback(func(status *gtsmodel.Status) { | 	c.GTS.Status().SetInvalidateCallback(func(status *gtsmodel.Status) { | ||||||
| 		// Invalidate status ID cached visibility. | 		// Invalidate status ID cached visibility. | ||||||
| 		c.Visibility.Invalidate("ItemID", status.ID) | 		c.Visibility.Invalidate("ItemID", status.ID) | ||||||
|  | @ -206,6 +222,11 @@ func (c *Caches) setuphooks() { | ||||||
| 			// Invalidate in reply to ID list of original status. | 			// Invalidate in reply to ID list of original status. | ||||||
| 			c.GTS.InReplyToIDs().Invalidate(status.InReplyToID) | 			c.GTS.InReplyToIDs().Invalidate(status.InReplyToID) | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
|  | 		if status.PollID != "" { | ||||||
|  | 			// Invalidate cache of attached poll ID. | ||||||
|  | 			c.GTS.Poll().Invalidate("ID", status.PollID) | ||||||
|  | 		} | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	c.GTS.StatusFave().SetInvalidateCallback(func(fave *gtsmodel.StatusFave) { | 	c.GTS.StatusFave().SetInvalidateCallback(func(fave *gtsmodel.StatusFave) { | ||||||
|  | @ -244,6 +265,7 @@ func (c *Caches) Sweep(threshold float64) { | ||||||
| 	c.GTS.Media().Trim(threshold) | 	c.GTS.Media().Trim(threshold) | ||||||
| 	c.GTS.Mention().Trim(threshold) | 	c.GTS.Mention().Trim(threshold) | ||||||
| 	c.GTS.Notification().Trim(threshold) | 	c.GTS.Notification().Trim(threshold) | ||||||
|  | 	c.GTS.Poll().Trim(threshold) | ||||||
| 	c.GTS.Report().Trim(threshold) | 	c.GTS.Report().Trim(threshold) | ||||||
| 	c.GTS.Status().Trim(threshold) | 	c.GTS.Status().Trim(threshold) | ||||||
| 	c.GTS.StatusFave().Trim(threshold) | 	c.GTS.StatusFave().Trim(threshold) | ||||||
|  |  | ||||||
							
								
								
									
										98
									
								
								internal/cache/gts.go
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										98
									
								
								internal/cache/gts.go
									
										
									
									
										vendored
									
									
								
							|  | @ -52,6 +52,9 @@ type GTSCaches struct { | ||||||
| 	media            *result.Cache[*gtsmodel.MediaAttachment] | 	media            *result.Cache[*gtsmodel.MediaAttachment] | ||||||
| 	mention          *result.Cache[*gtsmodel.Mention] | 	mention          *result.Cache[*gtsmodel.Mention] | ||||||
| 	notification     *result.Cache[*gtsmodel.Notification] | 	notification     *result.Cache[*gtsmodel.Notification] | ||||||
|  | 	poll             *result.Cache[*gtsmodel.Poll] | ||||||
|  | 	pollVote         *result.Cache[*gtsmodel.PollVote] | ||||||
|  | 	pollVoteIDs      *SliceCache[string] | ||||||
| 	report           *result.Cache[*gtsmodel.Report] | 	report           *result.Cache[*gtsmodel.Report] | ||||||
| 	status           *result.Cache[*gtsmodel.Status] | 	status           *result.Cache[*gtsmodel.Status] | ||||||
| 	statusFave       *result.Cache[*gtsmodel.StatusFave] | 	statusFave       *result.Cache[*gtsmodel.StatusFave] | ||||||
|  | @ -90,6 +93,9 @@ func (c *GTSCaches) Init() { | ||||||
| 	c.initMedia() | 	c.initMedia() | ||||||
| 	c.initMention() | 	c.initMention() | ||||||
| 	c.initNotification() | 	c.initNotification() | ||||||
|  | 	c.initPoll() | ||||||
|  | 	c.initPollVote() | ||||||
|  | 	c.initPollVoteIDs() | ||||||
| 	c.initReport() | 	c.initReport() | ||||||
| 	c.initStatus() | 	c.initStatus() | ||||||
| 	c.initStatusFave() | 	c.initStatusFave() | ||||||
|  | @ -231,6 +237,21 @@ func (c *GTSCaches) Notification() *result.Cache[*gtsmodel.Notification] { | ||||||
| 	return c.notification | 	return c.notification | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Poll provides access to the gtsmodel Poll database cache. | ||||||
|  | func (c *GTSCaches) Poll() *result.Cache[*gtsmodel.Poll] { | ||||||
|  | 	return c.poll | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // PollVote provides access to the gtsmodel PollVote database cache. | ||||||
|  | func (c *GTSCaches) PollVote() *result.Cache[*gtsmodel.PollVote] { | ||||||
|  | 	return c.pollVote | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // PollVoteIDs provides access to the poll vote IDs list database cache. | ||||||
|  | func (c *GTSCaches) PollVoteIDs() *SliceCache[string] { | ||||||
|  | 	return c.pollVoteIDs | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Report provides access to the gtsmodel Report database cache. | // Report provides access to the gtsmodel Report database cache. | ||||||
| func (c *GTSCaches) Report() *result.Cache[*gtsmodel.Report] { | func (c *GTSCaches) Report() *result.Cache[*gtsmodel.Report] { | ||||||
| 	return c.report | 	return c.report | ||||||
|  | @ -246,26 +267,26 @@ func (c *GTSCaches) StatusFave() *result.Cache[*gtsmodel.StatusFave] { | ||||||
| 	return c.statusFave | 	return c.statusFave | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Tag provides access to the gtsmodel Tag database cache. |  | ||||||
| func (c *GTSCaches) Tag() *result.Cache[*gtsmodel.Tag] { |  | ||||||
| 	return c.tag |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ThreadMute provides access to the gtsmodel ThreadMute database cache. |  | ||||||
| func (c *GTSCaches) ThreadMute() *result.Cache[*gtsmodel.ThreadMute] { |  | ||||||
| 	return c.threadMute |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // StatusFaveIDs provides access to the status fave IDs list database cache. | // StatusFaveIDs provides access to the status fave IDs list database cache. | ||||||
| func (c *GTSCaches) StatusFaveIDs() *SliceCache[string] { | func (c *GTSCaches) StatusFaveIDs() *SliceCache[string] { | ||||||
| 	return c.statusFaveIDs | 	return c.statusFaveIDs | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Tag provides access to the gtsmodel Tag database cache. | ||||||
|  | func (c *GTSCaches) Tag() *result.Cache[*gtsmodel.Tag] { | ||||||
|  | 	return c.tag | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Tombstone provides access to the gtsmodel Tombstone database cache. | // Tombstone provides access to the gtsmodel Tombstone database cache. | ||||||
| func (c *GTSCaches) Tombstone() *result.Cache[*gtsmodel.Tombstone] { | func (c *GTSCaches) Tombstone() *result.Cache[*gtsmodel.Tombstone] { | ||||||
| 	return c.tombstone | 	return c.tombstone | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ThreadMute provides access to the gtsmodel ThreadMute database cache. | ||||||
|  | func (c *GTSCaches) ThreadMute() *result.Cache[*gtsmodel.ThreadMute] { | ||||||
|  | 	return c.threadMute | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // User provides access to the gtsmodel User database cache. | // User provides access to the gtsmodel User database cache. | ||||||
| func (c *GTSCaches) User() *result.Cache[*gtsmodel.User] { | func (c *GTSCaches) User() *result.Cache[*gtsmodel.User] { | ||||||
| 	return c.user | 	return c.user | ||||||
|  | @ -685,6 +706,63 @@ func (c *GTSCaches) initNotification() { | ||||||
| 	c.notification.IgnoreErrors(ignoreErrors) | 	c.notification.IgnoreErrors(ignoreErrors) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (c *GTSCaches) initPoll() { | ||||||
|  | 	// Calculate maximum cache size. | ||||||
|  | 	cap := calculateResultCacheMax( | ||||||
|  | 		sizeofPoll(), // model in-mem size. | ||||||
|  | 		config.GetCachePollMemRatio(), | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	log.Infof(nil, "cache size = %d", cap) | ||||||
|  | 
 | ||||||
|  | 	c.poll = result.New([]result.Lookup{ | ||||||
|  | 		{Name: "ID"}, | ||||||
|  | 		{Name: "StatusID"}, | ||||||
|  | 	}, func(p1 *gtsmodel.Poll) *gtsmodel.Poll { | ||||||
|  | 		p2 := new(gtsmodel.Poll) | ||||||
|  | 		*p2 = *p1 | ||||||
|  | 		return p2 | ||||||
|  | 	}, cap) | ||||||
|  | 
 | ||||||
|  | 	c.poll.IgnoreErrors(ignoreErrors) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *GTSCaches) initPollVote() { | ||||||
|  | 	// Calculate maximum cache size. | ||||||
|  | 	cap := calculateResultCacheMax( | ||||||
|  | 		sizeofPollVote(), // model in-mem size. | ||||||
|  | 		config.GetCachePollVoteMemRatio(), | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	log.Infof(nil, "cache size = %d", cap) | ||||||
|  | 
 | ||||||
|  | 	c.pollVote = result.New([]result.Lookup{ | ||||||
|  | 		{Name: "ID"}, | ||||||
|  | 		{Name: "PollID.AccountID"}, | ||||||
|  | 		{Name: "PollID", Multi: true}, | ||||||
|  | 	}, func(v1 *gtsmodel.PollVote) *gtsmodel.PollVote { | ||||||
|  | 		v2 := new(gtsmodel.PollVote) | ||||||
|  | 		*v2 = *v1 | ||||||
|  | 		return v2 | ||||||
|  | 	}, cap) | ||||||
|  | 
 | ||||||
|  | 	c.pollVote.IgnoreErrors(ignoreErrors) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *GTSCaches) initPollVoteIDs() { | ||||||
|  | 	// Calculate maximum cache size. | ||||||
|  | 	cap := calculateSliceCacheMax( | ||||||
|  | 		config.GetCachePollVoteIDsMemRatio(), | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	log.Infof(nil, "cache size = %d", cap) | ||||||
|  | 
 | ||||||
|  | 	c.pollVoteIDs = &SliceCache[string]{Cache: simple.New[string, []string]( | ||||||
|  | 		0, | ||||||
|  | 		cap, | ||||||
|  | 	)} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (c *GTSCaches) initReport() { | func (c *GTSCaches) initReport() { | ||||||
| 	// Calculate maximum cache size. | 	// Calculate maximum cache size. | ||||||
| 	cap := calculateResultCacheMax( | 	cap := calculateResultCacheMax( | ||||||
|  |  | ||||||
							
								
								
									
										23
									
								
								internal/cache/size.go
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										23
									
								
								internal/cache/size.go
									
										
									
									
										vendored
									
									
								
							|  | @ -189,6 +189,8 @@ func totalOfRatios() float64 { | ||||||
| 		config.GetCacheMediaMemRatio() + | 		config.GetCacheMediaMemRatio() + | ||||||
| 		config.GetCacheMentionMemRatio() + | 		config.GetCacheMentionMemRatio() + | ||||||
| 		config.GetCacheNotificationMemRatio() + | 		config.GetCacheNotificationMemRatio() + | ||||||
|  | 		config.GetCachePollMemRatio() + | ||||||
|  | 		config.GetCachePollVoteMemRatio() + | ||||||
| 		config.GetCacheReportMemRatio() + | 		config.GetCacheReportMemRatio() + | ||||||
| 		config.GetCacheStatusMemRatio() + | 		config.GetCacheStatusMemRatio() + | ||||||
| 		config.GetCacheStatusFaveMemRatio() + | 		config.GetCacheStatusFaveMemRatio() + | ||||||
|  | @ -438,6 +440,27 @@ func sizeofNotification() uintptr { | ||||||
| 	})) | 	})) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func sizeofPoll() uintptr { | ||||||
|  | 	return uintptr(size.Of(>smodel.Poll{ | ||||||
|  | 		ID:         exampleID, | ||||||
|  | 		Multiple:   func() *bool { ok := false; return &ok }(), | ||||||
|  | 		HideCounts: func() *bool { ok := false; return &ok }(), | ||||||
|  | 		Options:    []string{exampleTextSmall, exampleTextSmall, exampleTextSmall, exampleTextSmall}, | ||||||
|  | 		StatusID:   exampleID, | ||||||
|  | 		ExpiresAt:  exampleTime, | ||||||
|  | 	})) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func sizeofPollVote() uintptr { | ||||||
|  | 	return uintptr(size.Of(>smodel.PollVote{ | ||||||
|  | 		ID:        exampleID, | ||||||
|  | 		Choices:   []int{69, 420, 1337}, | ||||||
|  | 		AccountID: exampleID, | ||||||
|  | 		PollID:    exampleID, | ||||||
|  | 		CreatedAt: exampleTime, | ||||||
|  | 	})) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func sizeofReport() uintptr { | func sizeofReport() uintptr { | ||||||
| 	return uintptr(size.Of(>smodel.Report{ | 	return uintptr(size.Of(>smodel.Report{ | ||||||
| 		ID:                     exampleID, | 		ID:                     exampleID, | ||||||
|  |  | ||||||
|  | @ -204,6 +204,9 @@ type CacheConfiguration struct { | ||||||
| 	MediaMemRatio            float64       `name:"media-mem-ratio"` | 	MediaMemRatio            float64       `name:"media-mem-ratio"` | ||||||
| 	MentionMemRatio          float64       `name:"mention-mem-ratio"` | 	MentionMemRatio          float64       `name:"mention-mem-ratio"` | ||||||
| 	NotificationMemRatio     float64       `name:"notification-mem-ratio"` | 	NotificationMemRatio     float64       `name:"notification-mem-ratio"` | ||||||
|  | 	PollMemRatio             float64       `name:"poll-mem-ratio"` | ||||||
|  | 	PollVoteMemRatio         float64       `name:"poll-vote-mem-ratio"` | ||||||
|  | 	PollVoteIDsMemRatio      float64       `name:"poll-vote-ids-mem-ratio"` | ||||||
| 	ReportMemRatio           float64       `name:"report-mem-ratio"` | 	ReportMemRatio           float64       `name:"report-mem-ratio"` | ||||||
| 	StatusMemRatio           float64       `name:"status-mem-ratio"` | 	StatusMemRatio           float64       `name:"status-mem-ratio"` | ||||||
| 	StatusFaveMemRatio       float64       `name:"status-fave-mem-ratio"` | 	StatusFaveMemRatio       float64       `name:"status-fave-mem-ratio"` | ||||||
|  |  | ||||||
|  | @ -171,6 +171,9 @@ var Defaults = Configuration{ | ||||||
| 		MediaMemRatio:            4, | 		MediaMemRatio:            4, | ||||||
| 		MentionMemRatio:          2, | 		MentionMemRatio:          2, | ||||||
| 		NotificationMemRatio:     2, | 		NotificationMemRatio:     2, | ||||||
|  | 		PollMemRatio:             1, | ||||||
|  | 		PollVoteMemRatio:         2, | ||||||
|  | 		PollVoteIDsMemRatio:      2, | ||||||
| 		ReportMemRatio:           1, | 		ReportMemRatio:           1, | ||||||
| 		StatusMemRatio:           5, | 		StatusMemRatio:           5, | ||||||
| 		StatusFaveMemRatio:       2, | 		StatusFaveMemRatio:       2, | ||||||
|  |  | ||||||
|  | @ -70,7 +70,7 @@ func main() { | ||||||
| 	fmt.Fprint(output, ")\n\n") | 	fmt.Fprint(output, ")\n\n") | ||||||
| 	generateFields(output, nil, reflect.TypeOf(config.Configuration{})) | 	generateFields(output, nil, reflect.TypeOf(config.Configuration{})) | ||||||
| 	_ = output.Close() | 	_ = output.Close() | ||||||
| 	_ = exec.Command("gofumports", "-w", out).Run() | 	_ = exec.Command("gofumpt", "-w", out).Run() | ||||||
| 
 | 
 | ||||||
| 	// The plan here is that eventually we might be able | 	// The plan here is that eventually we might be able | ||||||
| 	// to generate an example configuration from struct tags | 	// to generate an example configuration from struct tags | ||||||
|  |  | ||||||
|  | @ -3099,6 +3099,81 @@ func GetCacheNotificationMemRatio() float64 { return global.GetCacheNotification | ||||||
| // SetCacheNotificationMemRatio safely sets the value for global configuration 'Cache.NotificationMemRatio' field | // SetCacheNotificationMemRatio safely sets the value for global configuration 'Cache.NotificationMemRatio' field | ||||||
| func SetCacheNotificationMemRatio(v float64) { global.SetCacheNotificationMemRatio(v) } | func SetCacheNotificationMemRatio(v float64) { global.SetCacheNotificationMemRatio(v) } | ||||||
| 
 | 
 | ||||||
|  | // GetCachePollMemRatio safely fetches the Configuration value for state's 'Cache.PollMemRatio' field | ||||||
|  | func (st *ConfigState) GetCachePollMemRatio() (v float64) { | ||||||
|  | 	st.mutex.RLock() | ||||||
|  | 	v = st.config.Cache.PollMemRatio | ||||||
|  | 	st.mutex.RUnlock() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SetCachePollMemRatio safely sets the Configuration value for state's 'Cache.PollMemRatio' field | ||||||
|  | func (st *ConfigState) SetCachePollMemRatio(v float64) { | ||||||
|  | 	st.mutex.Lock() | ||||||
|  | 	defer st.mutex.Unlock() | ||||||
|  | 	st.config.Cache.PollMemRatio = v | ||||||
|  | 	st.reloadToViper() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CachePollMemRatioFlag returns the flag name for the 'Cache.PollMemRatio' field | ||||||
|  | func CachePollMemRatioFlag() string { return "cache-poll-mem-ratio" } | ||||||
|  | 
 | ||||||
|  | // GetCachePollMemRatio safely fetches the value for global configuration 'Cache.PollMemRatio' field | ||||||
|  | func GetCachePollMemRatio() float64 { return global.GetCachePollMemRatio() } | ||||||
|  | 
 | ||||||
|  | // SetCachePollMemRatio safely sets the value for global configuration 'Cache.PollMemRatio' field | ||||||
|  | func SetCachePollMemRatio(v float64) { global.SetCachePollMemRatio(v) } | ||||||
|  | 
 | ||||||
|  | // GetCachePollVoteMemRatio safely fetches the Configuration value for state's 'Cache.PollVoteMemRatio' field | ||||||
|  | func (st *ConfigState) GetCachePollVoteMemRatio() (v float64) { | ||||||
|  | 	st.mutex.RLock() | ||||||
|  | 	v = st.config.Cache.PollVoteMemRatio | ||||||
|  | 	st.mutex.RUnlock() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SetCachePollVoteMemRatio safely sets the Configuration value for state's 'Cache.PollVoteMemRatio' field | ||||||
|  | func (st *ConfigState) SetCachePollVoteMemRatio(v float64) { | ||||||
|  | 	st.mutex.Lock() | ||||||
|  | 	defer st.mutex.Unlock() | ||||||
|  | 	st.config.Cache.PollVoteMemRatio = v | ||||||
|  | 	st.reloadToViper() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CachePollVoteMemRatioFlag returns the flag name for the 'Cache.PollVoteMemRatio' field | ||||||
|  | func CachePollVoteMemRatioFlag() string { return "cache-poll-vote-mem-ratio" } | ||||||
|  | 
 | ||||||
|  | // GetCachePollVoteMemRatio safely fetches the value for global configuration 'Cache.PollVoteMemRatio' field | ||||||
|  | func GetCachePollVoteMemRatio() float64 { return global.GetCachePollVoteMemRatio() } | ||||||
|  | 
 | ||||||
|  | // SetCachePollVoteMemRatio safely sets the value for global configuration 'Cache.PollVoteMemRatio' field | ||||||
|  | func SetCachePollVoteMemRatio(v float64) { global.SetCachePollVoteMemRatio(v) } | ||||||
|  | 
 | ||||||
|  | // GetCachePollVoteIDsMemRatio safely fetches the Configuration value for state's 'Cache.PollVoteIDsMemRatio' field | ||||||
|  | func (st *ConfigState) GetCachePollVoteIDsMemRatio() (v float64) { | ||||||
|  | 	st.mutex.RLock() | ||||||
|  | 	v = st.config.Cache.PollVoteIDsMemRatio | ||||||
|  | 	st.mutex.RUnlock() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SetCachePollVoteIDsMemRatio safely sets the Configuration value for state's 'Cache.PollVoteIDsMemRatio' field | ||||||
|  | func (st *ConfigState) SetCachePollVoteIDsMemRatio(v float64) { | ||||||
|  | 	st.mutex.Lock() | ||||||
|  | 	defer st.mutex.Unlock() | ||||||
|  | 	st.config.Cache.PollVoteIDsMemRatio = v | ||||||
|  | 	st.reloadToViper() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CachePollVoteIDsMemRatioFlag returns the flag name for the 'Cache.PollVoteIDsMemRatio' field | ||||||
|  | func CachePollVoteIDsMemRatioFlag() string { return "cache-poll-vote-ids-mem-ratio" } | ||||||
|  | 
 | ||||||
|  | // GetCachePollVoteIDsMemRatio safely fetches the value for global configuration 'Cache.PollVoteIDsMemRatio' field | ||||||
|  | func GetCachePollVoteIDsMemRatio() float64 { return global.GetCachePollVoteIDsMemRatio() } | ||||||
|  | 
 | ||||||
|  | // SetCachePollVoteIDsMemRatio safely sets the value for global configuration 'Cache.PollVoteIDsMemRatio' field | ||||||
|  | func SetCachePollVoteIDsMemRatio(v float64) { global.SetCachePollVoteIDsMemRatio(v) } | ||||||
|  | 
 | ||||||
| // GetCacheReportMemRatio safely fetches the Configuration value for state's 'Cache.ReportMemRatio' field | // GetCacheReportMemRatio safely fetches the Configuration value for state's 'Cache.ReportMemRatio' field | ||||||
| func (st *ConfigState) GetCacheReportMemRatio() (v float64) { | func (st *ConfigState) GetCacheReportMemRatio() (v float64) { | ||||||
| 	st.mutex.RLock() | 	st.mutex.RLock() | ||||||
|  |  | ||||||
|  | @ -42,7 +42,7 @@ type AccountTestSuite struct { | ||||||
| func (suite *AccountTestSuite) TestGetAccountStatuses() { | func (suite *AccountTestSuite) TestGetAccountStatuses() { | ||||||
| 	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, false, false, "", "", false, false) | 	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, false, false, "", "", false, false) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.Len(statuses, 5) | 	suite.Len(statuses, 6) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() { | func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() { | ||||||
|  | @ -65,7 +65,7 @@ func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
| 	} | 	} | ||||||
| 	suite.Len(statuses, 1) | 	suite.Len(statuses, 2) | ||||||
| 
 | 
 | ||||||
| 	// try to get the last page (should be empty) | 	// try to get the last page (should be empty) | ||||||
| 	statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false) | 	statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false) | ||||||
|  | @ -76,7 +76,7 @@ func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() { | ||||||
| func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() { | func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() { | ||||||
| 	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false) | 	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.Len(statuses, 5) | 	suite.Len(statuses, 6) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogsPublicOnly() { | func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogsPublicOnly() { | ||||||
|  | @ -306,7 +306,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() { | ||||||
| func (suite *AccountTestSuite) TestGetAccountLastPosted() { | func (suite *AccountTestSuite) TestGetAccountLastPosted() { | ||||||
| 	lastPosted, err := suite.db.GetAccountLastPosted(context.Background(), suite.testAccounts["local_account_1"].ID, false) | 	lastPosted, err := suite.db.GetAccountLastPosted(context.Background(), suite.testAccounts["local_account_1"].ID, false) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.EqualValues(1653046675, lastPosted.Unix()) | 	suite.EqualValues(1653046870, lastPosted.Unix()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *AccountTestSuite) TestGetAccountLastPostedWebOnly() { | func (suite *AccountTestSuite) TestGetAccountLastPostedWebOnly() { | ||||||
|  |  | ||||||
|  | @ -121,7 +121,7 @@ func (suite *BasicTestSuite) TestGetAllStatuses() { | ||||||
| 	s := []*gtsmodel.Status{} | 	s := []*gtsmodel.Status{} | ||||||
| 	err := suite.db.GetAll(context.Background(), &s) | 	err := suite.db.GetAll(context.Background(), &s) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.Len(s, 17) | 	suite.Len(s, 20) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *BasicTestSuite) TestGetAllNotNull() { | func (suite *BasicTestSuite) TestGetAllNotNull() { | ||||||
|  |  | ||||||
|  | @ -71,6 +71,7 @@ type DBService struct { | ||||||
| 	db.Media | 	db.Media | ||||||
| 	db.Mention | 	db.Mention | ||||||
| 	db.Notification | 	db.Notification | ||||||
|  | 	db.Poll | ||||||
| 	db.Relationship | 	db.Relationship | ||||||
| 	db.Report | 	db.Report | ||||||
| 	db.Rule | 	db.Rule | ||||||
|  | @ -203,6 +204,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { | ||||||
| 			db:    db, | 			db:    db, | ||||||
| 			state: state, | 			state: state, | ||||||
| 		}, | 		}, | ||||||
|  | 		Poll: &pollDB{ | ||||||
|  | 			db:    db, | ||||||
|  | 			state: state, | ||||||
|  | 		}, | ||||||
| 		Relationship: &relationshipDB{ | 		Relationship: &relationshipDB{ | ||||||
| 			db:    db, | 			db:    db, | ||||||
| 			state: state, | 			state: state, | ||||||
|  |  | ||||||
|  | @ -54,6 +54,8 @@ type BunDBStandardTestSuite struct { | ||||||
| 	testMarkers      map[string]*gtsmodel.Marker | 	testMarkers      map[string]*gtsmodel.Marker | ||||||
| 	testRules        map[string]*gtsmodel.Rule | 	testRules        map[string]*gtsmodel.Rule | ||||||
| 	testThreads      map[string]*gtsmodel.Thread | 	testThreads      map[string]*gtsmodel.Thread | ||||||
|  | 	testPolls        map[string]*gtsmodel.Poll | ||||||
|  | 	testPollVotes    map[string]*gtsmodel.PollVote | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *BunDBStandardTestSuite) SetupSuite() { | func (suite *BunDBStandardTestSuite) SetupSuite() { | ||||||
|  | @ -77,6 +79,8 @@ func (suite *BunDBStandardTestSuite) SetupSuite() { | ||||||
| 	suite.testMarkers = testrig.NewTestMarkers() | 	suite.testMarkers = testrig.NewTestMarkers() | ||||||
| 	suite.testRules = testrig.NewTestRules() | 	suite.testRules = testrig.NewTestRules() | ||||||
| 	suite.testThreads = testrig.NewTestThreads() | 	suite.testThreads = testrig.NewTestThreads() | ||||||
|  | 	suite.testPolls = testrig.NewTestPolls() | ||||||
|  | 	suite.testPollVotes = testrig.NewTestPollVotes() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *BunDBStandardTestSuite) SetupTest() { | func (suite *BunDBStandardTestSuite) SetupTest() { | ||||||
|  |  | ||||||
|  | @ -47,13 +47,13 @@ func (suite *InstanceTestSuite) TestCountInstanceUsersRemote() { | ||||||
| func (suite *InstanceTestSuite) TestCountInstanceStatuses() { | func (suite *InstanceTestSuite) TestCountInstanceStatuses() { | ||||||
| 	count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost()) | 	count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost()) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.Equal(16, count) | 	suite.Equal(18, count) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() { | func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() { | ||||||
| 	count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io") | 	count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io") | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.Equal(1, count) | 	suite.Equal(2, count) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *InstanceTestSuite) TestCountInstanceDomains() { | func (suite *InstanceTestSuite) TestCountInstanceDomains() { | ||||||
|  |  | ||||||
|  | @ -20,10 +20,10 @@ package bundb | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 
 | 
 | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/state" | 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||||
|  | @ -54,31 +54,9 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Set the mention originating status. | 	// Further populate the mention fields where applicable. | ||||||
| 	mention.Status, err = m.state.DB.GetStatusByID( | 	if err := m.PopulateMention(ctx, mention); err != nil { | ||||||
| 		gtscontext.SetBarebones(ctx), | 		return nil, err | ||||||
| 		mention.StatusID, |  | ||||||
| 	) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, fmt.Errorf("error populating mention status: %w", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Set the mention origin account model. |  | ||||||
| 	mention.OriginAccount, err = m.state.DB.GetAccountByID( |  | ||||||
| 		gtscontext.SetBarebones(ctx), |  | ||||||
| 		mention.OriginAccountID, |  | ||||||
| 	) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, fmt.Errorf("error populating mention origin account: %w", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Set the mention target account model. |  | ||||||
| 	mention.TargetAccount, err = m.state.DB.GetAccountByID( |  | ||||||
| 		gtscontext.SetBarebones(ctx), |  | ||||||
| 		mention.TargetAccountID, |  | ||||||
| 	) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, fmt.Errorf("error populating mention target account: %w", err) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return mention, nil | 	return mention, nil | ||||||
|  | @ -102,6 +80,45 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel. | ||||||
| 	return mentions, nil | 	return mentions, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (m *mentionDB) PopulateMention(ctx context.Context, mention *gtsmodel.Mention) (err error) { | ||||||
|  | 	var errs gtserror.MultiError | ||||||
|  | 
 | ||||||
|  | 	if mention.Status == nil { | ||||||
|  | 		// Set the mention originating status. | ||||||
|  | 		mention.Status, err = m.state.DB.GetStatusByID( | ||||||
|  | 			gtscontext.SetBarebones(ctx), | ||||||
|  | 			mention.StatusID, | ||||||
|  | 		) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return gtserror.Newf("error populating mention status: %w", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if mention.OriginAccount == nil { | ||||||
|  | 		// Set the mention origin account model. | ||||||
|  | 		mention.OriginAccount, err = m.state.DB.GetAccountByID( | ||||||
|  | 			gtscontext.SetBarebones(ctx), | ||||||
|  | 			mention.OriginAccountID, | ||||||
|  | 		) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return gtserror.Newf("error populating mention origin account: %w", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if mention.TargetAccount == nil { | ||||||
|  | 		// Set the mention target account model. | ||||||
|  | 		mention.TargetAccount, err = m.state.DB.GetAccountByID( | ||||||
|  | 			gtscontext.SetBarebones(ctx), | ||||||
|  | 			mention.TargetAccountID, | ||||||
|  | 		) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return gtserror.Newf("error populating mention target account: %w", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return errs.Combine() | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error { | func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error { | ||||||
| 	return m.state.Caches.GTS.Mention().Store(mention, func() error { | 	return m.state.Caches.GTS.Mention().Store(mention, func() error { | ||||||
| 		_, err := m.db.NewInsert().Model(mention).Exec(ctx) | 		_, err := m.db.NewInsert().Model(mention).Exec(ctx) | ||||||
|  |  | ||||||
|  | @ -0,0 +1,65 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package migrations | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/uptrace/bun" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func init() { | ||||||
|  | 	up := func(ctx context.Context, db *bun.DB) error { | ||||||
|  | 		// Create `polls` + `poll_votes` tables. | ||||||
|  | 		for _, model := range []any{ | ||||||
|  | 			>smodel.Poll{}, | ||||||
|  | 			>smodel.PollVote{}, | ||||||
|  | 		} { | ||||||
|  | 			_, err := db.NewCreateTable(). | ||||||
|  | 				IfNotExists(). | ||||||
|  | 				Model(model). | ||||||
|  | 				Exec(ctx) | ||||||
|  | 			if err != nil && !(strings.Contains(err.Error(), "already exists") || strings.Contains(err.Error(), "duplicate column name") || strings.Contains(err.Error(), "SQLSTATE 42701")) { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Add the new status `poll_id` column. | ||||||
|  | 		_, err := db.NewAddColumn(). | ||||||
|  | 			Model(>smodel.Status{}). | ||||||
|  | 			ColumnExpr("? CHAR(26)", bun.Ident("poll_id")). | ||||||
|  | 			Exec(ctx) | ||||||
|  | 		if err != nil && !(strings.Contains(err.Error(), "already exists") || strings.Contains(err.Error(), "duplicate column name") || strings.Contains(err.Error(), "SQLSTATE 42701")) { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	down := func(ctx context.Context, db *bun.DB) error { | ||||||
|  | 		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
|  | 			return nil | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := Migrations.Register(up, down); err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										536
									
								
								internal/db/bundb/poll.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										536
									
								
								internal/db/bundb/poll.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,536 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package bundb | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||||
|  | 	"github.com/uptrace/bun" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type pollDB struct { | ||||||
|  | 	db    *DB | ||||||
|  | 	state *state.State | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, error) { | ||||||
|  | 	return p.getPoll( | ||||||
|  | 		ctx, | ||||||
|  | 		"ID", | ||||||
|  | 		func(poll *gtsmodel.Poll) error { | ||||||
|  | 			return p.db.NewSelect(). | ||||||
|  | 				Model(poll). | ||||||
|  | 				Where("? = ?", bun.Ident("poll.id"), id). | ||||||
|  | 				Scan(ctx) | ||||||
|  | 		}, | ||||||
|  | 		id, | ||||||
|  | 	) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) GetPollByStatusID(ctx context.Context, statusID string) (*gtsmodel.Poll, error) { | ||||||
|  | 	return p.getPoll( | ||||||
|  | 		ctx, | ||||||
|  | 		"StatusID", | ||||||
|  | 		func(poll *gtsmodel.Poll) error { | ||||||
|  | 			return p.db.NewSelect(). | ||||||
|  | 				Model(poll). | ||||||
|  | 				Where("? = ?", bun.Ident("poll.status_id"), statusID). | ||||||
|  | 				Scan(ctx) | ||||||
|  | 		}, | ||||||
|  | 		statusID, | ||||||
|  | 	) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) { | ||||||
|  | 	// Fetch poll from database cache with loader callback | ||||||
|  | 	poll, err := p.state.Caches.GTS.Poll().Load(lookup, func() (*gtsmodel.Poll, error) { | ||||||
|  | 		var poll gtsmodel.Poll | ||||||
|  | 
 | ||||||
|  | 		// Not cached! Perform database query. | ||||||
|  | 		if err := dbQuery(&poll); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Ensure vote slice | ||||||
|  | 		// is non nil and set. | ||||||
|  | 		poll.CheckVotes() | ||||||
|  | 
 | ||||||
|  | 		return &poll, nil | ||||||
|  | 	}, keyParts...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if gtscontext.Barebones(ctx) { | ||||||
|  | 		// no need to fully populate. | ||||||
|  | 		return poll, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Further populate the poll fields where applicable. | ||||||
|  | 	if err := p.PopulatePoll(ctx, poll); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return poll, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) GetOpenPolls(ctx context.Context) ([]*gtsmodel.Poll, error) { | ||||||
|  | 	var pollIDs []string | ||||||
|  | 
 | ||||||
|  | 	// Select all polls with unset `closed_at` time. | ||||||
|  | 	if err := p.db.NewSelect(). | ||||||
|  | 		Table("polls"). | ||||||
|  | 		Column("polls.id"). | ||||||
|  | 		Join("JOIN ? ON ? = ?", bun.Ident("statuses"), bun.Ident("polls.id"), bun.Ident("statuses.poll_id")). | ||||||
|  | 		Where("? = true", bun.Ident("statuses.local")). | ||||||
|  | 		Where("? IS NULL", bun.Ident("polls.closed_at")). | ||||||
|  | 		Scan(ctx, &pollIDs); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Preallocate a slice to contain the poll models. | ||||||
|  | 	polls := make([]*gtsmodel.Poll, 0, len(pollIDs)) | ||||||
|  | 
 | ||||||
|  | 	for _, id := range pollIDs { | ||||||
|  | 		// Attempt to fetch poll from DB. | ||||||
|  | 		poll, err := p.GetPollByID(ctx, id) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Errorf(ctx, "error getting poll %s: %v", id, err) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Append poll to return slice. | ||||||
|  | 		polls = append(polls, poll) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return polls, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) PopulatePoll(ctx context.Context, poll *gtsmodel.Poll) error { | ||||||
|  | 	var ( | ||||||
|  | 		err  error | ||||||
|  | 		errs gtserror.MultiError | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	if poll.Status == nil { | ||||||
|  | 		// Vote account is not set, fetch from database. | ||||||
|  | 		poll.Status, err = p.state.DB.GetStatusByID( | ||||||
|  | 			gtscontext.SetBarebones(ctx), | ||||||
|  | 			poll.StatusID, | ||||||
|  | 		) | ||||||
|  | 		if err != nil { | ||||||
|  | 			errs.Appendf("error populating poll status: %w", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return errs.Combine() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) PutPoll(ctx context.Context, poll *gtsmodel.Poll) error { | ||||||
|  | 	// Ensure vote slice | ||||||
|  | 	// is non nil and set. | ||||||
|  | 	poll.CheckVotes() | ||||||
|  | 
 | ||||||
|  | 	return p.state.Caches.GTS.Poll().Store(poll, func() error { | ||||||
|  | 		_, err := p.db.NewInsert().Model(poll).Exec(ctx) | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...string) error { | ||||||
|  | 	// Ensure vote slice | ||||||
|  | 	// is non nil and set. | ||||||
|  | 	poll.CheckVotes() | ||||||
|  | 
 | ||||||
|  | 	return p.state.Caches.GTS.Poll().Store(poll, func() error { | ||||||
|  | 		return p.db.RunInTx(ctx, func(tx Tx) error { | ||||||
|  | 			// Update the status' "updated_at" field. | ||||||
|  | 			if _, err := tx.NewUpdate(). | ||||||
|  | 				Table("statuses"). | ||||||
|  | 				Where("? = ?", bun.Ident("id"), poll.StatusID). | ||||||
|  | 				SetColumn("updated_at", "?", time.Now()). | ||||||
|  | 				Exec(ctx); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Finally, update poll | ||||||
|  | 			// columns in database. | ||||||
|  | 			_, err := tx.NewUpdate(). | ||||||
|  | 				Model(poll). | ||||||
|  | 				Column(cols...). | ||||||
|  | 				Where("? = ?", bun.Ident("id"), poll.ID). | ||||||
|  | 				Exec(ctx) | ||||||
|  | 			return err | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) DeletePollByID(ctx context.Context, id string) error { | ||||||
|  | 	// Delete poll by ID from database. | ||||||
|  | 	if _, err := p.db.NewDelete(). | ||||||
|  | 		Table("polls"). | ||||||
|  | 		Where("? = ?", bun.Ident("id"), id). | ||||||
|  | 		Exec(ctx); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Invalidate poll by ID from cache. | ||||||
|  | 	p.state.Caches.GTS.Poll().Invalidate("ID", id) | ||||||
|  | 	p.state.Caches.GTS.PollVoteIDs().Invalidate(id) | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.PollVote, error) { | ||||||
|  | 	return p.getPollVote( | ||||||
|  | 		ctx, | ||||||
|  | 		"ID", | ||||||
|  | 		func(vote *gtsmodel.PollVote) error { | ||||||
|  | 			return p.db.NewSelect(). | ||||||
|  | 				Model(vote). | ||||||
|  | 				Where("? = ?", bun.Ident("poll_vote.id"), id). | ||||||
|  | 				Scan(ctx) | ||||||
|  | 		}, | ||||||
|  | 		id, | ||||||
|  | 	) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) { | ||||||
|  | 	return p.getPollVote( | ||||||
|  | 		ctx, | ||||||
|  | 		"PollID.AccountID", | ||||||
|  | 		func(vote *gtsmodel.PollVote) error { | ||||||
|  | 			return p.db.NewSelect(). | ||||||
|  | 				Model(vote). | ||||||
|  | 				Where("? = ?", bun.Ident("poll_vote.account_id"), accountID). | ||||||
|  | 				Where("? = ?", bun.Ident("poll_vote.poll_id"), pollID). | ||||||
|  | 				Scan(ctx) | ||||||
|  | 		}, | ||||||
|  | 		pollID, | ||||||
|  | 		accountID, | ||||||
|  | 	) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) getPollVote(ctx context.Context, lookup string, dbQuery func(*gtsmodel.PollVote) error, keyParts ...any) (*gtsmodel.PollVote, error) { | ||||||
|  | 	// Fetch vote from database cache with loader callback | ||||||
|  | 	vote, err := p.state.Caches.GTS.PollVote().Load(lookup, func() (*gtsmodel.PollVote, error) { | ||||||
|  | 		var vote gtsmodel.PollVote | ||||||
|  | 
 | ||||||
|  | 		// Not cached! Perform database query. | ||||||
|  | 		if err := dbQuery(&vote); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return &vote, nil | ||||||
|  | 	}, keyParts...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if gtscontext.Barebones(ctx) { | ||||||
|  | 		// no need to fully populate. | ||||||
|  | 		return vote, nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Further populate the vote fields where applicable. | ||||||
|  | 	if err := p.PopulatePollVote(ctx, vote); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return vote, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) { | ||||||
|  | 	voteIDs, err := p.state.Caches.GTS.PollVoteIDs().Load(pollID, func() ([]string, error) { | ||||||
|  | 		var voteIDs []string | ||||||
|  | 
 | ||||||
|  | 		// Vote IDs not in cache, perform DB query! | ||||||
|  | 		q := newSelectPollVotes(p.db, pollID) | ||||||
|  | 		if _, err := q.Exec(ctx, &voteIDs); // nocollapse | ||||||
|  | 		err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return voteIDs, nil | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Preallocate slice of expected length. | ||||||
|  | 	votes := make([]*gtsmodel.PollVote, 0, len(voteIDs)) | ||||||
|  | 
 | ||||||
|  | 	for _, id := range voteIDs { | ||||||
|  | 		// Fetch poll vote model for this ID. | ||||||
|  | 		vote, err := p.GetPollVoteByID(ctx, id) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Errorf(ctx, "error getting poll vote %s: %v", id, err) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Append to return slice. | ||||||
|  | 		votes = append(votes, vote) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return votes, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote) error { | ||||||
|  | 	var ( | ||||||
|  | 		err  error | ||||||
|  | 		errs gtserror.MultiError | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	if vote.Account == nil { | ||||||
|  | 		// Vote account is not set, fetch from database. | ||||||
|  | 		vote.Account, err = p.state.DB.GetAccountByID( | ||||||
|  | 			gtscontext.SetBarebones(ctx), | ||||||
|  | 			vote.AccountID, | ||||||
|  | 		) | ||||||
|  | 		if err != nil { | ||||||
|  | 			errs.Appendf("error populating vote account: %w", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if vote.Poll == nil { | ||||||
|  | 		// Vote poll is not set, fetch from database. | ||||||
|  | 		vote.Poll, err = p.GetPollByID( | ||||||
|  | 			gtscontext.SetBarebones(ctx), | ||||||
|  | 			vote.PollID, | ||||||
|  | 		) | ||||||
|  | 		if err != nil { | ||||||
|  | 			errs.Appendf("error populating vote poll: %w", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return errs.Combine() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error { | ||||||
|  | 	return p.state.Caches.GTS.PollVote().Store(vote, func() error { | ||||||
|  | 		return p.db.RunInTx(ctx, func(tx Tx) error { | ||||||
|  | 			// Try insert vote into database. | ||||||
|  | 			if _, err := tx.NewInsert(). | ||||||
|  | 				Model(vote). | ||||||
|  | 				Exec(ctx); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			var poll gtsmodel.Poll | ||||||
|  | 
 | ||||||
|  | 			// Select poll counts from DB. | ||||||
|  | 			if err := tx.NewSelect(). | ||||||
|  | 				Model(&poll). | ||||||
|  | 				Where("? = ?", bun.Ident("id"), vote.PollID). | ||||||
|  | 				Scan(ctx); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Increment poll votes for choices. | ||||||
|  | 			poll.IncrementVotes(vote.Choices) | ||||||
|  | 
 | ||||||
|  | 			// Finally, update the poll entry. | ||||||
|  | 			_, err := tx.NewUpdate(). | ||||||
|  | 				Model(&poll). | ||||||
|  | 				Column("votes", "voters"). | ||||||
|  | 				Where("? = ?", bun.Ident("id"), vote.PollID). | ||||||
|  | 				Exec(ctx) | ||||||
|  | 			return err | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { | ||||||
|  | 	err := p.db.RunInTx(ctx, func(tx Tx) error { | ||||||
|  | 		// Delete all vote in poll, | ||||||
|  | 		// returning all vote choices. | ||||||
|  | 		switch _, err := tx.NewDelete(). | ||||||
|  | 			Table("poll_votes"). | ||||||
|  | 			Where("? = ?", bun.Ident("poll_id"), pollID). | ||||||
|  | 			Exec(ctx); { | ||||||
|  | 
 | ||||||
|  | 		case err == nil: | ||||||
|  | 			// no issue. | ||||||
|  | 
 | ||||||
|  | 		case errors.Is(err, db.ErrNoEntries): | ||||||
|  | 			// no votes found, | ||||||
|  | 			// return here. | ||||||
|  | 			return nil | ||||||
|  | 
 | ||||||
|  | 		default: | ||||||
|  | 			// irrecoverable. | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		var poll gtsmodel.Poll | ||||||
|  | 
 | ||||||
|  | 		// Select poll counts from DB. | ||||||
|  | 		switch err := tx.NewSelect(). | ||||||
|  | 			Model(&poll). | ||||||
|  | 			Where("? = ?", bun.Ident("id"), pollID). | ||||||
|  | 			Scan(ctx); { | ||||||
|  | 
 | ||||||
|  | 		case err == nil: | ||||||
|  | 			// no issue. | ||||||
|  | 
 | ||||||
|  | 		case errors.Is(err, db.ErrNoEntries): | ||||||
|  | 			// no votes found, | ||||||
|  | 			// return here. | ||||||
|  | 			return nil | ||||||
|  | 
 | ||||||
|  | 		default: | ||||||
|  | 			// irrecoverable. | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Zero all counts. | ||||||
|  | 		poll.ResetVotes() | ||||||
|  | 
 | ||||||
|  | 		// Finally, update the poll entry. | ||||||
|  | 		_, err := tx.NewUpdate(). | ||||||
|  | 			Model(&poll). | ||||||
|  | 			Column("votes", "voters"). | ||||||
|  | 			Where("? = ?", bun.Ident("id"), pollID). | ||||||
|  | 			Exec(ctx) | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Invalidate poll vote and poll entry from caches. | ||||||
|  | 	p.state.Caches.GTS.Poll().Invalidate("ID", pollID) | ||||||
|  | 	p.state.Caches.GTS.PollVote().Invalidate("PollID", pollID) | ||||||
|  | 	p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID) | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error { | ||||||
|  | 	err := p.db.RunInTx(ctx, func(tx Tx) error { | ||||||
|  | 		var choices []int | ||||||
|  | 
 | ||||||
|  | 		// Delete vote in poll by account, | ||||||
|  | 		// returning the ID + choices of the vote. | ||||||
|  | 		switch err := tx.NewDelete(). | ||||||
|  | 			Table("poll_votes"). | ||||||
|  | 			Where("? = ?", bun.Ident("poll_id"), pollID). | ||||||
|  | 			Where("? = ?", bun.Ident("account_id"), accountID). | ||||||
|  | 			Returning("choices"). | ||||||
|  | 			Scan(ctx, &choices); { | ||||||
|  | 
 | ||||||
|  | 		case err == nil: | ||||||
|  | 			// no issue. | ||||||
|  | 
 | ||||||
|  | 		case errors.Is(err, db.ErrNoEntries): | ||||||
|  | 			// no votes found, | ||||||
|  | 			// return here. | ||||||
|  | 			return nil | ||||||
|  | 
 | ||||||
|  | 		default: | ||||||
|  | 			// irrecoverable. | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		var poll gtsmodel.Poll | ||||||
|  | 
 | ||||||
|  | 		// Select poll counts from DB. | ||||||
|  | 		switch err := tx.NewSelect(). | ||||||
|  | 			Model(&poll). | ||||||
|  | 			Where("? = ?", bun.Ident("id"), pollID). | ||||||
|  | 			Scan(ctx); { | ||||||
|  | 
 | ||||||
|  | 		case err == nil: | ||||||
|  | 			// no issue. | ||||||
|  | 
 | ||||||
|  | 		case errors.Is(err, db.ErrNoEntries): | ||||||
|  | 			// no votes found, | ||||||
|  | 			// return here. | ||||||
|  | 			return nil | ||||||
|  | 
 | ||||||
|  | 		default: | ||||||
|  | 			// irrecoverable. | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Decrement votes for choices. | ||||||
|  | 		poll.IncrementVotes(choices) | ||||||
|  | 
 | ||||||
|  | 		// Finally, update the poll entry. | ||||||
|  | 		_, err := tx.NewUpdate(). | ||||||
|  | 			Model(&poll). | ||||||
|  | 			Column("votes", "voters"). | ||||||
|  | 			Where("? = ?", bun.Ident("id"), pollID). | ||||||
|  | 			Exec(ctx) | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Invalidate poll vote and poll entry from caches. | ||||||
|  | 	p.state.Caches.GTS.Poll().Invalidate("ID", pollID) | ||||||
|  | 	p.state.Caches.GTS.PollVote().Invalidate("PollID.AccountID", pollID, accountID) | ||||||
|  | 	p.state.Caches.GTS.PollVoteIDs().Invalidate(pollID) | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *pollDB) DeletePollVotesByAccountID(ctx context.Context, accountID string) error { | ||||||
|  | 	var pollIDs []string | ||||||
|  | 
 | ||||||
|  | 	// Select all polls this account | ||||||
|  | 	// has registered a poll vote in. | ||||||
|  | 	if err := p.db.NewSelect(). | ||||||
|  | 		Table("poll_votes"). | ||||||
|  | 		Column("poll_id"). | ||||||
|  | 		Where("? = ?", bun.Ident("account_id"), accountID). | ||||||
|  | 		Scan(ctx, &pollIDs); err != nil && | ||||||
|  | 		!errors.Is(err, db.ErrNoEntries) { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, id := range pollIDs { | ||||||
|  | 		// Delete all votes by this account in each of the polls, | ||||||
|  | 		// this way ensures that all necessary caches are invalidated. | ||||||
|  | 		if err := p.DeletePollVoteBy(ctx, id, accountID); err != nil { | ||||||
|  | 			log.Errorf(ctx, "error deleting vote by %s in %s: %v", accountID, id, err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // newSelectPollVotes returns a new select query for all rows in the poll_votes table with poll_id = pollID. | ||||||
|  | func newSelectPollVotes(db *DB, pollID string) *bun.SelectQuery { | ||||||
|  | 	return db.NewSelect(). | ||||||
|  | 		TableExpr("?", bun.Ident("poll_votes")). | ||||||
|  | 		ColumnExpr("?", bun.Ident("id")). | ||||||
|  | 		Where("? = ?", bun.Ident("poll_id"), pollID). | ||||||
|  | 		OrderExpr("? DESC", bun.Ident("id")) | ||||||
|  | } | ||||||
							
								
								
									
										318
									
								
								internal/db/bundb/poll_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										318
									
								
								internal/db/bundb/poll_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,318 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package bundb_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 	"math/rand" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/stretchr/testify/suite" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/id" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type PollTestSuite struct { | ||||||
|  | 	BunDBStandardTestSuite | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) TestGetPollBy() { | ||||||
|  | 	t := suite.T() | ||||||
|  | 
 | ||||||
|  | 	// Create a new context for this test. | ||||||
|  | 	ctx, cncl := context.WithCancel(context.Background()) | ||||||
|  | 	defer cncl() | ||||||
|  | 
 | ||||||
|  | 	// Sentinel error to mark avoiding a test case. | ||||||
|  | 	sentinelErr := errors.New("sentinel") | ||||||
|  | 
 | ||||||
|  | 	// isEqual checks if 2 poll models are equal. | ||||||
|  | 	isEqual := func(p1, p2 gtsmodel.Poll) bool { | ||||||
|  | 		// Clear populated sub-models. | ||||||
|  | 		p1.Status = nil | ||||||
|  | 		p2.Status = nil | ||||||
|  | 
 | ||||||
|  | 		// Localize all of the time fields. | ||||||
|  | 		p1.ExpiresAt = p1.ExpiresAt.Local() | ||||||
|  | 		p2.ExpiresAt = p2.ExpiresAt.Local() | ||||||
|  | 		p1.ClosedAt = p1.ClosedAt.Local() | ||||||
|  | 		p2.ClosedAt = p2.ClosedAt.Local() | ||||||
|  | 
 | ||||||
|  | 		// Perform the comparison. | ||||||
|  | 		return suite.Equal(p1, p2) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, poll := range suite.testPolls { | ||||||
|  | 		for lookup, dbfunc := range map[string]func() (*gtsmodel.Poll, error){ | ||||||
|  | 			"id": func() (*gtsmodel.Poll, error) { | ||||||
|  | 				return suite.db.GetPollByID(ctx, poll.ID) | ||||||
|  | 			}, | ||||||
|  | 
 | ||||||
|  | 			"status_id": func() (*gtsmodel.Poll, error) { | ||||||
|  | 				return suite.db.GetPollByStatusID(ctx, poll.StatusID) | ||||||
|  | 			}, | ||||||
|  | 		} { | ||||||
|  | 
 | ||||||
|  | 			// Clear database caches. | ||||||
|  | 			suite.state.Caches.Init() | ||||||
|  | 
 | ||||||
|  | 			t.Logf("checking database lookup %q", lookup) | ||||||
|  | 
 | ||||||
|  | 			// Perform database function. | ||||||
|  | 			checkPoll, err := dbfunc() | ||||||
|  | 			if err != nil { | ||||||
|  | 				if err == sentinelErr { | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				t.Errorf("error encountered for database lookup %q: %v", lookup, err) | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Check received account data. | ||||||
|  | 			if !isEqual(*checkPoll, *poll) { | ||||||
|  | 				t.Errorf("poll does not contain expected data: %+v", checkPoll) | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Check that poll source status populated. | ||||||
|  | 			if poll.StatusID != (*checkPoll).Status.ID { | ||||||
|  | 				t.Errorf("poll source status not correctly populated for: %+v", poll) | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) TestGetPollVoteBy() { | ||||||
|  | 	t := suite.T() | ||||||
|  | 
 | ||||||
|  | 	// Create a new context for this test. | ||||||
|  | 	ctx, cncl := context.WithCancel(context.Background()) | ||||||
|  | 	defer cncl() | ||||||
|  | 
 | ||||||
|  | 	// Sentinel error to mark avoiding a test case. | ||||||
|  | 	sentinelErr := errors.New("sentinel") | ||||||
|  | 
 | ||||||
|  | 	// isEqual checks if 2 poll vote models are equal. | ||||||
|  | 	isEqual := func(v1, v2 gtsmodel.PollVote) bool { | ||||||
|  | 		// Clear populated sub-models. | ||||||
|  | 		v1.Poll = nil | ||||||
|  | 		v2.Poll = nil | ||||||
|  | 		v1.Account = nil | ||||||
|  | 		v2.Account = nil | ||||||
|  | 
 | ||||||
|  | 		// Localize all of the time fields. | ||||||
|  | 		v1.CreatedAt = v1.CreatedAt.Local() | ||||||
|  | 		v2.CreatedAt = v2.CreatedAt.Local() | ||||||
|  | 
 | ||||||
|  | 		// Perform the comparison. | ||||||
|  | 		return suite.Equal(v1, v2) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, vote := range suite.testPollVotes { | ||||||
|  | 		for lookup, dbfunc := range map[string]func() (*gtsmodel.PollVote, error){ | ||||||
|  | 			"id": func() (*gtsmodel.PollVote, error) { | ||||||
|  | 				return suite.db.GetPollVoteByID(ctx, vote.ID) | ||||||
|  | 			}, | ||||||
|  | 
 | ||||||
|  | 			"poll_id_account_id": func() (*gtsmodel.PollVote, error) { | ||||||
|  | 				return suite.db.GetPollVoteBy(ctx, vote.PollID, vote.AccountID) | ||||||
|  | 			}, | ||||||
|  | 		} { | ||||||
|  | 
 | ||||||
|  | 			// Clear database caches. | ||||||
|  | 			suite.state.Caches.Init() | ||||||
|  | 
 | ||||||
|  | 			t.Logf("checking database lookup %q", lookup) | ||||||
|  | 
 | ||||||
|  | 			// Perform database function. | ||||||
|  | 			checkVote, err := dbfunc() | ||||||
|  | 			if err != nil { | ||||||
|  | 				if err == sentinelErr { | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				t.Errorf("error encountered for database lookup %q: %v", lookup, err) | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Check received account data. | ||||||
|  | 			if !isEqual(*checkVote, *vote) { | ||||||
|  | 				t.Errorf("poll vote does not contain expected data: %+v", checkVote) | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Check that vote source poll populated. | ||||||
|  | 			if checkVote.PollID != (*checkVote).Poll.ID { | ||||||
|  | 				t.Errorf("vote source poll not correctly populated for: %+v", vote) | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Check that vote author account populated. | ||||||
|  | 			if checkVote.AccountID != (*checkVote).Account.ID { | ||||||
|  | 				t.Errorf("vote author account not correctly populated for: %+v", vote) | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) TestUpdatePoll() { | ||||||
|  | 	// Create a new context for this test. | ||||||
|  | 	ctx, cncl := context.WithCancel(context.Background()) | ||||||
|  | 	defer cncl() | ||||||
|  | 
 | ||||||
|  | 	for _, poll := range suite.testPolls { | ||||||
|  | 		// Take copy of poll. | ||||||
|  | 		poll := util.Ptr(*poll) | ||||||
|  | 
 | ||||||
|  | 		// Update the poll closed field. | ||||||
|  | 		poll.ClosedAt = time.Now() | ||||||
|  | 
 | ||||||
|  | 		// Update poll model in the database. | ||||||
|  | 		err := suite.db.UpdatePoll(ctx, poll) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Refetch poll from database to get latest. | ||||||
|  | 		latest, err := suite.db.GetPollByID(ctx, poll.ID) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// The latest poll should have updated closedAt. | ||||||
|  | 		suite.Equal(poll.ClosedAt, latest.ClosedAt) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) TestPutPoll() { | ||||||
|  | 	// Create a new context for this test. | ||||||
|  | 	ctx, cncl := context.WithCancel(context.Background()) | ||||||
|  | 	defer cncl() | ||||||
|  | 
 | ||||||
|  | 	for _, poll := range suite.testPolls { | ||||||
|  | 		// Delete this poll from the database. | ||||||
|  | 		err := suite.db.DeletePollByID(ctx, poll.ID) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Ensure that afterwards we can | ||||||
|  | 		// enter it again into database. | ||||||
|  | 		err = suite.db.PutPoll(ctx, poll) | ||||||
|  | 
 | ||||||
|  | 		// Ensure that afterwards we can fetch poll. | ||||||
|  | 		_, err = suite.db.GetPollByID(ctx, poll.ID) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) TestPutPollVote() { | ||||||
|  | 	// Create a new context for this test. | ||||||
|  | 	ctx, cncl := context.WithCancel(context.Background()) | ||||||
|  | 	defer cncl() | ||||||
|  | 
 | ||||||
|  | 	// randomChoices generates random vote choices in poll. | ||||||
|  | 	randomChoices := func(poll *gtsmodel.Poll) []int { | ||||||
|  | 		var max int | ||||||
|  | 		if *poll.Multiple { | ||||||
|  | 			max = len(poll.Options) | ||||||
|  | 		} else { | ||||||
|  | 			max = 1 | ||||||
|  | 		} | ||||||
|  | 		count := 1 + rand.Intn(max) | ||||||
|  | 		choices := make([]int, count) | ||||||
|  | 		for i := range choices { | ||||||
|  | 			choices[i] = rand.Intn(len(poll.Options)) | ||||||
|  | 		} | ||||||
|  | 		return choices | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, poll := range suite.testPolls { | ||||||
|  | 		// Create a new vote to insert for poll. | ||||||
|  | 		vote := >smodel.PollVote{ | ||||||
|  | 			ID:        id.NewULID(), | ||||||
|  | 			Choices:   randomChoices(poll), | ||||||
|  | 			PollID:    poll.ID, | ||||||
|  | 			AccountID: id.NewULID(), // random account, doesn't matter | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Insert this new vote into database. | ||||||
|  | 		err := suite.db.PutPollVote(ctx, vote) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Fetch latest version of poll from database. | ||||||
|  | 		latest, err := suite.db.GetPollByID(ctx, poll.ID) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Decr latest version choices by new vote's. | ||||||
|  | 		for _, choice := range vote.Choices { | ||||||
|  | 			latest.Votes[choice]-- | ||||||
|  | 		} | ||||||
|  | 		(*latest.Voters)-- | ||||||
|  | 
 | ||||||
|  | 		// Old poll and latest model after decr | ||||||
|  | 		// should have equal vote + voter counts. | ||||||
|  | 		suite.Equal(poll.Voters, latest.Voters) | ||||||
|  | 		suite.Equal(poll.Votes, latest.Votes) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) TestDeletePoll() { | ||||||
|  | 	// Create a new context for this test. | ||||||
|  | 	ctx, cncl := context.WithCancel(context.Background()) | ||||||
|  | 	defer cncl() | ||||||
|  | 
 | ||||||
|  | 	for _, poll := range suite.testPolls { | ||||||
|  | 		// Delete this poll from the database. | ||||||
|  | 		err := suite.db.DeletePollByID(ctx, poll.ID) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Ensure that afterwards we cannot fetch poll. | ||||||
|  | 		_, err = suite.db.GetPollByID(ctx, poll.ID) | ||||||
|  | 		suite.ErrorIs(err, db.ErrNoEntries) | ||||||
|  | 
 | ||||||
|  | 		// Or again by the status it's attached to. | ||||||
|  | 		_, err = suite.db.GetPollByStatusID(ctx, poll.StatusID) | ||||||
|  | 		suite.ErrorIs(err, db.ErrNoEntries) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) TestDeletePollVotes() { | ||||||
|  | 	// Create a new context for this test. | ||||||
|  | 	ctx, cncl := context.WithCancel(context.Background()) | ||||||
|  | 	defer cncl() | ||||||
|  | 
 | ||||||
|  | 	for _, poll := range suite.testPolls { | ||||||
|  | 		// Delete votes associated with poll from database. | ||||||
|  | 		err := suite.db.DeletePollVotes(ctx, poll.ID) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Fetch latest version of poll from database. | ||||||
|  | 		poll, err = suite.db.GetPollByID(ctx, poll.ID) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Check that poll counts are all zero. | ||||||
|  | 		suite.Equal(*poll.Voters, 0) | ||||||
|  | 		suite.Equal(poll.Votes, make([]int, len(poll.Options))) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestPollTestSuite(t *testing.T) { | ||||||
|  | 	suite.Run(t, new(PollTestSuite)) | ||||||
|  | } | ||||||
|  | @ -199,7 +199,8 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri | ||||||
| 
 | 
 | ||||||
| 		// Follow IDs not in cache, perform DB query! | 		// Follow IDs not in cache, perform DB query! | ||||||
| 		q := newSelectFollows(r.db, accountID) | 		q := newSelectFollows(r.db, accountID) | ||||||
| 		if _, err := q.Exec(ctx, &followIDs); err != nil { | 		if _, err := q.Exec(ctx, &followIDs); // nocollapse | ||||||
|  | 		err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -213,7 +214,8 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID | ||||||
| 
 | 
 | ||||||
| 		// Follow IDs not in cache, perform DB query! | 		// Follow IDs not in cache, perform DB query! | ||||||
| 		q := newSelectLocalFollows(r.db, accountID) | 		q := newSelectLocalFollows(r.db, accountID) | ||||||
| 		if _, err := q.Exec(ctx, &followIDs); err != nil { | 		if _, err := q.Exec(ctx, &followIDs); // nocollapse | ||||||
|  | 		err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -227,7 +229,8 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st | ||||||
| 
 | 
 | ||||||
| 		// Follow IDs not in cache, perform DB query! | 		// Follow IDs not in cache, perform DB query! | ||||||
| 		q := newSelectFollowers(r.db, accountID) | 		q := newSelectFollowers(r.db, accountID) | ||||||
| 		if _, err := q.Exec(ctx, &followIDs); err != nil { | 		if _, err := q.Exec(ctx, &followIDs); // nocollapse | ||||||
|  | 		err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -241,7 +244,8 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account | ||||||
| 
 | 
 | ||||||
| 		// Follow IDs not in cache, perform DB query! | 		// Follow IDs not in cache, perform DB query! | ||||||
| 		q := newSelectLocalFollowers(r.db, accountID) | 		q := newSelectLocalFollowers(r.db, accountID) | ||||||
| 		if _, err := q.Exec(ctx, &followIDs); err != nil { | 		if _, err := q.Exec(ctx, &followIDs); // nocollapse | ||||||
|  | 		err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -255,7 +259,8 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account | ||||||
| 
 | 
 | ||||||
| 		// Follow request IDs not in cache, perform DB query! | 		// Follow request IDs not in cache, perform DB query! | ||||||
| 		q := newSelectFollowRequests(r.db, accountID) | 		q := newSelectFollowRequests(r.db, accountID) | ||||||
| 		if _, err := q.Exec(ctx, &followReqIDs); err != nil { | 		if _, err := q.Exec(ctx, &followReqIDs); // nocollapse | ||||||
|  | 		err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -269,7 +274,8 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco | ||||||
| 
 | 
 | ||||||
| 		// Follow request IDs not in cache, perform DB query! | 		// Follow request IDs not in cache, perform DB query! | ||||||
| 		q := newSelectFollowRequesting(r.db, accountID) | 		q := newSelectFollowRequesting(r.db, accountID) | ||||||
| 		if _, err := q.Exec(ctx, &followReqIDs); err != nil { | 		if _, err := q.Exec(ctx, &followReqIDs); // nocollapse | ||||||
|  | 		err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -283,7 +289,8 @@ func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID strin | ||||||
| 
 | 
 | ||||||
| 		// Block IDs not in cache, perform DB query! | 		// Block IDs not in cache, perform DB query! | ||||||
| 		q := newSelectBlocks(r.db, accountID) | 		q := newSelectBlocks(r.db, accountID) | ||||||
| 		if _, err := q.Exec(ctx, &blockIDs); err != nil { | 		if _, err := q.Exec(ctx, &blockIDs); // nocollapse | ||||||
|  | 		err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -154,17 +154,6 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if status.InReplyToID != "" && status.InReplyTo == nil { |  | ||||||
| 		// Status parent is not set, fetch from database. |  | ||||||
| 		status.InReplyTo, err = s.GetStatusByID( |  | ||||||
| 			gtscontext.SetBarebones(ctx), |  | ||||||
| 			status.InReplyToID, |  | ||||||
| 		) |  | ||||||
| 		if err != nil { |  | ||||||
| 			errs.Appendf("error populating status parent: %w", err) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if status.InReplyToID != "" { | 	if status.InReplyToID != "" { | ||||||
| 		if status.InReplyTo == nil { | 		if status.InReplyTo == nil { | ||||||
| 			// Status parent is not set, fetch from database. | 			// Status parent is not set, fetch from database. | ||||||
|  | @ -213,6 +202,17 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	if status.PollID != "" && status.Poll == nil { | ||||||
|  | 		// Status poll is not set, fetch from database. | ||||||
|  | 		status.Poll, err = s.state.DB.GetPollByID( | ||||||
|  | 			gtscontext.SetBarebones(ctx), | ||||||
|  | 			status.PollID, | ||||||
|  | 		) | ||||||
|  | 		if err != nil { | ||||||
|  | 			errs.Appendf("error populating status poll: %w", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	if !status.AttachmentsPopulated() { | 	if !status.AttachmentsPopulated() { | ||||||
| 		// Status attachments are out-of-date with IDs, repopulate. | 		// Status attachments are out-of-date with IDs, repopulate. | ||||||
| 		status.Attachments, err = s.state.DB.GetAttachmentsByIDs( | 		status.Attachments, err = s.state.DB.GetAttachmentsByIDs( | ||||||
|  |  | ||||||
|  | @ -22,6 +22,7 @@ import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"codeberg.org/gruf/go-kv" | ||||||
| 	"github.com/stretchr/testify/suite" | 	"github.com/stretchr/testify/suite" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/ap" | 	"github.com/superseriousbusiness/gotosocial/internal/ap" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
|  | @ -73,20 +74,18 @@ func getFutureStatus() *gtsmodel.Status { | ||||||
| 
 | 
 | ||||||
| func (suite *TimelineTestSuite) publicCount() int { | func (suite *TimelineTestSuite) publicCount() int { | ||||||
| 	var publicCount int | 	var publicCount int | ||||||
| 
 |  | ||||||
| 	for _, status := range suite.testStatuses { | 	for _, status := range suite.testStatuses { | ||||||
| 		if status.Visibility == gtsmodel.VisibilityPublic && | 		if status.Visibility == gtsmodel.VisibilityPublic && | ||||||
| 			status.BoostOfID == "" { | 			status.BoostOfID == "" { | ||||||
| 			publicCount++ | 			publicCount++ | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 |  | ||||||
| 	return publicCount | 	return publicCount | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID string, minID string, expectedLength int) { | func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID string, minID string, expectedLength int) { | ||||||
| 	if l := len(statuses); l != expectedLength { | 	if l := len(statuses); l != expectedLength { | ||||||
| 		suite.FailNow("", "expected %d statuses in slice, got %d", expectedLength, l) | 		suite.FailNowf("", "expected %d statuses in slice, got %d", expectedLength, l) | ||||||
| 	} else if l == 0 { | 	} else if l == 0 { | ||||||
| 		// Can't test empty slice. | 		// Can't test empty slice. | ||||||
| 		return | 		return | ||||||
|  | @ -98,15 +97,15 @@ func (suite *TimelineTestSuite) checkStatuses(statuses []*gtsmodel.Status, maxID | ||||||
| 		id := status.ID | 		id := status.ID | ||||||
| 
 | 
 | ||||||
| 		if id >= maxID { | 		if id >= maxID { | ||||||
| 			suite.FailNow("", "%s greater than maxID %s", id, maxID) | 			suite.FailNowf("", "%s greater than maxID %s", id, maxID) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if id <= minID { | 		if id <= minID { | ||||||
| 			suite.FailNow("", "%s smaller than minID %s", id, minID) | 			suite.FailNowf("", "%s smaller than minID %s", id, minID) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if id > highest { | 		if id > highest { | ||||||
| 			suite.FailNow("", "statuses in slice were not ordered highest -> lowest ID") | 			suite.FailNowf("", "statuses in slice were not ordered highest -> lowest ID") | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		highest = id | 		highest = id | ||||||
|  | @ -121,6 +120,10 @@ func (suite *TimelineTestSuite) TestGetPublicTimeline() { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	suite.T().Log(kv.Field{ | ||||||
|  | 		K: "statuses", V: s, | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
| 	suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount()) | 	suite.checkStatuses(s, id.Highest, id.Lowest, suite.publicCount()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -154,7 +157,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimeline() { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	suite.checkStatuses(s, id.Highest, id.Lowest, 16) | 	suite.checkStatuses(s, id.Highest, id.Lowest, 18) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() { | func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() { | ||||||
|  | @ -186,7 +189,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	suite.checkStatuses(s, id.Highest, id.Lowest, 5) | 	suite.checkStatuses(s, id.Highest, id.Lowest, 6) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() { | func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() { | ||||||
|  | @ -208,7 +211,7 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineWithFutureStatus() { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	suite.NotContains(s, futureStatus) | 	suite.NotContains(s, futureStatus) | ||||||
| 	suite.checkStatuses(s, id.Highest, id.Lowest, 16) | 	suite.checkStatuses(s, id.Highest, id.Lowest, 18) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *TimelineTestSuite) TestGetHomeTimelineBackToFront() { | func (suite *TimelineTestSuite) TestGetHomeTimelineBackToFront() { | ||||||
|  | @ -239,8 +242,8 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineFromHighest() { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	suite.checkStatuses(s, id.Highest, id.Lowest, 5) | 	suite.checkStatuses(s, id.Highest, id.Lowest, 5) | ||||||
| 	suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID) | 	suite.Equal("01HEN2RZ8BG29Y5Z9VJC73HZW7", s[0].ID) | ||||||
| 	suite.Equal("01FCTA44PW9H1TB328S9AQXKDS", s[len(s)-1].ID) | 	suite.Equal("01FN3VJGFH10KR7S2PB0GFJZYG", s[len(s)-1].ID) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *TimelineTestSuite) TestGetListTimelineNoParams() { | func (suite *TimelineTestSuite) TestGetListTimelineNoParams() { | ||||||
|  | @ -254,7 +257,7 @@ func (suite *TimelineTestSuite) TestGetListTimelineNoParams() { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	suite.checkStatuses(s, id.Highest, id.Lowest, 11) | 	suite.checkStatuses(s, id.Highest, id.Lowest, 12) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *TimelineTestSuite) TestGetListTimelineMaxID() { | func (suite *TimelineTestSuite) TestGetListTimelineMaxID() { | ||||||
|  | @ -269,8 +272,8 @@ func (suite *TimelineTestSuite) TestGetListTimelineMaxID() { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	suite.checkStatuses(s, id.Highest, id.Lowest, 5) | 	suite.checkStatuses(s, id.Highest, id.Lowest, 5) | ||||||
| 	suite.Equal("01G36SF3V6Y6V5BF9P4R7PQG7G", s[0].ID) | 	suite.Equal("01HEN2PRXT0TF4YDRA64FZZRN7", s[0].ID) | ||||||
| 	suite.Equal("01FCQSQ667XHJ9AV9T27SJJSX5", s[len(s)-1].ID) | 	suite.Equal("01FF25D5Q0DH7CHD57CTRS6WK0", s[len(s)-1].ID) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *TimelineTestSuite) TestGetListTimelineMinID() { | func (suite *TimelineTestSuite) TestGetListTimelineMinID() { | ||||||
|  |  | ||||||
|  | @ -36,6 +36,7 @@ type DB interface { | ||||||
| 	Media | 	Media | ||||||
| 	Mention | 	Mention | ||||||
| 	Notification | 	Notification | ||||||
|  | 	Poll | ||||||
| 	Relationship | 	Relationship | ||||||
| 	Report | 	Report | ||||||
| 	Rule | 	Rule | ||||||
|  |  | ||||||
|  | @ -31,6 +31,9 @@ type Mention interface { | ||||||
| 	// GetMentions gets multiple mentions. | 	// GetMentions gets multiple mentions. | ||||||
| 	GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error) | 	GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, error) | ||||||
| 
 | 
 | ||||||
|  | 	// PopulateMention ensures that all sub-models of a mention are populated (e.g. accounts). | ||||||
|  | 	PopulateMention(ctx context.Context, mention *gtsmodel.Mention) error | ||||||
|  | 
 | ||||||
| 	// PutMention will insert the given mention into the database. | 	// PutMention will insert the given mention into the database. | ||||||
| 	PutMention(ctx context.Context, mention *gtsmodel.Mention) error | 	PutMention(ctx context.Context, mention *gtsmodel.Mention) error | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										71
									
								
								internal/db/poll.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								internal/db/poll.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,71 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package db | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 
 | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Poll interface { | ||||||
|  | 	// GetPollByID fetches the Poll with given ID from the database. | ||||||
|  | 	GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, error) | ||||||
|  | 
 | ||||||
|  | 	// GetPollByStatusID fetches the Poll with given status ID column value from the database. | ||||||
|  | 	GetPollByStatusID(ctx context.Context, statusID string) (*gtsmodel.Poll, error) | ||||||
|  | 
 | ||||||
|  | 	// GetOpenPolls fetches all local Polls in the database with an unset `closed_at` column. | ||||||
|  | 	GetOpenPolls(ctx context.Context) ([]*gtsmodel.Poll, error) | ||||||
|  | 
 | ||||||
|  | 	// PopulatePoll ensures the given Poll is fully populated with all other related database models. | ||||||
|  | 	PopulatePoll(ctx context.Context, poll *gtsmodel.Poll) error | ||||||
|  | 
 | ||||||
|  | 	// PutPoll puts the given Poll in the database. | ||||||
|  | 	PutPoll(ctx context.Context, poll *gtsmodel.Poll) error | ||||||
|  | 
 | ||||||
|  | 	// UpdatePoll updates the Poll in the database, only on selected columns if provided (else, all). | ||||||
|  | 	UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...string) error | ||||||
|  | 
 | ||||||
|  | 	// DeletePollByID deletes the Poll with given ID from the database. | ||||||
|  | 	DeletePollByID(ctx context.Context, id string) error | ||||||
|  | 
 | ||||||
|  | 	// GetPollVoteByID gets the PollVote with given ID from the database. | ||||||
|  | 	GetPollVoteByID(ctx context.Context, id string) (*gtsmodel.PollVote, error) | ||||||
|  | 
 | ||||||
|  | 	// GetPollVotesBy fetches the PollVote in Poll with ID, by account ID, from the database. | ||||||
|  | 	GetPollVoteBy(ctx context.Context, pollID string, accountID string) (*gtsmodel.PollVote, error) | ||||||
|  | 
 | ||||||
|  | 	// GetPollVotes fetches all PollVotes in Poll with ID, from the database. | ||||||
|  | 	GetPollVotes(ctx context.Context, pollID string) ([]*gtsmodel.PollVote, error) | ||||||
|  | 
 | ||||||
|  | 	// PopulatePollVote ensures the given PollVote is fully populated with all other related database models. | ||||||
|  | 	PopulatePollVote(ctx context.Context, votes *gtsmodel.PollVote) error | ||||||
|  | 
 | ||||||
|  | 	// PutPollVote puts the given PollVote in the database. | ||||||
|  | 	PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error | ||||||
|  | 
 | ||||||
|  | 	// DeletePollVotes deletes all PollVotes in Poll with given ID from the database. | ||||||
|  | 	DeletePollVotes(ctx context.Context, pollID string) error | ||||||
|  | 
 | ||||||
|  | 	// DeletePollVoteBy deletes the PollVote in Poll with ID, by account ID, from the database. | ||||||
|  | 	DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error | ||||||
|  | 
 | ||||||
|  | 	// DeletePollVotesByAccountID deletes all PollVotes in all Polls, by account ID, from the database. | ||||||
|  | 	DeletePollVotesByAccountID(ctx context.Context, accountID string) error | ||||||
|  | } | ||||||
|  | @ -365,12 +365,13 @@ func (d *Dereferencer) enrichStatus( | ||||||
| 
 | 
 | ||||||
| 	// Use existing status ID. | 	// Use existing status ID. | ||||||
| 	latestStatus.ID = status.ID | 	latestStatus.ID = status.ID | ||||||
| 
 |  | ||||||
| 	if latestStatus.ID == "" { | 	if latestStatus.ID == "" { | ||||||
|  | 
 | ||||||
| 		// Generate new status ID from the provided creation date. | 		// Generate new status ID from the provided creation date. | ||||||
| 		latestStatus.ID, err = id.NewULIDFromTime(latestStatus.CreatedAt) | 		latestStatus.ID, err = id.NewULIDFromTime(latestStatus.CreatedAt) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, nil, gtserror.Newf("invalid created at date: %w", err) | 			log.Errorf(ctx, "invalid created at date (falling back to 'now'): %v", err) | ||||||
|  | 			latestStatus.ID = id.NewULID() // just use "now" | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -379,6 +380,11 @@ func (d *Dereferencer) enrichStatus( | ||||||
| 	latestStatus.FetchedAt = time.Now() | 	latestStatus.FetchedAt = time.Now() | ||||||
| 	latestStatus.Local = status.Local | 	latestStatus.Local = status.Local | ||||||
| 
 | 
 | ||||||
|  | 	// Ensure the status' poll remains consistent, else reset the poll. | ||||||
|  | 	if err := d.fetchStatusPoll(ctx, status, latestStatus); err != nil { | ||||||
|  | 		return nil, nil, gtserror.Newf("error populating poll for status %s: %w", uri, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Ensure the status' mentions are populated, and pass in existing to check for changes. | 	// Ensure the status' mentions are populated, and pass in existing to check for changes. | ||||||
| 	if err := d.fetchStatusMentions(ctx, requestUser, status, latestStatus); err != nil { | 	if err := d.fetchStatusMentions(ctx, requestUser, status, latestStatus); err != nil { | ||||||
| 		return nil, nil, gtserror.Newf("error populating mentions for status %s: %w", uri, err) | 		return nil, nil, gtserror.Newf("error populating mentions for status %s: %w", uri, err) | ||||||
|  | @ -533,7 +539,7 @@ func (d *Dereferencer) fetchStatusMentions(ctx context.Context, requestUser stri | ||||||
| 		//       support for edited status revision history. | 		//       support for edited status revision history. | ||||||
| 		mention.ID, err = id.NewULIDFromTime(status.CreatedAt) | 		mention.ID, err = id.NewULIDFromTime(status.CreatedAt) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Errorf(ctx, "invalid created at date: %v", err) | 			log.Errorf(ctx, "invalid created at date (falling back to 'now'): %v", err) | ||||||
| 			mention.ID = id.NewULID() // just use "now" | 			mention.ID = id.NewULID() // just use "now" | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -681,6 +687,101 @@ func (d *Dereferencer) fetchStatusTags(ctx context.Context, existing, status *gt | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (d *Dereferencer) fetchStatusPoll(ctx context.Context, existing, status *gtsmodel.Status) error { | ||||||
|  | 	var ( | ||||||
|  | 		// insertStatusPoll generates ID and inserts the poll attached to status into the database. | ||||||
|  | 		insertStatusPoll = func(ctx context.Context, status *gtsmodel.Status) error { | ||||||
|  | 			var err error | ||||||
|  | 
 | ||||||
|  | 			// Generate new ID for poll from the status CreatedAt. | ||||||
|  | 			// TODO: update this to use "edited_at" when we add | ||||||
|  | 			//       support for edited status revision history. | ||||||
|  | 			status.Poll.ID, err = id.NewULIDFromTime(status.CreatedAt) | ||||||
|  | 			if err != nil { | ||||||
|  | 				log.Errorf(ctx, "invalid created at date (falling back to 'now'): %v", err) | ||||||
|  | 				status.Poll.ID = id.NewULID() // just use "now" | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Update the status<->poll links. | ||||||
|  | 			status.PollID = status.Poll.ID | ||||||
|  | 			status.Poll.StatusID = status.ID | ||||||
|  | 			status.Poll.Status = status | ||||||
|  | 
 | ||||||
|  | 			// Insert this latest poll into the database. | ||||||
|  | 			err = d.state.DB.PutPoll(ctx, status.Poll) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return gtserror.Newf("error putting in database: %w", err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// deleteStatusPoll deletes the poll with ID, and all attached votes, from the database. | ||||||
|  | 		deleteStatusPoll = func(ctx context.Context, pollID string) error { | ||||||
|  | 			if err := d.state.DB.DeletePollByID(ctx, pollID); err != nil { | ||||||
|  | 				return gtserror.Newf("error deleting existing poll from database: %w", err) | ||||||
|  | 			} | ||||||
|  | 			if err := d.state.DB.DeletePollVotes(ctx, pollID); err != nil { | ||||||
|  | 				return gtserror.Newf("error deleting existing votes from database: %w", err) | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	switch { | ||||||
|  | 	case existing.Poll == nil && status.Poll == nil: | ||||||
|  | 		// no poll before or after, nothing to do. | ||||||
|  | 		return nil | ||||||
|  | 
 | ||||||
|  | 	case existing.Poll == nil && status.Poll != nil: | ||||||
|  | 		// no previous poll, insert new poll! | ||||||
|  | 		return insertStatusPoll(ctx, status) | ||||||
|  | 
 | ||||||
|  | 	case /*existing.Poll != nil &&*/ status.Poll == nil: | ||||||
|  | 		// existing poll has been deleted, remove this. | ||||||
|  | 		return deleteStatusPoll(ctx, existing.PollID) | ||||||
|  | 
 | ||||||
|  | 	case /*existing.Poll != nil && status.Poll != nil && */ | ||||||
|  | 		!slices.Equal(existing.Poll.Options, status.Poll.Options) || | ||||||
|  | 			!existing.Poll.ExpiresAt.Equal(status.Poll.ExpiresAt): | ||||||
|  | 		// poll has changed since original, delete and reinsert new. | ||||||
|  | 		if err := deleteStatusPoll(ctx, existing.PollID); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		return insertStatusPoll(ctx, status) | ||||||
|  | 
 | ||||||
|  | 	case /*existing.Poll != nil && status.Poll != nil && */ | ||||||
|  | 		!existing.Poll.ClosedAt.Equal(status.Poll.ClosedAt) || | ||||||
|  | 			!slices.Equal(existing.Poll.Votes, status.Poll.Votes) || | ||||||
|  | 			existing.Poll.Voters != status.Poll.Voters: | ||||||
|  | 		// Since we last saw it, the poll has updated! | ||||||
|  | 		// Whether that be stats, or close time. | ||||||
|  | 		poll := existing.Poll | ||||||
|  | 		poll.Closing = (!poll.Closed() && status.Poll.Closed()) | ||||||
|  | 		poll.ClosedAt = status.Poll.ClosedAt | ||||||
|  | 		poll.Voters = status.Poll.Voters | ||||||
|  | 		poll.Votes = status.Poll.Votes | ||||||
|  | 
 | ||||||
|  | 		// Update poll model in the database (specifically only the possible changed columns). | ||||||
|  | 		if err := d.state.DB.UpdatePoll(ctx, poll, "closed_at", "voters", "votes"); err != nil { | ||||||
|  | 			return gtserror.Newf("error updating poll: %w", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Update poll on status. | ||||||
|  | 		status.PollID = poll.ID | ||||||
|  | 		status.Poll = poll | ||||||
|  | 		return nil | ||||||
|  | 
 | ||||||
|  | 	default: | ||||||
|  | 		// latest and existing | ||||||
|  | 		// polls are up to date. | ||||||
|  | 		poll := existing.Poll | ||||||
|  | 		status.PollID = poll.ID | ||||||
|  | 		status.Poll = poll | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (d *Dereferencer) fetchStatusAttachments(ctx context.Context, tsport transport.Transport, existing, status *gtsmodel.Status) error { | func (d *Dereferencer) fetchStatusAttachments(ctx context.Context, tsport transport.Transport, existing, status *gtsmodel.Status) error { | ||||||
| 	// Allocate new slice to take the yet-to-be fetched attachment IDs. | 	// Allocate new slice to take the yet-to-be fetched attachment IDs. | ||||||
| 	status.AttachmentIDs = make([]string, len(status.Attachments)) | 	status.AttachmentIDs = make([]string, len(status.Attachments)) | ||||||
|  |  | ||||||
|  | @ -28,6 +28,7 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/ap" | 	"github.com/superseriousbusiness/gotosocial/internal/ap" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/config" | 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/id" | 	"github.com/superseriousbusiness/gotosocial/internal/id" | ||||||
|  | @ -141,6 +142,22 @@ func (f *federatingDB) activityCreate( | ||||||
| 	// Extract objects from create activity. | 	// Extract objects from create activity. | ||||||
| 	objects := ap.ExtractObjects(create) | 	objects := ap.ExtractObjects(create) | ||||||
| 
 | 
 | ||||||
|  | 	// Extract PollOptionables (votes!) from objects slice. | ||||||
|  | 	optionables, objects := ap.ExtractPollOptionables(objects) | ||||||
|  | 
 | ||||||
|  | 	if len(optionables) > 0 { | ||||||
|  | 		// Handle provided poll vote(s) creation, this can | ||||||
|  | 		// be for single or multiple votes in the same poll. | ||||||
|  | 		err := f.createPollOptionables(ctx, | ||||||
|  | 			receivingAccount, | ||||||
|  | 			requestingAccount, | ||||||
|  | 			optionables, | ||||||
|  | 		) | ||||||
|  | 		if err != nil { | ||||||
|  | 			errs.Appendf("error creating poll vote(s): %w", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Extract Statusables from objects slice (this must be | 	// Extract Statusables from objects slice (this must be | ||||||
| 	// done AFTER extracting options due to how AS typing works). | 	// done AFTER extracting options due to how AS typing works). | ||||||
| 	statusables, objects := ap.ExtractStatusables(objects) | 	statusables, objects := ap.ExtractStatusables(objects) | ||||||
|  | @ -169,6 +186,112 @@ func (f *federatingDB) activityCreate( | ||||||
| 	return errs.Combine() | 	return errs.Combine() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // createPollOptionable handles a Create activity for a PollOptionable. | ||||||
|  | // This function doesn't handle database insertion, only validation checks | ||||||
|  | // before passing off to a worker for asynchronous processing. | ||||||
|  | func (f *federatingDB) createPollOptionables( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	receiver *gtsmodel.Account, | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	options []ap.PollOptionable, | ||||||
|  | ) error { | ||||||
|  | 	var ( | ||||||
|  | 		// the origin Status w/ Poll the vote | ||||||
|  | 		// options are in. This gets set on first | ||||||
|  | 		// iteration, relevant checks performed | ||||||
|  | 		// then re-used in each further iteration. | ||||||
|  | 		inReplyTo *gtsmodel.Status | ||||||
|  | 
 | ||||||
|  | 		// the resulting slices of Poll.Option | ||||||
|  | 		// choice indices passed into the new | ||||||
|  | 		// created PollVote object. | ||||||
|  | 		choices []int | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	for _, option := range options { | ||||||
|  | 		// Extract the "inReplyTo" property. | ||||||
|  | 		inReplyToURIs := ap.GetInReplyTo(option) | ||||||
|  | 		if len(inReplyToURIs) != 1 { | ||||||
|  | 			return gtserror.Newf("invalid inReplyTo property length: %d", len(inReplyToURIs)) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Stringify the inReplyTo URI. | ||||||
|  | 		statusURI := inReplyToURIs[0].String() | ||||||
|  | 
 | ||||||
|  | 		if inReplyTo == nil { | ||||||
|  | 			var err error | ||||||
|  | 
 | ||||||
|  | 			// This is the first object in the activity slice, | ||||||
|  | 			// check database for the poll source status by URI. | ||||||
|  | 			inReplyTo, err = f.state.DB.GetStatusByURI(ctx, statusURI) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return gtserror.Newf("error getting poll source from database %s: %w", statusURI, err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			switch { | ||||||
|  | 			// The origin status isn't a poll? | ||||||
|  | 			case inReplyTo.PollID == "": | ||||||
|  | 				return gtserror.Newf("poll vote in status %s without poll", statusURI) | ||||||
|  | 
 | ||||||
|  | 			// We don't own the poll ... | ||||||
|  | 			case !*inReplyTo.Local: | ||||||
|  | 				return gtserror.Newf("poll vote in remote status %s", statusURI) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Check whether user has already vote in this poll. | ||||||
|  | 			// (we only check this for the first object, as multiple | ||||||
|  | 			// may be sent in response to a multiple-choice poll). | ||||||
|  | 			vote, err := f.state.DB.GetPollVoteBy( | ||||||
|  | 				gtscontext.SetBarebones(ctx), | ||||||
|  | 				inReplyTo.PollID, | ||||||
|  | 				requester.ID, | ||||||
|  | 			) | ||||||
|  | 			if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
|  | 				return gtserror.Newf("error getting status %s poll votes from database: %w", statusURI, err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if vote != nil { | ||||||
|  | 				log.Warnf(ctx, "%s has already voted in poll %s", requester.URI, statusURI) | ||||||
|  | 				return nil // this is a useful warning for admins to report to us from logs | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if statusURI != inReplyTo.URI { | ||||||
|  | 			// All activity votes should be to the same poll per activity. | ||||||
|  | 			return gtserror.New("votes to multiple polls in single activity") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Extract the poll option name. | ||||||
|  | 		name := ap.ExtractName(option) | ||||||
|  | 
 | ||||||
|  | 		// Check that this is a valid option name. | ||||||
|  | 		choice := inReplyTo.Poll.GetChoice(name) | ||||||
|  | 		if choice == -1 { | ||||||
|  | 			return gtserror.Newf("poll vote in status %s invalid: %s", statusURI, name) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Append the option index to choices. | ||||||
|  | 		choices = append(choices, choice) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Enqueue message to the fedi API worker with poll vote(s). | ||||||
|  | 	f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ | ||||||
|  | 		APActivityType: ap.ActivityCreate, | ||||||
|  | 		APObjectType:   ap.ActivityQuestion, | ||||||
|  | 		GTSModel: >smodel.PollVote{ | ||||||
|  | 			ID:        id.NewULID(), | ||||||
|  | 			Choices:   choices, | ||||||
|  | 			AccountID: requester.ID, | ||||||
|  | 			Account:   requester, | ||||||
|  | 			PollID:    inReplyTo.PollID, | ||||||
|  | 			Poll:      inReplyTo.Poll, | ||||||
|  | 		}, | ||||||
|  | 		ReceivingAccount: receiver, | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // createStatusable handles a Create activity for a Statusable. | // createStatusable handles a Create activity for a Statusable. | ||||||
| // This function won't insert anything in the database yet, | // This function won't insert anything in the database yet, | ||||||
| // but will pass the Statusable (if appropriate) through to | // but will pass the Statusable (if appropriate) through to | ||||||
|  |  | ||||||
							
								
								
									
										121
									
								
								internal/gtsmodel/poll.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								internal/gtsmodel/poll.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,121 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package gtsmodel | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Poll represents an attached (to) Status poll, i.e. a questionaire. Can be remote / local. | ||||||
|  | type Poll struct { | ||||||
|  | 	ID         string    `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // Unique identity string. | ||||||
|  | 	Multiple   *bool     `bun:",nullzero,notnull,default:false"`          // Is this a multiple choice poll? i.e. can you vote on multiple options. | ||||||
|  | 	HideCounts *bool     `bun:",nullzero,notnull,default:false"`          // Hides vote counts until poll ends. | ||||||
|  | 	Options    []string  `bun:",nullzero,notnull"`                        // The available options for this poll. | ||||||
|  | 	Votes      []int     `bun:",nullzero,notnull"`                        // Vote counts per choice. | ||||||
|  | 	Voters     *int      `bun:",nullzero,notnull"`                        // Total no. voters count. | ||||||
|  | 	StatusID   string    `bun:"type:CHAR(26),nullzero,notnull,unique"`    // Status ID of which this Poll is attached to. | ||||||
|  | 	Status     *Status   `bun:"-"`                                        // The related Status for StatusID (not always set). | ||||||
|  | 	ExpiresAt  time.Time `bun:"type:timestamptz,nullzero,notnull"`        // The expiry date of this Poll. | ||||||
|  | 	ClosedAt   time.Time `bun:"type:timestamptz,nullzero"`                // The closure date of this poll, will be zerotime until set. | ||||||
|  | 	Closing    bool      `bun:"-"`                                        // An ephemeral field only set on Polls in the middle of closing. | ||||||
|  | 	// no creation date, use attached Status.CreatedAt. | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetChoice returns the option index with name. | ||||||
|  | func (p *Poll) GetChoice(name string) int { | ||||||
|  | 	for i, option := range p.Options { | ||||||
|  | 		if strings.EqualFold(option, name) { | ||||||
|  | 			return i | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return -1 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Expired returns whether the Poll is expired (i.e. date is BEFORE now). | ||||||
|  | func (p *Poll) Expired() bool { | ||||||
|  | 	return !p.ExpiresAt.IsZero() && | ||||||
|  | 		time.Now().After(p.ExpiresAt) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Closed returns whether the Poll is closed (i.e. date is set and BEFORE now). | ||||||
|  | func (p *Poll) Closed() bool { | ||||||
|  | 	return !p.ClosedAt.IsZero() && | ||||||
|  | 		time.Now().After(p.ClosedAt) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // IncrementVotes increments Poll vote and voter counts for given choices. | ||||||
|  | func (p *Poll) IncrementVotes(choices []int) { | ||||||
|  | 	if len(choices) == 0 { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	p.CheckVotes() | ||||||
|  | 	for _, choice := range choices { | ||||||
|  | 		p.Votes[choice]++ | ||||||
|  | 	} | ||||||
|  | 	(*p.Voters)++ | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // DecrementVotes decrements Poll vote and voter counts for given choices. | ||||||
|  | func (p *Poll) DecrementVotes(choices []int) { | ||||||
|  | 	if len(choices) == 0 { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	p.CheckVotes() | ||||||
|  | 	for _, choice := range choices { | ||||||
|  | 		if p.Votes[choice] != 0 { | ||||||
|  | 			p.Votes[choice]-- | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if (*p.Voters) != 0 { | ||||||
|  | 		(*p.Voters)-- | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ResetVotes resets all stored vote counts. | ||||||
|  | func (p *Poll) ResetVotes() { | ||||||
|  | 	p.Votes = make([]int, len(p.Options)) | ||||||
|  | 	p.Voters = new(int) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CheckVotes ensures that the Poll.Votes slice is not nil, | ||||||
|  | // else initializing an int slice len+cap equal to Poll.Options. | ||||||
|  | // Note this should not be needed anywhere other than the | ||||||
|  | // database and the processor. | ||||||
|  | func (p *Poll) CheckVotes() { | ||||||
|  | 	if p.Votes == nil { | ||||||
|  | 		p.Votes = make([]int, len(p.Options)) | ||||||
|  | 	} | ||||||
|  | 	if p.Voters == nil { | ||||||
|  | 		p.Voters = new(int) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // PollVote represents a single instance of vote(s) in a Poll by an account. | ||||||
|  | // If the Poll is single-choice, len(.Choices) = 1, if multiple-choice then | ||||||
|  | // len(.Choices) >= 1. Can be remote or local. | ||||||
|  | type PollVote struct { | ||||||
|  | 	ID        string    `bun:"type:CHAR(26),pk,nullzero,notnull,unique"`                    // Unique identity string. | ||||||
|  | 	Choices   []int     `bun:",nullzero,notnull"`                                           // The Poll's option indices of which these are votes for. | ||||||
|  | 	AccountID string    `bun:"type:CHAR(26),nullzero,notnull,unique:in_poll_by_account"`    // Account ID from which this vote originated. | ||||||
|  | 	Account   *Account  `bun:"-"`                                                           // The related Account for AccountID (not always set). | ||||||
|  | 	PollID    string    `bun:"type:CHAR(26),nullzero,notnull,unique:in_poll_by_account"`    // Poll ID of which this is a vote in. | ||||||
|  | 	Poll      *Poll     `bun:"-"`                                                           // The related Poll for PollID (not always set). | ||||||
|  | 	CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // The creation date of this PollVote. | ||||||
|  | } | ||||||
|  | @ -54,6 +54,8 @@ type Status struct { | ||||||
| 	BoostOf                  *Status            `bun:"-"`                                                           // status that corresponds to boostOfID | 	BoostOf                  *Status            `bun:"-"`                                                           // status that corresponds to boostOfID | ||||||
| 	BoostOfAccount           *Account           `bun:"rel:belongs-to"`                                              // account that corresponds to boostOfAccountID | 	BoostOfAccount           *Account           `bun:"rel:belongs-to"`                                              // account that corresponds to boostOfAccountID | ||||||
| 	ThreadID                 string             `bun:"type:CHAR(26),nullzero"`                                      // id of the thread to which this status belongs; only set for remote statuses if a local account is involved at some point in the thread, otherwise null | 	ThreadID                 string             `bun:"type:CHAR(26),nullzero"`                                      // id of the thread to which this status belongs; only set for remote statuses if a local account is involved at some point in the thread, otherwise null | ||||||
|  | 	PollID                   string             `bun:"type:CHAR(26),nullzero"`                                      // | ||||||
|  | 	Poll                     *Poll              `bun:"-"`                                                           // | ||||||
| 	ContentWarning           string             `bun:",nullzero"`                                                   // cw string for this status | 	ContentWarning           string             `bun:",nullzero"`                                                   // cw string for this status | ||||||
| 	Visibility               Visibility         `bun:",nullzero,notnull"`                                           // visibility entry for this status | 	Visibility               Visibility         `bun:",nullzero,notnull"`                                           // visibility entry for this status | ||||||
| 	Sensitive                *bool              `bun:",nullzero,notnull,default:false"`                             // mark the status as sensitive? | 	Sensitive                *bool              `bun:",nullzero,notnull,default:false"`                             // mark the status as sensitive? | ||||||
|  |  | ||||||
|  | @ -387,12 +387,12 @@ statusLoop: | ||||||
| func (p *Processor) deleteAccountNotifications(ctx context.Context, account *gtsmodel.Account) error { | func (p *Processor) deleteAccountNotifications(ctx context.Context, account *gtsmodel.Account) error { | ||||||
| 	// Delete all notifications of all types targeting given account. | 	// Delete all notifications of all types targeting given account. | ||||||
| 	if err := p.state.DB.DeleteNotifications(ctx, nil, account.ID, ""); err != nil && !errors.Is(err, db.ErrNoEntries) { | 	if err := p.state.DB.DeleteNotifications(ctx, nil, account.ID, ""); err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return err | 		return gtserror.Newf("error deleting notifications targeting account: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Delete all notifications of all types originating from given account. | 	// Delete all notifications of all types originating from given account. | ||||||
| 	if err := p.state.DB.DeleteNotifications(ctx, nil, "", account.ID); err != nil && !errors.Is(err, db.ErrNoEntries) { | 	if err := p.state.DB.DeleteNotifications(ctx, nil, "", account.ID); err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return err | 		return gtserror.Newf("error deleting notifications by account: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -402,29 +402,35 @@ func (p *Processor) deleteAccountPeripheral(ctx context.Context, account *gtsmod | ||||||
| 	// Delete all bookmarks owned by given account. | 	// Delete all bookmarks owned by given account. | ||||||
| 	if err := p.state.DB.DeleteStatusBookmarks(ctx, account.ID, ""); // nocollapse | 	if err := p.state.DB.DeleteStatusBookmarks(ctx, account.ID, ""); // nocollapse | ||||||
| 	err != nil && !errors.Is(err, db.ErrNoEntries) { | 	err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return err | 		return gtserror.Newf("error deleting bookmarks by account: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Delete all bookmarks targeting given account. | 	// Delete all bookmarks targeting given account. | ||||||
| 	if err := p.state.DB.DeleteStatusBookmarks(ctx, "", account.ID); // nocollapse | 	if err := p.state.DB.DeleteStatusBookmarks(ctx, "", account.ID); // nocollapse | ||||||
| 	err != nil && !errors.Is(err, db.ErrNoEntries) { | 	err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return err | 		return gtserror.Newf("error deleting bookmarks targeting account: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Delete all faves owned by given account. | 	// Delete all faves owned by given account. | ||||||
| 	if err := p.state.DB.DeleteStatusFaves(ctx, account.ID, ""); // nocollapse | 	if err := p.state.DB.DeleteStatusFaves(ctx, account.ID, ""); // nocollapse | ||||||
| 	err != nil && !errors.Is(err, db.ErrNoEntries) { | 	err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return err | 		return gtserror.Newf("error deleting faves by account: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Delete all faves targeting given account. | 	// Delete all faves targeting given account. | ||||||
| 	if err := p.state.DB.DeleteStatusFaves(ctx, "", account.ID); // nocollapse | 	if err := p.state.DB.DeleteStatusFaves(ctx, "", account.ID); // nocollapse | ||||||
| 	err != nil && !errors.Is(err, db.ErrNoEntries) { | 	err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return err | 		return gtserror.Newf("error deleting faves targeting account: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// TODO: add status mutes here when they're implemented. | 	// TODO: add status mutes here when they're implemented. | ||||||
| 
 | 
 | ||||||
|  | 	// Delete all poll votes owned by given account. | ||||||
|  | 	if err := p.state.DB.DeletePollVotesByAccountID(ctx, account.ID); // nocollapse | ||||||
|  | 	err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
|  | 		return gtserror.Newf("error deleting poll votes by account: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -47,8 +47,11 @@ func (p *Processor) GetTargetAccountBy( | ||||||
| 
 | 
 | ||||||
| 	if target == nil { | 	if target == nil { | ||||||
| 		// DB loader could not find account in database. | 		// DB loader could not find account in database. | ||||||
| 		err := errors.New("target account not found") | 		const text = "target account not found" | ||||||
| 		return nil, false, gtserror.NewErrorNotFound(err) | 		return nil, false, gtserror.NewErrorNotFound( | ||||||
|  | 			errors.New(text), | ||||||
|  | 			text, | ||||||
|  | 		) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Check whether target account is visible to requesting account. | 	// Check whether target account is visible to requesting account. | ||||||
|  | @ -106,8 +109,11 @@ func (p *Processor) GetVisibleTargetAccount( | ||||||
| 
 | 
 | ||||||
| 	if !visible { | 	if !visible { | ||||||
| 		// Pretend account doesn't exist if not visible. | 		// Pretend account doesn't exist if not visible. | ||||||
| 		err := errors.New("target account not found") | 		const text = "target account not found" | ||||||
| 		return nil, gtserror.NewErrorNotFound(err) | 		return nil, gtserror.NewErrorNotFound( | ||||||
|  | 			errors.New(text), | ||||||
|  | 			text, | ||||||
|  | 		) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return target, nil | 	return target, nil | ||||||
|  |  | ||||||
|  | @ -47,8 +47,11 @@ func (p *Processor) GetTargetStatusBy( | ||||||
| 
 | 
 | ||||||
| 	if target == nil { | 	if target == nil { | ||||||
| 		// DB loader could not find status in database. | 		// DB loader could not find status in database. | ||||||
| 		err := errors.New("target status not found") | 		const text = "target status not found" | ||||||
| 		return nil, false, gtserror.NewErrorNotFound(err) | 		return nil, false, gtserror.NewErrorNotFound( | ||||||
|  | 			errors.New(text), | ||||||
|  | 			text, | ||||||
|  | 		) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Check whether target status is visible to requesting account. | 	// Check whether target status is visible to requesting account. | ||||||
|  | @ -106,8 +109,11 @@ func (p *Processor) GetVisibleTargetStatus( | ||||||
| 
 | 
 | ||||||
| 	if !visible { | 	if !visible { | ||||||
| 		// Target should not be seen by requester. | 		// Target should not be seen by requester. | ||||||
| 		err := errors.New("target status not found") | 		const text = "target status not found" | ||||||
| 		return nil, gtserror.NewErrorNotFound(err) | 		return nil, gtserror.NewErrorNotFound( | ||||||
|  | 			errors.New(text), | ||||||
|  | 			text, | ||||||
|  | 		) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return target, nil | 	return target, nil | ||||||
|  |  | ||||||
|  | @ -56,12 +56,12 @@ func (p *Processor) StatusGet(ctx context.Context, requestedUsername string, req | ||||||
| 		return nil, gtserror.NewErrorNotFound(err) | 		return nil, gtserror.NewErrorNotFound(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	asStatus, err := p.converter.StatusToAS(ctx, status) | 	statusable, err := p.converter.StatusToAS(ctx, status) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	data, err := ap.Serialize(asStatus) | 	data, err := ap.Serialize(statusable) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
							
								
								
									
										126
									
								
								internal/processing/polls/expiry.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								internal/processing/polls/expiry.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,126 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package polls | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/ap" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/messages" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func (p *Processor) ScheduleAll(ctx context.Context) error { | ||||||
|  | 	// Fetch all open polls from the database (barebones models are enough). | ||||||
|  | 	polls, err := p.state.DB.GetOpenPolls(gtscontext.SetBarebones(ctx)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return gtserror.Newf("error getting open polls from db: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var errs gtserror.MultiError | ||||||
|  | 
 | ||||||
|  | 	for _, poll := range polls { | ||||||
|  | 		// Schedule each of the polls and catch any errors. | ||||||
|  | 		if err := p.ScheduleExpiry(ctx, poll); err != nil { | ||||||
|  | 			errs.Append(err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return errs.Combine() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *Processor) ScheduleExpiry(ctx context.Context, poll *gtsmodel.Poll) error { | ||||||
|  | 	// Ensure has a valid expiry. | ||||||
|  | 	if !poll.ClosedAt.IsZero() { | ||||||
|  | 		return gtserror.Newf("poll %s already expired", poll.ID) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Add the given poll to the scheduler. | ||||||
|  | 	ok := p.state.Workers.Scheduler.AddOnce( | ||||||
|  | 		poll.ID, | ||||||
|  | 		poll.ExpiresAt, | ||||||
|  | 		p.onExpiry(poll.ID), | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	if !ok { | ||||||
|  | 		// Failed to add the poll to the scheduler, either it was | ||||||
|  | 		// starting / stopping or there already exists a task for poll. | ||||||
|  | 		return gtserror.Newf("failed adding poll %s to scheduler", poll.ID) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	atStr := poll.ExpiresAt.Local().Format("Jan _2 2006 15:04:05") | ||||||
|  | 	log.Infof(ctx, "scheduled poll expiry for %s at '%s'", poll.ID, atStr) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // onExpiry returns a callback function to be used by the scheduler when the given poll expires. | ||||||
|  | func (p *Processor) onExpiry(pollID string) func(context.Context, time.Time) { | ||||||
|  | 	return func(ctx context.Context, now time.Time) { | ||||||
|  | 		// Get the latest version of poll from database. | ||||||
|  | 		poll, err := p.state.DB.GetPollByID(ctx, pollID) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Errorf(ctx, "error getting poll %s from db: %v", pollID, err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if !poll.ClosedAt.IsZero() { | ||||||
|  | 			// Expiry handler has already been run for this poll. | ||||||
|  | 			log.Errorf(ctx, "poll %s already closed", pollID) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Extract status and | ||||||
|  | 		// set its Poll field. | ||||||
|  | 		status := poll.Status | ||||||
|  | 		status.Poll = poll | ||||||
|  | 
 | ||||||
|  | 		// Ensure the status is fully populated (we need the account) | ||||||
|  | 		if err := p.state.DB.PopulateStatus(ctx, status); err != nil { | ||||||
|  | 			log.Errorf(ctx, "error populating poll %s status: %v", pollID, err) | ||||||
|  | 
 | ||||||
|  | 			if status.Account == nil { | ||||||
|  | 				// cannot continue without | ||||||
|  | 				// status account author. | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Set "closed" time. | ||||||
|  | 		poll.ClosedAt = now | ||||||
|  | 		poll.Closing = true | ||||||
|  | 
 | ||||||
|  | 		// Update the Poll to mark it as closed in the database. | ||||||
|  | 		if err := p.state.DB.UpdatePoll(ctx, poll, "closed_at"); err != nil { | ||||||
|  | 			log.Errorf(ctx, "error updating poll %s in db: %v", pollID, err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Enqueue a status update operation to the client API worker, | ||||||
|  | 		// this will asynchronously send an update with the Poll close time. | ||||||
|  | 		p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ | ||||||
|  | 			APActivityType: ap.ActivityUpdate, | ||||||
|  | 			APObjectType:   ap.ObjectNote, | ||||||
|  | 			GTSModel:       status, | ||||||
|  | 			OriginAccount:  status.Account, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										37
									
								
								internal/processing/polls/get.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								internal/processing/polls/get.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,37 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package polls | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 
 | ||||||
|  | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func (p *Processor) PollGet(ctx context.Context, requester *gtsmodel.Account, pollID string) (*apimodel.Poll, gtserror.WithCode) { | ||||||
|  | 	// Get (+ check visibility of) requested poll with ID. | ||||||
|  | 	poll, errWithCode := p.getTargetPoll(ctx, requester, pollID) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		return nil, errWithCode | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Return converted API model poll. | ||||||
|  | 	return p.toAPIPoll(ctx, requester, poll) | ||||||
|  | } | ||||||
							
								
								
									
										91
									
								
								internal/processing/polls/poll.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								internal/processing/polls/poll.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,91 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package polls | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 
 | ||||||
|  | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/processing/common" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/typeutils" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Processor struct { | ||||||
|  | 	// common processor logic | ||||||
|  | 	c *common.Processor | ||||||
|  | 
 | ||||||
|  | 	state     *state.State | ||||||
|  | 	converter *typeutils.Converter | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func New(common *common.Processor, state *state.State, converter *typeutils.Converter) Processor { | ||||||
|  | 	return Processor{ | ||||||
|  | 		c:         common, | ||||||
|  | 		state:     state, | ||||||
|  | 		converter: converter, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getTargetPoll fetches a target poll ID for requesting account, taking visibility of the poll's originating status into account. | ||||||
|  | func (p *Processor) getTargetPoll(ctx context.Context, requestingAccount *gtsmodel.Account, targetID string) (*gtsmodel.Poll, gtserror.WithCode) { | ||||||
|  | 	// Load the requested poll with ID. | ||||||
|  | 	// (barebones as we fetch status below) | ||||||
|  | 	poll, err := p.state.DB.GetPollByID( | ||||||
|  | 		gtscontext.SetBarebones(ctx), | ||||||
|  | 		targetID, | ||||||
|  | 	) | ||||||
|  | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
|  | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if poll == nil { | ||||||
|  | 		// No poll could be found for given ID. | ||||||
|  | 		const text = "target poll not found" | ||||||
|  | 		return nil, gtserror.NewErrorNotFound( | ||||||
|  | 			errors.New(text), | ||||||
|  | 			text, | ||||||
|  | 		) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check that we can see + fetch the originating status for requesting account. | ||||||
|  | 	status, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, poll.StatusID) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		return nil, errWithCode | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Update poll status. | ||||||
|  | 	poll.Status = status | ||||||
|  | 
 | ||||||
|  | 	return poll, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // toAPIPoll converrts a given Poll to frontend API model, returning an appropriate error with HTTP code on failure. | ||||||
|  | func (p *Processor) toAPIPoll(ctx context.Context, requester *gtsmodel.Account, poll *gtsmodel.Poll) (*apimodel.Poll, gtserror.WithCode) { | ||||||
|  | 	apiPoll, err := p.converter.PollToAPIPoll(ctx, requester, poll) | ||||||
|  | 	if err != nil { | ||||||
|  | 		err := gtserror.Newf("error converting to api model: %w", err) | ||||||
|  | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
|  | 	} | ||||||
|  | 	return apiPoll, nil | ||||||
|  | } | ||||||
							
								
								
									
										234
									
								
								internal/processing/polls/poll_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										234
									
								
								internal/processing/polls/poll_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,234 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package polls_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"math/rand" | ||||||
|  | 	"net/http" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"github.com/stretchr/testify/suite" | ||||||
|  | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/media" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/processing/common" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/processing/polls" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/typeutils" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/visibility" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/testrig" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type PollTestSuite struct { | ||||||
|  | 	suite.Suite | ||||||
|  | 	state  state.State | ||||||
|  | 	filter *visibility.Filter | ||||||
|  | 	polls  polls.Processor | ||||||
|  | 
 | ||||||
|  | 	testAccounts map[string]*gtsmodel.Account | ||||||
|  | 	testPolls    map[string]*gtsmodel.Poll | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) SetupTest() { | ||||||
|  | 	testrig.InitTestConfig() | ||||||
|  | 	testrig.InitTestLog() | ||||||
|  | 	suite.state.Caches.Init() | ||||||
|  | 	testrig.StartWorkers(&suite.state) | ||||||
|  | 	testrig.NewTestDB(&suite.state) | ||||||
|  | 	converter := typeutils.NewConverter(&suite.state) | ||||||
|  | 	controller := testrig.NewTestTransportController(&suite.state, nil) | ||||||
|  | 	mediaMgr := media.NewManager(&suite.state) | ||||||
|  | 	federator := testrig.NewTestFederator(&suite.state, controller, mediaMgr) | ||||||
|  | 	suite.filter = visibility.NewFilter(&suite.state) | ||||||
|  | 	common := common.New(&suite.state, converter, federator, suite.filter) | ||||||
|  | 	suite.polls = polls.New(&common, &suite.state, converter) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) TearDownTest() { | ||||||
|  | 	testrig.StopWorkers(&suite.state) | ||||||
|  | 	testrig.StandardDBTeardown(suite.state.DB) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) TestPollGet() { | ||||||
|  | 	// Create a new context for this test. | ||||||
|  | 	ctx, cncl := context.WithCancel(context.Background()) | ||||||
|  | 	defer cncl() | ||||||
|  | 
 | ||||||
|  | 	// Perform test for all requester + poll combos. | ||||||
|  | 	for _, account := range suite.testAccounts { | ||||||
|  | 		for _, poll := range suite.testPolls { | ||||||
|  | 			suite.testPollGet(ctx, account, poll) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) testPollGet(ctx context.Context, requester *gtsmodel.Account, poll *gtsmodel.Poll) { | ||||||
|  | 	// Ensure poll model is fully populated before anything. | ||||||
|  | 	if err := suite.state.DB.PopulatePoll(ctx, poll); err != nil { | ||||||
|  | 		suite.T().Fatalf("error populating poll: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var check func(*apimodel.Poll, gtserror.WithCode) bool | ||||||
|  | 
 | ||||||
|  | 	switch { | ||||||
|  | 	case !pollIsVisible(suite.filter, ctx, requester, poll): | ||||||
|  | 		// Poll should not be visible to requester, this should | ||||||
|  | 		// return an error code 404 (to prevent info leak). | ||||||
|  | 		check = func(poll *apimodel.Poll, err gtserror.WithCode) bool { | ||||||
|  | 			return poll == nil && err.Code() == http.StatusNotFound | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 	default: | ||||||
|  | 		// All other cases should succeed! i.e. no error and poll returned. | ||||||
|  | 		check = func(poll *apimodel.Poll, err gtserror.WithCode) bool { | ||||||
|  | 			return poll != nil && err == nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Perform the poll vote and check the expected response. | ||||||
|  | 	if !check(suite.polls.PollGet(ctx, requester, poll.ID)) { | ||||||
|  | 		suite.T().Errorf("unexpected response for poll get by %s", requester.DisplayName) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) TestPollVote() { | ||||||
|  | 	// Create a new context for this test. | ||||||
|  | 	ctx, cncl := context.WithCancel(context.Background()) | ||||||
|  | 	defer cncl() | ||||||
|  | 
 | ||||||
|  | 	// randomChoices generates random vote choices in poll. | ||||||
|  | 	randomChoices := func(poll *gtsmodel.Poll) []int { | ||||||
|  | 		var max int | ||||||
|  | 		if *poll.Multiple { | ||||||
|  | 			max = len(poll.Options) | ||||||
|  | 		} else { | ||||||
|  | 			max = 1 | ||||||
|  | 		} | ||||||
|  | 		count := 1 + rand.Intn(max) | ||||||
|  | 		choices := make([]int, count) | ||||||
|  | 		for i := range choices { | ||||||
|  | 			choices[i] = rand.Intn(len(poll.Options)) | ||||||
|  | 		} | ||||||
|  | 		return choices | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Perform test for all requester + poll combos. | ||||||
|  | 	for _, account := range suite.testAccounts { | ||||||
|  | 		for _, poll := range suite.testPolls { | ||||||
|  | 			// Generate some valid choices and test. | ||||||
|  | 			choices := randomChoices(poll) | ||||||
|  | 			suite.testPollVote(ctx, | ||||||
|  | 				account, | ||||||
|  | 				poll, | ||||||
|  | 				choices, | ||||||
|  | 			) | ||||||
|  | 
 | ||||||
|  | 			// Test with empty choices. | ||||||
|  | 			suite.testPollVote(ctx, | ||||||
|  | 				account, | ||||||
|  | 				poll, | ||||||
|  | 				nil, | ||||||
|  | 			) | ||||||
|  | 
 | ||||||
|  | 			// Test with out of range choice. | ||||||
|  | 			suite.testPollVote(ctx, | ||||||
|  | 				account, | ||||||
|  | 				poll, | ||||||
|  | 				[]int{len(poll.Options)}, | ||||||
|  | 			) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PollTestSuite) testPollVote(ctx context.Context, requester *gtsmodel.Account, poll *gtsmodel.Poll, choices []int) { | ||||||
|  | 	// Ensure poll model is fully populated before anything. | ||||||
|  | 	if err := suite.state.DB.PopulatePoll(ctx, poll); err != nil { | ||||||
|  | 		suite.T().Fatalf("error populating poll: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var check func(*apimodel.Poll, gtserror.WithCode) bool | ||||||
|  | 
 | ||||||
|  | 	switch { | ||||||
|  | 	case !poll.ClosedAt.IsZero(): | ||||||
|  | 		// Poll is already closed, i.e. no new votes allowed! | ||||||
|  | 		// This should return an error 422 (unprocessable entity). | ||||||
|  | 		check = func(poll *apimodel.Poll, err gtserror.WithCode) bool { | ||||||
|  | 			return poll == nil && err.Code() == http.StatusUnprocessableEntity | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 	case !voteChoicesAreValid(poll, choices): | ||||||
|  | 		// These are invalid vote choices, this should return | ||||||
|  | 		// an error code 400 to indicate invalid request data. | ||||||
|  | 		check = func(poll *apimodel.Poll, err gtserror.WithCode) bool { | ||||||
|  | 			return poll == nil && err.Code() == http.StatusBadRequest | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 	case poll.Status.AccountID == requester.ID: | ||||||
|  | 		// Immediately we know that poll owner cannot vote in | ||||||
|  | 		// their own poll. this should return an error 422. | ||||||
|  | 		check = func(poll *apimodel.Poll, err gtserror.WithCode) bool { | ||||||
|  | 			return poll == nil && err.Code() == http.StatusUnprocessableEntity | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 	case !pollIsVisible(suite.filter, ctx, requester, poll): | ||||||
|  | 		// Poll should not be visible to requester, this should | ||||||
|  | 		// return an error code 404 (to prevent info leak). | ||||||
|  | 		check = func(poll *apimodel.Poll, err gtserror.WithCode) bool { | ||||||
|  | 			return poll == nil && err.Code() == http.StatusNotFound | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 	default: | ||||||
|  | 		// All other cases should succeed! i.e. no error and poll returned. | ||||||
|  | 		check = func(poll *apimodel.Poll, err gtserror.WithCode) bool { | ||||||
|  | 			return poll != nil && err == nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Perform the poll vote and check the expected response. | ||||||
|  | 	if !check(suite.polls.PollVote(ctx, requester, poll.ID, choices)) { | ||||||
|  | 		suite.T().Errorf("unexpected response for poll vote by %s with %v", requester.DisplayName, choices) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // voteChoicesAreValid is a utility function to check whether choices are valid for poll. | ||||||
|  | func voteChoicesAreValid(poll *gtsmodel.Poll, choices []int) bool { | ||||||
|  | 	if len(choices) == 0 || !*poll.Multiple && len(choices) > 1 { | ||||||
|  | 		// Invalid number of vote choices. | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	for _, choice := range choices { | ||||||
|  | 		if choice < 0 || choice >= len(poll.Options) { | ||||||
|  | 			// Choice index out of range. | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // pollIsVisible is a short-hand function to return only a single boolean value for a visibility check on poll source status to account. | ||||||
|  | func pollIsVisible(filter *visibility.Filter, ctx context.Context, to *gtsmodel.Account, poll *gtsmodel.Poll) bool { | ||||||
|  | 	visible, _ := filter.StatusVisible(ctx, to, poll.Status) | ||||||
|  | 	return visible | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestPollTestSuite(t *testing.T) { | ||||||
|  | 	suite.Run(t, new(PollTestSuite)) | ||||||
|  | } | ||||||
							
								
								
									
										108
									
								
								internal/processing/polls/vote.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								internal/processing/polls/vote.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,108 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package polls | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 
 | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/ap" | ||||||
|  | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/id" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/messages" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func (p *Processor) PollVote(ctx context.Context, requester *gtsmodel.Account, pollID string, choices []int) (*apimodel.Poll, gtserror.WithCode) { | ||||||
|  | 	// Get (+ check visibility of) requested poll with ID. | ||||||
|  | 	poll, errWithCode := p.getTargetPoll(ctx, requester, pollID) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		return nil, errWithCode | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	switch { | ||||||
|  | 	// Poll author isn't allowed to vote in their own poll. | ||||||
|  | 	case requester.ID == poll.Status.AccountID: | ||||||
|  | 		const text = "you can't vote in your own poll" | ||||||
|  | 		return nil, gtserror.NewErrorUnprocessableEntity(errors.New(text), text) | ||||||
|  | 
 | ||||||
|  | 	// Poll has already closed, no more voting! | ||||||
|  | 	case !poll.ClosedAt.IsZero(): | ||||||
|  | 		const text = "poll already closed" | ||||||
|  | 		return nil, gtserror.NewErrorUnprocessableEntity(errors.New(text), text) | ||||||
|  | 
 | ||||||
|  | 	// No choices given, or multiple given for single-choice poll. | ||||||
|  | 	case len(choices) == 0 || (!*poll.Multiple && len(choices) > 1): | ||||||
|  | 		const text = "invalid number of choices for poll" | ||||||
|  | 		return nil, gtserror.NewErrorBadRequest(errors.New(text), text) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, choice := range choices { | ||||||
|  | 		if choice < 0 || choice >= len(poll.Options) { | ||||||
|  | 			// This is an invalid choice (index out of range). | ||||||
|  | 			const text = "invalid option index for poll" | ||||||
|  | 			return nil, gtserror.NewErrorBadRequest(errors.New(text), text) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Wrap the choices in a PollVote model. | ||||||
|  | 	vote := >smodel.PollVote{ | ||||||
|  | 		ID:        id.NewULID(), | ||||||
|  | 		Choices:   choices, | ||||||
|  | 		AccountID: requester.ID, | ||||||
|  | 		Account:   requester, | ||||||
|  | 		PollID:    pollID, | ||||||
|  | 		Poll:      poll, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Insert the new poll votes into the database. | ||||||
|  | 	err := p.state.DB.PutPollVote(ctx, vote) | ||||||
|  | 	switch { | ||||||
|  | 
 | ||||||
|  | 	case err == nil: | ||||||
|  | 		// no issue. | ||||||
|  | 
 | ||||||
|  | 	case errors.Is(err, db.ErrAlreadyExists): | ||||||
|  | 		// Users cannot vote multiple *times* (not choices). | ||||||
|  | 		const text = "you have already voted in poll" | ||||||
|  | 		return nil, gtserror.NewErrorUnprocessableEntity(err, text) | ||||||
|  | 
 | ||||||
|  | 	default: | ||||||
|  | 		// Any other irrecoverable database error. | ||||||
|  | 		err := gtserror.Newf("error inserting poll vote: %w", err) | ||||||
|  | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Enqueue worker task to handle side-effects of user poll vote(s). | ||||||
|  | 	p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ | ||||||
|  | 		APActivityType: ap.ActivityCreate, | ||||||
|  | 		APObjectType:   ap.ActivityQuestion, | ||||||
|  | 		GTSModel:       vote, // the vote choices | ||||||
|  | 		OriginAccount:  requester, | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	// Before returning the converted poll model, | ||||||
|  | 	// increment the vote counts on our local copy | ||||||
|  | 	// to get latest, instead of another db query. | ||||||
|  | 	poll.IncrementVotes(choices) | ||||||
|  | 
 | ||||||
|  | 	// Return converted API model poll. | ||||||
|  | 	return p.toAPIPoll(ctx, requester, poll) | ||||||
|  | } | ||||||
|  | @ -30,6 +30,7 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/list" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/list" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/markers" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/markers" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/media" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/media" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/processing/polls" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/report" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/report" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/search" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/search" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/status" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/status" | ||||||
|  | @ -64,6 +65,7 @@ type Processor struct { | ||||||
| 	list     list.Processor | 	list     list.Processor | ||||||
| 	markers  markers.Processor | 	markers  markers.Processor | ||||||
| 	media    media.Processor | 	media    media.Processor | ||||||
|  | 	polls    polls.Processor | ||||||
| 	report   report.Processor | 	report   report.Processor | ||||||
| 	search   search.Processor | 	search   search.Processor | ||||||
| 	status   status.Processor | 	status   status.Processor | ||||||
|  | @ -97,6 +99,10 @@ func (p *Processor) Media() *media.Processor { | ||||||
| 	return &p.media | 	return &p.media | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (p *Processor) Polls() *polls.Processor { | ||||||
|  | 	return &p.polls | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (p *Processor) Report() *report.Processor { | func (p *Processor) Report() *report.Processor { | ||||||
| 	return &p.report | 	return &p.report | ||||||
| } | } | ||||||
|  | @ -151,23 +157,22 @@ func NewProcessor( | ||||||
| 	// Start with sub processors that will | 	// Start with sub processors that will | ||||||
| 	// be required by the workers processor. | 	// be required by the workers processor. | ||||||
| 	commonProcessor := common.New(state, converter, federator, filter) | 	commonProcessor := common.New(state, converter, federator, filter) | ||||||
| 	accountProcessor := account.New(&commonProcessor, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc) | 	processor.account = account.New(&commonProcessor, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc) | ||||||
| 	mediaProcessor := media.New(state, converter, mediaManager, federator.TransportController()) | 	processor.media = media.New(state, converter, mediaManager, federator.TransportController()) | ||||||
| 	streamProcessor := stream.New(state, oauthServer) | 	processor.stream = stream.New(state, oauthServer) | ||||||
| 
 | 
 | ||||||
| 	// Instantiate the rest of the sub | 	// Instantiate the rest of the sub | ||||||
| 	// processors + pin them to this struct. | 	// processors + pin them to this struct. | ||||||
| 	processor.account = accountProcessor | 	processor.account = account.New(&commonProcessor, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc) | ||||||
| 	processor.admin = admin.New(state, cleaner, converter, mediaManager, federator.TransportController(), emailSender) | 	processor.admin = admin.New(state, cleaner, converter, mediaManager, federator.TransportController(), emailSender) | ||||||
| 	processor.fedi = fedi.New(state, converter, federator, filter) | 	processor.fedi = fedi.New(state, converter, federator, filter) | ||||||
| 	processor.list = list.New(state, converter) | 	processor.list = list.New(state, converter) | ||||||
| 	processor.markers = markers.New(state, converter) | 	processor.markers = markers.New(state, converter) | ||||||
| 	processor.media = mediaProcessor | 	processor.polls = polls.New(&commonProcessor, state, converter) | ||||||
| 	processor.report = report.New(state, converter) | 	processor.report = report.New(state, converter) | ||||||
| 	processor.timeline = timeline.New(state, converter, filter) | 	processor.timeline = timeline.New(state, converter, filter) | ||||||
| 	processor.search = search.New(state, federator, converter, filter) | 	processor.search = search.New(state, federator, converter, filter) | ||||||
| 	processor.status = status.New(&commonProcessor, state, federator, converter, filter, parseMentionFunc) | 	processor.status = status.New(state, &commonProcessor, &processor.polls, federator, converter, filter, parseMentionFunc) | ||||||
| 	processor.stream = streamProcessor |  | ||||||
| 	processor.user = user.New(state, emailSender) | 	processor.user = user.New(state, emailSender) | ||||||
| 
 | 
 | ||||||
| 	// Workers processor handles asynchronous | 	// Workers processor handles asynchronous | ||||||
|  | @ -179,9 +184,9 @@ func NewProcessor( | ||||||
| 		converter, | 		converter, | ||||||
| 		filter, | 		filter, | ||||||
| 		emailSender, | 		emailSender, | ||||||
| 		&accountProcessor, | 		&processor.account, | ||||||
| 		&mediaProcessor, | 		&processor.media, | ||||||
| 		&streamProcessor, | 		&processor.stream, | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	return processor | 	return processor | ||||||
|  |  | ||||||
|  | @ -66,6 +66,26 @@ func (p *Processor) Create(ctx context.Context, requestingAccount *gtsmodel.Acco | ||||||
| 		Text:                     form.Status, | 		Text:                     form.Status, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	if form.Poll != nil { | ||||||
|  | 		// Update the status AS type to "Question". | ||||||
|  | 		status.ActivityStreamsType = ap.ActivityQuestion | ||||||
|  | 
 | ||||||
|  | 		// Create new poll for status from form. | ||||||
|  | 		secs := time.Duration(form.Poll.ExpiresIn) | ||||||
|  | 		status.Poll = >smodel.Poll{ | ||||||
|  | 			ID:         id.NewULID(), | ||||||
|  | 			Multiple:   &form.Poll.Multiple, | ||||||
|  | 			HideCounts: &form.Poll.HideTotals, | ||||||
|  | 			Options:    form.Poll.Options, | ||||||
|  | 			StatusID:   statusID, | ||||||
|  | 			Status:     status, | ||||||
|  | 			ExpiresAt:  now.Add(secs * time.Second), | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Set poll ID on the status. | ||||||
|  | 		status.PollID = status.Poll.ID | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	if errWithCode := p.processReplyToID(ctx, form, requestingAccount.ID, status); errWithCode != nil { | 	if errWithCode := p.processReplyToID(ctx, form, requestingAccount.ID, status); errWithCode != nil { | ||||||
| 		return nil, errWithCode | 		return nil, errWithCode | ||||||
| 	} | 	} | ||||||
|  | @ -90,6 +110,14 @@ func (p *Processor) Create(ctx context.Context, requestingAccount *gtsmodel.Acco | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	if status.Poll != nil { | ||||||
|  | 		// Try to insert the new status poll in the database. | ||||||
|  | 		if err := p.state.DB.PutPoll(ctx, status.Poll); err != nil { | ||||||
|  | 			err := gtserror.Newf("error inserting poll in db: %w", err) | ||||||
|  | 			return nil, gtserror.NewErrorInternalError(err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Insert this new status in the database. | 	// Insert this new status in the database. | ||||||
| 	if err := p.state.DB.PutStatus(ctx, status); err != nil { | 	if err := p.state.DB.PutStatus(ctx, status); err != nil { | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
|  | @ -103,6 +131,15 @@ func (p *Processor) Create(ctx context.Context, requestingAccount *gtsmodel.Acco | ||||||
| 		OriginAccount:  requestingAccount, | 		OriginAccount:  requestingAccount, | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
|  | 	if status.Poll != nil { | ||||||
|  | 		// Now that the status is inserted, and side effects queued, | ||||||
|  | 		// attempt to schedule an expiry handler for the status poll. | ||||||
|  | 		if err := p.polls.ScheduleExpiry(ctx, status.Poll); err != nil { | ||||||
|  | 			err := gtserror.Newf("error scheduling poll expiry: %w", err) | ||||||
|  | 			return nil, gtserror.NewErrorInternalError(err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	return p.c.GetAPIStatus(ctx, requestingAccount, status) | 	return p.c.GetAPIStatus(ctx, requestingAccount, status) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -370,6 +407,18 @@ func (p *Processor) processContent(ctx context.Context, parseMention gtsmodel.Pa | ||||||
| 	status.ContentWarning = warningRes.HTML | 	status.ContentWarning = warningRes.HTML | ||||||
| 	status.Emojis = append(status.Emojis, warningRes.Emojis...) | 	status.Emojis = append(status.Emojis, warningRes.Emojis...) | ||||||
| 
 | 
 | ||||||
|  | 	if status.Poll != nil { | ||||||
|  | 		for i := range status.Poll.Options { | ||||||
|  | 			// Sanitize each option title name and format. | ||||||
|  | 			option := text.SanitizeToPlaintext(status.Poll.Options[i]) | ||||||
|  | 			optionRes := formatInput(format, option) | ||||||
|  | 
 | ||||||
|  | 			// Collect each formatted result. | ||||||
|  | 			status.Poll.Options[i] = optionRes.HTML | ||||||
|  | 			status.Emojis = append(status.Emojis, optionRes.Emojis...) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Gather all the database IDs from each of the gathered status mentions, tags, and emojis. | 	// Gather all the database IDs from each of the gathered status mentions, tags, and emojis. | ||||||
| 	status.MentionIDs = gatherIDs(status.Mentions, func(mention *gtsmodel.Mention) string { return mention.ID }) | 	status.MentionIDs = gatherIDs(status.Mentions, func(mention *gtsmodel.Mention) string { return mention.ID }) | ||||||
| 	status.TagIDs = gatherIDs(status.Tags, func(tag *gtsmodel.Tag) string { return tag.ID }) | 	status.TagIDs = gatherIDs(status.Tags, func(tag *gtsmodel.Tag) string { return tag.ID }) | ||||||
|  |  | ||||||
|  | @ -21,6 +21,7 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/federation" | 	"github.com/superseriousbusiness/gotosocial/internal/federation" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/common" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/common" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/processing/polls" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/state" | 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/text" | 	"github.com/superseriousbusiness/gotosocial/internal/text" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/typeutils" | 	"github.com/superseriousbusiness/gotosocial/internal/typeutils" | ||||||
|  | @ -28,7 +29,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Processor struct { | type Processor struct { | ||||||
| 	// common processor logic | 	// embedded common logic | ||||||
| 	c *common.Processor | 	c *common.Processor | ||||||
| 
 | 
 | ||||||
| 	state        *state.State | 	state        *state.State | ||||||
|  | @ -37,12 +38,16 @@ type Processor struct { | ||||||
| 	filter       *visibility.Filter | 	filter       *visibility.Filter | ||||||
| 	formatter    *text.Formatter | 	formatter    *text.Formatter | ||||||
| 	parseMention gtsmodel.ParseMentionFunc | 	parseMention gtsmodel.ParseMentionFunc | ||||||
|  | 
 | ||||||
|  | 	// other processors | ||||||
|  | 	polls *polls.Processor | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // New returns a new status processor. | // New returns a new status processor. | ||||||
| func New( | func New( | ||||||
| 	common *common.Processor, |  | ||||||
| 	state *state.State, | 	state *state.State, | ||||||
|  | 	common *common.Processor, | ||||||
|  | 	polls *polls.Processor, | ||||||
| 	federator *federation.Federator, | 	federator *federation.Federator, | ||||||
| 	converter *typeutils.Converter, | 	converter *typeutils.Converter, | ||||||
| 	filter *visibility.Filter, | 	filter *visibility.Filter, | ||||||
|  | @ -56,5 +61,6 @@ func New( | ||||||
| 		filter:       filter, | 		filter:       filter, | ||||||
| 		formatter:    text.NewFormatter(state.DB), | 		formatter:    text.NewFormatter(state.DB), | ||||||
| 		parseMention: parseMention, | 		parseMention: parseMention, | ||||||
|  | 		polls:        polls, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -25,6 +25,7 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/media" | 	"github.com/superseriousbusiness/gotosocial/internal/media" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing" | 	"github.com/superseriousbusiness/gotosocial/internal/processing" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/common" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/common" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/processing/polls" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/status" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/status" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/state" | 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/storage" | 	"github.com/superseriousbusiness/gotosocial/internal/storage" | ||||||
|  | @ -96,8 +97,8 @@ func (suite *StatusStandardTestSuite) SetupTest() { | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	common := common.New(&suite.state, suite.typeConverter, suite.federator, filter) | 	common := common.New(&suite.state, suite.typeConverter, suite.federator, filter) | ||||||
| 
 | 	polls := polls.New(&common, &suite.state, suite.typeConverter) | ||||||
| 	suite.status = status.New(&common, &suite.state, suite.federator, suite.typeConverter, filter, processing.GetParseMentionFunc(suite.db, suite.federator)) | 	suite.status = status.New(&suite.state, &common, &polls, suite.federator, suite.typeConverter, filter, processing.GetParseMentionFunc(suite.db, suite.federator)) | ||||||
| 
 | 
 | ||||||
| 	testrig.StandardDBSetup(suite.db, suite.testAccounts) | 	testrig.StandardDBSetup(suite.db, suite.testAccounts) | ||||||
| 	testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") | 	testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") | ||||||
|  |  | ||||||
|  | @ -77,8 +77,8 @@ func (suite *NotificationTestSuite) TestStreamNotification() { | ||||||
|     "header_static": "http://localhost:8080/assets/default_header.png", |     "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|     "followers_count": 0, |     "followers_count": 0, | ||||||
|     "following_count": 0, |     "following_count": 0, | ||||||
|     "statuses_count": 1, |     "statuses_count": 2, | ||||||
|     "last_status_at": "2021-09-20T10:40:37.000Z", |     "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|     "emojis": [], |     "emojis": [], | ||||||
|     "fields": [] |     "fields": [] | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | @ -158,26 +158,52 @@ func (f *federate) CreateStatus(ctx context.Context, status *gtsmodel.Status) er | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Convert status to ActivityStreams Statusable implementing type. | 	// Convert status to AS Statusable implementing type. | ||||||
| 	statusable, err := f.converter.StatusToAS(ctx, status) | 	statusable, err := f.converter.StatusToAS(ctx, status) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return gtserror.Newf("error converting status to Statusable: %w", err) | 		return gtserror.Newf("error converting status to Statusable: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Use ActivityStreams Statusable type as Object of Create. | 	// Send a Create activity with Statusable via the Actor's outbox. | ||||||
| 	create, err := f.converter.WrapStatusableInCreate(statusable, false) | 	create := typeutils.WrapStatusableInCreate(statusable, false) | ||||||
| 	if err != nil { | 	if _, err := f.FederatingActor().Send(ctx, outboxIRI, create); err != nil { | ||||||
| 		return gtserror.Newf("error wrapping Statusable in Create: %w", err) | 		return gtserror.Newf("error sending Create activity via outbox %s: %w", outboxIRI, err) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 	// Send the Create via the Actor's outbox. | func (f *federate) CreatePollVote(ctx context.Context, poll *gtsmodel.Poll, vote *gtsmodel.PollVote) error { | ||||||
| 	if _, err := f.FederatingActor().Send( | 	// Extract status from poll. | ||||||
| 		ctx, outboxIRI, create, | 	status := poll.Status | ||||||
| 	); err != nil { | 
 | ||||||
| 		return gtserror.Newf( | 	// Do nothing if the status | ||||||
| 			"error sending activity %T via outbox %s: %w", | 	// shouldn't be federated. | ||||||
| 			create, outboxIRI, err, | 	if !*status.Federated { | ||||||
| 		) | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Do nothing if this is | ||||||
|  | 	// a vote in our status. | ||||||
|  | 	if *status.Local { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Parse the outbox URI of the poll vote author. | ||||||
|  | 	outboxIRI, err := parseURI(vote.Account.OutboxURI) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Convert votes to AS PollOptionable implementing type. | ||||||
|  | 	notes, err := f.converter.PollVoteToASOptions(ctx, vote) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return gtserror.Newf("error converting to notes: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Send a Create activity with PollOptionables via the Actor's outbox. | ||||||
|  | 	create := typeutils.WrapPollOptionablesInCreate(notes...) | ||||||
|  | 	if _, err := f.FederatingActor().Send(ctx, outboxIRI, create); err != nil { | ||||||
|  | 		return gtserror.Newf("error sending Create activity via outbox %s: %w", outboxIRI, err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -256,13 +282,8 @@ func (f *federate) UpdateStatus(ctx context.Context, status *gtsmodel.Status) er | ||||||
| 		return gtserror.Newf("error converting status to Statusable: %w", err) | 		return gtserror.Newf("error converting status to Statusable: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Use ActivityStreams Statusable type as Object of Update. | 	// Send an Update activity with Statusable via the Actor's outbox. | ||||||
| 	update, err := f.converter.WrapStatusableInUpdate(statusable, false) | 	update := typeutils.WrapStatusableInUpdate(statusable, false) | ||||||
| 	if err != nil { |  | ||||||
| 		return gtserror.Newf("error wrapping Statusable in Update: %w", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Send the Update activity with Statusable via the Actor's outbox. |  | ||||||
| 	if _, err := f.FederatingActor().Send(ctx, outboxIRI, update); err != nil { | 	if _, err := f.FederatingActor().Send(ctx, outboxIRI, update); err != nil { | ||||||
| 		return gtserror.Newf("error sending Update activity via outbox %s: %w", outboxIRI, err) | 		return gtserror.Newf("error sending Update activity via outbox %s: %w", outboxIRI, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -93,6 +93,13 @@ func (p *Processor) ProcessFromClientAPI(ctx context.Context, cMsg messages.From | ||||||
| 		case ap.ObjectNote: | 		case ap.ObjectNote: | ||||||
| 			return p.clientAPI.CreateStatus(ctx, cMsg) | 			return p.clientAPI.CreateStatus(ctx, cMsg) | ||||||
| 
 | 
 | ||||||
|  | 		// CREATE QUESTION | ||||||
|  | 		// (note we don't handle poll *votes* as AS | ||||||
|  | 		// question type when federating (just notes), | ||||||
|  | 		// but it makes for a nicer type switch here. | ||||||
|  | 		case ap.ActivityQuestion: | ||||||
|  | 			return p.clientAPI.CreatePollVote(ctx, cMsg) | ||||||
|  | 
 | ||||||
| 		// CREATE FOLLOW (request) | 		// CREATE FOLLOW (request) | ||||||
| 		case ap.ActivityFollow: | 		case ap.ActivityFollow: | ||||||
| 			return p.clientAPI.CreateFollowReq(ctx, cMsg) | 			return p.clientAPI.CreateFollowReq(ctx, cMsg) | ||||||
|  | @ -189,7 +196,7 @@ func (p *Processor) ProcessFromClientAPI(ctx context.Context, cMsg messages.From | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return gtserror.Newf("unhandled: %s %s", cMsg.APActivityType, cMsg.APObjectType) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *clientAPI) CreateAccount(ctx context.Context, cMsg messages.FromClientAPI) error { | func (p *clientAPI) CreateAccount(ctx context.Context, cMsg messages.FromClientAPI) error { | ||||||
|  | @ -205,7 +212,7 @@ func (p *clientAPI) CreateAccount(ctx context.Context, cMsg messages.FromClientA | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.surface.emailPleaseConfirm(ctx, user, account.Username); err != nil { | 	if err := p.surface.emailPleaseConfirm(ctx, user, account.Username); err != nil { | ||||||
| 		return gtserror.Newf("error emailing %s: %w", account.Username, err) | 		log.Errorf(ctx, "error emailing confirm: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -218,7 +225,7 @@ func (p *clientAPI) CreateStatus(ctx context.Context, cMsg messages.FromClientAP | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil { | 	if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil { | ||||||
| 		return gtserror.Newf("error timelining status: %w", err) | 		log.Errorf(ctx, "error timelining and notifying status: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if status.InReplyToID != "" { | 	if status.InReplyToID != "" { | ||||||
|  | @ -228,7 +235,48 @@ func (p *clientAPI) CreateStatus(ctx context.Context, cMsg messages.FromClientAP | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.federate.CreateStatus(ctx, status); err != nil { | 	if err := p.federate.CreateStatus(ctx, status); err != nil { | ||||||
| 		return gtserror.Newf("error federating status: %w", err) | 		log.Errorf(ctx, "error federating status: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *clientAPI) CreatePollVote(ctx context.Context, cMsg messages.FromClientAPI) error { | ||||||
|  | 	// Cast the create poll vote attached to message. | ||||||
|  | 	vote, ok := cMsg.GTSModel.(*gtsmodel.PollVote) | ||||||
|  | 	if !ok { | ||||||
|  | 		return gtserror.Newf("cannot cast %T -> *gtsmodel.Pollvote", cMsg.GTSModel) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Ensure the vote is fully populated in order to get original poll. | ||||||
|  | 	if err := p.state.DB.PopulatePollVote(ctx, vote); err != nil { | ||||||
|  | 		return gtserror.Newf("error populating poll vote from db: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Ensure the poll on the vote is fully populated to get origin status. | ||||||
|  | 	if err := p.state.DB.PopulatePoll(ctx, vote.Poll); err != nil { | ||||||
|  | 		return gtserror.Newf("error populating poll from db: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get the origin status, | ||||||
|  | 	// (also set the poll on it). | ||||||
|  | 	status := vote.Poll.Status | ||||||
|  | 	status.Poll = vote.Poll | ||||||
|  | 
 | ||||||
|  | 	// Interaction counts changed on the source status, uncache from timelines. | ||||||
|  | 	p.surface.invalidateStatusFromTimelines(ctx, vote.Poll.StatusID) | ||||||
|  | 
 | ||||||
|  | 	if *status.Local { | ||||||
|  | 		// These are poll votes in a local status, we only need to | ||||||
|  | 		// federate the updated status model with latest vote counts. | ||||||
|  | 		if err := p.federate.UpdateStatus(ctx, status); err != nil { | ||||||
|  | 			log.Errorf(ctx, "error federating status update: %v", err) | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		// These are votes in a remote poll, federate to origin the new poll vote(s). | ||||||
|  | 		if err := p.federate.CreatePollVote(ctx, vote.Poll, vote); err != nil { | ||||||
|  | 			log.Errorf(ctx, "error federating poll vote: %v", err) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -241,14 +289,17 @@ func (p *clientAPI) CreateFollowReq(ctx context.Context, cMsg messages.FromClien | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.surface.notifyFollowRequest(ctx, followRequest); err != nil { | 	if err := p.surface.notifyFollowRequest(ctx, followRequest); err != nil { | ||||||
| 		return gtserror.Newf("error notifying follow request: %w", err) | 		log.Errorf(ctx, "error notifying follow request: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Convert the follow request to follow model (requests are sent as follows). | ||||||
|  | 	follow := p.converter.FollowRequestToFollow(ctx, followRequest) | ||||||
|  | 
 | ||||||
| 	if err := p.federate.Follow( | 	if err := p.federate.Follow( | ||||||
| 		ctx, | 		ctx, | ||||||
| 		p.converter.FollowRequestToFollow(ctx, followRequest), | 		follow, | ||||||
| 	); err != nil { | 	); err != nil { | ||||||
| 		return gtserror.Newf("error federating follow: %w", err) | 		log.Errorf(ctx, "error federating follow request: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -266,7 +317,7 @@ func (p *clientAPI) CreateLike(ctx context.Context, cMsg messages.FromClientAPI) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.surface.notifyFave(ctx, fave); err != nil { | 	if err := p.surface.notifyFave(ctx, fave); err != nil { | ||||||
| 		return gtserror.Newf("error notifying fave: %w", err) | 		log.Errorf(ctx, "error notifying fave: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Interaction counts changed on the faved status; | 	// Interaction counts changed on the faved status; | ||||||
|  | @ -274,7 +325,7 @@ func (p *clientAPI) CreateLike(ctx context.Context, cMsg messages.FromClientAPI) | ||||||
| 	p.surface.invalidateStatusFromTimelines(ctx, fave.StatusID) | 	p.surface.invalidateStatusFromTimelines(ctx, fave.StatusID) | ||||||
| 
 | 
 | ||||||
| 	if err := p.federate.Like(ctx, fave); err != nil { | 	if err := p.federate.Like(ctx, fave); err != nil { | ||||||
| 		return gtserror.Newf("error federating like: %w", err) | 		log.Errorf(ctx, "error federating like: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -288,12 +339,12 @@ func (p *clientAPI) CreateAnnounce(ctx context.Context, cMsg messages.FromClient | ||||||
| 
 | 
 | ||||||
| 	// Timeline and notify the boost wrapper status. | 	// Timeline and notify the boost wrapper status. | ||||||
| 	if err := p.surface.timelineAndNotifyStatus(ctx, boost); err != nil { | 	if err := p.surface.timelineAndNotifyStatus(ctx, boost); err != nil { | ||||||
| 		return gtserror.Newf("error timelining boost: %w", err) | 		log.Errorf(ctx, "error timelining and notifying status: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Notify the boost target account. | 	// Notify the boost target account. | ||||||
| 	if err := p.surface.notifyAnnounce(ctx, boost); err != nil { | 	if err := p.surface.notifyAnnounce(ctx, boost); err != nil { | ||||||
| 		return gtserror.Newf("error notifying boost: %w", err) | 		log.Errorf(ctx, "error notifying boost: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Interaction counts changed on the boosted status; | 	// Interaction counts changed on the boosted status; | ||||||
|  | @ -301,7 +352,7 @@ func (p *clientAPI) CreateAnnounce(ctx context.Context, cMsg messages.FromClient | ||||||
| 	p.surface.invalidateStatusFromTimelines(ctx, boost.BoostOfID) | 	p.surface.invalidateStatusFromTimelines(ctx, boost.BoostOfID) | ||||||
| 
 | 
 | ||||||
| 	if err := p.federate.Announce(ctx, boost); err != nil { | 	if err := p.federate.Announce(ctx, boost); err != nil { | ||||||
| 		return gtserror.Newf("error federating announce: %w", err) | 		log.Errorf(ctx, "error federating announce: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -335,7 +386,7 @@ func (p *clientAPI) CreateBlock(ctx context.Context, cMsg messages.FromClientAPI | ||||||
| 	// TODO: same with bookmarks? | 	// TODO: same with bookmarks? | ||||||
| 
 | 
 | ||||||
| 	if err := p.federate.Block(ctx, block); err != nil { | 	if err := p.federate.Block(ctx, block); err != nil { | ||||||
| 		return gtserror.Newf("error federating block: %w", err) | 		log.Errorf(ctx, "error federating block: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -350,7 +401,19 @@ func (p *clientAPI) UpdateStatus(ctx context.Context, cMsg messages.FromClientAP | ||||||
| 
 | 
 | ||||||
| 	// Federate the updated status changes out remotely. | 	// Federate the updated status changes out remotely. | ||||||
| 	if err := p.federate.UpdateStatus(ctx, status); err != nil { | 	if err := p.federate.UpdateStatus(ctx, status); err != nil { | ||||||
| 		return gtserror.Newf("error federating status update: %w", err) | 		log.Errorf(ctx, "error federating status update: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Status representation has changed, invalidate from timelines. | ||||||
|  | 	p.surface.invalidateStatusFromTimelines(ctx, status.ID) | ||||||
|  | 
 | ||||||
|  | 	if status.Poll != nil && status.Poll.Closing { | ||||||
|  | 
 | ||||||
|  | 		// If the latest status has a newly closed poll, at least compared | ||||||
|  | 		// to the existing version, then notify poll close to all voters. | ||||||
|  | 		if err := p.surface.notifyPollClose(ctx, status); err != nil { | ||||||
|  | 			log.Errorf(ctx, "error notifying poll close: %v", err) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -363,7 +426,7 @@ func (p *clientAPI) UpdateAccount(ctx context.Context, cMsg messages.FromClientA | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.federate.UpdateAccount(ctx, account); err != nil { | 	if err := p.federate.UpdateAccount(ctx, account); err != nil { | ||||||
| 		return gtserror.Newf("error federating account update: %w", err) | 		log.Errorf(ctx, "error federating account update: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -382,7 +445,7 @@ func (p *clientAPI) UpdateReport(ctx context.Context, cMsg messages.FromClientAP | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.surface.emailReportClosed(ctx, report); err != nil { | 	if err := p.surface.emailReportClosed(ctx, report); err != nil { | ||||||
| 		return gtserror.Newf("error sending report closed email: %w", err) | 		log.Errorf(ctx, "error emailing report closed: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -395,11 +458,11 @@ func (p *clientAPI) AcceptFollow(ctx context.Context, cMsg messages.FromClientAP | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.surface.notifyFollow(ctx, follow); err != nil { | 	if err := p.surface.notifyFollow(ctx, follow); err != nil { | ||||||
| 		return gtserror.Newf("error notifying follow: %w", err) | 		log.Errorf(ctx, "error notifying follow: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.federate.AcceptFollow(ctx, follow); err != nil { | 	if err := p.federate.AcceptFollow(ctx, follow); err != nil { | ||||||
| 		return gtserror.Newf("error federating follow request accept: %w", err) | 		log.Errorf(ctx, "error federating follow accept: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -415,7 +478,7 @@ func (p *clientAPI) RejectFollowRequest(ctx context.Context, cMsg messages.FromC | ||||||
| 		ctx, | 		ctx, | ||||||
| 		p.converter.FollowRequestToFollow(ctx, followReq), | 		p.converter.FollowRequestToFollow(ctx, followReq), | ||||||
| 	); err != nil { | 	); err != nil { | ||||||
| 		return gtserror.Newf("error federating reject follow: %w", err) | 		log.Errorf(ctx, "error federating follow reject: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -428,7 +491,7 @@ func (p *clientAPI) UndoFollow(ctx context.Context, cMsg messages.FromClientAPI) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.federate.UndoFollow(ctx, follow); err != nil { | 	if err := p.federate.UndoFollow(ctx, follow); err != nil { | ||||||
| 		return gtserror.Newf("error federating undo follow: %w", err) | 		log.Errorf(ctx, "error federating follow undo: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -441,7 +504,7 @@ func (p *clientAPI) UndoBlock(ctx context.Context, cMsg messages.FromClientAPI) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.federate.UndoBlock(ctx, block); err != nil { | 	if err := p.federate.UndoBlock(ctx, block); err != nil { | ||||||
| 		return gtserror.Newf("error federating undo block: %w", err) | 		log.Errorf(ctx, "error federating block undo: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -458,7 +521,7 @@ func (p *clientAPI) UndoFave(ctx context.Context, cMsg messages.FromClientAPI) e | ||||||
| 	p.surface.invalidateStatusFromTimelines(ctx, statusFave.StatusID) | 	p.surface.invalidateStatusFromTimelines(ctx, statusFave.StatusID) | ||||||
| 
 | 
 | ||||||
| 	if err := p.federate.UndoLike(ctx, statusFave); err != nil { | 	if err := p.federate.UndoLike(ctx, statusFave); err != nil { | ||||||
| 		return gtserror.Newf("error federating undo like: %w", err) | 		log.Errorf(ctx, "error federating like undo: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -475,7 +538,7 @@ func (p *clientAPI) UndoAnnounce(ctx context.Context, cMsg messages.FromClientAP | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.surface.deleteStatusFromTimelines(ctx, status.ID); err != nil { | 	if err := p.surface.deleteStatusFromTimelines(ctx, status.ID); err != nil { | ||||||
| 		return gtserror.Newf("error removing status from timelines: %w", err) | 		log.Errorf(ctx, "error removing timelined status: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Interaction counts changed on the boosted status; | 	// Interaction counts changed on the boosted status; | ||||||
|  | @ -483,7 +546,7 @@ func (p *clientAPI) UndoAnnounce(ctx context.Context, cMsg messages.FromClientAP | ||||||
| 	p.surface.invalidateStatusFromTimelines(ctx, status.BoostOfID) | 	p.surface.invalidateStatusFromTimelines(ctx, status.BoostOfID) | ||||||
| 
 | 
 | ||||||
| 	if err := p.federate.UndoAnnounce(ctx, status); err != nil { | 	if err := p.federate.UndoAnnounce(ctx, status); err != nil { | ||||||
| 		return gtserror.Newf("error federating undo announce: %w", err) | 		log.Errorf(ctx, "error federating announce undo: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -509,7 +572,7 @@ func (p *clientAPI) DeleteStatus(ctx context.Context, cMsg messages.FromClientAP | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.wipeStatus(ctx, status, deleteAttachments); err != nil { | 	if err := p.wipeStatus(ctx, status, deleteAttachments); err != nil { | ||||||
| 		return gtserror.Newf("error wiping status: %w", err) | 		log.Errorf(ctx, "error wiping status: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if status.InReplyToID != "" { | 	if status.InReplyToID != "" { | ||||||
|  | @ -519,7 +582,7 @@ func (p *clientAPI) DeleteStatus(ctx context.Context, cMsg messages.FromClientAP | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.federate.DeleteStatus(ctx, status); err != nil { | 	if err := p.federate.DeleteStatus(ctx, status); err != nil { | ||||||
| 		return gtserror.Newf("error federating status delete: %w", err) | 		log.Errorf(ctx, "error federating status delete: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -543,11 +606,11 @@ func (p *clientAPI) DeleteAccount(ctx context.Context, cMsg messages.FromClientA | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.federate.DeleteAccount(ctx, cMsg.TargetAccount); err != nil { | 	if err := p.federate.DeleteAccount(ctx, cMsg.TargetAccount); err != nil { | ||||||
| 		return gtserror.Newf("error federating account delete: %w", err) | 		log.Errorf(ctx, "error federating account delete: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.account.Delete(ctx, cMsg.TargetAccount, originID); err != nil { | 	if err := p.account.Delete(ctx, cMsg.TargetAccount, originID); err != nil { | ||||||
| 		return gtserror.Newf("error deleting account: %w", err) | 		log.Errorf(ctx, "error deleting account: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -563,12 +626,12 @@ func (p *clientAPI) ReportAccount(ctx context.Context, cMsg messages.FromClientA | ||||||
| 	// remote instance if desired. | 	// remote instance if desired. | ||||||
| 	if *report.Forwarded { | 	if *report.Forwarded { | ||||||
| 		if err := p.federate.Flag(ctx, report); err != nil { | 		if err := p.federate.Flag(ctx, report); err != nil { | ||||||
| 			return gtserror.Newf("error federating report: %w", err) | 			log.Errorf(ctx, "error federating flag: %v", err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.surface.emailReportOpened(ctx, report); err != nil { | 	if err := p.surface.emailReportOpened(ctx, report); err != nil { | ||||||
| 		return gtserror.Newf("error sending report opened email: %w", err) | 		log.Errorf(ctx, "error emailing report opened: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  |  | ||||||
|  | @ -114,6 +114,10 @@ func (p *Processor) ProcessFromFediAPI(ctx context.Context, fMsg messages.FromFe | ||||||
| 		// CREATE FLAG/REPORT | 		// CREATE FLAG/REPORT | ||||||
| 		case ap.ActivityFlag: | 		case ap.ActivityFlag: | ||||||
| 			return p.fediAPI.CreateFlag(ctx, fMsg) | 			return p.fediAPI.CreateFlag(ctx, fMsg) | ||||||
|  | 
 | ||||||
|  | 		// CREATE QUESTION | ||||||
|  | 		case ap.ActivityQuestion: | ||||||
|  | 			return p.fediAPI.CreatePollVote(ctx, fMsg) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 	// UPDATE SOMETHING | 	// UPDATE SOMETHING | ||||||
|  | @ -170,7 +174,7 @@ func (p *fediAPI) CreateStatus(ctx context.Context, fMsg messages.FromFediAPI) e | ||||||
| 		// Both situations we need to parse account URI to fetch it. | 		// Both situations we need to parse account URI to fetch it. | ||||||
| 		accountURI, err := url.Parse(status.AccountURI) | 		accountURI, err := url.Parse(status.AccountURI) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return gtserror.Newf("error parsing account uri: %w", err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// Ensure that account for this status has been deref'd. | 		// Ensure that account for this status has been deref'd. | ||||||
|  | @ -180,7 +184,7 @@ func (p *fediAPI) CreateStatus(ctx context.Context, fMsg messages.FromFediAPI) e | ||||||
| 			accountURI, | 			accountURI, | ||||||
| 		) | 		) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return gtserror.Newf("error getting account by uri: %w", err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -192,7 +196,48 @@ func (p *fediAPI) CreateStatus(ctx context.Context, fMsg messages.FromFediAPI) e | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil { | 	if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil { | ||||||
| 		return gtserror.Newf("error timelining status: %w", err) | 		log.Errorf(ctx, "error timelining and notifying status: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *fediAPI) CreatePollVote(ctx context.Context, fMsg messages.FromFediAPI) error { | ||||||
|  | 	// Cast poll vote type from the worker message. | ||||||
|  | 	vote, ok := fMsg.GTSModel.(*gtsmodel.PollVote) | ||||||
|  | 	if !ok { | ||||||
|  | 		return gtserror.Newf("cannot cast %T -> *gtsmodel.PollVote", fMsg.GTSModel) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Insert the new poll vote in the database. | ||||||
|  | 	if err := p.state.DB.PutPollVote(ctx, vote); err != nil { | ||||||
|  | 		return gtserror.Newf("error inserting poll vote in db: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Ensure the poll vote is fully populated at this point. | ||||||
|  | 	if err := p.state.DB.PopulatePollVote(ctx, vote); err != nil { | ||||||
|  | 		return gtserror.Newf("error populating poll vote from db: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Ensure the poll on the vote is fully populated to get origin status. | ||||||
|  | 	if err := p.state.DB.PopulatePoll(ctx, vote.Poll); err != nil { | ||||||
|  | 		return gtserror.Newf("error populating poll from db: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get the origin status, | ||||||
|  | 	// (also set the poll on it). | ||||||
|  | 	status := vote.Poll.Status | ||||||
|  | 	status.Poll = vote.Poll | ||||||
|  | 
 | ||||||
|  | 	// Interaction counts changed on the source status, uncache from timelines. | ||||||
|  | 	p.surface.invalidateStatusFromTimelines(ctx, vote.Poll.StatusID) | ||||||
|  | 
 | ||||||
|  | 	if *status.Local { | ||||||
|  | 		// These were poll votes in a local status, we need to | ||||||
|  | 		// federate the updated status model with latest vote counts. | ||||||
|  | 		if err := p.federate.UpdateStatus(ctx, status); err != nil { | ||||||
|  | 			log.Errorf(ctx, "error federating status update: %v", err) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -269,12 +314,10 @@ func (p *fediAPI) CreateFollowReq(ctx context.Context, fMsg messages.FromFediAPI | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if *followRequest.TargetAccount.Locked { | 	if *followRequest.TargetAccount.Locked { | ||||||
| 		// Account on our instance is locked: | 		// Account on our instance is locked: just notify the follow request. | ||||||
| 		// just notify the follow request. |  | ||||||
| 		if err := p.surface.notifyFollowRequest(ctx, followRequest); err != nil { | 		if err := p.surface.notifyFollowRequest(ctx, followRequest); err != nil { | ||||||
| 			return gtserror.Newf("error notifying follow request: %w", err) | 			log.Errorf(ctx, "error notifying follow request: %v", err) | ||||||
| 		} | 		} | ||||||
| 
 |  | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -291,11 +334,11 @@ func (p *fediAPI) CreateFollowReq(ctx context.Context, fMsg messages.FromFediAPI | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.federate.AcceptFollow(ctx, follow); err != nil { | 	if err := p.federate.AcceptFollow(ctx, follow); err != nil { | ||||||
| 		return gtserror.Newf("error federating accept follow request: %w", err) | 		log.Errorf(ctx, "error federating follow request accept: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.surface.notifyFollow(ctx, follow); err != nil { | 	if err := p.surface.notifyFollow(ctx, follow); err != nil { | ||||||
| 		return gtserror.Newf("error notifying follow: %w", err) | 		log.Errorf(ctx, "error notifying follow: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -313,7 +356,7 @@ func (p *fediAPI) CreateLike(ctx context.Context, fMsg messages.FromFediAPI) err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.surface.notifyFave(ctx, fave); err != nil { | 	if err := p.surface.notifyFave(ctx, fave); err != nil { | ||||||
| 		return gtserror.Newf("error notifying fave: %w", err) | 		log.Errorf(ctx, "error notifying fave: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Interaction counts changed on the faved status; | 	// Interaction counts changed on the faved status; | ||||||
|  | @ -354,11 +397,11 @@ func (p *fediAPI) CreateAnnounce(ctx context.Context, fMsg messages.FromFediAPI) | ||||||
| 
 | 
 | ||||||
| 	// Timeline and notify the announce. | 	// Timeline and notify the announce. | ||||||
| 	if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil { | 	if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil { | ||||||
| 		return gtserror.Newf("error timelining status: %w", err) | 		log.Errorf(ctx, "error timelining and notifying status: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.surface.notifyAnnounce(ctx, status); err != nil { | 	if err := p.surface.notifyAnnounce(ctx, status); err != nil { | ||||||
| 		return gtserror.Newf("error notifying status: %w", err) | 		log.Errorf(ctx, "error notifying announce: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Interaction counts changed on the boosted status; | 	// Interaction counts changed on the boosted status; | ||||||
|  | @ -382,7 +425,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er | ||||||
| 		block.AccountID, | 		block.AccountID, | ||||||
| 		block.TargetAccountID, | 		block.TargetAccountID, | ||||||
| 	); err != nil { | 	); err != nil { | ||||||
| 		return gtserror.Newf("%w", err) | 		log.Errorf(ctx, "error wiping items from block -> target's home timeline: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.state.Timelines.Home.WipeItemsFromAccountID( | 	if err := p.state.Timelines.Home.WipeItemsFromAccountID( | ||||||
|  | @ -390,7 +433,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er | ||||||
| 		block.TargetAccountID, | 		block.TargetAccountID, | ||||||
| 		block.AccountID, | 		block.AccountID, | ||||||
| 	); err != nil { | 	); err != nil { | ||||||
| 		return gtserror.Newf("%w", err) | 		log.Errorf(ctx, "error wiping items from target -> block's home timeline: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Now list timelines. | 	// Now list timelines. | ||||||
|  | @ -399,7 +442,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er | ||||||
| 		block.AccountID, | 		block.AccountID, | ||||||
| 		block.TargetAccountID, | 		block.TargetAccountID, | ||||||
| 	); err != nil { | 	); err != nil { | ||||||
| 		return gtserror.Newf("%w", err) | 		log.Errorf(ctx, "error wiping items from block -> target's list timeline(s): %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.state.Timelines.List.WipeItemsFromAccountID( | 	if err := p.state.Timelines.List.WipeItemsFromAccountID( | ||||||
|  | @ -407,7 +450,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er | ||||||
| 		block.TargetAccountID, | 		block.TargetAccountID, | ||||||
| 		block.AccountID, | 		block.AccountID, | ||||||
| 	); err != nil { | 	); err != nil { | ||||||
| 		return gtserror.Newf("%w", err) | 		log.Errorf(ctx, "error wiping items from target -> block's list timeline(s): %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Remove any follows that existed between blocker + blockee. | 	// Remove any follows that existed between blocker + blockee. | ||||||
|  | @ -416,10 +459,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er | ||||||
| 		block.AccountID, | 		block.AccountID, | ||||||
| 		block.TargetAccountID, | 		block.TargetAccountID, | ||||||
| 	); err != nil { | 	); err != nil { | ||||||
| 		return gtserror.Newf( | 		log.Errorf(ctx, "error deleting follow from block -> target: %v", err) | ||||||
| 			"db error deleting follow from %s targeting %s: %w", |  | ||||||
| 			block.AccountID, block.TargetAccountID, err, |  | ||||||
| 		) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.state.DB.DeleteFollow( | 	if err := p.state.DB.DeleteFollow( | ||||||
|  | @ -427,10 +467,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er | ||||||
| 		block.TargetAccountID, | 		block.TargetAccountID, | ||||||
| 		block.AccountID, | 		block.AccountID, | ||||||
| 	); err != nil { | 	); err != nil { | ||||||
| 		return gtserror.Newf( | 		log.Errorf(ctx, "error deleting follow from target -> block: %v", err) | ||||||
| 			"db error deleting follow from %s targeting %s: %w", |  | ||||||
| 			block.TargetAccountID, block.AccountID, err, |  | ||||||
| 		) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Remove any follow requests that existed between blocker + blockee. | 	// Remove any follow requests that existed between blocker + blockee. | ||||||
|  | @ -439,10 +476,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er | ||||||
| 		block.AccountID, | 		block.AccountID, | ||||||
| 		block.TargetAccountID, | 		block.TargetAccountID, | ||||||
| 	); err != nil { | 	); err != nil { | ||||||
| 		return gtserror.Newf( | 		log.Errorf(ctx, "error deleting follow request from block -> target: %v", err) | ||||||
| 			"db error deleting follow request from %s targeting %s: %w", |  | ||||||
| 			block.AccountID, block.TargetAccountID, err, |  | ||||||
| 		) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.state.DB.DeleteFollowRequest( | 	if err := p.state.DB.DeleteFollowRequest( | ||||||
|  | @ -450,10 +484,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er | ||||||
| 		block.TargetAccountID, | 		block.TargetAccountID, | ||||||
| 		block.AccountID, | 		block.AccountID, | ||||||
| 	); err != nil { | 	); err != nil { | ||||||
| 		return gtserror.Newf( | 		log.Errorf(ctx, "error deleting follow request from target -> block: %v", err) | ||||||
| 			"db error deleting follow request from %s targeting %s: %w", |  | ||||||
| 			block.TargetAccountID, block.AccountID, err, |  | ||||||
| 		) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -469,7 +500,7 @@ func (p *fediAPI) CreateFlag(ctx context.Context, fMsg messages.FromFediAPI) err | ||||||
| 	// - notify admins by dm / notification | 	// - notify admins by dm / notification | ||||||
| 
 | 
 | ||||||
| 	if err := p.surface.emailReportOpened(ctx, incomingReport); err != nil { | 	if err := p.surface.emailReportOpened(ctx, incomingReport); err != nil { | ||||||
| 		return gtserror.Newf("error sending report opened email: %w", err) | 		log.Errorf(ctx, "error emailing report opened: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -497,7 +528,7 @@ func (p *fediAPI) UpdateAccount(ctx context.Context, fMsg messages.FromFediAPI) | ||||||
| 		true, // Force refresh. | 		true, // Force refresh. | ||||||
| 	) | 	) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return gtserror.Newf("error refreshing updated account: %w", err) | 		log.Errorf(ctx, "error refreshing account: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -514,7 +545,7 @@ func (p *fediAPI) UpdateStatus(ctx context.Context, fMsg messages.FromFediAPI) e | ||||||
| 	apStatus, _ := fMsg.APObjectModel.(ap.Statusable) | 	apStatus, _ := fMsg.APObjectModel.(ap.Statusable) | ||||||
| 
 | 
 | ||||||
| 	// Fetch up-to-date attach status attachments, etc. | 	// Fetch up-to-date attach status attachments, etc. | ||||||
| 	_, statusable, err := p.federate.RefreshStatus( | 	status, _, err := p.federate.RefreshStatus( | ||||||
| 		ctx, | 		ctx, | ||||||
| 		fMsg.ReceivingAccount.Username, | 		fMsg.ReceivingAccount.Username, | ||||||
| 		existing, | 		existing, | ||||||
|  | @ -522,12 +553,19 @@ func (p *fediAPI) UpdateStatus(ctx context.Context, fMsg messages.FromFediAPI) e | ||||||
| 		true, | 		true, | ||||||
| 	) | 	) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return gtserror.Newf("error refreshing updated status: %w", err) | 		log.Errorf(ctx, "error refreshing status: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if statusable != nil { |  | ||||||
| 	// Status representation was refetched, uncache from timelines. | 	// Status representation was refetched, uncache from timelines. | ||||||
| 		p.surface.invalidateStatusFromTimelines(ctx, existing.ID) | 	p.surface.invalidateStatusFromTimelines(ctx, status.ID) | ||||||
|  | 
 | ||||||
|  | 	if status.Poll != nil && status.Poll.Closing { | ||||||
|  | 
 | ||||||
|  | 		// If the latest status has a newly closed poll, at least compared | ||||||
|  | 		// to the existing version, then notify poll close to all voters. | ||||||
|  | 		if err := p.surface.notifyPollClose(ctx, status); err != nil { | ||||||
|  | 			log.Errorf(ctx, "error sending poll notification: %v", err) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -545,7 +583,7 @@ func (p *fediAPI) DeleteStatus(ctx context.Context, fMsg messages.FromFediAPI) e | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.wipeStatus(ctx, status, deleteAttachments); err != nil { | 	if err := p.wipeStatus(ctx, status, deleteAttachments); err != nil { | ||||||
| 		return gtserror.Newf("error wiping status: %w", err) | 		log.Errorf(ctx, "error wiping status: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if status.InReplyToID != "" { | 	if status.InReplyToID != "" { | ||||||
|  | @ -564,7 +602,7 @@ func (p *fediAPI) DeleteAccount(ctx context.Context, fMsg messages.FromFediAPI) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.account.Delete(ctx, account, account.ID); err != nil { | 	if err := p.account.Delete(ctx, account, account.ID); err != nil { | ||||||
| 		return gtserror.Newf("error deleting account: %w", err) | 		log.Errorf(ctx, "error deleting account: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  |  | ||||||
|  | @ -347,8 +347,15 @@ func (suite *FromFediAPITestSuite) TestProcessAccountDelete() { | ||||||
| 		suite.FailNow("timeout waiting for statuses to be deleted") | 		suite.FailNow("timeout waiting for statuses to be deleted") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	dbAccount, err := suite.db.GetAccountByID(ctx, deletedAccount.ID) | 	var dbAccount *gtsmodel.Account | ||||||
| 	suite.NoError(err) | 
 | ||||||
|  | 	// account data should be zeroed. | ||||||
|  | 	if !testrig.WaitFor(func() bool { | ||||||
|  | 		dbAccount, err = suite.db.GetAccountByID(ctx, deletedAccount.ID) | ||||||
|  | 		return err == nil && dbAccount.DisplayName == "" | ||||||
|  | 	}) { | ||||||
|  | 		suite.FailNow("timeout waiting for statuses to be deleted") | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	suite.Empty(dbAccount.Note) | 	suite.Empty(dbAccount.Note) | ||||||
| 	suite.Empty(dbAccount.DisplayName) | 	suite.Empty(dbAccount.DisplayName) | ||||||
|  |  | ||||||
|  | @ -35,12 +35,25 @@ func (s *surface) notifyMentions( | ||||||
| 	ctx context.Context, | 	ctx context.Context, | ||||||
| 	status *gtsmodel.Status, | 	status *gtsmodel.Status, | ||||||
| ) error { | ) error { | ||||||
| 	var ( | 	var errs gtserror.MultiError | ||||||
| 		mentions = status.Mentions | 
 | ||||||
| 		errs     = gtserror.NewMultiError(len(mentions)) | 	for _, mention := range status.Mentions { | ||||||
| 	) | 		// Set status on the mention (stops | ||||||
|  | 		// the below function populating it). | ||||||
|  | 		mention.Status = status | ||||||
|  | 
 | ||||||
|  | 		// Beforehand, ensure the passed mention is fully populated. | ||||||
|  | 		if err := s.state.DB.PopulateMention(ctx, mention); err != nil { | ||||||
|  | 			errs.Appendf("error populating mention %s: %w", mention.ID, err) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if mention.TargetAccount.IsRemote() { | ||||||
|  | 			// no need to notify | ||||||
|  | 			// remote accounts. | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 	for _, mention := range mentions { |  | ||||||
| 		// Ensure thread not muted | 		// Ensure thread not muted | ||||||
| 		// by mentioned account. | 		// by mentioned account. | ||||||
| 		muted, err := s.state.DB.IsThreadMutedByAccount( | 		muted, err := s.state.DB.IsThreadMutedByAccount( | ||||||
|  | @ -48,9 +61,8 @@ func (s *surface) notifyMentions( | ||||||
| 			status.ThreadID, | 			status.ThreadID, | ||||||
| 			mention.TargetAccountID, | 			mention.TargetAccountID, | ||||||
| 		) | 		) | ||||||
| 
 |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			errs.Append(err) | 			errs.Appendf("error checking status thread mute %s: %w", status.ThreadID, err) | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -61,14 +73,16 @@ func (s *surface) notifyMentions( | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if err := s.notify( | 		// notify mentioned | ||||||
| 			ctx, | 		// by status author. | ||||||
|  | 		if err := s.notify(ctx, | ||||||
| 			gtsmodel.NotificationMention, | 			gtsmodel.NotificationMention, | ||||||
| 			mention.TargetAccountID, | 			mention.TargetAccount, | ||||||
| 			mention.OriginAccountID, | 			mention.OriginAccount, | ||||||
| 			mention.StatusID, | 			mention.StatusID, | ||||||
| 		); err != nil { | 		); err != nil { | ||||||
| 			errs.Append(err) | 			errs.Appendf("error notifying mention target %s: %w", mention.TargetAccountID, err) | ||||||
|  | 			continue | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -79,15 +93,30 @@ func (s *surface) notifyMentions( | ||||||
| // follow request that they have a new follow request. | // follow request that they have a new follow request. | ||||||
| func (s *surface) notifyFollowRequest( | func (s *surface) notifyFollowRequest( | ||||||
| 	ctx context.Context, | 	ctx context.Context, | ||||||
| 	followRequest *gtsmodel.FollowRequest, | 	followReq *gtsmodel.FollowRequest, | ||||||
| ) error { | ) error { | ||||||
| 	return s.notify( | 	// Beforehand, ensure the passed follow request is fully populated. | ||||||
| 		ctx, | 	if err := s.state.DB.PopulateFollowRequest(ctx, followReq); err != nil { | ||||||
|  | 		return gtserror.Newf("error populating follow request %s: %w", followReq.ID, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if followReq.TargetAccount.IsRemote() { | ||||||
|  | 		// no need to notify | ||||||
|  | 		// remote accounts. | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Now notify the follow request itself. | ||||||
|  | 	if err := s.notify(ctx, | ||||||
| 		gtsmodel.NotificationFollowRequest, | 		gtsmodel.NotificationFollowRequest, | ||||||
| 		followRequest.TargetAccountID, | 		followReq.TargetAccount, | ||||||
| 		followRequest.AccountID, | 		followReq.Account, | ||||||
| 		"", | 		"", | ||||||
| 	) | 	); err != nil { | ||||||
|  | 		return gtserror.Newf("error notifying follow target %s: %w", followReq.TargetAccountID, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // notifyFollow notifies the target of the given follow that | // notifyFollow notifies the target of the given follow that | ||||||
|  | @ -98,6 +127,17 @@ func (s *surface) notifyFollow( | ||||||
| 	ctx context.Context, | 	ctx context.Context, | ||||||
| 	follow *gtsmodel.Follow, | 	follow *gtsmodel.Follow, | ||||||
| ) error { | ) error { | ||||||
|  | 	// Beforehand, ensure the passed follow is fully populated. | ||||||
|  | 	if err := s.state.DB.PopulateFollow(ctx, follow); err != nil { | ||||||
|  | 		return gtserror.Newf("error populating follow %s: %w", follow.ID, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if follow.TargetAccount.IsRemote() { | ||||||
|  | 		// no need to notify | ||||||
|  | 		// remote accounts. | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Check if previous follow req notif exists. | 	// Check if previous follow req notif exists. | ||||||
| 	prevNotif, err := s.state.DB.GetNotification( | 	prevNotif, err := s.state.DB.GetNotification( | ||||||
| 		gtscontext.SetBarebones(ctx), | 		gtscontext.SetBarebones(ctx), | ||||||
|  | @ -107,24 +147,28 @@ func (s *surface) notifyFollow( | ||||||
| 		"", | 		"", | ||||||
| 	) | 	) | ||||||
| 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return gtserror.Newf("db error checking for previous follow request notification: %w", err) | 		return gtserror.Newf("error getting notification: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if prevNotif != nil { | 	if prevNotif != nil { | ||||||
| 		// Previous notif existed, delete it. | 		// Previous follow request notif existed, delete it before creating new. | ||||||
| 		if err := s.state.DB.DeleteNotificationByID(ctx, prevNotif.ID); err != nil { | 		if err := s.state.DB.DeleteNotificationByID(ctx, prevNotif.ID); // nocollapse | ||||||
| 			return gtserror.Newf("db error removing previous follow request notification %s: %w", prevNotif.ID, err) | 		err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
|  | 			return gtserror.Newf("error deleting notification %s: %w", prevNotif.ID, err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Now notify the follow itself. | 	// Now notify the follow itself. | ||||||
| 	return s.notify( | 	if err := s.notify(ctx, | ||||||
| 		ctx, |  | ||||||
| 		gtsmodel.NotificationFollow, | 		gtsmodel.NotificationFollow, | ||||||
| 		follow.TargetAccountID, | 		follow.TargetAccount, | ||||||
| 		follow.AccountID, | 		follow.Account, | ||||||
| 		"", | 		"", | ||||||
| 	) | 	); err != nil { | ||||||
|  | 		return gtserror.Newf("error notifying follow target %s: %w", follow.TargetAccountID, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // notifyFave notifies the target of the given | // notifyFave notifies the target of the given | ||||||
|  | @ -138,6 +182,17 @@ func (s *surface) notifyFave( | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Beforehand, ensure the passed status fave is fully populated. | ||||||
|  | 	if err := s.state.DB.PopulateStatusFave(ctx, fave); err != nil { | ||||||
|  | 		return gtserror.Newf("error populating fave %s: %w", fave.ID, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if fave.TargetAccount.IsRemote() { | ||||||
|  | 		// no need to notify | ||||||
|  | 		// remote accounts. | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Ensure favee hasn't | 	// Ensure favee hasn't | ||||||
| 	// muted the thread. | 	// muted the thread. | ||||||
| 	muted, err := s.state.DB.IsThreadMutedByAccount( | 	muted, err := s.state.DB.IsThreadMutedByAccount( | ||||||
|  | @ -145,24 +200,28 @@ func (s *surface) notifyFave( | ||||||
| 		fave.Status.ThreadID, | 		fave.Status.ThreadID, | ||||||
| 		fave.TargetAccountID, | 		fave.TargetAccountID, | ||||||
| 	) | 	) | ||||||
| 
 |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return gtserror.Newf("error checking status thread mute %s: %w", fave.StatusID, err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if muted { | 	if muted { | ||||||
| 		// Boostee doesn't want | 		// Favee doesn't want | ||||||
| 		// notifs for this thread. | 		// notifs for this thread. | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return s.notify( | 	// notify status author | ||||||
| 		ctx, | 	// of fave by account. | ||||||
|  | 	if err := s.notify(ctx, | ||||||
| 		gtsmodel.NotificationFave, | 		gtsmodel.NotificationFave, | ||||||
| 		fave.TargetAccountID, | 		fave.TargetAccount, | ||||||
| 		fave.AccountID, | 		fave.Account, | ||||||
| 		fave.StatusID, | 		fave.StatusID, | ||||||
| 	) | 	); err != nil { | ||||||
|  | 		return gtserror.Newf("error notifying status author %s: %w", fave.TargetAccountID, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // notifyAnnounce notifies the status boost target | // notifyAnnounce notifies the status boost target | ||||||
|  | @ -176,14 +235,19 @@ func (s *surface) notifyAnnounce( | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if status.BoostOf == nil { | 	if status.BoostOfAccountID == status.AccountID { | ||||||
| 		// No boosted status | 		// Self-boost, nothing to do. | ||||||
| 		// set, nothing to do. |  | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if status.BoostOfAccountID == status.AccountID { | 	// Beforehand, ensure the passed status is fully populated. | ||||||
| 		// Self-boost, nothing to do. | 	if err := s.state.DB.PopulateStatus(ctx, status); err != nil { | ||||||
|  | 		return gtserror.Newf("error populating status %s: %w", status.ID, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if status.BoostOfAccount.IsRemote() { | ||||||
|  | 		// no need to notify | ||||||
|  | 		// remote accounts. | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -196,7 +260,7 @@ func (s *surface) notifyAnnounce( | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return gtserror.Newf("error checking status thread mute %s: %w", status.BoostOfID, err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if muted { | 	if muted { | ||||||
|  | @ -205,13 +269,68 @@ func (s *surface) notifyAnnounce( | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return s.notify( | 	// notify status author | ||||||
| 		ctx, | 	// of boost by account. | ||||||
|  | 	if err := s.notify(ctx, | ||||||
| 		gtsmodel.NotificationReblog, | 		gtsmodel.NotificationReblog, | ||||||
| 		status.BoostOfAccountID, | 		status.BoostOfAccount, | ||||||
| 		status.AccountID, | 		status.Account, | ||||||
| 		status.ID, | 		status.ID, | ||||||
| 	) | 	); err != nil { | ||||||
|  | 		return gtserror.Newf("error notifying status author %s: %w", status.BoostOfAccountID, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *surface) notifyPollClose(ctx context.Context, status *gtsmodel.Status) error { | ||||||
|  | 	// Beforehand, ensure the passed status is fully populated. | ||||||
|  | 	if err := s.state.DB.PopulateStatus(ctx, status); err != nil { | ||||||
|  | 		return gtserror.Newf("error populating status %s: %w", status.ID, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Fetch all votes in the attached status poll. | ||||||
|  | 	votes, err := s.state.DB.GetPollVotes(ctx, status.PollID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return gtserror.Newf("error getting poll %s votes: %w", status.PollID, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var errs gtserror.MultiError | ||||||
|  | 
 | ||||||
|  | 	if status.Account.IsLocal() { | ||||||
|  | 		// Send a notification to the status | ||||||
|  | 		// author that their poll has closed! | ||||||
|  | 		if err := s.notify(ctx, | ||||||
|  | 			gtsmodel.NotificationPoll, | ||||||
|  | 			status.Account, | ||||||
|  | 			status.Account, | ||||||
|  | 			status.ID, | ||||||
|  | 		); err != nil { | ||||||
|  | 			errs.Appendf("error notifying poll author: %w", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, vote := range votes { | ||||||
|  | 		if vote.Account.IsRemote() { | ||||||
|  | 			// no need to notify | ||||||
|  | 			// remote accounts. | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// notify voter that | ||||||
|  | 		// poll has been closed. | ||||||
|  | 		if err := s.notify(ctx, | ||||||
|  | 			gtsmodel.NotificationMention, | ||||||
|  | 			vote.Account, | ||||||
|  | 			status.Account, | ||||||
|  | 			status.ID, | ||||||
|  | 		); err != nil { | ||||||
|  | 			errs.Appendf("error notifying poll voter %s: %w", vote.AccountID, err) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return errs.Combine() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // notify creates, inserts, and streams a new | // notify creates, inserts, and streams a new | ||||||
|  | @ -228,17 +347,12 @@ func (s *surface) notifyAnnounce( | ||||||
| func (s *surface) notify( | func (s *surface) notify( | ||||||
| 	ctx context.Context, | 	ctx context.Context, | ||||||
| 	notificationType gtsmodel.NotificationType, | 	notificationType gtsmodel.NotificationType, | ||||||
| 	targetAccountID string, | 	targetAccount *gtsmodel.Account, | ||||||
| 	originAccountID string, | 	originAccount *gtsmodel.Account, | ||||||
| 	statusID string, | 	statusID string, | ||||||
| ) error { | ) error { | ||||||
| 	targetAccount, err := s.state.DB.GetAccountByID(ctx, targetAccountID) | 	if targetAccount.IsRemote() { | ||||||
| 	if err != nil { | 		// nothing to do. | ||||||
| 		return gtserror.Newf("error getting target account %s: %w", targetAccountID, err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if !targetAccount.IsLocal() { |  | ||||||
| 		// Nothing to do. |  | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -247,8 +361,8 @@ func (s *surface) notify( | ||||||
| 	if _, err := s.state.DB.GetNotification( | 	if _, err := s.state.DB.GetNotification( | ||||||
| 		gtscontext.SetBarebones(ctx), | 		gtscontext.SetBarebones(ctx), | ||||||
| 		notificationType, | 		notificationType, | ||||||
| 		targetAccountID, | 		targetAccount.ID, | ||||||
| 		originAccountID, | 		originAccount.ID, | ||||||
| 		statusID, | 		statusID, | ||||||
| 	); err == nil { | 	); err == nil { | ||||||
| 		// Notification exists; | 		// Notification exists; | ||||||
|  | @ -264,8 +378,10 @@ func (s *surface) notify( | ||||||
| 	notif := >smodel.Notification{ | 	notif := >smodel.Notification{ | ||||||
| 		ID:               id.NewULID(), | 		ID:               id.NewULID(), | ||||||
| 		NotificationType: notificationType, | 		NotificationType: notificationType, | ||||||
| 		TargetAccountID:  targetAccountID, | 		TargetAccountID:  targetAccount.ID, | ||||||
| 		OriginAccountID:  originAccountID, | 		TargetAccount:    targetAccount, | ||||||
|  | 		OriginAccountID:  originAccount.ID, | ||||||
|  | 		OriginAccount:    originAccount, | ||||||
| 		StatusID:         statusID, | 		StatusID:         statusID, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -85,7 +85,7 @@ func (s *surface) timelineAndNotifyStatusForFollowers( | ||||||
| 	follows []*gtsmodel.Follow, | 	follows []*gtsmodel.Follow, | ||||||
| ) error { | ) error { | ||||||
| 	var ( | 	var ( | ||||||
| 		errs  = new(gtserror.MultiError) | 		errs  gtserror.MultiError | ||||||
| 		boost = status.BoostOfID != "" | 		boost = status.BoostOfID != "" | ||||||
| 		reply = status.InReplyToURI != "" | 		reply = status.InReplyToURI != "" | ||||||
| 	) | 	) | ||||||
|  | @ -117,7 +117,7 @@ func (s *surface) timelineAndNotifyStatusForFollowers( | ||||||
| 			ctx, | 			ctx, | ||||||
| 			status, | 			status, | ||||||
| 			follow, | 			follow, | ||||||
| 			errs, | 			&errs, | ||||||
| 		) | 		) | ||||||
| 
 | 
 | ||||||
| 		// Add status to home timeline for owner | 		// Add status to home timeline for owner | ||||||
|  | @ -160,11 +160,10 @@ func (s *surface) timelineAndNotifyStatusForFollowers( | ||||||
| 		//   - This is a top-level post (not a reply or boost). | 		//   - This is a top-level post (not a reply or boost). | ||||||
| 		// | 		// | ||||||
| 		// That means we can officially notify this one. | 		// That means we can officially notify this one. | ||||||
| 		if err := s.notify( | 		if err := s.notify(ctx, | ||||||
| 			ctx, |  | ||||||
| 			gtsmodel.NotificationStatus, | 			gtsmodel.NotificationStatus, | ||||||
| 			follow.AccountID, | 			follow.Account, | ||||||
| 			status.AccountID, | 			status.Account, | ||||||
| 			status.ID, | 			status.ID, | ||||||
| 		); err != nil { | 		); err != nil { | ||||||
| 			errs.Appendf("error notifying account %s about new status: %w", follow.AccountID, err) | 			errs.Appendf("error notifying account %s about new status: %w", follow.AccountID, err) | ||||||
|  |  | ||||||
|  | @ -85,6 +85,21 @@ func wipeStatusF(state *state.State, media *media.Processor, surface *surface) w | ||||||
| 			errs.Appendf("error deleting status faves: %w", err) | 			errs.Appendf("error deleting status faves: %w", err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		if pollID := statusToDelete.PollID; pollID != "" { | ||||||
|  | 			// Delete this poll by ID from the database. | ||||||
|  | 			if err := state.DB.DeletePollByID(ctx, pollID); err != nil { | ||||||
|  | 				errs.Appendf("error deleting status poll: %w", err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Delete any poll votes pointing to this poll ID. | ||||||
|  | 			if err := state.DB.DeletePollVotes(ctx, pollID); err != nil { | ||||||
|  | 				errs.Appendf("error deleting status poll votes: %w", err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Cancel any scheduled expiry task for poll. | ||||||
|  | 			_ = state.Workers.Scheduler.Cancel(pollID) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		// delete all boosts for this status + remove them from timelines | 		// delete all boosts for this status + remove them from timelines | ||||||
| 		boosts, err := state.DB.GetStatusBoosts( | 		boosts, err := state.DB.GetStatusBoosts( | ||||||
| 			// we MUST set a barebones context here, | 			// we MUST set a barebones context here, | ||||||
|  |  | ||||||
|  | @ -26,18 +26,16 @@ import ( | ||||||
| 	"codeberg.org/gruf/go-sched" | 	"codeberg.org/gruf/go-sched" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Scheduler wraps an underlying task scheduler | // Scheduler wraps an underlying scheduler to provide | ||||||
| // to provide concurrency safe tracking by 'id' | // task tracking by unique string identifiers, so jobs | ||||||
| // strings in order to provide easy cancellation. | // may be cancelled with only an identifier. | ||||||
| type Scheduler struct { | type Scheduler struct { | ||||||
| 	sch sched.Scheduler | 	sch sched.Scheduler | ||||||
| 	ts  map[string]*task | 	ts  map[string]*task | ||||||
| 	mu  sync.Mutex | 	mu  sync.Mutex | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Start will start the Scheduler background routine, returning success. | // Start attempts to start the scheduler. Returns false if already running. | ||||||
| // Note that this creates a new internal task map, stopping and dropping |  | ||||||
| // all previously known running tasks. |  | ||||||
| func (sch *Scheduler) Start() bool { | func (sch *Scheduler) Start() bool { | ||||||
| 	if sch.sch.Start(nil) { | 	if sch.sch.Start(nil) { | ||||||
| 		sch.ts = make(map[string]*task) | 		sch.ts = make(map[string]*task) | ||||||
|  | @ -46,9 +44,8 @@ func (sch *Scheduler) Start() bool { | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Stop will stop the Scheduler background routine, returning success. | // Stop attempts to stop scheduler, cancelling | ||||||
| // Note that this nils-out the internal task map, stopping and dropping | // all running tasks. Returns false if not running. | ||||||
| // all previously known running tasks. |  | ||||||
| func (sch *Scheduler) Stop() bool { | func (sch *Scheduler) Stop() bool { | ||||||
| 	if sch.sch.Stop() { | 	if sch.sch.Stop() { | ||||||
| 		sch.ts = nil | 		sch.ts = nil | ||||||
|  | @ -57,18 +54,17 @@ func (sch *Scheduler) Stop() bool { | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // AddOnce adds a run-once job with given id, function and timing parameters, returning success. | // AddOnce schedules the given task to run at time, registered under the given ID. Returns false if task already exists for id. | ||||||
| func (sch *Scheduler) AddOnce(id string, start time.Time, fn func(context.Context, time.Time)) bool { | func (sch *Scheduler) AddOnce(id string, start time.Time, fn func(context.Context, time.Time)) bool { | ||||||
| 	return sch.schedule(id, fn, (*sched.Once)(&start)) | 	return sch.schedule(id, fn, (*sched.Once)(&start)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // AddRecurring adds a new recurring job with given id, function and timing parameters, returning success. | // AddRecurring schedules the given task to return at given period, starting at given time, registered under given id. Returns false if task already exists for id. | ||||||
| func (sch *Scheduler) AddRecurring(id string, start time.Time, freq time.Duration, fn func(context.Context, time.Time)) bool { | func (sch *Scheduler) AddRecurring(id string, start time.Time, freq time.Duration, fn func(context.Context, time.Time)) bool { | ||||||
| 	return sch.schedule(id, fn, &sched.PeriodicAt{Once: sched.Once(start), Period: sched.Periodic(freq)}) | 	return sch.schedule(id, fn, &sched.PeriodicAt{Once: sched.Once(start), Period: sched.Periodic(freq)}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Cancel will attempt to cancel job with given id, | // Cancel attempts to cancel a scheduled task with id, returns false if no task found. | ||||||
| // dropping it from internal scheduler and task map. |  | ||||||
| func (sch *Scheduler) Cancel(id string) bool { | func (sch *Scheduler) Cancel(id string) bool { | ||||||
| 	// Attempt to acquire and | 	// Attempt to acquire and | ||||||
| 	// delete task with iD. | 	// delete task with iD. | ||||||
|  | @ -125,6 +121,8 @@ func (sch *Scheduler) schedule(id string, fn func(context.Context, time.Time), t | ||||||
| 	return true | 	return true | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // task simply wraps together a scheduled | ||||||
|  | // job, and the matching cancel function. | ||||||
| type task struct { | type task struct { | ||||||
| 	job  *sched.Job | 	job  *sched.Job | ||||||
| 	cncl func() | 	cncl func() | ||||||
|  |  | ||||||
|  | @ -228,7 +228,7 @@ func (suite *GetTestSuite) TestGetNewTimelineMoreThanPossible() { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
| 	} | 	} | ||||||
| 	suite.checkStatuses(statuses, id.Highest, id.Lowest, 16) | 	suite.checkStatuses(statuses, id.Highest, id.Lowest, 18) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *GetTestSuite) TestGetNewTimelineMoreThanPossiblePageUp() { | func (suite *GetTestSuite) TestGetNewTimelineMoreThanPossiblePageUp() { | ||||||
|  | @ -255,7 +255,7 @@ func (suite *GetTestSuite) TestGetNewTimelineMoreThanPossiblePageUp() { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
| 	} | 	} | ||||||
| 	suite.checkStatuses(statuses, id.Highest, id.Lowest, 16) | 	suite.checkStatuses(statuses, id.Highest, id.Lowest, 18) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *GetTestSuite) TestGetNewTimelineNoFollowing() { | func (suite *GetTestSuite) TestGetNewTimelineNoFollowing() { | ||||||
|  | @ -284,7 +284,7 @@ func (suite *GetTestSuite) TestGetNewTimelineNoFollowing() { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
| 	} | 	} | ||||||
| 	suite.checkStatuses(statuses, id.Highest, id.Lowest, 5) | 	suite.checkStatuses(statuses, id.Highest, id.Lowest, 6) | ||||||
| 
 | 
 | ||||||
| 	for _, s := range statuses { | 	for _, s := range statuses { | ||||||
| 		if s.GetAccountID() != testAccount.ID { | 		if s.GetAccountID() != testAccount.ID { | ||||||
|  |  | ||||||
|  | @ -40,7 +40,7 @@ func (suite *PruneTestSuite) TestPrune() { | ||||||
| 
 | 
 | ||||||
| 	pruned, err := suite.state.Timelines.Home.Prune(ctx, testAccountID, desiredPreparedItemsLength, desiredIndexedItemsLength) | 	pruned, err := suite.state.Timelines.Home.Prune(ctx, testAccountID, desiredPreparedItemsLength, desiredIndexedItemsLength) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.Equal(12, pruned) | 	suite.Equal(15, pruned) | ||||||
| 	suite.Equal(5, suite.state.Timelines.Home.GetIndexedLength(ctx, testAccountID)) | 	suite.Equal(5, suite.state.Timelines.Home.GetIndexedLength(ctx, testAccountID)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -56,7 +56,7 @@ func (suite *PruneTestSuite) TestPruneTwice() { | ||||||
| 
 | 
 | ||||||
| 	pruned, err := suite.state.Timelines.Home.Prune(ctx, testAccountID, desiredPreparedItemsLength, desiredIndexedItemsLength) | 	pruned, err := suite.state.Timelines.Home.Prune(ctx, testAccountID, desiredPreparedItemsLength, desiredIndexedItemsLength) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.Equal(12, pruned) | 	suite.Equal(15, pruned) | ||||||
| 	suite.Equal(5, suite.state.Timelines.Home.GetIndexedLength(ctx, testAccountID)) | 	suite.Equal(5, suite.state.Timelines.Home.GetIndexedLength(ctx, testAccountID)) | ||||||
| 
 | 
 | ||||||
| 	// Prune same again, nothing should be pruned this time. | 	// Prune same again, nothing should be pruned this time. | ||||||
|  | @ -78,7 +78,7 @@ func (suite *PruneTestSuite) TestPruneTo0() { | ||||||
| 
 | 
 | ||||||
| 	pruned, err := suite.state.Timelines.Home.Prune(ctx, testAccountID, desiredPreparedItemsLength, desiredIndexedItemsLength) | 	pruned, err := suite.state.Timelines.Home.Prune(ctx, testAccountID, desiredPreparedItemsLength, desiredIndexedItemsLength) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.Equal(17, pruned) | 	suite.Equal(20, pruned) | ||||||
| 	suite.Equal(0, suite.state.Timelines.Home.GetIndexedLength(ctx, testAccountID)) | 	suite.Equal(0, suite.state.Timelines.Home.GetIndexedLength(ctx, testAccountID)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -95,7 +95,7 @@ func (suite *PruneTestSuite) TestPruneToInfinityAndBeyond() { | ||||||
| 	pruned, err := suite.state.Timelines.Home.Prune(ctx, testAccountID, desiredPreparedItemsLength, desiredIndexedItemsLength) | 	pruned, err := suite.state.Timelines.Home.Prune(ctx, testAccountID, desiredPreparedItemsLength, desiredIndexedItemsLength) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.Equal(0, pruned) | 	suite.Equal(0, pruned) | ||||||
| 	suite.Equal(17, suite.state.Timelines.Home.GetIndexedLength(ctx, testAccountID)) | 	suite.Equal(20, suite.state.Timelines.Home.GetIndexedLength(ctx, testAccountID)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestPruneTestSuite(t *testing.T) { | func TestPruneTestSuite(t *testing.T) { | ||||||
|  |  | ||||||
|  | @ -261,8 +261,10 @@ func (c *Converter) ASStatusToStatus(ctx context.Context, statusable ap.Statusab | ||||||
| 	// Attached poll information (the statusable will actually | 	// Attached poll information (the statusable will actually | ||||||
| 	// be a Pollable, as a Question is a subset of our Status). | 	// be a Pollable, as a Question is a subset of our Status). | ||||||
| 	if pollable, ok := ap.ToPollable(statusable); ok { | 	if pollable, ok := ap.ToPollable(statusable); ok { | ||||||
| 		// TODO: handle decoding poll data | 		status.Poll, err = ap.ExtractPoll(pollable) | ||||||
| 		_ = pollable | 		if err != nil { | ||||||
|  | 			l.Warnf("error(s) extracting poll: %v", err) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// status.Hashtags | 	// status.Hashtags | ||||||
|  |  | ||||||
|  | @ -412,8 +412,24 @@ func (c *Converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (ap.Stat | ||||||
| 		return nil, gtserror.Newf("error populating status: %w", err) | 		return nil, gtserror.Newf("error populating status: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// We convert it as an AS Note. | 	var status ap.Statusable | ||||||
| 	status := streams.NewActivityStreamsNote() | 
 | ||||||
|  | 	if s.Poll != nil { | ||||||
|  | 		// If status has poll available, we convert | ||||||
|  | 		// it as an AS Question (similar to a Note). | ||||||
|  | 		poll := streams.NewActivityStreamsQuestion() | ||||||
|  | 
 | ||||||
|  | 		// Add required status poll data to AS Question. | ||||||
|  | 		if err := c.addPollToAS(ctx, s.Poll, poll); err != nil { | ||||||
|  | 			return nil, gtserror.Newf("error converting poll: %w", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Set poll as status. | ||||||
|  | 		status = poll | ||||||
|  | 	} else { | ||||||
|  | 		// Else we converter it as an AS Note. | ||||||
|  | 		status = streams.NewActivityStreamsNote() | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	// id | 	// id | ||||||
| 	statusURI, err := url.Parse(s.URI) | 	statusURI, err := url.Parse(s.URI) | ||||||
|  | @ -636,6 +652,73 @@ func (c *Converter) StatusToAS(ctx context.Context, s *gtsmodel.Status) (ap.Stat | ||||||
| 	return status, nil | 	return status, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (c *Converter) addPollToAS(ctx context.Context, poll *gtsmodel.Poll, dst ap.Pollable) error { | ||||||
|  | 	var optionsProp interface { | ||||||
|  | 		// the minimum interface for appending AS Notes | ||||||
|  | 		// to an AS type options property of some kind. | ||||||
|  | 		AppendActivityStreamsNote(vocab.ActivityStreamsNote) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(poll.Options) != len(poll.Votes) { | ||||||
|  | 		return gtserror.Newf("invalid poll %s", poll.ID) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !*poll.HideCounts { | ||||||
|  | 		// Set total no. voting accounts. | ||||||
|  | 		ap.SetVotersCount(dst, *poll.Voters) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if *poll.Multiple { | ||||||
|  | 		// Create new multiple-choice (AnyOf) property for poll. | ||||||
|  | 		anyOfProp := streams.NewActivityStreamsAnyOfProperty() | ||||||
|  | 		dst.SetActivityStreamsAnyOf(anyOfProp) | ||||||
|  | 		optionsProp = anyOfProp | ||||||
|  | 	} else { | ||||||
|  | 		// Create new single-choice (OneOf) property for poll. | ||||||
|  | 		oneOfProp := streams.NewActivityStreamsOneOfProperty() | ||||||
|  | 		dst.SetActivityStreamsOneOf(oneOfProp) | ||||||
|  | 		optionsProp = oneOfProp | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for i, name := range poll.Options { | ||||||
|  | 		// Create new Note object to represent option. | ||||||
|  | 		note := streams.NewActivityStreamsNote() | ||||||
|  | 
 | ||||||
|  | 		// Create new name property and set the option name. | ||||||
|  | 		nameProp := streams.NewActivityStreamsNameProperty() | ||||||
|  | 		nameProp.AppendXMLSchemaString(name) | ||||||
|  | 		note.SetActivityStreamsName(nameProp) | ||||||
|  | 
 | ||||||
|  | 		if !*poll.HideCounts { | ||||||
|  | 			// Create new total items property to hold the vote count. | ||||||
|  | 			totalItemsProp := streams.NewActivityStreamsTotalItemsProperty() | ||||||
|  | 			totalItemsProp.Set(poll.Votes[i]) | ||||||
|  | 
 | ||||||
|  | 			// Create new replies property with collection to encompass count. | ||||||
|  | 			repliesProp := streams.NewActivityStreamsRepliesProperty() | ||||||
|  | 			collection := streams.NewActivityStreamsCollection() | ||||||
|  | 			collection.SetActivityStreamsTotalItems(totalItemsProp) | ||||||
|  | 			repliesProp.SetActivityStreamsCollection(collection) | ||||||
|  | 
 | ||||||
|  | 			// Attach the replies to Note object. | ||||||
|  | 			note.SetActivityStreamsReplies(repliesProp) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Append the note to options property. | ||||||
|  | 		optionsProp.AppendActivityStreamsNote(note) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Set poll endTime property. | ||||||
|  | 	ap.SetEndTime(dst, poll.ExpiresAt) | ||||||
|  | 
 | ||||||
|  | 	if !poll.ClosedAt.IsZero() { | ||||||
|  | 		// Poll is closed, set closed property. | ||||||
|  | 		ap.AppendClosed(dst, poll.ClosedAt) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // StatusToASDelete converts a gts model status into a Delete of that status, using just the | // StatusToASDelete converts a gts model status into a Delete of that status, using just the | ||||||
| // URI of the status as object, and addressing the Delete appropriately. | // URI of the status as object, and addressing the Delete appropriately. | ||||||
| func (c *Converter) StatusToASDelete(ctx context.Context, s *gtsmodel.Status) (vocab.ActivityStreamsDelete, error) { | func (c *Converter) StatusToASDelete(ctx context.Context, s *gtsmodel.Status) (vocab.ActivityStreamsDelete, error) { | ||||||
|  | @ -1413,12 +1496,8 @@ func (c *Converter) StatusesToASOutboxPage(ctx context.Context, outboxID string, | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		create, err := c.WrapStatusableInCreate(note, true) | 		activity := WrapStatusableInCreate(note, true) | ||||||
| 		if err != nil { | 		itemsProp.AppendActivityStreamsCreate(activity) | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		itemsProp.AppendActivityStreamsCreate(create) |  | ||||||
| 
 | 
 | ||||||
| 		if highest == "" || s.ID > highest { | 		if highest == "" || s.ID > highest { | ||||||
| 			highest = s.ID | 			highest = s.ID | ||||||
|  | @ -1569,3 +1648,66 @@ func (c *Converter) ReportToASFlag(ctx context.Context, r *gtsmodel.Report) (voc | ||||||
| 
 | 
 | ||||||
| 	return flag, nil | 	return flag, nil | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func (c *Converter) PollVoteToASOptions(ctx context.Context, vote *gtsmodel.PollVote) ([]ap.PollOptionable, error) { | ||||||
|  | 	// Ensure the vote is fully populated (this fetches author). | ||||||
|  | 	if err := c.state.DB.PopulatePollVote(ctx, vote); err != nil { | ||||||
|  | 		return nil, gtserror.Newf("error populating vote from db: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get the vote author. | ||||||
|  | 	author := vote.Account | ||||||
|  | 
 | ||||||
|  | 	// Get the JSONLD ID IRI for vote author. | ||||||
|  | 	authorIRI, err := url.Parse(author.URI) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, gtserror.Newf("invalid author uri: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get the vote poll. | ||||||
|  | 	poll := vote.Poll | ||||||
|  | 
 | ||||||
|  | 	// Ensure the poll is fully populated with status. | ||||||
|  | 	if err := c.state.DB.PopulatePoll(ctx, poll); err != nil { | ||||||
|  | 		return nil, gtserror.Newf("error populating poll from db: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get the JSONLD ID IRI for poll's source status. | ||||||
|  | 	statusIRI, err := url.Parse(poll.Status.URI) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, gtserror.Newf("invalid status uri: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get the JSONLD ID IRI for poll's author account. | ||||||
|  | 	pollAuthorIRI, err := url.Parse(poll.Status.AccountURI) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, gtserror.Newf("invalid account uri: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Preallocate the return slice of notes. | ||||||
|  | 	notes := make([]ap.PollOptionable, len(vote.Choices)) | ||||||
|  | 
 | ||||||
|  | 	for i, choice := range vote.Choices { | ||||||
|  | 		// Create new note to represent vote. | ||||||
|  | 		note := streams.NewActivityStreamsNote() | ||||||
|  | 
 | ||||||
|  | 		// For AP IRI generate from author URI + poll ID + vote choice. | ||||||
|  | 		id := fmt.Sprintf("%s#%s/votes/%d", author.URI, poll.ID, choice) | ||||||
|  | 		ap.MustSet(ap.SetJSONLDIdStr, ap.WithJSONLDId(note), id) | ||||||
|  | 
 | ||||||
|  | 		// Attach new name property to note with vote choice. | ||||||
|  | 		nameProp := streams.NewActivityStreamsNameProperty() | ||||||
|  | 		nameProp.AppendXMLSchemaString(poll.Options[choice]) | ||||||
|  | 		note.SetActivityStreamsName(nameProp) | ||||||
|  | 
 | ||||||
|  | 		// Set 'to', 'attribTo', 'inReplyTo' fields. | ||||||
|  | 		ap.AppendAttributedTo(note, authorIRI) | ||||||
|  | 		ap.AppendInReplyTo(note, statusIRI) | ||||||
|  | 		ap.AppendTo(note, pollAuthorIRI) | ||||||
|  | 
 | ||||||
|  | 		// Set note in return slice. | ||||||
|  | 		notes[i] = note | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return notes, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -680,7 +680,7 @@ func (suite *InternalToASTestSuite) TestStatusesToASOutboxPage() { | ||||||
|     { |     { | ||||||
|       "actor": "http://localhost:8080/users/admin", |       "actor": "http://localhost:8080/users/admin", | ||||||
|       "cc": "http://localhost:8080/users/admin/followers", |       "cc": "http://localhost:8080/users/admin/followers", | ||||||
|       "id": "http://localhost:8080/users/admin/statuses/01F8MHAAY43M6RJ473VQFCVH37/activity", |       "id": "http://localhost:8080/users/admin/statuses/01F8MHAAY43M6RJ473VQFCVH37/activity#Create", | ||||||
|       "object": "http://localhost:8080/users/admin/statuses/01F8MHAAY43M6RJ473VQFCVH37", |       "object": "http://localhost:8080/users/admin/statuses/01F8MHAAY43M6RJ473VQFCVH37", | ||||||
|       "published": "2021-10-20T12:36:45Z", |       "published": "2021-10-20T12:36:45Z", | ||||||
|       "to": "https://www.w3.org/ns/activitystreams#Public", |       "to": "https://www.w3.org/ns/activitystreams#Public", | ||||||
|  | @ -689,7 +689,7 @@ func (suite *InternalToASTestSuite) TestStatusesToASOutboxPage() { | ||||||
|     { |     { | ||||||
|       "actor": "http://localhost:8080/users/admin", |       "actor": "http://localhost:8080/users/admin", | ||||||
|       "cc": "http://localhost:8080/users/admin/followers", |       "cc": "http://localhost:8080/users/admin/followers", | ||||||
|       "id": "http://localhost:8080/users/admin/statuses/01F8MH75CBF9JFX4ZAD54N0W0R/activity", |       "id": "http://localhost:8080/users/admin/statuses/01F8MH75CBF9JFX4ZAD54N0W0R/activity#Create", | ||||||
|       "object": "http://localhost:8080/users/admin/statuses/01F8MH75CBF9JFX4ZAD54N0W0R", |       "object": "http://localhost:8080/users/admin/statuses/01F8MH75CBF9JFX4ZAD54N0W0R", | ||||||
|       "published": "2021-10-20T11:36:45Z", |       "published": "2021-10-20T11:36:45Z", | ||||||
|       "to": "https://www.w3.org/ns/activitystreams#Public", |       "to": "https://www.w3.org/ns/activitystreams#Public", | ||||||
|  |  | ||||||
|  | @ -729,10 +729,13 @@ func (c *Converter) StatusToAPIStatus(ctx context.Context, s *gtsmodel.Status, r | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if appID := s.CreatedWithApplicationID; appID != "" { | 	if appID := s.CreatedWithApplicationID; appID != "" { | ||||||
| 		app, err := c.state.DB.GetApplicationByID(ctx, appID) | 		app := s.CreatedWithApplication | ||||||
|  | 		if app == nil { | ||||||
|  | 			app, err = c.state.DB.GetApplicationByID(ctx, appID) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, fmt.Errorf("error getting application %s: %w", appID, err) | 				return nil, fmt.Errorf("error getting application %s: %w", appID, err) | ||||||
| 			} | 			} | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 		apiApp, err := c.AppToAPIAppPublic(ctx, app) | 		apiApp, err := c.AppToAPIAppPublic(ctx, app) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | @ -742,6 +745,18 @@ func (c *Converter) StatusToAPIStatus(ctx context.Context, s *gtsmodel.Status, r | ||||||
| 		apiStatus.Application = apiApp | 		apiStatus.Application = apiApp | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	if s.Poll != nil { | ||||||
|  | 		// Set originating | ||||||
|  | 		// status on the poll. | ||||||
|  | 		poll := s.Poll | ||||||
|  | 		poll.Status = s | ||||||
|  | 
 | ||||||
|  | 		apiStatus.Poll, err = c.PollToAPIPoll(ctx, requestingAccount, poll) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, fmt.Errorf("error converting poll: %w", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Normalization. | 	// Normalization. | ||||||
| 
 | 
 | ||||||
| 	if s.URL == "" { | 	if s.URL == "" { | ||||||
|  | @ -1287,6 +1302,86 @@ func (c *Converter) MarkersToAPIMarker(ctx context.Context, markers []*gtsmodel. | ||||||
| 	return apiMarker, nil | 	return apiMarker, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // PollToAPIPoll converts a database (gtsmodel) Poll into an API model representation appropriate for the given requesting account. | ||||||
|  | func (c *Converter) PollToAPIPoll(ctx context.Context, requester *gtsmodel.Account, poll *gtsmodel.Poll) (*apimodel.Poll, error) { | ||||||
|  | 	// Ensure the poll model is fully populated for src status. | ||||||
|  | 	if err := c.state.DB.PopulatePoll(ctx, poll); err != nil { | ||||||
|  | 		return nil, gtserror.Newf("error populating poll: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var ( | ||||||
|  | 		totalVotes  int | ||||||
|  | 		totalVoters int | ||||||
|  | 		voteCounts  []int | ||||||
|  | 		ownChoices  []int | ||||||
|  | 		isAuthor    bool | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	if requester != nil { | ||||||
|  | 		// Get vote by requester in poll (if any). | ||||||
|  | 		vote, err := c.state.DB.GetPollVoteBy(ctx, | ||||||
|  | 			poll.ID, | ||||||
|  | 			requester.ID, | ||||||
|  | 		) | ||||||
|  | 		if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
|  | 			return nil, gtserror.Newf("error getting vote for poll %s: %w", poll.ID, err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if vote != nil { | ||||||
|  | 			// Set choices by requester. | ||||||
|  | 			ownChoices = vote.Choices | ||||||
|  | 
 | ||||||
|  | 			// Update default totals in the | ||||||
|  | 			// case that counts are hidden. | ||||||
|  | 			totalVotes = len(vote.Choices) | ||||||
|  | 			totalVoters = 1 | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Check if requester is author of source status. | ||||||
|  | 		isAuthor = (requester.ID == poll.Status.AccountID) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Preallocate a slice of frontend model poll choices. | ||||||
|  | 	options := make([]apimodel.PollOption, len(poll.Options)) | ||||||
|  | 
 | ||||||
|  | 	// Add the titles to all of the options. | ||||||
|  | 	for i, title := range poll.Options { | ||||||
|  | 		options[i].Title = title | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if isAuthor || !*poll.HideCounts { | ||||||
|  | 		// A remote status, | ||||||
|  | 		// the simple route! | ||||||
|  | 		// | ||||||
|  | 		// Pull cached remote values. | ||||||
|  | 		totalVoters = *poll.Voters | ||||||
|  | 		voteCounts = poll.Votes | ||||||
|  | 
 | ||||||
|  | 		// Accumulate total from all counts. | ||||||
|  | 		for _, count := range poll.Votes { | ||||||
|  | 			totalVotes += count | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// When this is status author, or hide counts | ||||||
|  | 		// is disabled, set the counts known per vote. | ||||||
|  | 		for i, count := range voteCounts { | ||||||
|  | 			options[i].VotesCount = count | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return &apimodel.Poll{ | ||||||
|  | 		ID:          poll.ID, | ||||||
|  | 		ExpiresAt:   util.FormatISO8601(poll.ExpiresAt), | ||||||
|  | 		Expired:     poll.Closed(), | ||||||
|  | 		Multiple:    *poll.Multiple, | ||||||
|  | 		VotesCount:  totalVotes, | ||||||
|  | 		VotersCount: totalVoters, | ||||||
|  | 		Voted:       (isAuthor || len(ownChoices) > 0), | ||||||
|  | 		OwnVotes:    ownChoices, | ||||||
|  | 		Options:     options, | ||||||
|  | 	}, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // convertAttachmentsToAPIAttachments will convert a slice of GTS model attachments to frontend API model attachments, falling back to IDs if no GTS models supplied. | // convertAttachmentsToAPIAttachments will convert a slice of GTS model attachments to frontend API model attachments, falling back to IDs if no GTS models supplied. | ||||||
| func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, attachments []*gtsmodel.MediaAttachment, attachmentIDs []string) ([]apimodel.Attachment, error) { | func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, attachments []*gtsmodel.MediaAttachment, attachmentIDs []string) ([]apimodel.Attachment, error) { | ||||||
| 	var errs gtserror.MultiError | 	var errs gtserror.MultiError | ||||||
|  |  | ||||||
|  | @ -58,8 +58,8 @@ func (suite *InternalToFrontendTestSuite) TestAccountToFrontend() { | ||||||
|   "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", |   "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", | ||||||
|   "followers_count": 2, |   "followers_count": 2, | ||||||
|   "following_count": 2, |   "following_count": 2, | ||||||
|   "statuses_count": 5, |   "statuses_count": 6, | ||||||
|   "last_status_at": "2022-05-20T11:37:55.000Z", |   "last_status_at": "2022-05-20T11:41:10.000Z", | ||||||
|   "emojis": [], |   "emojis": [], | ||||||
|   "fields": [], |   "fields": [], | ||||||
|   "enable_rss": true, |   "enable_rss": true, | ||||||
|  | @ -100,8 +100,8 @@ func (suite *InternalToFrontendTestSuite) TestAccountToFrontendWithEmojiStruct() | ||||||
|   "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", |   "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", | ||||||
|   "followers_count": 2, |   "followers_count": 2, | ||||||
|   "following_count": 2, |   "following_count": 2, | ||||||
|   "statuses_count": 5, |   "statuses_count": 6, | ||||||
|   "last_status_at": "2022-05-20T11:37:55.000Z", |   "last_status_at": "2022-05-20T11:41:10.000Z", | ||||||
|   "emojis": [ |   "emojis": [ | ||||||
|     { |     { | ||||||
|       "shortcode": "rainbow", |       "shortcode": "rainbow", | ||||||
|  | @ -148,8 +148,8 @@ func (suite *InternalToFrontendTestSuite) TestAccountToFrontendWithEmojiIDs() { | ||||||
|   "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", |   "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", | ||||||
|   "followers_count": 2, |   "followers_count": 2, | ||||||
|   "following_count": 2, |   "following_count": 2, | ||||||
|   "statuses_count": 5, |   "statuses_count": 6, | ||||||
|   "last_status_at": "2022-05-20T11:37:55.000Z", |   "last_status_at": "2022-05-20T11:41:10.000Z", | ||||||
|   "emojis": [ |   "emojis": [ | ||||||
|     { |     { | ||||||
|       "shortcode": "rainbow", |       "shortcode": "rainbow", | ||||||
|  | @ -192,8 +192,8 @@ func (suite *InternalToFrontendTestSuite) TestAccountToFrontendSensitive() { | ||||||
|   "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", |   "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", | ||||||
|   "followers_count": 2, |   "followers_count": 2, | ||||||
|   "following_count": 2, |   "following_count": 2, | ||||||
|   "statuses_count": 5, |   "statuses_count": 6, | ||||||
|   "last_status_at": "2022-05-20T11:37:55.000Z", |   "last_status_at": "2022-05-20T11:41:10.000Z", | ||||||
|   "emojis": [], |   "emojis": [], | ||||||
|   "fields": [], |   "fields": [], | ||||||
|   "source": { |   "source": { | ||||||
|  | @ -660,7 +660,7 @@ func (suite *InternalToFrontendTestSuite) TestInstanceV1ToFrontend() { | ||||||
|   }, |   }, | ||||||
|   "stats": { |   "stats": { | ||||||
|     "domain_count": 2, |     "domain_count": 2, | ||||||
|     "status_count": 16, |     "status_count": 18, | ||||||
|     "user_count": 4 |     "user_count": 4 | ||||||
|   }, |   }, | ||||||
|   "thumbnail": "http://localhost:8080/assets/logo.png", |   "thumbnail": "http://localhost:8080/assets/logo.png", | ||||||
|  | @ -910,8 +910,8 @@ func (suite *InternalToFrontendTestSuite) TestReportToFrontend1() { | ||||||
|     "header_static": "http://localhost:8080/assets/default_header.png", |     "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|     "followers_count": 0, |     "followers_count": 0, | ||||||
|     "following_count": 0, |     "following_count": 0, | ||||||
|     "statuses_count": 1, |     "statuses_count": 2, | ||||||
|     "last_status_at": "2021-09-20T10:40:37.000Z", |     "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|     "emojis": [], |     "emojis": [], | ||||||
|     "fields": [] |     "fields": [] | ||||||
|   } |   } | ||||||
|  | @ -953,8 +953,8 @@ func (suite *InternalToFrontendTestSuite) TestReportToFrontend2() { | ||||||
|     "header_static": "http://localhost:8080/assets/default_header.png", |     "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|     "followers_count": 1, |     "followers_count": 1, | ||||||
|     "following_count": 1, |     "following_count": 1, | ||||||
|     "statuses_count": 7, |     "statuses_count": 8, | ||||||
|     "last_status_at": "2021-10-20T10:40:37.000Z", |     "last_status_at": "2021-07-28T08:40:37.000Z", | ||||||
|     "emojis": [], |     "emojis": [], | ||||||
|     "fields": [ |     "fields": [ | ||||||
|       { |       { | ||||||
|  | @ -1027,8 +1027,8 @@ func (suite *InternalToFrontendTestSuite) TestAdminReportToFrontend1() { | ||||||
|       "header_static": "http://localhost:8080/assets/default_header.png", |       "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|       "followers_count": 0, |       "followers_count": 0, | ||||||
|       "following_count": 0, |       "following_count": 0, | ||||||
|       "statuses_count": 1, |       "statuses_count": 2, | ||||||
|       "last_status_at": "2021-09-20T10:40:37.000Z", |       "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|       "emojis": [], |       "emojis": [], | ||||||
|       "fields": [] |       "fields": [] | ||||||
|     } |     } | ||||||
|  | @ -1068,8 +1068,8 @@ func (suite *InternalToFrontendTestSuite) TestAdminReportToFrontend1() { | ||||||
|       "header_static": "http://localhost:8080/assets/default_header.png", |       "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|       "followers_count": 1, |       "followers_count": 1, | ||||||
|       "following_count": 1, |       "following_count": 1, | ||||||
|       "statuses_count": 7, |       "statuses_count": 8, | ||||||
|       "last_status_at": "2021-10-20T10:40:37.000Z", |       "last_status_at": "2021-07-28T08:40:37.000Z", | ||||||
|       "emojis": [], |       "emojis": [], | ||||||
|       "fields": [ |       "fields": [ | ||||||
|         { |         { | ||||||
|  | @ -1239,8 +1239,8 @@ func (suite *InternalToFrontendTestSuite) TestAdminReportToFrontend2() { | ||||||
|       "header_static": "http://localhost:8080/assets/default_header.png", |       "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|       "followers_count": 1, |       "followers_count": 1, | ||||||
|       "following_count": 1, |       "following_count": 1, | ||||||
|       "statuses_count": 7, |       "statuses_count": 8, | ||||||
|       "last_status_at": "2021-10-20T10:40:37.000Z", |       "last_status_at": "2021-07-28T08:40:37.000Z", | ||||||
|       "emojis": [], |       "emojis": [], | ||||||
|       "fields": [ |       "fields": [ | ||||||
|         { |         { | ||||||
|  | @ -1295,8 +1295,8 @@ func (suite *InternalToFrontendTestSuite) TestAdminReportToFrontend2() { | ||||||
|       "header_static": "http://localhost:8080/assets/default_header.png", |       "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|       "followers_count": 0, |       "followers_count": 0, | ||||||
|       "following_count": 0, |       "following_count": 0, | ||||||
|       "statuses_count": 1, |       "statuses_count": 2, | ||||||
|       "last_status_at": "2021-09-20T10:40:37.000Z", |       "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|       "emojis": [], |       "emojis": [], | ||||||
|       "fields": [] |       "fields": [] | ||||||
|     } |     } | ||||||
|  | @ -1342,8 +1342,8 @@ func (suite *InternalToFrontendTestSuite) TestAdminReportToFrontend2() { | ||||||
|         "header_static": "http://localhost:8080/assets/default_header.png", |         "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|         "followers_count": 0, |         "followers_count": 0, | ||||||
|         "following_count": 0, |         "following_count": 0, | ||||||
|         "statuses_count": 1, |         "statuses_count": 2, | ||||||
|         "last_status_at": "2021-09-20T10:40:37.000Z", |         "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|         "emojis": [], |         "emojis": [], | ||||||
|         "fields": [] |         "fields": [] | ||||||
|       }, |       }, | ||||||
|  | @ -1473,8 +1473,8 @@ func (suite *InternalToFrontendTestSuite) TestAdminReportToFrontendSuspendedLoca | ||||||
|       "header_static": "http://localhost:8080/assets/default_header.png", |       "header_static": "http://localhost:8080/assets/default_header.png", | ||||||
|       "followers_count": 0, |       "followers_count": 0, | ||||||
|       "following_count": 0, |       "following_count": 0, | ||||||
|       "statuses_count": 1, |       "statuses_count": 2, | ||||||
|       "last_status_at": "2021-09-20T10:40:37.000Z", |       "last_status_at": "2021-09-11T09:40:37.000Z", | ||||||
|       "emojis": [], |       "emojis": [], | ||||||
|       "fields": [] |       "fields": [] | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  | @ -19,6 +19,7 @@ package typeutils | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/superseriousbusiness/activity/pub" | 	"github.com/superseriousbusiness/activity/pub" | ||||||
| 	"github.com/superseriousbusiness/activity/streams" | 	"github.com/superseriousbusiness/activity/streams" | ||||||
|  | @ -84,132 +85,86 @@ func (c *Converter) WrapPersonInUpdate(person vocab.ActivityStreamsPerson, origi | ||||||
| 	return update, nil | 	return update, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // WrapNoteInCreate wraps a Statusable with a Create activity. | func WrapStatusableInCreate(status ap.Statusable, iriOnly bool) vocab.ActivityStreamsCreate { | ||||||
| // |  | ||||||
| // If objectIRIOnly is set to true, then the function won't put the *entire* note in the Object field of the Create, |  | ||||||
| // but just the AP URI of the note. This is useful in cases where you want to give a remote server something to dereference, |  | ||||||
| // and still have control over whether or not they're allowed to actually see the contents. |  | ||||||
| func (c *Converter) WrapStatusableInCreate(status ap.Statusable, objectIRIOnly bool) (vocab.ActivityStreamsCreate, error) { |  | ||||||
| 	create := streams.NewActivityStreamsCreate() | 	create := streams.NewActivityStreamsCreate() | ||||||
| 
 | 	wrapStatusableInActivity(create, status, iriOnly) | ||||||
| 	// Object property | 	return create | ||||||
| 	objectProp := streams.NewActivityStreamsObjectProperty() |  | ||||||
| 	if objectIRIOnly { |  | ||||||
| 		// Only append the object IRI to objectProp. |  | ||||||
| 		objectProp.AppendIRI(status.GetJSONLDId().GetIRI()) |  | ||||||
| 	} else { |  | ||||||
| 		// Our statusable's are always note types. |  | ||||||
| 		asNote := status.(vocab.ActivityStreamsNote) |  | ||||||
| 		objectProp.AppendActivityStreamsNote(asNote) |  | ||||||
| 	} |  | ||||||
| 	create.SetActivityStreamsObject(objectProp) |  | ||||||
| 
 |  | ||||||
| 	// ID property |  | ||||||
| 	idProp := streams.NewJSONLDIdProperty() |  | ||||||
| 	createID := status.GetJSONLDId().GetIRI().String() + "/activity" |  | ||||||
| 	createIDIRI, err := url.Parse(createID) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	idProp.SetIRI(createIDIRI) |  | ||||||
| 	create.SetJSONLDId(idProp) |  | ||||||
| 
 |  | ||||||
| 	// Actor Property |  | ||||||
| 	actorProp := streams.NewActivityStreamsActorProperty() |  | ||||||
| 	actorIRI, err := ap.ExtractAttributedToURI(status) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, gtserror.Newf("couldn't extract AttributedTo: %w", err) |  | ||||||
| 	} |  | ||||||
| 	actorProp.AppendIRI(actorIRI) |  | ||||||
| 	create.SetActivityStreamsActor(actorProp) |  | ||||||
| 
 |  | ||||||
| 	// Published Property |  | ||||||
| 	publishedProp := streams.NewActivityStreamsPublishedProperty() |  | ||||||
| 	published, err := ap.ExtractPublished(status) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, gtserror.Newf("couldn't extract Published: %w", err) |  | ||||||
| 	} |  | ||||||
| 	publishedProp.Set(published) |  | ||||||
| 	create.SetActivityStreamsPublished(publishedProp) |  | ||||||
| 
 |  | ||||||
| 	// To Property |  | ||||||
| 	toProp := streams.NewActivityStreamsToProperty() |  | ||||||
| 	if toURIs := ap.ExtractToURIs(status); len(toURIs) != 0 { |  | ||||||
| 		for _, toURI := range toURIs { |  | ||||||
| 			toProp.AppendIRI(toURI) |  | ||||||
| 		} |  | ||||||
| 		create.SetActivityStreamsTo(toProp) |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 	// Cc Property | func WrapPollOptionablesInCreate(options ...ap.PollOptionable) vocab.ActivityStreamsCreate { | ||||||
| 	ccProp := streams.NewActivityStreamsCcProperty() | 	if len(options) == 0 { | ||||||
| 	if ccURIs := ap.ExtractCcURIs(status); len(ccURIs) != 0 { | 		panic("no options") | ||||||
| 		for _, ccURI := range ccURIs { |  | ||||||
| 			ccProp.AppendIRI(ccURI) |  | ||||||
| 		} |  | ||||||
| 		create.SetActivityStreamsCc(ccProp) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return create, nil | 	// Extract attributedTo IRI from any option. | ||||||
|  | 	attribTos := ap.GetAttributedTo(options[0]) | ||||||
|  | 	if len(attribTos) != 1 { | ||||||
|  | 		panic("invalid attributedTo count") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| // WrapStatusableInUpdate wraps a Statusable with an Update activity. | 	// Extract target status IRI from any option. | ||||||
| // | 	replyTos := ap.GetInReplyTo(options[0]) | ||||||
| // If objectIRIOnly is set to true, then the function won't put the *entire* note in the Object field of the Create, | 	if len(replyTos) != 1 { | ||||||
| // but just the AP URI of the note. This is useful in cases where you want to give a remote server something to dereference, | 		panic("invalid inReplyTo count") | ||||||
| // and still have control over whether or not they're allowed to actually see the contents. | 	} | ||||||
| func (c *Converter) WrapStatusableInUpdate(status ap.Statusable, objectIRIOnly bool) (vocab.ActivityStreamsUpdate, error) { | 
 | ||||||
|  | 	// Allocate create activity and copy over 'To' property. | ||||||
|  | 	create := streams.NewActivityStreamsCreate() | ||||||
|  | 	ap.AppendTo(create, ap.GetTo(options[0])...) | ||||||
|  | 
 | ||||||
|  | 	// Activity ID formatted as: {$statusIRI}/activity#vote/{$voterIRI}. | ||||||
|  | 	id := replyTos[0].String() + "/activity#vote/" + attribTos[0].String() | ||||||
|  | 	ap.MustSet(ap.SetJSONLDIdStr, ap.WithJSONLDId(create), id) | ||||||
|  | 
 | ||||||
|  | 	// Set a current publish time for activity. | ||||||
|  | 	ap.SetPublished(create, time.Now()) | ||||||
|  | 
 | ||||||
|  | 	// Append each poll option as object to activity. | ||||||
|  | 	for _, option := range options { | ||||||
|  | 		status, _ := ap.ToStatusable(option) | ||||||
|  | 		appendStatusableToActivity(create, status, false) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return create | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func WrapStatusableInUpdate(status ap.Statusable, iriOnly bool) vocab.ActivityStreamsUpdate { | ||||||
| 	update := streams.NewActivityStreamsUpdate() | 	update := streams.NewActivityStreamsUpdate() | ||||||
|  | 	wrapStatusableInActivity(update, status, iriOnly) | ||||||
|  | 	return update | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| 	// Object property | // wrapStatusableInActivity adds the required ap.Statusable data to the given ap.Activityable. | ||||||
| 	objectProp := streams.NewActivityStreamsObjectProperty() | func wrapStatusableInActivity(activity ap.Activityable, status ap.Statusable, iriOnly bool) { | ||||||
| 	if objectIRIOnly { | 	idIRI := ap.GetJSONLDId(status) // activity ID formatted as {$statusIRI}/activity#{$typeName} | ||||||
| 		objectProp.AppendIRI(status.GetJSONLDId().GetIRI()) | 	ap.MustSet(ap.SetJSONLDIdStr, ap.WithJSONLDId(activity), idIRI.String()+"/activity#"+activity.GetTypeName()) | ||||||
| 	} else if _, ok := status.(ap.Pollable); ok { | 	appendStatusableToActivity(activity, status, iriOnly) | ||||||
| 		asQuestion := status.(vocab.ActivityStreamsQuestion) | 	ap.AppendTo(activity, ap.GetTo(status)...) | ||||||
| 		objectProp.AppendActivityStreamsQuestion(asQuestion) | 	ap.AppendCc(activity, ap.GetCc(status)...) | ||||||
|  | 	ap.AppendActor(activity, ap.GetAttributedTo(status)...) | ||||||
|  | 	ap.SetPublished(activity, ap.GetPublished(status)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // appendStatusableToActivity appends a Statusable type to an Activityable, handling case of Question, Note or just IRI type. | ||||||
|  | func appendStatusableToActivity(activity ap.Activityable, status ap.Statusable, iriOnly bool) { | ||||||
|  | 	// Get existing object property or allocate new. | ||||||
|  | 	objProp := activity.GetActivityStreamsObject() | ||||||
|  | 	if objProp == nil { | ||||||
|  | 		objProp = streams.NewActivityStreamsObjectProperty() | ||||||
|  | 		activity.SetActivityStreamsObject(objProp) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if iriOnly { | ||||||
|  | 		// Only append status IRI. | ||||||
|  | 		idIRI := ap.GetJSONLDId(status) | ||||||
|  | 		objProp.AppendIRI(idIRI) | ||||||
|  | 	} else if poll, ok := ap.ToPollable(status); ok { | ||||||
|  | 		// Our Pollable implementer is an AS Question type. | ||||||
|  | 		question := poll.(vocab.ActivityStreamsQuestion) | ||||||
|  | 		objProp.AppendActivityStreamsQuestion(question) | ||||||
| 	} else { | 	} else { | ||||||
| 		asNote := status.(vocab.ActivityStreamsNote) | 		// All of our other Statusable types are AS Note. | ||||||
| 		objectProp.AppendActivityStreamsNote(asNote) | 		note := status.(vocab.ActivityStreamsNote) | ||||||
|  | 		objProp.AppendActivityStreamsNote(note) | ||||||
| 	} | 	} | ||||||
| 	update.SetActivityStreamsObject(objectProp) |  | ||||||
| 
 |  | ||||||
| 	// ID property |  | ||||||
| 	idProp := streams.NewJSONLDIdProperty() |  | ||||||
| 	createID := status.GetJSONLDId().GetIRI().String() + "/activity" |  | ||||||
| 	createIDIRI, err := url.Parse(createID) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	idProp.SetIRI(createIDIRI) |  | ||||||
| 	update.SetJSONLDId(idProp) |  | ||||||
| 
 |  | ||||||
| 	// Actor Property |  | ||||||
| 	actorProp := streams.NewActivityStreamsActorProperty() |  | ||||||
| 	actorIRI, err := ap.ExtractAttributedToURI(status) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, gtserror.Newf("couldn't extract AttributedTo: %w", err) |  | ||||||
| 	} |  | ||||||
| 	actorProp.AppendIRI(actorIRI) |  | ||||||
| 	update.SetActivityStreamsActor(actorProp) |  | ||||||
| 
 |  | ||||||
| 	// To Property |  | ||||||
| 	toProp := streams.NewActivityStreamsToProperty() |  | ||||||
| 	if toURIs := ap.ExtractToURIs(status); len(toURIs) != 0 { |  | ||||||
| 		for _, toURI := range toURIs { |  | ||||||
| 			toProp.AppendIRI(toURI) |  | ||||||
| 		} |  | ||||||
| 		update.SetActivityStreamsTo(toProp) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Cc Property |  | ||||||
| 	ccProp := streams.NewActivityStreamsCcProperty() |  | ||||||
| 	if ccURIs := ap.ExtractCcURIs(status); len(ccURIs) != 0 { |  | ||||||
| 		for _, ccURI := range ccURIs { |  | ||||||
| 			ccProp.AppendIRI(ccURI) |  | ||||||
| 		} |  | ||||||
| 		update.SetActivityStreamsCc(ccProp) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return update, nil |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -24,6 +24,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/stretchr/testify/suite" | 	"github.com/stretchr/testify/suite" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/ap" | 	"github.com/superseriousbusiness/gotosocial/internal/ap" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/typeutils" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type WrapTestSuite struct { | type WrapTestSuite struct { | ||||||
|  | @ -36,7 +37,7 @@ func (suite *WrapTestSuite) TestWrapNoteInCreateIRIOnly() { | ||||||
| 	note, err := suite.typeconverter.StatusToAS(context.Background(), testStatus) | 	note, err := suite.typeconverter.StatusToAS(context.Background(), testStatus) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 
 | 
 | ||||||
| 	create, err := suite.typeconverter.WrapStatusableInCreate(note, true) | 	create := typeutils.WrapStatusableInCreate(note, true) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.NotNil(create) | 	suite.NotNil(create) | ||||||
| 
 | 
 | ||||||
|  | @ -50,7 +51,7 @@ func (suite *WrapTestSuite) TestWrapNoteInCreateIRIOnly() { | ||||||
|   "@context": "https://www.w3.org/ns/activitystreams", |   "@context": "https://www.w3.org/ns/activitystreams", | ||||||
|   "actor": "http://localhost:8080/users/the_mighty_zork", |   "actor": "http://localhost:8080/users/the_mighty_zork", | ||||||
|   "cc": "http://localhost:8080/users/the_mighty_zork/followers", |   "cc": "http://localhost:8080/users/the_mighty_zork/followers", | ||||||
|   "id": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/activity", |   "id": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/activity#Create", | ||||||
|   "object": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY", |   "object": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY", | ||||||
|   "published": "2021-10-20T12:40:37+02:00", |   "published": "2021-10-20T12:40:37+02:00", | ||||||
|   "to": "https://www.w3.org/ns/activitystreams#Public", |   "to": "https://www.w3.org/ns/activitystreams#Public", | ||||||
|  | @ -64,7 +65,7 @@ func (suite *WrapTestSuite) TestWrapNoteInCreate() { | ||||||
| 	note, err := suite.typeconverter.StatusToAS(context.Background(), testStatus) | 	note, err := suite.typeconverter.StatusToAS(context.Background(), testStatus) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 
 | 
 | ||||||
| 	create, err := suite.typeconverter.WrapStatusableInCreate(note, false) | 	create := typeutils.WrapStatusableInCreate(note, false) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.NotNil(create) | 	suite.NotNil(create) | ||||||
| 
 | 
 | ||||||
|  | @ -78,7 +79,7 @@ func (suite *WrapTestSuite) TestWrapNoteInCreate() { | ||||||
|   "@context": "https://www.w3.org/ns/activitystreams", |   "@context": "https://www.w3.org/ns/activitystreams", | ||||||
|   "actor": "http://localhost:8080/users/the_mighty_zork", |   "actor": "http://localhost:8080/users/the_mighty_zork", | ||||||
|   "cc": "http://localhost:8080/users/the_mighty_zork/followers", |   "cc": "http://localhost:8080/users/the_mighty_zork/followers", | ||||||
|   "id": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/activity", |   "id": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY/activity#Create", | ||||||
|   "object": { |   "object": { | ||||||
|     "attachment": [], |     "attachment": [], | ||||||
|     "attributedTo": "http://localhost:8080/users/the_mighty_zork", |     "attributedTo": "http://localhost:8080/users/the_mighty_zork", | ||||||
|  |  | ||||||
|  | @ -17,6 +17,18 @@ | ||||||
| 
 | 
 | ||||||
| package util | package util | ||||||
| 
 | 
 | ||||||
|  | // EqualPtrs returns whether the values contained within two comparable ptr types are equal. | ||||||
|  | func EqualPtrs[T comparable](t1, t2 *T) bool { | ||||||
|  | 	switch { | ||||||
|  | 	case t1 == nil: | ||||||
|  | 		return (t2 == nil) | ||||||
|  | 	case t2 == nil: | ||||||
|  | 		return false | ||||||
|  | 	default: | ||||||
|  | 		return (*t1 == *t2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Ptr returns a pointer to the passed in type | // Ptr returns a pointer to the passed in type | ||||||
| func Ptr[T any](t T) *T { | func Ptr[T any](t T) *T { | ||||||
| 	return &t | 	return &t | ||||||
|  |  | ||||||
|  | @ -43,6 +43,9 @@ EXPECT=$(cat << "EOF" | ||||||
|         "memory-target": 104857600, |         "memory-target": 104857600, | ||||||
|         "mention-mem-ratio": 2, |         "mention-mem-ratio": 2, | ||||||
|         "notification-mem-ratio": 2, |         "notification-mem-ratio": 2, | ||||||
|  |         "poll-mem-ratio": 1, | ||||||
|  |         "poll-vote-ids-mem-ratio": 2, | ||||||
|  |         "poll-vote-mem-ratio": 2, | ||||||
|         "report-mem-ratio": 1, |         "report-mem-ratio": 1, | ||||||
|         "status-fave-ids-mem-ratio": 3, |         "status-fave-ids-mem-ratio": 3, | ||||||
|         "status-fave-mem-ratio": 2, |         "status-fave-mem-ratio": 2, | ||||||
|  |  | ||||||
|  | @ -44,6 +44,8 @@ var testModels = []interface{}{ | ||||||
| 	>smodel.Marker{}, | 	>smodel.Marker{}, | ||||||
| 	>smodel.MediaAttachment{}, | 	>smodel.MediaAttachment{}, | ||||||
| 	>smodel.Mention{}, | 	>smodel.Mention{}, | ||||||
|  | 	>smodel.Poll{}, | ||||||
|  | 	>smodel.PollVote{}, | ||||||
| 	>smodel.Status{}, | 	>smodel.Status{}, | ||||||
| 	>smodel.StatusToEmoji{}, | 	>smodel.StatusToEmoji{}, | ||||||
| 	>smodel.StatusToTag{}, | 	>smodel.StatusToTag{}, | ||||||
|  | @ -315,6 +317,18 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) { | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	for _, v := range NewTestPolls() { | ||||||
|  | 		if err := db.Put(ctx, v); err != nil { | ||||||
|  | 			log.Panic(nil, err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, v := range NewTestPollVotes() { | ||||||
|  | 		if err := db.Put(ctx, v); err != nil { | ||||||
|  | 			log.Panic(nil, err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	if err := db.CreateInstanceAccount(ctx); err != nil { | 	if err := db.CreateInstanceAccount(ctx); err != nil { | ||||||
| 		log.Panic(nil, err) | 		log.Panic(nil, err) | ||||||
| 	} | 	} | ||||||
|  | @ -330,7 +344,7 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) { | ||||||
| func StandardDBTeardown(db db.DB) { | func StandardDBTeardown(db db.DB) { | ||||||
| 	ctx := context.Background() | 	ctx := context.Background() | ||||||
| 	if db == nil { | 	if db == nil { | ||||||
| 		log.Panic(nil, "db was nil") | 		return | ||||||
| 	} | 	} | ||||||
| 	for _, m := range testModels { | 	for _, m := range testModels { | ||||||
| 		if err := db.DropTable(ctx, m); err != nil { | 		if err := db.DropTable(ctx, m); err != nil { | ||||||
|  |  | ||||||
|  | @ -1538,6 +1538,33 @@ func NewTestStatuses() map[string]*gtsmodel.Status { | ||||||
| 			Likeable:                 util.Ptr(true), | 			Likeable:                 util.Ptr(true), | ||||||
| 			ActivityStreamsType:      ap.ObjectNote, | 			ActivityStreamsType:      ap.ObjectNote, | ||||||
| 		}, | 		}, | ||||||
|  | 		"local_account_1_status_6": { | ||||||
|  | 			ID:                       "01HEN2RZ8BG29Y5Z9VJC73HZW7", | ||||||
|  | 			URI:                      "http://localhost:8080/users/the_mighty_zork/statuses/065TKBPE0H2AH8S5X8JCK4XC58", | ||||||
|  | 			URL:                      "http://localhost:8080/@the_mighty_zork/statuses/065TKBPE0H2AH8S5X8JCK4XC58", | ||||||
|  | 			Content:                  "what do you think of sloths?", | ||||||
|  | 			Text:                     "what do you think of sloths?", | ||||||
|  | 			AttachmentIDs:            nil, | ||||||
|  | 			CreatedAt:                TimeMustParse("2022-05-20T11:41:10Z"), | ||||||
|  | 			UpdatedAt:                TimeMustParse("2022-05-20T11:41:10Z"), | ||||||
|  | 			Local:                    util.Ptr(true), | ||||||
|  | 			AccountURI:               "http://localhost:8080/users/the_mighty_zork", | ||||||
|  | 			AccountID:                "01F8MH1H7YV1Z7D2C8K2730QBF", | ||||||
|  | 			InReplyToID:              "", | ||||||
|  | 			BoostOfID:                "", | ||||||
|  | 			ThreadID:                 "", | ||||||
|  | 			ContentWarning:           "", | ||||||
|  | 			Visibility:               gtsmodel.VisibilityFollowersOnly, | ||||||
|  | 			Sensitive:                util.Ptr(false), | ||||||
|  | 			Language:                 "en", | ||||||
|  | 			CreatedWithApplicationID: "01F8MGY43H3N2C8EWPR2FPYEXG", | ||||||
|  | 			Federated:                util.Ptr(true), | ||||||
|  | 			Boostable:                util.Ptr(true), | ||||||
|  | 			Replyable:                util.Ptr(true), | ||||||
|  | 			Likeable:                 util.Ptr(true), | ||||||
|  | 			ActivityStreamsType:      ap.ActivityQuestion, | ||||||
|  | 			PollID:                   "01HEN2RKT1YTEZ80SA8HGP105F", | ||||||
|  | 		}, | ||||||
| 		"local_account_2_status_1": { | 		"local_account_2_status_1": { | ||||||
| 			ID:                       "01F8MHBQCBTDKN6X5VHGMMN4MA", | 			ID:                       "01F8MHBQCBTDKN6X5VHGMMN4MA", | ||||||
| 			URI:                      "http://localhost:8080/users/1happyturtle/statuses/01F8MHBQCBTDKN6X5VHGMMN4MA", | 			URI:                      "http://localhost:8080/users/1happyturtle/statuses/01F8MHBQCBTDKN6X5VHGMMN4MA", | ||||||
|  | @ -1722,6 +1749,33 @@ func NewTestStatuses() map[string]*gtsmodel.Status { | ||||||
| 			Likeable:                 util.Ptr(true), | 			Likeable:                 util.Ptr(true), | ||||||
| 			ActivityStreamsType:      ap.ObjectNote, | 			ActivityStreamsType:      ap.ObjectNote, | ||||||
| 		}, | 		}, | ||||||
|  | 		"local_account_2_status_8": { | ||||||
|  | 			ID:                       "01HEN2PRXT0TF4YDRA64FZZRN7", | ||||||
|  | 			URI:                      "http://localhost:8080/users/1happyturtle/statuses/065TKBPE0EJ6X3QDR1AH9DAB8M", | ||||||
|  | 			URL:                      "http://localhost:8080/@1happyturtle/statuses/065TKBPE0EJ6X3QDR1AH9DAB8M", | ||||||
|  | 			Content:                  "hey everyone i got stuck in a shed. any ideas for how to get out?", | ||||||
|  | 			Text:                     "hey everyone i got stuck in a shed. any ideas for how to get out?", | ||||||
|  | 			AttachmentIDs:            nil, | ||||||
|  | 			CreatedAt:                TimeMustParse("2021-07-28T10:40:37+02:00"), | ||||||
|  | 			UpdatedAt:                TimeMustParse("2021-07-28T10:40:37+02:00"), | ||||||
|  | 			Local:                    util.Ptr(true), | ||||||
|  | 			AccountURI:               "http://localhost:8080/users/1happyturtle", | ||||||
|  | 			AccountID:                "01F8MH5NBDF2MV7CTC4Q5128HF", | ||||||
|  | 			InReplyToID:              "", | ||||||
|  | 			BoostOfID:                "", | ||||||
|  | 			ThreadID:                 "", | ||||||
|  | 			ContentWarning:           "", | ||||||
|  | 			Visibility:               gtsmodel.VisibilityPublic, | ||||||
|  | 			Sensitive:                util.Ptr(false), | ||||||
|  | 			Language:                 "en", | ||||||
|  | 			CreatedWithApplicationID: "01F8MGYG9E893WRHW0TAEXR8GJ", | ||||||
|  | 			Federated:                util.Ptr(true), | ||||||
|  | 			Boostable:                util.Ptr(true), | ||||||
|  | 			Replyable:                util.Ptr(true), | ||||||
|  | 			Likeable:                 util.Ptr(true), | ||||||
|  | 			ActivityStreamsType:      ap.ActivityQuestion, | ||||||
|  | 			PollID:                   "01HEN2QB5NR4NCEHGYC3HN84K6", | ||||||
|  | 		}, | ||||||
| 		"remote_account_1_status_1": { | 		"remote_account_1_status_1": { | ||||||
| 			ID:                       "01FVW7JHQFSFK166WWKR8CBA6M", | 			ID:                       "01FVW7JHQFSFK166WWKR8CBA6M", | ||||||
| 			URI:                      "http://fossbros-anonymous.io/users/foss_satan/statuses/01FVW7JHQFSFK166WWKR8CBA6M", | 			URI:                      "http://fossbros-anonymous.io/users/foss_satan/statuses/01FVW7JHQFSFK166WWKR8CBA6M", | ||||||
|  | @ -1749,6 +1803,136 @@ func NewTestStatuses() map[string]*gtsmodel.Status { | ||||||
| 			Likeable:                 util.Ptr(true), | 			Likeable:                 util.Ptr(true), | ||||||
| 			ActivityStreamsType:      ap.ObjectNote, | 			ActivityStreamsType:      ap.ObjectNote, | ||||||
| 		}, | 		}, | ||||||
|  | 		"remote_account_1_status_2": { | ||||||
|  | 			ID:                       "01HEN2QRFA8H3C6QPN7RD4KSR6", | ||||||
|  | 			URI:                      "http://fossbros-anonymous.io/users/foss_satan/statuses/065TKDN4BX1PC8N19TSY9SD2N4", | ||||||
|  | 			URL:                      "http://fossbros-anonymous.io/@foss_satan/statuses/065TKDN4BX1PC8N19TSY9SD2N4", | ||||||
|  | 			Content:                  "what products should i buy at the grocery store?", | ||||||
|  | 			AttachmentIDs:            []string{"01FVW7RXPQ8YJHTEXYPE7Q8ZY0"}, | ||||||
|  | 			CreatedAt:                TimeMustParse("2021-09-11T11:40:37+02:00"), | ||||||
|  | 			UpdatedAt:                TimeMustParse("2021-09-11T11:40:37+02:00"), | ||||||
|  | 			Local:                    util.Ptr(false), | ||||||
|  | 			AccountURI:               "http://fossbros-anonymous.io/users/foss_satan", | ||||||
|  | 			AccountID:                "01F8MH5ZK5VRH73AKHQM6Y9VNX", | ||||||
|  | 			InReplyToID:              "", | ||||||
|  | 			InReplyToAccountID:       "", | ||||||
|  | 			InReplyToURI:             "", | ||||||
|  | 			BoostOfID:                "", | ||||||
|  | 			ContentWarning:           "", | ||||||
|  | 			Visibility:               gtsmodel.VisibilityUnlocked, | ||||||
|  | 			Sensitive:                util.Ptr(false), | ||||||
|  | 			Language:                 "en", | ||||||
|  | 			CreatedWithApplicationID: "", | ||||||
|  | 			Federated:                util.Ptr(true), | ||||||
|  | 			Boostable:                util.Ptr(true), | ||||||
|  | 			Replyable:                util.Ptr(true), | ||||||
|  | 			Likeable:                 util.Ptr(true), | ||||||
|  | 			ActivityStreamsType:      ap.ActivityQuestion, | ||||||
|  | 			PollID:                   "01HEN2R65468ZG657C4ZPHJ4EX", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func NewTestPolls() map[string]*gtsmodel.Poll { | ||||||
|  | 	return map[string]*gtsmodel.Poll{ | ||||||
|  | 		"local_account_1_status_6_poll": { | ||||||
|  | 			ID:         "01HEN2RKT1YTEZ80SA8HGP105F", | ||||||
|  | 			Multiple:   util.Ptr(false), | ||||||
|  | 			HideCounts: util.Ptr(true), | ||||||
|  | 			Options:    []string{"good", "bad", "meh"}, | ||||||
|  | 			Votes:      []int{2, 0, 0}, // needs to match stored poll votes | ||||||
|  | 			Voters:     util.Ptr(2),    // needs to match stored poll votes | ||||||
|  | 			StatusID:   "01HEN2RZ8BG29Y5Z9VJC73HZW7", | ||||||
|  | 			Status:     nil, | ||||||
|  | 			ExpiresAt:  TimeMustParse("2022-05-21T11:41:10Z"), | ||||||
|  | 			ClosedAt:   time.Time{}, | ||||||
|  | 			Closing:    false, | ||||||
|  | 		}, | ||||||
|  | 		"local_account_2_status_8_poll": { | ||||||
|  | 			ID:         "01HEN2QB5NR4NCEHGYC3HN84K6", | ||||||
|  | 			Multiple:   util.Ptr(false), | ||||||
|  | 			HideCounts: util.Ptr(false), | ||||||
|  | 			Options:    []string{"50:50", "phone a friend", "ask the audience"}, | ||||||
|  | 			Votes:      []int{0, 1, 1}, // needs to match stored poll votes | ||||||
|  | 			Voters:     util.Ptr(2),    // needs to match stored poll votes | ||||||
|  | 			StatusID:   "01HEN2PRXT0TF4YDRA64FZZRN7", | ||||||
|  | 			Status:     nil, | ||||||
|  | 			ExpiresAt:  TimeMustParse("2021-08-28T10:40:37+02:00"), | ||||||
|  | 			ClosedAt:   TimeMustParse("2021-08-28T10:40:37+02:00"), | ||||||
|  | 			Closing:    false, | ||||||
|  | 		}, | ||||||
|  | 		"remote_account_1_status_2_poll": { | ||||||
|  | 			ID:         "01HEN2R65468ZG657C4ZPHJ4EX", | ||||||
|  | 			Multiple:   util.Ptr(true), | ||||||
|  | 			HideCounts: util.Ptr(false), | ||||||
|  | 			Options:    []string{"vaseline", "tissues", "financial times"}, | ||||||
|  | 			Votes:      []int{3, 2, 18}, | ||||||
|  | 			Voters:     util.Ptr(6), | ||||||
|  | 			StatusID:   "01HEN2QRFA8H3C6QPN7RD4KSR6", | ||||||
|  | 			Status:     nil, | ||||||
|  | 			ExpiresAt:  TimeMustParse("2021-09-11T12:40:37+02:00"), | ||||||
|  | 			ClosedAt:   TimeMustParse("2021-09-11T12:40:37+02:00"), | ||||||
|  | 			Closing:    false, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func NewTestPollVotes() map[string]*gtsmodel.PollVote { | ||||||
|  | 	return map[string]*gtsmodel.PollVote{ | ||||||
|  | 		"local_account_1_status_6_poll_vote_local_account_2": { | ||||||
|  | 			ID:        "01HEN2VN4DZ4ENCK6AS4PKM5B3", | ||||||
|  | 			Choices:   []int{0}, | ||||||
|  | 			AccountID: "01F8MH5NBDF2MV7CTC4Q5128HF", | ||||||
|  | 			Account:   nil, | ||||||
|  | 			PollID:    "01HEN2RKT1YTEZ80SA8HGP105F", | ||||||
|  | 			Poll:      nil, | ||||||
|  | 			CreatedAt: TimeMustParse("2022-05-20T14:41:10Z"), | ||||||
|  | 		}, | ||||||
|  | 		"local_account_1_status_6_poll_vote_remote_account_1": { | ||||||
|  | 			ID:        "01HEN2VM975JG8N9KPFQ597KGF", | ||||||
|  | 			Choices:   []int{0}, | ||||||
|  | 			AccountID: "01F8MH5ZK5VRH73AKHQM6Y9VNX", | ||||||
|  | 			Account:   nil, | ||||||
|  | 			PollID:    "01HEN2RKT1YTEZ80SA8HGP105F", | ||||||
|  | 			Poll:      nil, | ||||||
|  | 			CreatedAt: TimeMustParse("2022-05-20T15:41:10Z"), | ||||||
|  | 		}, | ||||||
|  | 		"local_account_2_status_8_poll_vote_local_account_1": { | ||||||
|  | 			ID:        "01HEN2VK9TX5BTD3B0CSRBWE89", | ||||||
|  | 			Choices:   []int{2}, | ||||||
|  | 			AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF", | ||||||
|  | 			Account:   nil, | ||||||
|  | 			PollID:    "01HEN2QB5NR4NCEHGYC3HN84K6", | ||||||
|  | 			Poll:      nil, | ||||||
|  | 			CreatedAt: TimeMustParse("2021-07-29T10:40:37+02:00"), | ||||||
|  | 		}, | ||||||
|  | 		"local_account_2_status_8_poll_vote_remote_account_1": { | ||||||
|  | 			ID:        "01HEN2VHW4HAHBM4YH3N55794D", | ||||||
|  | 			Choices:   []int{1}, | ||||||
|  | 			AccountID: "01F8MH5ZK5VRH73AKHQM6Y9VNX", | ||||||
|  | 			Account:   nil, | ||||||
|  | 			PollID:    "01HEN2QB5NR4NCEHGYC3HN84K6", | ||||||
|  | 			Poll:      nil, | ||||||
|  | 			CreatedAt: TimeMustParse("2021-08-10T10:40:37+02:00"), | ||||||
|  | 		}, | ||||||
|  | 		"remote_account_1_status_2_poll_vote_local_account_1": { | ||||||
|  | 			ID:        "01HEN2VH077W1QY7VKQFPKD6B6", | ||||||
|  | 			Choices:   []int{1, 2}, | ||||||
|  | 			AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF", | ||||||
|  | 			Account:   nil, | ||||||
|  | 			PollID:    "01HEN2R65468ZG657C4ZPHJ4EX", | ||||||
|  | 			Poll:      nil, | ||||||
|  | 			CreatedAt: TimeMustParse("2021-09-11T11:45:37+02:00"), | ||||||
|  | 		}, | ||||||
|  | 		"remote_account_1_status_2_poll_vote_local_account_2": { | ||||||
|  | 			ID:        "01HEN2VG6EP3GJA208586H356K", | ||||||
|  | 			Choices:   []int{0, 2}, | ||||||
|  | 			AccountID: "01F8MH5NBDF2MV7CTC4Q5128HF", | ||||||
|  | 			Account:   nil, | ||||||
|  | 			PollID:    "01HEN2R65468ZG657C4ZPHJ4EX", | ||||||
|  | 			Poll:      nil, | ||||||
|  | 			CreatedAt: TimeMustParse("2021-09-11T11:47:37+02:00"), | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue