| 
									
										
										
										
											2023-03-12 16:00:57 +01:00
										 |  |  | // GoToSocial | 
					
						
							|  |  |  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | 
					
						
							|  |  |  | // SPDX-License-Identifier: AGPL-3.0-or-later | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // This program is free software: you can redistribute it and/or modify | 
					
						
							|  |  |  | // it under the terms of the GNU Affero General Public License as published by | 
					
						
							|  |  |  | // the Free Software Foundation, either version 3 of the License, or | 
					
						
							|  |  |  | // (at your option) any later version. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // This program is distributed in the hope that it will be useful, | 
					
						
							|  |  |  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | 
					
						
							|  |  |  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | 
					
						
							|  |  |  | // GNU Affero General Public License for more details. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // You should have received a copy of the GNU Affero General Public License | 
					
						
							|  |  |  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | package bundb | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"context" | 
					
						
							|  |  |  | 	"time" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-12 10:15:54 +01:00
										 |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 
					
						
							| 
									
										
										
										
											2022-12-08 17:35:14 +00:00
										 |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/state" | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | 	"github.com/uptrace/bun" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type userDB struct { | 
					
						
							| 
									
										
										
										
											2024-02-07 14:43:27 +00:00
										 |  |  | 	db    *bun.DB | 
					
						
							| 
									
										
										
										
											2022-12-08 17:35:14 +00:00
										 |  |  | 	state *state.State | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-25 09:34:05 +01:00
										 |  |  | func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, error) { | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | 	return u.getUser( | 
					
						
							|  |  |  | 		ctx, | 
					
						
							|  |  |  | 		"ID", | 
					
						
							|  |  |  | 		func(user *gtsmodel.User) error { | 
					
						
							|  |  |  | 			return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("id"), id).Scan(ctx) | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		id, | 
					
						
							|  |  |  | 	) | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | func (u *userDB) GetUsersByIDs(ctx context.Context, ids []string) ([]*gtsmodel.User, error) { | 
					
						
							|  |  |  | 	var ( | 
					
						
							|  |  |  | 		users = make([]*gtsmodel.User, 0, len(ids)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Collect errors instead of | 
					
						
							|  |  |  | 		// returning early on any. | 
					
						
							|  |  |  | 		errs gtserror.MultiError | 
					
						
							|  |  |  | 	) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, id := range ids { | 
					
						
							|  |  |  | 		// Attempt to fetch user from DB. | 
					
						
							|  |  |  | 		user, err := u.GetUserByID(ctx, id) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			errs.Appendf("error getting user %s: %w", id, err) | 
					
						
							|  |  |  | 			continue | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | 		// Append user to return slice. | 
					
						
							|  |  |  | 		users = append(users, user) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return users, errs.Combine() | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-25 09:34:05 +01:00
										 |  |  | func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, error) { | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | 	return u.getUser( | 
					
						
							|  |  |  | 		ctx, | 
					
						
							|  |  |  | 		"AccountID", | 
					
						
							|  |  |  | 		func(user *gtsmodel.User) error { | 
					
						
							|  |  |  | 			return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("account_id"), accountID).Scan(ctx) | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		accountID, | 
					
						
							|  |  |  | 	) | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | func (u *userDB) GetUserByEmailAddress(ctx context.Context, email string) (*gtsmodel.User, error) { | 
					
						
							|  |  |  | 	return u.getUser( | 
					
						
							|  |  |  | 		ctx, | 
					
						
							|  |  |  | 		"Email", | 
					
						
							|  |  |  | 		func(user *gtsmodel.User) error { | 
					
						
							|  |  |  | 			return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("email"), email).Scan(ctx) | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		email, | 
					
						
							|  |  |  | 	) | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2022-12-06 14:15:56 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-25 09:34:05 +01:00
										 |  |  | func (u *userDB) GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, error) { | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | 	return u.getUser( | 
					
						
							|  |  |  | 		ctx, | 
					
						
							|  |  |  | 		"ExternalID", | 
					
						
							|  |  |  | 		func(user *gtsmodel.User) error { | 
					
						
							|  |  |  | 			return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("external_id"), id).Scan(ctx) | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		id, | 
					
						
							|  |  |  | 	) | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2022-12-06 14:15:56 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | func (u *userDB) GetUserByConfirmationToken(ctx context.Context, token string) (*gtsmodel.User, error) { | 
					
						
							|  |  |  | 	return u.getUser( | 
					
						
							|  |  |  | 		ctx, | 
					
						
							|  |  |  | 		"ConfirmationToken", | 
					
						
							|  |  |  | 		func(user *gtsmodel.User) error { | 
					
						
							|  |  |  | 			return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("confirmation_token"), token).Scan(ctx) | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		token, | 
					
						
							|  |  |  | 	) | 
					
						
							| 
									
										
										
										
											2022-12-06 14:15:56 +01:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | func (u *userDB) getUser(ctx context.Context, lookup string, dbQuery func(*gtsmodel.User) error, keyParts ...any) (*gtsmodel.User, error) { | 
					
						
							|  |  |  | 	// Fetch user from database cache with loader callback. | 
					
						
							| 
									
										
										
										
											2024-07-24 09:41:43 +01:00
										 |  |  | 	user, err := u.state.Caches.DB.User.LoadOne(lookup, func() (*gtsmodel.User, error) { | 
					
						
							| 
									
										
										
										
											2022-11-15 18:45:15 +00:00
										 |  |  | 		var user gtsmodel.User | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | 		// Not cached! perform database query. | 
					
						
							|  |  |  | 		if err := dbQuery(&user); err != nil { | 
					
						
							| 
									
										
										
										
											2023-08-17 17:26:21 +01:00
										 |  |  | 			return nil, err | 
					
						
							| 
									
										
										
										
											2022-11-15 18:45:15 +00:00
										 |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		return &user, nil | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | 	}, keyParts...) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-27 16:39:44 +01:00
										 |  |  | 	if gtscontext.Barebones(ctx) { | 
					
						
							|  |  |  | 		// Return without populating. | 
					
						
							|  |  |  | 		return user, nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := u.PopulateUser(ctx, user); err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return user, nil | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-27 16:39:44 +01:00
										 |  |  | // PopulateUser ensures that the user's struct fields are populated. | 
					
						
							|  |  |  | func (u *userDB) PopulateUser(ctx context.Context, user *gtsmodel.User) error { | 
					
						
							|  |  |  | 	var ( | 
					
						
							|  |  |  | 		errs = gtserror.NewMultiError(1) | 
					
						
							|  |  |  | 		err  error | 
					
						
							|  |  |  | 	) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if user.Account == nil { | 
					
						
							|  |  |  | 		// Fetch the related account model for this user. | 
					
						
							|  |  |  | 		user.Account, err = u.state.DB.GetAccountByID( | 
					
						
							|  |  |  | 			gtscontext.SetBarebones(ctx), | 
					
						
							|  |  |  | 			user.AccountID, | 
					
						
							|  |  |  | 		) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			errs.Appendf("error populating user account: %w", err) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return errs.Combine() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-25 09:34:05 +01:00
										 |  |  | func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) { | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | 	var userIDs []string | 
					
						
							| 
									
										
										
										
											2023-03-27 16:02:26 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | 	// Scan all user IDs into slice. | 
					
						
							|  |  |  | 	if err := u.db.NewSelect(). | 
					
						
							|  |  |  | 		Table("users"). | 
					
						
							|  |  |  | 		Column("id"). | 
					
						
							|  |  |  | 		Scan(ctx, &userIDs); err != nil { | 
					
						
							| 
									
										
										
										
											2023-08-17 17:26:21 +01:00
										 |  |  | 		return nil, err | 
					
						
							| 
									
										
										
										
											2023-03-27 16:02:26 +02:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-10 15:08:41 +01:00
										 |  |  | 	// Transform user IDs into user slice. | 
					
						
							|  |  |  | 	return u.GetUsersByIDs(ctx, userIDs) | 
					
						
							| 
									
										
										
										
											2023-03-27 16:02:26 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-25 09:34:05 +01:00
										 |  |  | func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error { | 
					
						
							| 
									
										
										
										
											2024-07-24 09:41:43 +01:00
										 |  |  | 	return u.state.Caches.DB.User.Store(user, func() error { | 
					
						
							| 
									
										
										
										
											2023-07-25 09:34:05 +01:00
										 |  |  | 		_, err := u.db. | 
					
						
							| 
									
										
										
										
											2022-11-15 18:45:15 +00:00
										 |  |  | 			NewInsert(). | 
					
						
							|  |  |  | 			Model(user). | 
					
						
							|  |  |  | 			Exec(ctx) | 
					
						
							| 
									
										
										
										
											2023-08-17 17:26:21 +01:00
										 |  |  | 		return err | 
					
						
							| 
									
										
										
										
											2022-11-15 18:45:15 +00:00
										 |  |  | 	}) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-25 09:34:05 +01:00
										 |  |  | func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) error { | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | 	// Update the user's last-updated | 
					
						
							|  |  |  | 	user.UpdatedAt = time.Now() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-16 11:27:08 +01:00
										 |  |  | 	if len(columns) > 0 { | 
					
						
							|  |  |  | 		// If we're updating by column, ensure "updated_at" is included | 
					
						
							|  |  |  | 		columns = append(columns, "updated_at") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-24 09:41:43 +01:00
										 |  |  | 	return u.state.Caches.DB.User.Store(user, func() error { | 
					
						
							| 
									
										
										
										
											2023-07-25 09:34:05 +01:00
										 |  |  | 		_, err := u.db. | 
					
						
							| 
									
										
										
										
											2023-05-12 10:15:54 +01:00
										 |  |  | 			NewUpdate(). | 
					
						
							|  |  |  | 			Model(user). | 
					
						
							|  |  |  | 			Where("? = ?", bun.Ident("user.id"), user.ID). | 
					
						
							|  |  |  | 			Column(columns...). | 
					
						
							|  |  |  | 			Exec(ctx) | 
					
						
							| 
									
										
										
										
											2023-08-17 17:26:21 +01:00
										 |  |  | 		return err | 
					
						
							| 
									
										
										
										
											2023-05-12 10:15:54 +01:00
										 |  |  | 	}) | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-25 09:34:05 +01:00
										 |  |  | func (u *userDB) DeleteUserByID(ctx context.Context, userID string) error { | 
					
						
							| 
									
										
										
										
											2024-09-16 16:46:09 +00:00
										 |  |  | 	// Gather necessary fields from | 
					
						
							|  |  |  | 	// deleted for cache invaliation. | 
					
						
							|  |  |  | 	var deleted gtsmodel.User | 
					
						
							|  |  |  | 	deleted.ID = userID | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Delete user from DB. | 
					
						
							|  |  |  | 	if _, err := u.db.NewDelete(). | 
					
						
							|  |  |  | 		Model(&deleted). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("id"), userID). | 
					
						
							|  |  |  | 		Returning("?", bun.Ident("account_id")). | 
					
						
							|  |  |  | 		Exec(ctx); err != nil { | 
					
						
							| 
									
										
										
										
											2023-05-12 10:15:54 +01:00
										 |  |  | 		return err | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-16 16:46:09 +00:00
										 |  |  | 	// Invalidate cached user by ID, manually | 
					
						
							|  |  |  | 	// call invalidate hook in case not cached. | 
					
						
							|  |  |  | 	u.state.Caches.DB.User.Invalidate("ID", userID) | 
					
						
							|  |  |  | 	u.state.Caches.OnInvalidateUser(&deleted) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return nil | 
					
						
							| 
									
										
										
										
											2022-10-03 10:46:11 +02:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2024-04-13 13:25:10 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | func (u *userDB) PutDeniedUser(ctx context.Context, deniedUser *gtsmodel.DeniedUser) error { | 
					
						
							|  |  |  | 	_, err := u.db.NewInsert(). | 
					
						
							|  |  |  | 		Model(deniedUser). | 
					
						
							|  |  |  | 		Exec(ctx) | 
					
						
							|  |  |  | 	return err | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (u *userDB) GetDeniedUserByID(ctx context.Context, id string) (*gtsmodel.DeniedUser, error) { | 
					
						
							|  |  |  | 	deniedUser := new(gtsmodel.DeniedUser) | 
					
						
							|  |  |  | 	if err := u.db. | 
					
						
							|  |  |  | 		NewSelect(). | 
					
						
							|  |  |  | 		Model(deniedUser). | 
					
						
							|  |  |  | 		Where("? = ?", bun.Ident("denied_user.id"), id). | 
					
						
							|  |  |  | 		Scan(ctx); err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return deniedUser, nil | 
					
						
							|  |  |  | } |