diff --git a/go.mod b/go.mod index 25d3a1f44..9e0ef9a08 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/gorilla/sessions v1.2.1 // indirect github.com/gorilla/websocket v1.4.2 github.com/h2non/filetype v1.1.1 + github.com/jackc/pgconn v1.10.0 github.com/jackc/pgx/v4 v4.13.0 github.com/json-iterator/go v1.1.11 // indirect github.com/leodido/go-urn v1.2.1 // indirect diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index c96d0df9e..f0ccf8711 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -34,7 +34,7 @@ import ( type accountDB struct { config *config.Config - conn *bun.DB + conn *dbConn log *logrus.Logger } @@ -52,7 +52,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac q := a.newAccountQ(account). Where("account.id = ?", id) - err := processErrorResponse(q.Scan(ctx)) + err := a.conn.ProcessError(q.Scan(ctx)) return account, err } @@ -63,7 +63,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel. q := a.newAccountQ(account). Where("account.uri = ?", uri) - err := processErrorResponse(q.Scan(ctx)) + err := a.conn.ProcessError(q.Scan(ctx)) return account, err } @@ -74,7 +74,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel. q := a.newAccountQ(account). Where("account.url = ?", uri) - err := processErrorResponse(q.Scan(ctx)) + err := a.conn.ProcessError(q.Scan(ctx)) return account, err } @@ -93,7 +93,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account _, err := q.Exec(ctx) - err = processErrorResponse(err) + err = a.conn.ProcessError(err) return account, err } @@ -113,7 +113,7 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts WhereGroup(" AND ", whereEmptyOrNull("domain")) } - err := processErrorResponse(q.Scan(ctx)) + err := a.conn.ProcessError(q.Scan(ctx)) return account, err } @@ -129,7 +129,7 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string) Where("account_id = ?", accountID). Column("created_at") - err := processErrorResponse(q.Scan(ctx)) + err := a.conn.ProcessError(q.Scan(ctx)) return status.CreatedAt, err } @@ -174,7 +174,7 @@ func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username stri Where("username = ?", username). WhereGroup(" AND ", whereEmptyOrNull("domain")) - err := processErrorResponse(q.Scan(ctx)) + err := a.conn.ProcessError(q.Scan(ctx)) return account, err } diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index 09f2d3bff..29c353a56 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -35,13 +35,12 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/util" - "github.com/uptrace/bun" "golang.org/x/crypto/bcrypt" ) type adminDB struct { config *config.Config - conn *bun.DB + conn *dbConn log *logrus.Logger } @@ -52,7 +51,7 @@ func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (boo Where("username = ?", username). Where("domain = ?", nil) - return notExists(ctx, q) + return a.conn.NotExists(ctx, q) } func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) { @@ -72,7 +71,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db. // fail because we found something return false, fmt.Errorf("email domain %s is blocked", domain) } else if err != sql.ErrNoRows { - return false, processErrorResponse(err) + return false, a.conn.ProcessError(err) } // check if this email is associated with a user already @@ -82,7 +81,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db. Where("email = ?", email). WhereOr("unconfirmed_email = ?", email) - return notExists(ctx, q) + return a.conn.NotExists(ctx, q) } func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) { @@ -187,7 +186,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error { a.log.Infof("instance account %s already exists", username) return nil } else if err != sql.ErrNoRows { - return processErrorResponse(err) + return a.conn.ProcessError(err) } key, err := rsa.GenerateKey(rand.Reader, 2048) @@ -245,7 +244,7 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error { a.log.Infof("instance instance %s already exists", domain) return nil } else if err != sql.ErrNoRows { - return processErrorResponse(err) + return a.conn.ProcessError(err) } iID, err := id.NewRandomULID() diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index 983b6b810..acf706b65 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -31,7 +31,7 @@ import ( type basicDB struct { config *config.Config - conn *bun.DB + conn *dbConn log *logrus.Logger } @@ -49,7 +49,7 @@ func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Erro Model(i). Where("id = ?", id) - return processErrorResponse(q.Scan(ctx)) + return b.conn.ProcessError(q.Scan(ctx)) } func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { @@ -59,7 +59,6 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) q := b.conn.NewSelect().Model(i) for _, w := range where { - if w.Value == nil { q = q.Where("? IS NULL", bun.Ident(w.Key)) } else { @@ -71,7 +70,7 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) } } - return processErrorResponse(q.Scan(ctx)) + return b.conn.ProcessError(q.Scan(ctx)) } func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error { @@ -79,7 +78,7 @@ func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error { NewSelect(). Model(i) - return processErrorResponse(q.Scan(ctx)) + return b.conn.ProcessError(q.Scan(ctx)) } func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error { @@ -90,7 +89,7 @@ func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.E _, err := q.Exec(ctx) - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error { @@ -108,7 +107,7 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface _, err := q.Exec(ctx) - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error { @@ -119,7 +118,7 @@ func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.E _, err := q.Exec(ctx) - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error { @@ -130,7 +129,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu _, err := q.Exec(ctx) - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error { @@ -152,7 +151,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, _, err := q.Exec(ctx) - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error { @@ -162,7 +161,7 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error { func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error { _, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx) - return processErrorResponse(err) + return b.conn.ProcessError(err) } func (b *basicDB) IsHealthy(ctx context.Context) db.Error { diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index a24cee3ae..9b9e6c53a 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -68,7 +68,7 @@ type bunDBService struct { db.Status db.Timeline config *config.Config - conn *bun.DB + conn *dbConn log *logrus.Logger } @@ -76,7 +76,7 @@ type bunDBService struct { // Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection. func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) { var sqldb *sql.DB - var conn *bun.DB + var conn *dbConn // depending on the database type we're trying to create, we need to use a different driver... switch strings.ToLower(c.DBConfig.Type) { @@ -87,25 +87,30 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) return nil, fmt.Errorf("could not create bundb postgres options: %s", err) } sqldb = stdlib.OpenDB(*opts) - conn = bun.NewDB(sqldb, pgdialect.New()) + conn = &dbConn{ + DB: bun.NewDB(sqldb, pgdialect.New()), + errProc: processPostgresError, + } case dbTypeSqlite: // SQLITE var err error - sqldb, err = sql.Open("sqlite", c.DBConfig.Address) if err != nil { return nil, fmt.Errorf("could not open sqlite db: %s", err) } - conn = bun.NewDB(sqldb, sqlitedialect.New()) + conn = &dbConn{ + DB: bun.NewDB(sqldb, sqlitedialect.New()), + errProc: processSQLiteError, + } if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") { log.Warn("sqlite in-memory database should only be used for debugging") - // don't close connections on close -- otherwise + // don't close connections on disconnect -- otherwise // the SQLite database will be deleted when there // are no active connections - sqldb.SetConnMaxLifetime(0) sqldb.SetMaxOpenConns(1000) + sqldb.SetConnMaxLifetime(0) } default: return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type)) diff --git a/internal/db/bundb/conn.go b/internal/db/bundb/conn.go new file mode 100644 index 000000000..e6cf85499 --- /dev/null +++ b/internal/db/bundb/conn.go @@ -0,0 +1,44 @@ +package bundb + +import ( + "context" + "database/sql" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/uptrace/bun" +) + +type dbConn struct { + errProc func(error) db.Error // errProc is the SQL-type specific error processor + *bun.DB // DB is the underlying bun.DB connection +} + +func (conn *dbConn) ProcessError(err error) db.Error { + switch { + case err == nil: + return nil + case err == sql.ErrNoRows: + return db.ErrNoEntries + default: + return conn.errProc(err) + } +} + +func (conn *dbConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) { + // Get the select query result + count, err := query.Count(ctx) + + // Process error as our own and check if it exists + switch err := conn.ProcessError(err); err { + case nil, db.ErrAlreadyExists: + return (count != 0), nil + default: + return false, err + } +} + +func (conn *dbConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) { + // Simply inverse of conn.exists() + exists, err := conn.Exists(ctx, query) + return !exists, err +} diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 6aa2b8ffe..e3abaa17f 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -27,12 +27,11 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/util" - "github.com/uptrace/bun" ) type domainDB struct { config *config.Config - conn *bun.DB + conn *dbConn log *logrus.Logger } @@ -47,7 +46,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db Where("LOWER(domain) = LOWER(?)", domain). Limit(1) - return exists(ctx, q) + return d.conn.Exists(ctx, q) } func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) { diff --git a/internal/db/bundb/errors.go b/internal/db/bundb/errors.go new file mode 100644 index 000000000..731e3d7e2 --- /dev/null +++ b/internal/db/bundb/errors.go @@ -0,0 +1,43 @@ +package bundb + +import ( + "github.com/jackc/pgconn" + "github.com/superseriousbusiness/gotosocial/internal/db" + "modernc.org/sqlite" + sqlite3 "modernc.org/sqlite/lib" +) + +// processPostgresError +func processPostgresError(err error) db.Error { + // Attempt to cast as postgres + pgErr, ok := err.(*pgconn.PgError) + if !ok { + return err + } + + // Handle supplied error code: + // (https://www.postgresql.org/docs/10/errcodes-appendix.html) + switch pgErr.Code { + case "23505" /* unique_violation */ : + return db.ErrAlreadyExists + default: + return err + } +} + +// processSQLiteError +func processSQLiteError(err error) db.Error { + // Attempt to cast as sqlite + sqliteErr, ok := err.(*sqlite.Error) + if !ok { + return err + } + + // Handle supplied error code: + switch sqliteErr.Code() { + case sqlite3.SQLITE_CONSTRAINT_UNIQUE: + return db.ErrAlreadyExists + default: + return nil + } +} diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index 141b255cf..455bf0223 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -30,7 +30,7 @@ import ( type instanceDB struct { config *config.Config - conn *bun.DB + conn *dbConn log *logrus.Logger } @@ -53,7 +53,7 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int count, err := q.Count(ctx) - return count, processErrorResponse(err) + return count, i.conn.ProcessError(err) } func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) { @@ -72,7 +72,7 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) ( count, err := q.Count(ctx) - return count, processErrorResponse(err) + return count, i.conn.ProcessError(err) } func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) { @@ -93,7 +93,7 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i count, err := q.Count(ctx) - return count, processErrorResponse(err) + return count, i.conn.ProcessError(err) } func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) { @@ -114,7 +114,7 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max q = q.Limit(limit) } - err := processErrorResponse(q.Scan(ctx)) + err := i.conn.ProcessError(q.Scan(ctx)) return accounts, err } diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index 04e55ca62..67c2d43d3 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -30,7 +30,7 @@ import ( type mediaDB struct { config *config.Config - conn *bun.DB + conn *dbConn log *logrus.Logger } @@ -47,7 +47,7 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M q := m.newMediaQ(attachment). Where("media_attachment.id = ?", id) - err := processErrorResponse(q.Scan(ctx)) + err := m.conn.ProcessError(q.Scan(ctx)) return attachment, err } diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go index a444f9b5f..9645c09b6 100644 --- a/internal/db/bundb/mention.go +++ b/internal/db/bundb/mention.go @@ -31,7 +31,7 @@ import ( type mentionDB struct { config *config.Config - conn *bun.DB + conn *dbConn log *logrus.Logger cache cache.Cache } @@ -84,7 +84,7 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio q := m.newMentionQ(mention). Where("mention.id = ?", id) - err := processErrorResponse(q.Scan(ctx)) + err := m.conn.ProcessError(q.Scan(ctx)) if err == nil && mention != nil { m.cacheMention(id, mention) @@ -99,7 +99,7 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel. for _, i := range ids { mention, err := m.GetMention(ctx, i) if err != nil { - return nil, processErrorResponse(err) + return nil, m.conn.ProcessError(err) } mentions = append(mentions, mention) } diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index 1c30837ec..b95cd02ac 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -31,7 +31,7 @@ import ( type notificationDB struct { config *config.Config - conn *bun.DB + conn *dbConn log *logrus.Logger cache cache.Cache } @@ -84,7 +84,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo q := n.newNotificationQ(notification). Where("notification.id = ?", id) - err := processErrorResponse(q.Scan(ctx)) + err := n.conn.ProcessError(q.Scan(ctx)) if err == nil && notification != nil { n.cacheNotification(id, notification) @@ -115,7 +115,7 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, q = q.Limit(limit) } - err := processErrorResponse(q.Scan(ctx)) + err := n.conn.ProcessError(q.Scan(ctx)) if err != nil { return nil, err } @@ -125,7 +125,7 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, notifications := []*gtsmodel.Notification{} for _, notifID := range notifIDs { notif, err := n.GetNotification(ctx, notifID.ID) - errP := processErrorResponse(err) + errP := n.conn.ProcessError(err) if errP != nil { return nil, errP } diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index ed144669e..ace11d128 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -32,7 +32,7 @@ import ( type relationshipDB struct { config *config.Config - conn *bun.DB + conn *dbConn log *logrus.Logger } @@ -66,7 +66,7 @@ func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account Where("account_id = ?", account2) } - return exists(ctx, q) + return r.conn.Exists(ctx, q) } func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) { @@ -76,7 +76,7 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 Where("block.account_id = ?", account1). Where("block.target_account_id = ?", account2) - err := processErrorResponse(q.Scan(ctx)) + err := r.conn.ProcessError(q.Scan(ctx)) return block, err } @@ -176,7 +176,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode Where("target_account_id = ?", targetAccount.ID). Limit(1) - return exists(ctx, q) + return r.conn.Exists(ctx, q) } func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) { @@ -190,7 +190,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g Where("account_id = ?", sourceAccount.ID). Where("target_account_id = ?", targetAccount.ID) - return exists(ctx, q) + return r.conn.Exists(ctx, q) } func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) { @@ -201,13 +201,13 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod // make sure account 1 follows account 2 f1, err := r.IsFollowing(ctx, account1, account2) if err != nil { - return false, processErrorResponse(err) + return false, r.conn.ProcessError(err) } // make sure account 2 follows account 1 f2, err := r.IsFollowing(ctx, account2, account1) if err != nil { - return false, processErrorResponse(err) + return false, r.conn.ProcessError(err) } return f1 && f2, nil @@ -222,7 +222,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI Where("account_id = ?", originAccountID). Where("target_account_id = ?", targetAccountID). Scan(ctx); err != nil { - return nil, processErrorResponse(err) + return nil, r.conn.ProcessError(err) } // create a new follow to 'replace' the request with @@ -239,7 +239,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI Model(follow). On("CONFLICT CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI). Exec(ctx); err != nil { - return nil, processErrorResponse(err) + return nil, r.conn.ProcessError(err) } // now remove the follow request @@ -249,7 +249,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI Where("account_id = ?", originAccountID). Where("target_account_id = ?", targetAccountID). Exec(ctx); err != nil { - return nil, processErrorResponse(err) + return nil, r.conn.ProcessError(err) } return follow, nil @@ -261,7 +261,7 @@ func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID q := r.newFollowQ(&followRequests). Where("target_account_id = ?", accountID) - err := processErrorResponse(q.Scan(ctx)) + err := r.conn.ProcessError(q.Scan(ctx)) return followRequests, err } @@ -272,7 +272,7 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string q := r.newFollowQ(&follows). Where("account_id = ?", accountID) - err := processErrorResponse(q.Scan(ctx)) + err := r.conn.ProcessError(q.Scan(ctx)) return follows, err } @@ -286,7 +286,6 @@ func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID stri } func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) { - follows := []*gtsmodel.Follow{} q := r.conn. @@ -306,7 +305,7 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str if err == sql.ErrNoRows { return follows, nil } - return nil, processErrorResponse(err) + return nil, r.conn.ProcessError(err) } return follows, nil } diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go index 87e20673d..3a3192a19 100644 --- a/internal/db/bundb/session.go +++ b/internal/db/bundb/session.go @@ -27,12 +27,11 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" - "github.com/uptrace/bun" ) type sessionDB struct { config *config.Config - conn *bun.DB + conn *dbConn log *logrus.Logger } @@ -46,7 +45,7 @@ func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db _, err := q.Exec(ctx) - err = processErrorResponse(err) + err = s.conn.ProcessError(err) return rs, err } @@ -79,7 +78,7 @@ func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession, _, err = q.Exec(ctx) - err = processErrorResponse(err) + err = s.conn.ProcessError(err) return rs, err } diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index da8d8ca41..b735ffaf4 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -34,7 +34,7 @@ import ( type statusDB struct { config *config.Config - conn *bun.DB + conn *dbConn log *logrus.Logger cache cache.Cache } @@ -121,8 +121,7 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat q := s.newStatusQ(status). Where("status.id = ?", id) - err := processErrorResponse(q.Scan(ctx)) - + err := s.conn.ProcessError(q.Scan(ctx)) if err != nil { return nil, err } @@ -144,8 +143,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St q := s.newStatusQ(status). Where("LOWER(status.uri) = LOWER(?)", uri) - err := processErrorResponse(q.Scan(ctx)) - + err := s.conn.ProcessError(q.Scan(ctx)) if err != nil { return nil, err } @@ -167,8 +165,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.St q := s.newStatusQ(status). Where("LOWER(status.url) = LOWER(?)", uri) - err := processErrorResponse(q.Scan(ctx)) - + err := s.conn.ProcessError(q.Scan(ctx)) if err != nil { return nil, err } @@ -217,7 +214,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er return err } - return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction)) + return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction)) } func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) { @@ -321,7 +318,7 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status, Where("status_id = ?", status.ID). Where("account_id = ?", accountID) - return exists(ctx, q) + return s.conn.Exists(ctx, q) } func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { @@ -331,7 +328,7 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta Where("boost_of_id = ?", status.ID). Where("account_id = ?", accountID) - return exists(ctx, q) + return s.conn.Exists(ctx, q) } func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { @@ -341,7 +338,7 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, Where("status_id = ?", status.ID). Where("account_id = ?", accountID) - return exists(ctx, q) + return s.conn.Exists(ctx, q) } func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) { @@ -351,7 +348,7 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St Where("status_id = ?", status.ID). Where("account_id = ?", accountID) - return exists(ctx, q) + return s.conn.Exists(ctx, q) } func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) { @@ -360,7 +357,7 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) q := s.newFaveQ(&faves). Where("status_id = ?", status.ID) - err := processErrorResponse(q.Scan(ctx)) + err := s.conn.ProcessError(q.Scan(ctx)) return faves, err } @@ -370,6 +367,6 @@ func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status q := s.newStatusQ(&reblogs). Where("boost_of_id = ?", status.ID) - err := processErrorResponse(q.Scan(ctx)) + err := s.conn.ProcessError(q.Scan(ctx)) return reblogs, err } diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index cd202f436..dc75f1163 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -32,7 +32,7 @@ import ( type timelineDB struct { config *config.Config - conn *bun.DB + conn *dbConn log *logrus.Logger } @@ -86,7 +86,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI q = q.WhereGroup(" AND ", whereGroup) - return statuses, processErrorResponse(q.Scan(ctx)) + return statuses, t.conn.ProcessError(q.Scan(ctx)) } func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) { @@ -121,13 +121,12 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, ma q = q.Limit(limit) } - return statuses, processErrorResponse(q.Scan(ctx)) + return statuses, t.conn.ProcessError(q.Scan(ctx)) } // TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20! // It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds. func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) { - faves := []*gtsmodel.StatusFave{} fq := t.conn. diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go index faa80221f..9e1afb87e 100644 --- a/internal/db/bundb/util.go +++ b/internal/db/bundb/util.go @@ -19,64 +19,9 @@ package bundb import ( - "context" - "strings" - - "database/sql" - - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/uptrace/bun" ) -// processErrorResponse parses the given error and returns an appropriate DBError. -func processErrorResponse(err error) db.Error { - switch err { - case nil: - return nil - case sql.ErrNoRows: - return db.ErrNoEntries - default: - if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { - return db.ErrAlreadyExists - } - return err - } -} - -func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) { - count, err := q.Count(ctx) - - exists := count != 0 - - err = processErrorResponse(err) - - if err != nil { - if err == db.ErrNoEntries { - return false, nil - } - return false, err - } - - return exists, nil -} - -func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) { - count, err := q.Count(ctx) - - notExists := count == 0 - - err = processErrorResponse(err) - - if err != nil { - if err == db.ErrNoEntries { - return true, nil - } - return false, err - } - - return notExists, nil -} - // whereEmptyOrNull is a convenience function to return a bun WhereGroup that specifies // that the given column should be EITHER an empty string OR null. // diff --git a/vendor/modules.txt b/vendor/modules.txt index be4982c3d..d98ddb2cf 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -302,6 +302,7 @@ github.com/h2non/filetype/types # github.com/jackc/chunkreader/v2 v2.0.1 github.com/jackc/chunkreader/v2 # github.com/jackc/pgconn v1.10.0 +## explicit github.com/jackc/pgconn github.com/jackc/pgconn/internal/ctxwatch github.com/jackc/pgconn/stmtcache