| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | package bundb | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"context" | 
					
						
							|  |  |  | 	"database/sql" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/superseriousbusiness/gotosocial/internal/db" | 
					
						
							|  |  |  | 	"github.com/uptrace/bun" | 
					
						
							|  |  |  | 	"github.com/uptrace/bun/dialect" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-30 11:16:23 +02:00
										 |  |  | // DBConn wrapps a bun.DB conn to provide SQL-type specific additional functionality | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | type DBConn struct { | 
					
						
							|  |  |  | 	errProc func(error) db.Error // errProc is the SQL-type specific error processor | 
					
						
							|  |  |  | 	*bun.DB                      // DB is the underlying bun.DB connection | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-07-10 16:18:21 +01:00
										 |  |  | // WrapDBConn wraps a bun DB connection to provide our own error processing dependent on DB dialect. | 
					
						
							| 
									
										
										
										
											2021-10-11 05:37:33 -07:00
										 |  |  | func WrapDBConn(dbConn *bun.DB) *DBConn { | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | 	var errProc func(error) db.Error | 
					
						
							|  |  |  | 	switch dbConn.Dialect().Name() { | 
					
						
							|  |  |  | 	case dialect.PG: | 
					
						
							|  |  |  | 		errProc = processPostgresError | 
					
						
							|  |  |  | 	case dialect.SQLite: | 
					
						
							|  |  |  | 		errProc = processSQLiteError | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		panic("unknown dialect name: " + dbConn.Dialect().Name().String()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return &DBConn{ | 
					
						
							|  |  |  | 		errProc: errProc, | 
					
						
							|  |  |  | 		DB:      dbConn, | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-30 11:16:23 +02:00
										 |  |  | // RunInTx wraps execution of the supplied transaction function. | 
					
						
							| 
									
										
										
										
											2021-09-01 10:08:21 +01:00
										 |  |  | func (conn *DBConn) RunInTx(ctx context.Context, fn func(bun.Tx) error) db.Error { | 
					
						
							| 
									
										
										
										
											2022-07-10 16:18:21 +01:00
										 |  |  | 	return conn.ProcessError(func() error { | 
					
						
							|  |  |  | 		// Acquire a new transaction | 
					
						
							|  |  |  | 		tx, err := conn.BeginTx(ctx, nil) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2021-09-01 10:08:21 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-07-10 16:18:21 +01:00
										 |  |  | 		var done bool | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		defer func() { | 
					
						
							|  |  |  | 			if !done { | 
					
						
							|  |  |  | 				_ = tx.Rollback() | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		}() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Perform supplied transaction | 
					
						
							|  |  |  | 		if err := fn(tx); err != nil { | 
					
						
							|  |  |  | 			return err | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2021-09-01 10:08:21 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-07-10 16:18:21 +01:00
										 |  |  | 		// Finally, commit | 
					
						
							| 
									
										
										
										
											2022-09-28 18:30:40 +01:00
										 |  |  | 		err = tx.Commit() //nolint:contextcheck | 
					
						
							| 
									
										
										
										
											2022-07-10 16:18:21 +01:00
										 |  |  | 		done = true | 
					
						
							|  |  |  | 		return err | 
					
						
							|  |  |  | 	}()) | 
					
						
							| 
									
										
										
										
											2021-09-01 10:08:21 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | // ProcessError processes an error to replace any known values with our own db.Error types, | 
					
						
							|  |  |  | // making it easier to catch specific situations (e.g. no rows, already exists, etc) | 
					
						
							|  |  |  | func (conn *DBConn) ProcessError(err error) db.Error { | 
					
						
							|  |  |  | 	switch { | 
					
						
							|  |  |  | 	case err == nil: | 
					
						
							|  |  |  | 		return nil | 
					
						
							|  |  |  | 	case err == sql.ErrNoRows: | 
					
						
							|  |  |  | 		return db.ErrNoEntries | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		return conn.errProc(err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors | 
					
						
							|  |  |  | func (conn *DBConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) { | 
					
						
							| 
									
										
										
										
											2022-05-02 12:53:46 +02:00
										 |  |  | 	exists, err := query.Exists(ctx) | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// Process error as our own and check if it exists | 
					
						
							|  |  |  | 	switch err := conn.ProcessError(err); err { | 
					
						
							|  |  |  | 	case nil: | 
					
						
							| 
									
										
										
										
											2022-05-02 12:53:46 +02:00
										 |  |  | 		return exists, nil | 
					
						
							| 
									
										
										
										
											2021-08-29 15:41:41 +01:00
										 |  |  | 	case db.ErrNoEntries: | 
					
						
							|  |  |  | 		return false, nil | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		return false, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // NotExists is the functional opposite of conn.Exists() | 
					
						
							|  |  |  | func (conn *DBConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) { | 
					
						
							|  |  |  | 	exists, err := conn.Exists(ctx, query) | 
					
						
							|  |  |  | 	return !exists, err | 
					
						
							|  |  |  | } |