| 
									
										
										
										
											2023-03-12 16:00:57 +01:00
										 |  |  | // GoToSocial | 
					
						
							|  |  |  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | 
					
						
							|  |  |  | // SPDX-License-Identifier: AGPL-3.0-or-later | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // This program is free software: you can redistribute it and/or modify | 
					
						
							|  |  |  | // it under the terms of the GNU Affero General Public License as published by | 
					
						
							|  |  |  | // the Free Software Foundation, either version 3 of the License, or | 
					
						
							|  |  |  | // (at your option) any later version. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // This program is distributed in the hope that it will be useful, | 
					
						
							|  |  |  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | 
					
						
							|  |  |  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | 
					
						
							|  |  |  | // GNU Affero General Public License for more details. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // You should have received a copy of the GNU Affero General Public License | 
					
						
							|  |  |  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | package bundb | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"context" | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | 	"database/sql" | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 	"errors" | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	"fmt" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/db" | 
					
						
							|  |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 
					
						
							| 
									
										
										
										
											2022-12-08 17:35:14 +00:00
										 |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/state" | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | 	"github.com/uptrace/bun" | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type relationshipDB struct { | 
					
						
							| 
									
										
										
										
											2022-12-08 17:35:14 +00:00
										 |  |  | 	conn  *DBConn | 
					
						
							|  |  |  | 	state *state.State | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | func (r *relationshipDB) newFollowQ(follow interface{}) *bun.SelectQuery { | 
					
						
							|  |  |  | 	return r.conn. | 
					
						
							|  |  |  | 		NewSelect(). | 
					
						
							|  |  |  | 		Model(follow). | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 		Relation("Account"). | 
					
						
							|  |  |  | 		Relation("TargetAccount") | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) { | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 	// Look for a block in direction of account1->account2 | 
					
						
							|  |  |  | 	block1, err := r.getBlock(ctx, account1, account2) | 
					
						
							|  |  |  | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | 
					
						
							|  |  |  | 		return false, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if block1 != nil { | 
					
						
							|  |  |  | 		// account1 blocks account2 | 
					
						
							|  |  |  | 		return true, nil | 
					
						
							|  |  |  | 	} else if !eitherDirection { | 
					
						
							|  |  |  | 		// Don't check for mutli-directional | 
					
						
							|  |  |  | 		return false, nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Look for a block in direction of account2->account1 | 
					
						
							|  |  |  | 	block2, err := r.getBlock(ctx, account2, account1) | 
					
						
							|  |  |  | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | 
					
						
							|  |  |  | 		return false, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return (block2 != nil), nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) { | 
					
						
							|  |  |  | 	// Fetch block from database | 
					
						
							|  |  |  | 	block, err := r.getBlock(ctx, account1, account2) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Set the block originating account | 
					
						
							| 
									
										
										
										
											2022-12-08 17:35:14 +00:00
										 |  |  | 	block.Account, err = r.state.DB.GetAccountByID(ctx, block.AccountID) | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Set the block target account | 
					
						
							| 
									
										
										
										
											2022-12-08 17:35:14 +00:00
										 |  |  | 	block.TargetAccount, err = r.state.DB.GetAccountByID(ctx, block.TargetAccountID) | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return block, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (r *relationshipDB) getBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) { | 
					
						
							| 
									
										
										
										
											2022-12-08 17:35:14 +00:00
										 |  |  | 	return r.state.Caches.GTS.Block().Load("AccountID.TargetAccountID", func() (*gtsmodel.Block, error) { | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 		var block gtsmodel.Block | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		q := r.conn.NewSelect().Model(&block). | 
					
						
							|  |  |  | 			Where("? = ?", bun.Ident("block.account_id"), account1). | 
					
						
							|  |  |  | 			Where("? = ?", bun.Ident("block.target_account_id"), account2) | 
					
						
							|  |  |  | 		if err := q.Scan(ctx); err != nil { | 
					
						
							|  |  |  | 			return nil, r.conn.ProcessError(err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		return &block, nil | 
					
						
							|  |  |  | 	}, account1, account2) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) db.Error { | 
					
						
							| 
									
										
										
										
											2022-12-08 17:35:14 +00:00
										 |  |  | 	return r.state.Caches.GTS.Block().Store(block, func() error { | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 		_, err := r.conn.NewInsert().Model(block).Exec(ctx) | 
					
						
							|  |  |  | 		return r.conn.ProcessError(err) | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) db.Error { | 
					
						
							|  |  |  | 	if _, err := r.conn. | 
					
						
							|  |  |  | 		NewDelete(). | 
					
						
							|  |  |  | 		TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("block.id"), id). | 
					
						
							|  |  |  | 		Exec(ctx); err != nil { | 
					
						
							|  |  |  | 		return r.conn.ProcessError(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Drop any old value from cache by this ID | 
					
						
							| 
									
										
										
										
											2022-12-08 17:35:14 +00:00
										 |  |  | 	r.state.Caches.GTS.Block().Invalidate("ID", id) | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) db.Error { | 
					
						
							|  |  |  | 	if _, err := r.conn. | 
					
						
							|  |  |  | 		NewDelete(). | 
					
						
							|  |  |  | 		TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("block.uri"), uri). | 
					
						
							|  |  |  | 		Exec(ctx); err != nil { | 
					
						
							|  |  |  | 		return r.conn.ProcessError(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Drop any old value from cache by this URI | 
					
						
							| 
									
										
										
										
											2022-12-08 17:35:14 +00:00
										 |  |  | 	r.state.Caches.GTS.Block().Invalidate("URI", uri) | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (r *relationshipDB) DeleteBlocksByOriginAccountID(ctx context.Context, originAccountID string) db.Error { | 
					
						
							|  |  |  | 	blockIDs := []string{} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	q := r.conn. | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | 		NewSelect(). | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 		Column("block.id"). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("block.account_id"), originAccountID) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := q.Scan(ctx, &blockIDs); err != nil { | 
					
						
							|  |  |  | 		return r.conn.ProcessError(err) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 	for _, blockID := range blockIDs { | 
					
						
							|  |  |  | 		if err := r.DeleteBlockByID(ctx, blockID); err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | func (r *relationshipDB) DeleteBlocksByTargetAccountID(ctx context.Context, targetAccountID string) db.Error { | 
					
						
							|  |  |  | 	blockIDs := []string{} | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 	q := r.conn. | 
					
						
							|  |  |  | 		NewSelect(). | 
					
						
							|  |  |  | 		TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")). | 
					
						
							|  |  |  | 		Column("block.id"). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 	if err := q.Scan(ctx, &blockIDs); err != nil { | 
					
						
							|  |  |  | 		return r.conn.ProcessError(err) | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	for _, blockID := range blockIDs { | 
					
						
							|  |  |  | 		if err := r.DeleteBlockByID(ctx, blockID); err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) { | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	rel := >smodel.Relationship{ | 
					
						
							|  |  |  | 		ID: targetAccount, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// check if the requesting account follows the target account | 
					
						
							|  |  |  | 	follow := >smodel.Follow{} | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | 	if err := r.conn. | 
					
						
							|  |  |  | 		NewSelect(). | 
					
						
							|  |  |  | 		Model(follow). | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		Column("follow.show_reblogs", "follow.notify"). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("follow.account_id"), requestingAccount). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount). | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | 		Limit(1). | 
					
						
							|  |  |  | 		Scan(ctx); err != nil { | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		if err := r.conn.ProcessError(err); err != db.ErrNoEntries { | 
					
						
							|  |  |  | 			return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 		} | 
					
						
							|  |  |  | 		// no follow exists so these are all false | 
					
						
							|  |  |  | 		rel.Following = false | 
					
						
							|  |  |  | 		rel.ShowingReblogs = false | 
					
						
							|  |  |  | 		rel.Notifying = false | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		// follow exists so we can fill these fields out... | 
					
						
							|  |  |  | 		rel.Following = true | 
					
						
							| 
									
										
										
										
											2022-08-15 12:35:05 +02:00
										 |  |  | 		rel.ShowingReblogs = *follow.ShowReblogs | 
					
						
							|  |  |  | 		rel.Notifying = *follow.Notify | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// check if the target account follows the requesting account | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 	followedByQ := r.conn. | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | 		NewSelect(). | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). | 
					
						
							|  |  |  | 		Column("follow.id"). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("follow.account_id"), targetAccount). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount) | 
					
						
							|  |  |  | 	followedBy, err := r.conn.Exists(ctx, followedByQ) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 	rel.FollowedBy = followedBy | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 	// check if there's a pending following request from requesting account to target account | 
					
						
							|  |  |  | 	requestedQ := r.conn. | 
					
						
							|  |  |  | 		NewSelect(). | 
					
						
							|  |  |  | 		TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). | 
					
						
							|  |  |  | 		Column("follow_request.id"). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount) | 
					
						
							|  |  |  | 	requested, err := r.conn.Exists(ctx, requestedQ) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 	rel.Requested = requested | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 	// check if the requesting account is blocking the target account | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 	blockA2T, err := r.getBlock(ctx, requestingAccount, targetAccount) | 
					
						
							|  |  |  | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 	rel.Blocking = (blockA2T != nil) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 	// check if the requesting account is blocked by the target account | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 	blockT2A, err := r.getBlock(ctx, targetAccount, requestingAccount) | 
					
						
							|  |  |  | 	if err != nil && !errors.Is(err, db.ErrNoEntries) { | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2022-11-20 16:33:49 +00:00
										 |  |  | 	rel.BlockedBy = (blockT2A != nil) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	return rel, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	if sourceAccount == nil || targetAccount == nil { | 
					
						
							|  |  |  | 		return false, nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	q := r.conn. | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | 		NewSelect(). | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")). | 
					
						
							|  |  |  | 		Column("follow.id"). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | 	return r.conn.Exists(ctx, q) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	if sourceAccount == nil || targetAccount == nil { | 
					
						
							|  |  |  | 		return false, nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	q := r.conn. | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | 		NewSelect(). | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). | 
					
						
							|  |  |  | 		Column("follow_request.id"). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | 	return r.conn.Exists(ctx, q) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) { | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	if account1 == nil || account2 == nil { | 
					
						
							|  |  |  | 		return false, nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// make sure account 1 follows account 2 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | 	f1, err := r.IsFollowing(ctx, account1, account2) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | 		return false, err | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// make sure account 2 follows account 1 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | 	f2, err := r.IsFollowing(ctx, account2, account1) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | 		return false, err | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return f1 && f2, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) { | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 	var follow *gtsmodel.Follow | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error { | 
					
						
							|  |  |  | 		// get original follow request | 
					
						
							|  |  |  | 		followRequest := >smodel.FollowRequest{} | 
					
						
							|  |  |  | 		if err := tx. | 
					
						
							|  |  |  | 			NewSelect(). | 
					
						
							|  |  |  | 			Model(followRequest). | 
					
						
							|  |  |  | 			Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). | 
					
						
							|  |  |  | 			Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). | 
					
						
							|  |  |  | 			Scan(ctx); err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		// create a new follow to 'replace' the request with | 
					
						
							|  |  |  | 		follow = >smodel.Follow{ | 
					
						
							|  |  |  | 			ID:              followRequest.ID, | 
					
						
							|  |  |  | 			AccountID:       originAccountID, | 
					
						
							|  |  |  | 			TargetAccountID: targetAccountID, | 
					
						
							|  |  |  | 			URI:             followRequest.URI, | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		// if the follow already exists, just update the URI -- we don't need to do anything else | 
					
						
							|  |  |  | 		if _, err := tx. | 
					
						
							|  |  |  | 			NewInsert(). | 
					
						
							|  |  |  | 			Model(follow). | 
					
						
							|  |  |  | 			On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI). | 
					
						
							|  |  |  | 			Exec(ctx); err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// now remove the follow request | 
					
						
							|  |  |  | 		if _, err := tx. | 
					
						
							|  |  |  | 			NewDelete(). | 
					
						
							|  |  |  | 			TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). | 
					
						
							|  |  |  | 			Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). | 
					
						
							|  |  |  | 			Exec(ctx); err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		return nil | 
					
						
							|  |  |  | 	}); err != nil { | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | 		return nil, r.conn.ProcessError(err) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 	// return the new follow | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	return follow, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-16 13:27:43 +02:00
										 |  |  | func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) { | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 	followRequest := >smodel.FollowRequest{} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := r.conn.RunInTx(ctx, func(tx bun.Tx) error { | 
					
						
							|  |  |  | 		// get original follow request | 
					
						
							|  |  |  | 		if err := tx. | 
					
						
							|  |  |  | 			NewSelect(). | 
					
						
							|  |  |  | 			Model(followRequest). | 
					
						
							|  |  |  | 			Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID). | 
					
						
							|  |  |  | 			Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID). | 
					
						
							|  |  |  | 			Scan(ctx); err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2021-10-16 13:27:43 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		// now delete it from the database by ID | 
					
						
							|  |  |  | 		if _, err := tx. | 
					
						
							|  |  |  | 			NewDelete(). | 
					
						
							|  |  |  | 			TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")). | 
					
						
							|  |  |  | 			Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). | 
					
						
							|  |  |  | 			Exec(ctx); err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		return nil | 
					
						
							|  |  |  | 	}); err != nil { | 
					
						
							| 
									
										
										
										
											2021-10-16 13:27:43 +02:00
										 |  |  | 		return nil, r.conn.ProcessError(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// return the deleted follow request | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 	return followRequest, nil | 
					
						
							| 
									
										
										
										
											2021-10-16 13:27:43 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, db.Error) { | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	followRequests := []*gtsmodel.FollowRequest{} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	q := r.newFollowQ(&followRequests). | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		Where("? = ?", bun.Ident("follow_request.target_account_id"), accountID). | 
					
						
							| 
									
										
										
										
											2022-08-31 05:27:39 -04:00
										 |  |  | 		Order("follow_request.updated_at DESC") | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-22 08:46:19 +01:00
										 |  |  | 	if err := q.Scan(ctx); err != nil { | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | 		return nil, r.conn.ProcessError(err) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | 	return followRequests, nil | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, db.Error) { | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	follows := []*gtsmodel.Follow{} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	q := r.newFollowQ(&follows). | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		Where("? = ?", bun.Ident("follow.account_id"), accountID). | 
					
						
							| 
									
										
										
										
											2022-08-31 05:27:39 -04:00
										 |  |  | 		Order("follow.updated_at DESC") | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-22 08:46:19 +01:00
										 |  |  | 	if err := q.Scan(ctx); err != nil { | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | 		return nil, r.conn.ProcessError(err) | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | 	return follows, nil | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 	q := r.conn. | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | 		NewSelect(). | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if localOnly { | 
					
						
							|  |  |  | 		q = q. | 
					
						
							|  |  |  | 			Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.target_account_id"), bun.Ident("account.id")). | 
					
						
							|  |  |  | 			Where("? = ?", bun.Ident("follow.account_id"), accountID). | 
					
						
							|  |  |  | 			Where("? IS NULL", bun.Ident("account.domain")) | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return q.Count(ctx) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	follows := []*gtsmodel.Follow{} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | 	q := r.conn. | 
					
						
							|  |  |  | 		NewSelect(). | 
					
						
							| 
									
										
										
										
											2022-08-31 05:27:39 -04:00
										 |  |  | 		Model(&follows). | 
					
						
							|  |  |  | 		Order("follow.updated_at DESC") | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	if localOnly { | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		q = q. | 
					
						
							|  |  |  | 			Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")). | 
					
						
							|  |  |  | 			Where("? = ?", bun.Ident("follow.target_account_id"), accountID). | 
					
						
							|  |  |  | 			Where("? IS NULL", bun.Ident("account.domain")) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	} else { | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | 	err := q.Scan(ctx) | 
					
						
							|  |  |  | 	if err != nil && err != sql.ErrNoRows { | 
					
						
							|  |  |  | 		return nil, r.conn.ProcessError(err) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	return follows, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | func (r *relationshipDB) CountAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) (int, db.Error) { | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 	q := r.conn. | 
					
						
							| 
									
										
										
										
											2021-08-25 15:34:33 +02:00
										 |  |  | 		NewSelect(). | 
					
						
							| 
									
										
										
										
											2022-10-08 13:50:48 +02:00
										 |  |  | 		TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if localOnly { | 
					
						
							|  |  |  | 		q = q. | 
					
						
							|  |  |  | 			Join("JOIN ? AS ? ON ? = ?", bun.Ident("accounts"), bun.Ident("account"), bun.Ident("follow.account_id"), bun.Ident("account.id")). | 
					
						
							|  |  |  | 			Where("? = ?", bun.Ident("follow.target_account_id"), accountID). | 
					
						
							|  |  |  | 			Where("? IS NULL", bun.Ident("account.domain")) | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		q = q.Where("? = ?", bun.Ident("follow.target_account_id"), accountID) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return q.Count(ctx) | 
					
						
							| 
									
										
										
										
											2021-08-20 12:26:56 +02:00
										 |  |  | } |