mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 01:02:25 -05:00 
			
		
		
		
	[chore] Standardize database queries, use bun.Ident() properly (#886)
		
	* use bun.Ident for user queries * use bun.Ident for account queries * use bun.Ident for media queries * add DeleteAccount func * remove CaseInsensitive in Where+use Ident ipv Safe * update admin db * update domain, use ident * update emoji, use ident * update instance queries, use bun.Ident * fix media * update mentions, use bun ident * update relationship + tests * use tableexpr * add test follows to bun db test suite * update notifications * updatebyprimarykey => updatebyid * fix session * prefer explicit ID to pk * fix little fucky wucky * remove workaround * use proper db func for attachment selection * update status db * add m2m entries in test rig * fix up timeline * go fmt * fix status put issue * update GetAccountStatuses
This commit is contained in:
		
					parent
					
						
							
								e58a6a2da3
							
						
					
				
			
			
				commit
				
					
						aa07750bdb
					
				
			
		
					 45 changed files with 1074 additions and 570 deletions
				
			
		|  | @ -101,7 +101,7 @@ var Confirm action.GTSAction = func(ctx context.Context) error { | ||||||
| 	u.Email = u.UnconfirmedEmail | 	u.Email = u.UnconfirmedEmail | ||||||
| 	u.ConfirmedAt = time.Now() | 	u.ConfirmedAt = time.Now() | ||||||
| 	u.UpdatedAt = time.Now() | 	u.UpdatedAt = time.Now() | ||||||
| 	if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil { | 	if err := dbConn.UpdateByID(ctx, u, u.ID, updatingColumns...); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										5
									
								
								internal/cache/account.go
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								internal/cache/account.go
									
										
									
									
										vendored
									
									
								
							|  | @ -101,6 +101,11 @@ func (c *AccountCache) Put(account *gtsmodel.Account) { | ||||||
| 	c.cache.Set(account.ID, copyAccount(account)) | 	c.cache.Set(account.ID, copyAccount(account)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Invalidate removes (invalidates) one account from the cache by its ID. | ||||||
|  | func (c *AccountCache) Invalidate(id string) { | ||||||
|  | 	c.cache.Invalidate(id) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects. | // copyAccount performs a surface-level copy of account, only keeping attached IDs intact, not the objects. | ||||||
| // due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr) | // due to all the data being copied being 99% primitive types or strings (which are immutable and passed by ptr) | ||||||
| // this should be a relatively cheap process | // this should be a relatively cheap process | ||||||
|  |  | ||||||
|  | @ -48,6 +48,11 @@ type Account interface { | ||||||
| 	// UpdateAccount updates one account by ID. | 	// UpdateAccount updates one account by ID. | ||||||
| 	UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) | 	UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error) | ||||||
| 
 | 
 | ||||||
|  | 	// DeleteAccount deletes one account from the database by its ID. | ||||||
|  | 	// DO NOT USE THIS WHEN SUSPENDING ACCOUNTS! In that case you should mark the | ||||||
|  | 	// account as suspended instead, rather than deleting from the db entirely. | ||||||
|  | 	DeleteAccount(ctx context.Context, id string) Error | ||||||
|  | 
 | ||||||
| 	// GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username. | 	// GetAccountCustomCSSByUsername returns the custom css of an account on this instance with the given username. | ||||||
| 	GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error) | 	GetAccountCustomCSSByUsername(ctx context.Context, username string) (string, Error) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -62,11 +62,11 @@ type Basic interface { | ||||||
| 	// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. | 	// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. | ||||||
| 	Put(ctx context.Context, i interface{}) Error | 	Put(ctx context.Context, i interface{}) Error | ||||||
| 
 | 
 | ||||||
| 	// UpdateByPrimaryKey updates values of i based on its primary key. | 	// UpdateByID updates values of i based on its id. | ||||||
| 	// If any columns are specified, these will be updated exclusively. | 	// If any columns are specified, these will be updated exclusively. | ||||||
| 	// Otherwise, the whole model will be updated. | 	// Otherwise, the whole model will be updated. | ||||||
| 	// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. | 	// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. | ||||||
| 	UpdateByPrimaryKey(ctx context.Context, i interface{}, columns ...string) Error | 	UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) Error | ||||||
| 
 | 
 | ||||||
| 	// UpdateWhere updates column key of interface i with the given value, where the given parameters apply. | 	// UpdateWhere updates column key of interface i with the given value, where the given parameters apply. | ||||||
| 	UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error | 	UpdateWhere(ctx context.Context, where []Where, key string, value interface{}, i interface{}) Error | ||||||
|  |  | ||||||
|  | @ -21,7 +21,6 @@ package bundb | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | @ -56,7 +55,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac | ||||||
| 			return a.cache.GetByID(id) | 			return a.cache.GetByID(id) | ||||||
| 		}, | 		}, | ||||||
| 		func(account *gtsmodel.Account) error { | 		func(account *gtsmodel.Account) error { | ||||||
| 			return a.newAccountQ(account).Where("account.id = ?", id).Scan(ctx) | 			return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx) | ||||||
| 		}, | 		}, | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
|  | @ -68,7 +67,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel. | ||||||
| 			return a.cache.GetByURI(uri) | 			return a.cache.GetByURI(uri) | ||||||
| 		}, | 		}, | ||||||
| 		func(account *gtsmodel.Account) error { | 		func(account *gtsmodel.Account) error { | ||||||
| 			return a.newAccountQ(account).Where("account.uri = ?", uri).Scan(ctx) | 			return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx) | ||||||
| 		}, | 		}, | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
|  | @ -80,7 +79,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel. | ||||||
| 			return a.cache.GetByURL(url) | 			return a.cache.GetByURL(url) | ||||||
| 		}, | 		}, | ||||||
| 		func(account *gtsmodel.Account) error { | 		func(account *gtsmodel.Account) error { | ||||||
| 			return a.newAccountQ(account).Where("account.url = ?", url).Scan(ctx) | 			return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx) | ||||||
| 		}, | 		}, | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
|  | @ -95,11 +94,11 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str | ||||||
| 			q := a.newAccountQ(account) | 			q := a.newAccountQ(account) | ||||||
| 
 | 
 | ||||||
| 			if domain != "" { | 			if domain != "" { | ||||||
| 				q = q.Where("account.username = ?", username) | 				q = q.Where("? = ?", bun.Ident("account.username"), username) | ||||||
| 				q = q.Where("account.domain = ?", domain) | 				q = q.Where("? = ?", bun.Ident("account.domain"), domain) | ||||||
| 			} else { | 			} else { | ||||||
| 				q = q.Where("account.username = ?", strings.ToLower(username)) | 				q = q.Where("? = ?", bun.Ident("account.username"), strings.ToLower(username)) | ||||||
| 				q = q.Where("account.domain IS NULL") | 				q = q.Where("? IS NULL", bun.Ident("account.domain")) | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			return q.Scan(ctx) | 			return q.Scan(ctx) | ||||||
|  | @ -114,7 +113,7 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo | ||||||
| 			return a.cache.GetByPubkeyID(id) | 			return a.cache.GetByPubkeyID(id) | ||||||
| 		}, | 		}, | ||||||
| 		func(account *gtsmodel.Account) error { | 		func(account *gtsmodel.Account) error { | ||||||
| 			return a.newAccountQ(account).Where("account.public_key_uri = ?", id).Scan(ctx) | 			return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx) | ||||||
| 		}, | 		}, | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
|  | @ -169,16 +168,19 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account | ||||||
| 	if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { | 	if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { | ||||||
| 		// create links between this account and any emojis it uses | 		// create links between this account and any emojis it uses | ||||||
| 		// first clear out any old emoji links | 		// first clear out any old emoji links | ||||||
| 		if _, err := tx.NewDelete(). | 		if _, err := tx. | ||||||
| 			Model(&[]*gtsmodel.AccountToEmoji{}). | 			NewDelete(). | ||||||
| 			Where("account_id = ?", account.ID). | 			TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). | ||||||
|  | 			Where("? = ?", bun.Ident("account_to_emoji.account_id"), account.ID). | ||||||
| 			Exec(ctx); err != nil { | 			Exec(ctx); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// now populate new emoji links | 		// now populate new emoji links | ||||||
| 		for _, i := range account.EmojiIDs { | 		for _, i := range account.EmojiIDs { | ||||||
| 			if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ | 			if _, err := tx. | ||||||
|  | 				NewInsert(). | ||||||
|  | 				Model(>smodel.AccountToEmoji{ | ||||||
| 					AccountID: account.ID, | 					AccountID: account.ID, | ||||||
| 					EmojiID:   i, | 					EmojiID:   i, | ||||||
| 				}).Exec(ctx); err != nil { | 				}).Exec(ctx); err != nil { | ||||||
|  | @ -187,8 +189,15 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// update the account | 		// update the account | ||||||
| 		_, err := tx.NewUpdate().Model(account).WherePK().Exec(ctx) | 		if _, err := tx. | ||||||
|  | 			NewUpdate(). | ||||||
|  | 			Model(account). | ||||||
|  | 			Where("? = ?", bun.Ident("account.id"), account.ID). | ||||||
|  | 			Exec(ctx); err != nil { | ||||||
| 			return err | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return nil | ||||||
| 	}); err != nil { | 	}); err != nil { | ||||||
| 		return nil, a.conn.ProcessError(err) | 		return nil, a.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
|  | @ -197,6 +206,32 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account | ||||||
| 	return account, nil | 	return account, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error { | ||||||
|  | 	if err := a.conn.RunInTx(ctx, func(tx bun.Tx) error { | ||||||
|  | 		// clear out any emoji links | ||||||
|  | 		if _, err := tx. | ||||||
|  | 			NewDelete(). | ||||||
|  | 			TableExpr("? AS ?", bun.Ident("account_to_emojis"), bun.Ident("account_to_emoji")). | ||||||
|  | 			Where("? = ?", bun.Ident("account_to_emoji.account_id"), id). | ||||||
|  | 			Exec(ctx); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// delete the account | ||||||
|  | 		_, err := tx. | ||||||
|  | 			NewUpdate(). | ||||||
|  | 			TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). | ||||||
|  | 			Where("? = ?", bun.Ident("account.id"), id). | ||||||
|  | 			Exec(ctx) | ||||||
|  | 		return err | ||||||
|  | 	}); err != nil { | ||||||
|  | 		return a.conn.ProcessError(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	a.cache.Invalidate(id) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { | func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) { | ||||||
| 	account := new(gtsmodel.Account) | 	account := new(gtsmodel.Account) | ||||||
| 
 | 
 | ||||||
|  | @ -204,11 +239,11 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts | ||||||
| 
 | 
 | ||||||
| 	if domain != "" { | 	if domain != "" { | ||||||
| 		q = q. | 		q = q. | ||||||
| 			Where("account.username = ?", domain). | 			Where("? = ?", bun.Ident("account.username"), domain). | ||||||
| 			Where("account.domain = ?", domain) | 			Where("? = ?", bun.Ident("account.domain"), domain) | ||||||
| 	} else { | 	} else { | ||||||
| 		q = q. | 		q = q. | ||||||
| 			Where("account.username = ?", config.GetHost()). | 			Where("? = ?", bun.Ident("account.username"), config.GetHost()). | ||||||
| 			WhereGroup(" AND ", whereEmptyOrNull("domain")) | 			WhereGroup(" AND ", whereEmptyOrNull("domain")) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -224,10 +259,10 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) | ||||||
| 	q := a.conn. | 	q := a.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(status). | 		Model(status). | ||||||
| 		Order("id DESC"). | 		Column("status.created_at"). | ||||||
| 		Limit(1). | 		Where("? = ?", bun.Ident("status.account_id"), accountID). | ||||||
| 		Where("account_id = ?", accountID). | 		Order("status.id DESC"). | ||||||
| 		Column("created_at") | 		Limit(1) | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx); err != nil { | 	if err := q.Scan(ctx); err != nil { | ||||||
| 		return time.Time{}, a.conn.ProcessError(err) | 		return time.Time{}, a.conn.ProcessError(err) | ||||||
|  | @ -240,12 +275,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen | ||||||
| 		return errors.New("one media attachment cannot be both header and avatar") | 		return errors.New("one media attachment cannot be both header and avatar") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var headerOrAVI string | 	var column bun.Ident | ||||||
| 	switch { | 	switch { | ||||||
| 	case *mediaAttachment.Avatar: | 	case *mediaAttachment.Avatar: | ||||||
| 		headerOrAVI = "avatar" | 		column = bun.Ident("account.avatar_media_attachment_id") | ||||||
| 	case *mediaAttachment.Header: | 	case *mediaAttachment.Header: | ||||||
| 		headerOrAVI = "header" | 		column = bun.Ident("account.header_media_attachment_id") | ||||||
| 	default: | 	default: | ||||||
| 		return errors.New("given media attachment was neither a header nor an avatar") | 		return errors.New("given media attachment was neither a header nor an avatar") | ||||||
| 	} | 	} | ||||||
|  | @ -257,11 +292,12 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen | ||||||
| 		Exec(ctx); err != nil { | 		Exec(ctx); err != nil { | ||||||
| 		return a.conn.ProcessError(err) | 		return a.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	if _, err := a.conn. | 	if _, err := a.conn. | ||||||
| 		NewUpdate(). | 		NewUpdate(). | ||||||
| 		Model(>smodel.Account{}). | 		TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). | ||||||
| 		Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID). | 		Set("? = ?", column, mediaAttachment.ID). | ||||||
| 		Where("id = ?", accountID). | 		Where("? = ?", bun.Ident("account.id"), accountID). | ||||||
| 		Exec(ctx); err != nil { | 		Exec(ctx); err != nil { | ||||||
| 		return a.conn.ProcessError(err) | 		return a.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
|  | @ -284,7 +320,7 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g | ||||||
| 	if err := a.conn. | 	if err := a.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(faves). | 		Model(faves). | ||||||
| 		Where("account_id = ?", accountID). | 		Where("? = ?", bun.Ident("status_fave.account_id"), accountID). | ||||||
| 		Scan(ctx); err != nil { | 		Scan(ctx); err != nil { | ||||||
| 		return nil, a.conn.ProcessError(err) | 		return nil, a.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
|  | @ -295,8 +331,8 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g | ||||||
| func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) { | func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, db.Error) { | ||||||
| 	return a.conn. | 	return a.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.Status{}). | 		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). | ||||||
| 		Where("account_id = ?", accountID). | 		Where("? = ?", bun.Ident("status.account_id"), accountID). | ||||||
| 		Count(ctx) | 		Count(ctx) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -305,12 +341,12 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li | ||||||
| 
 | 
 | ||||||
| 	q := a.conn. | 	q := a.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Table("statuses"). | 		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). | ||||||
| 		Column("id"). | 		Column("status.id"). | ||||||
| 		Order("id DESC") | 		Order("status.id DESC") | ||||||
| 
 | 
 | ||||||
| 	if accountID != "" { | 	if accountID != "" { | ||||||
| 		q = q.Where("account_id = ?", accountID) | 		q = q.Where("? = ?", bun.Ident("status.account_id"), accountID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if limit != 0 { | 	if limit != 0 { | ||||||
|  | @ -321,27 +357,27 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li | ||||||
| 		// include self-replies (threads) | 		// include self-replies (threads) | ||||||
| 		whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { | 		whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { | ||||||
| 			return q. | 			return q. | ||||||
| 				WhereOr("in_reply_to_account_id = ?", accountID). | 				WhereOr("? = ?", bun.Ident("status.in_reply_to_account_id"), accountID). | ||||||
| 				WhereGroup(" OR ", whereEmptyOrNull("in_reply_to_uri")) | 				WhereGroup(" OR ", whereEmptyOrNull("status.in_reply_to_uri")) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		q = q.WhereGroup(" AND ", whereGroup) | 		q = q.WhereGroup(" AND ", whereGroup) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if excludeReblogs { | 	if excludeReblogs { | ||||||
| 		q = q.WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")) | 		q = q.WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if maxID != "" { | 	if maxID != "" { | ||||||
| 		q = q.Where("id < ?", maxID) | 		q = q.Where("? < ?", bun.Ident("status.id"), maxID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if minID != "" { | 	if minID != "" { | ||||||
| 		q = q.Where("id > ?", minID) | 		q = q.Where("? > ?", bun.Ident("status.id"), minID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if pinnedOnly { | 	if pinnedOnly { | ||||||
| 		q = q.Where("pinned = ?", true) | 		q = q.Where("? = ?", bun.Ident("status.pinned"), true) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if mediaOnly { | 	if mediaOnly { | ||||||
|  | @ -352,15 +388,15 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li | ||||||
| 			switch a.conn.Dialect().Name() { | 			switch a.conn.Dialect().Name() { | ||||||
| 			case dialect.PG: | 			case dialect.PG: | ||||||
| 				return q. | 				return q. | ||||||
| 					Where("? IS NOT NULL", bun.Ident("attachments")). | 					Where("? IS NOT NULL", bun.Ident("status.attachments")). | ||||||
| 					Where("? != '{}'", bun.Ident("attachments")) | 					Where("? != '{}'", bun.Ident("status.attachments")) | ||||||
| 			case dialect.SQLite: | 			case dialect.SQLite: | ||||||
| 				return q. | 				return q. | ||||||
| 					Where("? IS NOT NULL", bun.Ident("attachments")). | 					Where("? IS NOT NULL", bun.Ident("status.attachments")). | ||||||
| 					Where("? != ''", bun.Ident("attachments")). | 					Where("? != ''", bun.Ident("status.attachments")). | ||||||
| 					Where("? != 'null'", bun.Ident("attachments")). | 					Where("? != 'null'", bun.Ident("status.attachments")). | ||||||
| 					Where("? != '{}'", bun.Ident("attachments")). | 					Where("? != '{}'", bun.Ident("status.attachments")). | ||||||
| 					Where("? != '[]'", bun.Ident("attachments")) | 					Where("? != '[]'", bun.Ident("status.attachments")) | ||||||
| 			default: | 			default: | ||||||
| 				log.Panic("db dialect was neither pg nor sqlite") | 				log.Panic("db dialect was neither pg nor sqlite") | ||||||
| 				return q | 				return q | ||||||
|  | @ -369,7 +405,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if publicOnly { | 	if publicOnly { | ||||||
| 		q = q.Where("visibility = ?", gtsmodel.VisibilityPublic) | 		q = q.Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx, &statusIDs); err != nil { | 	if err := q.Scan(ctx, &statusIDs); err != nil { | ||||||
|  | @ -384,19 +420,19 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, | ||||||
| 
 | 
 | ||||||
| 	q := a.conn. | 	q := a.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Table("statuses"). | 		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). | ||||||
| 		Column("id"). | 		Column("status.id"). | ||||||
| 		Where("account_id = ?", accountID). | 		Where("? = ?", bun.Ident("status.account_id"), accountID). | ||||||
| 		WhereGroup(" AND ", whereEmptyOrNull("in_reply_to_uri")). | 		WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")). | ||||||
| 		WhereGroup(" AND ", whereEmptyOrNull("boost_of_id")). | 		WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")). | ||||||
| 		Where("visibility = ?", gtsmodel.VisibilityPublic). | 		Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). | ||||||
| 		Where("federated = ?", true) | 		Where("? = ?", bun.Ident("status.federated"), true) | ||||||
| 
 | 
 | ||||||
| 	if maxID != "" { | 	if maxID != "" { | ||||||
| 		q = q.Where("id < ?", maxID) | 		q = q.Where("? < ?", bun.Ident("status.id"), maxID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	q = q.Limit(limit).Order("id DESC") | 	q = q.Limit(limit).Order("status.id DESC") | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx, &statusIDs); err != nil { | 	if err := q.Scan(ctx, &statusIDs); err != nil { | ||||||
| 		return nil, a.conn.ProcessError(err) | 		return nil, a.conn.ProcessError(err) | ||||||
|  | @ -411,16 +447,16 @@ func (a *accountDB) GetAccountBlocks(ctx context.Context, accountID string, maxI | ||||||
| 	fq := a.conn. | 	fq := a.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(&blocks). | 		Model(&blocks). | ||||||
| 		Where("block.account_id = ?", accountID). | 		Where("? = ?", bun.Ident("block.account_id"), accountID). | ||||||
| 		Relation("TargetAccount"). | 		Relation("TargetAccount"). | ||||||
| 		Order("block.id DESC") | 		Order("block.id DESC") | ||||||
| 
 | 
 | ||||||
| 	if maxID != "" { | 	if maxID != "" { | ||||||
| 		fq = fq.Where("block.id < ?", maxID) | 		fq = fq.Where("? < ?", bun.Ident("block.id"), maxID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if sinceID != "" { | 	if sinceID != "" { | ||||||
| 		fq = fq.Where("block.id > ?", sinceID) | 		fq = fq.Where("? > ?", bun.Ident("block.id"), sinceID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if limit > 0 { | 	if limit > 0 { | ||||||
|  |  | ||||||
|  | @ -42,6 +42,18 @@ func (suite *AccountTestSuite) TestGetAccountStatuses() { | ||||||
| 	suite.Len(statuses, 5) | 	suite.Len(statuses, 5) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogs() { | ||||||
|  | 	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false, false) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Len(statuses, 5) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *AccountTestSuite) TestGetAccountStatusesExcludeRepliesAndReblogsPublicOnly() { | ||||||
|  | 	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, true, true, "", "", false, false, true) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Len(statuses, 1) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (suite *AccountTestSuite) TestGetAccountStatusesMediaOnly() { | func (suite *AccountTestSuite) TestGetAccountStatusesMediaOnly() { | ||||||
| 	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, false, false, "", "", false, true, false) | 	statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, false, false, "", "", false, true, false) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
|  | @ -99,7 +111,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() { | ||||||
| 	err = dbService.GetConn(). | 	err = dbService.GetConn(). | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(noCache). | 		Model(noCache). | ||||||
| 		Where("account.id = ?", bun.Ident(testAccount.ID)). | 		Where("? = ?", bun.Ident("account.id"), testAccount.ID). | ||||||
| 		Relation("AvatarMediaAttachment"). | 		Relation("AvatarMediaAttachment"). | ||||||
| 		Relation("HeaderMediaAttachment"). | 		Relation("HeaderMediaAttachment"). | ||||||
| 		Relation("Emojis"). | 		Relation("Emojis"). | ||||||
|  | @ -127,7 +139,7 @@ func (suite *AccountTestSuite) TestUpdateAccount() { | ||||||
| 	err = dbService.GetConn(). | 	err = dbService.GetConn(). | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(noCache). | 		Model(noCache). | ||||||
| 		Where("account.id = ?", bun.Ident(testAccount.ID)). | 		Where("? = ?", bun.Ident("account.id"), testAccount.ID). | ||||||
| 		Relation("AvatarMediaAttachment"). | 		Relation("AvatarMediaAttachment"). | ||||||
| 		Relation("HeaderMediaAttachment"). | 		Relation("HeaderMediaAttachment"). | ||||||
| 		Relation("Emojis"). | 		Relation("Emojis"). | ||||||
|  |  | ||||||
|  | @ -22,7 +22,6 @@ import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
| 	"crypto/rsa" | 	"crypto/rsa" | ||||||
| 	"database/sql" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/mail" | 	"net/mail" | ||||||
|  | @ -37,21 +36,26 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/id" | 	"github.com/superseriousbusiness/gotosocial/internal/id" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/uris" | 	"github.com/superseriousbusiness/gotosocial/internal/uris" | ||||||
|  | 	"github.com/uptrace/bun" | ||||||
| 	"golang.org/x/crypto/bcrypt" | 	"golang.org/x/crypto/bcrypt" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // generate RSA keys of this length | ||||||
|  | const rsaKeyBits = 2048 | ||||||
|  | 
 | ||||||
| type adminDB struct { | type adminDB struct { | ||||||
| 	conn         *DBConn | 	conn         *DBConn | ||||||
| 	userCache    *cache.UserCache | 	userCache    *cache.UserCache | ||||||
|  | 	accountCache *cache.AccountCache | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { | func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { | ||||||
| 	q := a.conn. | 	q := a.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.Account{}). | 		TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). | ||||||
| 		Where("username = ?", username). | 		Column("account.id"). | ||||||
| 		Where("domain = ?", nil) | 		Where("? = ?", bun.Ident("account.username"), username). | ||||||
| 
 | 		Where("? IS NULL", bun.Ident("account.domain")) | ||||||
| 	return a.conn.NotExists(ctx, q) | 	return a.conn.NotExists(ctx, q) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -64,29 +68,31 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db. | ||||||
| 	domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ | 	domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ | ||||||
| 
 | 
 | ||||||
| 	// check if the email domain is blocked | 	// check if the email domain is blocked | ||||||
| 	if err := a.conn. | 	emailDomainBlockedQ := a.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.EmailDomainBlock{}). | 		TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")). | ||||||
| 		Where("domain = ?", domain). | 		Column("email_domain_block.id"). | ||||||
| 		Scan(ctx); err == nil { | 		Where("? = ?", bun.Ident("email_domain_block.domain"), domain) | ||||||
| 		// fail because we found something | 	emailDomainBlocked, err := a.conn.Exists(ctx, emailDomainBlockedQ) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 	if emailDomainBlocked { | ||||||
| 		return false, fmt.Errorf("email domain %s is blocked", domain) | 		return false, fmt.Errorf("email domain %s is blocked", domain) | ||||||
| 	} else if err != sql.ErrNoRows { |  | ||||||
| 		return false, a.conn.ProcessError(err) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// check if this email is associated with a user already | 	// check if this email is associated with a user already | ||||||
| 	q := a.conn. | 	q := a.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.User{}). | 		TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")). | ||||||
| 		Where("email = ?", email). | 		Column("user.id"). | ||||||
| 		WhereOr("unconfirmed_email = ?", email) | 		Where("? = ?", bun.Ident("user.email"), email). | ||||||
| 
 | 		WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email) | ||||||
| 	return a.conn.NotExists(ctx, q) | 	return a.conn.NotExists(ctx, q) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { | func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { | ||||||
| 	key, err := rsa.GenerateKey(rand.Reader, 2048) | 	key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Errorf("error creating new rsa key: %s", err) | 		log.Errorf("error creating new rsa key: %s", err) | ||||||
| 		return nil, err | 		return nil, err | ||||||
|  | @ -94,13 +100,20 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, | ||||||
| 
 | 
 | ||||||
| 	// if something went wrong while creating a user, we might already have an account, so check here first... | 	// if something went wrong while creating a user, we might already have an account, so check here first... | ||||||
| 	acct := >smodel.Account{} | 	acct := >smodel.Account{} | ||||||
| 	q := a.conn.NewSelect(). | 	if err := a.conn. | ||||||
|  | 		NewSelect(). | ||||||
| 		Model(acct). | 		Model(acct). | ||||||
| 		Where("username = ?", username). | 		Where("? = ?", bun.Ident("account.username"), username). | ||||||
| 		WhereGroup(" AND ", whereEmptyOrNull("domain")) | 		WhereGroup(" AND ", whereEmptyOrNull("account.domain")). | ||||||
|  | 		Scan(ctx); err != nil { | ||||||
|  | 		err = a.conn.ProcessError(err) | ||||||
|  | 		if err != db.ErrNoEntries { | ||||||
|  | 			log.Errorf("error checking for existing account: %s", err) | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx); err != nil { | 		// if we have db.ErrNoEntries, we just don't have an | ||||||
| 		// we just don't have an account yet so create one before we proceed | 		// account yet so create one before we proceed | ||||||
| 		accountURIs := uris.GenerateURIsForAccount(username) | 		accountURIs := uris.GenerateURIsForAccount(username) | ||||||
| 		accountID, err := id.NewRandomULID() | 		accountID, err := id.NewRandomULID() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | @ -126,14 +139,19 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, | ||||||
| 			FeaturedCollectionURI: accountURIs.CollectionURI, | 			FeaturedCollectionURI: accountURIs.CollectionURI, | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		// insert the new account! | ||||||
| 		if _, err = a.conn. | 		if _, err = a.conn. | ||||||
| 			NewInsert(). | 			NewInsert(). | ||||||
| 			Model(acct). | 			Model(acct). | ||||||
| 			Exec(ctx); err != nil { | 			Exec(ctx); err != nil { | ||||||
| 			return nil, a.conn.ProcessError(err) | 			return nil, a.conn.ProcessError(err) | ||||||
| 		} | 		} | ||||||
|  | 		a.accountCache.Put(acct) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// we either created or already had an account by now, | ||||||
|  | 	// so proceed with creating a user for that account | ||||||
|  | 
 | ||||||
| 	pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) | 	pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("error hashing password: %s", err) | 		return nil, fmt.Errorf("error hashing password: %s", err) | ||||||
|  | @ -171,6 +189,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, | ||||||
| 		u.Moderator = &moderator | 		u.Moderator = &moderator | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// insert the user! | ||||||
| 	if _, err = a.conn. | 	if _, err = a.conn. | ||||||
| 		NewInsert(). | 		NewInsert(). | ||||||
| 		Model(u). | 		Model(u). | ||||||
|  | @ -187,9 +206,10 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { | ||||||
| 
 | 
 | ||||||
| 	q := a.conn. | 	q := a.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.Account{}). | 		TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). | ||||||
| 		Where("username = ?", username). | 		Column("account.id"). | ||||||
| 		WhereGroup(" AND ", whereEmptyOrNull("domain")) | 		Where("? = ?", bun.Ident("account.username"), username). | ||||||
|  | 		WhereGroup(" AND ", whereEmptyOrNull("account.domain")) | ||||||
| 
 | 
 | ||||||
| 	exists, err := a.conn.Exists(ctx, q) | 	exists, err := a.conn.Exists(ctx, q) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -200,7 +220,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	key, err := rsa.GenerateKey(rand.Reader, 2048) | 	key, err := rsa.GenerateKey(rand.Reader, rsaKeyBits) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Errorf("error creating new rsa key: %s", err) | 		log.Errorf("error creating new rsa key: %s", err) | ||||||
| 		return err | 		return err | ||||||
|  | @ -237,6 +257,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { | ||||||
| 		return a.conn.ProcessError(err) | 		return a.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	a.accountCache.Put(acct) | ||||||
| 	log.Infof("instance account %s CREATED with id %s", username, acct.ID) | 	log.Infof("instance account %s CREATED with id %s", username, acct.ID) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  | @ -248,8 +269,9 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error { | ||||||
| 	// check if instance entry already exists | 	// check if instance entry already exists | ||||||
| 	q := a.conn. | 	q := a.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.Instance{}). | 		Column("instance.id"). | ||||||
| 		Where("domain = ?", host) | 		TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")). | ||||||
|  | 		Where("? = ?", bun.Ident("instance.domain"), host) | ||||||
| 
 | 
 | ||||||
| 	exists, err := a.conn.Exists(ctx, q) | 	exists, err := a.conn.Exists(ctx, q) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  |  | ||||||
|  | @ -23,6 +23,7 @@ import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"github.com/stretchr/testify/suite" | 	"github.com/stretchr/testify/suite" | ||||||
|  | 	gtsmodel "github.com/superseriousbusiness/gotosocial/internal/db/bundb/migrations/20211113114307_init" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/testrig" | 	"github.com/superseriousbusiness/gotosocial/testrig" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -30,6 +31,44 @@ type AdminTestSuite struct { | ||||||
| 	BunDBStandardTestSuite | 	BunDBStandardTestSuite | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (suite *AdminTestSuite) TestIsUsernameAvailableNo() { | ||||||
|  | 	available, err := suite.db.IsUsernameAvailable(context.Background(), "the_mighty_zork") | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.False(available) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *AdminTestSuite) TestIsUsernameAvailableYes() { | ||||||
|  | 	available, err := suite.db.IsUsernameAvailable(context.Background(), "someone_completely_different") | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.True(available) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *AdminTestSuite) TestIsEmailAvailableNo() { | ||||||
|  | 	available, err := suite.db.IsEmailAvailable(context.Background(), "zork@example.org") | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.False(available) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *AdminTestSuite) TestIsEmailAvailableYes() { | ||||||
|  | 	available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com") | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.True(available) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() { | ||||||
|  | 	if err := suite.db.Put(context.Background(), >smodel.EmailDomainBlock{ | ||||||
|  | 		ID:                 "01GEEV2R2YC5GRSN96761YJE47", | ||||||
|  | 		Domain:             "somewhere.com", | ||||||
|  | 		CreatedByAccountID: suite.testAccounts["admin_account"].ID, | ||||||
|  | 	}); err != nil { | ||||||
|  | 		suite.FailNow(err.Error()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	available, err := suite.db.IsEmailAvailable(context.Background(), "someone@somewhere.com") | ||||||
|  | 	suite.EqualError(err, "email domain somewhere.com is blocked") | ||||||
|  | 	suite.False(available) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (suite *AdminTestSuite) TestCreateInstanceAccount() { | func (suite *AdminTestSuite) TestCreateInstanceAccount() { | ||||||
| 	// we need to take an empty db for this... | 	// we need to take an empty db for this... | ||||||
| 	testrig.StandardDBTeardown(suite.db) | 	testrig.StandardDBTeardown(suite.db) | ||||||
|  |  | ||||||
|  | @ -94,12 +94,12 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface | ||||||
| 	return b.conn.ProcessError(err) | 	return b.conn.ProcessError(err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (b *basicDB) UpdateByPrimaryKey(ctx context.Context, i interface{}, columns ...string) db.Error { | func (b *basicDB) UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) db.Error { | ||||||
| 	q := b.conn. | 	q := b.conn. | ||||||
| 		NewUpdate(). | 		NewUpdate(). | ||||||
| 		Model(i). | 		Model(i). | ||||||
| 		Column(columns...). | 		Column(columns...). | ||||||
| 		WherePK() | 		Where("? = ?", bun.Ident("id"), id) | ||||||
| 
 | 
 | ||||||
| 	_, err := q.Exec(ctx) | 	_, err := q.Exec(ctx) | ||||||
| 	return b.conn.ProcessError(err) | 	return b.conn.ProcessError(err) | ||||||
|  | @ -110,7 +110,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, | ||||||
| 
 | 
 | ||||||
| 	updateWhere(q, where) | 	updateWhere(q, where) | ||||||
| 
 | 
 | ||||||
| 	q = q.Set("? = ?", bun.Safe(key), value) | 	q = q.Set("? = ?", bun.Ident(key), value) | ||||||
| 
 | 
 | ||||||
| 	_, err := q.Exec(ctx) | 	_, err := q.Exec(ctx) | ||||||
| 	return b.conn.ProcessError(err) | 	return b.conn.ProcessError(err) | ||||||
|  |  | ||||||
|  | @ -159,17 +159,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { | ||||||
| 		return nil, fmt.Errorf("db migration error: %s", err) | 		return nil, fmt.Errorf("db migration error: %s", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Create DB structs that require ptrs to each other | 	// Prepare caches required by more than one struct | ||||||
| 	accounts := &accountDB{conn: conn, cache: cache.NewAccountCache()} | 	userCache := cache.NewUserCache() | ||||||
| 	status := &statusDB{conn: conn, cache: cache.NewStatusCache()} | 	accountCache := cache.NewAccountCache() | ||||||
| 	emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()} |  | ||||||
| 	timeline := &timelineDB{conn: conn} |  | ||||||
| 
 |  | ||||||
| 	// Setup DB cross-referencing |  | ||||||
| 	accounts.status = status |  | ||||||
| 	status.accounts = accounts |  | ||||||
| 	timeline.status = status |  | ||||||
| 
 | 
 | ||||||
|  | 	// Prepare other caches | ||||||
| 	// Prepare mentions cache | 	// Prepare mentions cache | ||||||
| 	// TODO: move into internal/cache | 	// TODO: move into internal/cache | ||||||
| 	mentionCache := grufcache.New[string, *gtsmodel.Mention]() | 	mentionCache := grufcache.New[string, *gtsmodel.Mention]() | ||||||
|  | @ -182,22 +176,30 @@ func NewBunDBService(ctx context.Context) (db.DB, error) { | ||||||
| 	notifCache.SetTTL(time.Minute*5, false) | 	notifCache.SetTTL(time.Minute*5, false) | ||||||
| 	notifCache.Start(time.Second * 10) | 	notifCache.Start(time.Second * 10) | ||||||
| 
 | 
 | ||||||
| 	// Prepare other caches | 	// Create DB structs that require ptrs to each other | ||||||
| 	blockCache := cache.NewDomainBlockCache() | 	accounts := &accountDB{conn: conn, cache: accountCache} | ||||||
| 	userCache := cache.NewUserCache() | 	status := &statusDB{conn: conn, cache: cache.NewStatusCache()} | ||||||
|  | 	emoji := &emojiDB{conn: conn, cache: cache.NewEmojiCache()} | ||||||
|  | 	timeline := &timelineDB{conn: conn} | ||||||
|  | 
 | ||||||
|  | 	// Setup DB cross-referencing | ||||||
|  | 	accounts.status = status | ||||||
|  | 	status.accounts = accounts | ||||||
|  | 	timeline.status = status | ||||||
| 
 | 
 | ||||||
| 	ps := &DBService{ | 	ps := &DBService{ | ||||||
| 		Account: accounts, | 		Account: accounts, | ||||||
| 		Admin: &adminDB{ | 		Admin: &adminDB{ | ||||||
| 			conn:         conn, | 			conn:         conn, | ||||||
| 			userCache:    userCache, | 			userCache:    userCache, | ||||||
|  | 			accountCache: accountCache, | ||||||
| 		}, | 		}, | ||||||
| 		Basic: &basicDB{ | 		Basic: &basicDB{ | ||||||
| 			conn: conn, | 			conn: conn, | ||||||
| 		}, | 		}, | ||||||
| 		Domain: &domainDB{ | 		Domain: &domainDB{ | ||||||
| 			conn:  conn, | 			conn:  conn, | ||||||
| 			cache: blockCache, | 			cache: cache.NewDomainBlockCache(), | ||||||
| 		}, | 		}, | ||||||
| 		Emoji: emoji, | 		Emoji: emoji, | ||||||
| 		Instance: &instanceDB{ | 		Instance: &instanceDB{ | ||||||
|  |  | ||||||
|  | @ -40,6 +40,7 @@ type BunDBStandardTestSuite struct { | ||||||
| 	testStatuses     map[string]*gtsmodel.Status | 	testStatuses     map[string]*gtsmodel.Status | ||||||
| 	testTags         map[string]*gtsmodel.Tag | 	testTags         map[string]*gtsmodel.Tag | ||||||
| 	testMentions     map[string]*gtsmodel.Mention | 	testMentions     map[string]*gtsmodel.Mention | ||||||
|  | 	testFollows      map[string]*gtsmodel.Follow | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *BunDBStandardTestSuite) SetupSuite() { | func (suite *BunDBStandardTestSuite) SetupSuite() { | ||||||
|  | @ -52,6 +53,7 @@ func (suite *BunDBStandardTestSuite) SetupSuite() { | ||||||
| 	suite.testStatuses = testrig.NewTestStatuses() | 	suite.testStatuses = testrig.NewTestStatuses() | ||||||
| 	suite.testTags = testrig.NewTestTags() | 	suite.testTags = testrig.NewTestTags() | ||||||
| 	suite.testMentions = testrig.NewTestMentions() | 	suite.testMentions = testrig.NewTestMentions() | ||||||
|  | 	suite.testFollows = testrig.NewTestFollows() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *BunDBStandardTestSuite) SetupTest() { | func (suite *BunDBStandardTestSuite) SetupTest() { | ||||||
|  |  | ||||||
|  | @ -28,6 +28,7 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/config" | 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/uptrace/bun" | ||||||
| 	"golang.org/x/net/idna" | 	"golang.org/x/net/idna" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -95,7 +96,7 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel | ||||||
| 	q := d.conn. | 	q := d.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(block). | 		Model(block). | ||||||
| 		Where("domain = ?", domain). | 		Where("? = ?", bun.Ident("domain_block.domain"), domain). | ||||||
| 		Limit(1) | 		Limit(1) | ||||||
| 
 | 
 | ||||||
| 	// Query database for domain block | 	// Query database for domain block | ||||||
|  | @ -126,7 +127,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro | ||||||
| 	// Attempt to delete domain block | 	// Attempt to delete domain block | ||||||
| 	if _, err := d.conn.NewDelete(). | 	if _, err := d.conn.NewDelete(). | ||||||
| 		Model((*gtsmodel.DomainBlock)(nil)). | 		Model((*gtsmodel.DomainBlock)(nil)). | ||||||
| 		Where("domain = ?", domain). | 		Where("? = ?", bun.Ident("domain_block.domain"), domain). | ||||||
| 		Exec(ctx); err != nil { | 		Exec(ctx); err != nil { | ||||||
| 		return d.conn.ProcessError(err) | 		return d.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -54,12 +54,12 @@ func (e *emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Er | ||||||
| 
 | 
 | ||||||
| 	q := e.conn. | 	q := e.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Table("emojis"). | 		TableExpr("? AS ?", bun.Ident("emojis"), bun.Ident("emoji")). | ||||||
| 		Column("id"). | 		Column("emoji.id"). | ||||||
| 		Where("visible_in_picker = true"). | 		Where("? = ?", bun.Ident("emoji.visible_in_picker"), true). | ||||||
| 		Where("disabled = false"). | 		Where("? = ?", bun.Ident("emoji.disabled"), false). | ||||||
| 		Where("domain IS NULL"). | 		Where("? IS NULL", bun.Ident("emoji.domain")). | ||||||
| 		Order("shortcode ASC") | 		Order("emoji.shortcode ASC") | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx, &emojiIDs); err != nil { | 	if err := q.Scan(ctx, &emojiIDs); err != nil { | ||||||
| 		return nil, e.conn.ProcessError(err) | 		return nil, e.conn.ProcessError(err) | ||||||
|  | @ -75,7 +75,7 @@ func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, | ||||||
| 			return e.cache.GetByID(id) | 			return e.cache.GetByID(id) | ||||||
| 		}, | 		}, | ||||||
| 		func(emoji *gtsmodel.Emoji) error { | 		func(emoji *gtsmodel.Emoji) error { | ||||||
| 			return e.newEmojiQ(emoji).Where("emoji.id = ?", id).Scan(ctx) | 			return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.id"), id).Scan(ctx) | ||||||
| 		}, | 		}, | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
|  | @ -87,7 +87,7 @@ func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoj | ||||||
| 			return e.cache.GetByURI(uri) | 			return e.cache.GetByURI(uri) | ||||||
| 		}, | 		}, | ||||||
| 		func(emoji *gtsmodel.Emoji) error { | 		func(emoji *gtsmodel.Emoji) error { | ||||||
| 			return e.newEmojiQ(emoji).Where("emoji.uri = ?", uri).Scan(ctx) | 			return e.newEmojiQ(emoji).Where("? = ?", bun.Ident("emoji.uri"), uri).Scan(ctx) | ||||||
| 		}, | 		}, | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
|  | @ -102,11 +102,11 @@ func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode strin | ||||||
| 			q := e.newEmojiQ(emoji) | 			q := e.newEmojiQ(emoji) | ||||||
| 
 | 
 | ||||||
| 			if domain != "" { | 			if domain != "" { | ||||||
| 				q = q.Where("emoji.shortcode = ?", shortcode) | 				q = q.Where("? = ?", bun.Ident("emoji.shortcode"), shortcode) | ||||||
| 				q = q.Where("emoji.domain = ?", domain) | 				q = q.Where("? = ?", bun.Ident("emoji.domain"), domain) | ||||||
| 			} else { | 			} else { | ||||||
| 				q = q.Where("emoji.shortcode = ?", strings.ToLower(shortcode)) | 				q = q.Where("? = ?", bun.Ident("emoji.shortcode"), strings.ToLower(shortcode)) | ||||||
| 				q = q.Where("emoji.domain IS NULL") | 				q = q.Where("? IS NULL", bun.Ident("emoji.domain")) | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			return q.Scan(ctx) | 			return q.Scan(ctx) | ||||||
|  |  | ||||||
|  | @ -24,7 +24,6 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/config" | 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" |  | ||||||
| 	"github.com/uptrace/bun" | 	"github.com/uptrace/bun" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -35,15 +34,16 @@ type instanceDB struct { | ||||||
| func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { | func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) { | ||||||
| 	q := i.conn. | 	q := i.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(&[]*gtsmodel.Account{}). | 		TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")). | ||||||
| 		Where("username != ?", domain). | 		Column("account.id"). | ||||||
| 		Where("? IS NULL", bun.Ident("suspended_at")) | 		Where("? != ?", bun.Ident("account.username"), domain). | ||||||
|  | 		Where("? IS NULL", bun.Ident("account.suspended_at")) | ||||||
| 
 | 
 | ||||||
| 	if domain == config.GetHost() { | 	if domain == config.GetHost() || domain == config.GetAccountDomain() { | ||||||
| 		// if the domain is *this* domain, just count where the domain field is null | 		// if the domain is *this* domain, just count where the domain field is null | ||||||
| 		q = q.WhereGroup(" AND ", whereEmptyOrNull("domain")) | 		q = q.WhereGroup(" AND ", whereEmptyOrNull("account.domain")) | ||||||
| 	} else { | 	} else { | ||||||
| 		q = q.Where("domain = ?", domain) | 		q = q.Where("? = ?", bun.Ident("account.domain"), domain) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	count, err := q.Count(ctx) | 	count, err := q.Count(ctx) | ||||||
|  | @ -56,15 +56,16 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int | ||||||
| func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) { | func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) { | ||||||
| 	q := i.conn. | 	q := i.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(&[]*gtsmodel.Status{}) | 		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")) | ||||||
| 
 | 
 | ||||||
| 	if domain == config.GetHost() { | 	if domain == config.GetHost() || domain == config.GetAccountDomain() { | ||||||
| 		// if the domain is *this* domain, just count where local is true | 		// if the domain is *this* domain, just count where local is true | ||||||
| 		q = q.Where("local = ?", true) | 		q = q.Where("? = ?", bun.Ident("status.local"), true) | ||||||
| 	} else { | 	} else { | ||||||
| 		// join on the domain of the account | 		// join on the domain of the account | ||||||
| 		q = q.Join("JOIN accounts AS account ON account.id = status.account_id"). | 		q = q. | ||||||
| 			Where("account.domain = ?", domain) | 			Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("account.id"), bun.Ident("status.account_id")). | ||||||
|  | 			Where("? = ?", bun.Ident("account.domain"), domain) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	count, err := q.Count(ctx) | 	count, err := q.Count(ctx) | ||||||
|  | @ -77,14 +78,14 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) ( | ||||||
| func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) { | func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) { | ||||||
| 	q := i.conn. | 	q := i.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(&[]*gtsmodel.Instance{}) | 		TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")) | ||||||
| 
 | 
 | ||||||
| 	if domain == config.GetHost() { | 	if domain == config.GetHost() { | ||||||
| 		// if the domain is *this* domain, just count other instances it knows about | 		// if the domain is *this* domain, just count other instances it knows about | ||||||
| 		// exclude domains that are blocked | 		// exclude domains that are blocked | ||||||
| 		q = q. | 		q = q. | ||||||
| 			Where("domain != ?", domain). | 			Where("? != ?", bun.Ident("instance.domain"), domain). | ||||||
| 			Where("? IS NULL", bun.Ident("suspended_at")) | 			Where("? IS NULL", bun.Ident("instance.suspended_at")) | ||||||
| 	} else { | 	} else { | ||||||
| 		// TODO: implement federated domain counting properly for remote domains | 		// TODO: implement federated domain counting properly for remote domains | ||||||
| 		return 0, nil | 		return 0, nil | ||||||
|  | @ -103,10 +104,10 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool | ||||||
| 	q := i.conn. | 	q := i.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(&instances). | 		Model(&instances). | ||||||
| 		Where("domain != ?", config.GetHost()) | 		Where("? != ?", bun.Ident("instance.domain"), config.GetHost()) | ||||||
| 
 | 
 | ||||||
| 	if !includeSuspended { | 	if !includeSuspended { | ||||||
| 		q = q.Where("? IS NULL", bun.Ident("suspended_at")) | 		q = q.Where("? IS NULL", bun.Ident("instance.suspended_at")) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx); err != nil { | 	if err := q.Scan(ctx); err != nil { | ||||||
|  | @ -117,17 +118,15 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { | func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { | ||||||
| 	log.Debug("GetAccountsForInstance") |  | ||||||
| 
 |  | ||||||
| 	accounts := []*gtsmodel.Account{} | 	accounts := []*gtsmodel.Account{} | ||||||
| 
 | 
 | ||||||
| 	q := i.conn.NewSelect(). | 	q := i.conn.NewSelect(). | ||||||
| 		Model(&accounts). | 		Model(&accounts). | ||||||
| 		Where("domain = ?", domain). | 		Where("? = ?", bun.Ident("account.domain"), domain). | ||||||
| 		Order("id DESC") | 		Order("account.id DESC") | ||||||
| 
 | 
 | ||||||
| 	if maxID != "" { | 	if maxID != "" { | ||||||
| 		q = q.Where("id < ?", maxID) | 		q = q.Where("? < ?", bun.Ident("account.id"), maxID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if limit > 0 { | 	if limit > 0 { | ||||||
|  |  | ||||||
							
								
								
									
										83
									
								
								internal/db/bundb/instance_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								internal/db/bundb/instance_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,83 @@ | ||||||
|  | /* | ||||||
|  |    GoToSocial | ||||||
|  |    Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org | ||||||
|  | 
 | ||||||
|  |    This program is free software: you can redistribute it and/or modify | ||||||
|  |    it under the terms of the GNU Affero General Public License as published by | ||||||
|  |    the Free Software Foundation, either version 3 of the License, or | ||||||
|  |    (at your option) any later version. | ||||||
|  | 
 | ||||||
|  |    This program is distributed in the hope that it will be useful, | ||||||
|  |    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  |    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  |    GNU Affero General Public License for more details. | ||||||
|  | 
 | ||||||
|  |    You should have received a copy of the GNU Affero General Public License | ||||||
|  |    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | */ | ||||||
|  | 
 | ||||||
|  | package bundb_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"github.com/stretchr/testify/suite" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/config" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type InstanceTestSuite struct { | ||||||
|  | 	BunDBStandardTestSuite | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *InstanceTestSuite) TestCountInstanceUsers() { | ||||||
|  | 	count, err := suite.db.CountInstanceUsers(context.Background(), config.GetHost()) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Equal(4, count) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *InstanceTestSuite) TestCountInstanceUsersRemote() { | ||||||
|  | 	count, err := suite.db.CountInstanceUsers(context.Background(), "fossbros-anonymous.io") | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Equal(1, count) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *InstanceTestSuite) TestCountInstanceStatuses() { | ||||||
|  | 	count, err := suite.db.CountInstanceStatuses(context.Background(), config.GetHost()) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Equal(16, count) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *InstanceTestSuite) TestCountInstanceStatusesRemote() { | ||||||
|  | 	count, err := suite.db.CountInstanceStatuses(context.Background(), "fossbros-anonymous.io") | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Equal(1, count) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *InstanceTestSuite) TestCountInstanceDomains() { | ||||||
|  | 	count, err := suite.db.CountInstanceDomains(context.Background(), config.GetHost()) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Equal(2, count) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *InstanceTestSuite) TestGetInstancePeers() { | ||||||
|  | 	peers, err := suite.db.GetInstancePeers(context.Background(), false) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Len(peers, 2) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *InstanceTestSuite) TestGetInstancePeersIncludeSuspended() { | ||||||
|  | 	peers, err := suite.db.GetInstancePeers(context.Background(), true) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Len(peers, 2) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *InstanceTestSuite) TestGetInstanceAccounts() { | ||||||
|  | 	accounts, err := suite.db.GetInstanceAccounts(context.Background(), "fossbros-anonymous.io", "", 10) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Len(accounts, 1) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestInstanceTestSuite(t *testing.T) { | ||||||
|  | 	suite.Run(t, new(InstanceTestSuite)) | ||||||
|  | } | ||||||
|  | @ -42,7 +42,7 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M | ||||||
| 	attachment := >smodel.MediaAttachment{} | 	attachment := >smodel.MediaAttachment{} | ||||||
| 
 | 
 | ||||||
| 	q := m.newMediaQ(attachment). | 	q := m.newMediaQ(attachment). | ||||||
| 		Where("media_attachment.id = ?", id) | 		Where("? = ?", bun.Ident("media_attachment.id"), id) | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx); err != nil { | 	if err := q.Scan(ctx); err != nil { | ||||||
| 		return nil, m.conn.ProcessError(err) | 		return nil, m.conn.ProcessError(err) | ||||||
|  | @ -56,10 +56,10 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l | ||||||
| 	q := m.conn. | 	q := m.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(&attachments). | 		Model(&attachments). | ||||||
| 		Where("media_attachment.cached = true"). | 		Where("? = ?", bun.Ident("media_attachment.cached"), true). | ||||||
| 		Where("media_attachment.avatar = false"). | 		Where("? = ?", bun.Ident("media_attachment.avatar"), false). | ||||||
| 		Where("media_attachment.header = false"). | 		Where("? = ?", bun.Ident("media_attachment.header"), false). | ||||||
| 		Where("media_attachment.created_at < ?", olderThan). | 		Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan). | ||||||
| 		WhereGroup(" AND ", whereNotEmptyAndNotNull("media_attachment.remote_url")). | 		WhereGroup(" AND ", whereNotEmptyAndNotNull("media_attachment.remote_url")). | ||||||
| 		Order("media_attachment.created_at DESC") | 		Order("media_attachment.created_at DESC") | ||||||
| 
 | 
 | ||||||
|  | @ -79,13 +79,13 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit | ||||||
| 	q := m.newMediaQ(&attachments). | 	q := m.newMediaQ(&attachments). | ||||||
| 		WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery { | 		WhereGroup(" AND ", func(innerQ *bun.SelectQuery) *bun.SelectQuery { | ||||||
| 			return innerQ. | 			return innerQ. | ||||||
| 				WhereOr("media_attachment.avatar = true"). | 				WhereOr("? = ?", bun.Ident("media_attachment.avatar"), true). | ||||||
| 				WhereOr("media_attachment.header = true") | 				WhereOr("? = ?", bun.Ident("media_attachment.header"), true) | ||||||
| 		}). | 		}). | ||||||
| 		Order("media_attachment.id DESC") | 		Order("media_attachment.id DESC") | ||||||
| 
 | 
 | ||||||
| 	if maxID != "" { | 	if maxID != "" { | ||||||
| 		q = q.Where("media_attachment.id < ?", maxID) | 		q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if limit != 0 { | 	if limit != 0 { | ||||||
|  | @ -103,15 +103,15 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim | ||||||
| 	attachments := []*gtsmodel.MediaAttachment{} | 	attachments := []*gtsmodel.MediaAttachment{} | ||||||
| 
 | 
 | ||||||
| 	q := m.newMediaQ(&attachments). | 	q := m.newMediaQ(&attachments). | ||||||
| 		Where("media_attachment.cached = true"). | 		Where("? = ?", bun.Ident("media_attachment.cached"), true). | ||||||
| 		Where("media_attachment.avatar = false"). | 		Where("? = ?", bun.Ident("media_attachment.avatar"), false). | ||||||
| 		Where("media_attachment.header = false"). | 		Where("? = ?", bun.Ident("media_attachment.header"), false). | ||||||
| 		Where("media_attachment.created_at < ?", olderThan). | 		Where("? < ?", bun.Ident("media_attachment.created_at"), olderThan). | ||||||
| 		Where("media_attachment.remote_url IS NULL"). | 		Where("? IS NULL", bun.Ident("media_attachment.remote_url")). | ||||||
| 		Where("media_attachment.status_id IS NULL") | 		Where("? IS NULL", bun.Ident("media_attachment.status_id")) | ||||||
| 
 | 
 | ||||||
| 	if maxID != "" { | 	if maxID != "" { | ||||||
| 		q = q.Where("media_attachment.id < ?", maxID) | 		q = q.Where("? < ?", bun.Ident("media_attachment.id"), maxID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if limit != 0 { | 	if limit != 0 { | ||||||
|  |  | ||||||
|  | @ -46,7 +46,7 @@ func (m *mentionDB) getMentionDB(ctx context.Context, id string) (*gtsmodel.Ment | ||||||
| 	mention := gtsmodel.Mention{} | 	mention := gtsmodel.Mention{} | ||||||
| 
 | 
 | ||||||
| 	q := m.newMentionQ(&mention). | 	q := m.newMentionQ(&mention). | ||||||
| 		Where("mention.id = ?", id) | 		Where("? = ?", bun.Ident("mention.id"), id) | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx); err != nil { | 	if err := q.Scan(ctx); err != nil { | ||||||
| 		return nil, m.conn.ProcessError(err) | 		return nil, m.conn.ProcessError(err) | ||||||
|  |  | ||||||
|  | @ -47,8 +47,8 @@ func init() { | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if _, err := tx.NewDelete(). | 		if _, err := tx.NewDelete(). | ||||||
| 			Model(a). | 			TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")). | ||||||
| 			WherePK(). | 			Where("? = ?", bun.Ident("media_attachment.id"), a.ID). | ||||||
| 			Exec(ctx); err != nil { | 			Exec(ctx); err != nil { | ||||||
| 			l.Errorf("error deleting attachment with id %s: %s", a.ID, err) | 			l.Errorf("error deleting attachment with id %s: %s", a.ID, err) | ||||||
| 		} else { | 		} else { | ||||||
|  |  | ||||||
|  | @ -25,6 +25,7 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/log" | 	"github.com/superseriousbusiness/gotosocial/internal/log" | ||||||
|  | 	"github.com/uptrace/bun" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type notificationDB struct { | type notificationDB struct { | ||||||
|  | @ -44,7 +45,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo | ||||||
| 		Relation("OriginAccount"). | 		Relation("OriginAccount"). | ||||||
| 		Relation("TargetAccount"). | 		Relation("TargetAccount"). | ||||||
| 		Relation("Status"). | 		Relation("Status"). | ||||||
| 		WherePK() | 		Where("? = ?", bun.Ident("notification.id"), id) | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx); err != nil { | 	if err := q.Scan(ctx); err != nil { | ||||||
| 		return nil, n.conn.ProcessError(err) | 		return nil, n.conn.ProcessError(err) | ||||||
|  | @ -67,24 +68,24 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, | ||||||
| 
 | 
 | ||||||
| 	q := n.conn. | 	q := n.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Table("notifications"). | 		TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). | ||||||
| 		Column("id") | 		Column("notification.id") | ||||||
| 
 | 
 | ||||||
| 	if maxID != "" { | 	if maxID != "" { | ||||||
| 		q = q.Where("id < ?", maxID) | 		q = q.Where("? < ?", bun.Ident("notification.id"), maxID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if sinceID != "" { | 	if sinceID != "" { | ||||||
| 		q = q.Where("id > ?", sinceID) | 		q = q.Where("? > ?", bun.Ident("notification.id"), sinceID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, excludeType := range excludeTypes { | 	for _, excludeType := range excludeTypes { | ||||||
| 		q = q.Where("notification_type != ?", excludeType) | 		q = q.Where("? != ?", bun.Ident("notification.notification_type"), excludeType) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	q = q. | 	q = q. | ||||||
| 		Where("target_account_id = ?", accountID). | 		Where("? = ?", bun.Ident("notification.target_account_id"), accountID). | ||||||
| 		Order("id DESC") | 		Order("notification.id DESC") | ||||||
| 
 | 
 | ||||||
| 	if limit != 0 { | 	if limit != 0 { | ||||||
| 		q = q.Limit(limit) | 		q = q.Limit(limit) | ||||||
|  | @ -116,13 +117,12 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, | ||||||
| func (n *notificationDB) ClearNotifications(ctx context.Context, accountID string) db.Error { | func (n *notificationDB) ClearNotifications(ctx context.Context, accountID string) db.Error { | ||||||
| 	if _, err := n.conn. | 	if _, err := n.conn. | ||||||
| 		NewDelete(). | 		NewDelete(). | ||||||
| 		Table("notifications"). | 		TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). | ||||||
| 		Where("target_account_id = ?", accountID). | 		Where("? = ?", bun.Ident("notification.target_account_id"), accountID). | ||||||
| 		Exec(ctx); err != nil { | 		Exec(ctx); err != nil { | ||||||
| 		return n.conn.ProcessError(err) | 		return n.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	n.cache.Clear() | 	n.cache.Clear() | ||||||
| 
 |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -51,26 +51,25 @@ func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery { | ||||||
| func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) { | func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) { | ||||||
| 	q := r.conn. | 	q := r.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.Block{}). | 		TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). | ||||||
| 		ExcludeColumn("id", "created_at", "updated_at", "uri"). | 		Column("block.id") | ||||||
| 		Limit(1) |  | ||||||
| 
 | 
 | ||||||
| 	if eitherDirection { | 	if eitherDirection { | ||||||
| 		q = q. | 		q = q. | ||||||
| 			WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery { | 			WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery { | ||||||
| 				return inner. | 				return inner. | ||||||
| 					Where("account_id = ?", account1). | 					Where("? = ?", bun.Ident("block.account_id"), account1). | ||||||
| 					Where("target_account_id = ?", account2) | 					Where("? = ?", bun.Ident("block.target_account_id"), account2) | ||||||
| 			}). | 			}). | ||||||
| 			WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery { | 			WhereGroup(" OR ", func(inner *bun.SelectQuery) *bun.SelectQuery { | ||||||
| 				return inner. | 				return inner. | ||||||
| 					Where("account_id = ?", account2). | 					Where("? = ?", bun.Ident("block.account_id"), account2). | ||||||
| 					Where("target_account_id = ?", account1) | 					Where("? = ?", bun.Ident("block.target_account_id"), account1) | ||||||
| 			}) | 			}) | ||||||
| 	} else { | 	} else { | ||||||
| 		q = q. | 		q = q. | ||||||
| 			Where("account_id = ?", account1). | 			Where("? = ?", bun.Ident("block.account_id"), account1). | ||||||
| 			Where("target_account_id = ?", account2) | 			Where("? = ?", bun.Ident("block.target_account_id"), account2) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return r.conn.Exists(ctx, q) | 	return r.conn.Exists(ctx, q) | ||||||
|  | @ -80,8 +79,8 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 | ||||||
| 	block := >smodel.Block{} | 	block := >smodel.Block{} | ||||||
| 
 | 
 | ||||||
| 	q := r.newBlockQ(block). | 	q := r.newBlockQ(block). | ||||||
| 		Where("block.account_id = ?", account1). | 		Where("? = ?", bun.Ident("block.account_id"), account1). | ||||||
| 		Where("block.target_account_id = ?", account2) | 		Where("? = ?", bun.Ident("block.target_account_id"), account2) | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx); err != nil { | 	if err := q.Scan(ctx); err != nil { | ||||||
| 		return nil, r.conn.ProcessError(err) | 		return nil, r.conn.ProcessError(err) | ||||||
|  | @ -99,13 +98,13 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount | ||||||
| 	if err := r.conn. | 	if err := r.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(follow). | 		Model(follow). | ||||||
| 		Where("account_id = ?", requestingAccount). | 		Column("follow.show_reblogs", "follow.notify"). | ||||||
| 		Where("target_account_id = ?", targetAccount). | 		Where("? = ?", bun.Ident("follow.account_id"), requestingAccount). | ||||||
|  | 		Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount). | ||||||
| 		Limit(1). | 		Limit(1). | ||||||
| 		Scan(ctx); err != nil { | 		Scan(ctx); err != nil { | ||||||
| 		if err != sql.ErrNoRows { | 		if err := r.conn.ProcessError(err); err != db.ErrNoEntries { | ||||||
| 			// a proper error | 			return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err) | ||||||
| 			return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err) |  | ||||||
| 		} | 		} | ||||||
| 		// no follow exists so these are all false | 		// no follow exists so these are all false | ||||||
| 		rel.Following = false | 		rel.Following = false | ||||||
|  | @ -119,55 +118,56 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// check if the target account follows the requesting account | 	// check if the target account follows the requesting account | ||||||
| 	count, err := r.conn. | 	followedByQ := r.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.Follow{}). | 		TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). | ||||||
| 		Where("account_id = ?", targetAccount). | 		Column("follow.id"). | ||||||
| 		Where("target_account_id = ?", requestingAccount). | 		Where("? = ?", bun.Ident("follow.account_id"), targetAccount). | ||||||
| 		Limit(1). | 		Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount) | ||||||
| 		Count(ctx) | 	followedBy, err := r.conn.Exists(ctx, followedByQ) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err) | 		return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err) | ||||||
| 	} | 	} | ||||||
| 	rel.FollowedBy = count > 0 | 	rel.FollowedBy = followedBy | ||||||
| 
 |  | ||||||
| 	// check if the requesting account blocks the target account |  | ||||||
| 	count, err = r.conn.NewSelect(). |  | ||||||
| 		Model(>smodel.Block{}). |  | ||||||
| 		Where("account_id = ?", requestingAccount). |  | ||||||
| 		Where("target_account_id = ?", targetAccount). |  | ||||||
| 		Limit(1). |  | ||||||
| 		Count(ctx) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err) |  | ||||||
| 	} |  | ||||||
| 	rel.Blocking = count > 0 |  | ||||||
| 
 |  | ||||||
| 	// check if the target account blocks the requesting account |  | ||||||
| 	count, err = r.conn. |  | ||||||
| 		NewSelect(). |  | ||||||
| 		Model(>smodel.Block{}). |  | ||||||
| 		Where("account_id = ?", targetAccount). |  | ||||||
| 		Where("target_account_id = ?", requestingAccount). |  | ||||||
| 		Limit(1). |  | ||||||
| 		Count(ctx) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) |  | ||||||
| 	} |  | ||||||
| 	rel.BlockedBy = count > 0 |  | ||||||
| 
 | 
 | ||||||
| 	// check if there's a pending following request from requesting account to target account | 	// check if there's a pending following request from requesting account to target account | ||||||
| 	count, err = r.conn. | 	requestedQ := r.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.FollowRequest{}). | 		TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). | ||||||
| 		Where("account_id = ?", requestingAccount). | 		Column("follow_request.id"). | ||||||
| 		Where("target_account_id = ?", targetAccount). | 		Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount). | ||||||
| 		Limit(1). | 		Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount) | ||||||
| 		Count(ctx) | 	requested, err := r.conn.Exists(ctx, requestedQ) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) | 		return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err) | ||||||
| 	} | 	} | ||||||
| 	rel.Requested = count > 0 | 	rel.Requested = requested | ||||||
|  | 
 | ||||||
|  | 	// check if the requesting account is blocking the target account | ||||||
|  | 	blockingQ := r.conn. | ||||||
|  | 		NewSelect(). | ||||||
|  | 		TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). | ||||||
|  | 		Column("block.id"). | ||||||
|  | 		Where("? = ?", bun.Ident("block.account_id"), requestingAccount). | ||||||
|  | 		Where("? = ?", bun.Ident("block.target_account_id"), targetAccount) | ||||||
|  | 	blocking, err := r.conn.Exists(ctx, blockingQ) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err) | ||||||
|  | 	} | ||||||
|  | 	rel.Blocking = blocking | ||||||
|  | 
 | ||||||
|  | 	// check if the requesting account is blocked by the target account | ||||||
|  | 	blockedByQ := r.conn. | ||||||
|  | 		NewSelect(). | ||||||
|  | 		TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). | ||||||
|  | 		Column("block.id"). | ||||||
|  | 		Where("? = ?", bun.Ident("block.account_id"), targetAccount). | ||||||
|  | 		Where("? = ?", bun.Ident("block.target_account_id"), requestingAccount) | ||||||
|  | 	blockedBy, err := r.conn.Exists(ctx, blockedByQ) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err) | ||||||
|  | 	} | ||||||
|  | 	rel.BlockedBy = blockedBy | ||||||
| 
 | 
 | ||||||
| 	return rel, nil | 	return rel, nil | ||||||
| } | } | ||||||
|  | @ -179,10 +179,10 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode | ||||||
| 
 | 
 | ||||||
| 	q := r.conn. | 	q := r.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.Follow{}). | 		TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). | ||||||
| 		Where("account_id = ?", sourceAccount.ID). | 		Column("follow.id"). | ||||||
| 		Where("target_account_id = ?", targetAccount.ID). | 		Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID). | ||||||
| 		Limit(1) | 		Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID) | ||||||
| 
 | 
 | ||||||
| 	return r.conn.Exists(ctx, q) | 	return r.conn.Exists(ctx, q) | ||||||
| } | } | ||||||
|  | @ -194,9 +194,10 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g | ||||||
| 
 | 
 | ||||||
| 	q := r.conn. | 	q := r.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.FollowRequest{}). | 		TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). | ||||||
| 		Where("account_id = ?", sourceAccount.ID). | 		Column("follow_request.id"). | ||||||
| 		Where("target_account_id = ?", targetAccount.ID) | 		Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID). | ||||||
|  | 		Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID) | ||||||
| 
 | 
 | ||||||
| 	return r.conn.Exists(ctx, q) | 	return r.conn.Exists(ctx, q) | ||||||
| } | } | ||||||
|  | @ -222,82 +223,98 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { | func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { | ||||||
| 	// make sure the original follow request exists | 	var follow *gtsmodel.Follow | ||||||
| 	fr := >smodel.FollowRequest{} | 
 | ||||||
| 	if err := r.conn. | 	if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error { | ||||||
|  | 		// get original follow request | ||||||
|  | 		followRequest := >smodel.FollowRequest{} | ||||||
|  | 		if err := tx. | ||||||
| 			NewSelect(). | 			NewSelect(). | ||||||
| 		Model(fr). | 			Model(followRequest). | ||||||
| 		Where("account_id = ?", originAccountID). | 			Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). | ||||||
| 		Where("target_account_id = ?", targetAccountID). | 			Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). | ||||||
| 			Scan(ctx); err != nil { | 			Scan(ctx); err != nil { | ||||||
| 		return nil, r.conn.ProcessError(err) | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// create a new follow to 'replace' the request with | 		// create a new follow to 'replace' the request with | ||||||
| 	follow := >smodel.Follow{ | 		follow = >smodel.Follow{ | ||||||
| 		ID:              fr.ID, | 			ID:              followRequest.ID, | ||||||
| 			AccountID:       originAccountID, | 			AccountID:       originAccountID, | ||||||
| 			TargetAccountID: targetAccountID, | 			TargetAccountID: targetAccountID, | ||||||
| 		URI:             fr.URI, | 			URI:             followRequest.URI, | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// if the follow already exists, just update the URI -- we don't need to do anything else | 		// if the follow already exists, just update the URI -- we don't need to do anything else | ||||||
| 	if _, err := r.conn. | 		if _, err := tx. | ||||||
| 			NewInsert(). | 			NewInsert(). | ||||||
| 			Model(follow). | 			Model(follow). | ||||||
| 		On("CONFLICT (account_id,target_account_id) DO UPDATE set uri = ?", follow.URI). | 			On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI). | ||||||
| 			Exec(ctx); err != nil { | 			Exec(ctx); err != nil { | ||||||
| 		return nil, r.conn.ProcessError(err) | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// now remove the follow request | 		// now remove the follow request | ||||||
| 	if _, err := r.conn. | 		if _, err := tx. | ||||||
| 			NewDelete(). | 			NewDelete(). | ||||||
| 		Model(>smodel.FollowRequest{}). | 			TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). | ||||||
| 		Where("account_id = ?", originAccountID). | 			Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). | ||||||
| 		Where("target_account_id = ?", targetAccountID). |  | ||||||
| 			Exec(ctx); err != nil { | 			Exec(ctx); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return nil | ||||||
|  | 	}); err != nil { | ||||||
| 		return nil, r.conn.ProcessError(err) | 		return nil, r.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// return the new follow | ||||||
| 	return follow, nil | 	return follow, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) { | func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) { | ||||||
| 	// first get the follow request out of the database | 	followRequest := >smodel.FollowRequest{} | ||||||
| 	fr := >smodel.FollowRequest{} | 
 | ||||||
| 	if err := r.conn. | 	if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error { | ||||||
|  | 		// get original follow request | ||||||
|  | 		if err := tx. | ||||||
| 			NewSelect(). | 			NewSelect(). | ||||||
| 		Model(fr). | 			Model(followRequest). | ||||||
| 		Where("account_id = ?", originAccountID). | 			Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). | ||||||
| 		Where("target_account_id = ?", targetAccountID). | 			Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). | ||||||
| 			Scan(ctx); err != nil { | 			Scan(ctx); err != nil { | ||||||
| 		return nil, r.conn.ProcessError(err) | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// now delete it from the database by ID | 		// now delete it from the database by ID | ||||||
| 	if _, err := r.conn. | 		if _, err := tx. | ||||||
| 			NewDelete(). | 			NewDelete(). | ||||||
| 		Model(>smodel.FollowRequest{ID: fr.ID}). | 			TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). | ||||||
| 		WherePK(). | 			Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). | ||||||
| 			Exec(ctx); err != nil { | 			Exec(ctx); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return nil | ||||||
|  | 	}); err != nil { | ||||||
| 		return nil, r.conn.ProcessError(err) | 		return nil, r.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// return the deleted follow request | 	// return the deleted follow request | ||||||
| 	return fr, nil | 	return followRequest, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) { | func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) { | ||||||
| 	followRequests := []*gtsmodel.FollowRequest{} | 	followRequests := []*gtsmodel.FollowRequest{} | ||||||
| 
 | 
 | ||||||
| 	q := r.newFollowQ(&followRequests). | 	q := r.newFollowQ(&followRequests). | ||||||
| 		Where("target_account_id = ?", accountID). | 		Where("? = ?", bun.Ident("follow_request.target_account_id"), accountID). | ||||||
| 		Order("follow_request.updated_at DESC") | 		Order("follow_request.updated_at DESC") | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx); err != nil { | 	if err := q.Scan(ctx); err != nil { | ||||||
| 		return nil, r.conn.ProcessError(err) | 		return nil, r.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	return followRequests, nil | 	return followRequests, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -305,21 +322,31 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string | ||||||
| 	follows := []*gtsmodel.Follow{} | 	follows := []*gtsmodel.Follow{} | ||||||
| 
 | 
 | ||||||
| 	q := r.newFollowQ(&follows). | 	q := r.newFollowQ(&follows). | ||||||
| 		Where("account_id = ?", accountID). | 		Where("? = ?", bun.Ident("follow.account_id"), accountID). | ||||||
| 		Order("follow.updated_at DESC") | 		Order("follow.updated_at DESC") | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx); err != nil { | 	if err := q.Scan(ctx); err != nil { | ||||||
| 		return nil, r.conn.ProcessError(err) | 		return nil, r.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	return follows, nil | 	return follows, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { | func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { | ||||||
| 	return r.conn. | 	q := r.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(&[]*gtsmodel.Follow{}). | 		TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")) | ||||||
| 		Where("account_id = ?", accountID). | 
 | ||||||
| 		Count(ctx) | 	if localOnly { | ||||||
|  | 		q = q. | ||||||
|  | 			Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.target_account_id"), bun.Ident("account.id")). | ||||||
|  | 			Where("? = ?", bun.Ident("follow.account_id"), accountID). | ||||||
|  | 			Where("? IS NULL", bun.Ident("account.domain")) | ||||||
|  | 	} else { | ||||||
|  | 		q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return q.Count(ctx) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { | func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { | ||||||
|  | @ -331,12 +358,12 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str | ||||||
| 		Order("follow.updated_at DESC") | 		Order("follow.updated_at DESC") | ||||||
| 
 | 
 | ||||||
| 	if localOnly { | 	if localOnly { | ||||||
| 		q = q.ColumnExpr("follow.*"). | 		q = q. | ||||||
| 			Join("JOIN accounts AS a ON follow.account_id = CAST(a.id as TEXT)"). | 			Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")). | ||||||
| 			Where("follow.target_account_id = ?", accountID). | 			Where("? = ?", bun.Ident("follow.target_account_id"), accountID). | ||||||
| 			WhereGroup(" AND ", whereEmptyOrNull("a.domain")) | 			Where("? IS NULL", bun.Ident("account.domain")) | ||||||
| 	} else { | 	} else { | ||||||
| 		q = q.Where("target_account_id = ?", accountID) | 		q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err := q.Scan(ctx) | 	err := q.Scan(ctx) | ||||||
|  | @ -347,9 +374,18 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { | func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { | ||||||
| 	return r.conn. | 	q := r.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(&[]*gtsmodel.Follow{}). | 		TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")) | ||||||
| 		Where("target_account_id = ?", accountID). | 
 | ||||||
| 		Count(ctx) | 	if localOnly { | ||||||
|  | 		q = q. | ||||||
|  | 			Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")). | ||||||
|  | 			Where("? = ?", bun.Ident("follow.target_account_id"), accountID). | ||||||
|  | 			Where("? IS NULL", bun.Ident("account.domain")) | ||||||
|  | 	} else { | ||||||
|  | 		q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return q.Count(ctx) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -20,7 +20,6 @@ package bundb_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" |  | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"github.com/stretchr/testify/suite" | 	"github.com/stretchr/testify/suite" | ||||||
|  | @ -48,12 +47,14 @@ func (suite *RelationshipTestSuite) TestIsBlocked() { | ||||||
| 	suite.False(blocked) | 	suite.False(blocked) | ||||||
| 
 | 
 | ||||||
| 	// have account1 block account2 | 	// have account1 block account2 | ||||||
| 	suite.db.Put(ctx, >smodel.Block{ | 	if err := suite.db.Put(ctx, >smodel.Block{ | ||||||
| 		ID:              "01G202BCSXXJZ70BHB5KCAHH8C", | 		ID:              "01G202BCSXXJZ70BHB5KCAHH8C", | ||||||
| 		URI:             "http://localhost:8080/some_block_uri_1", | 		URI:             "http://localhost:8080/some_block_uri_1", | ||||||
| 		AccountID:       account1, | 		AccountID:       account1, | ||||||
| 		TargetAccountID: account2, | 		TargetAccountID: account2, | ||||||
| 	}) | 	}); err != nil { | ||||||
|  | 		suite.FailNow(err.Error()) | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	// account 1 now blocks account 2 | 	// account 1 now blocks account 2 | ||||||
| 	blocked, err = suite.db.IsBlocked(ctx, account1, account2, false) | 	blocked, err = suite.db.IsBlocked(ctx, account1, account2, false) | ||||||
|  | @ -75,62 +76,242 @@ func (suite *RelationshipTestSuite) TestIsBlocked() { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *RelationshipTestSuite) TestGetBlock() { | func (suite *RelationshipTestSuite) TestGetBlock() { | ||||||
| 	suite.Suite.T().Skip("TODO: implement") | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	account1 := suite.testAccounts["local_account_1"].ID | ||||||
|  | 	account2 := suite.testAccounts["local_account_2"].ID | ||||||
|  | 
 | ||||||
|  | 	if err := suite.db.Put(ctx, >smodel.Block{ | ||||||
|  | 		ID:              "01G202BCSXXJZ70BHB5KCAHH8C", | ||||||
|  | 		URI:             "http://localhost:8080/some_block_uri_1", | ||||||
|  | 		AccountID:       account1, | ||||||
|  | 		TargetAccountID: account2, | ||||||
|  | 	}); err != nil { | ||||||
|  | 		suite.FailNow(err.Error()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	block, err := suite.db.GetBlock(ctx, account1, account2) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.NotNil(block) | ||||||
|  | 	suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *RelationshipTestSuite) TestGetRelationship() { | func (suite *RelationshipTestSuite) TestGetRelationship() { | ||||||
| 	suite.Suite.T().Skip("TODO: implement") | 	requestingAccount := suite.testAccounts["local_account_1"] | ||||||
|  | 	targetAccount := suite.testAccounts["admin_account"] | ||||||
|  | 
 | ||||||
|  | 	relationship, err := suite.db.GetRelationship(context.Background(), requestingAccount.ID, targetAccount.ID) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.NotNil(relationship) | ||||||
|  | 
 | ||||||
|  | 	suite.True(relationship.Following) | ||||||
|  | 	suite.True(relationship.ShowingReblogs) | ||||||
|  | 	suite.False(relationship.Notifying) | ||||||
|  | 	suite.True(relationship.FollowedBy) | ||||||
|  | 	suite.False(relationship.Blocking) | ||||||
|  | 	suite.False(relationship.BlockedBy) | ||||||
|  | 	suite.False(relationship.Muting) | ||||||
|  | 	suite.False(relationship.MutingNotifications) | ||||||
|  | 	suite.False(relationship.Requested) | ||||||
|  | 	suite.False(relationship.DomainBlocking) | ||||||
|  | 	suite.False(relationship.Endorsed) | ||||||
|  | 	suite.Empty(relationship.Note) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *RelationshipTestSuite) TestIsFollowing() { | func (suite *RelationshipTestSuite) TestIsFollowingYes() { | ||||||
| 	suite.Suite.T().Skip("TODO: implement") | 	requestingAccount := suite.testAccounts["local_account_1"] | ||||||
|  | 	targetAccount := suite.testAccounts["admin_account"] | ||||||
|  | 	isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.True(isFollowing) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *RelationshipTestSuite) TestIsFollowingNo() { | ||||||
|  | 	requestingAccount := suite.testAccounts["admin_account"] | ||||||
|  | 	targetAccount := suite.testAccounts["local_account_2"] | ||||||
|  | 	isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.False(isFollowing) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *RelationshipTestSuite) TestIsMutualFollowing() { | func (suite *RelationshipTestSuite) TestIsMutualFollowing() { | ||||||
| 	suite.Suite.T().Skip("TODO: implement") | 	requestingAccount := suite.testAccounts["local_account_1"] | ||||||
|  | 	targetAccount := suite.testAccounts["admin_account"] | ||||||
|  | 	isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.True(isMutualFollowing) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *RelationshipTestSuite) AcceptFollowRequest() { | func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() { | ||||||
| 	for _, account := range suite.testAccounts { | 	requestingAccount := suite.testAccounts["local_account_1"] | ||||||
| 		_, err := suite.db.AcceptFollowRequest(context.Background(), account.ID, "NON-EXISTENT-ID") | 	targetAccount := suite.testAccounts["local_account_2"] | ||||||
| 		if err != nil && !errors.Is(err, db.ErrNoEntries) { | 	isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount) | ||||||
| 			suite.Suite.Fail("error accepting follow request: %v", err) | 	suite.NoError(err) | ||||||
| 		} | 	suite.True(isMutualFollowing) | ||||||
| 	} |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *RelationshipTestSuite) GetAccountFollowRequests() { | func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() { | ||||||
| 	suite.Suite.T().Skip("TODO: implement") | 	ctx := context.Background() | ||||||
| } | 	account := suite.testAccounts["admin_account"] | ||||||
|  | 	targetAccount := suite.testAccounts["local_account_2"] | ||||||
| 
 | 
 | ||||||
| func (suite *RelationshipTestSuite) GetAccountFollows() { | 	followRequest := >smodel.FollowRequest{ | ||||||
| 	suite.Suite.T().Skip("TODO: implement") | 		ID:              "01GEF753FWHCHRDWR0QEHBXM8W", | ||||||
| } | 		URI:             "http://localhost:8080/weeeeeeeeeeeeeeeee", | ||||||
| 
 | 		AccountID:       account.ID, | ||||||
| func (suite *RelationshipTestSuite) CountAccountFollows() { | 		TargetAccountID: targetAccount.ID, | ||||||
| 	suite.Suite.T().Skip("TODO: implement") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (suite *RelationshipTestSuite) GetAccountFollowedBy() { |  | ||||||
| 	// TODO: more comprehensive tests here |  | ||||||
| 
 |  | ||||||
| 	for _, account := range suite.testAccounts { |  | ||||||
| 		var err error |  | ||||||
| 
 |  | ||||||
| 		_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, false) |  | ||||||
| 		if err != nil { |  | ||||||
| 			suite.Suite.Fail("error checking accounts followed by: %v", err) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 		_, err = suite.db.GetAccountFollowedBy(context.Background(), account.ID, true) | 	if err := suite.db.Put(ctx, followRequest); err != nil { | ||||||
| 		if err != nil { | 		suite.FailNow(err.Error()) | ||||||
| 			suite.Suite.Fail("error checking localOnly accounts followed by: %v", err) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.NotNil(follow) | ||||||
|  | 	suite.Equal(followRequest.URI, follow.URI) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *RelationshipTestSuite) CountAccountFollowedBy() { | func (suite *RelationshipTestSuite) TestAcceptFollowRequestNotExisting() { | ||||||
| 	suite.Suite.T().Skip("TODO: implement") | 	ctx := context.Background() | ||||||
|  | 	account := suite.testAccounts["admin_account"] | ||||||
|  | 	targetAccount := suite.testAccounts["local_account_2"] | ||||||
|  | 
 | ||||||
|  | 	follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID) | ||||||
|  | 	suite.ErrorIs(err, db.ErrNoEntries) | ||||||
|  | 	suite.Nil(follow) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *RelationshipTestSuite) TestAcceptFollowRequestFollowAlreadyExists() { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	account := suite.testAccounts["local_account_1"] | ||||||
|  | 	targetAccount := suite.testAccounts["admin_account"] | ||||||
|  | 
 | ||||||
|  | 	// follow already exists in the db from local_account_1 -> admin_account | ||||||
|  | 	existingFollow := >smodel.Follow{} | ||||||
|  | 	if err := suite.db.GetByID(ctx, suite.testFollows["local_account_1_admin_account"].ID, existingFollow); err != nil { | ||||||
|  | 		suite.FailNow(err.Error()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	followRequest := >smodel.FollowRequest{ | ||||||
|  | 		ID:              "01GEF753FWHCHRDWR0QEHBXM8W", | ||||||
|  | 		URI:             "http://localhost:8080/weeeeeeeeeeeeeeeee", | ||||||
|  | 		AccountID:       account.ID, | ||||||
|  | 		TargetAccountID: targetAccount.ID, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := suite.db.Put(ctx, followRequest); err != nil { | ||||||
|  | 		suite.FailNow(err.Error()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	follow, err := suite.db.AcceptFollowRequest(ctx, account.ID, targetAccount.ID) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.NotNil(follow) | ||||||
|  | 
 | ||||||
|  | 	// uri should be equal to value of new/overlapping follow request | ||||||
|  | 	suite.NotEqual(followRequest.URI, existingFollow.URI) | ||||||
|  | 	suite.Equal(followRequest.URI, follow.URI) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	account := suite.testAccounts["admin_account"] | ||||||
|  | 	targetAccount := suite.testAccounts["local_account_2"] | ||||||
|  | 
 | ||||||
|  | 	followRequest := >smodel.FollowRequest{ | ||||||
|  | 		ID:              "01GEF753FWHCHRDWR0QEHBXM8W", | ||||||
|  | 		URI:             "http://localhost:8080/weeeeeeeeeeeeeeeee", | ||||||
|  | 		AccountID:       account.ID, | ||||||
|  | 		TargetAccountID: targetAccount.ID, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := suite.db.Put(ctx, followRequest); err != nil { | ||||||
|  | 		suite.FailNow(err.Error()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.NotNil(rejectedFollowRequest) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	account := suite.testAccounts["admin_account"] | ||||||
|  | 	targetAccount := suite.testAccounts["local_account_2"] | ||||||
|  | 
 | ||||||
|  | 	rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID) | ||||||
|  | 	suite.ErrorIs(err, db.ErrNoEntries) | ||||||
|  | 	suite.Nil(rejectedFollowRequest) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	account := suite.testAccounts["admin_account"] | ||||||
|  | 	targetAccount := suite.testAccounts["local_account_2"] | ||||||
|  | 
 | ||||||
|  | 	followRequest := >smodel.FollowRequest{ | ||||||
|  | 		ID:              "01GEF753FWHCHRDWR0QEHBXM8W", | ||||||
|  | 		URI:             "http://localhost:8080/weeeeeeeeeeeeeeeee", | ||||||
|  | 		AccountID:       account.ID, | ||||||
|  | 		TargetAccountID: targetAccount.ID, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := suite.db.Put(ctx, followRequest); err != nil { | ||||||
|  | 		suite.FailNow(err.Error()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Len(followRequests, 1) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *RelationshipTestSuite) TestGetAccountFollows() { | ||||||
|  | 	account := suite.testAccounts["local_account_1"] | ||||||
|  | 	follows, err := suite.db.GetAccountFollows(context.Background(), account.ID) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Len(follows, 2) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() { | ||||||
|  | 	account := suite.testAccounts["local_account_1"] | ||||||
|  | 	followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, true) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Equal(2, followsCount) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *RelationshipTestSuite) TestCountAccountFollows() { | ||||||
|  | 	account := suite.testAccounts["local_account_1"] | ||||||
|  | 	followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID, false) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Equal(2, followsCount) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() { | ||||||
|  | 	account := suite.testAccounts["local_account_1"] | ||||||
|  | 	follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, false) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Len(follows, 2) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *RelationshipTestSuite) TestGetAccountFollowedByLocalOnly() { | ||||||
|  | 	account := suite.testAccounts["local_account_1"] | ||||||
|  | 	follows, err := suite.db.GetAccountFollowedBy(context.Background(), account.ID, true) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Len(follows, 2) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() { | ||||||
|  | 	account := suite.testAccounts["local_account_1"] | ||||||
|  | 	followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, false) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Equal(2, followsCount) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *RelationshipTestSuite) TestCountAccountFollowedByLocalOnly() { | ||||||
|  | 	account := suite.testAccounts["local_account_1"] | ||||||
|  | 	followsCount, err := suite.db.CountAccountFollowedBy(context.Background(), account.ID, true) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 	suite.Equal(2, followsCount) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestRelationshipTestSuite(t *testing.T) { | func TestRelationshipTestSuite(t *testing.T) { | ||||||
|  |  | ||||||
|  | @ -21,7 +21,6 @@ package bundb | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
| 	"errors" |  | ||||||
| 
 | 
 | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
|  | @ -35,29 +34,22 @@ type sessionDB struct { | ||||||
| func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { | func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { | ||||||
| 	rss := make([]*gtsmodel.RouterSession, 0, 1) | 	rss := make([]*gtsmodel.RouterSession, 0, 1) | ||||||
| 
 | 
 | ||||||
| 	_, err := s.conn. | 	// get the first router session in the db or... | ||||||
|  | 	if err := s.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(&rss). | 		Model(&rss). | ||||||
| 		Limit(1). | 		Limit(1). | ||||||
| 		Order("id DESC"). | 		Order("router_session.id DESC"). | ||||||
| 		Exec(ctx) | 		Scan(ctx); err != nil { | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, s.conn.ProcessError(err) | 		return nil, s.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// ... create a new one | ||||||
| 	if len(rss) == 0 { | 	if len(rss) == 0 { | ||||||
| 		// no session created yet, so make one |  | ||||||
| 		return s.createSession(ctx) | 		return s.createSession(ctx) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if len(rss) != 1 { | 	return rss[0], nil | ||||||
| 		// we asked for 1 so we should get 1 |  | ||||||
| 		return nil, errors.New("more than 1 router session was returned") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// return the one session found |  | ||||||
| 	rs := rss[0] |  | ||||||
| 	return rs, nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { | func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, db.Error) { | ||||||
|  | @ -71,24 +63,23 @@ func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	rid, err := id.NewULID() | 	id, err := id.NewULID() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	rs := >smodel.RouterSession{ | 	rs := >smodel.RouterSession{ | ||||||
| 		ID:    rid, | 		ID:    id, | ||||||
| 		Auth:  auth, | 		Auth:  auth, | ||||||
| 		Crypt: crypt, | 		Crypt: crypt, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	q := s.conn. | 	if _, err := s.conn. | ||||||
| 		NewInsert(). | 		NewInsert(). | ||||||
| 		Model(rs) | 		Model(rs). | ||||||
| 
 | 		Exec(ctx); err != nil { | ||||||
| 	_, err = q.Exec(ctx) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, s.conn.ProcessError(err) | 		return nil, s.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	return rs, nil | 	return rs, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -37,14 +37,13 @@ func (suite *SessionTestSuite) TestGetSession() { | ||||||
| 	suite.NotEmpty(session.Crypt) | 	suite.NotEmpty(session.Crypt) | ||||||
| 	suite.NotEmpty(session.ID) | 	suite.NotEmpty(session.ID) | ||||||
| 
 | 
 | ||||||
| 	// TODO -- the same session should be returned with consecutive selects | 	// the same session should be returned with consecutive selects | ||||||
| 	// right now there's an issue with bytea in bun, so uncomment this when that issue is fixed: https://github.com/uptrace/bun/issues/122 | 	session2, err := suite.db.GetSession(context.Background()) | ||||||
| 	// session2, err := suite.db.GetSession(context.Background()) | 	suite.NoError(err) | ||||||
| 	// suite.NoError(err) | 	suite.NotNil(session2) | ||||||
| 	// suite.NotNil(session2) | 	suite.Equal(session.Auth, session2.Auth) | ||||||
| 	// suite.Equal(session.Auth, session2.Auth) | 	suite.Equal(session.Crypt, session2.Crypt) | ||||||
| 	// suite.Equal(session.Crypt, session2.Crypt) | 	suite.Equal(session.ID, session2.ID) | ||||||
| 	// suite.Equal(session.ID, session2.ID) |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSessionTestSuite(t *testing.T) { | func TestSessionTestSuite(t *testing.T) { | ||||||
|  |  | ||||||
|  | @ -72,7 +72,7 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat | ||||||
| 			return s.cache.GetByID(id) | 			return s.cache.GetByID(id) | ||||||
| 		}, | 		}, | ||||||
| 		func(status *gtsmodel.Status) error { | 		func(status *gtsmodel.Status) error { | ||||||
| 			return s.newStatusQ(status).Where("status.id = ?", id).Scan(ctx) | 			return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx) | ||||||
| 		}, | 		}, | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
|  | @ -84,7 +84,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St | ||||||
| 			return s.cache.GetByURI(uri) | 			return s.cache.GetByURI(uri) | ||||||
| 		}, | 		}, | ||||||
| 		func(status *gtsmodel.Status) error { | 		func(status *gtsmodel.Status) error { | ||||||
| 			return s.newStatusQ(status).Where("status.uri = ?", uri).Scan(ctx) | 			return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx) | ||||||
| 		}, | 		}, | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
|  | @ -96,7 +96,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St | ||||||
| 			return s.cache.GetByURL(url) | 			return s.cache.GetByURL(url) | ||||||
| 		}, | 		}, | ||||||
| 		func(status *gtsmodel.Status) error { | 		func(status *gtsmodel.Status) error { | ||||||
| 			return s.newStatusQ(status).Where("status.url = ?", url).Scan(ctx) | 			return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx) | ||||||
| 		}, | 		}, | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
|  | @ -109,8 +109,7 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta | ||||||
| 		status = >smodel.Status{} | 		status = >smodel.Status{} | ||||||
| 
 | 
 | ||||||
| 		// Not cached! Perform database query | 		// Not cached! Perform database query | ||||||
| 		err := dbQuery(status) | 		if err := dbQuery(status); err != nil { | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, s.conn.ProcessError(err) | 			return nil, s.conn.ProcessError(err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -138,49 +137,12 @@ func (s *statusDB) getStatus(ctx context.Context, cacheGet func() (*gtsmodel.Sta | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { | func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error { | ||||||
| 	return s.conn.RunInTx(ctx, func(tx bun.Tx) error { |  | ||||||
| 		// create links between this status and any emojis it uses |  | ||||||
| 		for _, i := range status.EmojiIDs { |  | ||||||
| 			if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{ |  | ||||||
| 				StatusID: status.ID, |  | ||||||
| 				EmojiID:  i, |  | ||||||
| 			}).Exec(ctx); err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// create links between this status and any tags it uses |  | ||||||
| 		for _, i := range status.TagIDs { |  | ||||||
| 			if _, err := tx.NewInsert().Model(>smodel.StatusToTag{ |  | ||||||
| 				StatusID: status.ID, |  | ||||||
| 				TagID:    i, |  | ||||||
| 			}).Exec(ctx); err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// change the status ID of the media attachments to the new status |  | ||||||
| 		for _, a := range status.Attachments { |  | ||||||
| 			a.StatusID = status.ID |  | ||||||
| 			a.UpdatedAt = time.Now() |  | ||||||
| 			if _, err := tx.NewUpdate().Model(a). |  | ||||||
| 				Where("id = ?", a.ID). |  | ||||||
| 				Exec(ctx); err != nil { |  | ||||||
| 				return err |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// Finally, insert the status |  | ||||||
| 		_, err := tx.NewInsert().Model(status).Exec(ctx) |  | ||||||
| 		return err |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, db.Error) { |  | ||||||
| 	err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { | 	err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { | ||||||
| 		// create links between this status and any emojis it uses | 		// create links between this status and any emojis it uses | ||||||
| 		for _, i := range status.EmojiIDs { | 		for _, i := range status.EmojiIDs { | ||||||
| 			if _, err := tx.NewInsert().Model(>smodel.StatusToEmoji{ | 			if _, err := tx. | ||||||
|  | 				NewInsert(). | ||||||
|  | 				Model(>smodel.StatusToEmoji{ | ||||||
| 					StatusID: status.ID, | 					StatusID: status.ID, | ||||||
| 					EmojiID:  i, | 					EmojiID:  i, | ||||||
| 				}).Exec(ctx); err != nil { | 				}).Exec(ctx); err != nil { | ||||||
|  | @ -193,7 +155,75 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (* | ||||||
| 
 | 
 | ||||||
| 		// create links between this status and any tags it uses | 		// create links between this status and any tags it uses | ||||||
| 		for _, i := range status.TagIDs { | 		for _, i := range status.TagIDs { | ||||||
| 			if _, err := tx.NewInsert().Model(>smodel.StatusToTag{ | 			if _, err := tx. | ||||||
|  | 				NewInsert(). | ||||||
|  | 				Model(>smodel.StatusToTag{ | ||||||
|  | 					StatusID: status.ID, | ||||||
|  | 					TagID:    i, | ||||||
|  | 				}).Exec(ctx); err != nil { | ||||||
|  | 				err = s.conn.errProc(err) | ||||||
|  | 				if !errors.Is(err, db.ErrAlreadyExists) { | ||||||
|  | 					return err | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// change the status ID of the media attachments to the new status | ||||||
|  | 		for _, a := range status.Attachments { | ||||||
|  | 			a.StatusID = status.ID | ||||||
|  | 			a.UpdatedAt = time.Now() | ||||||
|  | 			if _, err := tx. | ||||||
|  | 				NewUpdate(). | ||||||
|  | 				Model(a). | ||||||
|  | 				Where("? = ?", bun.Ident("media_attachment.id"), a.ID). | ||||||
|  | 				Exec(ctx); err != nil { | ||||||
|  | 				err = s.conn.errProc(err) | ||||||
|  | 				if !errors.Is(err, db.ErrAlreadyExists) { | ||||||
|  | 					return err | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Finally, insert the status | ||||||
|  | 		if _, err := tx. | ||||||
|  | 			NewInsert(). | ||||||
|  | 			Model(status). | ||||||
|  | 			Exec(ctx); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return nil | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return s.conn.ProcessError(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s.cache.Put(status) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (*gtsmodel.Status, db.Error) { | ||||||
|  | 	err := s.conn.RunInTx(ctx, func(tx bun.Tx) error { | ||||||
|  | 		// create links between this status and any emojis it uses | ||||||
|  | 		for _, i := range status.EmojiIDs { | ||||||
|  | 			if _, err := tx. | ||||||
|  | 				NewInsert(). | ||||||
|  | 				Model(>smodel.StatusToEmoji{ | ||||||
|  | 					StatusID: status.ID, | ||||||
|  | 					EmojiID:  i, | ||||||
|  | 				}).Exec(ctx); err != nil { | ||||||
|  | 				err = s.conn.errProc(err) | ||||||
|  | 				if !errors.Is(err, db.ErrAlreadyExists) { | ||||||
|  | 					return err | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// create links between this status and any tags it uses | ||||||
|  | 		for _, i := range status.TagIDs { | ||||||
|  | 			if _, err := tx. | ||||||
|  | 				NewInsert(). | ||||||
|  | 				Model(>smodel.StatusToTag{ | ||||||
| 					StatusID: status.ID, | 					StatusID: status.ID, | ||||||
| 					TagID:    i, | 					TagID:    i, | ||||||
| 				}).Exec(ctx); err != nil { | 				}).Exec(ctx); err != nil { | ||||||
|  | @ -208,23 +238,32 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status) (* | ||||||
| 		for _, a := range status.Attachments { | 		for _, a := range status.Attachments { | ||||||
| 			a.StatusID = status.ID | 			a.StatusID = status.ID | ||||||
| 			a.UpdatedAt = time.Now() | 			a.UpdatedAt = time.Now() | ||||||
| 			if _, err := tx.NewUpdate().Model(a). | 			if _, err := tx. | ||||||
| 				Where("id = ?", a.ID). | 				NewUpdate(). | ||||||
|  | 				Model(a). | ||||||
|  | 				Where("? = ?", bun.Ident("media_attachment.id"), a.ID). | ||||||
| 				Exec(ctx); err != nil { | 				Exec(ctx); err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// Finally, update the status itself | 		// Finally, update the status itself | ||||||
| 		if _, err := tx.NewUpdate().Model(status).WherePK().Exec(ctx); err != nil { | 		if _, err := tx. | ||||||
|  | 			NewUpdate(). | ||||||
|  | 			Model(status). | ||||||
|  | 			Where("? = ?", bun.Ident("status.id"), status.ID). | ||||||
|  | 			Exec(ctx); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		s.cache.Put(status) |  | ||||||
| 		return nil | 		return nil | ||||||
| 	}) | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, s.conn.ProcessError(err) | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	return status, err | 	s.cache.Put(status) | ||||||
|  | 	return status, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { | func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { | ||||||
|  | @ -232,8 +271,8 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { | ||||||
| 		// delete links between this status and any emojis it uses | 		// delete links between this status and any emojis it uses | ||||||
| 		if _, err := tx. | 		if _, err := tx. | ||||||
| 			NewDelete(). | 			NewDelete(). | ||||||
| 			Model(>smodel.StatusToEmoji{}). | 			TableExpr("? AS ?", bun.Ident("status_to_emojis"), bun.Ident("status_to_emoji")). | ||||||
| 			Where("status_id = ?", id). | 			Where("? = ?", bun.Ident("status_to_emoji.status_id"), id). | ||||||
| 			Exec(ctx); err != nil { | 			Exec(ctx); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
|  | @ -241,8 +280,8 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { | ||||||
| 		// delete links between this status and any tags it uses | 		// delete links between this status and any tags it uses | ||||||
| 		if _, err := tx. | 		if _, err := tx. | ||||||
| 			NewDelete(). | 			NewDelete(). | ||||||
| 			Model(>smodel.StatusToTag{}). | 			TableExpr("? AS ?", bun.Ident("status_to_tags"), bun.Ident("status_to_tag")). | ||||||
| 			Where("status_id = ?", id). | 			Where("? = ?", bun.Ident("status_to_tag.status_id"), id). | ||||||
| 			Exec(ctx); err != nil { | 			Exec(ctx); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
|  | @ -250,17 +289,20 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error { | ||||||
| 		// delete the status itself | 		// delete the status itself | ||||||
| 		if _, err := tx. | 		if _, err := tx. | ||||||
| 			NewDelete(). | 			NewDelete(). | ||||||
| 			Model(>smodel.Status{ID: id}). | 			TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). | ||||||
| 			WherePK(). | 			Where("? = ?", bun.Ident("status.id"), id). | ||||||
| 			Exec(ctx); err != nil { | 			Exec(ctx); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		s.cache.Invalidate(id) |  | ||||||
| 		return nil | 		return nil | ||||||
| 	}) | 	}) | ||||||
| 
 | 	if err != nil { | ||||||
| 		return s.conn.ProcessError(err) | 		return s.conn.ProcessError(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s.cache.Invalidate(id) | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { | func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { | ||||||
|  | @ -312,11 +354,11 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, | ||||||
| 
 | 
 | ||||||
| 	q := s.conn. | 	q := s.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Table("statuses"). | 		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). | ||||||
| 		Column("id"). | 		Column("status.id"). | ||||||
| 		Where("in_reply_to_id = ?", status.ID) | 		Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID) | ||||||
| 	if minID != "" { | 	if minID != "" { | ||||||
| 		q = q.Where("id > ?", minID) | 		q = q.Where("? > ?", bun.Ident("status.id"), minID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx, &childIDs); err != nil { | 	if err := q.Scan(ctx, &childIDs); err != nil { | ||||||
|  | @ -356,23 +398,35 @@ func (s *statusDB) statusChildren(ctx context.Context, status *gtsmodel.Status, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { | func (s *statusDB) CountStatusReplies(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { | ||||||
| 	return s.conn.NewSelect().Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count(ctx) | 	return s.conn. | ||||||
|  | 		NewSelect(). | ||||||
|  | 		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). | ||||||
|  | 		Where("? = ?", bun.Ident("status.in_reply_to_id"), status.ID). | ||||||
|  | 		Count(ctx) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { | func (s *statusDB) CountStatusReblogs(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { | ||||||
| 	return s.conn.NewSelect().Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count(ctx) | 	return s.conn. | ||||||
|  | 		NewSelect(). | ||||||
|  | 		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). | ||||||
|  | 		Where("? = ?", bun.Ident("status.boost_of_id"), status.ID). | ||||||
|  | 		Count(ctx) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { | func (s *statusDB) CountStatusFaves(ctx context.Context, status *gtsmodel.Status) (int, db.Error) { | ||||||
| 	return s.conn.NewSelect().Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count(ctx) | 	return s.conn. | ||||||
|  | 		NewSelect(). | ||||||
|  | 		TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). | ||||||
|  | 		Where("? = ?", bun.Ident("status_fave.status_id"), status.ID). | ||||||
|  | 		Count(ctx) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { | func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { | ||||||
| 	q := s.conn. | 	q := s.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.StatusFave{}). | 		TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")). | ||||||
| 		Where("status_id = ?", status.ID). | 		Where("? = ?", bun.Ident("status_fave.status_id"), status.ID). | ||||||
| 		Where("account_id = ?", accountID) | 		Where("? = ?", bun.Ident("status_fave.account_id"), accountID) | ||||||
| 
 | 
 | ||||||
| 	return s.conn.Exists(ctx, q) | 	return s.conn.Exists(ctx, q) | ||||||
| } | } | ||||||
|  | @ -380,9 +434,9 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, | ||||||
| func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { | func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { | ||||||
| 	q := s.conn. | 	q := s.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.Status{}). | 		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). | ||||||
| 		Where("boost_of_id = ?", status.ID). | 		Where("? = ?", bun.Ident("status.boost_of_id"), status.ID). | ||||||
| 		Where("account_id = ?", accountID) | 		Where("? = ?", bun.Ident("status.account_id"), accountID) | ||||||
| 
 | 
 | ||||||
| 	return s.conn.Exists(ctx, q) | 	return s.conn.Exists(ctx, q) | ||||||
| } | } | ||||||
|  | @ -390,9 +444,9 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta | ||||||
| func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { | func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { | ||||||
| 	q := s.conn. | 	q := s.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.StatusMute{}). | 		TableExpr("? AS ?", bun.Ident("status_mutes"), bun.Ident("status_mute")). | ||||||
| 		Where("status_id = ?", status.ID). | 		Where("? = ?", bun.Ident("status_mute.status_id"), status.ID). | ||||||
| 		Where("account_id = ?", accountID) | 		Where("? = ?", bun.Ident("status_mute.account_id"), accountID) | ||||||
| 
 | 
 | ||||||
| 	return s.conn.Exists(ctx, q) | 	return s.conn.Exists(ctx, q) | ||||||
| } | } | ||||||
|  | @ -400,9 +454,9 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, | ||||||
| func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { | func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { | ||||||
| 	q := s.conn. | 	q := s.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(>smodel.StatusBookmark{}). | 		TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). | ||||||
| 		Where("status_id = ?", status.ID). | 		Where("? = ?", bun.Ident("status_bookmark.status_id"), status.ID). | ||||||
| 		Where("account_id = ?", accountID) | 		Where("? = ?", bun.Ident("status_bookmark.account_id"), accountID) | ||||||
| 
 | 
 | ||||||
| 	return s.conn.Exists(ctx, q) | 	return s.conn.Exists(ctx, q) | ||||||
| } | } | ||||||
|  | @ -410,8 +464,9 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St | ||||||
| func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) { | func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) { | ||||||
| 	faves := []*gtsmodel.StatusFave{} | 	faves := []*gtsmodel.StatusFave{} | ||||||
| 
 | 
 | ||||||
| 	q := s.newFaveQ(&faves). | 	q := s. | ||||||
| 		Where("status_id = ?", status.ID) | 		newFaveQ(&faves). | ||||||
|  | 		Where("? = ?", bun.Ident("status_fave.status_id"), status.ID) | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx); err != nil { | 	if err := q.Scan(ctx); err != nil { | ||||||
| 		return nil, s.conn.ProcessError(err) | 		return nil, s.conn.ProcessError(err) | ||||||
|  | @ -422,8 +477,9 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) | ||||||
| func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) { | func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) { | ||||||
| 	reblogs := []*gtsmodel.Status{} | 	reblogs := []*gtsmodel.Status{} | ||||||
| 
 | 
 | ||||||
| 	q := s.newStatusQ(&reblogs). | 	q := s. | ||||||
| 		Where("boost_of_id = ?", status.ID) | 		newStatusQ(&reblogs). | ||||||
|  | 		Where("? = ?", bun.Ident("status.boost_of_id"), status.ID) | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx); err != nil { | 	if err := q.Scan(ctx); err != nil { | ||||||
| 		return nil, s.conn.ProcessError(err) | 		return nil, s.conn.ProcessError(err) | ||||||
|  |  | ||||||
|  | @ -108,14 +108,14 @@ func (suite *StatusTestSuite) TestGetStatusTwice() { | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	after1 := time.Now() | 	after1 := time.Now() | ||||||
| 	duration1 := after1.Sub(before1) | 	duration1 := after1.Sub(before1) | ||||||
| 	fmt.Println(duration1.Milliseconds()) | 	fmt.Println(duration1.Microseconds()) | ||||||
| 
 | 
 | ||||||
| 	before2 := time.Now() | 	before2 := time.Now() | ||||||
| 	_, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI) | 	_, err = suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_1_status_1"].URI) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	after2 := time.Now() | 	after2 := time.Now() | ||||||
| 	duration2 := after2.Sub(before2) | 	duration2 := after2.Sub(before2) | ||||||
| 	fmt.Println(duration2.Milliseconds()) | 	fmt.Println(duration2.Microseconds()) | ||||||
| 
 | 
 | ||||||
| 	// second retrieval should be several orders faster since it will be cached now | 	// second retrieval should be several orders faster since it will be cached now | ||||||
| 	suite.Less(duration2, duration1) | 	suite.Less(duration2, duration1) | ||||||
|  |  | ||||||
|  | @ -34,38 +34,48 @@ type timelineDB struct { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { | func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { | ||||||
|  | 	// Ensure reasonable | ||||||
|  | 	if limit < 0 { | ||||||
|  | 		limit = 0 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Make educated guess for slice size | 	// Make educated guess for slice size | ||||||
| 	statusIDs := make([]string, 0, limit) | 	statusIDs := make([]string, 0, limit) | ||||||
| 
 | 
 | ||||||
| 	q := t.conn. | 	q := t.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Table("statuses"). | 		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). | ||||||
| 
 |  | ||||||
| 		// Select only IDs from table | 		// Select only IDs from table | ||||||
| 		Column("statuses.id"). | 		Column("status.id"). | ||||||
| 		// Find out who accountID follows. | 		// Find out who accountID follows. | ||||||
| 		Join("LEFT JOIN follows ON follows.target_account_id = statuses.account_id AND follows.account_id = ?", accountID). | 		Join("LEFT JOIN ? AS ? ON ? = ? AND ? = ?", | ||||||
|  | 			bun.Ident("follows"), | ||||||
|  | 			bun.Ident("follow"), | ||||||
|  | 			bun.Ident("follow.target_account_id"), | ||||||
|  | 			bun.Ident("status.account_id"), | ||||||
|  | 			bun.Ident("follow.account_id"), | ||||||
|  | 			accountID). | ||||||
| 		// Sort by highest ID (newest) to lowest ID (oldest) | 		// Sort by highest ID (newest) to lowest ID (oldest) | ||||||
| 		Order("statuses.id DESC") | 		Order("status.id DESC") | ||||||
| 
 | 
 | ||||||
| 	if maxID != "" { | 	if maxID != "" { | ||||||
| 		// return only statuses LOWER (ie., older) than maxID | 		// return only statuses LOWER (ie., older) than maxID | ||||||
| 		q = q.Where("statuses.id < ?", maxID) | 		q = q.Where("? < ?", bun.Ident("status.id"), maxID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if sinceID != "" { | 	if sinceID != "" { | ||||||
| 		// return only statuses HIGHER (ie., newer) than sinceID | 		// return only statuses HIGHER (ie., newer) than sinceID | ||||||
| 		q = q.Where("statuses.id > ?", sinceID) | 		q = q.Where("? > ?", bun.Ident("status.id"), sinceID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if minID != "" { | 	if minID != "" { | ||||||
| 		// return only statuses HIGHER (ie., newer) than minID | 		// return only statuses HIGHER (ie., newer) than minID | ||||||
| 		q = q.Where("statuses.id > ?", minID) | 		q = q.Where("? > ?", bun.Ident("status.id"), minID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if local { | 	if local { | ||||||
| 		// return only statuses posted by local account havers | 		// return only statuses posted by local account havers | ||||||
| 		q = q.Where("statuses.local = ?", local) | 		q = q.Where("? = ?", bun.Ident("status.local"), local) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if limit > 0 { | 	if limit > 0 { | ||||||
|  | @ -78,13 +88,11 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI | ||||||
| 	// | 	// | ||||||
| 	// This is equivalent to something like WHERE ... AND (... OR ...) | 	// This is equivalent to something like WHERE ... AND (... OR ...) | ||||||
| 	// See: https://bun.uptrace.dev/guide/queries.html#select | 	// See: https://bun.uptrace.dev/guide/queries.html#select | ||||||
| 	whereGroup := func(*bun.SelectQuery) *bun.SelectQuery { | 	q = q.WhereGroup(" AND ", func(*bun.SelectQuery) *bun.SelectQuery { | ||||||
| 		return q. | 		return q. | ||||||
| 			WhereOr("follows.account_id = ?", accountID). | 			WhereOr("? = ?", bun.Ident("follow.account_id"), accountID). | ||||||
| 			WhereOr("statuses.account_id = ?", accountID) | 			WhereOr("? = ?", bun.Ident("status.account_id"), accountID) | ||||||
| 	} | 	}) | ||||||
| 
 |  | ||||||
| 	q = q.WhereGroup(" AND ", whereGroup) |  | ||||||
| 
 | 
 | ||||||
| 	if err := q.Scan(ctx, &statusIDs); err != nil { | 	if err := q.Scan(ctx, &statusIDs); err != nil { | ||||||
| 		return nil, t.conn.ProcessError(err) | 		return nil, t.conn.ProcessError(err) | ||||||
|  | @ -118,28 +126,28 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, ma | ||||||
| 
 | 
 | ||||||
| 	q := t.conn. | 	q := t.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Table("statuses"). | 		TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). | ||||||
| 		Column("statuses.id"). | 		Column("status.id"). | ||||||
| 		Where("statuses.visibility = ?", gtsmodel.VisibilityPublic). | 		Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). | ||||||
| 		WhereGroup(" AND ", whereEmptyOrNull("statuses.in_reply_to_id")). | 		WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_id")). | ||||||
| 		WhereGroup(" AND ", whereEmptyOrNull("statuses.in_reply_to_uri")). | 		WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")). | ||||||
| 		WhereGroup(" AND ", whereEmptyOrNull("statuses.boost_of_id")). | 		WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")). | ||||||
| 		Order("statuses.id DESC") | 		Order("status.id DESC") | ||||||
| 
 | 
 | ||||||
| 	if maxID != "" { | 	if maxID != "" { | ||||||
| 		q = q.Where("statuses.id < ?", maxID) | 		q = q.Where("? < ?", bun.Ident("status.id"), maxID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if sinceID != "" { | 	if sinceID != "" { | ||||||
| 		q = q.Where("statuses.id > ?", sinceID) | 		q = q.Where("? > ?", bun.Ident("status.id"), sinceID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if minID != "" { | 	if minID != "" { | ||||||
| 		q = q.Where("statuses.id > ?", minID) | 		q = q.Where("? > ?", bun.Ident("status.id"), minID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if local { | 	if local { | ||||||
| 		q = q.Where("statuses.local = ?", local) | 		q = q.Where("? = ?", bun.Ident("status.local"), local) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if limit > 0 { | 	if limit > 0 { | ||||||
|  | @ -181,15 +189,15 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max | ||||||
| 	fq := t.conn. | 	fq := t.conn. | ||||||
| 		NewSelect(). | 		NewSelect(). | ||||||
| 		Model(&faves). | 		Model(&faves). | ||||||
| 		Where("account_id = ?", accountID). | 		Where("? = ?", bun.Ident("status_fave.account_id"), accountID). | ||||||
| 		Order("id DESC") | 		Order("status_fave.id DESC") | ||||||
| 
 | 
 | ||||||
| 	if maxID != "" { | 	if maxID != "" { | ||||||
| 		fq = fq.Where("id < ?", maxID) | 		fq = fq.Where("? < ?", bun.Ident("status_fave.id"), maxID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if minID != "" { | 	if minID != "" { | ||||||
| 		fq = fq.Where("id > ?", minID) | 		fq = fq.Where("? > ?", bun.Ident("status_fave.id"), minID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if limit > 0 { | 	if limit > 0 { | ||||||
|  |  | ||||||
|  | @ -38,6 +38,15 @@ func (suite *TimelineTestSuite) TestGetPublicTimeline() { | ||||||
| 	suite.Len(s, 6) | 	suite.Len(s, 6) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (suite *TimelineTestSuite) TestGetHomeTimeline() { | ||||||
|  | 	viewingAccount := suite.testAccounts["local_account_1"] | ||||||
|  | 
 | ||||||
|  | 	s, err := suite.db.GetHomeTimeline(context.Background(), viewingAccount.ID, "", "", "", 20, false) | ||||||
|  | 	suite.NoError(err) | ||||||
|  | 
 | ||||||
|  | 	suite.Len(s, 16) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestTimelineTestSuite(t *testing.T) { | func TestTimelineTestSuite(t *testing.T) { | ||||||
| 	suite.Run(t, new(TimelineTestSuite)) | 	suite.Run(t, new(TimelineTestSuite)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -67,7 +67,7 @@ func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db | ||||||
| 			return u.cache.GetByID(id) | 			return u.cache.GetByID(id) | ||||||
| 		}, | 		}, | ||||||
| 		func(user *gtsmodel.User) error { | 		func(user *gtsmodel.User) error { | ||||||
| 			return u.newUserQ(user).Where("user.id = ?", id).Scan(ctx) | 			return u.newUserQ(user).Where("? = ?", bun.Ident("user.id"), id).Scan(ctx) | ||||||
| 		}, | 		}, | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
|  | @ -79,7 +79,7 @@ func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gts | ||||||
| 			return u.cache.GetByAccountID(accountID) | 			return u.cache.GetByAccountID(accountID) | ||||||
| 		}, | 		}, | ||||||
| 		func(user *gtsmodel.User) error { | 		func(user *gtsmodel.User) error { | ||||||
| 			return u.newUserQ(user).Where("user.account_id = ?", accountID).Scan(ctx) | 			return u.newUserQ(user).Where("? = ?", bun.Ident("user.account_id"), accountID).Scan(ctx) | ||||||
| 		}, | 		}, | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
|  | @ -91,7 +91,7 @@ func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) | ||||||
| 			return u.cache.GetByEmail(emailAddress) | 			return u.cache.GetByEmail(emailAddress) | ||||||
| 		}, | 		}, | ||||||
| 		func(user *gtsmodel.User) error { | 		func(user *gtsmodel.User) error { | ||||||
| 			return u.newUserQ(user).Where("user.email = ?", emailAddress).Scan(ctx) | 			return u.newUserQ(user).Where("? = ?", bun.Ident("user.email"), emailAddress).Scan(ctx) | ||||||
| 		}, | 		}, | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
|  | @ -103,7 +103,7 @@ func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationTok | ||||||
| 			return u.cache.GetByConfirmationToken(confirmationToken) | 			return u.cache.GetByConfirmationToken(confirmationToken) | ||||||
| 		}, | 		}, | ||||||
| 		func(user *gtsmodel.User) error { | 		func(user *gtsmodel.User) error { | ||||||
| 			return u.newUserQ(user).Where("user.confirmation_token = ?", confirmationToken).Scan(ctx) | 			return u.newUserQ(user).Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken).Scan(ctx) | ||||||
| 		}, | 		}, | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
|  | @ -127,7 +127,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns .. | ||||||
| 	if _, err := u.conn. | 	if _, err := u.conn. | ||||||
| 		NewUpdate(). | 		NewUpdate(). | ||||||
| 		Model(user). | 		Model(user). | ||||||
| 		WherePK(). | 		Where("? = ?", bun.Ident("user.id"), user.ID). | ||||||
| 		Column(columns...). | 		Column(columns...). | ||||||
| 		Exec(ctx); err != nil { | 		Exec(ctx); err != nil { | ||||||
| 		return nil, u.conn.ProcessError(err) | 		return nil, u.conn.ProcessError(err) | ||||||
|  | @ -140,8 +140,8 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns .. | ||||||
| func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error { | func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error { | ||||||
| 	if _, err := u.conn. | 	if _, err := u.conn. | ||||||
| 		NewDelete(). | 		NewDelete(). | ||||||
| 		Model(>smodel.User{ID: userID}). | 		TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")). | ||||||
| 		WherePK(). | 		Where("? = ?", bun.Ident("user.id"), userID). | ||||||
| 		Exec(ctx); err != nil { | 		Exec(ctx); err != nil { | ||||||
| 		return u.conn.ProcessError(err) | 		return u.conn.ProcessError(err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -85,14 +85,8 @@ func parseWhere(w db.Where) (query string, args []interface{}) { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if w.CaseInsensitive { |  | ||||||
| 			query = "LOWER(?) != LOWER(?)" |  | ||||||
| 			args = []interface{}{bun.Safe(w.Key), w.Value} |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		query = "? != ?" | 		query = "? != ?" | ||||||
| 		args = []interface{}{bun.Safe(w.Key), w.Value} | 		args = []interface{}{bun.Ident(w.Key), w.Value} | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -102,13 +96,7 @@ func parseWhere(w db.Where) (query string, args []interface{}) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if w.CaseInsensitive { |  | ||||||
| 		query = "LOWER(?) = LOWER(?)" |  | ||||||
| 		args = []interface{}{bun.Safe(w.Key), w.Value} |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	query = "? = ?" | 	query = "? = ?" | ||||||
| 	args = []interface{}{bun.Safe(w.Key), w.Value} | 	args = []interface{}{bun.Ident(w.Key), w.Value} | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -24,9 +24,6 @@ type Where struct { | ||||||
| 	Key string | 	Key string | ||||||
| 	// The value to match. | 	// The value to match. | ||||||
| 	Value interface{} | 	Value interface{} | ||||||
| 	// Whether the value (if a string) should be case sensitive or not. |  | ||||||
| 	// Defaults to false. |  | ||||||
| 	CaseInsensitive bool |  | ||||||
| 	// If set, reverse the where. | 	// If set, reverse the where. | ||||||
| 	// `WHERE k = v` becomes `WHERE k != v`. | 	// `WHERE k = v` becomes `WHERE k != v`. | ||||||
| 	// `WHERE k IS NULL` becomes `WHERE k IS NOT NULL` | 	// `WHERE k IS NULL` becomes `WHERE k IS NOT NULL` | ||||||
|  |  | ||||||
|  | @ -101,7 +101,7 @@ func (p *ProcessingMedia) LoadAttachment(ctx context.Context) (*gtsmodel.MediaAt | ||||||
| 	if !p.insertedInDB { | 	if !p.insertedInDB { | ||||||
| 		if p.recache { | 		if p.recache { | ||||||
| 			// if it's a recache we should only need to update | 			// if it's a recache we should only need to update | ||||||
| 			if err := p.database.UpdateByPrimaryKey(ctx, p.attachment); err != nil { | 			if err := p.database.UpdateByID(ctx, p.attachment, p.attachment.ID); err != nil { | ||||||
| 				return nil, err | 				return nil, err | ||||||
| 			} | 			} | ||||||
| 		} else { | 		} else { | ||||||
|  |  | ||||||
|  | @ -40,7 +40,7 @@ func (suite *PruneMetaTestSuite) TestPruneMeta() { | ||||||
| 	zork := suite.testAccounts["local_account_1"] | 	zork := suite.testAccounts["local_account_1"] | ||||||
| 	zork.AvatarMediaAttachmentID = "" | 	zork.AvatarMediaAttachmentID = "" | ||||||
| 	zork.HeaderMediaAttachmentID = "" | 	zork.HeaderMediaAttachmentID = "" | ||||||
| 	if err := suite.db.UpdateByPrimaryKey(ctx, zork, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil { | 	if err := suite.db.UpdateByID(ctx, zork, zork.ID, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil { | ||||||
| 		panic(err) | 		panic(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -72,7 +72,7 @@ func (suite *PruneMetaTestSuite) TestPruneMetaTwice() { | ||||||
| 	zork := suite.testAccounts["local_account_1"] | 	zork := suite.testAccounts["local_account_1"] | ||||||
| 	zork.AvatarMediaAttachmentID = "" | 	zork.AvatarMediaAttachmentID = "" | ||||||
| 	zork.HeaderMediaAttachmentID = "" | 	zork.HeaderMediaAttachmentID = "" | ||||||
| 	if err := suite.db.UpdateByPrimaryKey(ctx, zork, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil { | 	if err := suite.db.UpdateByID(ctx, zork, zork.ID, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil { | ||||||
| 		panic(err) | 		panic(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -95,14 +95,14 @@ func (suite *PruneMetaTestSuite) TestPruneMetaMultipleAccounts() { | ||||||
| 	zork := suite.testAccounts["local_account_1"] | 	zork := suite.testAccounts["local_account_1"] | ||||||
| 	zork.AvatarMediaAttachmentID = "" | 	zork.AvatarMediaAttachmentID = "" | ||||||
| 	zork.HeaderMediaAttachmentID = "" | 	zork.HeaderMediaAttachmentID = "" | ||||||
| 	if err := suite.db.UpdateByPrimaryKey(ctx, zork, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil { | 	if err := suite.db.UpdateByID(ctx, zork, zork.ID, "avatar_media_attachment_id", "header_media_attachment_id"); err != nil { | ||||||
| 		panic(err) | 		panic(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// set zork's unused header as belonging to turtle | 	// set zork's unused header as belonging to turtle | ||||||
| 	turtle := suite.testAccounts["local_account_1"] | 	turtle := suite.testAccounts["local_account_1"] | ||||||
| 	zorkOldHeader.AccountID = turtle.ID | 	zorkOldHeader.AccountID = turtle.ID | ||||||
| 	if err := suite.db.UpdateByPrimaryKey(ctx, zorkOldHeader, "account_id"); err != nil { | 	if err := suite.db.UpdateByID(ctx, zorkOldHeader, zorkOldHeader.ID, "account_id"); err != nil { | ||||||
| 		panic(err) | 		panic(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -90,7 +90,7 @@ func (m *manager) pruneOneRemote(ctx context.Context, attachment *gtsmodel.Media | ||||||
| 
 | 
 | ||||||
| 	// update the attachment to reflect that we no longer have it cached | 	// update the attachment to reflect that we no longer have it cached | ||||||
| 	if changed { | 	if changed { | ||||||
| 		return m.db.UpdateByPrimaryKey(ctx, attachment, "updated_at", "cached") | 		return m.db.UpdateByID(ctx, attachment, attachment.ID, "updated_at", "cached") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  |  | ||||||
|  | @ -128,15 +128,17 @@ func (p *processor) initiateDomainBlockSideEffects(ctx context.Context, account | ||||||
| 		instance.ContactAccountUsername = "" | 		instance.ContactAccountUsername = "" | ||||||
| 		instance.ContactAccountID = "" | 		instance.ContactAccountID = "" | ||||||
| 		instance.Version = "" | 		instance.Version = "" | ||||||
| 		if err := p.db.UpdateByPrimaryKey(ctx, instance, updatingColumns...); err != nil { | 		if err := p.db.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil { | ||||||
| 			l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err) | 			l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err) | ||||||
| 		} | 		} | ||||||
| 		l.Debug("domainBlockProcessSideEffects: instance entry updated") | 		l.Debug("domainBlockProcessSideEffects: instance entry updated") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// if we have an instance account for this instance, delete it | 	// if we have an instance account for this instance, delete it | ||||||
| 	if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "username", Value: block.Domain, CaseInsensitive: true}}, >smodel.Account{}); err != nil { | 	if instanceAccount, err := p.db.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil { | ||||||
| 		l.Errorf("domainBlockProcessSideEffects: db error removing instance account: %s", err) | 		if err := p.db.DeleteAccount(ctx, instanceAccount.ID); err != nil { | ||||||
|  | 			l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// delete accounts through the normal account deletion system (which should also delete media + posts + remove posts from timelines) | 	// delete accounts through the normal account deletion system (which should also delete media + posts + remove posts from timelines) | ||||||
|  |  | ||||||
|  | @ -55,14 +55,14 @@ func (p *processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc | ||||||
| 	// remove the domain block reference from the instance, if we have an entry for it | 	// remove the domain block reference from the instance, if we have an entry for it | ||||||
| 	i := >smodel.Instance{} | 	i := >smodel.Instance{} | ||||||
| 	if err := p.db.GetWhere(ctx, []db.Where{ | 	if err := p.db.GetWhere(ctx, []db.Where{ | ||||||
| 		{Key: "domain", Value: domainBlock.Domain, CaseInsensitive: true}, | 		{Key: "domain", Value: domainBlock.Domain}, | ||||||
| 		{Key: "domain_block_id", Value: id}, | 		{Key: "domain_block_id", Value: id}, | ||||||
| 	}, i); err == nil { | 	}, i); err == nil { | ||||||
| 		updatingColumns := []string{"suspended_at", "domain_block_id", "updated_at"} | 		updatingColumns := []string{"suspended_at", "domain_block_id", "updated_at"} | ||||||
| 		i.SuspendedAt = time.Time{} | 		i.SuspendedAt = time.Time{} | ||||||
| 		i.DomainBlockID = "" | 		i.DomainBlockID = "" | ||||||
| 		i.UpdatedAt = time.Now() | 		i.UpdatedAt = time.Now() | ||||||
| 		if err := p.db.UpdateByPrimaryKey(ctx, i, updatingColumns...); err != nil { | 		if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { | ||||||
| 			return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err)) | 			return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err)) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -224,7 +224,7 @@ func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.db.UpdateByPrimaryKey(ctx, i, updatingColumns...); err != nil { | 	if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { | ||||||
| 		return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err)) | 		return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err)) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -69,7 +69,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileUncached() { | ||||||
| 	// uncache the file from local | 	// uncache the file from local | ||||||
| 	testAttachment := suite.testAttachments["remote_account_1_status_1_attachment_1"] | 	testAttachment := suite.testAttachments["remote_account_1_status_1_attachment_1"] | ||||||
| 	testAttachment.Cached = testrig.FalseBool() | 	testAttachment.Cached = testrig.FalseBool() | ||||||
| 	err := suite.db.UpdateByPrimaryKey(ctx, testAttachment, "cached") | 	err := suite.db.UpdateByID(ctx, testAttachment, testAttachment.ID, "cached") | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	err = suite.storage.Delete(ctx, testAttachment.File.Path) | 	err = suite.storage.Delete(ctx, testAttachment.File.Path) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
|  | @ -124,7 +124,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileUncachedInterrupted() { | ||||||
| 	// uncache the file from local | 	// uncache the file from local | ||||||
| 	testAttachment := suite.testAttachments["remote_account_1_status_1_attachment_1"] | 	testAttachment := suite.testAttachments["remote_account_1_status_1_attachment_1"] | ||||||
| 	testAttachment.Cached = testrig.FalseBool() | 	testAttachment.Cached = testrig.FalseBool() | ||||||
| 	err := suite.db.UpdateByPrimaryKey(ctx, testAttachment, "cached") | 	err := suite.db.UpdateByID(ctx, testAttachment, testAttachment.ID, "cached") | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	err = suite.storage.Delete(ctx, testAttachment.File.Path) | 	err = suite.storage.Delete(ctx, testAttachment.File.Path) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
|  | @ -179,7 +179,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileThumbnailUncached() { | ||||||
| 
 | 
 | ||||||
| 	// uncache the file from local | 	// uncache the file from local | ||||||
| 	testAttachment.Cached = testrig.FalseBool() | 	testAttachment.Cached = testrig.FalseBool() | ||||||
| 	err = suite.db.UpdateByPrimaryKey(ctx, testAttachment, "cached") | 	err = suite.db.UpdateByID(ctx, testAttachment, testAttachment.ID, "cached") | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 	err = suite.storage.Delete(ctx, testAttachment.File.Path) | 	err = suite.storage.Delete(ctx, testAttachment.File.Path) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
|  |  | ||||||
|  | @ -47,7 +47,7 @@ func (p *processor) Unattach(ctx context.Context, account *gtsmodel.Account, med | ||||||
| 	attachment.UpdatedAt = time.Now() | 	attachment.UpdatedAt = time.Now() | ||||||
| 	attachment.StatusID = "" | 	attachment.StatusID = "" | ||||||
| 
 | 
 | ||||||
| 	if err := p.db.UpdateByPrimaryKey(ctx, attachment, updatingColumns...); err != nil { | 	if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil { | ||||||
| 		return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error updating attachment: %s", err)) | 		return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error updating attachment: %s", err)) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -61,7 +61,7 @@ func (p *processor) Update(ctx context.Context, account *gtsmodel.Account, media | ||||||
| 		updatingColumns = append(updatingColumns, "focus_x", "focus_y") | 		updatingColumns = append(updatingColumns, "focus_x", "focus_y") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := p.db.UpdateByPrimaryKey(ctx, attachment, updatingColumns...); err != nil { | 	if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil { | ||||||
| 		return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating media: %s", err)) | 		return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating media: %s", err)) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -162,27 +162,28 @@ func (p *processor) ProcessMediaIDs(ctx context.Context, form *apimodel.Advanced | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	gtsMediaAttachments := []*gtsmodel.MediaAttachment{} | 	attachments := []*gtsmodel.MediaAttachment{} | ||||||
| 	attachments := []string{} | 	attachmentIDs := []string{} | ||||||
| 	for _, mediaID := range form.MediaIDs { | 	for _, mediaID := range form.MediaIDs { | ||||||
| 		// check these attachments exist | 		attachment, err := p.db.GetAttachmentByID(ctx, mediaID) | ||||||
| 		a := >smodel.MediaAttachment{} | 		if err != nil { | ||||||
| 		if err := p.db.GetByID(ctx, mediaID, a); err != nil { | 			return fmt.Errorf("ProcessMediaIDs: invalid media type or media not found for media id %s", mediaID) | ||||||
| 			return fmt.Errorf("invalid media type or media not found for media id %s", mediaID) |  | ||||||
| 		} | 		} | ||||||
| 		// check they belong to the requesting account id | 
 | ||||||
| 		if a.AccountID != thisAccountID { | 		if attachment.AccountID != thisAccountID { | ||||||
| 			return fmt.Errorf("media with id %s does not belong to account %s", mediaID, thisAccountID) | 			return fmt.Errorf("ProcessMediaIDs: media with id %s does not belong to account %s", mediaID, thisAccountID) | ||||||
| 		} | 		} | ||||||
| 		// check they're not already used in a status | 
 | ||||||
| 		if a.StatusID != "" || a.ScheduledStatusID != "" { | 		if attachment.StatusID != "" || attachment.ScheduledStatusID != "" { | ||||||
| 			return fmt.Errorf("media with id %s is already attached to a status", mediaID) | 			return fmt.Errorf("ProcessMediaIDs: media with id %s is already attached to a status", mediaID) | ||||||
| 		} | 		} | ||||||
| 		gtsMediaAttachments = append(gtsMediaAttachments, a) | 
 | ||||||
| 		attachments = append(attachments, a.ID) | 		attachments = append(attachments, attachment) | ||||||
|  | 		attachmentIDs = append(attachmentIDs, attachment.ID) | ||||||
| 	} | 	} | ||||||
| 	status.Attachments = gtsMediaAttachments | 
 | ||||||
| 	status.AttachmentIDs = attachments | 	status.Attachments = attachments | ||||||
|  | 	status.AttachmentIDs = attachmentIDs | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -45,7 +45,7 @@ func (p *processor) ChangePassword(ctx context.Context, user *gtsmodel.User, old | ||||||
| 	user.EncryptedPassword = string(newPasswordHash) | 	user.EncryptedPassword = string(newPasswordHash) | ||||||
| 	user.UpdatedAt = time.Now() | 	user.UpdatedAt = time.Now() | ||||||
| 
 | 
 | ||||||
| 	if err := p.db.UpdateByPrimaryKey(ctx, user, "encrypted_password", "updated_at"); err != nil { | 	if err := p.db.UpdateByID(ctx, user, user.ID, "encrypted_password", "updated_at"); err != nil { | ||||||
| 		return gtserror.NewErrorInternalError(err, "database error") | 		return gtserror.NewErrorInternalError(err, "database error") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -77,7 +77,7 @@ func (p *processor) SendConfirmEmail(ctx context.Context, user *gtsmodel.User, u | ||||||
| 	user.LastEmailedAt = time.Now() | 	user.LastEmailedAt = time.Now() | ||||||
| 	user.UpdatedAt = time.Now() | 	user.UpdatedAt = time.Now() | ||||||
| 
 | 
 | ||||||
| 	if err := p.db.UpdateByPrimaryKey(ctx, user, updatingColumns...); err != nil { | 	if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil { | ||||||
| 		return fmt.Errorf("SendConfirmEmail: error updating user entry after email sent: %s", err) | 		return fmt.Errorf("SendConfirmEmail: error updating user entry after email sent: %s", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -126,7 +126,7 @@ func (p *processor) ConfirmEmail(ctx context.Context, token string) (*gtsmodel.U | ||||||
| 	user.ConfirmationToken = "" | 	user.ConfirmationToken = "" | ||||||
| 	user.UpdatedAt = time.Now() | 	user.UpdatedAt = time.Now() | ||||||
| 
 | 
 | ||||||
| 	if err := p.db.UpdateByPrimaryKey(ctx, user, updatingColumns...); err != nil { | 	if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil { | ||||||
| 		return nil, gtserror.NewErrorInternalError(err) | 		return nil, gtserror.NewErrorInternalError(err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -74,7 +74,7 @@ func (suite *EmailConfirmTestSuite) TestConfirmEmail() { | ||||||
| 	user.ConfirmationSentAt = time.Now().Add(-5 * time.Minute) | 	user.ConfirmationSentAt = time.Now().Add(-5 * time.Minute) | ||||||
| 	user.ConfirmationToken = "1d1aa44b-afa4-49c8-ac4b-eceb61715cc6" | 	user.ConfirmationToken = "1d1aa44b-afa4-49c8-ac4b-eceb61715cc6" | ||||||
| 
 | 
 | ||||||
| 	err := suite.db.UpdateByPrimaryKey(ctx, user, updatingColumns...) | 	err := suite.db.UpdateByID(ctx, user, user.ID, updatingColumns...) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 
 | 
 | ||||||
| 	// confirm with the token set above | 	// confirm with the token set above | ||||||
|  | @ -102,7 +102,7 @@ func (suite *EmailConfirmTestSuite) TestConfirmEmailOldToken() { | ||||||
| 	user.ConfirmationSentAt = time.Now().Add(-192 * time.Hour) | 	user.ConfirmationSentAt = time.Now().Add(-192 * time.Hour) | ||||||
| 	user.ConfirmationToken = "1d1aa44b-afa4-49c8-ac4b-eceb61715cc6" | 	user.ConfirmationToken = "1d1aa44b-afa4-49c8-ac4b-eceb61715cc6" | ||||||
| 
 | 
 | ||||||
| 	err := suite.db.UpdateByPrimaryKey(ctx, user, updatingColumns...) | 	err := suite.db.UpdateByID(ctx, user, user.ID, updatingColumns...) | ||||||
| 	suite.NoError(err) | 	suite.NoError(err) | ||||||
| 
 | 
 | ||||||
| 	// confirm with the token set above | 	// confirm with the token set above | ||||||
|  |  | ||||||
|  | @ -187,7 +187,7 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, v := range NewTestStatuses() { | 	for _, v := range NewTestStatuses() { | ||||||
| 		if err := db.PutStatus(ctx, v); err != nil { | 		if err := db.Put(ctx, v); err != nil { | ||||||
| 			log.Panic(err) | 			log.Panic(err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | @ -198,12 +198,24 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) { | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	for _, v := range NewTestStatusToEmojis() { | ||||||
|  | 		if err := db.Put(ctx, v); err != nil { | ||||||
|  | 			log.Panic(err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	for _, v := range NewTestTags() { | 	for _, v := range NewTestTags() { | ||||||
| 		if err := db.Put(ctx, v); err != nil { | 		if err := db.Put(ctx, v); err != nil { | ||||||
| 			log.Panic(err) | 			log.Panic(err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	for _, v := range NewTestStatusToTags() { | ||||||
|  | 		if err := db.Put(ctx, v); err != nil { | ||||||
|  | 			log.Panic(err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	for _, v := range NewTestMentions() { | 	for _, v := range NewTestMentions() { | ||||||
| 		if err := db.Put(ctx, v); err != nil { | 		if err := db.Put(ctx, v); err != nil { | ||||||
| 			log.Panic(err) | 			log.Panic(err) | ||||||
|  |  | ||||||
|  | @ -977,6 +977,15 @@ func NewTestEmojis() map[string]*gtsmodel.Emoji { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func NewTestStatusToEmojis() map[string]*gtsmodel.StatusToEmoji { | ||||||
|  | 	return map[string]*gtsmodel.StatusToEmoji{ | ||||||
|  | 		"admin_account_status_1_rainbow": { | ||||||
|  | 			StatusID: "01F8MH75CBF9JFX4ZAD54N0W0R", | ||||||
|  | 			EmojiID:  "01F8MH9H8E4VG3KDYJR9EGPXCQ", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func NewTestInstances() map[string]*gtsmodel.Instance { | func NewTestInstances() map[string]*gtsmodel.Instance { | ||||||
| 	return map[string]*gtsmodel.Instance{ | 	return map[string]*gtsmodel.Instance{ | ||||||
| 		"localhost:8080": { | 		"localhost:8080": { | ||||||
|  | @ -1540,6 +1549,15 @@ func NewTestTags() map[string]*gtsmodel.Tag { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func NewTestStatusToTags() map[string]*gtsmodel.StatusToTag { | ||||||
|  | 	return map[string]*gtsmodel.StatusToTag{ | ||||||
|  | 		"admin_account_status_1_welcome": { | ||||||
|  | 			StatusID: "01F8MH75CBF9JFX4ZAD54N0W0R", | ||||||
|  | 			TagID:    "01F8MHA1A2NF9MJ3WCCQ3K8BSZ", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // NewTestMentions returns a map of gts model mentions keyed by their name. | // NewTestMentions returns a map of gts model mentions keyed by their name. | ||||||
| func NewTestMentions() map[string]*gtsmodel.Mention { | func NewTestMentions() map[string]*gtsmodel.Mention { | ||||||
| 	return map[string]*gtsmodel.Mention{ | 	return map[string]*gtsmodel.Mention{ | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue