mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 06:02:26 -05:00 
			
		
		
		
	[feature] add paging to account follows, followers and follow requests endpoints (#2186)
This commit is contained in:
		
					parent
					
						
							
								4b594516ec
							
						
					
				
			
			
				commit
				
					
						7293d6029b
					
				
			
		
					 51 changed files with 2281 additions and 641 deletions
				
			
		|  | @ -3072,6 +3072,13 @@ paths: | ||||||
|                 - accounts |                 - accounts | ||||||
|     /api/v1/accounts/{id}/followers: |     /api/v1/accounts/{id}/followers: | ||||||
|         get: |         get: | ||||||
|  |             description: |- | ||||||
|  |                 The next and previous queries can be parsed from the returned Link header. | ||||||
|  |                 Example: | ||||||
|  | 
 | ||||||
|  |                 ``` | ||||||
|  |                 <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/followers?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/followers?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev" | ||||||
|  |                 ```` | ||||||
|             operationId: accountFollowers |             operationId: accountFollowers | ||||||
|             parameters: |             parameters: | ||||||
|                 - description: Account ID. |                 - description: Account ID. | ||||||
|  | @ -3079,6 +3086,25 @@ paths: | ||||||
|                   name: id |                   name: id | ||||||
|                   required: true |                   required: true | ||||||
|                   type: string |                   type: string | ||||||
|  |                 - description: 'Return only follower accounts *OLDER* than the given max ID. The follower account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.' | ||||||
|  |                   in: query | ||||||
|  |                   name: max_id | ||||||
|  |                   type: string | ||||||
|  |                 - description: 'Return only follower accounts *NEWER* than the given since ID. The follower account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.' | ||||||
|  |                   in: query | ||||||
|  |                   name: since_id | ||||||
|  |                   type: string | ||||||
|  |                 - description: 'Return only follower accounts *IMMEDIATELY NEWER* than the given min ID. The follower account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.' | ||||||
|  |                   in: query | ||||||
|  |                   name: min_id | ||||||
|  |                   type: string | ||||||
|  |                 - default: 40 | ||||||
|  |                   description: Number of follower accounts to return. | ||||||
|  |                   in: query | ||||||
|  |                   maximum: 80 | ||||||
|  |                   minimum: 1 | ||||||
|  |                   name: limit | ||||||
|  |                   type: integer | ||||||
|             produces: |             produces: | ||||||
|                 - application/json |                 - application/json | ||||||
|             responses: |             responses: | ||||||
|  | @ -3106,6 +3132,13 @@ paths: | ||||||
|                 - accounts |                 - accounts | ||||||
|     /api/v1/accounts/{id}/following: |     /api/v1/accounts/{id}/following: | ||||||
|         get: |         get: | ||||||
|  |             description: |- | ||||||
|  |                 The next and previous queries can be parsed from the returned Link header. | ||||||
|  |                 Example: | ||||||
|  | 
 | ||||||
|  |                 ``` | ||||||
|  |                 <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/following?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/following?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev" | ||||||
|  |                 ```` | ||||||
|             operationId: accountFollowing |             operationId: accountFollowing | ||||||
|             parameters: |             parameters: | ||||||
|                 - description: Account ID. |                 - description: Account ID. | ||||||
|  | @ -3113,6 +3146,25 @@ paths: | ||||||
|                   name: id |                   name: id | ||||||
|                   required: true |                   required: true | ||||||
|                   type: string |                   type: string | ||||||
|  |                 - description: 'Return only following accounts *OLDER* than the given max ID. The following account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.' | ||||||
|  |                   in: query | ||||||
|  |                   name: max_id | ||||||
|  |                   type: string | ||||||
|  |                 - description: 'Return only following accounts *NEWER* than the given since ID. The following account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.' | ||||||
|  |                   in: query | ||||||
|  |                   name: since_id | ||||||
|  |                   type: string | ||||||
|  |                 - description: 'Return only following accounts *IMMEDIATELY NEWER* than the given min ID. The following account with the specified ID will not be included in the response. NOTE: the ID is of the internal follow, NOT any of the returned accounts.' | ||||||
|  |                   in: query | ||||||
|  |                   name: min_id | ||||||
|  |                   type: string | ||||||
|  |                 - default: 40 | ||||||
|  |                   description: Number of following accounts to return. | ||||||
|  |                   in: query | ||||||
|  |                   maximum: 80 | ||||||
|  |                   minimum: 1 | ||||||
|  |                   name: limit | ||||||
|  |                   type: integer | ||||||
|             produces: |             produces: | ||||||
|                 - application/json |                 - application/json | ||||||
|             responses: |             responses: | ||||||
|  | @ -4679,19 +4731,25 @@ paths: | ||||||
|                 ```` |                 ```` | ||||||
|             operationId: blocksGet |             operationId: blocksGet | ||||||
|             parameters: |             parameters: | ||||||
|                 - default: 20 |                 - description: 'Return only blocked accounts *OLDER* than the given max ID. The blocked account with the specified ID will not be included in the response. NOTE: the ID is of the internal block, NOT any of the returned accounts.' | ||||||
|                   description: Number of blocks to return. |  | ||||||
|                   in: query |  | ||||||
|                   name: limit |  | ||||||
|                   type: integer |  | ||||||
|                 - description: Return only blocks *OLDER* than the given block ID. The block with the specified ID will not be included in the response. |  | ||||||
|                   in: query |                   in: query | ||||||
|                   name: max_id |                   name: max_id | ||||||
|                   type: string |                   type: string | ||||||
|                 - description: Return only blocks *NEWER* than the given block ID. The block with the specified ID will not be included in the response. |                 - description: 'Return only blocked accounts *NEWER* than the given since ID. The blocked account with the specified ID will not be included in the response. NOTE: the ID is of the internal block, NOT any of the returned accounts.' | ||||||
|                   in: query |                   in: query | ||||||
|                   name: since_id |                   name: since_id | ||||||
|                   type: string |                   type: string | ||||||
|  |                 - description: 'Return only blocked accounts *IMMEDIATELY NEWER* than the given min ID. The blocked account with the specified ID will not be included in the response. NOTE: the ID is of the internal block, NOT any of the returned accounts.' | ||||||
|  |                   in: query | ||||||
|  |                   name: min_id | ||||||
|  |                   type: string | ||||||
|  |                 - default: 40 | ||||||
|  |                   description: Number of blocked accounts to return. | ||||||
|  |                   in: query | ||||||
|  |                   maximum: 80 | ||||||
|  |                   minimum: 1 | ||||||
|  |                   name: limit | ||||||
|  |                   type: integer | ||||||
|             produces: |             produces: | ||||||
|                 - application/json |                 - application/json | ||||||
|             responses: |             responses: | ||||||
|  | @ -4857,12 +4915,32 @@ paths: | ||||||
|                 - featured_tags |                 - featured_tags | ||||||
|     /api/v1/follow_requests: |     /api/v1/follow_requests: | ||||||
|         get: |         get: | ||||||
|             description: Accounts will be sorted in order of follow request date descending (newest first). |             description: |- | ||||||
|  |                 The next and previous queries can be parsed from the returned Link header. | ||||||
|  |                 Example: | ||||||
|  | 
 | ||||||
|  |                 ``` | ||||||
|  |                 <https://example.org/api/v1/follow_requests?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/follow_requests?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev" | ||||||
|  |                 ```` | ||||||
|             operationId: getFollowRequests |             operationId: getFollowRequests | ||||||
|             parameters: |             parameters: | ||||||
|                 - default: 40 |                 - description: 'Return only follow requesting accounts *OLDER* than the given max ID. The follow requester with the specified ID will not be included in the response. NOTE: the ID is of the internal follow request, NOT any of the returned accounts.' | ||||||
|                   description: Number of accounts to return. |  | ||||||
|                   in: query |                   in: query | ||||||
|  |                   name: max_id | ||||||
|  |                   type: string | ||||||
|  |                 - description: 'Return only follow requesting accounts *NEWER* than the given since ID. The follow requester with the specified ID will not be included in the response. NOTE: the ID is of the internal follow request, NOT any of the returned accounts.' | ||||||
|  |                   in: query | ||||||
|  |                   name: since_id | ||||||
|  |                   type: string | ||||||
|  |                 - description: 'Return only follow requesting accounts *IMMEDIATELY NEWER* than the given min ID. The follow requester with the specified ID will not be included in the response. NOTE: the ID is of the internal follow request, NOT any of the returned accounts.' | ||||||
|  |                   in: query | ||||||
|  |                   name: min_id | ||||||
|  |                   type: string | ||||||
|  |                 - default: 40 | ||||||
|  |                   description: Number of follow requesting accounts to return. | ||||||
|  |                   in: query | ||||||
|  |                   maximum: 80 | ||||||
|  |                   minimum: 1 | ||||||
|                   name: limit |                   name: limit | ||||||
|                   type: integer |                   type: integer | ||||||
|             produces: |             produces: | ||||||
|  |  | ||||||
							
								
								
									
										1
									
								
								go.mod
									
										
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
										
									
									
									
								
							|  | @ -46,6 +46,7 @@ require ( | ||||||
| 	github.com/superseriousbusiness/exif-terminator v0.5.0 | 	github.com/superseriousbusiness/exif-terminator v0.5.0 | ||||||
| 	github.com/superseriousbusiness/oauth2/v4 v4.3.2-SSB.0.20230227143000-f4900831d6c8 | 	github.com/superseriousbusiness/oauth2/v4 v4.3.2-SSB.0.20230227143000-f4900831d6c8 | ||||||
| 	github.com/tdewolff/minify/v2 v2.12.9 | 	github.com/tdewolff/minify/v2 v2.12.9 | ||||||
|  | 	github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 | ||||||
| 	github.com/ulule/limiter/v3 v3.11.2 | 	github.com/ulule/limiter/v3 v3.11.2 | ||||||
| 	github.com/uptrace/bun v1.1.15 | 	github.com/uptrace/bun v1.1.15 | ||||||
| 	github.com/uptrace/bun/dialect/pgdialect v1.1.15 | 	github.com/uptrace/bun/dialect/pgdialect v1.1.15 | ||||||
|  |  | ||||||
							
								
								
									
										2
									
								
								go.sum
									
										
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
										
									
									
									
								
							|  | @ -568,6 +568,8 @@ github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563 h1:Otn9S136ELckZ | ||||||
| github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563/go.mod h1:mLqSmt7Dv/CNneF2wfcChfN1rvapyQr01LGKnKex0DQ= | github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563/go.mod h1:mLqSmt7Dv/CNneF2wfcChfN1rvapyQr01LGKnKex0DQ= | ||||||
| github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= | github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= | ||||||
| github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= | github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= | ||||||
|  | github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y= | ||||||
|  | github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE= | ||||||
| github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= | github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= | ||||||
| github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= | github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= | ||||||
| github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= | github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= | ||||||
|  |  | ||||||
|  | @ -18,21 +18,33 @@ | ||||||
| package accounts_test | package accounts_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
|  | 	"math/rand" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | 	"net/url" | ||||||
|  | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| 	"github.com/stretchr/testify/suite" | 	"github.com/stretchr/testify/suite" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/api/client/accounts" | 	"github.com/superseriousbusiness/gotosocial/internal/api/client/accounts" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/testrig" | 	"github.com/superseriousbusiness/gotosocial/testrig" | ||||||
|  | 	"github.com/tomnomnom/linkheader" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // random reader according to current-time source seed. | ||||||
|  | var randRd = rand.New(rand.NewSource(time.Now().Unix())) | ||||||
|  | 
 | ||||||
| type FollowTestSuite struct { | type FollowTestSuite struct { | ||||||
| 	AccountStandardTestSuite | 	AccountStandardTestSuite | ||||||
| } | } | ||||||
|  | @ -69,6 +81,405 @@ func (suite *FollowTestSuite) TestFollowSelf() { | ||||||
| 	assert.NoError(suite.T(), err) | 	assert.NoError(suite.T(), err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (suite *FollowTestSuite) TestGetFollowersPageBackwardLimit2() { | ||||||
|  | 	suite.testGetFollowersPage(2, "backward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *FollowTestSuite) TestGetFollowersPageBackwardLimit4() { | ||||||
|  | 	suite.testGetFollowersPage(4, "backward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *FollowTestSuite) TestGetFollowersPageBackwardLimit6() { | ||||||
|  | 	suite.testGetFollowersPage(6, "backward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *FollowTestSuite) TestGetFollowersPageForwardLimit2() { | ||||||
|  | 	suite.testGetFollowersPage(2, "forward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *FollowTestSuite) TestGetFollowersPageForwardLimit4() { | ||||||
|  | 	suite.testGetFollowersPage(4, "forward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *FollowTestSuite) TestGetFollowersPageForwardLimit6() { | ||||||
|  | 	suite.testGetFollowersPage(6, "forward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *FollowTestSuite) testGetFollowersPage(limit int, direction string) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	// The authed local account we are going to use for HTTP requests | ||||||
|  | 	requestingAccount := suite.testAccounts["local_account_1"] | ||||||
|  | 	suite.clearAccountRelations(requestingAccount.ID) | ||||||
|  | 
 | ||||||
|  | 	// Get current time. | ||||||
|  | 	now := time.Now() | ||||||
|  | 
 | ||||||
|  | 	var i int | ||||||
|  | 
 | ||||||
|  | 	for _, targetAccount := range suite.testAccounts { | ||||||
|  | 		if targetAccount.ID == requestingAccount.ID { | ||||||
|  | 			// we cannot be our own target... | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Get next simple ID. | ||||||
|  | 		id := strconv.Itoa(i) | ||||||
|  | 		i++ | ||||||
|  | 
 | ||||||
|  | 		// put a follow in the database | ||||||
|  | 		err := suite.db.PutFollow(ctx, >smodel.Follow{ | ||||||
|  | 			ID:              id, | ||||||
|  | 			CreatedAt:       now, | ||||||
|  | 			UpdatedAt:       now, | ||||||
|  | 			URI:             fmt.Sprintf("%s/follow/%s", targetAccount.URI, id), | ||||||
|  | 			AccountID:       targetAccount.ID, | ||||||
|  | 			TargetAccountID: requestingAccount.ID, | ||||||
|  | 		}) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Bump now by 1 second. | ||||||
|  | 		now = now.Add(time.Second) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get _ALL_ follows we expect to see without any paging (this filters invisible). | ||||||
|  | 	apiRsp, err := suite.processor.Account().FollowersGet(ctx, requestingAccount, requestingAccount.ID, nil) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	expectAccounts := apiRsp.Items // interfaced{} account slice | ||||||
|  | 
 | ||||||
|  | 	// Iteratively set | ||||||
|  | 	// link query string. | ||||||
|  | 	var query string | ||||||
|  | 
 | ||||||
|  | 	switch direction { | ||||||
|  | 	case "backward": | ||||||
|  | 		// Set the starting query to page backward from newest. | ||||||
|  | 		acc := expectAccounts[0].(*model.Account) | ||||||
|  | 		newest, _ := suite.db.GetFollow(ctx, acc.ID, requestingAccount.ID) | ||||||
|  | 		expectAccounts = expectAccounts[1:] | ||||||
|  | 		query = fmt.Sprintf("limit=%d&max_id=%s", limit, newest.ID) | ||||||
|  | 
 | ||||||
|  | 	case "forward": | ||||||
|  | 		// Set the starting query to page forward from the oldest. | ||||||
|  | 		acc := expectAccounts[len(expectAccounts)-1].(*model.Account) | ||||||
|  | 		oldest, _ := suite.db.GetFollow(ctx, acc.ID, requestingAccount.ID) | ||||||
|  | 		expectAccounts = expectAccounts[:len(expectAccounts)-1] | ||||||
|  | 		query = fmt.Sprintf("limit=%d&min_id=%s", limit, oldest.ID) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for p := 0; ; p++ { | ||||||
|  | 		// Prepare new request for endpoint | ||||||
|  | 		recorder := httptest.NewRecorder() | ||||||
|  | 		endpoint := fmt.Sprintf("/api/v1/accounts/%s/followers", requestingAccount.ID) | ||||||
|  | 		ctx := suite.newContext(recorder, http.MethodGet, []byte{}, endpoint, "") | ||||||
|  | 		ctx.Params = gin.Params{{Key: "id", Value: requestingAccount.ID}} | ||||||
|  | 		ctx.Request.URL.RawQuery = query // setting provided next query value | ||||||
|  | 
 | ||||||
|  | 		// call the handler and check for valid response code. | ||||||
|  | 		suite.T().Logf("direction=%q page=%d query=%q", direction, p, query) | ||||||
|  | 		suite.accountsModule.AccountFollowersGETHandler(ctx) | ||||||
|  | 		suite.Equal(http.StatusOK, recorder.Code) | ||||||
|  | 
 | ||||||
|  | 		var accounts []*model.Account | ||||||
|  | 
 | ||||||
|  | 		// Decode response body into API account models | ||||||
|  | 		result := recorder.Result() | ||||||
|  | 		dec := json.NewDecoder(result.Body) | ||||||
|  | 		err := dec.Decode(&accounts) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 		_ = result.Body.Close() | ||||||
|  | 
 | ||||||
|  | 		var ( | ||||||
|  | 
 | ||||||
|  | 			// start provides the starting index for loop in accounts. | ||||||
|  | 			start func([]*model.Account) int | ||||||
|  | 
 | ||||||
|  | 			// iter performs the loop iter step with index. | ||||||
|  | 			iter func(int) int | ||||||
|  | 
 | ||||||
|  | 			// check performs the loop conditional check against index and accounts. | ||||||
|  | 			check func(int, []*model.Account) bool | ||||||
|  | 
 | ||||||
|  | 			// expect pulls the next account to check against from expectAccounts. | ||||||
|  | 			expect func([]interface{}) interface{} | ||||||
|  | 
 | ||||||
|  | 			// trunc drops the last checked account from expectAccounts. | ||||||
|  | 			trunc func([]interface{}) []interface{} | ||||||
|  | 		) | ||||||
|  | 
 | ||||||
|  | 		switch direction { | ||||||
|  | 		case "backward": | ||||||
|  | 			// When paging backwards (DESC) we: | ||||||
|  | 			// - iter from end of received accounts | ||||||
|  | 			// - iterate backward through received accounts | ||||||
|  | 			// - stop when we reach last index of received accounts | ||||||
|  | 			// - compare each received with the first index of expected accounts | ||||||
|  | 			// - after each compare, drop the first index of expected accounts | ||||||
|  | 			start = func([]*model.Account) int { return 0 } | ||||||
|  | 			iter = func(i int) int { return i + 1 } | ||||||
|  | 			check = func(idx int, i []*model.Account) bool { return idx < len(i) } | ||||||
|  | 			expect = func(i []interface{}) interface{} { return i[0] } | ||||||
|  | 			trunc = func(i []interface{}) []interface{} { return i[1:] } | ||||||
|  | 
 | ||||||
|  | 		case "forward": | ||||||
|  | 			// When paging forwards (ASC) we: | ||||||
|  | 			// - iter from end of received accounts | ||||||
|  | 			// - iterate backward through received accounts | ||||||
|  | 			// - stop when we reach first index of received accounts | ||||||
|  | 			// - compare each received with the last index of expected accounts | ||||||
|  | 			// - after each compare, drop the last index of expected accounts | ||||||
|  | 			start = func(i []*model.Account) int { return len(i) - 1 } | ||||||
|  | 			iter = func(i int) int { return i - 1 } | ||||||
|  | 			check = func(idx int, i []*model.Account) bool { return idx >= 0 } | ||||||
|  | 			expect = func(i []interface{}) interface{} { return i[len(i)-1] } | ||||||
|  | 			trunc = func(i []interface{}) []interface{} { return i[:len(i)-1] } | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for i := start(accounts); check(i, accounts); i = iter(i) { | ||||||
|  | 			// Get next expected account. | ||||||
|  | 			iface := expect(expectAccounts) | ||||||
|  | 
 | ||||||
|  | 			// Check that expected account matches received. | ||||||
|  | 			expectAccID := iface.(*model.Account).ID | ||||||
|  | 			receivdAccID := accounts[i].ID | ||||||
|  | 			suite.Equal(expectAccID, receivdAccID, "unexpected account at position in response on page=%d", p) | ||||||
|  | 
 | ||||||
|  | 			// Drop checked from expected accounts. | ||||||
|  | 			expectAccounts = trunc(expectAccounts) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if len(expectAccounts) == 0 { | ||||||
|  | 			// Reached end. | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Parse response link header values. | ||||||
|  | 		values := result.Header.Values("Link") | ||||||
|  | 		links := linkheader.ParseMultiple(values) | ||||||
|  | 		filteredLinks := links.FilterByRel("next") | ||||||
|  | 		suite.NotEmpty(filteredLinks, "no next link provided with more remaining accounts on page=%d", p) | ||||||
|  | 
 | ||||||
|  | 		// A ref link header was set. | ||||||
|  | 		link := filteredLinks[0] | ||||||
|  | 
 | ||||||
|  | 		// Parse URI from URI string. | ||||||
|  | 		uri, err := url.Parse(link.URL) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Set next raw query value. | ||||||
|  | 		query = uri.RawQuery | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *FollowTestSuite) TestGetFollowingPageBackwardLimit2() { | ||||||
|  | 	suite.testGetFollowingPage(2, "backward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *FollowTestSuite) TestGetFollowingPageBackwardLimit4() { | ||||||
|  | 	suite.testGetFollowingPage(4, "backward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *FollowTestSuite) TestGetFollowingPageBackwardLimit6() { | ||||||
|  | 	suite.testGetFollowingPage(6, "backward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *FollowTestSuite) TestGetFollowingPageForwardLimit2() { | ||||||
|  | 	suite.testGetFollowingPage(2, "forward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *FollowTestSuite) TestGetFollowingPageForwardLimit4() { | ||||||
|  | 	suite.testGetFollowingPage(4, "forward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *FollowTestSuite) TestGetFollowingPageForwardLimit6() { | ||||||
|  | 	suite.testGetFollowingPage(6, "forward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *FollowTestSuite) testGetFollowingPage(limit int, direction string) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	// The authed local account we are going to use for HTTP requests | ||||||
|  | 	requestingAccount := suite.testAccounts["local_account_1"] | ||||||
|  | 	suite.clearAccountRelations(requestingAccount.ID) | ||||||
|  | 
 | ||||||
|  | 	// Get current time. | ||||||
|  | 	now := time.Now() | ||||||
|  | 
 | ||||||
|  | 	var i int | ||||||
|  | 
 | ||||||
|  | 	for _, targetAccount := range suite.testAccounts { | ||||||
|  | 		if targetAccount.ID == requestingAccount.ID { | ||||||
|  | 			// we cannot be our own target... | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Get next simple ID. | ||||||
|  | 		id := strconv.Itoa(i) | ||||||
|  | 		i++ | ||||||
|  | 
 | ||||||
|  | 		// put a follow in the database | ||||||
|  | 		err := suite.db.PutFollow(ctx, >smodel.Follow{ | ||||||
|  | 			ID:              id, | ||||||
|  | 			CreatedAt:       now, | ||||||
|  | 			UpdatedAt:       now, | ||||||
|  | 			URI:             fmt.Sprintf("%s/follow/%s", requestingAccount.URI, id), | ||||||
|  | 			AccountID:       requestingAccount.ID, | ||||||
|  | 			TargetAccountID: targetAccount.ID, | ||||||
|  | 		}) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Bump now by 1 second. | ||||||
|  | 		now = now.Add(time.Second) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get _ALL_ follows we expect to see without any paging (this filters invisible). | ||||||
|  | 	apiRsp, err := suite.processor.Account().FollowingGet(ctx, requestingAccount, requestingAccount.ID, nil) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	expectAccounts := apiRsp.Items // interfaced{} account slice | ||||||
|  | 
 | ||||||
|  | 	// Iteratively set | ||||||
|  | 	// link query string. | ||||||
|  | 	var query string | ||||||
|  | 
 | ||||||
|  | 	switch direction { | ||||||
|  | 	case "backward": | ||||||
|  | 		// Set the starting query to page backward from newest. | ||||||
|  | 		acc := expectAccounts[0].(*model.Account) | ||||||
|  | 		newest, _ := suite.db.GetFollow(ctx, requestingAccount.ID, acc.ID) | ||||||
|  | 		expectAccounts = expectAccounts[1:] | ||||||
|  | 		query = fmt.Sprintf("limit=%d&max_id=%s", limit, newest.ID) | ||||||
|  | 
 | ||||||
|  | 	case "forward": | ||||||
|  | 		// Set the starting query to page forward from the oldest. | ||||||
|  | 		acc := expectAccounts[len(expectAccounts)-1].(*model.Account) | ||||||
|  | 		oldest, _ := suite.db.GetFollow(ctx, requestingAccount.ID, acc.ID) | ||||||
|  | 		expectAccounts = expectAccounts[:len(expectAccounts)-1] | ||||||
|  | 		query = fmt.Sprintf("limit=%d&min_id=%s", limit, oldest.ID) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for p := 0; ; p++ { | ||||||
|  | 		// Prepare new request for endpoint | ||||||
|  | 		recorder := httptest.NewRecorder() | ||||||
|  | 		endpoint := fmt.Sprintf("/api/v1/accounts/%s/following", requestingAccount.ID) | ||||||
|  | 		ctx := suite.newContext(recorder, http.MethodGet, []byte{}, endpoint, "") | ||||||
|  | 		ctx.Params = gin.Params{{Key: "id", Value: requestingAccount.ID}} | ||||||
|  | 		ctx.Request.URL.RawQuery = query // setting provided next query value | ||||||
|  | 
 | ||||||
|  | 		// call the handler and check for valid response code. | ||||||
|  | 		suite.T().Logf("direction=%q page=%d query=%q", direction, p, query) | ||||||
|  | 		suite.accountsModule.AccountFollowingGETHandler(ctx) | ||||||
|  | 		suite.Equal(http.StatusOK, recorder.Code) | ||||||
|  | 
 | ||||||
|  | 		var accounts []*model.Account | ||||||
|  | 
 | ||||||
|  | 		// Decode response body into API account models | ||||||
|  | 		result := recorder.Result() | ||||||
|  | 		dec := json.NewDecoder(result.Body) | ||||||
|  | 		err := dec.Decode(&accounts) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 		_ = result.Body.Close() | ||||||
|  | 
 | ||||||
|  | 		var ( | ||||||
|  | 			// start provides the starting index for loop in accounts. | ||||||
|  | 			start func([]*model.Account) int | ||||||
|  | 
 | ||||||
|  | 			// iter performs the loop iter step with index. | ||||||
|  | 			iter func(int) int | ||||||
|  | 
 | ||||||
|  | 			// check performs the loop conditional check against index and accounts. | ||||||
|  | 			check func(int, []*model.Account) bool | ||||||
|  | 
 | ||||||
|  | 			// expect pulls the next account to check against from expectAccounts. | ||||||
|  | 			expect func([]interface{}) interface{} | ||||||
|  | 
 | ||||||
|  | 			// trunc drops the last checked account from expectAccounts. | ||||||
|  | 			trunc func([]interface{}) []interface{} | ||||||
|  | 		) | ||||||
|  | 
 | ||||||
|  | 		switch direction { | ||||||
|  | 		case "backward": | ||||||
|  | 			// When paging backwards (DESC) we: | ||||||
|  | 			// - iter from end of received accounts | ||||||
|  | 			// - iterate backward through received accounts | ||||||
|  | 			// - stop when we reach last index of received accounts | ||||||
|  | 			// - compare each received with the first index of expected accounts | ||||||
|  | 			// - after each compare, drop the first index of expected accounts | ||||||
|  | 			start = func([]*model.Account) int { return 0 } | ||||||
|  | 			iter = func(i int) int { return i + 1 } | ||||||
|  | 			check = func(idx int, i []*model.Account) bool { return idx < len(i) } | ||||||
|  | 			expect = func(i []interface{}) interface{} { return i[0] } | ||||||
|  | 			trunc = func(i []interface{}) []interface{} { return i[1:] } | ||||||
|  | 
 | ||||||
|  | 		case "forward": | ||||||
|  | 			// When paging forwards (ASC) we: | ||||||
|  | 			// - iter from end of received accounts | ||||||
|  | 			// - iterate backward through received accounts | ||||||
|  | 			// - stop when we reach first index of received accounts | ||||||
|  | 			// - compare each received with the last index of expected accounts | ||||||
|  | 			// - after each compare, drop the last index of expected accounts | ||||||
|  | 			start = func(i []*model.Account) int { return len(i) - 1 } | ||||||
|  | 			iter = func(i int) int { return i - 1 } | ||||||
|  | 			check = func(idx int, i []*model.Account) bool { return idx >= 0 } | ||||||
|  | 			expect = func(i []interface{}) interface{} { return i[len(i)-1] } | ||||||
|  | 			trunc = func(i []interface{}) []interface{} { return i[:len(i)-1] } | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for i := start(accounts); check(i, accounts); i = iter(i) { | ||||||
|  | 			// Get next expected account. | ||||||
|  | 			iface := expect(expectAccounts) | ||||||
|  | 
 | ||||||
|  | 			// Check that expected account matches received. | ||||||
|  | 			expectAccID := iface.(*model.Account).ID | ||||||
|  | 			receivdAccID := accounts[i].ID | ||||||
|  | 			suite.Equal(expectAccID, receivdAccID, "unexpected account at position in response on page=%d", p) | ||||||
|  | 
 | ||||||
|  | 			// Drop checked from expected accounts. | ||||||
|  | 			expectAccounts = trunc(expectAccounts) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if len(expectAccounts) == 0 { | ||||||
|  | 			// Reached end. | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Parse response link header values. | ||||||
|  | 		values := result.Header.Values("Link") | ||||||
|  | 		links := linkheader.ParseMultiple(values) | ||||||
|  | 		filteredLinks := links.FilterByRel("next") | ||||||
|  | 		suite.NotEmpty(filteredLinks, "no next link provided with more remaining accounts on page=%d", p) | ||||||
|  | 
 | ||||||
|  | 		// A ref link header was set. | ||||||
|  | 		link := filteredLinks[0] | ||||||
|  | 
 | ||||||
|  | 		// Parse URI from URI string. | ||||||
|  | 		uri, err := url.Parse(link.URL) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Set next raw query value. | ||||||
|  | 		query = uri.RawQuery | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *FollowTestSuite) clearAccountRelations(id string) { | ||||||
|  | 	// Esnure no account blocks exist between accounts. | ||||||
|  | 	_ = suite.db.DeleteAccountBlocks( | ||||||
|  | 		context.Background(), | ||||||
|  | 		id, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	// Ensure no account follows exist between accounts. | ||||||
|  | 	_ = suite.db.DeleteAccountFollows( | ||||||
|  | 		context.Background(), | ||||||
|  | 		id, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	// Ensure no account follow_requests exist between accounts. | ||||||
|  | 	_ = suite.db.DeleteAccountFollowRequests( | ||||||
|  | 		context.Background(), | ||||||
|  | 		id, | ||||||
|  | 	) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestFollowTestSuite(t *testing.T) { | func TestFollowTestSuite(t *testing.T) { | ||||||
| 	suite.Run(t, new(FollowTestSuite)) | 	suite.Run(t, new(FollowTestSuite)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -25,12 +25,20 @@ import ( | ||||||
| 	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" | 	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/paging" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // AccountFollowersGETHandler swagger:operation GET /api/v1/accounts/{id}/followers accountFollowers | // AccountFollowersGETHandler swagger:operation GET /api/v1/accounts/{id}/followers accountFollowers | ||||||
| // | // | ||||||
| // See followers of account with given id. | // See followers of account with given id. | ||||||
| // | // | ||||||
|  | // The next and previous queries can be parsed from the returned Link header. | ||||||
|  | // Example: | ||||||
|  | // | ||||||
|  | // ``` | ||||||
|  | // <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/followers?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/followers?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev" | ||||||
|  | // ```` | ||||||
|  | // | ||||||
| //	--- | //	--- | ||||||
| //	tags: | //	tags: | ||||||
| //	- accounts | //	- accounts | ||||||
|  | @ -45,6 +53,42 @@ import ( | ||||||
| //		description: Account ID. | //		description: Account ID. | ||||||
| //		in: path | //		in: path | ||||||
| //		required: true | //		required: true | ||||||
|  | //	- | ||||||
|  | //		name: max_id | ||||||
|  | //		type: string | ||||||
|  | //		description: >- | ||||||
|  | //			Return only follower accounts *OLDER* than the given max ID. | ||||||
|  | //			The follower account with the specified ID will not be included in the response. | ||||||
|  | //			NOTE: the ID is of the internal follow, NOT any of the returned accounts. | ||||||
|  | //		in: query | ||||||
|  | //		required: false | ||||||
|  | //	- | ||||||
|  | //		name: since_id | ||||||
|  | //		type: string | ||||||
|  | //		description: >- | ||||||
|  | //			Return only follower accounts *NEWER* than the given since ID. | ||||||
|  | //			The follower account with the specified ID will not be included in the response. | ||||||
|  | //			NOTE: the ID is of the internal follow, NOT any of the returned accounts. | ||||||
|  | //		in: query | ||||||
|  | //		required: false | ||||||
|  | //	- | ||||||
|  | //		name: min_id | ||||||
|  | //		type: string | ||||||
|  | //		description: >- | ||||||
|  | //			Return only follower accounts *IMMEDIATELY NEWER* than the given min ID. | ||||||
|  | //			The follower account with the specified ID will not be included in the response. | ||||||
|  | //			NOTE: the ID is of the internal follow, NOT any of the returned accounts. | ||||||
|  | //		in: query | ||||||
|  | //		required: false | ||||||
|  | //	- | ||||||
|  | //		name: limit | ||||||
|  | //		type: integer | ||||||
|  | //		description: Number of follower accounts to return. | ||||||
|  | //		default: 40 | ||||||
|  | //		minimum: 1 | ||||||
|  | //		maximum: 80 | ||||||
|  | //		in: query | ||||||
|  | //		required: false | ||||||
| // | // | ||||||
| //	security: | //	security: | ||||||
| //	- OAuth2 Bearer: | //	- OAuth2 Bearer: | ||||||
|  | @ -87,11 +131,25 @@ func (m *Module) AccountFollowersGETHandler(c *gin.Context) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	followers, errWithCode := m.processor.Account().FollowersGet(c.Request.Context(), authed.Account, targetAcctID) | 	page, errWithCode := paging.ParseIDPage(c, | ||||||
|  | 		1,  // min limit | ||||||
|  | 		80, // max limit | ||||||
|  | 		40, // default limit | ||||||
|  | 	) | ||||||
| 	if errWithCode != nil { | 	if errWithCode != nil { | ||||||
| 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	c.JSON(http.StatusOK, followers) | 	resp, errWithCode := m.processor.Account().FollowersGet(c.Request.Context(), authed.Account, targetAcctID, page) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if resp.LinkHeader != "" { | ||||||
|  | 		c.Header("Link", resp.LinkHeader) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	c.JSON(http.StatusOK, resp.Items) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -25,12 +25,20 @@ import ( | ||||||
| 	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" | 	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/paging" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // AccountFollowingGETHandler swagger:operation GET /api/v1/accounts/{id}/following accountFollowing | // AccountFollowingGETHandler swagger:operation GET /api/v1/accounts/{id}/following accountFollowing | ||||||
| // | // | ||||||
| // See accounts followed by given account id. | // See accounts followed by given account id. | ||||||
| // | // | ||||||
|  | // The next and previous queries can be parsed from the returned Link header. | ||||||
|  | // Example: | ||||||
|  | // | ||||||
|  | // ``` | ||||||
|  | // <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/following?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/accounts/0657WMDEC3KQDTD6NZ4XJZBK4M/following?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev" | ||||||
|  | // ```` | ||||||
|  | // | ||||||
| //	--- | //	--- | ||||||
| //	tags: | //	tags: | ||||||
| //	- accounts | //	- accounts | ||||||
|  | @ -45,6 +53,42 @@ import ( | ||||||
| //		description: Account ID. | //		description: Account ID. | ||||||
| //		in: path | //		in: path | ||||||
| //		required: true | //		required: true | ||||||
|  | //	- | ||||||
|  | //		name: max_id | ||||||
|  | //		type: string | ||||||
|  | //		description: >- | ||||||
|  | //			Return only following accounts *OLDER* than the given max ID. | ||||||
|  | //			The following account with the specified ID will not be included in the response. | ||||||
|  | //			NOTE: the ID is of the internal follow, NOT any of the returned accounts. | ||||||
|  | //		in: query | ||||||
|  | //		required: false | ||||||
|  | //	- | ||||||
|  | //		name: since_id | ||||||
|  | //		type: string | ||||||
|  | //		description: >- | ||||||
|  | //			Return only following accounts *NEWER* than the given since ID. | ||||||
|  | //			The following account with the specified ID will not be included in the response. | ||||||
|  | //			NOTE: the ID is of the internal follow, NOT any of the returned accounts. | ||||||
|  | //		in: query | ||||||
|  | //		required: false | ||||||
|  | //	- | ||||||
|  | //		name: min_id | ||||||
|  | //		type: string | ||||||
|  | //		description: >- | ||||||
|  | //			Return only following accounts *IMMEDIATELY NEWER* than the given min ID. | ||||||
|  | //			The following account with the specified ID will not be included in the response. | ||||||
|  | //			NOTE: the ID is of the internal follow, NOT any of the returned accounts. | ||||||
|  | //		in: query | ||||||
|  | //		required: false | ||||||
|  | //	- | ||||||
|  | //		name: limit | ||||||
|  | //		type: integer | ||||||
|  | //		description: Number of following accounts to return. | ||||||
|  | //		default: 40 | ||||||
|  | //		minimum: 1 | ||||||
|  | //		maximum: 80 | ||||||
|  | //		in: query | ||||||
|  | //		required: false | ||||||
| // | // | ||||||
| //	security: | //	security: | ||||||
| //	- OAuth2 Bearer: | //	- OAuth2 Bearer: | ||||||
|  | @ -87,11 +131,25 @@ func (m *Module) AccountFollowingGETHandler(c *gin.Context) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	following, errWithCode := m.processor.Account().FollowingGet(c.Request.Context(), authed.Account, targetAcctID) | 	page, errWithCode := paging.ParseIDPage(c, | ||||||
|  | 		1,  // min limit | ||||||
|  | 		80, // max limit | ||||||
|  | 		40, // default limit | ||||||
|  | 	) | ||||||
| 	if errWithCode != nil { | 	if errWithCode != nil { | ||||||
| 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	c.JSON(http.StatusOK, following) | 	resp, errWithCode := m.processor.Account().FollowingGet(c.Request.Context(), authed.Account, targetAcctID, page) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if resp.LinkHeader != "" { | ||||||
|  | 		c.Header("Link", resp.LinkHeader) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	c.JSON(http.StatusOK, resp.Items) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -47,25 +47,40 @@ import ( | ||||||
| // | // | ||||||
| //	parameters: | //	parameters: | ||||||
| //	- | //	- | ||||||
| //		name: limit |  | ||||||
| //		type: integer |  | ||||||
| //		description: Number of blocks to return. |  | ||||||
| //		default: 20 |  | ||||||
| //		in: query |  | ||||||
| //	- |  | ||||||
| //		name: max_id | //		name: max_id | ||||||
| //		type: string | //		type: string | ||||||
| //		description: >- | //		description: >- | ||||||
| //			Return only blocks *OLDER* than the given block ID. | //			Return only blocked accounts *OLDER* than the given max ID. | ||||||
| //			The block with the specified ID will not be included in the response. | //			The blocked account with the specified ID will not be included in the response. | ||||||
|  | //			NOTE: the ID is of the internal block, NOT any of the returned accounts. | ||||||
| //		in: query | //		in: query | ||||||
|  | //		required: false | ||||||
| //	- | //	- | ||||||
| //		name: since_id | //		name: since_id | ||||||
| //		type: string | //		type: string | ||||||
| //		description: >- | //		description: >- | ||||||
| //		  Return only blocks *NEWER* than the given block ID. | //			Return only blocked accounts *NEWER* than the given since ID. | ||||||
| //		  The block with the specified ID will not be included in the response. | //			The blocked account with the specified ID will not be included in the response. | ||||||
|  | //			NOTE: the ID is of the internal block, NOT any of the returned accounts. | ||||||
| //		in: query | //		in: query | ||||||
|  | //	- | ||||||
|  | //		name: min_id | ||||||
|  | //		type: string | ||||||
|  | //		description: >- | ||||||
|  | //			Return only blocked accounts *IMMEDIATELY NEWER* than the given min ID. | ||||||
|  | //			The blocked account with the specified ID will not be included in the response. | ||||||
|  | //			NOTE: the ID is of the internal block, NOT any of the returned accounts. | ||||||
|  | //		in: query | ||||||
|  | //		required: false | ||||||
|  | //	- | ||||||
|  | //		name: limit | ||||||
|  | //		type: integer | ||||||
|  | //		description: Number of blocked accounts to return. | ||||||
|  | //		default: 40 | ||||||
|  | //		minimum: 1 | ||||||
|  | //		maximum: 80 | ||||||
|  | //		in: query | ||||||
|  | //		required: false | ||||||
| // | // | ||||||
| //	security: | //	security: | ||||||
| //	- OAuth2 Bearer: | //	- OAuth2 Bearer: | ||||||
|  | @ -105,15 +120,15 @@ func (m *Module) BlocksGETHandler(c *gin.Context) { | ||||||
| 
 | 
 | ||||||
| 	page, errWithCode := paging.ParseIDPage(c, | 	page, errWithCode := paging.ParseIDPage(c, | ||||||
| 		1,  // min limit | 		1,  // min limit | ||||||
| 		100, // max limit | 		80, // max limit | ||||||
| 		20,  // default limit | 		40, // default limit | ||||||
| 	) | 	) | ||||||
| 	if errWithCode != nil { | 	if errWithCode != nil { | ||||||
| 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	resp, errWithCode := m.processor.BlocksGet( | 	resp, errWithCode := m.processor.Account().BlocksGet( | ||||||
| 		c.Request.Context(), | 		c.Request.Context(), | ||||||
| 		authed.Account, | 		authed.Account, | ||||||
| 		page, | 		page, | ||||||
|  |  | ||||||
|  | @ -87,7 +87,7 @@ func (m *Module) FollowRequestAuthorizePOSTHandler(c *gin.Context) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	relationship, errWithCode := m.processor.FollowRequestAccept(c.Request.Context(), authed, originAccountID) | 	relationship, errWithCode := m.processor.Account().FollowRequestAccept(c.Request.Context(), authed.Account, originAccountID) | ||||||
| 	if errWithCode != nil { | 	if errWithCode != nil { | ||||||
| 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
| 		return | 		return | ||||||
|  |  | ||||||
|  | @ -24,12 +24,19 @@ import ( | ||||||
| 	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" | 	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/paging" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // FollowRequestGETHandler swagger:operation GET /api/v1/follow_requests getFollowRequests | // FollowRequestGETHandler swagger:operation GET /api/v1/follow_requests getFollowRequests | ||||||
| // | // | ||||||
| // Get an array of accounts that have requested to follow you. | // Get an array of accounts that have requested to follow you. | ||||||
| // Accounts will be sorted in order of follow request date descending (newest first). | // | ||||||
|  | // The next and previous queries can be parsed from the returned Link header. | ||||||
|  | // Example: | ||||||
|  | // | ||||||
|  | // ``` | ||||||
|  | // <https://example.org/api/v1/follow_requests?limit=80&max_id=01FC0SKA48HNSVR6YKZCQGS2V8>; rel="next", <https://example.org/api/v1/follow_requests?limit=80&min_id=01FC0SKW5JK2Q4EVAV2B462YY0>; rel="prev" | ||||||
|  | // ```` | ||||||
| // | // | ||||||
| //	--- | //	--- | ||||||
| //	tags: | //	tags: | ||||||
|  | @ -40,11 +47,41 @@ import ( | ||||||
| // | // | ||||||
| //	parameters: | //	parameters: | ||||||
| //	- | //	- | ||||||
|  | //		name: max_id | ||||||
|  | //		type: string | ||||||
|  | //		description: >- | ||||||
|  | //			Return only follow requesting accounts *OLDER* than the given max ID. | ||||||
|  | //			The follow requester with the specified ID will not be included in the response. | ||||||
|  | //			NOTE: the ID is of the internal follow request, NOT any of the returned accounts. | ||||||
|  | //		in: query | ||||||
|  | //		required: false | ||||||
|  | //	- | ||||||
|  | //		name: since_id | ||||||
|  | //		type: string | ||||||
|  | //		description: >- | ||||||
|  | //			Return only follow requesting accounts *NEWER* than the given since ID. | ||||||
|  | //			The follow requester with the specified ID will not be included in the response. | ||||||
|  | //			NOTE: the ID is of the internal follow request, NOT any of the returned accounts. | ||||||
|  | //		in: query | ||||||
|  | //		required: false | ||||||
|  | //	- | ||||||
|  | //		name: min_id | ||||||
|  | //		type: string | ||||||
|  | //		description: >- | ||||||
|  | //			Return only follow requesting accounts *IMMEDIATELY NEWER* than the given min ID. | ||||||
|  | //			The follow requester with the specified ID will not be included in the response. | ||||||
|  | //			NOTE: the ID is of the internal follow request, NOT any of the returned accounts. | ||||||
|  | //		in: query | ||||||
|  | //		required: false | ||||||
|  | //	- | ||||||
| //		name: limit | //		name: limit | ||||||
| //		type: integer | //		type: integer | ||||||
| //		description: Number of accounts to return. | //		description: Number of follow requesting accounts to return. | ||||||
| //		default: 40 | //		default: 40 | ||||||
|  | //		minimum: 1 | ||||||
|  | //		maximum: 80 | ||||||
| //		in: query | //		in: query | ||||||
|  | //		required: false | ||||||
| // | // | ||||||
| //	security: | //	security: | ||||||
| //	- OAuth2 Bearer: | //	- OAuth2 Bearer: | ||||||
|  | @ -82,11 +119,25 @@ func (m *Module) FollowRequestGETHandler(c *gin.Context) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	accts, errWithCode := m.processor.FollowRequestsGet(c.Request.Context(), authed) | 	page, errWithCode := paging.ParseIDPage(c, | ||||||
|  | 		1,  // min limit | ||||||
|  | 		80, // max limit | ||||||
|  | 		40, // default limit | ||||||
|  | 	) | ||||||
| 	if errWithCode != nil { | 	if errWithCode != nil { | ||||||
| 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	c.JSON(http.StatusOK, accts) | 	resp, errWithCode := m.processor.Account().FollowRequestsGet(c.Request.Context(), authed.Account, page) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if resp.LinkHeader != "" { | ||||||
|  | 		c.Header("Link", resp.LinkHeader) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	c.JSON(http.StatusOK, resp.Items) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -22,17 +22,25 @@ import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" | 	"io" | ||||||
|  | 	"math/rand" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | 	"net/url" | ||||||
|  | 	"strconv" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| 	"github.com/stretchr/testify/suite" | 	"github.com/stretchr/testify/suite" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/tomnomnom/linkheader" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // random reader according to current-time source seed. | ||||||
|  | var randRd = rand.New(rand.NewSource(time.Now().Unix())) | ||||||
|  | 
 | ||||||
| type GetTestSuite struct { | type GetTestSuite struct { | ||||||
| 	FollowRequestStandardTestSuite | 	FollowRequestStandardTestSuite | ||||||
| } | } | ||||||
|  | @ -68,7 +76,7 @@ func (suite *GetTestSuite) TestGet() { | ||||||
| 	defer result.Body.Close() | 	defer result.Body.Close() | ||||||
| 
 | 
 | ||||||
| 	// check the response | 	// check the response | ||||||
| 	b, err := ioutil.ReadAll(result.Body) | 	b, err := io.ReadAll(result.Body) | ||||||
| 	assert.NoError(suite.T(), err) | 	assert.NoError(suite.T(), err) | ||||||
| 	dst := new(bytes.Buffer) | 	dst := new(bytes.Buffer) | ||||||
| 	err = json.Indent(dst, b, "", "  ") | 	err = json.Indent(dst, b, "", "  ") | ||||||
|  | @ -99,6 +107,214 @@ func (suite *GetTestSuite) TestGet() { | ||||||
| ]`, dst.String()) | ]`, dst.String()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (suite *GetTestSuite) TestGetPageBackwardLimit2() { | ||||||
|  | 	suite.testGetPage(2, "backward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *GetTestSuite) TestGetPageBackwardLimit4() { | ||||||
|  | 	suite.testGetPage(4, "backward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *GetTestSuite) TestGetPageBackwardLimit6() { | ||||||
|  | 	suite.testGetPage(6, "backward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *GetTestSuite) TestGetPageForwardLimit2() { | ||||||
|  | 	suite.testGetPage(2, "forward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *GetTestSuite) TestGetPageForwardLimit4() { | ||||||
|  | 	suite.testGetPage(4, "forward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *GetTestSuite) TestGetPageForwardLimit6() { | ||||||
|  | 	suite.testGetPage(6, "forward") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *GetTestSuite) testGetPage(limit int, direction string) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	// The authed local account we are going to use for HTTP requests | ||||||
|  | 	requestingAccount := suite.testAccounts["local_account_1"] | ||||||
|  | 	suite.clearAccountRelations(requestingAccount.ID) | ||||||
|  | 
 | ||||||
|  | 	// Get current time. | ||||||
|  | 	now := time.Now() | ||||||
|  | 
 | ||||||
|  | 	var i int | ||||||
|  | 
 | ||||||
|  | 	for _, targetAccount := range suite.testAccounts { | ||||||
|  | 		if targetAccount.ID == requestingAccount.ID { | ||||||
|  | 			// we cannot be our own target... | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Get next simple ID. | ||||||
|  | 		id := strconv.Itoa(i) | ||||||
|  | 		i++ | ||||||
|  | 
 | ||||||
|  | 		// put a follow request in the database | ||||||
|  | 		err := suite.db.PutFollowRequest(ctx, >smodel.FollowRequest{ | ||||||
|  | 			ID:              id, | ||||||
|  | 			CreatedAt:       now, | ||||||
|  | 			UpdatedAt:       now, | ||||||
|  | 			URI:             fmt.Sprintf("%s/follow/%s", targetAccount.URI, id), | ||||||
|  | 			AccountID:       targetAccount.ID, | ||||||
|  | 			TargetAccountID: requestingAccount.ID, | ||||||
|  | 		}) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Bump now by 1 second. | ||||||
|  | 		now = now.Add(time.Second) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get _ALL_ follow requests we expect to see without any paging (this filters invisible). | ||||||
|  | 	apiRsp, err := suite.processor.Account().FollowRequestsGet(ctx, requestingAccount, nil) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	expectAccounts := apiRsp.Items // interfaced{} account slice | ||||||
|  | 
 | ||||||
|  | 	// Iteratively set | ||||||
|  | 	// link query string. | ||||||
|  | 	var query string | ||||||
|  | 
 | ||||||
|  | 	switch direction { | ||||||
|  | 	case "backward": | ||||||
|  | 		// Set the starting query to page backward from newest. | ||||||
|  | 		acc := expectAccounts[0].(*model.Account) | ||||||
|  | 		newest, _ := suite.db.GetFollowRequest(ctx, acc.ID, requestingAccount.ID) | ||||||
|  | 		expectAccounts = expectAccounts[1:] | ||||||
|  | 		query = fmt.Sprintf("limit=%d&max_id=%s", limit, newest.ID) | ||||||
|  | 
 | ||||||
|  | 	case "forward": | ||||||
|  | 		// Set the starting query to page forward from the oldest. | ||||||
|  | 		acc := expectAccounts[len(expectAccounts)-1].(*model.Account) | ||||||
|  | 		oldest, _ := suite.db.GetFollowRequest(ctx, acc.ID, requestingAccount.ID) | ||||||
|  | 		expectAccounts = expectAccounts[:len(expectAccounts)-1] | ||||||
|  | 		query = fmt.Sprintf("limit=%d&min_id=%s", limit, oldest.ID) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for p := 0; ; p++ { | ||||||
|  | 		// Prepare new request for endpoint | ||||||
|  | 		recorder := httptest.NewRecorder() | ||||||
|  | 		ctx := suite.newContext(recorder, http.MethodGet, []byte{}, "/api/v1/follow_requests", "") | ||||||
|  | 		ctx.Request.URL.RawQuery = query // setting provided next query value | ||||||
|  | 
 | ||||||
|  | 		// call the handler and check for valid response code. | ||||||
|  | 		suite.T().Logf("direction=%q page=%d query=%q", direction, p, query) | ||||||
|  | 		suite.followRequestModule.FollowRequestGETHandler(ctx) | ||||||
|  | 		suite.Equal(http.StatusOK, recorder.Code) | ||||||
|  | 
 | ||||||
|  | 		var accounts []*model.Account | ||||||
|  | 
 | ||||||
|  | 		// Decode response body into API account models | ||||||
|  | 		result := recorder.Result() | ||||||
|  | 		dec := json.NewDecoder(result.Body) | ||||||
|  | 		err := dec.Decode(&accounts) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 		_ = result.Body.Close() | ||||||
|  | 
 | ||||||
|  | 		var ( | ||||||
|  | 
 | ||||||
|  | 			// start provides the starting index for loop in accounts. | ||||||
|  | 			start func([]*model.Account) int | ||||||
|  | 
 | ||||||
|  | 			// iter performs the loop iter step with index. | ||||||
|  | 			iter func(int) int | ||||||
|  | 
 | ||||||
|  | 			// check performs the loop conditional check against index and accounts. | ||||||
|  | 			check func(int, []*model.Account) bool | ||||||
|  | 
 | ||||||
|  | 			// expect pulls the next account to check against from expectAccounts. | ||||||
|  | 			expect func([]interface{}) interface{} | ||||||
|  | 
 | ||||||
|  | 			// trunc drops the last checked account from expectAccounts. | ||||||
|  | 			trunc func([]interface{}) []interface{} | ||||||
|  | 		) | ||||||
|  | 
 | ||||||
|  | 		switch direction { | ||||||
|  | 		case "backward": | ||||||
|  | 			// When paging backwards (DESC) we: | ||||||
|  | 			// - iter from end of received accounts | ||||||
|  | 			// - iterate backward through received accounts | ||||||
|  | 			// - stop when we reach last index of received accounts | ||||||
|  | 			// - compare each received with the first index of expected accounts | ||||||
|  | 			// - after each compare, drop the first index of expected accounts | ||||||
|  | 			start = func([]*model.Account) int { return 0 } | ||||||
|  | 			iter = func(i int) int { return i + 1 } | ||||||
|  | 			check = func(idx int, i []*model.Account) bool { return idx < len(i) } | ||||||
|  | 			expect = func(i []interface{}) interface{} { return i[0] } | ||||||
|  | 			trunc = func(i []interface{}) []interface{} { return i[1:] } | ||||||
|  | 
 | ||||||
|  | 		case "forward": | ||||||
|  | 			// When paging forwards (ASC) we: | ||||||
|  | 			// - iter from end of received accounts | ||||||
|  | 			// - iterate backward through received accounts | ||||||
|  | 			// - stop when we reach first index of received accounts | ||||||
|  | 			// - compare each received with the last index of expected accounts | ||||||
|  | 			// - after each compare, drop the last index of expected accounts | ||||||
|  | 			start = func(i []*model.Account) int { return len(i) - 1 } | ||||||
|  | 			iter = func(i int) int { return i - 1 } | ||||||
|  | 			check = func(idx int, i []*model.Account) bool { return idx >= 0 } | ||||||
|  | 			expect = func(i []interface{}) interface{} { return i[len(i)-1] } | ||||||
|  | 			trunc = func(i []interface{}) []interface{} { return i[:len(i)-1] } | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for i := start(accounts); check(i, accounts); i = iter(i) { | ||||||
|  | 			// Get next expected account. | ||||||
|  | 			iface := expect(expectAccounts) | ||||||
|  | 
 | ||||||
|  | 			// Check that expected account matches received. | ||||||
|  | 			expectAccID := iface.(*model.Account).ID | ||||||
|  | 			receivdAccID := accounts[i].ID | ||||||
|  | 			suite.Equal(expectAccID, receivdAccID, "unexpected account at position in response on page=%d", p) | ||||||
|  | 
 | ||||||
|  | 			// Drop checked from expected accounts. | ||||||
|  | 			expectAccounts = trunc(expectAccounts) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if len(expectAccounts) == 0 { | ||||||
|  | 			// Reached end. | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Parse response link header values. | ||||||
|  | 		values := result.Header.Values("Link") | ||||||
|  | 		links := linkheader.ParseMultiple(values) | ||||||
|  | 		filteredLinks := links.FilterByRel("next") | ||||||
|  | 		suite.NotEmpty(filteredLinks, "no next link provided with more remaining accounts on page=%d", p) | ||||||
|  | 
 | ||||||
|  | 		// A ref link header was set. | ||||||
|  | 		link := filteredLinks[0] | ||||||
|  | 
 | ||||||
|  | 		// Parse URI from URI string. | ||||||
|  | 		uri, err := url.Parse(link.URL) | ||||||
|  | 		suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 		// Set next raw query value. | ||||||
|  | 		query = uri.RawQuery | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *GetTestSuite) clearAccountRelations(id string) { | ||||||
|  | 	// Esnure no account blocks exist between accounts. | ||||||
|  | 	_ = suite.db.DeleteAccountBlocks( | ||||||
|  | 		context.Background(), | ||||||
|  | 		id, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	// Ensure no account follows exist between accounts. | ||||||
|  | 	_ = suite.db.DeleteAccountFollows( | ||||||
|  | 		context.Background(), | ||||||
|  | 		id, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	// Ensure no account follow_requests exist between accounts. | ||||||
|  | 	_ = suite.db.DeleteAccountFollowRequests( | ||||||
|  | 		context.Background(), | ||||||
|  | 		id, | ||||||
|  | 	) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestGetTestSuite(t *testing.T) { | func TestGetTestSuite(t *testing.T) { | ||||||
| 	suite.Run(t, &GetTestSuite{}) | 	suite.Run(t, &GetTestSuite{}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -85,7 +85,7 @@ func (m *Module) FollowRequestRejectPOSTHandler(c *gin.Context) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	relationship, errWithCode := m.processor.FollowRequestReject(c.Request.Context(), authed, originAccountID) | 	relationship, errWithCode := m.processor.Account().FollowRequestReject(c.Request.Context(), authed.Account, originAccountID) | ||||||
| 	if errWithCode != nil { | 	if errWithCode != nil { | ||||||
| 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) | ||||||
| 		return | 		return | ||||||
|  |  | ||||||
|  | @ -102,8 +102,8 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount | ||||||
| 	return &rel, nil | 	return &rel, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { | func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) { | ||||||
| 	followIDs, err := r.getAccountFollowIDs(ctx, accountID) | 	followIDs, err := r.getAccountFollowIDs(ctx, accountID, page) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | @ -118,8 +118,8 @@ func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID s | ||||||
| 	return r.GetFollowsByIDs(ctx, followIDs) | 	return r.GetFollowsByIDs(ctx, followIDs) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { | func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) { | ||||||
| 	followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) | 	followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, page) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | @ -134,16 +134,16 @@ func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID | ||||||
| 	return r.GetFollowsByIDs(ctx, followerIDs) | 	return r.GetFollowsByIDs(ctx, followerIDs) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { | func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) { | ||||||
| 	followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) | 	followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, page) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return r.GetFollowRequestsByIDs(ctx, followReqIDs) | 	return r.GetFollowRequestsByIDs(ctx, followReqIDs) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) { | func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) { | ||||||
| 	followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) | 	followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, page) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | @ -151,39 +151,15 @@ func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, account | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) { | func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) { | ||||||
| 	// Load block IDs from cache with database loader callback. | 	blockIDs, err := r.getAccountBlockIDs(ctx, accountID, page) | ||||||
| 	blockIDs, err := r.state.Caches.GTS.BlockIDs().Load(accountID, func() ([]string, error) { |  | ||||||
| 		var blockIDs []string |  | ||||||
| 
 |  | ||||||
| 		// Block IDs not in cache, perform DB query! |  | ||||||
| 		q := newSelectBlocks(r.db, accountID) |  | ||||||
| 		if _, err := q.Exec(ctx, &blockIDs); err != nil { |  | ||||||
| 			return nil, err |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		return blockIDs, nil |  | ||||||
| 	}) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 |  | ||||||
| 	// Our cached / selected block IDs are |  | ||||||
| 	// ALWAYS stored in descending order. |  | ||||||
| 	// Depending on the paging requested |  | ||||||
| 	// this may be an unexpected order. |  | ||||||
| 	if !page.GetOrder().Ascending() { |  | ||||||
| 		blockIDs = paging.Reverse(blockIDs) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Page the resulting block IDs. |  | ||||||
| 	blockIDs = page.Page(blockIDs) |  | ||||||
| 
 |  | ||||||
| 	// Convert these IDs to full block objects. |  | ||||||
| 	return r.GetBlocksByIDs(ctx, blockIDs) | 	return r.GetBlocksByIDs(ctx, blockIDs) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { | func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { | ||||||
| 	followIDs, err := r.getAccountFollowIDs(ctx, accountID) | 	followIDs, err := r.getAccountFollowIDs(ctx, accountID, nil) | ||||||
| 	return len(followIDs), err | 	return len(followIDs), err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -193,7 +169,7 @@ func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { | func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { | ||||||
| 	followerIDs, err := r.getAccountFollowerIDs(ctx, accountID) | 	followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, nil) | ||||||
| 	return len(followerIDs), err | 	return len(followerIDs), err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -203,17 +179,22 @@ func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, account | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { | func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { | ||||||
| 	followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID) | 	followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, nil) | ||||||
| 	return len(followReqIDs), err | 	return len(followReqIDs), err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { | func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { | ||||||
| 	followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID) | 	followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, nil) | ||||||
| 	return len(followReqIDs), err | 	return len(followReqIDs), err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string) ([]string, error) { | func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID string) (int, error) { | ||||||
| 	return r.state.Caches.GTS.FollowIDs().Load(">"+accountID, func() ([]string, error) { | 	blockIDs, err := r.getAccountBlockIDs(ctx, accountID, nil) | ||||||
|  | 	return len(blockIDs), err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { | ||||||
|  | 	return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), ">"+accountID, page, func() ([]string, error) { | ||||||
| 		var followIDs []string | 		var followIDs []string | ||||||
| 
 | 
 | ||||||
| 		// Follow IDs not in cache, perform DB query! | 		// Follow IDs not in cache, perform DB query! | ||||||
|  | @ -240,8 +221,8 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string) ([]string, error) { | func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { | ||||||
| 	return r.state.Caches.GTS.FollowIDs().Load("<"+accountID, func() ([]string, error) { | 	return loadPagedIDs(r.state.Caches.GTS.FollowIDs(), "<"+accountID, page, func() ([]string, error) { | ||||||
| 		var followIDs []string | 		var followIDs []string | ||||||
| 
 | 
 | ||||||
| 		// Follow IDs not in cache, perform DB query! | 		// Follow IDs not in cache, perform DB query! | ||||||
|  | @ -268,8 +249,8 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string) ([]string, error) { | func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { | ||||||
| 	return r.state.Caches.GTS.FollowRequestIDs().Load(">"+accountID, func() ([]string, error) { | 	return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), ">"+accountID, page, func() ([]string, error) { | ||||||
| 		var followReqIDs []string | 		var followReqIDs []string | ||||||
| 
 | 
 | ||||||
| 		// Follow request IDs not in cache, perform DB query! | 		// Follow request IDs not in cache, perform DB query! | ||||||
|  | @ -282,8 +263,8 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string) ([]string, error) { | func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { | ||||||
| 	return r.state.Caches.GTS.FollowRequestIDs().Load("<"+accountID, func() ([]string, error) { | 	return loadPagedIDs(r.state.Caches.GTS.FollowRequestIDs(), "<"+accountID, page, func() ([]string, error) { | ||||||
| 		var followReqIDs []string | 		var followReqIDs []string | ||||||
| 
 | 
 | ||||||
| 		// Follow request IDs not in cache, perform DB query! | 		// Follow request IDs not in cache, perform DB query! | ||||||
|  | @ -296,13 +277,27 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { | ||||||
|  | 	return loadPagedIDs(r.state.Caches.GTS.BlockIDs(), accountID, page, func() ([]string, error) { | ||||||
|  | 		var blockIDs []string | ||||||
|  | 
 | ||||||
|  | 		// Block IDs not in cache, perform DB query! | ||||||
|  | 		q := newSelectBlocks(r.db, accountID) | ||||||
|  | 		if _, err := q.Exec(ctx, &blockIDs); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return blockIDs, nil | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. | // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. | ||||||
| func newSelectFollowRequests(db *DB, accountID string) *bun.SelectQuery { | func newSelectFollowRequests(db *DB, accountID string) *bun.SelectQuery { | ||||||
| 	return db.NewSelect(). | 	return db.NewSelect(). | ||||||
| 		TableExpr("?", bun.Ident("follow_requests")). | 		TableExpr("?", bun.Ident("follow_requests")). | ||||||
| 		ColumnExpr("?", bun.Ident("id")). | 		ColumnExpr("?", bun.Ident("id")). | ||||||
| 		Where("? = ?", bun.Ident("target_account_id"), accountID). | 		Where("? = ?", bun.Ident("target_account_id"), accountID). | ||||||
| 		OrderExpr("? DESC", bun.Ident("updated_at")) | 		OrderExpr("? DESC", bun.Ident("id")) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID. | // newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID. | ||||||
|  | @ -311,7 +306,7 @@ func newSelectFollowRequesting(db *DB, accountID string) *bun.SelectQuery { | ||||||
| 		TableExpr("?", bun.Ident("follow_requests")). | 		TableExpr("?", bun.Ident("follow_requests")). | ||||||
| 		ColumnExpr("?", bun.Ident("id")). | 		ColumnExpr("?", bun.Ident("id")). | ||||||
| 		Where("? = ?", bun.Ident("target_account_id"), accountID). | 		Where("? = ?", bun.Ident("target_account_id"), accountID). | ||||||
| 		OrderExpr("? DESC", bun.Ident("updated_at")) | 		OrderExpr("? DESC", bun.Ident("id")) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID. | // newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID. | ||||||
|  | @ -320,7 +315,7 @@ func newSelectFollows(db *DB, accountID string) *bun.SelectQuery { | ||||||
| 		Table("follows"). | 		Table("follows"). | ||||||
| 		Column("id"). | 		Column("id"). | ||||||
| 		Where("? = ?", bun.Ident("account_id"), accountID). | 		Where("? = ?", bun.Ident("account_id"), accountID). | ||||||
| 		OrderExpr("? DESC", bun.Ident("updated_at")) | 		OrderExpr("? DESC", bun.Ident("id")) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newSelectLocalFollows returns a new select query for all rows in the follows table with | // newSelectLocalFollows returns a new select query for all rows in the follows table with | ||||||
|  | @ -338,7 +333,7 @@ func newSelectLocalFollows(db *DB, accountID string) *bun.SelectQuery { | ||||||
| 				Column("id"). | 				Column("id"). | ||||||
| 				Where("? IS NULL", bun.Ident("domain")), | 				Where("? IS NULL", bun.Ident("domain")), | ||||||
| 		). | 		). | ||||||
| 		OrderExpr("? DESC", bun.Ident("updated_at")) | 		OrderExpr("? DESC", bun.Ident("id")) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID. | // newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID. | ||||||
|  | @ -347,7 +342,7 @@ func newSelectFollowers(db *DB, accountID string) *bun.SelectQuery { | ||||||
| 		Table("follows"). | 		Table("follows"). | ||||||
| 		Column("id"). | 		Column("id"). | ||||||
| 		Where("? = ?", bun.Ident("target_account_id"), accountID). | 		Where("? = ?", bun.Ident("target_account_id"), accountID). | ||||||
| 		OrderExpr("? DESC", bun.Ident("updated_at")) | 		OrderExpr("? DESC", bun.Ident("id")) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newSelectLocalFollowers returns a new select query for all rows in the follows table with | // newSelectLocalFollowers returns a new select query for all rows in the follows table with | ||||||
|  | @ -365,14 +360,14 @@ func newSelectLocalFollowers(db *DB, accountID string) *bun.SelectQuery { | ||||||
| 				Column("id"). | 				Column("id"). | ||||||
| 				Where("? IS NULL", bun.Ident("domain")), | 				Where("? IS NULL", bun.Ident("domain")), | ||||||
| 		). | 		). | ||||||
| 		OrderExpr("? DESC", bun.Ident("updated_at")) | 		OrderExpr("? DESC", bun.Ident("id")) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID. | // newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID. | ||||||
| func newSelectBlocks(db *DB, accountID string) *bun.SelectQuery { | func newSelectBlocks(db *DB, accountID string) *bun.SelectQuery { | ||||||
| 	return db.NewSelect(). | 	return db.NewSelect(). | ||||||
| 		TableExpr("?", bun.Ident("blocks")). | 		TableExpr("?", bun.Ident("blocks")). | ||||||
| 		ColumnExpr("?", bun.Ident("?")). | 		ColumnExpr("?", bun.Ident("id")). | ||||||
| 		Where("? = ?", bun.Ident("account_id"), accountID). | 		Where("? = ?", bun.Ident("account_id"), accountID). | ||||||
| 		OrderExpr("? DESC", bun.Ident("updated_at")) | 		OrderExpr("? DESC", bun.Ident("id")) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -753,14 +753,14 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID) | 	followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID, nil) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.Len(followRequests, 1) | 	suite.Len(followRequests, 1) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *RelationshipTestSuite) TestGetAccountFollows() { | func (suite *RelationshipTestSuite) TestGetAccountFollows() { | ||||||
| 	account := suite.testAccounts["local_account_1"] | 	account := suite.testAccounts["local_account_1"] | ||||||
| 	follows, err := suite.db.GetAccountFollows(context.Background(), account.ID) | 	follows, err := suite.db.GetAccountFollows(context.Background(), account.ID, nil) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.Len(follows, 2) | 	suite.Len(follows, 2) | ||||||
| } | } | ||||||
|  | @ -781,7 +781,7 @@ func (suite *RelationshipTestSuite) TestCountAccountFollows() { | ||||||
| 
 | 
 | ||||||
| func (suite *RelationshipTestSuite) TestGetAccountFollowers() { | func (suite *RelationshipTestSuite) TestGetAccountFollowers() { | ||||||
| 	account := suite.testAccounts["local_account_1"] | 	account := suite.testAccounts["local_account_1"] | ||||||
| 	follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID) | 	follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID, nil) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	suite.Len(follows, 2) | 	suite.Len(follows, 2) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -114,6 +114,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI | ||||||
| 	follows, err := t.state.DB.GetAccountFollows( | 	follows, err := t.state.DB.GetAccountFollows( | ||||||
| 		gtscontext.SetBarebones(ctx), | 		gtscontext.SetBarebones(ctx), | ||||||
| 		accountID, | 		accountID, | ||||||
|  | 		nil, // select all | ||||||
| 	) | 	) | ||||||
| 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return nil, gtserror.Newf("db error getting follows for account %s: %w", accountID, err) | 		return nil, gtserror.Newf("db error getting follows for account %s: %w", accountID, err) | ||||||
|  |  | ||||||
|  | @ -167,8 +167,8 @@ func (suite *TimelineTestSuite) TestGetHomeTimelineNoFollowing() { | ||||||
| 	follows, err := suite.state.DB.GetAccountFollows( | 	follows, err := suite.state.DB.GetAccountFollows( | ||||||
| 		gtscontext.SetBarebones(ctx), | 		gtscontext.SetBarebones(ctx), | ||||||
| 		viewingAccount.ID, | 		viewingAccount.ID, | ||||||
|  | 		nil, // select all | ||||||
| 	) | 	) | ||||||
| 
 |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -20,7 +20,9 @@ package bundb | ||||||
| import ( | import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/cache" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/paging" | ||||||
| 	"github.com/uptrace/bun" | 	"github.com/uptrace/bun" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -83,6 +85,29 @@ func whereStartsLike( | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // loadPagedIDs loads a page of IDs from given SliceCache by `key`, resorting to `loadDESC` if required. Uses `page` to sort + page resulting IDs. | ||||||
|  | // NOTE: IDs returned from `cache` / `loadDESC` MUST be in descending order, otherwise paging will not work correctly / return things out of order. | ||||||
|  | func loadPagedIDs(cache *cache.SliceCache[string], key string, page *paging.Page, loadDESC func() ([]string, error)) ([]string, error) { | ||||||
|  | 	// Check cache for IDs, else load. | ||||||
|  | 	ids, err := cache.Load(key, loadDESC) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Our cached / selected IDs are ALWAYS | ||||||
|  | 	// fetched from `loadDESC` in descending | ||||||
|  | 	// order. Depending on the paging requested | ||||||
|  | 	// this may be an unexpected order. | ||||||
|  | 	if page.GetOrder().Ascending() { | ||||||
|  | 		ids = paging.Reverse(ids) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Page the resulting IDs. | ||||||
|  | 	ids = page.Page(ids) | ||||||
|  | 
 | ||||||
|  | 	return ids, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // updateWhere parses []db.Where and adds it to the given update query. | // updateWhere parses []db.Where and adds it to the given update query. | ||||||
| func updateWhere(q *bun.UpdateQuery, where []db.Where) { | func updateWhere(q *bun.UpdateQuery, where []db.Where) { | ||||||
| 	for _, w := range where { | 	for _, w := range where { | ||||||
|  |  | ||||||
|  | @ -138,43 +138,46 @@ type Relationship interface { | ||||||
| 	RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) error | 	RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) error | ||||||
| 
 | 
 | ||||||
| 	// GetAccountFollows returns a slice of follows owned by the given accountID. | 	// GetAccountFollows returns a slice of follows owned by the given accountID. | ||||||
| 	GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) | 	GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) | ||||||
| 
 | 
 | ||||||
| 	// GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance. | 	// GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance. | ||||||
| 	GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) | 	GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) | ||||||
| 
 | 
 | ||||||
|  | 	// GetAccountFollowers fetches follows that target given accountID. | ||||||
|  | 	GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) | ||||||
|  | 
 | ||||||
|  | 	// GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance. | ||||||
|  | 	GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) | ||||||
|  | 
 | ||||||
|  | 	// GetAccountFollowRequests returns all follow requests targeting the given account. | ||||||
|  | 	GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) | ||||||
|  | 
 | ||||||
|  | 	// GetAccountFollowRequesting returns all follow requests originating from the given account. | ||||||
|  | 	GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) | ||||||
|  | 
 | ||||||
|  | 	// GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters. | ||||||
|  | 	GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error) | ||||||
|  | 
 | ||||||
| 	// CountAccountFollows returns the amount of accounts that the given accountID is following. | 	// CountAccountFollows returns the amount of accounts that the given accountID is following. | ||||||
| 	CountAccountFollows(ctx context.Context, accountID string) (int, error) | 	CountAccountFollows(ctx context.Context, accountID string) (int, error) | ||||||
| 
 | 
 | ||||||
| 	// CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance. | 	// CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance. | ||||||
| 	CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) | 	CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) | ||||||
| 
 | 
 | ||||||
| 	// GetAccountFollowers fetches follows that target given accountID. |  | ||||||
| 	GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) |  | ||||||
| 
 |  | ||||||
| 	// GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance. |  | ||||||
| 	GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) |  | ||||||
| 
 |  | ||||||
| 	// CountAccountFollowers returns the amounts that the given ID is followed by. | 	// CountAccountFollowers returns the amounts that the given ID is followed by. | ||||||
| 	CountAccountFollowers(ctx context.Context, accountID string) (int, error) | 	CountAccountFollowers(ctx context.Context, accountID string) (int, error) | ||||||
| 
 | 
 | ||||||
| 	// CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance. | 	// CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance. | ||||||
| 	CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) | 	CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) | ||||||
| 
 | 
 | ||||||
| 	// GetAccountFollowRequests returns all follow requests targeting the given account. |  | ||||||
| 	GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) |  | ||||||
| 
 |  | ||||||
| 	// GetAccountFollowRequesting returns all follow requests originating from the given account. |  | ||||||
| 	GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) |  | ||||||
| 
 |  | ||||||
| 	// CountAccountFollowRequests returns number of follow requests targeting the given account. | 	// CountAccountFollowRequests returns number of follow requests targeting the given account. | ||||||
| 	CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) | 	CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) | ||||||
| 
 | 
 | ||||||
| 	// CountAccountFollowerRequests returns number of follow requests originating from the given account. | 	// CountAccountFollowerRequests returns number of follow requests originating from the given account. | ||||||
| 	CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) | 	CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) | ||||||
| 
 | 
 | ||||||
| 	// GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters. | 	// CountAccountBlocks ... | ||||||
| 	GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error) | 	CountAccountBlocks(ctx context.Context, accountID string) (int, error) | ||||||
| 
 | 
 | ||||||
| 	// GetNote gets a private note from a source account on a target account, if it exists. | 	// GetNote gets a private note from a source account on a target account, if it exists. | ||||||
| 	GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) | 	GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) | ||||||
|  |  | ||||||
|  | @ -38,7 +38,7 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	follows, err := f.state.DB.GetAccountFollowers(ctx, acct.ID) | 	follows, err := f.state.DB.GetAccountFollowers(ctx, acct.ID, nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("Followers: db error getting followers for account id %s: %s", acct.ID, err) | 		return nil, fmt.Errorf("Followers: db error getting followers for account id %s: %s", acct.ID, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -38,7 +38,7 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	follows, err := f.state.DB.GetAccountFollows(ctx, acct.ID) | 	follows, err := f.state.DB.GetAccountFollows(ctx, acct.ID, nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("Following: db error getting following for account id %s: %w", acct.ID, err) | 		return nil, fmt.Errorf("Following: db error getting following for account id %s: %w", acct.ID, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -47,8 +47,8 @@ func (suite *FollowingTestSuite) TestGetFollowing() { | ||||||
| 	suite.Equal(`{ | 	suite.Equal(`{ | ||||||
|   "@context": "https://www.w3.org/ns/activitystreams", |   "@context": "https://www.w3.org/ns/activitystreams", | ||||||
|   "items": [ |   "items": [ | ||||||
|     "http://localhost:8080/users/admin", |     "http://localhost:8080/users/1happyturtle", | ||||||
|     "http://localhost:8080/users/1happyturtle" |     "http://localhost:8080/users/admin" | ||||||
|   ], |   ], | ||||||
|   "type": "Collection" |   "type": "Collection" | ||||||
| }`, string(fJson)) | }`, string(fJson)) | ||||||
|  |  | ||||||
|  | @ -89,7 +89,7 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs | ||||||
| 			return nil, fmt.Errorf("couldn't find local account with username %s: %s", localAccountUsername, err) | 			return nil, fmt.Errorf("couldn't find local account with username %s: %s", localAccountUsername, err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		follows, err := f.state.DB.GetAccountFollowers(c, account.ID) | 		follows, err := f.state.DB.GetAccountFollowers(c, account.ID, nil) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, fmt.Errorf("couldn't get followers of local account %s: %s", localAccountUsername, err) | 			return nil, fmt.Errorf("couldn't get followers of local account %s: %s", localAccountUsername, err) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | @ -17,10 +17,10 @@ | ||||||
| 
 | 
 | ||||||
| package paging | package paging | ||||||
| 
 | 
 | ||||||
| // MinID returns an ID boundary with given min ID value, | // EitherMinID returns an ID boundary with given min ID value, | ||||||
| // using either the `since_id`,"DESC" name,ordering or | // using either the `since_id`,"DESC" name,ordering or | ||||||
| // `min_id`,"ASC" name,ordering depending on which is set. | // `min_id`,"ASC" name,ordering depending on which is set. | ||||||
| func MinID(minID, sinceID string) Boundary { | func EitherMinID(minID, sinceID string) Boundary { | ||||||
| 	/* | 	/* | ||||||
| 
 | 
 | ||||||
| 	           Paging with `since_id` vs `min_id`: | 	           Paging with `since_id` vs `min_id`: | ||||||
|  | @ -47,18 +47,28 @@ func MinID(minID, sinceID string) Boundary { | ||||||
| 	*/ | 	*/ | ||||||
| 	switch { | 	switch { | ||||||
| 	case minID != "": | 	case minID != "": | ||||||
| 		return Boundary{ | 		return MinID(minID) | ||||||
| 			Name:  "min_id", |  | ||||||
| 			Value: minID, |  | ||||||
| 			Order: OrderAscending, |  | ||||||
| 		} |  | ||||||
| 	default: | 	default: | ||||||
| 		// default min is `since_id` | 		// default min is `since_id` | ||||||
|  | 		return SinceID(sinceID) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SinceID ... | ||||||
|  | func SinceID(sinceID string) Boundary { | ||||||
| 	return Boundary{ | 	return Boundary{ | ||||||
| 		Name:  "since_id", | 		Name:  "since_id", | ||||||
| 		Value: sinceID, | 		Value: sinceID, | ||||||
| 		Order: OrderDescending, | 		Order: OrderDescending, | ||||||
| 	} | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // MinID ... | ||||||
|  | func MinID(minID string) Boundary { | ||||||
|  | 	return Boundary{ | ||||||
|  | 		Name:  "min_id", | ||||||
|  | 		Value: minID, | ||||||
|  | 		Order: OrderAscending, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -111,7 +121,7 @@ func (b Boundary) new(value string) Boundary { | ||||||
| 
 | 
 | ||||||
| // Find finds the boundary's set value in input slice, or returns -1. | // Find finds the boundary's set value in input slice, or returns -1. | ||||||
| func (b Boundary) Find(in []string) int { | func (b Boundary) Find(in []string) int { | ||||||
| 	if zero(b.Value) { | 	if b.Value == "" { | ||||||
| 		return -1 | 		return -1 | ||||||
| 	} | 	} | ||||||
| 	for i := range in { | 	for i := range in { | ||||||
|  | @ -121,15 +131,3 @@ func (b Boundary) Find(in []string) int { | ||||||
| 	} | 	} | ||||||
| 	return -1 | 	return -1 | ||||||
| } | } | ||||||
| 
 |  | ||||||
| // Query returns this boundary as assembled query key=value pair. |  | ||||||
| func (b Boundary) Query() string { |  | ||||||
| 	switch { |  | ||||||
| 	case zero(b.Value): |  | ||||||
| 		return "" |  | ||||||
| 	case b.Name == "": |  | ||||||
| 		panic("value without boundary name") |  | ||||||
| 	default: |  | ||||||
| 		return b.Name + "=" + b.Value |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -20,7 +20,6 @@ package paging | ||||||
| import ( | import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" |  | ||||||
| 
 | 
 | ||||||
| 	"golang.org/x/exp/slices" | 	"golang.org/x/exp/slices" | ||||||
| ) | ) | ||||||
|  | @ -70,26 +69,10 @@ func (p *Page) GetOrder() Order { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *Page) order() Order { | func (p *Page) order() Order { | ||||||
| 	var ( |  | ||||||
| 		// Check if min/max values set. |  | ||||||
| 		minValue = zero(p.Min.Value) |  | ||||||
| 		maxValue = zero(p.Max.Value) |  | ||||||
| 
 |  | ||||||
| 		// Check if min/max orders set. |  | ||||||
| 		minOrder = (p.Min.Order != 0) |  | ||||||
| 		maxOrder = (p.Max.Order != 0) |  | ||||||
| 	) |  | ||||||
| 
 |  | ||||||
| 	switch { | 	switch { | ||||||
| 	// Boundaries with a value AND order set | 	case p.Min.Order != 0: | ||||||
| 	// take priority. Min always comes first. |  | ||||||
| 	case minValue && minOrder: |  | ||||||
| 		return p.Min.Order | 		return p.Min.Order | ||||||
| 	case maxValue && maxOrder: | 	case p.Max.Order != 0: | ||||||
| 		return p.Max.Order |  | ||||||
| 	case minOrder: |  | ||||||
| 		return p.Min.Order |  | ||||||
| 	case maxOrder: |  | ||||||
| 		return p.Max.Order | 		return p.Max.Order | ||||||
| 	default: | 	default: | ||||||
| 		return 0 | 		return 0 | ||||||
|  | @ -108,31 +91,9 @@ func (p *Page) Page(in []string) []string { | ||||||
| 		return in | 		return in | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if o := p.order(); !o.Ascending() { | 	if p.order().Ascending() { | ||||||
| 		// Default sort is descending, |  | ||||||
| 		// catching all cases when NOT |  | ||||||
| 		// ascending (even zero value). |  | ||||||
| 		// |  | ||||||
| 		// NOTE: sorted data does not always |  | ||||||
| 		// occur according to string ineqs |  | ||||||
| 		// so we unfortunately cannot check. |  | ||||||
| 
 |  | ||||||
| 		if maxIdx := p.Max.Find(in); maxIdx != -1 { |  | ||||||
| 			// Reslice skipping up to max. |  | ||||||
| 			in = in[maxIdx+1:] |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if minIdx := p.Min.Find(in); minIdx != -1 { |  | ||||||
| 			// Reslice stripping past min. |  | ||||||
| 			in = in[:minIdx] |  | ||||||
| 		} |  | ||||||
| 	} else { |  | ||||||
| 		// Sort type is ascending, input | 		// Sort type is ascending, input | ||||||
| 		// data is assumed to be ascending. | 		// data is assumed to be ascending. | ||||||
| 		// |  | ||||||
| 		// NOTE: sorted data does not always |  | ||||||
| 		// occur according to string ineqs |  | ||||||
| 		// so we unfortunately cannot check. |  | ||||||
| 
 | 
 | ||||||
| 		if minIdx := p.Min.Find(in); minIdx != -1 { | 		if minIdx := p.Min.Find(in); minIdx != -1 { | ||||||
| 			// Reslice skipping up to min. | 			// Reslice skipping up to min. | ||||||
|  | @ -144,6 +105,11 @@ func (p *Page) Page(in []string) []string { | ||||||
| 			in = in[:maxIdx] | 			in = in[:maxIdx] | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		if p.Limit > 0 && p.Limit < len(in) { | ||||||
|  | 			// Reslice input to limit. | ||||||
|  | 			in = in[:p.Limit] | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		if len(in) > 1 { | 		if len(in) > 1 { | ||||||
| 			// Clone input before | 			// Clone input before | ||||||
| 			// any modifications. | 			// any modifications. | ||||||
|  | @ -153,20 +119,34 @@ func (p *Page) Page(in []string) []string { | ||||||
| 			// ALWAYS be descending. | 			// ALWAYS be descending. | ||||||
| 			in = Reverse(in) | 			in = Reverse(in) | ||||||
| 		} | 		} | ||||||
|  | 	} else { | ||||||
|  | 		// Default sort is descending, | ||||||
|  | 		// catching all cases when NOT | ||||||
|  | 		// ascending (even zero value). | ||||||
|  | 
 | ||||||
|  | 		if maxIdx := p.Max.Find(in); maxIdx != -1 { | ||||||
|  | 			// Reslice skipping up to max. | ||||||
|  | 			in = in[maxIdx+1:] | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if minIdx := p.Min.Find(in); minIdx != -1 { | ||||||
|  | 			// Reslice stripping past min. | ||||||
|  | 			in = in[:minIdx] | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if p.Limit > 0 && p.Limit < len(in) { | 		if p.Limit > 0 && p.Limit < len(in) { | ||||||
| 			// Reslice input to limit. | 			// Reslice input to limit. | ||||||
| 			in = in[:p.Limit] | 			in = in[:p.Limit] | ||||||
| 		} | 		} | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	return in | 	return in | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Next creates a new instance for the next returnable page, using | // Next creates a new instance for the next returnable page, using | ||||||
| // given max value. This preserves original limit and max key name. | // given max value. This preserves original limit and max key name. | ||||||
| func (p *Page) Next(max string) *Page { | func (p *Page) Next(lo, hi string) *Page { | ||||||
| 	if p == nil || max == "" { | 	if p == nil || lo == "" || hi == "" { | ||||||
| 		// no paging. | 		// no paging. | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  | @ -177,16 +157,27 @@ func (p *Page) Next(max string) *Page { | ||||||
| 	// Set original limit. | 	// Set original limit. | ||||||
| 	p2.Limit = p.Limit | 	p2.Limit = p.Limit | ||||||
| 
 | 
 | ||||||
| 	// Create new from old. | 	if p.order().Ascending() { | ||||||
| 	p2.Max = p.Max.new(max) | 		// When ascending, next page | ||||||
|  | 		// needs to start with min at | ||||||
|  | 		// the next highest value. | ||||||
|  | 		p2.Min = p.Min.new(hi) | ||||||
|  | 		p2.Max = p.Max.new("") | ||||||
|  | 	} else { | ||||||
|  | 		// When descending, next page | ||||||
|  | 		// needs to start with max at | ||||||
|  | 		// the next lowest value. | ||||||
|  | 		p2.Min = p.Min.new("") | ||||||
|  | 		p2.Max = p.Max.new(lo) | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	return p2 | 	return p2 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Prev creates a new instance for the prev returnable page, using | // Prev creates a new instance for the prev returnable page, using | ||||||
| // given min value. This preserves original limit and min key name. | // given min value. This preserves original limit and min key name. | ||||||
| func (p *Page) Prev(min string) *Page { | func (p *Page) Prev(lo, hi string) *Page { | ||||||
| 	if p == nil || min == "" { | 	if p == nil || lo == "" || hi == "" { | ||||||
| 		// no paging. | 		// no paging. | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  | @ -197,55 +188,56 @@ func (p *Page) Prev(min string) *Page { | ||||||
| 	// Set original limit. | 	// Set original limit. | ||||||
| 	p2.Limit = p.Limit | 	p2.Limit = p.Limit | ||||||
| 
 | 
 | ||||||
| 	// Create new from old. | 	if p.order().Ascending() { | ||||||
| 	p2.Min = p.Min.new(min) | 		// When ascending, prev page | ||||||
|  | 		// needs to start with max at | ||||||
|  | 		// the next lowest value. | ||||||
|  | 		p2.Min = p.Min.new("") | ||||||
|  | 		p2.Max = p.Max.new(lo) | ||||||
|  | 	} else { | ||||||
|  | 		// When descending, next page | ||||||
|  | 		// needs to start with max at | ||||||
|  | 		// the next lowest value. | ||||||
|  | 		p2.Min = p.Min.new(hi) | ||||||
|  | 		p2.Max = p.Max.new("") | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	return p2 | 	return p2 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ToLink builds a URL link for given endpoint information and extra query parameters, | // ToLink builds a URL link for given endpoint information and extra query parameters, | ||||||
| // appending this Page's minimum / maximum boundaries and available limit (if any). | // appending this Page's minimum / maximum boundaries and available limit (if any). | ||||||
| func (p *Page) ToLink(proto, host, path string, queryParams []string) string { | func (p *Page) ToLink(proto, host, path string, queryParams url.Values) string { | ||||||
| 	if p == nil { | 	if p == nil { | ||||||
| 		// no paging. | 		// no paging. | ||||||
| 		return "" | 		return "" | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Check length before | 	if queryParams == nil { | ||||||
| 	// adding boundary params. | 		// Allocate new query parameters. | ||||||
| 	old := len(queryParams) | 		queryParams = make(url.Values) | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	if minParam := p.Min.Query(); minParam != "" { | 	if p.Min.Value != "" { | ||||||
| 		// A page-minimum query parameter is available. | 		// A page-minimum query parameter is available. | ||||||
| 		queryParams = append(queryParams, minParam) | 		queryParams.Add(p.Min.Name, p.Min.Value) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if maxParam := p.Max.Query(); maxParam != "" { | 	if p.Max.Value != "" { | ||||||
| 		// A page-maximum query parameter is available. | 		// A page-maximum query parameter is available. | ||||||
| 		queryParams = append(queryParams, maxParam) | 		queryParams.Add(p.Max.Name, p.Max.Value) | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if len(queryParams) == old { |  | ||||||
| 		// No page boundaries. |  | ||||||
| 		return "" |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if p.Limit > 0 { | 	if p.Limit > 0 { | ||||||
| 		// Build limit key-value query parameter. | 		// A page limit query parameter is available. | ||||||
| 		param := "limit=" + strconv.Itoa(p.Limit) | 		queryParams.Add("limit", strconv.Itoa(p.Limit)) | ||||||
| 
 |  | ||||||
| 		// Append `limit=$value` query parameter. |  | ||||||
| 		queryParams = append(queryParams, param) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Join collected params into query str. |  | ||||||
| 	query := strings.Join(queryParams, "&") |  | ||||||
| 
 |  | ||||||
| 	// Build URL string. | 	// Build URL string. | ||||||
| 	return (&url.URL{ | 	return (&url.URL{ | ||||||
| 		Scheme:   proto, | 		Scheme:   proto, | ||||||
| 		Host:     host, | 		Host:     host, | ||||||
| 		Path:     path, | 		Path:     path, | ||||||
| 		RawQuery: query, | 		RawQuery: queryParams.Encode(), | ||||||
| 	}).String() | 	}).String() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -97,7 +97,7 @@ var cases = []Case{ | ||||||
| 
 | 
 | ||||||
| 		// Return page and expected IDs. | 		// Return page and expected IDs. | ||||||
| 		return ids, &paging.Page{ | 		return ids, &paging.Page{ | ||||||
| 			Min: paging.MinID(minID, ""), | 			Min: paging.MinID(minID), | ||||||
| 			Max: paging.MaxID(maxID), | 			Max: paging.MaxID(maxID), | ||||||
| 		}, expect | 		}, expect | ||||||
| 	}), | 	}), | ||||||
|  | @ -129,7 +129,7 @@ var cases = []Case{ | ||||||
| 
 | 
 | ||||||
| 		// Return page and expected IDs. | 		// Return page and expected IDs. | ||||||
| 		return ids, &paging.Page{ | 		return ids, &paging.Page{ | ||||||
| 			Min:   paging.MinID(minID, ""), | 			Min:   paging.MinID(minID), | ||||||
| 			Max:   paging.MaxID(maxID), | 			Max:   paging.MaxID(maxID), | ||||||
| 			Limit: limit, | 			Limit: limit, | ||||||
| 		}, expect | 		}, expect | ||||||
|  | @ -156,7 +156,7 @@ var cases = []Case{ | ||||||
| 
 | 
 | ||||||
| 		// Return page and expected IDs. | 		// Return page and expected IDs. | ||||||
| 		return ids, &paging.Page{ | 		return ids, &paging.Page{ | ||||||
| 			Min:   paging.MinID(minID, ""), | 			Min:   paging.MinID(minID), | ||||||
| 			Max:   paging.MaxID(maxID), | 			Max:   paging.MaxID(maxID), | ||||||
| 			Limit: len(ids) * 2, | 			Limit: len(ids) * 2, | ||||||
| 		}, expect | 		}, expect | ||||||
|  | @ -182,7 +182,7 @@ var cases = []Case{ | ||||||
| 
 | 
 | ||||||
| 		// Return page and expected IDs. | 		// Return page and expected IDs. | ||||||
| 		return ids, &paging.Page{ | 		return ids, &paging.Page{ | ||||||
| 			Min: paging.MinID("", sinceID), | 			Min: paging.SinceID(sinceID), | ||||||
| 			Max: paging.MaxID(maxID), | 			Max: paging.MaxID(maxID), | ||||||
| 		}, expect | 		}, expect | ||||||
| 	}), | 	}), | ||||||
|  | @ -225,7 +225,7 @@ var cases = []Case{ | ||||||
| 
 | 
 | ||||||
| 		// Return page and expected IDs. | 		// Return page and expected IDs. | ||||||
| 		return ids, &paging.Page{ | 		return ids, &paging.Page{ | ||||||
| 			Min: paging.MinID("", sinceID), | 			Min: paging.SinceID(sinceID), | ||||||
| 		}, expect | 		}, expect | ||||||
| 	}), | 	}), | ||||||
| 	CreateCase("minID set", func(ids []string) ([]string, *paging.Page, []string) { | 	CreateCase("minID set", func(ids []string) ([]string, *paging.Page, []string) { | ||||||
|  | @ -247,7 +247,7 @@ var cases = []Case{ | ||||||
| 
 | 
 | ||||||
| 		// Return page and expected IDs. | 		// Return page and expected IDs. | ||||||
| 		return ids, &paging.Page{ | 		return ids, &paging.Page{ | ||||||
| 			Min: paging.MinID(minID, ""), | 			Min: paging.MinID(minID), | ||||||
| 		}, expect | 		}, expect | ||||||
| 	}), | 	}), | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -30,9 +30,9 @@ import ( | ||||||
| // While conversely, a zero default limit will not enforce paging, returning a nil page value. | // While conversely, a zero default limit will not enforce paging, returning a nil page value. | ||||||
| func ParseIDPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCode) { | func ParseIDPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCode) { | ||||||
| 	// Extract request query params. | 	// Extract request query params. | ||||||
| 	sinceID := c.Query("since_id") | 	sinceID, haveSince := c.GetQuery("since_id") | ||||||
| 	minID := c.Query("min_id") | 	minID, haveMin := c.GetQuery("min_id") | ||||||
| 	maxID := c.Query("max_id") | 	maxID, haveMax := c.GetQuery("max_id") | ||||||
| 
 | 
 | ||||||
| 	// Extract request limit parameter. | 	// Extract request limit parameter. | ||||||
| 	limit, errWithCode := ParseLimit(c, min, max, _default) | 	limit, errWithCode := ParseLimit(c, min, max, _default) | ||||||
|  | @ -40,20 +40,38 @@ func ParseIDPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCo | ||||||
| 		return nil, errWithCode | 		return nil, errWithCode | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if sinceID == "" && | 	switch { | ||||||
| 		minID == "" && | 	case haveMin: | ||||||
| 		maxID == "" && | 		// A min_id was supplied, even if the value | ||||||
| 		limit == 0 { | 		// itself is empty. This indicates ASC order. | ||||||
| 		// No ID paging params provided, and no default |  | ||||||
| 		// limit value which indicates paging not enforced. |  | ||||||
| 		return nil, nil |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 		return &Page{ | 		return &Page{ | ||||||
| 		Min:   MinID(minID, sinceID), | 			Min:   MinID(minID), | ||||||
| 			Max:   MaxID(maxID), | 			Max:   MaxID(maxID), | ||||||
| 			Limit: limit, | 			Limit: limit, | ||||||
| 		}, nil | 		}, nil | ||||||
|  | 
 | ||||||
|  | 	case haveMax || haveSince: | ||||||
|  | 		// A max_id or since_id was supplied, even if the | ||||||
|  | 		// value itself is empty. This indicates DESC order. | ||||||
|  | 		return &Page{ | ||||||
|  | 			Min:   SinceID(sinceID), | ||||||
|  | 			Max:   MaxID(maxID), | ||||||
|  | 			Limit: limit, | ||||||
|  | 		}, nil | ||||||
|  | 
 | ||||||
|  | 	case limit == 0: | ||||||
|  | 		// No ID paging params provided, and no default | ||||||
|  | 		// limit value which indicates paging not enforced. | ||||||
|  | 		return nil, nil | ||||||
|  | 
 | ||||||
|  | 	default: | ||||||
|  | 		// only limit. | ||||||
|  | 		return &Page{ | ||||||
|  | 			Min:   SinceID(""), | ||||||
|  | 			Max:   MaxID(""), | ||||||
|  | 			Limit: limit, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ParseShortcodeDomainPage parses an emoji shortcode domain Page from a request context, returning BadRequest | // ParseShortcodeDomainPage parses an emoji shortcode domain Page from a request context, returning BadRequest | ||||||
|  | @ -62,8 +80,8 @@ func ParseIDPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCo | ||||||
| // a zero default limit will not enforce paging, returning a nil page value. | // a zero default limit will not enforce paging, returning a nil page value. | ||||||
| func ParseShortcodeDomainPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCode) { | func ParseShortcodeDomainPage(c *gin.Context, min, max, _default int) (*Page, gtserror.WithCode) { | ||||||
| 	// Extract request query parameters. | 	// Extract request query parameters. | ||||||
| 	minShortcode := c.Query("min_shortcode_domain") | 	minShortcode, haveMin := c.GetQuery("min_shortcode_domain") | ||||||
| 	maxShortcode := c.Query("max_shortcode_domain") | 	maxShortcode, haveMax := c.GetQuery("max_shortcode_domain") | ||||||
| 
 | 
 | ||||||
| 	// Extract request limit parameter. | 	// Extract request limit parameter. | ||||||
| 	limit, errWithCode := ParseLimit(c, min, max, _default) | 	limit, errWithCode := ParseLimit(c, min, max, _default) | ||||||
|  | @ -71,8 +89,8 @@ func ParseShortcodeDomainPage(c *gin.Context, min, max, _default int) (*Page, gt | ||||||
| 		return nil, errWithCode | 		return nil, errWithCode | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if minShortcode == "" && | 	if !haveMin && | ||||||
| 		maxShortcode == "" && | 		!haveMax && | ||||||
| 		limit == 0 { | 		limit == 0 { | ||||||
| 		// No ID paging params provided, and no default | 		// No ID paging params provided, and no default | ||||||
| 		// limit value which indicates paging not enforced. | 		// limit value which indicates paging not enforced. | ||||||
|  | @ -89,7 +107,10 @@ func ParseShortcodeDomainPage(c *gin.Context, min, max, _default int) (*Page, gt | ||||||
| // ParseLimit parses the limit query parameter from a request context, returning BadRequest on error parsing and _default if zero limit given. | // ParseLimit parses the limit query parameter from a request context, returning BadRequest on error parsing and _default if zero limit given. | ||||||
| func ParseLimit(c *gin.Context, min, max, _default int) (int, gtserror.WithCode) { | func ParseLimit(c *gin.Context, min, max, _default int) (int, gtserror.WithCode) { | ||||||
| 	// Get limit query param. | 	// Get limit query param. | ||||||
| 	str := c.Query("limit") | 	str, ok := c.GetQuery("limit") | ||||||
|  | 	if !ok { | ||||||
|  | 		return _default, nil | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	// Attempt to parse limit int. | 	// Attempt to parse limit int. | ||||||
| 	i, err := strconv.Atoi(str) | 	i, err := strconv.Atoi(str) | ||||||
|  |  | ||||||
|  | @ -18,6 +18,7 @@ | ||||||
| package paging | package paging | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"net/url" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
|  | @ -35,18 +36,13 @@ type ResponseParams struct { | ||||||
| 	Path  string        // path to use for next/prev queries in the link header | 	Path  string        // path to use for next/prev queries in the link header | ||||||
| 	Next  *Page         // page details for the next page | 	Next  *Page         // page details for the next page | ||||||
| 	Prev  *Page         // page details for the previous page | 	Prev  *Page         // page details for the previous page | ||||||
| 	Query []string      // any extra query parameters to provide in the link header, should be in the format 'example=value' | 	Query url.Values    // any extra query parameters to provide in the link header, should be in the format 'example=value' | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // PackageResponse is a convenience function for returning | // PackageResponse is a convenience function for returning | ||||||
| // a bunch of pageable items (notifications, statuses, etc), as well | // a bunch of pageable items (notifications, statuses, etc), as well | ||||||
| // as a Link header to inform callers of where to find next/prev items. | // as a Link header to inform callers of where to find next/prev items. | ||||||
| func PackageResponse(params ResponseParams) *apimodel.PageableResponse { | func PackageResponse(params ResponseParams) *apimodel.PageableResponse { | ||||||
| 	if len(params.Items) == 0 { |  | ||||||
| 		// No items to page through. |  | ||||||
| 		return EmptyResponse() |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var ( | 	var ( | ||||||
| 		// Extract paging params. | 		// Extract paging params. | ||||||
| 		nextPg = params.Next | 		nextPg = params.Next | ||||||
|  |  | ||||||
|  | @ -42,9 +42,9 @@ func (suite *PagingSuite) TestPagingStandard() { | ||||||
| 	resp := paging.PackageResponse(params) | 	resp := paging.PackageResponse(params) | ||||||
| 
 | 
 | ||||||
| 	suite.Equal(make([]interface{}, 10, 10), resp.Items) | 	suite.Equal(make([]interface{}, 10, 10), resp.Items) | ||||||
| 	suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN&limit=10>; rel="next", <https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R&limit=10>; rel="prev"`, resp.LinkHeader) | 	suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN>; rel="next", <https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R>; rel="prev"`, resp.LinkHeader) | ||||||
| 	suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN&limit=10`, resp.NextLink) | 	suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN`, resp.NextLink) | ||||||
| 	suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R&limit=10`, resp.PrevLink) | 	suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R`, resp.PrevLink) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *PagingSuite) TestPagingNoLimit() { | func (suite *PagingSuite) TestPagingNoLimit() { | ||||||
|  | @ -77,9 +77,9 @@ func (suite *PagingSuite) TestPagingNoNextID() { | ||||||
| 	resp := paging.PackageResponse(params) | 	resp := paging.PackageResponse(params) | ||||||
| 
 | 
 | ||||||
| 	suite.Equal(make([]interface{}, 10, 10), resp.Items) | 	suite.Equal(make([]interface{}, 10, 10), resp.Items) | ||||||
| 	suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R&limit=10>; rel="prev"`, resp.LinkHeader) | 	suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R>; rel="prev"`, resp.LinkHeader) | ||||||
| 	suite.Equal(``, resp.NextLink) | 	suite.Equal(``, resp.NextLink) | ||||||
| 	suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R&limit=10`, resp.PrevLink) | 	suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R`, resp.PrevLink) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *PagingSuite) TestPagingNoPrevID() { | func (suite *PagingSuite) TestPagingNoPrevID() { | ||||||
|  | @ -94,27 +94,11 @@ func (suite *PagingSuite) TestPagingNoPrevID() { | ||||||
| 	resp := paging.PackageResponse(params) | 	resp := paging.PackageResponse(params) | ||||||
| 
 | 
 | ||||||
| 	suite.Equal(make([]interface{}, 10, 10), resp.Items) | 	suite.Equal(make([]interface{}, 10, 10), resp.Items) | ||||||
| 	suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN&limit=10>; rel="next"`, resp.LinkHeader) | 	suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN>; rel="next"`, resp.LinkHeader) | ||||||
| 	suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN&limit=10`, resp.NextLink) | 	suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN`, resp.NextLink) | ||||||
| 	suite.Equal(``, resp.PrevLink) | 	suite.Equal(``, resp.PrevLink) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *PagingSuite) TestPagingNoItems() { |  | ||||||
| 	config.SetHost("example.org") |  | ||||||
| 
 |  | ||||||
| 	params := paging.ResponseParams{ |  | ||||||
| 		Next: nextPage("01H11KA1DM2VH3747YDE7FV5HN", 10), |  | ||||||
| 		Prev: prevPage("01H11KBBVRRDYYC5KEPME1NP5R", 10), |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	resp := paging.PackageResponse(params) |  | ||||||
| 
 |  | ||||||
| 	suite.Empty(resp.Items) |  | ||||||
| 	suite.Empty(resp.LinkHeader) |  | ||||||
| 	suite.Empty(resp.NextLink) |  | ||||||
| 	suite.Empty(resp.PrevLink) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestPagingSuite(t *testing.T) { | func TestPagingSuite(t *testing.T) { | ||||||
| 	suite.Run(t, &PagingSuite{}) | 	suite.Run(t, &PagingSuite{}) | ||||||
| } | } | ||||||
|  | @ -128,7 +112,7 @@ func nextPage(id string, limit int) *paging.Page { | ||||||
| 
 | 
 | ||||||
| func prevPage(id string, limit int) *paging.Page { | func prevPage(id string, limit int) *paging.Page { | ||||||
| 	return &paging.Page{ | 	return &paging.Page{ | ||||||
| 		Min:   paging.MinID(id, ""), | 		Min:   paging.MinID(id), | ||||||
| 		Limit: limit, | 		Limit: limit, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -41,9 +41,3 @@ func Reverse(in []string) []string { | ||||||
| 
 | 
 | ||||||
| 	return in | 	return in | ||||||
| } | } | ||||||
| 
 |  | ||||||
| // zero is a shorthand to check a generic value is its zero value. |  | ||||||
| func zero[T comparable](t T) bool { |  | ||||||
| 	var z T |  | ||||||
| 	return t == z |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -22,6 +22,7 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/media" | 	"github.com/superseriousbusiness/gotosocial/internal/media" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/processing/common" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/state" | 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/text" | 	"github.com/superseriousbusiness/gotosocial/internal/text" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/typeutils" | 	"github.com/superseriousbusiness/gotosocial/internal/typeutils" | ||||||
|  | @ -32,6 +33,9 @@ import ( | ||||||
| // | // | ||||||
| // It also contains logic for actions towards accounts such as following, blocking, seeing follows, etc. | // It also contains logic for actions towards accounts such as following, blocking, seeing follows, etc. | ||||||
| type Processor struct { | type Processor struct { | ||||||
|  | 	// common processor logic | ||||||
|  | 	c *common.Processor | ||||||
|  | 
 | ||||||
| 	state        *state.State | 	state        *state.State | ||||||
| 	tc           typeutils.TypeConverter | 	tc           typeutils.TypeConverter | ||||||
| 	mediaManager *media.Manager | 	mediaManager *media.Manager | ||||||
|  | @ -44,6 +48,7 @@ type Processor struct { | ||||||
| 
 | 
 | ||||||
| // New returns a new account processor. | // New returns a new account processor. | ||||||
| func New( | func New( | ||||||
|  | 	common *common.Processor, | ||||||
| 	state *state.State, | 	state *state.State, | ||||||
| 	tc typeutils.TypeConverter, | 	tc typeutils.TypeConverter, | ||||||
| 	mediaManager *media.Manager, | 	mediaManager *media.Manager, | ||||||
|  | @ -53,6 +58,7 @@ func New( | ||||||
| 	parseMention gtsmodel.ParseMentionFunc, | 	parseMention gtsmodel.ParseMentionFunc, | ||||||
| ) Processor { | ) Processor { | ||||||
| 	return Processor{ | 	return Processor{ | ||||||
|  | 		c:            common, | ||||||
| 		state:        state, | 		state:        state, | ||||||
| 		tc:           tc, | 		tc:           tc, | ||||||
| 		mediaManager: mediaManager, | 		mediaManager: mediaManager, | ||||||
|  |  | ||||||
|  | @ -30,6 +30,7 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing" | 	"github.com/superseriousbusiness/gotosocial/internal/processing" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/account" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/account" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/processing/common" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/state" | 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/storage" | 	"github.com/superseriousbusiness/gotosocial/internal/storage" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/transport" | 	"github.com/superseriousbusiness/gotosocial/internal/transport" | ||||||
|  | @ -113,7 +114,8 @@ func (suite *AccountStandardTestSuite) SetupTest() { | ||||||
| 	suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails) | 	suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails) | ||||||
| 
 | 
 | ||||||
| 	filter := visibility.NewFilter(&suite.state) | 	filter := visibility.NewFilter(&suite.state) | ||||||
| 	suite.accountProcessor = account.New(&suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, filter, processing.GetParseMentionFunc(suite.db, suite.federator)) | 	common := common.New(&suite.state, suite.tc, suite.federator, filter) | ||||||
|  | 	suite.accountProcessor = account.New(&common, &suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, filter, processing.GetParseMentionFunc(suite.db, suite.federator)) | ||||||
| 	testrig.StandardDBSetup(suite.db, nil) | 	testrig.StandardDBSetup(suite.db, nil) | ||||||
| 	testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") | 	testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -28,8 +28,11 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/id" | 	"github.com/superseriousbusiness/gotosocial/internal/id" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/messages" | 	"github.com/superseriousbusiness/gotosocial/internal/messages" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/paging" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/uris" | 	"github.com/superseriousbusiness/gotosocial/internal/uris" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local. | // BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local. | ||||||
|  | @ -128,6 +131,53 @@ func (p *Processor) BlockRemove(ctx context.Context, requestingAccount *gtsmodel | ||||||
| 	return p.RelationshipGet(ctx, requestingAccount, targetAccountID) | 	return p.RelationshipGet(ctx, requestingAccount, targetAccountID) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // BlocksGet ... | ||||||
|  | func (p *Processor) BlocksGet( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	requestingAccount *gtsmodel.Account, | ||||||
|  | 	page *paging.Page, | ||||||
|  | ) (*apimodel.PageableResponse, gtserror.WithCode) { | ||||||
|  | 	blocks, err := p.state.DB.GetAccountBlocks(ctx, | ||||||
|  | 		requestingAccount.ID, | ||||||
|  | 		page, | ||||||
|  | 	) | ||||||
|  | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
|  | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check for empty response. | ||||||
|  | 	count := len(blocks) | ||||||
|  | 	if len(blocks) == 0 { | ||||||
|  | 		return util.EmptyPageableResponse(), nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	items := make([]interface{}, 0, count) | ||||||
|  | 
 | ||||||
|  | 	for _, block := range blocks { | ||||||
|  | 		// Convert target account to frontend API model. (target will never be nil) | ||||||
|  | 		account, err := p.tc.AccountToAPIAccountBlocked(ctx, block.TargetAccount) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Errorf(ctx, "error converting account to public api account: %v", err) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Append target to return items. | ||||||
|  | 		items = append(items, account) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get the lowest and highest | ||||||
|  | 	// ID values, used for paging. | ||||||
|  | 	lo := blocks[count-1].ID | ||||||
|  | 	hi := blocks[0].ID | ||||||
|  | 
 | ||||||
|  | 	return paging.PackageResponse(paging.ResponseParams{ | ||||||
|  | 		Items: items, | ||||||
|  | 		Path:  "/api/v1/blocks", | ||||||
|  | 		Next:  page.Next(lo, hi), | ||||||
|  | 		Prev:  page.Prev(lo, hi), | ||||||
|  | 	}), nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (p *Processor) getBlockTarget(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*gtsmodel.Account, *gtsmodel.Block, gtserror.WithCode) { | func (p *Processor) getBlockTarget(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*gtsmodel.Account, *gtsmodel.Block, gtserror.WithCode) { | ||||||
| 	// Account should not block or unblock itself. | 	// Account should not block or unblock itself. | ||||||
| 	if requestingAccount.ID == targetAccountID { | 	if requestingAccount.ID == targetAccountID { | ||||||
|  |  | ||||||
|  | @ -160,7 +160,7 @@ func (p *Processor) deleteUserAndTokensForAccount(ctx context.Context, account * | ||||||
| //   - Follow requests created by account. | //   - Follow requests created by account. | ||||||
| func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.Account) error { | func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.Account) error { | ||||||
| 	// Delete follows targeting this account. | 	// Delete follows targeting this account. | ||||||
| 	followedBy, err := p.state.DB.GetAccountFollowers(ctx, account.ID) | 	followedBy, err := p.state.DB.GetAccountFollowers(ctx, account.ID, nil) | ||||||
| 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return gtserror.Newf("db error getting follows targeting account %s: %w", account.ID, err) | 		return gtserror.Newf("db error getting follows targeting account %s: %w", account.ID, err) | ||||||
| 	} | 	} | ||||||
|  | @ -172,7 +172,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel. | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Delete follow requests targeting this account. | 	// Delete follow requests targeting this account. | ||||||
| 	followRequestedBy, err := p.state.DB.GetAccountFollowRequests(ctx, account.ID) | 	followRequestedBy, err := p.state.DB.GetAccountFollowRequests(ctx, account.ID, nil) | ||||||
| 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return gtserror.Newf("db error getting follow requests targeting account %s: %w", account.ID, err) | 		return gtserror.Newf("db error getting follow requests targeting account %s: %w", account.ID, err) | ||||||
| 	} | 	} | ||||||
|  | @ -193,7 +193,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel. | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	// Delete follows originating from this account. | 	// Delete follows originating from this account. | ||||||
| 	following, err := p.state.DB.GetAccountFollows(ctx, account.ID) | 	following, err := p.state.DB.GetAccountFollows(ctx, account.ID, nil) | ||||||
| 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return gtserror.Newf("db error getting follows owned by account %s: %w", account.ID, err) | 		return gtserror.Newf("db error getting follows owned by account %s: %w", account.ID, err) | ||||||
| 	} | 	} | ||||||
|  | @ -211,7 +211,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel. | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Delete follow requests originating from this account. | 	// Delete follow requests originating from this account. | ||||||
| 	followRequesting, err := p.state.DB.GetAccountFollowRequesting(ctx, account.ID) | 	followRequesting, err := p.state.DB.GetAccountFollowRequesting(ctx, account.ID, nil) | ||||||
| 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		return gtserror.Newf("db error getting follow requests owned by account %s: %w", account.ID, err) | 		return gtserror.Newf("db error getting follow requests owned by account %s: %w", account.ID, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -20,7 +20,6 @@ package account | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 
 | 
 | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/ap" | 	"github.com/superseriousbusiness/gotosocial/internal/ap" | ||||||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
|  | @ -35,7 +34,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| // FollowCreate handles a follow request to an account, either remote or local. | // FollowCreate handles a follow request to an account, either remote or local. | ||||||
| func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { | func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { | ||||||
| 	targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount.ID, form.ID) | 	targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount, form.ID) | ||||||
| 	if errWithCode != nil { | 	if errWithCode != nil { | ||||||
| 		return nil, errWithCode | 		return nil, errWithCode | ||||||
| 	} | 	} | ||||||
|  | @ -46,7 +45,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode | ||||||
| 		requestingAccount.ID, | 		requestingAccount.ID, | ||||||
| 		targetAccount.ID, | 		targetAccount.ID, | ||||||
| 	); err != nil && !errors.Is(err, db.ErrNoEntries) { | 	); err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		err = fmt.Errorf("FollowCreate: db error checking existing follow: %w", err) | 		err = gtserror.Newf("db error checking existing follow: %w", err) | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
| 	} else if follow != nil { | 	} else if follow != nil { | ||||||
| 		// Already follows, update if necessary + return relationship. | 		// Already follows, update if necessary + return relationship. | ||||||
|  | @ -66,7 +65,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode | ||||||
| 		requestingAccount.ID, | 		requestingAccount.ID, | ||||||
| 		targetAccount.ID, | 		targetAccount.ID, | ||||||
| 	); err != nil && !errors.Is(err, db.ErrNoEntries) { | 	); err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		err = fmt.Errorf("FollowCreate: db error checking existing follow request: %w", err) | 		err = gtserror.Newf("db error checking existing follow request: %w", err) | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
| 	} else if followRequest != nil { | 	} else if followRequest != nil { | ||||||
| 		// Already requested, update if necessary + return relationship. | 		// Already requested, update if necessary + return relationship. | ||||||
|  | @ -100,7 +99,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.state.DB.PutFollowRequest(ctx, fr); err != nil { | 	if err := p.state.DB.PutFollowRequest(ctx, fr); err != nil { | ||||||
| 		err = fmt.Errorf("FollowCreate: error creating follow request in db: %s", err) | 		err = gtserror.Newf("error creating follow request in db: %s", err) | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -112,7 +111,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode | ||||||
| 		// Because we know the requestingAccount is also | 		// Because we know the requestingAccount is also | ||||||
| 		// local, we don't need to federate the accept out. | 		// local, we don't need to federate the accept out. | ||||||
| 		if _, err := p.state.DB.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil { | 		if _, err := p.state.DB.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil { | ||||||
| 			err = fmt.Errorf("FollowCreate: error accepting follow request for local unlocked account: %w", err) | 			err = gtserror.Newf("error accepting follow request for local unlocked account: %w", err) | ||||||
| 			return nil, gtserror.NewErrorInternalError(err) | 			return nil, gtserror.NewErrorInternalError(err) | ||||||
| 		} | 		} | ||||||
| 	} else if targetAccount.IsRemote() { | 	} else if targetAccount.IsRemote() { | ||||||
|  | @ -132,7 +131,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode | ||||||
| 
 | 
 | ||||||
| // FollowRemove handles the removal of a follow/follow request to an account, either remote or local. | // FollowRemove handles the removal of a follow/follow request to an account, either remote or local. | ||||||
| func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { | func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { | ||||||
| 	targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount.ID, targetAccountID) | 	targetAccount, errWithCode := p.getFollowTarget(ctx, requestingAccount, targetAccountID) | ||||||
| 	if errWithCode != nil { | 	if errWithCode != nil { | ||||||
| 		return nil, errWithCode | 		return nil, errWithCode | ||||||
| 	} | 	} | ||||||
|  | @ -140,7 +139,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode | ||||||
| 	// Unfollow and deal with side effects. | 	// Unfollow and deal with side effects. | ||||||
| 	msgs, err := p.unfollow(ctx, requestingAccount, targetAccount) | 	msgs, err := p.unfollow(ctx, requestingAccount, targetAccount) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, gtserror.NewErrorNotFound(fmt.Errorf("FollowRemove: account %s not found in the db: %s", targetAccountID, err)) | 		return nil, gtserror.NewErrorNotFound(gtserror.Newf("account %s not found in the db: %s", targetAccountID, err)) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Batch queue accreted client api messages. | 	// Batch queue accreted client api messages. | ||||||
|  | @ -166,7 +165,6 @@ func (p *Processor) updateFollow( | ||||||
| 	currentNotify *bool, | 	currentNotify *bool, | ||||||
| 	update func(...string) error, | 	update func(...string) error, | ||||||
| ) (*apimodel.Relationship, gtserror.WithCode) { | ) (*apimodel.Relationship, gtserror.WithCode) { | ||||||
| 
 |  | ||||||
| 	if form.Reblogs == nil && form.Notify == nil { | 	if form.Reblogs == nil && form.Notify == nil { | ||||||
| 		// There's nothing to update. | 		// There's nothing to update. | ||||||
| 		return p.RelationshipGet(ctx, requestingAccount, form.ID) | 		return p.RelationshipGet(ctx, requestingAccount, form.ID) | ||||||
|  | @ -192,7 +190,7 @@ func (p *Processor) updateFollow( | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := update(columns...); err != nil { | 	if err := update(columns...); err != nil { | ||||||
| 		err = fmt.Errorf("updateFollow: error updating existing follow (request): %w", err) | 		err = gtserror.Newf("error updating existing follow (request): %w", err) | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -201,38 +199,23 @@ func (p *Processor) updateFollow( | ||||||
| 
 | 
 | ||||||
| // getFollowTarget is a convenience function which: | // getFollowTarget is a convenience function which: | ||||||
| //   - Checks if account is trying to follow/unfollow itself. | //   - Checks if account is trying to follow/unfollow itself. | ||||||
| //   - Returns not found if there's a block in place between accounts. | //   - Returns not found if target should not be visible to requester. | ||||||
| //   - Returns target account according to its id. | //   - Returns target account according to its id. | ||||||
| func (p *Processor) getFollowTarget(ctx context.Context, requestingAccountID string, targetAccountID string) (*gtsmodel.Account, gtserror.WithCode) { | func (p *Processor) getFollowTarget(ctx context.Context, requester *gtsmodel.Account, targetID string) (*gtsmodel.Account, gtserror.WithCode) { | ||||||
|  | 	// Check for requester. | ||||||
|  | 	if requester == nil { | ||||||
|  | 		err := errors.New("no authorized user") | ||||||
|  | 		return nil, gtserror.NewErrorUnauthorized(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Account can't follow or unfollow itself. | 	// Account can't follow or unfollow itself. | ||||||
| 	if requestingAccountID == targetAccountID { | 	if requester.ID == targetID { | ||||||
| 		err := errors.New("account can't follow or unfollow itself") | 		err := errors.New("account can't follow or unfollow itself") | ||||||
| 		return nil, gtserror.NewErrorNotAcceptable(err) | 		return nil, gtserror.NewErrorNotAcceptable(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Do nothing if a block exists in either direction between accounts. | 	// Fetch the target account for requesting user account. | ||||||
| 	if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, targetAccountID); err != nil { | 	return p.c.GetVisibleTargetAccount(ctx, requester, targetID) | ||||||
| 		err = fmt.Errorf("db error checking block between accounts: %w", err) |  | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) |  | ||||||
| 	} else if blocked { |  | ||||||
| 		err = errors.New("block exists between accounts") |  | ||||||
| 		return nil, gtserror.NewErrorNotFound(err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Ensure target account retrievable. |  | ||||||
| 	targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID) |  | ||||||
| 	if err != nil { |  | ||||||
| 		if !errors.Is(err, db.ErrNoEntries) { |  | ||||||
| 			// Real db error. |  | ||||||
| 			err = fmt.Errorf("db error looking for target account %s: %w", targetAccountID, err) |  | ||||||
| 			return nil, gtserror.NewErrorInternalError(err) |  | ||||||
| 		} |  | ||||||
| 		// Account not found. |  | ||||||
| 		err = fmt.Errorf("target account %s not found in the db", targetAccountID) |  | ||||||
| 		return nil, gtserror.NewErrorNotFound(err, err.Error()) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return targetAccount, nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // unfollow is a convenience function for having requesting account | // unfollow is a convenience function for having requesting account | ||||||
|  | @ -248,7 +231,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac | ||||||
| 	// Get follow from requesting account to target account. | 	// Get follow from requesting account to target account. | ||||||
| 	follow, err := p.state.DB.GetFollow(ctx, requestingAccount.ID, targetAccount.ID) | 	follow, err := p.state.DB.GetFollow(ctx, requestingAccount.ID, targetAccount.ID) | ||||||
| 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		err = fmt.Errorf("unfollow: error getting follow from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) | 		err = gtserror.Newf("error getting follow from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -257,7 +240,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac | ||||||
| 		err = p.state.DB.DeleteFollowByID(ctx, follow.ID) | 		err = p.state.DB.DeleteFollowByID(ctx, follow.ID) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			if !errors.Is(err, db.ErrNoEntries) { | 			if !errors.Is(err, db.ErrNoEntries) { | ||||||
| 				err = fmt.Errorf("unfollow: error deleting request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) | 				err = gtserror.Newf("error deleting request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) | ||||||
| 				return nil, err | 				return nil, err | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
|  | @ -284,7 +267,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac | ||||||
| 	// Get follow request from requesting account to target account. | 	// Get follow request from requesting account to target account. | ||||||
| 	followReq, err := p.state.DB.GetFollowRequest(ctx, requestingAccount.ID, targetAccount.ID) | 	followReq, err := p.state.DB.GetFollowRequest(ctx, requestingAccount.ID, targetAccount.ID) | ||||||
| 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		err = fmt.Errorf("unfollow: error getting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) | 		err = gtserror.Newf("error getting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -293,7 +276,7 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac | ||||||
| 		err = p.state.DB.DeleteFollowRequestByID(ctx, followReq.ID) | 		err = p.state.DB.DeleteFollowRequestByID(ctx, followReq.ID) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			if !errors.Is(err, db.ErrNoEntries) { | 			if !errors.Is(err, db.ErrNoEntries) { | ||||||
| 				err = fmt.Errorf("unfollow: error deleting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) | 				err = gtserror.Newf("error deleting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err) | ||||||
| 				return nil, err | 				return nil, err | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										119
									
								
								internal/processing/account/follow_request.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								internal/processing/account/follow_request.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,119 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package account | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 
 | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/ap" | ||||||
|  | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/messages" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/paging" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // FollowRequestAccept handles the accepting of a follow request from the sourceAccountID to the requestingAccount (the currently authorized account). | ||||||
|  | func (p *Processor) FollowRequestAccept(ctx context.Context, requestingAccount *gtsmodel.Account, sourceAccountID string) (*apimodel.Relationship, gtserror.WithCode) { | ||||||
|  | 	follow, err := p.state.DB.AcceptFollowRequest(ctx, sourceAccountID, requestingAccount.ID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, gtserror.NewErrorNotFound(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if follow.Account != nil { | ||||||
|  | 		// Only enqueue work in the case we have a request creating account stored. | ||||||
|  | 		// NOTE: due to how AcceptFollowRequest works, the inverse shouldn't be possible. | ||||||
|  | 		p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ | ||||||
|  | 			APObjectType:   ap.ActivityFollow, | ||||||
|  | 			APActivityType: ap.ActivityAccept, | ||||||
|  | 			GTSModel:       follow, | ||||||
|  | 			OriginAccount:  follow.Account, | ||||||
|  | 			TargetAccount:  follow.TargetAccount, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return p.RelationshipGet(ctx, requestingAccount, sourceAccountID) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // FollowRequestReject handles the rejection of a follow request from the sourceAccountID to the requestingAccount (the currently authorized account). | ||||||
|  | func (p *Processor) FollowRequestReject(ctx context.Context, requestingAccount *gtsmodel.Account, sourceAccountID string) (*apimodel.Relationship, gtserror.WithCode) { | ||||||
|  | 	followRequest, err := p.state.DB.GetFollowRequest(ctx, sourceAccountID, requestingAccount.ID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, gtserror.NewErrorNotFound(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	err = p.state.DB.RejectFollowRequest(ctx, sourceAccountID, requestingAccount.ID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, gtserror.NewErrorNotFound(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if followRequest.Account != nil { | ||||||
|  | 		// Only enqueue work in the case we have a request creating account stored. | ||||||
|  | 		// NOTE: due to how GetFollowRequest works, the inverse shouldn't be possible. | ||||||
|  | 		p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ | ||||||
|  | 			APObjectType:   ap.ActivityFollow, | ||||||
|  | 			APActivityType: ap.ActivityReject, | ||||||
|  | 			GTSModel:       followRequest, | ||||||
|  | 			OriginAccount:  followRequest.Account, | ||||||
|  | 			TargetAccount:  followRequest.TargetAccount, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return p.RelationshipGet(ctx, requestingAccount, sourceAccountID) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // FollowRequestsGet fetches a list of the accounts that are follow requesting the given requestingAccount (the currently authorized account). | ||||||
|  | func (p *Processor) FollowRequestsGet(ctx context.Context, requestingAccount *gtsmodel.Account, page *paging.Page) (*apimodel.PageableResponse, gtserror.WithCode) { | ||||||
|  | 	// Fetch follow requests targeting the given requesting account model. | ||||||
|  | 	followRequests, err := p.state.DB.GetAccountFollowRequests(ctx, requestingAccount.ID, page) | ||||||
|  | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
|  | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check for empty response. | ||||||
|  | 	count := len(followRequests) | ||||||
|  | 	if count == 0 { | ||||||
|  | 		return paging.EmptyResponse(), nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get the lowest and highest | ||||||
|  | 	// ID values, used for paging. | ||||||
|  | 	lo := followRequests[count-1].ID | ||||||
|  | 	hi := followRequests[0].ID | ||||||
|  | 
 | ||||||
|  | 	// Func to fetch follow source at index. | ||||||
|  | 	getIdx := func(i int) *gtsmodel.Account { | ||||||
|  | 		return followRequests[i].Account | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get a filtered slice of public API account models. | ||||||
|  | 	items := p.c.GetVisibleAPIAccountsPaged(ctx, | ||||||
|  | 		requestingAccount, | ||||||
|  | 		getIdx, | ||||||
|  | 		count, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	return paging.PackageResponse(paging.ResponseParams{ | ||||||
|  | 		Items: items, | ||||||
|  | 		Path:  "/api/v1/follow_requests", | ||||||
|  | 		Next:  page.Next(lo, hi), | ||||||
|  | 		Prev:  page.Prev(lo, hi), | ||||||
|  | 	}), nil | ||||||
|  | } | ||||||
|  | @ -20,128 +20,120 @@ package account | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 
 | 
 | ||||||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | 	"github.com/superseriousbusiness/gotosocial/internal/paging" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // FollowersGet fetches a list of the target account's followers. | // FollowersGet fetches a list of the target account's followers. | ||||||
| func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { | func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, page *paging.Page) (*apimodel.PageableResponse, gtserror.WithCode) { | ||||||
| 	if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil { | 	// Fetch target account to check it exists, and visibility of requester->target. | ||||||
| 		err = fmt.Errorf("FollowersGet: db error checking block: %w", err) | 	_, errWithCode := p.c.GetVisibleTargetAccount(ctx, requestingAccount, targetAccountID) | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 	if errWithCode != nil { | ||||||
| 	} else if blocked { | 		return nil, errWithCode | ||||||
| 		err = errors.New("FollowersGet: block exists between accounts") |  | ||||||
| 		return nil, gtserror.NewErrorNotFound(err) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	follows, err := p.state.DB.GetAccountFollowers(ctx, targetAccountID) | 	follows, err := p.state.DB.GetAccountFollowers(ctx, targetAccountID, page) | ||||||
| 	if err != nil { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		if !errors.Is(err, db.ErrNoEntries) { | 		err = gtserror.Newf("db error getting followers: %w", err) | ||||||
| 			err = fmt.Errorf("FollowersGet: db error getting followers: %w", err) |  | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
| 	} | 	} | ||||||
| 		return []apimodel.Account{}, nil | 
 | ||||||
|  | 	// Check for empty response. | ||||||
|  | 	count := len(follows) | ||||||
|  | 	if count == 0 { | ||||||
|  | 		return paging.EmptyResponse(), nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return p.accountsFromFollows(ctx, follows, requestingAccount.ID) | 	// Get the lowest and highest | ||||||
|  | 	// ID values, used for paging. | ||||||
|  | 	lo := follows[count-1].ID | ||||||
|  | 	hi := follows[0].ID | ||||||
|  | 
 | ||||||
|  | 	// Func to fetch follow source at index. | ||||||
|  | 	getIdx := func(i int) *gtsmodel.Account { | ||||||
|  | 		return follows[i].Account | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get a filtered slice of public API account models. | ||||||
|  | 	items := p.c.GetVisibleAPIAccountsPaged(ctx, | ||||||
|  | 		requestingAccount, | ||||||
|  | 		getIdx, | ||||||
|  | 		len(follows), | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	return paging.PackageResponse(paging.ResponseParams{ | ||||||
|  | 		Items: items, | ||||||
|  | 		Path:  "/api/v1/accounts/" + targetAccountID + "/followers", | ||||||
|  | 		Next:  page.Next(lo, hi), | ||||||
|  | 		Prev:  page.Prev(lo, hi), | ||||||
|  | 	}), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // FollowingGet fetches a list of the accounts that target account is following. | // FollowingGet fetches a list of the accounts that target account is following. | ||||||
| func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { | func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, page *paging.Page) (*apimodel.PageableResponse, gtserror.WithCode) { | ||||||
| 	if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil { | 	// Fetch target account to check it exists, and visibility of requester->target. | ||||||
| 		err = fmt.Errorf("FollowingGet: db error checking block: %w", err) | 	_, errWithCode := p.c.GetVisibleTargetAccount(ctx, requestingAccount, targetAccountID) | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 	if errWithCode != nil { | ||||||
| 	} else if blocked { | 		return nil, errWithCode | ||||||
| 		err = errors.New("FollowingGet: block exists between accounts") |  | ||||||
| 		return nil, gtserror.NewErrorNotFound(err) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID) | 	// Fetch known accounts that follow given target account ID. | ||||||
| 	if err != nil { | 	follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID, page) | ||||||
| 		if !errors.Is(err, db.ErrNoEntries) { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 			err = fmt.Errorf("FollowingGet: db error getting followers: %w", err) | 		err = gtserror.Newf("db error getting followers: %w", err) | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
| 	} | 	} | ||||||
| 		return []apimodel.Account{}, nil | 
 | ||||||
|  | 	// Check for empty response. | ||||||
|  | 	count := len(follows) | ||||||
|  | 	if count == 0 { | ||||||
|  | 		return paging.EmptyResponse(), nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return p.targetAccountsFromFollows(ctx, follows, requestingAccount.ID) | 	// Get the lowest and highest | ||||||
|  | 	// ID values, used for paging. | ||||||
|  | 	lo := follows[count-1].ID | ||||||
|  | 	hi := follows[0].ID | ||||||
|  | 
 | ||||||
|  | 	// Func to fetch follow source at index. | ||||||
|  | 	getIdx := func(i int) *gtsmodel.Account { | ||||||
|  | 		return follows[i].TargetAccount | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Get a filtered slice of public API account models. | ||||||
|  | 	items := p.c.GetVisibleAPIAccountsPaged(ctx, | ||||||
|  | 		requestingAccount, | ||||||
|  | 		getIdx, | ||||||
|  | 		len(follows), | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	return paging.PackageResponse(paging.ResponseParams{ | ||||||
|  | 		Items: items, | ||||||
|  | 		Path:  "/api/v1/accounts/" + targetAccountID + "/following", | ||||||
|  | 		Next:  page.Next(lo, hi), | ||||||
|  | 		Prev:  page.Prev(lo, hi), | ||||||
|  | 	}), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account. | // RelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account. | ||||||
| func (p *Processor) RelationshipGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { | func (p *Processor) RelationshipGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { | ||||||
| 	if requestingAccount == nil { | 	if requestingAccount == nil { | ||||||
| 		return nil, gtserror.NewErrorForbidden(errors.New("not authed")) | 		return nil, gtserror.NewErrorForbidden(gtserror.New("not authed")) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	gtsR, err := p.state.DB.GetRelationship(ctx, requestingAccount.ID, targetAccountID) | 	gtsR, err := p.state.DB.GetRelationship(ctx, requestingAccount.ID, targetAccountID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting relationship: %s", err)) | 		return nil, gtserror.NewErrorInternalError(gtserror.Newf("error getting relationship: %s", err)) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	r, err := p.tc.RelationshipToAPIRelationship(ctx, gtsR) | 	r, err := p.tc.RelationshipToAPIRelationship(ctx, gtsR) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting relationship: %s", err)) | 		return nil, gtserror.NewErrorInternalError(gtserror.Newf("error converting relationship: %s", err)) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return r, nil | 	return r, nil | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func (p *Processor) accountsFromFollows(ctx context.Context, follows []*gtsmodel.Follow, requestingAccountID string) ([]apimodel.Account, gtserror.WithCode) { |  | ||||||
| 	accounts := make([]apimodel.Account, 0, len(follows)) |  | ||||||
| 	for _, follow := range follows { |  | ||||||
| 		if follow.Account == nil { |  | ||||||
| 			// No account set for some reason; just skip. |  | ||||||
| 			log.WithContext(ctx).WithField("follow", follow).Warn("follow had no associated account") |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, follow.AccountID); err != nil { |  | ||||||
| 			err = fmt.Errorf("accountsFromFollows: db error checking block: %w", err) |  | ||||||
| 			return nil, gtserror.NewErrorInternalError(err) |  | ||||||
| 		} else if blocked { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		account, err := p.tc.AccountToAPIAccountPublic(ctx, follow.Account) |  | ||||||
| 		if err != nil { |  | ||||||
| 			err = fmt.Errorf("accountsFromFollows: error converting account to api account: %w", err) |  | ||||||
| 			return nil, gtserror.NewErrorInternalError(err) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		accounts = append(accounts, *account) |  | ||||||
| 	} |  | ||||||
| 	return accounts, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *Processor) targetAccountsFromFollows(ctx context.Context, follows []*gtsmodel.Follow, requestingAccountID string) ([]apimodel.Account, gtserror.WithCode) { |  | ||||||
| 	accounts := make([]apimodel.Account, 0, len(follows)) |  | ||||||
| 	for _, follow := range follows { |  | ||||||
| 		if follow.TargetAccount == nil { |  | ||||||
| 			// No account set for some reason; just skip. |  | ||||||
| 			log.WithContext(ctx).WithField("follow", follow).Warn("follow had no associated target account") |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, follow.TargetAccountID); err != nil { |  | ||||||
| 			err = fmt.Errorf("targetAccountsFromFollows: db error checking block: %w", err) |  | ||||||
| 			return nil, gtserror.NewErrorInternalError(err) |  | ||||||
| 		} else if blocked { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		account, err := p.tc.AccountToAPIAccountPublic(ctx, follow.TargetAccount) |  | ||||||
| 		if err != nil { |  | ||||||
| 			err = fmt.Errorf("targetAccountsFromFollows: error converting account to api account: %w", err) |  | ||||||
| 			return nil, gtserror.NewErrorInternalError(err) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		accounts = append(accounts, *account) |  | ||||||
| 	} |  | ||||||
| 	return accounts, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -1,86 +0,0 @@ | ||||||
| // GoToSocial |  | ||||||
| // Copyright (C) GoToSocial Authors admin@gotosocial.org |  | ||||||
| // SPDX-License-Identifier: AGPL-3.0-or-later |  | ||||||
| // |  | ||||||
| // This program is free software: you can redistribute it and/or modify |  | ||||||
| // it under the terms of the GNU Affero General Public License as published by |  | ||||||
| // the Free Software Foundation, either version 3 of the License, or |  | ||||||
| // (at your option) any later version. |  | ||||||
| // |  | ||||||
| // This program is distributed in the hope that it will be useful, |  | ||||||
| // but WITHOUT ANY WARRANTY; without even the implied warranty of |  | ||||||
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the |  | ||||||
| // GNU Affero General Public License for more details. |  | ||||||
| // |  | ||||||
| // You should have received a copy of the GNU Affero General Public License |  | ||||||
| // along with this program.  If not, see <http://www.gnu.org/licenses/>. |  | ||||||
| 
 |  | ||||||
| package processing |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"context" |  | ||||||
| 	"errors" |  | ||||||
| 
 |  | ||||||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/paging" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/util" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // BlocksGet ... |  | ||||||
| func (p *Processor) BlocksGet( |  | ||||||
| 	ctx context.Context, |  | ||||||
| 	requestingAccount *gtsmodel.Account, |  | ||||||
| 	page *paging.Page, |  | ||||||
| ) (*apimodel.PageableResponse, gtserror.WithCode) { |  | ||||||
| 	blocks, err := p.state.DB.GetAccountBlocks(ctx, |  | ||||||
| 		requestingAccount.ID, |  | ||||||
| 		page, |  | ||||||
| 	) |  | ||||||
| 	if err != nil && !errors.Is(err, db.ErrNoEntries) { |  | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Check for zero length. |  | ||||||
| 	count := len(blocks) |  | ||||||
| 	if len(blocks) == 0 { |  | ||||||
| 		return util.EmptyPageableResponse(), nil |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var ( |  | ||||||
| 		items = make([]interface{}, 0, count) |  | ||||||
| 
 |  | ||||||
| 		// Set next + prev values before API converting |  | ||||||
| 		// so the caller can still page even on error. |  | ||||||
| 		nextMaxIDValue = blocks[count-1].ID |  | ||||||
| 		prevMinIDValue = blocks[0].ID |  | ||||||
| 	) |  | ||||||
| 
 |  | ||||||
| 	for _, block := range blocks { |  | ||||||
| 		if block.TargetAccount == nil { |  | ||||||
| 			// All models should be populated at this point. |  | ||||||
| 			log.Warnf(ctx, "block target account was nil: %v", err) |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// Convert target account to frontend API model. |  | ||||||
| 		account, err := p.tc.AccountToAPIAccountBlocked(ctx, block.TargetAccount) |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Errorf(ctx, "error converting account to public api account: %v", err) |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// Append target to return items. |  | ||||||
| 		items = append(items, account) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return paging.PackageResponse(paging.ResponseParams{ |  | ||||||
| 		Items: items, |  | ||||||
| 		Path:  "/api/v1/blocks", |  | ||||||
| 		Next:  page.Next(nextMaxIDValue), |  | ||||||
| 		Prev:  page.Prev(prevMinIDValue), |  | ||||||
| 	}), nil |  | ||||||
| } |  | ||||||
							
								
								
									
										238
									
								
								internal/processing/common/account.go.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										238
									
								
								internal/processing/common/account.go.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,238 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package common | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 
 | ||||||
|  | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // GetTargetAccountBy fetches the target account with db load function, given the authorized (or, nil) requester's | ||||||
|  | // account. This returns an approprate gtserror.WithCode accounting (ha) for not found and visibility to requester. | ||||||
|  | func (p *Processor) GetTargetAccountBy( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	getTargetFromDB func() (*gtsmodel.Account, error), | ||||||
|  | ) ( | ||||||
|  | 	account *gtsmodel.Account, | ||||||
|  | 	visible bool, | ||||||
|  | 	errWithCode gtserror.WithCode, | ||||||
|  | ) { | ||||||
|  | 	// Fetch the target account from db. | ||||||
|  | 	target, err := getTargetFromDB() | ||||||
|  | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
|  | 		return nil, false, gtserror.NewErrorInternalError(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if target == nil { | ||||||
|  | 		// DB loader could not find account in database. | ||||||
|  | 		err := errors.New("target account not found") | ||||||
|  | 		return nil, false, gtserror.NewErrorNotFound(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check whether target account is visible to requesting account. | ||||||
|  | 	visible, err = p.filter.AccountVisible(ctx, requester, target) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, false, gtserror.NewErrorInternalError(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if requester != nil && visible { | ||||||
|  | 		// Ensure the account is up-to-date. | ||||||
|  | 		p.federator.RefreshAccountAsync(ctx, | ||||||
|  | 			requester.Username, | ||||||
|  | 			target, | ||||||
|  | 			nil, | ||||||
|  | 			false, | ||||||
|  | 		) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return target, visible, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetTargetAccountByID is a call-through to GetTargetAccountBy() using the db GetAccountByID() function. | ||||||
|  | func (p *Processor) GetTargetAccountByID( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	targetID string, | ||||||
|  | ) ( | ||||||
|  | 	account *gtsmodel.Account, | ||||||
|  | 	visible bool, | ||||||
|  | 	errWithCode gtserror.WithCode, | ||||||
|  | ) { | ||||||
|  | 	return p.GetTargetAccountBy(ctx, requester, func() (*gtsmodel.Account, error) { | ||||||
|  | 		return p.state.DB.GetAccountByID(ctx, targetID) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetVisibleTargetAccount calls GetTargetAccountByID(), | ||||||
|  | // but converts a non-visible result to not-found error. | ||||||
|  | func (p *Processor) GetVisibleTargetAccount( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	targetID string, | ||||||
|  | ) ( | ||||||
|  | 	account *gtsmodel.Account, | ||||||
|  | 	errWithCode gtserror.WithCode, | ||||||
|  | ) { | ||||||
|  | 	// Fetch the target account by ID from the database. | ||||||
|  | 	target, visible, errWithCode := p.GetTargetAccountByID(ctx, | ||||||
|  | 		requester, | ||||||
|  | 		targetID, | ||||||
|  | 	) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		return nil, errWithCode | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !visible { | ||||||
|  | 		// Pretend account doesn't exist if not visible. | ||||||
|  | 		err := errors.New("target account not found") | ||||||
|  | 		return nil, gtserror.NewErrorNotFound(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return target, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetAPIAccount fetches the appropriate API account model depending on whether requester = target. | ||||||
|  | func (p *Processor) GetAPIAccount( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	target *gtsmodel.Account, | ||||||
|  | ) ( | ||||||
|  | 	apiAcc *apimodel.Account, | ||||||
|  | 	errWithCode gtserror.WithCode, | ||||||
|  | ) { | ||||||
|  | 	var err error | ||||||
|  | 
 | ||||||
|  | 	if requester != nil && requester.ID == target.ID { | ||||||
|  | 		// Only return sensitive account model _if_ requester = target. | ||||||
|  | 		apiAcc, err = p.converter.AccountToAPIAccountSensitive(ctx, target) | ||||||
|  | 	} else { | ||||||
|  | 		// Else, fall back to returning the public account model. | ||||||
|  | 		apiAcc, err = p.converter.AccountToAPIAccountPublic(ctx, target) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		err := gtserror.Newf("error converting account: %w", err) | ||||||
|  | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return apiAcc, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetAPIAccountBlocked fetches the limited "blocked" account model for given target. | ||||||
|  | func (p *Processor) GetAPIAccountBlocked( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	targetAcc *gtsmodel.Account, | ||||||
|  | ) ( | ||||||
|  | 	apiAcc *apimodel.Account, | ||||||
|  | 	errWithCode gtserror.WithCode, | ||||||
|  | ) { | ||||||
|  | 	apiAccount, err := p.converter.AccountToAPIAccountBlocked(ctx, targetAcc) | ||||||
|  | 	if err != nil { | ||||||
|  | 		err = gtserror.Newf("error converting account: %w", err) | ||||||
|  | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
|  | 	} | ||||||
|  | 	return apiAccount, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetVisibleAPIAccounts converts an array of gtsmodel.Accounts (inputted by next function) into | ||||||
|  | // public API model accounts, checking first for visibility. Please note that all errors will be | ||||||
|  | // logged at ERROR level, but will not be returned. Callers are likely to run into show-stopping | ||||||
|  | // errors in the lead-up to this function, whereas calling this should not be a show-stopper. | ||||||
|  | func (p *Processor) GetVisibleAPIAccounts( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	next func(int) *gtsmodel.Account, | ||||||
|  | 	length int, | ||||||
|  | ) []*apimodel.Account { | ||||||
|  | 	return p.getVisibleAPIAccounts(ctx, 3, requester, next, length) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetVisibleAPIAccountsPaged is functionally equivalent to GetVisibleAPIAccounts(), | ||||||
|  | // except the accounts are returned as a converted slice of accounts as interface{}. | ||||||
|  | func (p *Processor) GetVisibleAPIAccountsPaged( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	next func(int) *gtsmodel.Account, | ||||||
|  | 	length int, | ||||||
|  | ) []interface{} { | ||||||
|  | 	accounts := p.getVisibleAPIAccounts(ctx, 3, requester, next, length) | ||||||
|  | 	if len(accounts) == 0 { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	items := make([]interface{}, len(accounts)) | ||||||
|  | 	for i, account := range accounts { | ||||||
|  | 		items[i] = account | ||||||
|  | 	} | ||||||
|  | 	return items | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *Processor) getVisibleAPIAccounts( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	calldepth int, // used to skip wrapping func above these's names | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	next func(int) *gtsmodel.Account, | ||||||
|  | 	length int, | ||||||
|  | ) []*apimodel.Account { | ||||||
|  | 	// Start new log entry with | ||||||
|  | 	// the above calling func's name. | ||||||
|  | 	l := log. | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithField("caller", log.Caller(calldepth+1)) | ||||||
|  | 
 | ||||||
|  | 	// Preallocate slice according to expected length. | ||||||
|  | 	accounts := make([]*apimodel.Account, 0, length) | ||||||
|  | 
 | ||||||
|  | 	for i := 0; i < length; i++ { | ||||||
|  | 		// Get next account. | ||||||
|  | 		account := next(i) | ||||||
|  | 		if account == nil { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Check whether this account is visible to requesting account. | ||||||
|  | 		visible, err := p.filter.AccountVisible(ctx, requester, account) | ||||||
|  | 		if err != nil { | ||||||
|  | 			l.Errorf("error checking account visibility: %v", err) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if !visible { | ||||||
|  | 			// Not visible to requester. | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Convert the account to a public API model representation. | ||||||
|  | 		apiAcc, err := p.converter.AccountToAPIAccountPublic(ctx, account) | ||||||
|  | 		if err != nil { | ||||||
|  | 			l.Errorf("error converting account: %v", err) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Append API model to return slice. | ||||||
|  | 		accounts = append(accounts, apiAcc) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return accounts | ||||||
|  | } | ||||||
							
								
								
									
										50
									
								
								internal/processing/common/common.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								internal/processing/common/common.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,50 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package common | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/federation" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/typeutils" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/visibility" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Processor provides a processor with logic | ||||||
|  | // common to multiple logical domains of the | ||||||
|  | // processing subsection of the codebase. | ||||||
|  | type Processor struct { | ||||||
|  | 	state     *state.State | ||||||
|  | 	converter typeutils.TypeConverter | ||||||
|  | 	federator federation.Federator | ||||||
|  | 	filter    *visibility.Filter | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New returns a new Processor instance. | ||||||
|  | func New( | ||||||
|  | 	state *state.State, | ||||||
|  | 	converter typeutils.TypeConverter, | ||||||
|  | 	federator federation.Federator, | ||||||
|  | 	filter *visibility.Filter, | ||||||
|  | ) Processor { | ||||||
|  | 	return Processor{ | ||||||
|  | 		state:     state, | ||||||
|  | 		converter: converter, | ||||||
|  | 		federator: federator, | ||||||
|  | 		filter:    filter, | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										248
									
								
								internal/processing/common/status.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										248
									
								
								internal/processing/common/status.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,248 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package common | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 
 | ||||||
|  | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // GetTargetStatusBy fetches the target status with db load function, given the authorized (or, nil) requester's | ||||||
|  | // account. This returns an approprate gtserror.WithCode accounting for not found and visibility to requester. | ||||||
|  | func (p *Processor) GetTargetStatusBy( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	getTargetFromDB func() (*gtsmodel.Status, error), | ||||||
|  | ) ( | ||||||
|  | 	status *gtsmodel.Status, | ||||||
|  | 	visible bool, | ||||||
|  | 	errWithCode gtserror.WithCode, | ||||||
|  | ) { | ||||||
|  | 	// Fetch the target status from db. | ||||||
|  | 	target, err := getTargetFromDB() | ||||||
|  | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
|  | 		return nil, false, gtserror.NewErrorInternalError(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if target == nil { | ||||||
|  | 		// DB loader could not find status in database. | ||||||
|  | 		err := errors.New("target status not found") | ||||||
|  | 		return nil, false, gtserror.NewErrorNotFound(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check whether target status is visible to requesting account. | ||||||
|  | 	visible, err = p.filter.StatusVisible(ctx, requester, target) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, false, gtserror.NewErrorInternalError(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if requester != nil && visible { | ||||||
|  | 		// Ensure remote status is up-to-date. | ||||||
|  | 		p.federator.RefreshStatusAsync(ctx, | ||||||
|  | 			requester.Username, | ||||||
|  | 			target, | ||||||
|  | 			nil, | ||||||
|  | 			false, | ||||||
|  | 		) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return target, visible, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetTargetStatusByID is a call-through to GetTargetStatus() using the db GetStatusByID() function. | ||||||
|  | func (p *Processor) GetTargetStatusByID( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	targetID string, | ||||||
|  | ) ( | ||||||
|  | 	status *gtsmodel.Status, | ||||||
|  | 	visible bool, | ||||||
|  | 	errWithCode gtserror.WithCode, | ||||||
|  | ) { | ||||||
|  | 	return p.GetTargetStatusBy(ctx, requester, func() (*gtsmodel.Status, error) { | ||||||
|  | 		return p.state.DB.GetStatusByID(ctx, targetID) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetVisibleTargetStatus calls GetTargetStatusByID(), | ||||||
|  | // but converts a non-visible result to not-found error. | ||||||
|  | func (p *Processor) GetVisibleTargetStatus( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	targetID string, | ||||||
|  | ) ( | ||||||
|  | 	status *gtsmodel.Status, | ||||||
|  | 	errWithCode gtserror.WithCode, | ||||||
|  | ) { | ||||||
|  | 	// Fetch the target status by ID from the database. | ||||||
|  | 	target, visible, errWithCode := p.GetTargetStatusByID(ctx, | ||||||
|  | 		requester, | ||||||
|  | 		targetID, | ||||||
|  | 	) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		return nil, errWithCode | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !visible { | ||||||
|  | 		// Target should not be seen by requester. | ||||||
|  | 		err := errors.New("target status not found") | ||||||
|  | 		return nil, gtserror.NewErrorNotFound(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return target, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetAPIStatus fetches the appropriate API status model for target. | ||||||
|  | func (p *Processor) GetAPIStatus( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	target *gtsmodel.Status, | ||||||
|  | ) ( | ||||||
|  | 	apiStatus *apimodel.Status, | ||||||
|  | 	errWithCode gtserror.WithCode, | ||||||
|  | ) { | ||||||
|  | 	apiStatus, err := p.converter.StatusToAPIStatus(ctx, target, requester) | ||||||
|  | 	if err != nil { | ||||||
|  | 		err = gtserror.Newf("error converting status: %w", err) | ||||||
|  | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
|  | 	} | ||||||
|  | 	return apiStatus, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetVisibleAPIStatuses converts an array of gtsmodel.Status (inputted by next function) into | ||||||
|  | // API model statuses, checking first for visibility. Please note that all errors will be | ||||||
|  | // logged at ERROR level, but will not be returned. Callers are likely to run into show-stopping | ||||||
|  | // errors in the lead-up to this function, whereas calling this should not be a show-stopper. | ||||||
|  | func (p *Processor) GetVisibleAPIStatuses( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	next func(int) *gtsmodel.Status, | ||||||
|  | 	length int, | ||||||
|  | ) []*apimodel.Status { | ||||||
|  | 	return p.getVisibleAPIStatuses(ctx, 3, requester, next, length) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetVisibleAPIStatusesPaged is functionally equivalent to GetVisibleAPIStatuses(), | ||||||
|  | // except the statuses are returned as a converted slice of statuses as interface{}. | ||||||
|  | func (p *Processor) GetVisibleAPIStatusesPaged( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	next func(int) *gtsmodel.Status, | ||||||
|  | 	length int, | ||||||
|  | ) []interface{} { | ||||||
|  | 	statuses := p.getVisibleAPIStatuses(ctx, 3, requester, next, length) | ||||||
|  | 	if len(statuses) == 0 { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	items := make([]interface{}, len(statuses)) | ||||||
|  | 	for i, status := range statuses { | ||||||
|  | 		items[i] = status | ||||||
|  | 	} | ||||||
|  | 	return items | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *Processor) getVisibleAPIStatuses( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	calldepth int, // used to skip wrapping func above these's names | ||||||
|  | 	requester *gtsmodel.Account, | ||||||
|  | 	next func(int) *gtsmodel.Status, | ||||||
|  | 	length int, | ||||||
|  | ) []*apimodel.Status { | ||||||
|  | 	// Start new log entry with | ||||||
|  | 	// the above calling func's name. | ||||||
|  | 	l := log. | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithField("caller", log.Caller(calldepth+1)) | ||||||
|  | 
 | ||||||
|  | 	// Preallocate slice according to expected length. | ||||||
|  | 	statuses := make([]*apimodel.Status, 0, length) | ||||||
|  | 
 | ||||||
|  | 	for i := 0; i < length; i++ { | ||||||
|  | 		// Get next status. | ||||||
|  | 		status := next(i) | ||||||
|  | 		if status == nil { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Check whether this status is visible to requesting account. | ||||||
|  | 		visible, err := p.filter.StatusVisible(ctx, requester, status) | ||||||
|  | 		if err != nil { | ||||||
|  | 			l.Errorf("error checking status visibility: %v", err) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if !visible { | ||||||
|  | 			// Not visible to requester. | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Convert the status to an API model representation. | ||||||
|  | 		apiStatus, err := p.converter.StatusToAPIStatus(ctx, status, requester) | ||||||
|  | 		if err != nil { | ||||||
|  | 			l.Errorf("error converting status: %v", err) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Append API model to return slice. | ||||||
|  | 		statuses = append(statuses, apiStatus) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return statuses | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // InvalidateTimelinedStatus is a shortcut function for invalidating the cached | ||||||
|  | // representation one status in the home timeline and all list timelines of the | ||||||
|  | // given accountID. It should only be called in cases where a status update | ||||||
|  | // does *not* need to be passed into the processor via the worker queue, since | ||||||
|  | // such invalidation will, in that case, be handled by the processor instead. | ||||||
|  | func (p *Processor) InvalidateTimelinedStatus(ctx context.Context, accountID string, statusID string) error { | ||||||
|  | 	// Get lists first + bail if this fails. | ||||||
|  | 	lists, err := p.state.DB.GetListsForAccountID(ctx, accountID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return gtserror.Newf("db error getting lists for account %s: %w", accountID, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Start new log entry with | ||||||
|  | 	// the above calling func's name. | ||||||
|  | 	l := log. | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithField("caller", log.Caller(3)). | ||||||
|  | 		WithField("accountID", accountID). | ||||||
|  | 		WithField("statusID", statusID) | ||||||
|  | 
 | ||||||
|  | 	// Unprepare item from home + list timelines, just log | ||||||
|  | 	// if something goes wrong since this is not a showstopper. | ||||||
|  | 
 | ||||||
|  | 	if err := p.state.Timelines.Home.UnprepareItem(ctx, accountID, statusID); err != nil { | ||||||
|  | 		l.Errorf("error unpreparing item from home timeline: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, list := range lists { | ||||||
|  | 		if err := p.state.Timelines.List.UnprepareItem(ctx, list.ID, statusID); err != nil { | ||||||
|  | 			l.Errorf("error unpreparing item from list timeline %s: %v", list.ID, err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | @ -1,123 +0,0 @@ | ||||||
| // GoToSocial |  | ||||||
| // Copyright (C) GoToSocial Authors admin@gotosocial.org |  | ||||||
| // SPDX-License-Identifier: AGPL-3.0-or-later |  | ||||||
| // |  | ||||||
| // This program is free software: you can redistribute it and/or modify |  | ||||||
| // it under the terms of the GNU Affero General Public License as published by |  | ||||||
| // the Free Software Foundation, either version 3 of the License, or |  | ||||||
| // (at your option) any later version. |  | ||||||
| // |  | ||||||
| // This program is distributed in the hope that it will be useful, |  | ||||||
| // but WITHOUT ANY WARRANTY; without even the implied warranty of |  | ||||||
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the |  | ||||||
| // GNU Affero General Public License for more details. |  | ||||||
| // |  | ||||||
| // You should have received a copy of the GNU Affero General Public License |  | ||||||
| // along with this program.  If not, see <http://www.gnu.org/licenses/>. |  | ||||||
| 
 |  | ||||||
| package processing |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"context" |  | ||||||
| 	"errors" |  | ||||||
| 
 |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/ap" |  | ||||||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/messages" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/oauth" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) { |  | ||||||
| 	followRequests, err := p.state.DB.GetAccountFollowRequests(ctx, auth.Account.ID) |  | ||||||
| 	if err != nil && !errors.Is(err, db.ErrNoEntries) { |  | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	accts := make([]apimodel.Account, 0, len(followRequests)) |  | ||||||
| 	for _, followRequest := range followRequests { |  | ||||||
| 		if followRequest.Account == nil { |  | ||||||
| 			// The creator of the follow doesn't exist, |  | ||||||
| 			// just skip this one. |  | ||||||
| 			log.WithContext(ctx).WithField("followRequest", followRequest).Warn("follow request had no associated account") |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		apiAcct, err := p.tc.AccountToAPIAccountPublic(ctx, followRequest.Account) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, gtserror.NewErrorInternalError(err) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		accts = append(accts, *apiAcct) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return accts, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) { |  | ||||||
| 	follow, err := p.state.DB.AcceptFollowRequest(ctx, accountID, auth.Account.ID) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, gtserror.NewErrorNotFound(err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if follow.Account == nil { |  | ||||||
| 		// The creator of the follow doesn't exist, |  | ||||||
| 		// so we can't do further processing. |  | ||||||
| 		log.WithContext(ctx).WithField("follow", follow).Warn("follow had no associated account") |  | ||||||
| 		return p.relationship(ctx, auth.Account.ID, accountID) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ |  | ||||||
| 		APObjectType:   ap.ActivityFollow, |  | ||||||
| 		APActivityType: ap.ActivityAccept, |  | ||||||
| 		GTSModel:       follow, |  | ||||||
| 		OriginAccount:  follow.Account, |  | ||||||
| 		TargetAccount:  follow.TargetAccount, |  | ||||||
| 	}) |  | ||||||
| 
 |  | ||||||
| 	return p.relationship(ctx, auth.Account.ID, accountID) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) { |  | ||||||
| 	followRequest, err := p.state.DB.GetFollowRequest(ctx, accountID, auth.Account.ID) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, gtserror.NewErrorNotFound(err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	err = p.state.DB.RejectFollowRequest(ctx, accountID, auth.Account.ID) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, gtserror.NewErrorNotFound(err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if followRequest.Account == nil { |  | ||||||
| 		// The creator of the request doesn't exist, |  | ||||||
| 		// so we can't do further processing. |  | ||||||
| 		return p.relationship(ctx, auth.Account.ID, accountID) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ |  | ||||||
| 		APObjectType:   ap.ActivityFollow, |  | ||||||
| 		APActivityType: ap.ActivityReject, |  | ||||||
| 		GTSModel:       followRequest, |  | ||||||
| 		OriginAccount:  followRequest.Account, |  | ||||||
| 		TargetAccount:  followRequest.TargetAccount, |  | ||||||
| 	}) |  | ||||||
| 
 |  | ||||||
| 	return p.relationship(ctx, auth.Account.ID, accountID) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *Processor) relationship(ctx context.Context, accountID string, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { |  | ||||||
| 	relationship, err := p.state.DB.GetRelationship(ctx, accountID, targetAccountID) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	apiRelationship, err := p.tc.RelationshipToAPIRelationship(ctx, relationship) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return apiRelationship, nil |  | ||||||
| } |  | ||||||
|  | @ -30,35 +30,57 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/testrig" | 	"github.com/superseriousbusiness/gotosocial/testrig" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // TODO: move this to the "internal/processing/account" pkg | ||||||
| type FollowRequestTestSuite struct { | type FollowRequestTestSuite struct { | ||||||
| 	ProcessingStandardTestSuite | 	ProcessingStandardTestSuite | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *FollowRequestTestSuite) TestFollowRequestAccept() { | func (suite *FollowRequestTestSuite) TestFollowRequestAccept() { | ||||||
| 	requestingAccount := suite.testAccounts["remote_account_2"] | 	// The authed local account we are going to use for HTTP requests | ||||||
| 	targetAccount := suite.testAccounts["local_account_1"] | 	requestingAccount := suite.testAccounts["local_account_1"] | ||||||
|  | 
 | ||||||
|  | 	// The remote account whose follow request we are accepting | ||||||
|  | 	targetAccount := suite.testAccounts["remote_account_2"] | ||||||
| 
 | 
 | ||||||
| 	// put a follow request in the database | 	// put a follow request in the database | ||||||
| 	fr := >smodel.FollowRequest{ | 	fr := >smodel.FollowRequest{ | ||||||
| 		ID:              "01FJ1S8DX3STJJ6CEYPMZ1M0R3", | 		ID:              "01FJ1S8DX3STJJ6CEYPMZ1M0R3", | ||||||
| 		CreatedAt:       time.Now(), | 		CreatedAt:       time.Now(), | ||||||
| 		UpdatedAt:       time.Now(), | 		UpdatedAt:       time.Now(), | ||||||
| 		URI:             fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", requestingAccount.URI), | 		URI:             fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", targetAccount.URI), | ||||||
| 		AccountID:       requestingAccount.ID, | 		AccountID:       targetAccount.ID, | ||||||
| 		TargetAccountID: targetAccount.ID, | 		TargetAccountID: requestingAccount.ID, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err := suite.db.Put(context.Background(), fr) | 	err := suite.db.Put(context.Background(), fr) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 
 | 
 | ||||||
| 	relationship, errWithCode := suite.processor.FollowRequestAccept(context.Background(), suite.testAutheds["local_account_1"], requestingAccount.ID) | 	relationship, errWithCode := suite.processor.Account().FollowRequestAccept( | ||||||
|  | 		context.Background(), | ||||||
|  | 		requestingAccount, | ||||||
|  | 		targetAccount.ID, | ||||||
|  | 	) | ||||||
| 	suite.NoError(errWithCode) | 	suite.NoError(errWithCode) | ||||||
| 	suite.EqualValues(&apimodel.Relationship{ID: "01FHMQX3GAABWSM0S2VZEC2SWC", Following: false, ShowingReblogs: false, Notifying: false, FollowedBy: true, Blocking: false, BlockedBy: false, Muting: false, MutingNotifications: false, Requested: false, DomainBlocking: false, Endorsed: false, Note: ""}, relationship) | 	suite.EqualValues(&apimodel.Relationship{ | ||||||
|  | 		ID:                  "01FHMQX3GAABWSM0S2VZEC2SWC", | ||||||
|  | 		Following:           false, | ||||||
|  | 		ShowingReblogs:      false, | ||||||
|  | 		Notifying:           false, | ||||||
|  | 		FollowedBy:          true, | ||||||
|  | 		Blocking:            false, | ||||||
|  | 		BlockedBy:           false, | ||||||
|  | 		Muting:              false, | ||||||
|  | 		MutingNotifications: false, | ||||||
|  | 		Requested:           false, | ||||||
|  | 		DomainBlocking:      false, | ||||||
|  | 		Endorsed:            false, | ||||||
|  | 		Note:                "", | ||||||
|  | 	}, relationship) | ||||||
| 
 | 
 | ||||||
| 	// accept should be sent to Some_User | 	// accept should be sent to Some_User | ||||||
| 	var sent [][]byte | 	var sent [][]byte | ||||||
| 	if !testrig.WaitFor(func() bool { | 	if !testrig.WaitFor(func() bool { | ||||||
| 		sentI, ok := suite.httpClient.SentMessages.Load(requestingAccount.InboxURI) | 		sentI, ok := suite.httpClient.SentMessages.Load(targetAccount.InboxURI) | ||||||
| 		if ok { | 		if ok { | ||||||
| 			sent, ok = sentI.([][]byte) | 			sent, ok = sentI.([][]byte) | ||||||
| 			if !ok { | 			if !ok { | ||||||
|  | @ -87,41 +109,45 @@ func (suite *FollowRequestTestSuite) TestFollowRequestAccept() { | ||||||
| 	err = json.Unmarshal(sent[0], accept) | 	err = json.Unmarshal(sent[0], accept) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 
 | 
 | ||||||
| 	suite.Equal(targetAccount.URI, accept.Actor) | 	suite.Equal(requestingAccount.URI, accept.Actor) | ||||||
| 	suite.Equal(requestingAccount.URI, accept.Object.Actor) | 	suite.Equal(targetAccount.URI, accept.Object.Actor) | ||||||
| 	suite.Equal(fr.URI, accept.Object.ID) | 	suite.Equal(fr.URI, accept.Object.ID) | ||||||
| 	suite.Equal(targetAccount.URI, accept.Object.Object) | 	suite.Equal(requestingAccount.URI, accept.Object.Object) | ||||||
| 	suite.Equal(targetAccount.URI, accept.Object.To) | 	suite.Equal(requestingAccount.URI, accept.Object.To) | ||||||
| 	suite.Equal("Follow", accept.Object.Type) | 	suite.Equal("Follow", accept.Object.Type) | ||||||
| 	suite.Equal(requestingAccount.URI, accept.To) | 	suite.Equal(targetAccount.URI, accept.To) | ||||||
| 	suite.Equal("Accept", accept.Type) | 	suite.Equal("Accept", accept.Type) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *FollowRequestTestSuite) TestFollowRequestReject() { | func (suite *FollowRequestTestSuite) TestFollowRequestReject() { | ||||||
| 	requestingAccount := suite.testAccounts["remote_account_2"] | 	requestingAccount := suite.testAccounts["local_account_1"] | ||||||
| 	targetAccount := suite.testAccounts["local_account_1"] | 	targetAccount := suite.testAccounts["remote_account_2"] | ||||||
| 
 | 
 | ||||||
| 	// put a follow request in the database | 	// put a follow request in the database | ||||||
| 	fr := >smodel.FollowRequest{ | 	fr := >smodel.FollowRequest{ | ||||||
| 		ID:              "01FJ1S8DX3STJJ6CEYPMZ1M0R3", | 		ID:              "01FJ1S8DX3STJJ6CEYPMZ1M0R3", | ||||||
| 		CreatedAt:       time.Now(), | 		CreatedAt:       time.Now(), | ||||||
| 		UpdatedAt:       time.Now(), | 		UpdatedAt:       time.Now(), | ||||||
| 		URI:             fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", requestingAccount.URI), | 		URI:             fmt.Sprintf("%s/follow/01FJ1S8DX3STJJ6CEYPMZ1M0R3", targetAccount.URI), | ||||||
| 		AccountID:       requestingAccount.ID, | 		AccountID:       targetAccount.ID, | ||||||
| 		TargetAccountID: targetAccount.ID, | 		TargetAccountID: requestingAccount.ID, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err := suite.db.Put(context.Background(), fr) | 	err := suite.db.Put(context.Background(), fr) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 
 | 
 | ||||||
| 	relationship, errWithCode := suite.processor.FollowRequestReject(context.Background(), suite.testAutheds["local_account_1"], requestingAccount.ID) | 	relationship, errWithCode := suite.processor.Account().FollowRequestReject( | ||||||
|  | 		context.Background(), | ||||||
|  | 		requestingAccount, | ||||||
|  | 		targetAccount.ID, | ||||||
|  | 	) | ||||||
| 	suite.NoError(errWithCode) | 	suite.NoError(errWithCode) | ||||||
| 	suite.EqualValues(&apimodel.Relationship{ID: "01FHMQX3GAABWSM0S2VZEC2SWC", Following: false, ShowingReblogs: false, Notifying: false, FollowedBy: false, Blocking: false, BlockedBy: false, Muting: false, MutingNotifications: false, Requested: false, DomainBlocking: false, Endorsed: false, Note: ""}, relationship) | 	suite.EqualValues(&apimodel.Relationship{ID: "01FHMQX3GAABWSM0S2VZEC2SWC", Following: false, ShowingReblogs: false, Notifying: false, FollowedBy: false, Blocking: false, BlockedBy: false, Muting: false, MutingNotifications: false, Requested: false, DomainBlocking: false, Endorsed: false, Note: ""}, relationship) | ||||||
| 
 | 
 | ||||||
| 	// reject should be sent to Some_User | 	// reject should be sent to Some_User | ||||||
| 	var sent [][]byte | 	var sent [][]byte | ||||||
| 	if !testrig.WaitFor(func() bool { | 	if !testrig.WaitFor(func() bool { | ||||||
| 		sentI, ok := suite.httpClient.SentMessages.Load(requestingAccount.InboxURI) | 		sentI, ok := suite.httpClient.SentMessages.Load(targetAccount.InboxURI) | ||||||
| 		if ok { | 		if ok { | ||||||
| 			sent, ok = sentI.([][]byte) | 			sent, ok = sentI.([][]byte) | ||||||
| 			if !ok { | 			if !ok { | ||||||
|  | @ -150,13 +176,13 @@ func (suite *FollowRequestTestSuite) TestFollowRequestReject() { | ||||||
| 	err = json.Unmarshal(sent[0], reject) | 	err = json.Unmarshal(sent[0], reject) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 
 | 
 | ||||||
| 	suite.Equal(targetAccount.URI, reject.Actor) | 	suite.Equal(requestingAccount.URI, reject.Actor) | ||||||
| 	suite.Equal(requestingAccount.URI, reject.Object.Actor) | 	suite.Equal(targetAccount.URI, reject.Object.Actor) | ||||||
| 	suite.Equal(fr.URI, reject.Object.ID) | 	suite.Equal(fr.URI, reject.Object.ID) | ||||||
| 	suite.Equal(targetAccount.URI, reject.Object.Object) | 	suite.Equal(requestingAccount.URI, reject.Object.Object) | ||||||
| 	suite.Equal(targetAccount.URI, reject.Object.To) | 	suite.Equal(requestingAccount.URI, reject.Object.To) | ||||||
| 	suite.Equal("Follow", reject.Object.Type) | 	suite.Equal("Follow", reject.Object.Type) | ||||||
| 	suite.Equal(requestingAccount.URI, reject.To) | 	suite.Equal(targetAccount.URI, reject.To) | ||||||
| 	suite.Equal("Reject", reject.Type) | 	suite.Equal("Reject", reject.Type) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -24,6 +24,7 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/account" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/account" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/admin" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/admin" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/processing/common" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/fedi" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/fedi" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/list" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/list" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing/markers" | 	"github.com/superseriousbusiness/gotosocial/internal/processing/markers" | ||||||
|  | @ -147,7 +148,8 @@ func NewProcessor( | ||||||
| 	// | 	// | ||||||
| 	// Start with sub processors that will | 	// Start with sub processors that will | ||||||
| 	// be required by the workers processor. | 	// be required by the workers processor. | ||||||
| 	accountProcessor := account.New(state, tc, mediaManager, oauthServer, federator, filter, parseMentionFunc) | 	commonProcessor := common.New(state, tc, federator, filter) | ||||||
|  | 	accountProcessor := account.New(&commonProcessor, state, tc, mediaManager, oauthServer, federator, filter, parseMentionFunc) | ||||||
| 	mediaProcessor := media.New(state, tc, mediaManager, federator.TransportController()) | 	mediaProcessor := media.New(state, tc, mediaManager, federator.TransportController()) | ||||||
| 	streamProcessor := stream.New(state, oauthServer) | 	streamProcessor := stream.New(state, oauthServer) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -66,6 +66,7 @@ func (suite *GetTestSuite) emptyAccountFollows(ctx context.Context, accountID st | ||||||
| 	follows, err := suite.state.DB.GetAccountFollows( | 	follows, err := suite.state.DB.GetAccountFollows( | ||||||
| 		gtscontext.SetBarebones(ctx), | 		gtscontext.SetBarebones(ctx), | ||||||
| 		accountID, | 		accountID, | ||||||
|  | 		nil, // select all | ||||||
| 	) | 	) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
|  | @ -82,6 +83,7 @@ func (suite *GetTestSuite) emptyAccountFollows(ctx context.Context, accountID st | ||||||
| 	follows, err = suite.state.DB.GetAccountFollows( | 	follows, err = suite.state.DB.GetAccountFollows( | ||||||
| 		gtscontext.SetBarebones(ctx), | 		gtscontext.SetBarebones(ctx), | ||||||
| 		accountID, | 		accountID, | ||||||
|  | 		nil, // select all | ||||||
| 	) | 	) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
|  |  | ||||||
|  | @ -364,6 +364,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account { | ||||||
| 			SuspendedAt:             time.Time{}, | 			SuspendedAt:             time.Time{}, | ||||||
| 			HideCollections:         util.Ptr(false), | 			HideCollections:         util.Ptr(false), | ||||||
| 			SuspensionOrigin:        "", | 			SuspensionOrigin:        "", | ||||||
|  | 			EnableRSS:               util.Ptr(false), | ||||||
| 		}, | 		}, | ||||||
| 		"admin_account": { | 		"admin_account": { | ||||||
| 			ID:                      "01F8MH17FWEB39HZJ76B6VXSKF", | 			ID:                      "01F8MH17FWEB39HZJ76B6VXSKF", | ||||||
|  | @ -539,6 +540,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account { | ||||||
| 			SuspendedAt:           time.Time{}, | 			SuspendedAt:           time.Time{}, | ||||||
| 			HideCollections:       util.Ptr(false), | 			HideCollections:       util.Ptr(false), | ||||||
| 			SuspensionOrigin:      "", | 			SuspensionOrigin:      "", | ||||||
|  | 			EnableRSS:             util.Ptr(false), | ||||||
| 		}, | 		}, | ||||||
| 		"remote_account_2": { | 		"remote_account_2": { | ||||||
| 			ID:                    "01FHMQX3GAABWSM0S2VZEC2SWC", | 			ID:                    "01FHMQX3GAABWSM0S2VZEC2SWC", | ||||||
|  | @ -575,6 +577,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account { | ||||||
| 			SuspendedAt:           time.Time{}, | 			SuspendedAt:           time.Time{}, | ||||||
| 			HideCollections:       util.Ptr(false), | 			HideCollections:       util.Ptr(false), | ||||||
| 			SuspensionOrigin:      "", | 			SuspensionOrigin:      "", | ||||||
|  | 			EnableRSS:             util.Ptr(false), | ||||||
| 		}, | 		}, | ||||||
| 		"remote_account_3": { | 		"remote_account_3": { | ||||||
| 			ID:                      "062G5WYKY35KKD12EMSM3F8PJ8", | 			ID:                      "062G5WYKY35KKD12EMSM3F8PJ8", | ||||||
|  | @ -612,6 +615,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account { | ||||||
| 			HideCollections:         util.Ptr(false), | 			HideCollections:         util.Ptr(false), | ||||||
| 			SuspensionOrigin:        "", | 			SuspensionOrigin:        "", | ||||||
| 			HeaderMediaAttachmentID: "01PFPMWK2FF0D9WMHEJHR07C3R", | 			HeaderMediaAttachmentID: "01PFPMWK2FF0D9WMHEJHR07C3R", | ||||||
|  | 			EnableRSS:               util.Ptr(false), | ||||||
| 		}, | 		}, | ||||||
| 		"remote_account_4": { | 		"remote_account_4": { | ||||||
| 			ID:                      "07GZRBAEMBNKGZ8Z9VSKSXKR98", | 			ID:                      "07GZRBAEMBNKGZ8Z9VSKSXKR98", | ||||||
|  |  | ||||||
							
								
								
									
										2
									
								
								vendor/github.com/tomnomnom/linkheader/.gitignore
									
										
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								vendor/github.com/tomnomnom/linkheader/.gitignore
									
										
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,2 @@ | ||||||
|  | cpu.out | ||||||
|  | linkheader.test | ||||||
							
								
								
									
										6
									
								
								vendor/github.com/tomnomnom/linkheader/.travis.yml
									
										
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								vendor/github.com/tomnomnom/linkheader/.travis.yml
									
										
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,6 @@ | ||||||
|  | language: go | ||||||
|  | 
 | ||||||
|  | go: | ||||||
|  |   - 1.6 | ||||||
|  |   - 1.7 | ||||||
|  |   - tip | ||||||
							
								
								
									
										10
									
								
								vendor/github.com/tomnomnom/linkheader/CONTRIBUTING.mkd
									
										
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								vendor/github.com/tomnomnom/linkheader/CONTRIBUTING.mkd
									
										
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,10 @@ | ||||||
|  | # Contributing | ||||||
|  | 
 | ||||||
|  | * Raise an issue if appropriate | ||||||
|  | * Fork the repo | ||||||
|  | * Bootstrap the dev dependencies (run `./script/bootstrap`) | ||||||
|  | * Make your changes | ||||||
|  | * Use [gofmt](https://golang.org/cmd/gofmt/) | ||||||
|  | * Make sure the tests pass (run `./script/test`) | ||||||
|  | * Make sure the linters pass (run `./script/lint`) | ||||||
|  | * Issue a pull request | ||||||
							
								
								
									
										21
									
								
								vendor/github.com/tomnomnom/linkheader/LICENSE
									
										
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								vendor/github.com/tomnomnom/linkheader/LICENSE
									
										
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,21 @@ | ||||||
|  | MIT License | ||||||
|  | 
 | ||||||
|  | Copyright (c) 2016 Tom Hudson | ||||||
|  | 
 | ||||||
|  | Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||
|  | of this software and associated documentation files (the "Software"), to deal | ||||||
|  | in the Software without restriction, including without limitation the rights | ||||||
|  | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||||
|  | copies of the Software, and to permit persons to whom the Software is | ||||||
|  | furnished to do so, subject to the following conditions: | ||||||
|  | 
 | ||||||
|  | The above copyright notice and this permission notice shall be included in all | ||||||
|  | copies or substantial portions of the Software. | ||||||
|  | 
 | ||||||
|  | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||||
|  | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||||
|  | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||||
|  | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||||
|  | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||||
|  | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||||
|  | SOFTWARE. | ||||||
							
								
								
									
										35
									
								
								vendor/github.com/tomnomnom/linkheader/README.mkd
									
										
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								vendor/github.com/tomnomnom/linkheader/README.mkd
									
										
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,35 @@ | ||||||
|  | # Golang Link Header Parser | ||||||
|  | 
 | ||||||
|  | Library for parsing HTTP Link headers. Requires Go 1.6 or higher. | ||||||
|  | 
 | ||||||
|  | Docs can be found on [the GoDoc page](https://godoc.org/github.com/tomnomnom/linkheader). | ||||||
|  | 
 | ||||||
|  | [](https://travis-ci.org/tomnomnom/linkheader) | ||||||
|  | 
 | ||||||
|  | ## Basic Example | ||||||
|  | 
 | ||||||
|  | ```go | ||||||
|  | package main | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 
 | ||||||
|  | 	"github.com/tomnomnom/linkheader" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func main() { | ||||||
|  | 	header := "<https://api.github.com/user/58276/repos?page=2>; rel=\"next\"," + | ||||||
|  | 		"<https://api.github.com/user/58276/repos?page=2>; rel=\"last\"" | ||||||
|  | 	links := linkheader.Parse(header) | ||||||
|  | 
 | ||||||
|  | 	for _, link := range links { | ||||||
|  | 		fmt.Printf("URL: %s; Rel: %s\n", link.URL, link.Rel) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Output: | ||||||
|  | // URL: https://api.github.com/user/58276/repos?page=2; Rel: next | ||||||
|  | // URL: https://api.github.com/user/58276/repos?page=2; Rel: last | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
							
								
								
									
										151
									
								
								vendor/github.com/tomnomnom/linkheader/main.go
									
										
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										151
									
								
								vendor/github.com/tomnomnom/linkheader/main.go
									
										
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,151 @@ | ||||||
|  | // Package linkheader provides functions for parsing HTTP Link headers | ||||||
|  | package linkheader | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // A Link is a single URL and related parameters | ||||||
|  | type Link struct { | ||||||
|  | 	URL    string | ||||||
|  | 	Rel    string | ||||||
|  | 	Params map[string]string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // HasParam returns if a Link has a particular parameter or not | ||||||
|  | func (l Link) HasParam(key string) bool { | ||||||
|  | 	for p := range l.Params { | ||||||
|  | 		if p == key { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Param returns the value of a parameter if it exists | ||||||
|  | func (l Link) Param(key string) string { | ||||||
|  | 	for k, v := range l.Params { | ||||||
|  | 		if key == k { | ||||||
|  | 			return v | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // String returns the string representation of a link | ||||||
|  | func (l Link) String() string { | ||||||
|  | 
 | ||||||
|  | 	p := make([]string, 0, len(l.Params)) | ||||||
|  | 	for k, v := range l.Params { | ||||||
|  | 		p = append(p, fmt.Sprintf("%s=\"%s\"", k, v)) | ||||||
|  | 	} | ||||||
|  | 	if l.Rel != "" { | ||||||
|  | 		p = append(p, fmt.Sprintf("%s=\"%s\"", "rel", l.Rel)) | ||||||
|  | 	} | ||||||
|  | 	return fmt.Sprintf("<%s>; %s", l.URL, strings.Join(p, "; ")) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Links is a slice of Link structs | ||||||
|  | type Links []Link | ||||||
|  | 
 | ||||||
|  | // FilterByRel filters a group of Links by the provided Rel attribute | ||||||
|  | func (l Links) FilterByRel(r string) Links { | ||||||
|  | 	links := make(Links, 0) | ||||||
|  | 	for _, link := range l { | ||||||
|  | 		if link.Rel == r { | ||||||
|  | 			links = append(links, link) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return links | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // String returns the string representation of multiple Links | ||||||
|  | // for use in HTTP responses etc | ||||||
|  | func (l Links) String() string { | ||||||
|  | 	if l == nil { | ||||||
|  | 		return fmt.Sprint(nil) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var strs []string | ||||||
|  | 	for _, link := range l { | ||||||
|  | 		strs = append(strs, link.String()) | ||||||
|  | 	} | ||||||
|  | 	return strings.Join(strs, ", ") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Parse parses a raw Link header in the form: | ||||||
|  | //   <url>; rel="foo", <url>; rel="bar"; wat="dis" | ||||||
|  | // returning a slice of Link structs | ||||||
|  | func Parse(raw string) Links { | ||||||
|  | 	var links Links | ||||||
|  | 
 | ||||||
|  | 	// One chunk: <url>; rel="foo" | ||||||
|  | 	for _, chunk := range strings.Split(raw, ",") { | ||||||
|  | 
 | ||||||
|  | 		link := Link{URL: "", Rel: "", Params: make(map[string]string)} | ||||||
|  | 
 | ||||||
|  | 		// Figure out what each piece of the chunk is | ||||||
|  | 		for _, piece := range strings.Split(chunk, ";") { | ||||||
|  | 
 | ||||||
|  | 			piece = strings.Trim(piece, " ") | ||||||
|  | 			if piece == "" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// URL | ||||||
|  | 			if piece[0] == '<' && piece[len(piece)-1] == '>' { | ||||||
|  | 				link.URL = strings.Trim(piece, "<>") | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Params | ||||||
|  | 			key, val := parseParam(piece) | ||||||
|  | 			if key == "" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Special case for rel | ||||||
|  | 			if strings.ToLower(key) == "rel" { | ||||||
|  | 				link.Rel = val | ||||||
|  | 			} else { | ||||||
|  | 				link.Params[key] = val | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if link.URL != "" { | ||||||
|  | 			links = append(links, link) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return links | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ParseMultiple is like Parse, but accepts a slice of headers | ||||||
|  | // rather than just one header string | ||||||
|  | func ParseMultiple(headers []string) Links { | ||||||
|  | 	links := make(Links, 0) | ||||||
|  | 	for _, header := range headers { | ||||||
|  | 		links = append(links, Parse(header)...) | ||||||
|  | 	} | ||||||
|  | 	return links | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // parseParam takes a raw param in the form key="val" and | ||||||
|  | // returns the key and value as seperate strings | ||||||
|  | func parseParam(raw string) (key, val string) { | ||||||
|  | 
 | ||||||
|  | 	parts := strings.SplitN(raw, "=", 2) | ||||||
|  | 	if len(parts) == 1 { | ||||||
|  | 		return parts[0], "" | ||||||
|  | 	} | ||||||
|  | 	if len(parts) != 2 { | ||||||
|  | 		return "", "" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	key = parts[0] | ||||||
|  | 	val = strings.Trim(parts[1], "\"") | ||||||
|  | 
 | ||||||
|  | 	return key, val | ||||||
|  | 
 | ||||||
|  | } | ||||||
							
								
								
									
										3
									
								
								vendor/modules.txt
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								vendor/modules.txt
									
										
									
									
										vendored
									
									
								
							|  | @ -672,6 +672,9 @@ github.com/tdewolff/parse/v2/strconv | ||||||
| # github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc | # github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc | ||||||
| ## explicit | ## explicit | ||||||
| github.com/tmthrgd/go-hex | github.com/tmthrgd/go-hex | ||||||
|  | # github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 | ||||||
|  | ## explicit | ||||||
|  | github.com/tomnomnom/linkheader | ||||||
| # github.com/twitchyliquid64/golang-asm v0.15.1 | # github.com/twitchyliquid64/golang-asm v0.15.1 | ||||||
| ## explicit; go 1.13 | ## explicit; go 1.13 | ||||||
| github.com/twitchyliquid64/golang-asm/asm/arch | github.com/twitchyliquid64/golang-asm/asm/arch | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue