mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 14:42:26 -05:00 
			
		
		
		
	[chore] update account statuses paging logic (#1814)
This commit is contained in:
		
					parent
					
						
							
								ea1bbacf4b
							
						
					
				
			
			
				commit
				
					
						c48abd8bc0
					
				
			
		
					 5 changed files with 389 additions and 126 deletions
				
			
		|  | @ -29,6 +29,7 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/id" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/state" | 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/util" | 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||||||
|  | @ -475,48 +476,41 @@ func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (i | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, db.Error) { | func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, db.Error) { | ||||||
| 	statusIDs := []string{} | 	// Ensure reasonable | ||||||
|  | 	if limit < 0 { | ||||||
|  | 		limit = 0 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Make educated guess for slice size | ||||||
|  | 	var ( | ||||||
|  | 		statusIDs   = make([]string, 0, limit) | ||||||
|  | 		frontToBack = true | ||||||
|  | 	) | ||||||
| 
 | 
 | ||||||
| 	q := a.conn. | 	q := a.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). | 		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). | ||||||
|  | 		// Select only IDs from table | ||||||
| 		Column("status.id"). | 		Column("status.id"). | ||||||
| 		Order("status.id DESC") | 		Where("? = ?", bun.Ident("status.account_id"), accountID) | ||||||
| 
 |  | ||||||
| 	if accountID != "" { |  | ||||||
| 		q = q.Where("? = ?", bun.Ident("status.account_id"), accountID) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if limit != 0 { |  | ||||||
| 		q = q.Limit(limit) |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	if excludeReplies { | 	if excludeReplies { | ||||||
| 		// include self-replies (threads) | 		q = q.WhereGroup(" AND ", func(*bun.SelectQuery) *bun.SelectQuery { | ||||||
| 		whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { |  | ||||||
| 			return q. | 			return q. | ||||||
| 				WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID). | 				// Do include self replies (threads), but | ||||||
| 				WhereGroup(" OR ", whereEmptyOrNull("status.in_reply_to_uri")) | 				// don't include replies to other people. | ||||||
| 		} | 				Where("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID). | ||||||
| 
 | 				WhereOr("? IS NULL", bun.Ident("status.in_reply_to_uri")) | ||||||
| 		q = q.WhereGroup(" AND ", whereGroup) | 		}) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if excludeReblogs { | 	if excludeReblogs { | ||||||
| 		q = q.WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")) | 		q = q.Where("? IS NULL", bun.Ident("status.boost_of_id")) | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if maxID != "" { |  | ||||||
| 		q = q.Where("? < ?", bun.Ident("status.id"), maxID) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if minID != "" { |  | ||||||
| 		q = q.Where("? > ?", bun.Ident("status.id"), minID) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if mediaOnly { | 	if mediaOnly { | ||||||
| 		// attachments are stored as a json object; | 		// Attachments are stored as a json object; this | ||||||
| 		// this implementation differs between sqlite and postgres, | 		// implementation differs between SQLite and Postgres, | ||||||
| 		// so we have to be thorough to cover all eventualities | 		// so we have to be thorough to cover all eventualities | ||||||
| 		q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery { | 		q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery { | ||||||
| 			switch a.conn.Dialect().Name() { | 			switch a.conn.Dialect().Name() { | ||||||
|  | @ -542,10 +536,46 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li | ||||||
| 		q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic) | 		q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// return only statuses LOWER (ie., older) than maxID | ||||||
|  | 	if maxID == "" { | ||||||
|  | 		maxID = id.Highest | ||||||
|  | 	} | ||||||
|  | 	q = q.Where("? < ?", bun.Ident("status.id"), maxID) | ||||||
|  | 
 | ||||||
|  | 	if minID != "" { | ||||||
|  | 		// return only statuses HIGHER (ie., newer) than minID | ||||||
|  | 		q = q.Where("? > ?", bun.Ident("status.id"), minID) | ||||||
|  | 
 | ||||||
|  | 		// page up | ||||||
|  | 		frontToBack = false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if limit > 0 { | ||||||
|  | 		// limit amount of statuses returned | ||||||
|  | 		q = q.Limit(limit) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if frontToBack { | ||||||
|  | 		// Page down. | ||||||
|  | 		q = q.Order("status.id DESC") | ||||||
|  | 	} else { | ||||||
|  | 		// Page up. | ||||||
|  | 		q = q.Order("status.id ASC") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	if err := q.Scan(ctx, &statusIDs); err != nil { | 	if err := q.Scan(ctx, &statusIDs); err != nil { | ||||||
| 		return nil, a.conn.ProcessError(err) | 		return nil, a.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// If we're paging up, we still want statuses | ||||||
|  | 	// to be sorted by ID desc, so reverse ids slice. | ||||||
|  | 	// https://zchee.github.io/golang-wiki/SliceTricks/#reversing | ||||||
|  | 	if !frontToBack { | ||||||
|  | 		for l, r := 0, len(statusIDs)-1; l < r; l, r = l+1, r-1 { | ||||||
|  | 			statusIDs[l], statusIDs[r] = statusIDs[r], statusIDs[l] | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	return a.statusesFromIDs(ctx, statusIDs) | 	return a.statusesFromIDs(ctx, statusIDs) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -568,23 +598,45 @@ func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID stri | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, db.Error) { | func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, db.Error) { | ||||||
| 	statusIDs := []string{} | 	// Ensure reasonable | ||||||
|  | 	if limit < 0 { | ||||||
|  | 		limit = 0 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Make educated guess for slice size | ||||||
|  | 	statusIDs := make([]string, 0, limit) | ||||||
| 
 | 
 | ||||||
| 	q := a.conn. | 	q := a.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). | 		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). | ||||||
|  | 		// Select only IDs from table | ||||||
| 		Column("status.id"). | 		Column("status.id"). | ||||||
| 		Where("? = ?", bun.Ident("status.account_id"), accountID). | 		Where("? = ?", bun.Ident("status.account_id"), accountID). | ||||||
| 		WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")). | 		// Don't show replies or boosts. | ||||||
| 		WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")). | 		Where("? IS NULL", bun.Ident("status.in_reply_to_uri")). | ||||||
|  | 		Where("? IS NULL", bun.Ident("status.boost_of_id")). | ||||||
|  | 		// Only Public statuses. | ||||||
| 		Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). | 		Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). | ||||||
|  | 		// Don't show local-only statuses on the web view. | ||||||
| 		Where("? = ?", bun.Ident("status.federated"), true) | 		Where("? = ?", bun.Ident("status.federated"), true) | ||||||
| 
 | 
 | ||||||
| 	if maxID != "" { | 	// return only statuses LOWER (ie., older) than maxID | ||||||
| 		q = q.Where("? < ?", bun.Ident("status.id"), maxID) | 	if maxID == "" { | ||||||
|  | 		maxID = id.Highest | ||||||
|  | 	} | ||||||
|  | 	q = q.Where("? < ?", bun.Ident("status.id"), maxID) | ||||||
|  | 
 | ||||||
|  | 	if limit > 0 { | ||||||
|  | 		// limit amount of statuses returned | ||||||
|  | 		q = q.Limit(limit) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	q = q.Limit(limit).Order("status.id DESC") | 	if limit > 0 { | ||||||
|  | 		// limit amount of statuses returned | ||||||
|  | 		q = q.Limit(limit) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	q = q.Order("status.id DESC") | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx, &statusIDs); err != nil { | 	if err := q.Scan(ctx, &statusIDs); err != nil { | ||||||
| 		return nil, a.conn.ProcessError(err) | 		return nil, a.conn.ProcessError(err) | ||||||
|  |  | ||||||
|  | @ -45,6 +45,34 @@ func (suite *AccountTestSuite) TestGetAccountStatuses() { | ||||||
| 	suite.Len(statuses, 5) | 	suite.Len(statuses, 5) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (suite *AccountTestSuite) TestGetAccountStatusesPageDown() { | ||||||
|  | 	// get the first page | ||||||
|  | 	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, "", "", false, false) | ||||||
|  | 	if err != nil { | ||||||
|  | 		suite.FailNow(err.Error()) | ||||||
|  | 	} | ||||||
|  | 	suite.Len(statuses, 2) | ||||||
|  | 
 | ||||||
|  | 	// get the second page | ||||||
|  | 	statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false) | ||||||
|  | 	if err != nil { | ||||||
|  | 		suite.FailNow(err.Error()) | ||||||
|  | 	} | ||||||
|  | 	suite.Len(statuses, 2) | ||||||
|  | 
 | ||||||
|  | 	// get the third page | ||||||
|  | 	statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false) | ||||||
|  | 	if err != nil { | ||||||
|  | 		suite.FailNow(err.Error()) | ||||||
|  | 	} | ||||||
|  | 	suite.Len(statuses, 1) | ||||||
|  | 
 | ||||||
|  | 	// try to get the last page (should be empty) | ||||||
|  | 	statuses, err = suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 2, false, false, statuses[len(statuses)-1].ID, "", false, false) | ||||||
|  | 	suite.ErrorIs(err, db.ErrNoEntries) | ||||||
|  | 	suite.Empty(statuses) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() { | func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() { | ||||||
| 	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false) | 	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
|  |  | ||||||
|  | @ -26,16 +26,32 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/util" | 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // StatusesGet fetches a number of statuses (in time descending order) from the given account, filtered by visibility for | // StatusesGet fetches a number of statuses (in time descending order) from the | ||||||
| // the account given in authed. | // target account, filtered by visibility according to the requesting account. | ||||||
| func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, pinned bool, mediaOnly bool, publicOnly bool) (*apimodel.PageableResponse, gtserror.WithCode) { | func (p *Processor) StatusesGet( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	requestingAccount *gtsmodel.Account, | ||||||
|  | 	targetAccountID string, | ||||||
|  | 	limit int, | ||||||
|  | 	excludeReplies bool, | ||||||
|  | 	excludeReblogs bool, | ||||||
|  | 	maxID string, | ||||||
|  | 	minID string, | ||||||
|  | 	pinned bool, | ||||||
|  | 	mediaOnly bool, | ||||||
|  | 	publicOnly bool, | ||||||
|  | ) (*apimodel.PageableResponse, gtserror.WithCode) { | ||||||
| 	if requestingAccount != nil { | 	if requestingAccount != nil { | ||||||
| 		if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil { | 		blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID) | ||||||
|  | 		if err != nil { | ||||||
| 			return nil, gtserror.NewErrorInternalError(err) | 			return nil, gtserror.NewErrorInternalError(err) | ||||||
| 		} else if blocked { | 		} | ||||||
|  | 
 | ||||||
|  | 		if blocked { | ||||||
| 			err := errors.New("block exists between accounts") | 			err := errors.New("block exists between accounts") | ||||||
| 			return nil, gtserror.NewErrorNotFound(err) | 			return nil, gtserror.NewErrorNotFound(err) | ||||||
| 		} | 		} | ||||||
|  | @ -45,6 +61,7 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel | ||||||
| 		statuses []*gtsmodel.Status | 		statuses []*gtsmodel.Status | ||||||
| 		err      error | 		err      error | ||||||
| 	) | 	) | ||||||
|  | 
 | ||||||
| 	if pinned { | 	if pinned { | ||||||
| 		// Get *ONLY* pinned statuses. | 		// Get *ONLY* pinned statuses. | ||||||
| 		statuses, err = p.state.DB.GetAccountPinnedStatuses(ctx, targetAccountID) | 		statuses, err = p.state.DB.GetAccountPinnedStatuses(ctx, targetAccountID) | ||||||
|  | @ -52,14 +69,17 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel | ||||||
| 		// Get account statuses which *may* include pinned ones. | 		// Get account statuses which *may* include pinned ones. | ||||||
| 		statuses, err = p.state.DB.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, excludeReblogs, maxID, minID, mediaOnly, publicOnly) | 		statuses, err = p.state.DB.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, excludeReblogs, maxID, minID, mediaOnly, publicOnly) | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 
 | ||||||
| 		if err == db.ErrNoEntries { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 			return util.EmptyPageableResponse(), nil |  | ||||||
| 		} |  | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Filtering + serialization process is the same for either pinned status queries or 'normal' ones. | 	if len(statuses) == 0 { | ||||||
|  | 		return util.EmptyPageableResponse(), nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Filtering + serialization process is the same for | ||||||
|  | 	// both pinned status queries and 'normal' ones. | ||||||
| 	filtered, err := p.filter.StatusesVisible(ctx, requestingAccount, statuses) | 	filtered, err := p.filter.StatusesVisible(ctx, requestingAccount, statuses) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
|  | @ -67,24 +87,32 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel | ||||||
| 
 | 
 | ||||||
| 	count := len(filtered) | 	count := len(filtered) | ||||||
| 	if count == 0 { | 	if count == 0 { | ||||||
|  | 		// After filtering there were | ||||||
|  | 		// no statuses left to serve. | ||||||
| 		return util.EmptyPageableResponse(), nil | 		return util.EmptyPageableResponse(), nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	items := make([]interface{}, 0, count) | 	var ( | ||||||
| 	nextMaxIDValue := "" | 		items          = make([]interface{}, 0, count) | ||||||
| 	prevMinIDValue := "" | 		nextMaxIDValue string | ||||||
| 	for i, s := range filtered { | 		prevMinIDValue string | ||||||
| 		item, err := p.tc.StatusToAPIStatus(ctx, s, requestingAccount) | 	) | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status to api: %s", err)) |  | ||||||
| 		} |  | ||||||
| 
 | 
 | ||||||
|  | 	for i, s := range filtered { | ||||||
|  | 		// Set next + prev values before filtering and API | ||||||
|  | 		// converting, so caller can still page properly. | ||||||
| 		if i == count-1 { | 		if i == count-1 { | ||||||
| 			nextMaxIDValue = item.GetID() | 			nextMaxIDValue = s.ID | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if i == 0 { | 		if i == 0 { | ||||||
| 			prevMinIDValue = item.GetID() | 			prevMinIDValue = s.ID | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		item, err := p.tc.StatusToAPIStatus(ctx, s, requestingAccount) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Debugf(ctx, "skipping status %s because it couldn't be converted to its api representation: %s", s.ID, err) | ||||||
|  | 			continue | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		items = append(items, item) | 		items = append(items, item) | ||||||
|  | @ -100,7 +128,7 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel | ||||||
| 
 | 
 | ||||||
| 	return util.PackagePageableResponse(util.PageableResponseParams{ | 	return util.PackagePageableResponse(util.PageableResponseParams{ | ||||||
| 		Items:          items, | 		Items:          items, | ||||||
| 		Path:           fmt.Sprintf("/api/v1/accounts/%s/statuses", targetAccountID), | 		Path:           "/api/v1/accounts/" + targetAccountID + "/statuses", | ||||||
| 		NextMaxIDValue: nextMaxIDValue, | 		NextMaxIDValue: nextMaxIDValue, | ||||||
| 		PrevMinIDValue: prevMinIDValue, | 		PrevMinIDValue: prevMinIDValue, | ||||||
| 		Limit:          limit, | 		Limit:          limit, | ||||||
|  | @ -114,62 +142,58 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // WebStatusesGet fetches a number of statuses (in descending order) from the given account. It selects only | // WebStatusesGet fetches a number of statuses (in descending order) | ||||||
| // statuses which are suitable for showing on the public web profile of an account. | // from the given account. It selects only statuses which are suitable | ||||||
|  | // for showing on the public web profile of an account. | ||||||
| func (p *Processor) WebStatusesGet(ctx context.Context, targetAccountID string, maxID string) (*apimodel.PageableResponse, gtserror.WithCode) { | func (p *Processor) WebStatusesGet(ctx context.Context, targetAccountID string, maxID string) (*apimodel.PageableResponse, gtserror.WithCode) { | ||||||
| 	acct, err := p.state.DB.GetAccountByID(ctx, targetAccountID) | 	account, err := p.state.DB.GetAccountByID(ctx, targetAccountID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		if err == db.ErrNoEntries { | 		if errors.Is(err, db.ErrNoEntries) { | ||||||
| 			err := fmt.Errorf("account %s not found in the db, not getting web statuses for it", targetAccountID) | 			err := fmt.Errorf("account %s not found in the db, not getting web statuses for it", targetAccountID) | ||||||
| 			return nil, gtserror.NewErrorNotFound(err) | 			return nil, gtserror.NewErrorNotFound(err) | ||||||
| 		} | 		} | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if acct.Domain != "" { | 	if account.Domain != "" { | ||||||
| 		err := fmt.Errorf("account %s was not a local account, not getting web statuses for it", targetAccountID) | 		err := fmt.Errorf("account %s was not a local account, not getting web statuses for it", targetAccountID) | ||||||
| 		return nil, gtserror.NewErrorNotFound(err) | 		return nil, gtserror.NewErrorNotFound(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	statuses, err := p.state.DB.GetAccountWebStatuses(ctx, targetAccountID, 10, maxID) | 	statuses, err := p.state.DB.GetAccountWebStatuses(ctx, targetAccountID, 10, maxID) | ||||||
| 	if err != nil { | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | ||||||
| 		if err == db.ErrNoEntries { |  | ||||||
| 			return util.EmptyPageableResponse(), nil |  | ||||||
| 		} |  | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	count := len(statuses) | 	count := len(statuses) | ||||||
| 
 |  | ||||||
| 	if count == 0 { | 	if count == 0 { | ||||||
| 		return util.EmptyPageableResponse(), nil | 		return util.EmptyPageableResponse(), nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	items := []interface{}{} | 	var ( | ||||||
| 	nextMaxIDValue := "" | 		items          = make([]interface{}, 0, count) | ||||||
| 	prevMinIDValue := "" | 		nextMaxIDValue string | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
| 	for i, s := range statuses { | 	for i, s := range statuses { | ||||||
|  | 		// Set next value before API converting, | ||||||
|  | 		// so caller can still page properly. | ||||||
|  | 		if i == count-1 { | ||||||
|  | 			nextMaxIDValue = s.ID | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		item, err := p.tc.StatusToAPIStatus(ctx, s, nil) | 		item, err := p.tc.StatusToAPIStatus(ctx, s, nil) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting status to api: %s", err)) | 			log.Debugf(ctx, "skipping status %s because it couldn't be converted to its api representation: %s", s.ID, err) | ||||||
| 		} | 			continue | ||||||
| 
 |  | ||||||
| 		if i == count-1 { |  | ||||||
| 			nextMaxIDValue = item.GetID() |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if i == 0 { |  | ||||||
| 			prevMinIDValue = item.GetID() |  | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		items = append(items, item) | 		items = append(items, item) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return util.PackagePageableResponse(util.PageableResponseParams{ | 	return util.PackagePageableResponse(util.PageableResponseParams{ | ||||||
| 		Items:            items, | 		Items:          items, | ||||||
| 		Path:             "/@" + acct.Username, | 		Path:           "/@" + account.Username, | ||||||
| 		NextMaxIDValue:   nextMaxIDValue, | 		NextMaxIDValue: nextMaxIDValue, | ||||||
| 		PrevMinIDValue:   prevMinIDValue, |  | ||||||
| 		ExtraQueryParams: []string{}, |  | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -20,6 +20,7 @@ package util | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/config" | 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||||
|  | @ -47,6 +48,13 @@ type PageableResponseParams struct { | ||||||
| // a bunch of pageable items (notifications, statuses, etc), as well | // a bunch of pageable items (notifications, statuses, etc), as well | ||||||
| // as a Link header to inform callers of where to find next/prev items. | // as a Link header to inform callers of where to find next/prev items. | ||||||
| func PackagePageableResponse(params PageableResponseParams) (*apimodel.PageableResponse, gtserror.WithCode) { | func PackagePageableResponse(params PageableResponseParams) (*apimodel.PageableResponse, gtserror.WithCode) { | ||||||
|  | 	if len(params.Items) == 0 { | ||||||
|  | 		// No items to page through. | ||||||
|  | 		return EmptyPageableResponse(), nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Set default paging values, if | ||||||
|  | 	// they weren't set by the caller. | ||||||
| 	if params.NextMaxIDKey == "" { | 	if params.NextMaxIDKey == "" { | ||||||
| 		params.NextMaxIDKey = "max_id" | 		params.NextMaxIDKey = "max_id" | ||||||
| 	} | 	} | ||||||
|  | @ -55,58 +63,70 @@ func PackagePageableResponse(params PageableResponseParams) (*apimodel.PageableR | ||||||
| 		params.PrevMinIDKey = "min_id" | 		params.PrevMinIDKey = "min_id" | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	pageableResponse := EmptyPageableResponse() | 	var ( | ||||||
|  | 		protocol        = config.GetProtocol() | ||||||
|  | 		host            = config.GetHost() | ||||||
|  | 		nextLink        string | ||||||
|  | 		prevLink        string | ||||||
|  | 		linkHeaderParts = make([]string, 0, 2) | ||||||
|  | 	) | ||||||
| 
 | 
 | ||||||
| 	if len(params.Items) == 0 { | 	// Parse next link. | ||||||
| 		return pageableResponse, nil | 	if params.NextMaxIDValue != "" { | ||||||
|  | 		nextRaw := params.NextMaxIDKey + "=" + params.NextMaxIDValue | ||||||
|  | 
 | ||||||
|  | 		if params.Limit != 0 { | ||||||
|  | 			nextRaw = fmt.Sprintf("limit=%d&", params.Limit) + nextRaw | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for _, p := range params.ExtraQueryParams { | ||||||
|  | 			nextRaw += "&" + p | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		nextLink = func() string { | ||||||
|  | 			u := &url.URL{ | ||||||
|  | 				Scheme:   protocol, | ||||||
|  | 				Host:     host, | ||||||
|  | 				Path:     params.Path, | ||||||
|  | 				RawQuery: nextRaw, | ||||||
|  | 			} | ||||||
|  | 			return u.String() | ||||||
|  | 		}() | ||||||
|  | 
 | ||||||
|  | 		linkHeaderParts = append(linkHeaderParts, `<`+nextLink+`>; rel="next"`) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// items | 	// Parse prev link. | ||||||
| 	pageableResponse.Items = params.Items | 	if params.PrevMinIDValue != "" { | ||||||
|  | 		prevRaw := params.PrevMinIDKey + "=" + params.PrevMinIDValue | ||||||
| 
 | 
 | ||||||
| 	protocol := config.GetProtocol() | 		if params.Limit != 0 { | ||||||
| 	host := config.GetHost() | 			prevRaw = fmt.Sprintf("limit=%d&", params.Limit) + prevRaw | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 	// next | 		for _, p := range params.ExtraQueryParams { | ||||||
| 	nextRaw := params.NextMaxIDKey + "=" + params.NextMaxIDValue | 			prevRaw = prevRaw + "&" + p | ||||||
| 	if params.Limit != 0 { | 		} | ||||||
| 		nextRaw = fmt.Sprintf("limit=%d&", params.Limit) + nextRaw |  | ||||||
| 	} |  | ||||||
| 	for _, p := range params.ExtraQueryParams { |  | ||||||
| 		nextRaw = nextRaw + "&" + p |  | ||||||
| 	} |  | ||||||
| 	nextLink := &url.URL{ |  | ||||||
| 		Scheme:   protocol, |  | ||||||
| 		Host:     host, |  | ||||||
| 		Path:     params.Path, |  | ||||||
| 		RawQuery: nextRaw, |  | ||||||
| 	} |  | ||||||
| 	nextLinkString := nextLink.String() |  | ||||||
| 	pageableResponse.NextLink = nextLinkString |  | ||||||
| 
 | 
 | ||||||
| 	// prev | 		prevLink = func() string { | ||||||
| 	prevRaw := params.PrevMinIDKey + "=" + params.PrevMinIDValue | 			u := &url.URL{ | ||||||
| 	if params.Limit != 0 { | 				Scheme:   protocol, | ||||||
| 		prevRaw = fmt.Sprintf("limit=%d&", params.Limit) + prevRaw | 				Host:     host, | ||||||
| 	} | 				Path:     params.Path, | ||||||
| 	for _, p := range params.ExtraQueryParams { | 				RawQuery: prevRaw, | ||||||
| 		prevRaw = prevRaw + "&" + p | 			} | ||||||
| 	} | 			return u.String() | ||||||
| 	prevLink := &url.URL{ | 		}() | ||||||
| 		Scheme:   protocol, |  | ||||||
| 		Host:     host, |  | ||||||
| 		Path:     params.Path, |  | ||||||
| 		RawQuery: prevRaw, |  | ||||||
| 	} |  | ||||||
| 	prevLinkString := prevLink.String() |  | ||||||
| 	pageableResponse.PrevLink = prevLinkString |  | ||||||
| 
 | 
 | ||||||
| 	// link header | 		linkHeaderParts = append(linkHeaderParts, `<`+prevLink+`>; rel="prev"`) | ||||||
| 	next := fmt.Sprintf("<%s>; rel=\"next\"", nextLinkString) | 	} | ||||||
| 	prev := fmt.Sprintf("<%s>; rel=\"prev\"", prevLinkString) |  | ||||||
| 	pageableResponse.LinkHeader = next + ", " + prev |  | ||||||
| 
 | 
 | ||||||
| 	return pageableResponse, nil | 	return &apimodel.PageableResponse{ | ||||||
|  | 		Items:      params.Items, | ||||||
|  | 		LinkHeader: strings.Join(linkHeaderParts, ", "), | ||||||
|  | 		NextLink:   nextLink, | ||||||
|  | 		PrevLink:   prevLink, | ||||||
|  | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // EmptyPageableResponse just returns an empty | // EmptyPageableResponse just returns an empty | ||||||
|  |  | ||||||
							
								
								
									
										139
									
								
								internal/util/paging_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								internal/util/paging_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,139 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package util_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"github.com/stretchr/testify/suite" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type PagingSuite struct { | ||||||
|  | 	suite.Suite | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PagingSuite) TestPagingStandard() { | ||||||
|  | 	config.SetHost("example.org") | ||||||
|  | 
 | ||||||
|  | 	params := util.PageableResponseParams{ | ||||||
|  | 		Items:          make([]interface{}, 10, 10), | ||||||
|  | 		Path:           "/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses", | ||||||
|  | 		NextMaxIDValue: "01H11KA1DM2VH3747YDE7FV5HN", | ||||||
|  | 		PrevMinIDValue: "01H11KBBVRRDYYC5KEPME1NP5R", | ||||||
|  | 		Limit:          10, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	resp, errWithCode := util.PackagePageableResponse(params) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		suite.FailNow(errWithCode.Error()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	suite.Equal(make([]interface{}, 10, 10), resp.Items) | ||||||
|  | 	suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN>; rel="next", <https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R>; rel="prev"`, resp.LinkHeader) | ||||||
|  | 	suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN`, resp.NextLink) | ||||||
|  | 	suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R`, resp.PrevLink) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PagingSuite) TestPagingNoLimit() { | ||||||
|  | 	config.SetHost("example.org") | ||||||
|  | 
 | ||||||
|  | 	params := util.PageableResponseParams{ | ||||||
|  | 		Items:          make([]interface{}, 10, 10), | ||||||
|  | 		Path:           "/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses", | ||||||
|  | 		NextMaxIDValue: "01H11KA1DM2VH3747YDE7FV5HN", | ||||||
|  | 		PrevMinIDValue: "01H11KBBVRRDYYC5KEPME1NP5R", | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	resp, errWithCode := util.PackagePageableResponse(params) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		suite.FailNow(errWithCode.Error()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	suite.Equal(make([]interface{}, 10, 10), resp.Items) | ||||||
|  | 	suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN>; rel="next", <https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R>; rel="prev"`, resp.LinkHeader) | ||||||
|  | 	suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?max_id=01H11KA1DM2VH3747YDE7FV5HN`, resp.NextLink) | ||||||
|  | 	suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?min_id=01H11KBBVRRDYYC5KEPME1NP5R`, resp.PrevLink) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PagingSuite) TestPagingNoNextID() { | ||||||
|  | 	config.SetHost("example.org") | ||||||
|  | 
 | ||||||
|  | 	params := util.PageableResponseParams{ | ||||||
|  | 		Items:          make([]interface{}, 10, 10), | ||||||
|  | 		Path:           "/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses", | ||||||
|  | 		PrevMinIDValue: "01H11KBBVRRDYYC5KEPME1NP5R", | ||||||
|  | 		Limit:          10, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	resp, errWithCode := util.PackagePageableResponse(params) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		suite.FailNow(errWithCode.Error()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	suite.Equal(make([]interface{}, 10, 10), resp.Items) | ||||||
|  | 	suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R>; rel="prev"`, resp.LinkHeader) | ||||||
|  | 	suite.Equal(``, resp.NextLink) | ||||||
|  | 	suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&min_id=01H11KBBVRRDYYC5KEPME1NP5R`, resp.PrevLink) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PagingSuite) TestPagingNoPrevID() { | ||||||
|  | 	config.SetHost("example.org") | ||||||
|  | 
 | ||||||
|  | 	params := util.PageableResponseParams{ | ||||||
|  | 		Items:          make([]interface{}, 10, 10), | ||||||
|  | 		Path:           "/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses", | ||||||
|  | 		NextMaxIDValue: "01H11KA1DM2VH3747YDE7FV5HN", | ||||||
|  | 		Limit:          10, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	resp, errWithCode := util.PackagePageableResponse(params) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		suite.FailNow(errWithCode.Error()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	suite.Equal(make([]interface{}, 10, 10), resp.Items) | ||||||
|  | 	suite.Equal(`<https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN>; rel="next"`, resp.LinkHeader) | ||||||
|  | 	suite.Equal(`https://example.org/api/v1/accounts/01H11KA68PM4NNYJEG0FJQ90R3/statuses?limit=10&max_id=01H11KA1DM2VH3747YDE7FV5HN`, resp.NextLink) | ||||||
|  | 	suite.Equal(``, resp.PrevLink) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *PagingSuite) TestPagingNoItems() { | ||||||
|  | 	config.SetHost("example.org") | ||||||
|  | 
 | ||||||
|  | 	params := util.PageableResponseParams{ | ||||||
|  | 		NextMaxIDValue: "01H11KA1DM2VH3747YDE7FV5HN", | ||||||
|  | 		PrevMinIDValue: "01H11KBBVRRDYYC5KEPME1NP5R", | ||||||
|  | 		Limit:          10, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	resp, errWithCode := util.PackagePageableResponse(params) | ||||||
|  | 	if errWithCode != nil { | ||||||
|  | 		suite.FailNow(errWithCode.Error()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	suite.Empty(resp.Items) | ||||||
|  | 	suite.Empty(resp.LinkHeader) | ||||||
|  | 	suite.Empty(resp.NextLink) | ||||||
|  | 	suite.Empty(resp.PrevLink) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestPagingSuite(t *testing.T) { | ||||||
|  | 	suite.Run(t, &PagingSuite{}) | ||||||
|  | } | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue