mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 14:32:24 -05:00 
			
		
		
		
	[feature/performance] sqlite pragma optimize on close (#2596)
* wrap database drivers in order to handle error processing, hooks, etc * remove dead code * add code comment, remove unused blank imports
This commit is contained in:
		
					parent
					
						
							
								b6fe8e7a5b
							
						
					
				
			
			
				commit
				
					
						6738fd5bb0
					
				
			
		
					 31 changed files with 372 additions and 660 deletions
				
			
		|  | @ -37,7 +37,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type accountDB struct { | type accountDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -334,7 +334,7 @@ func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) e | ||||||
| 		// It is safe to run this database transaction within cache.Store | 		// It is safe to run this database transaction within cache.Store | ||||||
| 		// as the cache does not attempt a mutex lock until AFTER hook. | 		// as the cache does not attempt a mutex lock until AFTER hook. | ||||||
| 		// | 		// | ||||||
| 		return a.db.RunInTx(ctx, func(tx Tx) error { | 		return a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 			// create links between this account and any emojis it uses | 			// create links between this account and any emojis it uses | ||||||
| 			for _, i := range account.EmojiIDs { | 			for _, i := range account.EmojiIDs { | ||||||
| 				if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ | 				if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{ | ||||||
|  | @ -363,7 +363,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account | ||||||
| 		// It is safe to run this database transaction within cache.Store | 		// It is safe to run this database transaction within cache.Store | ||||||
| 		// as the cache does not attempt a mutex lock until AFTER hook. | 		// as the cache does not attempt a mutex lock until AFTER hook. | ||||||
| 		// | 		// | ||||||
| 		return a.db.RunInTx(ctx, func(tx Tx) error { | 		return a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 			// create links between this account and any emojis it uses | 			// create links between this account and any emojis it uses | ||||||
| 			// first clear out any old emoji links | 			// first clear out any old emoji links | ||||||
| 			if _, err := tx. | 			if _, err := tx. | ||||||
|  | @ -411,7 +411,7 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) error { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return a.db.RunInTx(ctx, func(tx Tx) error { | 	return a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 		// clear out any emoji links | 		// clear out any emoji links | ||||||
| 		if _, err := tx. | 		if _, err := tx. | ||||||
| 			NewDelete(). | 			NewDelete(). | ||||||
|  |  | ||||||
|  | @ -45,7 +45,7 @@ import ( | ||||||
| const rsaKeyBits = 2048 | const rsaKeyBits = 2048 | ||||||
| 
 | 
 | ||||||
| type adminDB struct { | type adminDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -56,7 +56,7 @@ func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (boo | ||||||
| 		Column("account.id"). | 		Column("account.id"). | ||||||
| 		Where("? = ?", bun.Ident("account.username"), username). | 		Where("? = ?", bun.Ident("account.username"), username). | ||||||
| 		Where("? IS NULL", bun.Ident("account.domain")) | 		Where("? IS NULL", bun.Ident("account.domain")) | ||||||
| 	return a.db.NotExists(ctx, q) | 	return notExists(ctx, q) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, error) { | func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, error) { | ||||||
|  | @ -73,7 +73,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, err | ||||||
| 		TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")). | 		TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")). | ||||||
| 		Column("email_domain_block.id"). | 		Column("email_domain_block.id"). | ||||||
| 		Where("? = ?", bun.Ident("email_domain_block.domain"), domain) | 		Where("? = ?", bun.Ident("email_domain_block.domain"), domain) | ||||||
| 	emailDomainBlocked, err := a.db.Exists(ctx, emailDomainBlockedQ) | 	emailDomainBlocked, err := exists(ctx, emailDomainBlockedQ) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return false, err | 		return false, err | ||||||
| 	} | 	} | ||||||
|  | @ -88,7 +88,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, err | ||||||
| 		Column("user.id"). | 		Column("user.id"). | ||||||
| 		Where("? = ?", bun.Ident("user.email"), email). | 		Where("? = ?", bun.Ident("user.email"), email). | ||||||
| 		WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email) | 		WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email) | ||||||
| 	return a.db.NotExists(ctx, q) | 	return notExists(ctx, q) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (a *adminDB) NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, error) { | func (a *adminDB) NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, error) { | ||||||
|  | @ -229,7 +229,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) error { | ||||||
| 		Where("? = ?", bun.Ident("account.username"), username). | 		Where("? = ?", bun.Ident("account.username"), username). | ||||||
| 		Where("? IS NULL", bun.Ident("account.domain")) | 		Where("? IS NULL", bun.Ident("account.domain")) | ||||||
| 
 | 
 | ||||||
| 	exists, err := a.db.Exists(ctx, q) | 	exists, err := exists(ctx, q) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | @ -287,7 +287,7 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) error { | ||||||
| 		TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")). | 		TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")). | ||||||
| 		Where("? = ?", bun.Ident("instance.domain"), host) | 		Where("? = ?", bun.Ident("instance.domain"), host) | ||||||
| 
 | 
 | ||||||
| 	exists, err := a.db.Exists(ctx, q) | 	exists, err := exists(ctx, q) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -26,7 +26,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type applicationDB struct { | type applicationDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -27,7 +27,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type basicDB struct { | type basicDB struct { | ||||||
| 	db *DB | 	db *bun.DB | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (b *basicDB) Put(ctx context.Context, i interface{}) error { | func (b *basicDB) Put(ctx context.Context, i interface{}) error { | ||||||
|  |  | ||||||
|  | @ -52,13 +52,6 @@ import ( | ||||||
| 	"modernc.org/sqlite" | 	"modernc.org/sqlite" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var registerTables = []interface{}{ |  | ||||||
| 	>smodel.AccountToEmoji{}, |  | ||||||
| 	>smodel.StatusToEmoji{}, |  | ||||||
| 	>smodel.StatusToTag{}, |  | ||||||
| 	>smodel.ThreadToStatus{}, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // DBService satisfies the DB interface | // DBService satisfies the DB interface | ||||||
| type DBService struct { | type DBService struct { | ||||||
| 	db.Account | 	db.Account | ||||||
|  | @ -88,12 +81,12 @@ type DBService struct { | ||||||
| 	db.Timeline | 	db.Timeline | ||||||
| 	db.User | 	db.User | ||||||
| 	db.Tombstone | 	db.Tombstone | ||||||
| 	db *DB | 	db *bun.DB | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetDB returns the underlying database connection pool. | // GetDB returns the underlying database connection pool. | ||||||
| // Should only be used in testing + exceptional circumstance. | // Should only be used in testing + exceptional circumstance. | ||||||
| func (dbService *DBService) DB() *DB { | func (dbService *DBService) DB() *bun.DB { | ||||||
| 	return dbService.db | 	return dbService.db | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -129,18 +122,18 @@ func doMigration(ctx context.Context, db *bun.DB) error { | ||||||
| // NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface. | // NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface. | ||||||
| // Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection. | // Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection. | ||||||
| func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { | func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { | ||||||
| 	var db *DB | 	var db *bun.DB | ||||||
| 	var err error | 	var err error | ||||||
| 	t := strings.ToLower(config.GetDbType()) | 	t := strings.ToLower(config.GetDbType()) | ||||||
| 
 | 
 | ||||||
| 	switch t { | 	switch t { | ||||||
| 	case "postgres": | 	case "postgres": | ||||||
| 		db, err = pgConn(ctx) | 		db, err = pgConn(ctx, state) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 	case "sqlite": | 	case "sqlite": | ||||||
| 		db, err = sqliteConn(ctx) | 		db, err = sqliteConn(ctx, state) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
|  | @ -159,14 +152,19 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { | ||||||
| 
 | 
 | ||||||
| 	// table registration is needed for many-to-many, see: | 	// table registration is needed for many-to-many, see: | ||||||
| 	// https://bun.uptrace.dev/orm/many-to-many-relation/ | 	// https://bun.uptrace.dev/orm/many-to-many-relation/ | ||||||
| 	for _, t := range registerTables { | 	for _, t := range []interface{}{ | ||||||
|  | 		>smodel.AccountToEmoji{}, | ||||||
|  | 		>smodel.StatusToEmoji{}, | ||||||
|  | 		>smodel.StatusToTag{}, | ||||||
|  | 		>smodel.ThreadToStatus{}, | ||||||
|  | 	} { | ||||||
| 		db.RegisterModel(t) | 		db.RegisterModel(t) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// perform any pending database migrations: this includes | 	// perform any pending database migrations: this includes | ||||||
| 	// the very first 'migration' on startup which just creates | 	// the very first 'migration' on startup which just creates | ||||||
| 	// necessary tables | 	// necessary tables | ||||||
| 	if err := doMigration(ctx, db.bun); err != nil { | 	if err := doMigration(ctx, db); err != nil { | ||||||
| 		return nil, fmt.Errorf("db migration error: %s", err) | 		return nil, fmt.Errorf("db migration error: %s", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -284,13 +282,18 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { | ||||||
| 	return ps, nil | 	return ps, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func pgConn(ctx context.Context) (*DB, error) { | func pgConn(ctx context.Context, state *state.State) (*bun.DB, error) { | ||||||
| 	opts, err := deriveBunDBPGOptions() //nolint:contextcheck | 	opts, err := deriveBunDBPGOptions() //nolint:contextcheck | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("could not create bundb postgres options: %s", err) | 		return nil, fmt.Errorf("could not create bundb postgres options: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	sqldb := stdlib.OpenDB(*opts) | 	cfg := stdlib.RegisterConnConfig(opts) | ||||||
|  | 
 | ||||||
|  | 	sqldb, err := sql.Open("pgx-gts", cfg) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("could not open postgres db: %w", err) | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	// Tune db connections for postgres, see: | 	// Tune db connections for postgres, see: | ||||||
| 	// - https://bun.uptrace.dev/guide/running-bun-in-production.html#database-sql | 	// - https://bun.uptrace.dev/guide/running-bun-in-production.html#database-sql | ||||||
|  | @ -299,18 +302,18 @@ func pgConn(ctx context.Context) (*DB, error) { | ||||||
| 	sqldb.SetMaxIdleConns(2)                  // assume default 2; if max idle is less than max open, it will be automatically adjusted | 	sqldb.SetMaxIdleConns(2)                  // assume default 2; if max idle is less than max open, it will be automatically adjusted | ||||||
| 	sqldb.SetConnMaxLifetime(5 * time.Minute) // fine to kill old connections | 	sqldb.SetConnMaxLifetime(5 * time.Minute) // fine to kill old connections | ||||||
| 
 | 
 | ||||||
| 	db := WrapDB(bun.NewDB(sqldb, pgdialect.New())) | 	db := bun.NewDB(sqldb, pgdialect.New()) | ||||||
| 
 | 
 | ||||||
| 	// ping to check the db is there and listening | 	// ping to check the db is there and listening | ||||||
| 	if err := db.PingContext(ctx); err != nil { | 	if err := db.PingContext(ctx); err != nil { | ||||||
| 		return nil, fmt.Errorf("postgres ping: %s", err) | 		return nil, fmt.Errorf("postgres ping: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	log.Info(ctx, "connected to POSTGRES database") | 	log.Info(ctx, "connected to POSTGRES database") | ||||||
| 	return db, nil | 	return db, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func sqliteConn(ctx context.Context) (*DB, error) { | func sqliteConn(ctx context.Context, state *state.State) (*bun.DB, error) { | ||||||
| 	// validate db address has actually been set | 	// validate db address has actually been set | ||||||
| 	address := config.GetDbAddress() | 	address := config.GetDbAddress() | ||||||
| 	if address == "" { | 	if address == "" { | ||||||
|  | @ -321,7 +324,7 @@ func sqliteConn(ctx context.Context) (*DB, error) { | ||||||
| 	address = buildSQLiteAddress(address) | 	address = buildSQLiteAddress(address) | ||||||
| 
 | 
 | ||||||
| 	// Open new DB instance | 	// Open new DB instance | ||||||
| 	sqldb, err := sql.Open("sqlite", address) | 	sqldb, err := sql.Open("sqlite-gts", address) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		if errWithCode, ok := err.(*sqlite.Error); ok { | 		if errWithCode, ok := err.(*sqlite.Error); ok { | ||||||
| 			err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()]) | 			err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()]) | ||||||
|  | @ -336,15 +339,14 @@ func sqliteConn(ctx context.Context) (*DB, error) { | ||||||
| 	sqldb.SetMaxIdleConns(1)              // only keep max 1 idle connection around | 	sqldb.SetMaxIdleConns(1)              // only keep max 1 idle connection around | ||||||
| 	sqldb.SetConnMaxLifetime(0)           // don't kill connections due to age | 	sqldb.SetConnMaxLifetime(0)           // don't kill connections due to age | ||||||
| 
 | 
 | ||||||
| 	// Wrap Bun database conn in our own wrapper | 	db := bun.NewDB(sqldb, sqlitedialect.New()) | ||||||
| 	db := WrapDB(bun.NewDB(sqldb, sqlitedialect.New())) |  | ||||||
| 
 | 
 | ||||||
| 	// ping to check the db is there and listening | 	// ping to check the db is there and listening | ||||||
| 	if err := db.PingContext(ctx); err != nil { | 	if err := db.PingContext(ctx); err != nil { | ||||||
| 		if errWithCode, ok := err.(*sqlite.Error); ok { | 		if errWithCode, ok := err.(*sqlite.Error); ok { | ||||||
| 			err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()]) | 			err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()]) | ||||||
| 		} | 		} | ||||||
| 		return nil, fmt.Errorf("sqlite ping: %s", err) | 		return nil, fmt.Errorf("sqlite ping: %w", err) | ||||||
| 	} | 	} | ||||||
| 	log.Infof(ctx, "connected to SQLITE database with address %s", address) | 	log.Infof(ctx, "connected to SQLITE database with address %s", address) | ||||||
| 
 | 
 | ||||||
|  | @ -418,7 +420,7 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) { | ||||||
| 		// parse the PEM block into the certificate | 		// parse the PEM block into the certificate | ||||||
| 		caCert, err := x509.ParseCertificate(caPem.Bytes) | 		caCert, err := x509.ParseCertificate(caPem.Bytes) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", certPath, err) | 			return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %w", certPath, err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// we're happy, add it to the existing pool and then use this pool in our tls config | 		// we're happy, add it to the existing pool and then use this pool in our tls config | ||||||
|  |  | ||||||
|  | @ -1,578 +0,0 @@ | ||||||
| // GoToSocial |  | ||||||
| // Copyright (C) GoToSocial Authors admin@gotosocial.org |  | ||||||
| // SPDX-License-Identifier: AGPL-3.0-or-later |  | ||||||
| // |  | ||||||
| // This program is free software: you can redistribute it and/or modify |  | ||||||
| // it under the terms of the GNU Affero General Public License as published by |  | ||||||
| // the Free Software Foundation, either version 3 of the License, or |  | ||||||
| // (at your option) any later version. |  | ||||||
| // |  | ||||||
| // This program is distributed in the hope that it will be useful, |  | ||||||
| // but WITHOUT ANY WARRANTY; without even the implied warranty of |  | ||||||
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the |  | ||||||
| // GNU Affero General Public License for more details. |  | ||||||
| // |  | ||||||
| // You should have received a copy of the GNU Affero General Public License |  | ||||||
| // along with this program.  If not, see <http://www.gnu.org/licenses/>. |  | ||||||
| 
 |  | ||||||
| package bundb |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"context" |  | ||||||
| 	"database/sql" |  | ||||||
| 	"time" |  | ||||||
| 	"unsafe" |  | ||||||
| 
 |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" |  | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" |  | ||||||
| 	"github.com/uptrace/bun" |  | ||||||
| 	"github.com/uptrace/bun/dialect" |  | ||||||
| 	"github.com/uptrace/bun/schema" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // DB wraps a bun database instance |  | ||||||
| // to provide common per-dialect SQL error |  | ||||||
| // conversions to common types, and retries |  | ||||||
| // on returned busy (SQLite only). |  | ||||||
| type DB struct { |  | ||||||
| 	// our own wrapped db type |  | ||||||
| 	// with retry backoff support. |  | ||||||
| 	// kept separate to the *bun.DB |  | ||||||
| 	// type to be passed into query |  | ||||||
| 	// builders as bun.IConn iface |  | ||||||
| 	// (this prevents double firing |  | ||||||
| 	// bun query hooks). |  | ||||||
| 	// |  | ||||||
| 	// also holds per-dialect |  | ||||||
| 	// error hook function. |  | ||||||
| 	raw rawdb |  | ||||||
| 
 |  | ||||||
| 	// bun DB interface we use |  | ||||||
| 	// for dialects, and improved |  | ||||||
| 	// struct marshal/unmarshaling. |  | ||||||
| 	bun *bun.DB |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // WrapDB wraps a bun database instance in our database type. |  | ||||||
| func WrapDB(db *bun.DB) *DB { |  | ||||||
| 	var errProc func(error) error |  | ||||||
| 	switch name := db.Dialect().Name(); name { |  | ||||||
| 	case dialect.PG: |  | ||||||
| 		errProc = processPostgresError |  | ||||||
| 	case dialect.SQLite: |  | ||||||
| 		errProc = processSQLiteError |  | ||||||
| 	default: |  | ||||||
| 		panic("unknown dialect name: " + name.String()) |  | ||||||
| 	} |  | ||||||
| 	return &DB{ |  | ||||||
| 		raw: rawdb{ |  | ||||||
| 			errHook: errProc, |  | ||||||
| 			db:      db.DB, |  | ||||||
| 		}, |  | ||||||
| 		bun: db, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Dialect is a direct call-through to bun.DB.Dialect(). |  | ||||||
| func (db *DB) Dialect() schema.Dialect { return db.bun.Dialect() } |  | ||||||
| 
 |  | ||||||
| // AddQueryHook is a direct call-through to bun.DB.AddQueryHook(). |  | ||||||
| func (db *DB) AddQueryHook(hook bun.QueryHook) { db.bun.AddQueryHook(hook) } |  | ||||||
| 
 |  | ||||||
| // RegisterModels is a direct call-through to bun.DB.RegisterModels(). |  | ||||||
| func (db *DB) RegisterModel(models ...any) { db.bun.RegisterModel(models...) } |  | ||||||
| 
 |  | ||||||
| // PingContext is a direct call-through to bun.DB.PingContext(). |  | ||||||
| func (db *DB) PingContext(ctx context.Context) error { return db.bun.PingContext(ctx) } |  | ||||||
| 
 |  | ||||||
| // Close is a direct call-through to bun.DB.Close(). |  | ||||||
| func (db *DB) Close() error { return db.bun.Close() } |  | ||||||
| 
 |  | ||||||
| // ExecContext wraps bun.DB.ExecContext() with retry-busy timeout and our own error processing. |  | ||||||
| func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) { |  | ||||||
| 	bundb := db.bun // use underlying *bun.DB interface for their query formatting |  | ||||||
| 	err = retryOnBusy(ctx, func() error { |  | ||||||
| 		result, err = bundb.ExecContext(ctx, query, args...) |  | ||||||
| 		err = db.raw.errHook(err) |  | ||||||
| 		return err |  | ||||||
| 	}) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // QueryContext wraps bun.DB.ExecContext() with retry-busy timeout and our own error processing. |  | ||||||
| func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) { |  | ||||||
| 	bundb := db.bun // use underlying *bun.DB interface for their query formatting |  | ||||||
| 	err = retryOnBusy(ctx, func() error { |  | ||||||
| 		rows, err = bundb.QueryContext(ctx, query, args...) |  | ||||||
| 		err = db.raw.errHook(err) |  | ||||||
| 		return err |  | ||||||
| 	}) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // QueryRowContext wraps bun.DB.ExecContext() with retry-busy timeout and our own error processing. |  | ||||||
| func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) (row *sql.Row) { |  | ||||||
| 	bundb := db.bun // use underlying *bun.DB interface for their query formatting |  | ||||||
| 	_ = retryOnBusy(ctx, func() error { |  | ||||||
| 		row = bundb.QueryRowContext(ctx, query, args...) |  | ||||||
| 		if err := db.raw.errHook(row.Err()); err != nil { |  | ||||||
| 			updateRowError(row, err) // set new error |  | ||||||
| 		} |  | ||||||
| 		return row.Err() |  | ||||||
| 	}) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // BeginTx wraps bun.DB.BeginTx() with retry-busy timeout and our own error processing. |  | ||||||
| func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (tx Tx, err error) { |  | ||||||
| 	var buntx bun.Tx // captured bun.Tx |  | ||||||
| 	bundb := db.bun  // use *bun.DB interface to return bun.Tx type |  | ||||||
| 
 |  | ||||||
| 	err = retryOnBusy(ctx, func() error { |  | ||||||
| 		buntx, err = bundb.BeginTx(ctx, opts) |  | ||||||
| 		err = db.raw.errHook(err) |  | ||||||
| 		return err |  | ||||||
| 	}) |  | ||||||
| 
 |  | ||||||
| 	if err == nil { |  | ||||||
| 		// Wrap bun.Tx in our type. |  | ||||||
| 		tx = wrapTx(db, &buntx) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // RunInTx is functionally the same as bun.DB.RunInTx() but with retry-busy timeouts. |  | ||||||
| func (db *DB) RunInTx(ctx context.Context, fn func(Tx) error) error { |  | ||||||
| 	// Attempt to start new transaction. |  | ||||||
| 	tx, err := db.BeginTx(ctx, nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var done bool |  | ||||||
| 
 |  | ||||||
| 	defer func() { |  | ||||||
| 		if !done { |  | ||||||
| 			// Rollback tx. |  | ||||||
| 			_ = tx.Rollback() |  | ||||||
| 		} |  | ||||||
| 	}() |  | ||||||
| 
 |  | ||||||
| 	// Perform supplied transaction |  | ||||||
| 	if err := fn(tx); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Commit tx. |  | ||||||
| 	err = tx.Commit() |  | ||||||
| 	done = true |  | ||||||
| 	return err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (db *DB) NewValues(model interface{}) *bun.ValuesQuery { |  | ||||||
| 	// note: passing in rawdb as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.DB.Query___() functions. |  | ||||||
| 	return bun.NewValuesQuery(db.bun, model).Conn(&db.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (db *DB) NewMerge() *bun.MergeQuery { |  | ||||||
| 	// note: passing in rawdb as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.DB.Query___() functions. |  | ||||||
| 	return bun.NewMergeQuery(db.bun).Conn(&db.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (db *DB) NewSelect() *bun.SelectQuery { |  | ||||||
| 	// note: passing in rawdb as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.DB.Query___() functions. |  | ||||||
| 	return bun.NewSelectQuery(db.bun).Conn(&db.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (db *DB) NewInsert() *bun.InsertQuery { |  | ||||||
| 	// note: passing in rawdb as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.DB.Query___() functions. |  | ||||||
| 	return bun.NewInsertQuery(db.bun).Conn(&db.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (db *DB) NewUpdate() *bun.UpdateQuery { |  | ||||||
| 	// note: passing in rawdb as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.DB.Query___() functions. |  | ||||||
| 	return bun.NewUpdateQuery(db.bun).Conn(&db.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (db *DB) NewDelete() *bun.DeleteQuery { |  | ||||||
| 	// note: passing in rawdb as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.DB.Query___() functions. |  | ||||||
| 	return bun.NewDeleteQuery(db.bun).Conn(&db.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (db *DB) NewRaw(query string, args ...interface{}) *bun.RawQuery { |  | ||||||
| 	// note: passing in rawdb as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.DB.Query___() functions. |  | ||||||
| 	return bun.NewRawQuery(db.bun, query, args...).Conn(&db.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (db *DB) NewCreateTable() *bun.CreateTableQuery { |  | ||||||
| 	// note: passing in rawdb as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.DB.Query___() functions. |  | ||||||
| 	return bun.NewCreateTableQuery(db.bun).Conn(&db.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (db *DB) NewDropTable() *bun.DropTableQuery { |  | ||||||
| 	// note: passing in rawdb as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.DB.Query___() functions. |  | ||||||
| 	return bun.NewDropTableQuery(db.bun).Conn(&db.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (db *DB) NewCreateIndex() *bun.CreateIndexQuery { |  | ||||||
| 	// note: passing in rawdb as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.DB.Query___() functions. |  | ||||||
| 	return bun.NewCreateIndexQuery(db.bun).Conn(&db.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (db *DB) NewDropIndex() *bun.DropIndexQuery { |  | ||||||
| 	// note: passing in rawdb as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.DB.Query___() functions. |  | ||||||
| 	return bun.NewDropIndexQuery(db.bun).Conn(&db.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (db *DB) NewTruncateTable() *bun.TruncateTableQuery { |  | ||||||
| 	// note: passing in rawdb as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.DB.Query___() functions. |  | ||||||
| 	return bun.NewTruncateTableQuery(db.bun).Conn(&db.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (db *DB) NewAddColumn() *bun.AddColumnQuery { |  | ||||||
| 	// note: passing in rawdb as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.DB.Query___() functions. |  | ||||||
| 	return bun.NewAddColumnQuery(db.bun).Conn(&db.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (db *DB) NewDropColumn() *bun.DropColumnQuery { |  | ||||||
| 	// note: passing in rawdb as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.DB.Query___() functions. |  | ||||||
| 	return bun.NewDropColumnQuery(db.bun).Conn(&db.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors. |  | ||||||
| func (db *DB) Exists(ctx context.Context, query *bun.SelectQuery) (bool, error) { |  | ||||||
| 	exists, err := query.Exists(ctx) |  | ||||||
| 	switch err { |  | ||||||
| 	case nil: |  | ||||||
| 		return exists, nil |  | ||||||
| 	case sql.ErrNoRows: |  | ||||||
| 		return false, nil |  | ||||||
| 	default: |  | ||||||
| 		return false, err |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // NotExists checks the results of a SelectQuery for the non-existence of the data in question, masking ErrNoEntries errors. |  | ||||||
| func (db *DB) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, error) { |  | ||||||
| 	exists, err := db.Exists(ctx, query) |  | ||||||
| 	return !exists, err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type rawdb struct { |  | ||||||
| 	// dialect specific error |  | ||||||
| 	// processing function hook. |  | ||||||
| 	errHook func(error) error |  | ||||||
| 
 |  | ||||||
| 	// embedded raw |  | ||||||
| 	// db interface |  | ||||||
| 	db *sql.DB |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ExecContext wraps sql.DB.ExecContext() with retry-busy timeout and our own error processing. |  | ||||||
| func (db *rawdb) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) { |  | ||||||
| 	err = retryOnBusy(ctx, func() error { |  | ||||||
| 		result, err = db.db.ExecContext(ctx, query, args...) |  | ||||||
| 		err = db.errHook(err) |  | ||||||
| 		return err |  | ||||||
| 	}) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // QueryContext wraps sql.DB.QueryContext() with retry-busy timeout and our own error processing. |  | ||||||
| func (db *rawdb) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) { |  | ||||||
| 	err = retryOnBusy(ctx, func() error { |  | ||||||
| 		rows, err = db.db.QueryContext(ctx, query, args...) |  | ||||||
| 		err = db.errHook(err) |  | ||||||
| 		return err |  | ||||||
| 	}) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // QueryRowContext wraps sql.DB.QueryRowContext() with retry-busy timeout and our own error processing. |  | ||||||
| func (db *rawdb) QueryRowContext(ctx context.Context, query string, args ...any) (row *sql.Row) { |  | ||||||
| 	_ = retryOnBusy(ctx, func() error { |  | ||||||
| 		row = db.db.QueryRowContext(ctx, query, args...) |  | ||||||
| 		err := db.errHook(row.Err()) |  | ||||||
| 		return err |  | ||||||
| 	}) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Tx wraps a bun transaction instance |  | ||||||
| // to provide common per-dialect SQL error |  | ||||||
| // conversions to common types, and retries |  | ||||||
| // on busy commit/rollback (SQLite only). |  | ||||||
| type Tx struct { |  | ||||||
| 	// our own wrapped Tx type |  | ||||||
| 	// kept separate to the *bun.Tx |  | ||||||
| 	// type to be passed into query |  | ||||||
| 	// builders as bun.IConn iface |  | ||||||
| 	// (this prevents double firing |  | ||||||
| 	// bun query hooks). |  | ||||||
| 	// |  | ||||||
| 	// also holds per-dialect |  | ||||||
| 	// error hook function. |  | ||||||
| 	raw rawtx |  | ||||||
| 
 |  | ||||||
| 	// bun Tx interface we use |  | ||||||
| 	// for dialects, and improved |  | ||||||
| 	// struct marshal/unmarshaling. |  | ||||||
| 	bun *bun.Tx |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // wrapTx wraps a given bun.Tx in our own wrapping Tx type. |  | ||||||
| func wrapTx(db *DB, tx *bun.Tx) Tx { |  | ||||||
| 	return Tx{ |  | ||||||
| 		raw: rawtx{ |  | ||||||
| 			errHook: db.raw.errHook, |  | ||||||
| 			tx:      tx.Tx, |  | ||||||
| 		}, |  | ||||||
| 		bun: tx, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ExecContext wraps bun.Tx.ExecContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction). |  | ||||||
| func (tx Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { |  | ||||||
| 	buntx := tx.bun // use underlying *bun.Tx interface for their query formatting |  | ||||||
| 	res, err := buntx.ExecContext(ctx, query, args...) |  | ||||||
| 	err = tx.raw.errHook(err) |  | ||||||
| 	return res, err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // QueryContext wraps bun.Tx.QueryContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction). |  | ||||||
| func (tx Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { |  | ||||||
| 	buntx := tx.bun // use underlying *bun.Tx interface for their query formatting |  | ||||||
| 	rows, err := buntx.QueryContext(ctx, query, args...) |  | ||||||
| 	err = tx.raw.errHook(err) |  | ||||||
| 	return rows, err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // QueryRowContext wraps bun.Tx.QueryRowContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction). |  | ||||||
| func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { |  | ||||||
| 	buntx := tx.bun // use underlying *bun.Tx interface for their query formatting |  | ||||||
| 	row := buntx.QueryRowContext(ctx, query, args...) |  | ||||||
| 	if err := tx.raw.errHook(row.Err()); err != nil { |  | ||||||
| 		updateRowError(row, err) // set new error |  | ||||||
| 	} |  | ||||||
| 	return row |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Commit wraps bun.Tx.Commit() with retry-busy timeout and our own error processing. |  | ||||||
| func (tx Tx) Commit() (err error) { |  | ||||||
| 	buntx := tx.bun // use *bun.Tx interface |  | ||||||
| 	err = retryOnBusy(context.TODO(), func() error { |  | ||||||
| 		err = buntx.Commit() |  | ||||||
| 		err = tx.raw.errHook(err) |  | ||||||
| 		return err |  | ||||||
| 	}) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Rollback wraps bun.Tx.Rollback() with retry-busy timeout and our own error processing. |  | ||||||
| func (tx Tx) Rollback() (err error) { |  | ||||||
| 	buntx := tx.bun // use *bun.Tx interface |  | ||||||
| 	err = retryOnBusy(context.TODO(), func() error { |  | ||||||
| 		err = buntx.Rollback() |  | ||||||
| 		err = tx.raw.errHook(err) |  | ||||||
| 		return err |  | ||||||
| 	}) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Dialect is a direct call-through to bun.DB.Dialect(). |  | ||||||
| func (tx Tx) Dialect() schema.Dialect { |  | ||||||
| 	return tx.bun.Dialect() |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (tx Tx) NewValues(model interface{}) *bun.ValuesQuery { |  | ||||||
| 	// note: passing in rawtx as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.Tx.Query___() functions. |  | ||||||
| 	return tx.bun.NewValues(model).Conn(&tx.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (tx Tx) NewMerge() *bun.MergeQuery { |  | ||||||
| 	// note: passing in rawtx as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.Tx.Query___() functions. |  | ||||||
| 	return tx.bun.NewMerge().Conn(&tx.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (tx Tx) NewSelect() *bun.SelectQuery { |  | ||||||
| 	// note: passing in rawtx as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.Tx.Query___() functions. |  | ||||||
| 	return tx.bun.NewSelect().Conn(&tx.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (tx Tx) NewInsert() *bun.InsertQuery { |  | ||||||
| 	// note: passing in rawtx as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.Tx.Query___() functions. |  | ||||||
| 	return tx.bun.NewInsert().Conn(&tx.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (tx Tx) NewUpdate() *bun.UpdateQuery { |  | ||||||
| 	// note: passing in rawtx as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.Tx.Query___() functions. |  | ||||||
| 	return tx.bun.NewUpdate().Conn(&tx.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (tx Tx) NewDelete() *bun.DeleteQuery { |  | ||||||
| 	// note: passing in rawtx as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.Tx.Query___() functions. |  | ||||||
| 	return tx.bun.NewDelete().Conn(&tx.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (tx Tx) NewRaw(query string, args ...interface{}) *bun.RawQuery { |  | ||||||
| 	// note: passing in rawtx as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.Tx.Query___() functions. |  | ||||||
| 	return tx.bun.NewRaw(query, args...).Conn(&tx.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (tx Tx) NewCreateTable() *bun.CreateTableQuery { |  | ||||||
| 	// note: passing in rawtx as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.Tx.Query___() functions. |  | ||||||
| 	return tx.bun.NewCreateTable().Conn(&tx.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (tx Tx) NewDropTable() *bun.DropTableQuery { |  | ||||||
| 	// note: passing in rawtx as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.Tx.Query___() functions. |  | ||||||
| 	return tx.bun.NewDropTable().Conn(&tx.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (tx Tx) NewCreateIndex() *bun.CreateIndexQuery { |  | ||||||
| 	// note: passing in rawtx as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.Tx.Query___() functions. |  | ||||||
| 	return tx.bun.NewCreateIndex().Conn(&tx.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (tx Tx) NewDropIndex() *bun.DropIndexQuery { |  | ||||||
| 	// note: passing in rawtx as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.Tx.Query___() functions. |  | ||||||
| 	return tx.bun.NewDropIndex().Conn(&tx.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (tx Tx) NewTruncateTable() *bun.TruncateTableQuery { |  | ||||||
| 	// note: passing in rawtx as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.Tx.Query___() functions. |  | ||||||
| 	return tx.bun.NewTruncateTable().Conn(&tx.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (tx Tx) NewAddColumn() *bun.AddColumnQuery { |  | ||||||
| 	// note: passing in rawtx as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.Tx.Query___() functions. |  | ||||||
| 	return tx.bun.NewAddColumn().Conn(&tx.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (tx Tx) NewDropColumn() *bun.DropColumnQuery { |  | ||||||
| 	// note: passing in rawtx as conn iface so no double query-hook |  | ||||||
| 	// firing when passed through the bun.Tx.Query___() functions. |  | ||||||
| 	return tx.bun.NewDropColumn().Conn(&tx.raw) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type rawtx struct { |  | ||||||
| 	// dialect specific error |  | ||||||
| 	// processing function hook. |  | ||||||
| 	errHook func(error) error |  | ||||||
| 
 |  | ||||||
| 	// embedded raw |  | ||||||
| 	// tx interface |  | ||||||
| 	tx *sql.Tx |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ExecContext wraps sql.Tx.ExecContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction). |  | ||||||
| func (tx *rawtx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { |  | ||||||
| 	res, err := tx.tx.ExecContext(ctx, query, args...) |  | ||||||
| 	err = tx.errHook(err) |  | ||||||
| 	return res, err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // QueryContext wraps sql.Tx.QueryContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction). |  | ||||||
| func (tx *rawtx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { |  | ||||||
| 	rows, err := tx.tx.QueryContext(ctx, query, args...) |  | ||||||
| 	err = tx.errHook(err) |  | ||||||
| 	return rows, err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // QueryRowContext wraps sql.Tx.QueryRowContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction). |  | ||||||
| func (tx *rawtx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { |  | ||||||
| 	row := tx.tx.QueryRowContext(ctx, query, args...) |  | ||||||
| 	if err := tx.errHook(row.Err()); err != nil { |  | ||||||
| 		updateRowError(row, err) // set new error |  | ||||||
| 	} |  | ||||||
| 	return row |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // updateRowError updates an sql.Row's internal error field using the unsafe package. |  | ||||||
| func updateRowError(sqlrow *sql.Row, err error) { |  | ||||||
| 	type row struct { |  | ||||||
| 		err  error |  | ||||||
| 		rows *sql.Rows |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// compile-time check to ensure sql.Row not changed. |  | ||||||
| 	if unsafe.Sizeof(row{}) != unsafe.Sizeof(sql.Row{}) { |  | ||||||
| 		panic("sql.Row has changed definition") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// this code is awful and i must be shamed for this. |  | ||||||
| 	(*row)(unsafe.Pointer(sqlrow)).err = err |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // retryOnBusy will retry given function on returned 'errBusy'. |  | ||||||
| func retryOnBusy(ctx context.Context, fn func() error) error { |  | ||||||
| 	var backoff time.Duration |  | ||||||
| 
 |  | ||||||
| 	for i := 0; ; i++ { |  | ||||||
| 		// Perform func. |  | ||||||
| 		err := fn() |  | ||||||
| 
 |  | ||||||
| 		if err != errBusy { |  | ||||||
| 			// May be nil, or may be |  | ||||||
| 			// some other error, either |  | ||||||
| 			// way return here. |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// backoff according to a multiplier of 2ms * 2^2n, |  | ||||||
| 		// up to a maximum possible backoff time of 5 minutes. |  | ||||||
| 		// |  | ||||||
| 		// this works out as the following: |  | ||||||
| 		// 4ms |  | ||||||
| 		// 16ms |  | ||||||
| 		// 64ms |  | ||||||
| 		// 256ms |  | ||||||
| 		// 1.024s |  | ||||||
| 		// 4.096s |  | ||||||
| 		// 16.384s |  | ||||||
| 		// 1m5.536s |  | ||||||
| 		// 4m22.144s |  | ||||||
| 		backoff = 2 * time.Millisecond * (1 << (2*i + 1)) |  | ||||||
| 		if backoff >= 5*time.Minute { |  | ||||||
| 			break |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		select { |  | ||||||
| 		// Context cancelled. |  | ||||||
| 		case <-ctx.Done(): |  | ||||||
| 
 |  | ||||||
| 		// Backoff for some time. |  | ||||||
| 		case <-time.After(backoff): |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff) |  | ||||||
| } |  | ||||||
|  | @ -31,7 +31,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type domainDB struct { | type domainDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										267
									
								
								internal/db/bundb/drivers.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										267
									
								
								internal/db/bundb/drivers.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,267 @@ | ||||||
|  | // GoToSocial | ||||||
|  | // Copyright (C) GoToSocial Authors admin@gotosocial.org | ||||||
|  | // SPDX-License-Identifier: AGPL-3.0-or-later | ||||||
|  | // | ||||||
|  | // This program is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Affero General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // This program is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  | // GNU Affero General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Affero General Public License | ||||||
|  | // along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | 
 | ||||||
|  | package bundb | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql" | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"time" | ||||||
|  | 	_ "unsafe" // linkname shenanigans | ||||||
|  | 
 | ||||||
|  | 	pgx "github.com/jackc/pgx/v5/stdlib" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/gtserror" | ||||||
|  | 	"modernc.org/sqlite" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var ( | ||||||
|  | 	// global SQL driver instances. | ||||||
|  | 	postgresDriver = pgx.GetDefaultDriver() | ||||||
|  | 	sqliteDriver   = getSQLiteDriver() | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func init() { | ||||||
|  | 	sql.Register("pgx-gts", &PostgreSQLDriver{}) | ||||||
|  | 	sql.Register("sqlite-gts", &SQLiteDriver{}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //go:linkname getSQLiteDriver modernc.org/sqlite.newDriver | ||||||
|  | func getSQLiteDriver() *sqlite.Driver | ||||||
|  | 
 | ||||||
|  | // PostgreSQLDriver is our own wrapper around the | ||||||
|  | // pgx/stdlib.Driver{} type in order to wrap further | ||||||
|  | // SQL driver types with our own err processing. | ||||||
|  | type PostgreSQLDriver struct{} | ||||||
|  | 
 | ||||||
|  | func (d *PostgreSQLDriver) Open(name string) (driver.Conn, error) { | ||||||
|  | 	c, err := postgresDriver.Open(name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return &PostgreSQLConn{conn: c.(conn)}, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type PostgreSQLConn struct{ conn } | ||||||
|  | 
 | ||||||
|  | func (c *PostgreSQLConn) Begin() (driver.Tx, error) { | ||||||
|  | 	return c.BeginTx(context.Background(), driver.TxOptions{}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *PostgreSQLConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | ||||||
|  | 	tx, err := c.conn.BeginTx(ctx, opts) | ||||||
|  | 	err = processPostgresError(err) | ||||||
|  | 	return tx, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *PostgreSQLConn) Prepare(query string) (driver.Stmt, error) { | ||||||
|  | 	return c.PrepareContext(context.Background(), query) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *PostgreSQLConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | ||||||
|  | 	stmt, err := c.conn.PrepareContext(ctx, query) | ||||||
|  | 	err = processPostgresError(err) | ||||||
|  | 	return stmt, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *PostgreSQLConn) Exec(query string, args []driver.NamedValue) (driver.Result, error) { | ||||||
|  | 	return c.ExecContext(context.Background(), query, args) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *PostgreSQLConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | ||||||
|  | 	result, err := c.conn.ExecContext(ctx, query, args) | ||||||
|  | 	err = processPostgresError(err) | ||||||
|  | 	return result, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *PostgreSQLConn) Query(query string, args []driver.NamedValue) (driver.Rows, error) { | ||||||
|  | 	return c.QueryContext(context.Background(), query, args) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *PostgreSQLConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | ||||||
|  | 	rows, err := c.conn.QueryContext(ctx, query, args) | ||||||
|  | 	err = processPostgresError(err) | ||||||
|  | 	return rows, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *PostgreSQLConn) Close() error { | ||||||
|  | 	return c.conn.Close() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type PostgreSQLTx struct{ driver.Tx } | ||||||
|  | 
 | ||||||
|  | func (tx *PostgreSQLTx) Commit() error { | ||||||
|  | 	err := tx.Tx.Commit() | ||||||
|  | 	return processPostgresError(err) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (tx *PostgreSQLTx) Rollback() error { | ||||||
|  | 	err := tx.Tx.Rollback() | ||||||
|  | 	return processPostgresError(err) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SQLiteDriver is our own wrapper around the | ||||||
|  | // sqlite.Driver{} type in order to wrap further | ||||||
|  | // SQL driver types with our own functionality, | ||||||
|  | // e.g. hooks, retries and err processing. | ||||||
|  | type SQLiteDriver struct{} | ||||||
|  | 
 | ||||||
|  | func (d *SQLiteDriver) Open(name string) (driver.Conn, error) { | ||||||
|  | 	c, err := sqliteDriver.Open(name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return &SQLiteConn{conn: c.(conn)}, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type SQLiteConn struct{ conn } | ||||||
|  | 
 | ||||||
|  | func (c *SQLiteConn) Begin() (driver.Tx, error) { | ||||||
|  | 	return c.BeginTx(context.Background(), driver.TxOptions{}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *SQLiteConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) { | ||||||
|  | 	err = retryOnBusy(ctx, func() error { | ||||||
|  | 		tx, err = c.conn.BeginTx(ctx, opts) | ||||||
|  | 		err = processSQLiteError(err) | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | 	return &SQLiteTx{Context: ctx, Tx: tx}, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { | ||||||
|  | 	return c.PrepareContext(context.Background(), query) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) { | ||||||
|  | 	err = retryOnBusy(ctx, func() error { | ||||||
|  | 		stmt, err = c.conn.PrepareContext(ctx, query) | ||||||
|  | 		err = processSQLiteError(err) | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *SQLiteConn) Exec(query string, args []driver.NamedValue) (driver.Result, error) { | ||||||
|  | 	return c.ExecContext(context.Background(), query, args) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { | ||||||
|  | 	err = retryOnBusy(ctx, func() error { | ||||||
|  | 		result, err = c.conn.ExecContext(ctx, query, args) | ||||||
|  | 		err = processSQLiteError(err) | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *SQLiteConn) Query(query string, args []driver.NamedValue) (driver.Rows, error) { | ||||||
|  | 	return c.QueryContext(context.Background(), query, args) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { | ||||||
|  | 	err = retryOnBusy(ctx, func() error { | ||||||
|  | 		rows, err = c.conn.QueryContext(ctx, query, args) | ||||||
|  | 		err = processSQLiteError(err) | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *SQLiteConn) Close() error { | ||||||
|  | 	// see: https://www.sqlite.org/pragma.html#pragma_optimize | ||||||
|  | 	const onClose = "PRAGMA analysis_limit=1000; PRAGMA optimize;" | ||||||
|  | 	_, _ = c.conn.ExecContext(context.Background(), onClose, nil) | ||||||
|  | 	return c.conn.Close() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type SQLiteTx struct { | ||||||
|  | 	context.Context | ||||||
|  | 	driver.Tx | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (tx *SQLiteTx) Commit() (err error) { | ||||||
|  | 	err = retryOnBusy(tx.Context, func() error { | ||||||
|  | 		err = tx.Tx.Commit() | ||||||
|  | 		err = processSQLiteError(err) | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (tx *SQLiteTx) Rollback() (err error) { | ||||||
|  | 	err = retryOnBusy(tx.Context, func() error { | ||||||
|  | 		err = tx.Tx.Rollback() | ||||||
|  | 		err = processSQLiteError(err) | ||||||
|  | 		return err | ||||||
|  | 	}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type conn interface { | ||||||
|  | 	driver.Conn | ||||||
|  | 	driver.ConnPrepareContext | ||||||
|  | 	driver.ExecerContext | ||||||
|  | 	driver.QueryerContext | ||||||
|  | 	driver.ConnBeginTx | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // retryOnBusy will retry given function on returned 'errBusy'. | ||||||
|  | func retryOnBusy(ctx context.Context, fn func() error) error { | ||||||
|  | 	var backoff time.Duration | ||||||
|  | 
 | ||||||
|  | 	for i := 0; ; i++ { | ||||||
|  | 		// Perform func. | ||||||
|  | 		err := fn() | ||||||
|  | 
 | ||||||
|  | 		if err != errBusy { | ||||||
|  | 			// May be nil, or may be | ||||||
|  | 			// some other error, either | ||||||
|  | 			// way return here. | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// backoff according to a multiplier of 2ms * 2^2n, | ||||||
|  | 		// up to a maximum possible backoff time of 5 minutes. | ||||||
|  | 		// | ||||||
|  | 		// this works out as the following: | ||||||
|  | 		// 4ms | ||||||
|  | 		// 16ms | ||||||
|  | 		// 64ms | ||||||
|  | 		// 256ms | ||||||
|  | 		// 1.024s | ||||||
|  | 		// 4.096s | ||||||
|  | 		// 16.384s | ||||||
|  | 		// 1m5.536s | ||||||
|  | 		// 4m22.144s | ||||||
|  | 		backoff = 2 * time.Millisecond * (1 << (2*i + 1)) | ||||||
|  | 		if backoff >= 5*time.Minute { | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		select { | ||||||
|  | 		// Context cancelled. | ||||||
|  | 		case <-ctx.Done(): | ||||||
|  | 
 | ||||||
|  | 		// Backoff for some time. | ||||||
|  | 		case <-time.After(backoff): | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff) | ||||||
|  | } | ||||||
|  | @ -38,7 +38,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type emojiDB struct { | type emojiDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -109,7 +109,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return e.db.RunInTx(ctx, func(tx Tx) error { | 	return e.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 		// Delete relational links between this emoji | 		// Delete relational links between this emoji | ||||||
| 		// and any statuses using it, returning the | 		// and any statuses using it, returning the | ||||||
| 		// status IDs so we can later update them. | 		// status IDs so we can later update them. | ||||||
|  |  | ||||||
|  | @ -29,7 +29,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type headerFilterDB struct { | type headerFilterDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -34,7 +34,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type instanceDB struct { | type instanceDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -35,7 +35,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type listDB struct { | type listDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -198,7 +198,7 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error { | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| 
 | 
 | ||||||
| 	return l.db.RunInTx(ctx, func(tx Tx) error { | 	return l.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 		// Delete all entries attached to list. | 		// Delete all entries attached to list. | ||||||
| 		if _, err := tx.NewDelete(). | 		if _, err := tx.NewDelete(). | ||||||
| 			Table("list_entries"). | 			Table("list_entries"). | ||||||
|  | @ -515,7 +515,7 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt | ||||||
| 	}() | 	}() | ||||||
| 
 | 
 | ||||||
| 	// Finally, insert each list entry into the database. | 	// Finally, insert each list entry into the database. | ||||||
| 	return l.db.RunInTx(ctx, func(tx Tx) error { | 	return l.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 		for _, entry := range entries { | 		for _, entry := range entries { | ||||||
| 			entry := entry // rescope | 			entry := entry // rescope | ||||||
| 			if err := l.state.Caches.GTS.ListEntry.Store(entry, func() error { | 			if err := l.state.Caches.GTS.ListEntry.Store(entry, func() error { | ||||||
|  |  | ||||||
|  | @ -30,7 +30,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type markerDB struct { | type markerDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -85,7 +85,7 @@ func (m *markerDB) UpdateMarker(ctx context.Context, marker *gtsmodel.Marker) er | ||||||
| 		// Optimistic concurrency control: start a transaction, try to update a row with a previously retrieved version. | 		// Optimistic concurrency control: start a transaction, try to update a row with a previously retrieved version. | ||||||
| 		// If the update in the transaction fails to actually change anything, another update happened concurrently, and | 		// If the update in the transaction fails to actually change anything, another update happened concurrently, and | ||||||
| 		// this update should be retried by the caller, which in this case involves sending HTTP 409 to the API client. | 		// this update should be retried by the caller, which in this case involves sending HTTP 409 to the API client. | ||||||
| 		return m.db.RunInTx(ctx, func(tx Tx) error { | 		return m.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 			result, err := tx.NewUpdate(). | 			result, err := tx.NewUpdate(). | ||||||
| 				Model(marker). | 				Model(marker). | ||||||
| 				WherePK(). | 				WherePK(). | ||||||
|  |  | ||||||
|  | @ -34,7 +34,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type mediaDB struct { | type mediaDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -151,7 +151,7 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { | ||||||
| 	defer m.state.Caches.GTS.Media.Invalidate("ID", id) | 	defer m.state.Caches.GTS.Media.Invalidate("ID", id) | ||||||
| 
 | 
 | ||||||
| 	// Delete media attachment in new transaction. | 	// Delete media attachment in new transaction. | ||||||
| 	err = m.db.RunInTx(ctx, func(tx Tx) error { | 	err = m.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 		if media.AccountID != "" { | 		if media.AccountID != "" { | ||||||
| 			var account gtsmodel.Account | 			var account gtsmodel.Account | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -33,7 +33,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type mentionDB struct { | type mentionDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -34,7 +34,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type notificationDB struct { | type notificationDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -34,7 +34,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type pollDB struct { | type pollDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -154,7 +154,7 @@ func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...st | ||||||
| 	poll.CheckVotes() | 	poll.CheckVotes() | ||||||
| 
 | 
 | ||||||
| 	return p.state.Caches.GTS.Poll.Store(poll, func() error { | 	return p.state.Caches.GTS.Poll.Store(poll, func() error { | ||||||
| 		return p.db.RunInTx(ctx, func(tx Tx) error { | 		return p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 			// Update the status' "updated_at" field. | 			// Update the status' "updated_at" field. | ||||||
| 			if _, err := tx.NewUpdate(). | 			if _, err := tx.NewUpdate(). | ||||||
| 				Table("statuses"). | 				Table("statuses"). | ||||||
|  | @ -362,7 +362,7 @@ func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote) | ||||||
| 
 | 
 | ||||||
| func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error { | func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error { | ||||||
| 	return p.state.Caches.GTS.PollVote.Store(vote, func() error { | 	return p.state.Caches.GTS.PollVote.Store(vote, func() error { | ||||||
| 		return p.db.RunInTx(ctx, func(tx Tx) error { | 		return p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 			// Try insert vote into database. | 			// Try insert vote into database. | ||||||
| 			if _, err := tx.NewInsert(). | 			if _, err := tx.NewInsert(). | ||||||
| 				Model(vote). | 				Model(vote). | ||||||
|  | @ -398,7 +398,7 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { | func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { | ||||||
| 	err := p.db.RunInTx(ctx, func(tx Tx) error { | 	err := p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 		// Delete all votes in poll. | 		// Delete all votes in poll. | ||||||
| 		res, err := tx.NewDelete(). | 		res, err := tx.NewDelete(). | ||||||
| 			Table("poll_votes"). | 			Table("poll_votes"). | ||||||
|  | @ -469,7 +469,7 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error { | func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error { | ||||||
| 	err := p.db.RunInTx(ctx, func(tx Tx) error { | 	err := p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 		// Slice should only ever be of length | 		// Slice should only ever be of length | ||||||
| 		// 0 or 1; it's a slice of slices only | 		// 0 or 1; it's a slice of slices only | ||||||
| 		// because we can't LIMIT deletes to 1. | 		// because we can't LIMIT deletes to 1. | ||||||
|  | @ -569,7 +569,7 @@ func (p *pollDB) DeletePollVotesByAccountID(ctx context.Context, accountID strin | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newSelectPollVotes returns a new select query for all rows in the poll_votes table with poll_id = pollID. | // newSelectPollVotes returns a new select query for all rows in the poll_votes table with poll_id = pollID. | ||||||
| func newSelectPollVotes(db *DB, pollID string) *bun.SelectQuery { | func newSelectPollVotes(db *bun.DB, pollID string) *bun.SelectQuery { | ||||||
| 	return db.NewSelect(). | 	return db.NewSelect(). | ||||||
| 		TableExpr("?", bun.Ident("poll_votes")). | 		TableExpr("?", bun.Ident("poll_votes")). | ||||||
| 		ColumnExpr("?", bun.Ident("id")). | 		ColumnExpr("?", bun.Ident("id")). | ||||||
|  |  | ||||||
|  | @ -31,7 +31,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type relationshipDB struct { | type relationshipDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -299,7 +299,7 @@ func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID strin | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. | // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. | ||||||
| func newSelectFollowRequests(db *DB, accountID string) *bun.SelectQuery { | func newSelectFollowRequests(db *bun.DB, accountID string) *bun.SelectQuery { | ||||||
| 	return db.NewSelect(). | 	return db.NewSelect(). | ||||||
| 		TableExpr("?", bun.Ident("follow_requests")). | 		TableExpr("?", bun.Ident("follow_requests")). | ||||||
| 		ColumnExpr("?", bun.Ident("id")). | 		ColumnExpr("?", bun.Ident("id")). | ||||||
|  | @ -308,7 +308,7 @@ func newSelectFollowRequests(db *DB, accountID string) *bun.SelectQuery { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID. | // newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID. | ||||||
| func newSelectFollowRequesting(db *DB, accountID string) *bun.SelectQuery { | func newSelectFollowRequesting(db *bun.DB, accountID string) *bun.SelectQuery { | ||||||
| 	return db.NewSelect(). | 	return db.NewSelect(). | ||||||
| 		TableExpr("?", bun.Ident("follow_requests")). | 		TableExpr("?", bun.Ident("follow_requests")). | ||||||
| 		ColumnExpr("?", bun.Ident("id")). | 		ColumnExpr("?", bun.Ident("id")). | ||||||
|  | @ -317,7 +317,7 @@ func newSelectFollowRequesting(db *DB, accountID string) *bun.SelectQuery { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID. | // newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID. | ||||||
| func newSelectFollows(db *DB, accountID string) *bun.SelectQuery { | func newSelectFollows(db *bun.DB, accountID string) *bun.SelectQuery { | ||||||
| 	return db.NewSelect(). | 	return db.NewSelect(). | ||||||
| 		Table("follows"). | 		Table("follows"). | ||||||
| 		Column("id"). | 		Column("id"). | ||||||
|  | @ -327,7 +327,7 @@ func newSelectFollows(db *DB, accountID string) *bun.SelectQuery { | ||||||
| 
 | 
 | ||||||
| // newSelectLocalFollows returns a new select query for all rows in the follows table with | // newSelectLocalFollows returns a new select query for all rows in the follows table with | ||||||
| // account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). | // account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). | ||||||
| func newSelectLocalFollows(db *DB, accountID string) *bun.SelectQuery { | func newSelectLocalFollows(db *bun.DB, accountID string) *bun.SelectQuery { | ||||||
| 	return db.NewSelect(). | 	return db.NewSelect(). | ||||||
| 		Table("follows"). | 		Table("follows"). | ||||||
| 		Column("id"). | 		Column("id"). | ||||||
|  | @ -344,7 +344,7 @@ func newSelectLocalFollows(db *DB, accountID string) *bun.SelectQuery { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID. | // newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID. | ||||||
| func newSelectFollowers(db *DB, accountID string) *bun.SelectQuery { | func newSelectFollowers(db *bun.DB, accountID string) *bun.SelectQuery { | ||||||
| 	return db.NewSelect(). | 	return db.NewSelect(). | ||||||
| 		Table("follows"). | 		Table("follows"). | ||||||
| 		Column("id"). | 		Column("id"). | ||||||
|  | @ -354,7 +354,7 @@ func newSelectFollowers(db *DB, accountID string) *bun.SelectQuery { | ||||||
| 
 | 
 | ||||||
| // newSelectLocalFollowers returns a new select query for all rows in the follows table with | // newSelectLocalFollowers returns a new select query for all rows in the follows table with | ||||||
| // target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). | // target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). | ||||||
| func newSelectLocalFollowers(db *DB, accountID string) *bun.SelectQuery { | func newSelectLocalFollowers(db *bun.DB, accountID string) *bun.SelectQuery { | ||||||
| 	return db.NewSelect(). | 	return db.NewSelect(). | ||||||
| 		Table("follows"). | 		Table("follows"). | ||||||
| 		Column("id"). | 		Column("id"). | ||||||
|  | @ -371,7 +371,7 @@ func newSelectLocalFollowers(db *DB, accountID string) *bun.SelectQuery { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID. | // newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID. | ||||||
| func newSelectBlocks(db *DB, accountID string) *bun.SelectQuery { | func newSelectBlocks(db *bun.DB, accountID string) *bun.SelectQuery { | ||||||
| 	return db.NewSelect(). | 	return db.NewSelect(). | ||||||
| 		TableExpr("?", bun.Ident("blocks")). | 		TableExpr("?", bun.Ident("blocks")). | ||||||
| 		ColumnExpr("?", bun.Ident("id")). | 		ColumnExpr("?", bun.Ident("id")). | ||||||
|  |  | ||||||
|  | @ -32,7 +32,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type reportDB struct { | type reportDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -32,7 +32,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type ruleDB struct { | type ruleDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -57,7 +57,7 @@ import ( | ||||||
| // This isn't ideal, of course, but at least we could cover the most common use case of | // This isn't ideal, of course, but at least we could cover the most common use case of | ||||||
| // a caller paging down through results. | // a caller paging down through results. | ||||||
| type searchDB struct { | type searchDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -24,10 +24,11 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/id" | 	"github.com/superseriousbusiness/gotosocial/internal/id" | ||||||
|  | 	"github.com/uptrace/bun" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type sessionDB struct { | type sessionDB struct { | ||||||
| 	db *DB | 	db *bun.DB | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, error) { | func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, error) { | ||||||
|  |  | ||||||
|  | @ -34,7 +34,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type statusDB struct { | type statusDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -330,7 +330,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error | ||||||
| 		// It is safe to run this database transaction within cache.Store | 		// It is safe to run this database transaction within cache.Store | ||||||
| 		// as the cache does not attempt a mutex lock until AFTER hook. | 		// as the cache does not attempt a mutex lock until AFTER hook. | ||||||
| 		// | 		// | ||||||
| 		return s.db.RunInTx(ctx, func(tx Tx) error { | 		return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 			// create links between this status and any emojis it uses | 			// create links between this status and any emojis it uses | ||||||
| 			for _, i := range status.EmojiIDs { | 			for _, i := range status.EmojiIDs { | ||||||
| 				if _, err := tx. | 				if _, err := tx. | ||||||
|  | @ -414,7 +414,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co | ||||||
| 		// It is safe to run this database transaction within cache.Store | 		// It is safe to run this database transaction within cache.Store | ||||||
| 		// as the cache does not attempt a mutex lock until AFTER hook. | 		// as the cache does not attempt a mutex lock until AFTER hook. | ||||||
| 		// | 		// | ||||||
| 		return s.db.RunInTx(ctx, func(tx Tx) error { | 		return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 			// create links between this status and any emojis it uses | 			// create links between this status and any emojis it uses | ||||||
| 			for _, i := range status.EmojiIDs { | 			for _, i := range status.EmojiIDs { | ||||||
| 				if _, err := tx. | 				if _, err := tx. | ||||||
|  | @ -509,7 +509,7 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error { | ||||||
| 	// On return ensure status invalidated from cache. | 	// On return ensure status invalidated from cache. | ||||||
| 	defer s.state.Caches.GTS.Status.Invalidate("ID", id) | 	defer s.state.Caches.GTS.Status.Invalidate("ID", id) | ||||||
| 
 | 
 | ||||||
| 	return s.db.RunInTx(ctx, func(tx Tx) error { | 	return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { | ||||||
| 		// delete links between this status and any emojis it uses | 		// delete links between this status and any emojis it uses | ||||||
| 		if _, err := tx. | 		if _, err := tx. | ||||||
| 			NewDelete(). | 			NewDelete(). | ||||||
|  | @ -697,6 +697,5 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St | ||||||
| 		TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). | 		TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")). | ||||||
| 		Where("? = ?", bun.Ident("status_bookmark.status_id"), status.ID). | 		Where("? = ?", bun.Ident("status_bookmark.status_id"), status.ID). | ||||||
| 		Where("? = ?", bun.Ident("status_bookmark.account_id"), accountID) | 		Where("? = ?", bun.Ident("status_bookmark.account_id"), accountID) | ||||||
| 
 | 	return exists(ctx, q) | ||||||
| 	return s.db.Exists(ctx, q) |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -29,7 +29,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type statusBookmarkDB struct { | type statusBookmarkDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -35,7 +35,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type statusFaveDB struct { | type statusFaveDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -28,7 +28,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type tagDB struct { | type tagDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -28,7 +28,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type threadDB struct { | type threadDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -34,7 +34,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type timelineDB struct { | type timelineDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -27,7 +27,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type tombstoneDB struct { | type tombstoneDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -31,7 +31,7 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type userDB struct { | type userDB struct { | ||||||
| 	db    *DB | 	db    *bun.DB | ||||||
| 	state *state.State | 	state *state.State | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -18,6 +18,8 @@ | ||||||
| package bundb | package bundb | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql" | ||||||
| 	"slices" | 	"slices" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
|  | @ -113,6 +115,25 @@ func whereStartsLike( | ||||||
| 	) | 	) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors. | ||||||
|  | func exists(ctx context.Context, query *bun.SelectQuery) (bool, error) { | ||||||
|  | 	exists, err := query.Exists(ctx) | ||||||
|  | 	switch err { | ||||||
|  | 	case nil: | ||||||
|  | 		return exists, nil | ||||||
|  | 	case sql.ErrNoRows: | ||||||
|  | 		return false, nil | ||||||
|  | 	default: | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // notExists checks the results of a SelectQuery for the non-existence of the data in question, masking ErrNoEntries errors. | ||||||
|  | func notExists(ctx context.Context, query *bun.SelectQuery) (bool, error) { | ||||||
|  | 	exists, err := exists(ctx, query) | ||||||
|  | 	return !exists, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // loadPagedIDs loads a page of IDs from given SliceCache by `key`, resorting to `loadDESC` if required. Uses `page` to sort + page resulting IDs. | // loadPagedIDs loads a page of IDs from given SliceCache by `key`, resorting to `loadDESC` if required. Uses `page` to sort + page resulting IDs. | ||||||
| // NOTE: IDs returned from `cache` / `loadDESC` MUST be in descending order, otherwise paging will not work correctly / return things out of order. | // NOTE: IDs returned from `cache` / `loadDESC` MUST be in descending order, otherwise paging will not work correctly / return things out of order. | ||||||
| func loadPagedIDs(cache *cache.SliceCache[string], key string, page *paging.Page, loadDESC func() ([]string, error)) ([]string, error) { | func loadPagedIDs(cache *cache.SliceCache[string], key string, page *paging.Page, loadDESC func() ([]string, error)) ([]string, error) { | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue