| 
									
										
										
										
											2024-08-02 13:41:46 +02:00
										 |  |  | // GoToSocial | 
					
						
							|  |  |  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | 
					
						
							|  |  |  | // SPDX-License-Identifier: AGPL-3.0-or-later | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // This program is free software: you can redistribute it and/or modify | 
					
						
							|  |  |  | // it under the terms of the GNU Affero General Public License as published by | 
					
						
							|  |  |  | // the Free Software Foundation, either version 3 of the License, or | 
					
						
							|  |  |  | // (at your option) any later version. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // This program is distributed in the hope that it will be useful, | 
					
						
							|  |  |  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | 
					
						
							|  |  |  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | 
					
						
							|  |  |  | // GNU Affero General Public License for more details. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // You should have received a copy of the GNU Affero General Public License | 
					
						
							|  |  |  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | package account | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"context" | 
					
						
							|  |  |  | 	"encoding/csv" | 
					
						
							|  |  |  | 	"errors" | 
					
						
							|  |  |  | 	"fmt" | 
					
						
							|  |  |  | 	"mime/multipart" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" | 
					
						
							|  |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 
					
						
							|  |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 
					
						
							|  |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/log" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (p *Processor) ImportData( | 
					
						
							|  |  |  | 	ctx context.Context, | 
					
						
							|  |  |  | 	requester *gtsmodel.Account, | 
					
						
							|  |  |  | 	data *multipart.FileHeader, | 
					
						
							|  |  |  | 	importType string, | 
					
						
							|  |  |  | 	overwrite bool, | 
					
						
							|  |  |  | ) gtserror.WithCode { | 
					
						
							|  |  |  | 	switch importType { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	case "following": | 
					
						
							|  |  |  | 		return p.importFollowing( | 
					
						
							|  |  |  | 			ctx, | 
					
						
							|  |  |  | 			requester, | 
					
						
							|  |  |  | 			data, | 
					
						
							|  |  |  | 			overwrite, | 
					
						
							|  |  |  | 		) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	case "blocks": | 
					
						
							|  |  |  | 		return p.importBlocks( | 
					
						
							|  |  |  | 			ctx, | 
					
						
							|  |  |  | 			requester, | 
					
						
							|  |  |  | 			data, | 
					
						
							|  |  |  | 			overwrite, | 
					
						
							|  |  |  | 		) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		const text = "import type not yet supported" | 
					
						
							|  |  |  | 		return gtserror.NewErrorUnprocessableEntity(errors.New(text), text) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (p *Processor) importFollowing( | 
					
						
							|  |  |  | 	ctx context.Context, | 
					
						
							|  |  |  | 	requester *gtsmodel.Account, | 
					
						
							|  |  |  | 	followingData *multipart.FileHeader, | 
					
						
							|  |  |  | 	overwrite bool, | 
					
						
							|  |  |  | ) gtserror.WithCode { | 
					
						
							|  |  |  | 	file, err := followingData.Open() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		err := fmt.Errorf("error opening following data file: %w", err) | 
					
						
							|  |  |  | 		return gtserror.NewErrorBadRequest(err, err.Error()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	defer file.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Parse records out of the file. | 
					
						
							|  |  |  | 	records, err := csv.NewReader(file).ReadAll() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		err := fmt.Errorf("error reading following data file: %w", err) | 
					
						
							|  |  |  | 		return gtserror.NewErrorBadRequest(err, err.Error()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Convert the records into a slice of barebones follows. | 
					
						
							|  |  |  | 	// | 
					
						
							|  |  |  | 	// Only TargetAccount.Username, TargetAccount.Domain, | 
					
						
							|  |  |  | 	// and ShowReblogs will be set on each Follow. | 
					
						
							|  |  |  | 	follows, err := p.converter.CSVToFollowing(ctx, records) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		err := fmt.Errorf("error converting records to follows: %w", err) | 
					
						
							|  |  |  | 		return gtserror.NewErrorBadRequest(err, err.Error()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Do remaining processing of this import asynchronously. | 
					
						
							|  |  |  | 	f := importFollowingAsyncF(p, requester, follows, overwrite) | 
					
						
							|  |  |  | 	p.state.Workers.Processing.Queue.Push(f) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func importFollowingAsyncF( | 
					
						
							|  |  |  | 	p *Processor, | 
					
						
							|  |  |  | 	requester *gtsmodel.Account, | 
					
						
							|  |  |  | 	follows []*gtsmodel.Follow, | 
					
						
							|  |  |  | 	overwrite bool, | 
					
						
							|  |  |  | ) func(context.Context) { | 
					
						
							|  |  |  | 	return func(ctx context.Context) { | 
					
						
							|  |  |  | 		// Map used to store wanted | 
					
						
							|  |  |  | 		// follow targets (if overwriting). | 
					
						
							|  |  |  | 		var wantedFollows map[string]struct{} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if overwrite { | 
					
						
							|  |  |  | 			// If we're overwriting, we need to get current | 
					
						
							|  |  |  | 			// follow(-req)s owned by requester *before* | 
					
						
							|  |  |  | 			// making any changes, so that we can remove | 
					
						
							|  |  |  | 			// unwanted follows after we've created new ones. | 
					
						
							|  |  |  | 			prevFollows, err := p.state.DB.GetAccountFollows(ctx, requester.ID, nil) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				log.Errorf(ctx, "db error getting following: %v", err) | 
					
						
							|  |  |  | 				return | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			prevFollowReqs, err := p.state.DB.GetAccountFollowRequesting(ctx, requester.ID, nil) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				log.Errorf(ctx, "db error getting follow requesting: %v", err) | 
					
						
							|  |  |  | 				return | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// Initialize new follows map. | 
					
						
							|  |  |  | 			wantedFollows = make(map[string]struct{}, len(follows)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// Once we've created (or tried to create) | 
					
						
							|  |  |  | 			// the required follows, go through previous | 
					
						
							|  |  |  | 			// follow(-request)s and remove unwanted ones. | 
					
						
							|  |  |  | 			defer func() { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				// AccountIDs to unfollow. | 
					
						
							|  |  |  | 				toRemove := []string{} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				// Check previous follows. | 
					
						
							|  |  |  | 				for _, prev := range prevFollows { | 
					
						
							|  |  |  | 					username := prev.TargetAccount.Username | 
					
						
							|  |  |  | 					domain := prev.TargetAccount.Domain | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 					_, wanted := wantedFollows[username+"@"+domain] | 
					
						
							|  |  |  | 					if !wanted { | 
					
						
							|  |  |  | 						toRemove = append(toRemove, prev.TargetAccountID) | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				// Now any pending follow requests. | 
					
						
							|  |  |  | 				for _, prev := range prevFollowReqs { | 
					
						
							|  |  |  | 					username := prev.TargetAccount.Username | 
					
						
							|  |  |  | 					domain := prev.TargetAccount.Domain | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 					_, wanted := wantedFollows[username+"@"+domain] | 
					
						
							|  |  |  | 					if !wanted { | 
					
						
							|  |  |  | 						toRemove = append(toRemove, prev.TargetAccountID) | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				// Remove each discovered | 
					
						
							|  |  |  | 				// unwanted follow. | 
					
						
							|  |  |  | 				for _, accountID := range toRemove { | 
					
						
							|  |  |  | 					if _, errWithCode := p.FollowRemove( | 
					
						
							|  |  |  | 						ctx, | 
					
						
							|  |  |  | 						requester, | 
					
						
							|  |  |  | 						accountID, | 
					
						
							|  |  |  | 					); errWithCode != nil { | 
					
						
							|  |  |  | 						log.Errorf(ctx, "could not unfollow account: %v", errWithCode.Unwrap()) | 
					
						
							|  |  |  | 						continue | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			}() | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Go through the follows parsed from CSV | 
					
						
							|  |  |  | 		// file, and create / update each one. | 
					
						
							|  |  |  | 		for _, follow := range follows { | 
					
						
							|  |  |  | 			var ( | 
					
						
							|  |  |  | 				// Username of the target. | 
					
						
							|  |  |  | 				username = follow.TargetAccount.Username | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				// Domain of the target. | 
					
						
							|  |  |  | 				// Empty for our domain. | 
					
						
							|  |  |  | 				domain = follow.TargetAccount.Domain | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				// Show reblogs on | 
					
						
							|  |  |  | 				// the new follow. | 
					
						
							|  |  |  | 				showReblogs = follow.ShowReblogs | 
					
						
							| 
									
										
										
										
											2024-09-16 20:39:15 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 				// Notify when new | 
					
						
							|  |  |  | 				// follow posts. | 
					
						
							|  |  |  | 				notify = follow.Notify | 
					
						
							| 
									
										
										
										
											2024-08-02 13:41:46 +02:00
										 |  |  | 			) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			if overwrite { | 
					
						
							|  |  |  | 				// We'll be overwriting, so store | 
					
						
							|  |  |  | 				// this new follow in our handy map. | 
					
						
							|  |  |  | 				wantedFollows[username+"@"+domain] = struct{}{} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// Get the target account, dereferencing it if necessary. | 
					
						
							|  |  |  | 			targetAcct, _, err := p.federator.Dereferencer.GetAccountByUsernameDomain( | 
					
						
							|  |  |  | 				ctx, | 
					
						
							|  |  |  | 				requester.Username, | 
					
						
							|  |  |  | 				username, | 
					
						
							|  |  |  | 				domain, | 
					
						
							|  |  |  | 			) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				log.Errorf(ctx, "could not retrieve account: %v", err) | 
					
						
							|  |  |  | 				continue | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// Use the processor's FollowCreate function | 
					
						
							|  |  |  | 			// to create or update the follow. This takes | 
					
						
							|  |  |  | 			// account of existing follows, and also sends | 
					
						
							|  |  |  | 			// the follow to the FromClientAPI processor. | 
					
						
							|  |  |  | 			if _, errWithCode := p.FollowCreate( | 
					
						
							|  |  |  | 				ctx, | 
					
						
							|  |  |  | 				requester, | 
					
						
							|  |  |  | 				&apimodel.AccountFollowRequest{ | 
					
						
							|  |  |  | 					ID:      targetAcct.ID, | 
					
						
							|  |  |  | 					Reblogs: showReblogs, | 
					
						
							| 
									
										
										
										
											2024-09-16 20:39:15 +02:00
										 |  |  | 					Notify:  notify, | 
					
						
							| 
									
										
										
										
											2024-08-02 13:41:46 +02:00
										 |  |  | 				}, | 
					
						
							|  |  |  | 			); errWithCode != nil { | 
					
						
							|  |  |  | 				log.Errorf(ctx, "could not follow account: %v", errWithCode.Unwrap()) | 
					
						
							|  |  |  | 				continue | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (p *Processor) importBlocks( | 
					
						
							|  |  |  | 	ctx context.Context, | 
					
						
							|  |  |  | 	requester *gtsmodel.Account, | 
					
						
							|  |  |  | 	blocksData *multipart.FileHeader, | 
					
						
							|  |  |  | 	overwrite bool, | 
					
						
							|  |  |  | ) gtserror.WithCode { | 
					
						
							|  |  |  | 	file, err := blocksData.Open() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		err := fmt.Errorf("error opening blocks data file: %w", err) | 
					
						
							|  |  |  | 		return gtserror.NewErrorBadRequest(err, err.Error()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	defer file.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Parse records out of the file. | 
					
						
							|  |  |  | 	records, err := csv.NewReader(file).ReadAll() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		err := fmt.Errorf("error reading blocks data file: %w", err) | 
					
						
							|  |  |  | 		return gtserror.NewErrorBadRequest(err, err.Error()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Convert the records into a slice of barebones blocks. | 
					
						
							|  |  |  | 	// | 
					
						
							|  |  |  | 	// Only TargetAccount.Username and TargetAccount.Domain, | 
					
						
							|  |  |  | 	// will be set on each Block. | 
					
						
							|  |  |  | 	blocks, err := p.converter.CSVToBlocks(ctx, records) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		err := fmt.Errorf("error converting records to blocks: %w", err) | 
					
						
							|  |  |  | 		return gtserror.NewErrorBadRequest(err, err.Error()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Do remaining processing of this import asynchronously. | 
					
						
							|  |  |  | 	f := importBlocksAsyncF(p, requester, blocks, overwrite) | 
					
						
							|  |  |  | 	p.state.Workers.Processing.Queue.Push(f) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func importBlocksAsyncF( | 
					
						
							|  |  |  | 	p *Processor, | 
					
						
							|  |  |  | 	requester *gtsmodel.Account, | 
					
						
							|  |  |  | 	blocks []*gtsmodel.Block, | 
					
						
							|  |  |  | 	overwrite bool, | 
					
						
							|  |  |  | ) func(context.Context) { | 
					
						
							|  |  |  | 	return func(ctx context.Context) { | 
					
						
							|  |  |  | 		// Map used to store wanted | 
					
						
							|  |  |  | 		// block targets (if overwriting). | 
					
						
							|  |  |  | 		var wantedBlocks map[string]struct{} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if overwrite { | 
					
						
							|  |  |  | 			// If we're overwriting, we need to get current | 
					
						
							|  |  |  | 			// blocks owned by requester *before* making any | 
					
						
							|  |  |  | 			// changes, so that we can remove unwanted blocks | 
					
						
							|  |  |  | 			// after we've created new ones. | 
					
						
							|  |  |  | 			var ( | 
					
						
							|  |  |  | 				prevBlocks []*gtsmodel.Block | 
					
						
							|  |  |  | 				err        error | 
					
						
							|  |  |  | 			) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			prevBlocks, err = p.state.DB.GetAccountBlocks(ctx, requester.ID, nil) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				log.Errorf(ctx, "db error getting blocks: %v", err) | 
					
						
							|  |  |  | 				return | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// Initialize new blocks map. | 
					
						
							|  |  |  | 			wantedBlocks = make(map[string]struct{}, len(blocks)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// Once we've created (or tried to create) | 
					
						
							|  |  |  | 			// the required blocks, go through previous | 
					
						
							|  |  |  | 			// blocks and remove unwanted ones. | 
					
						
							|  |  |  | 			defer func() { | 
					
						
							|  |  |  | 				for _, prev := range prevBlocks { | 
					
						
							|  |  |  | 					username := prev.TargetAccount.Username | 
					
						
							|  |  |  | 					domain := prev.TargetAccount.Domain | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 					_, wanted := wantedBlocks[username+"@"+domain] | 
					
						
							|  |  |  | 					if wanted { | 
					
						
							|  |  |  | 						// Leave this | 
					
						
							|  |  |  | 						// one alone. | 
					
						
							|  |  |  | 						continue | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 					if _, errWithCode := p.BlockRemove( | 
					
						
							|  |  |  | 						ctx, | 
					
						
							|  |  |  | 						requester, | 
					
						
							|  |  |  | 						prev.TargetAccountID, | 
					
						
							|  |  |  | 					); errWithCode != nil { | 
					
						
							|  |  |  | 						log.Errorf(ctx, "could not unblock account: %v", errWithCode.Unwrap()) | 
					
						
							|  |  |  | 						continue | 
					
						
							|  |  |  | 					} | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			}() | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Go through the blocks parsed from CSV | 
					
						
							|  |  |  | 		// file, and create / update each one. | 
					
						
							|  |  |  | 		for _, block := range blocks { | 
					
						
							|  |  |  | 			var ( | 
					
						
							|  |  |  | 				// Username of the target. | 
					
						
							|  |  |  | 				username = block.TargetAccount.Username | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				// Domain of the target. | 
					
						
							|  |  |  | 				// Empty for our domain. | 
					
						
							|  |  |  | 				domain = block.TargetAccount.Domain | 
					
						
							|  |  |  | 			) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			if overwrite { | 
					
						
							|  |  |  | 				// We'll be overwriting, so store | 
					
						
							|  |  |  | 				// this new block in our handy map. | 
					
						
							|  |  |  | 				wantedBlocks[username+"@"+domain] = struct{}{} | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// Get the target account, dereferencing it if necessary. | 
					
						
							|  |  |  | 			targetAcct, _, err := p.federator.Dereferencer.GetAccountByUsernameDomain( | 
					
						
							|  |  |  | 				ctx, | 
					
						
							|  |  |  | 				// Provide empty request user to use the | 
					
						
							|  |  |  | 				// instance account to deref the account. | 
					
						
							|  |  |  | 				// | 
					
						
							|  |  |  | 				// It's pointless to make lots of calls | 
					
						
							|  |  |  | 				// to a remote from an account that's about | 
					
						
							|  |  |  | 				// to block that account. | 
					
						
							|  |  |  | 				"", | 
					
						
							|  |  |  | 				username, | 
					
						
							|  |  |  | 				domain, | 
					
						
							|  |  |  | 			) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				log.Errorf(ctx, "could not retrieve account: %v", err) | 
					
						
							|  |  |  | 				continue | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// Use the processor's BlockCreate function | 
					
						
							|  |  |  | 			// to create or update the block. This takes | 
					
						
							|  |  |  | 			// account of existing blocks, and also sends | 
					
						
							|  |  |  | 			// the block to the FromClientAPI processor. | 
					
						
							|  |  |  | 			if _, errWithCode := p.BlockCreate( | 
					
						
							|  |  |  | 				ctx, | 
					
						
							|  |  |  | 				requester, | 
					
						
							|  |  |  | 				targetAcct.ID, | 
					
						
							|  |  |  | 			); errWithCode != nil { | 
					
						
							|  |  |  | 				log.Errorf(ctx, "could not block account: %v", errWithCode.Unwrap()) | 
					
						
							|  |  |  | 				continue | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } |