wrap bun.DB in dbConn for dbconn specific error processing etc

Signed-off-by: kim (grufwub) <grufwub@gmail.com>
This commit is contained in:
kim (grufwub) 2021-08-26 15:11:48 +01:00
commit 0000c2c4a5
18 changed files with 171 additions and 141 deletions

1
go.mod
View file

@ -29,6 +29,7 @@ require (
github.com/gorilla/sessions v1.2.1 // indirect
github.com/gorilla/websocket v1.4.2
github.com/h2non/filetype v1.1.1
github.com/jackc/pgconn v1.10.0
github.com/jackc/pgx/v4 v4.13.0
github.com/json-iterator/go v1.1.11 // indirect
github.com/leodido/go-urn v1.2.1 // indirect

View file

@ -34,7 +34,7 @@ import (
type accountDB struct {
config *config.Config
conn *bun.DB
conn *dbConn
log *logrus.Logger
}
@ -52,7 +52,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac
q := a.newAccountQ(account).
Where("account.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
err := a.conn.ProcessError(q.Scan(ctx))
return account, err
}
@ -63,7 +63,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
q := a.newAccountQ(account).
Where("account.uri = ?", uri)
err := processErrorResponse(q.Scan(ctx))
err := a.conn.ProcessError(q.Scan(ctx))
return account, err
}
@ -74,7 +74,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.
q := a.newAccountQ(account).
Where("account.url = ?", uri)
err := processErrorResponse(q.Scan(ctx))
err := a.conn.ProcessError(q.Scan(ctx))
return account, err
}
@ -93,7 +93,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
_, err := q.Exec(ctx)
err = processErrorResponse(err)
err = a.conn.ProcessError(err)
return account, err
}
@ -113,7 +113,7 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
WhereGroup(" AND ", whereEmptyOrNull("domain"))
}
err := processErrorResponse(q.Scan(ctx))
err := a.conn.ProcessError(q.Scan(ctx))
return account, err
}
@ -129,7 +129,7 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string)
Where("account_id = ?", accountID).
Column("created_at")
err := processErrorResponse(q.Scan(ctx))
err := a.conn.ProcessError(q.Scan(ctx))
return status.CreatedAt, err
}
@ -174,7 +174,7 @@ func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username stri
Where("username = ?", username).
WhereGroup(" AND ", whereEmptyOrNull("domain"))
err := processErrorResponse(q.Scan(ctx))
err := a.conn.ProcessError(q.Scan(ctx))
return account, err
}

View file

@ -35,13 +35,12 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
"golang.org/x/crypto/bcrypt"
)
type adminDB struct {
config *config.Config
conn *bun.DB
conn *dbConn
log *logrus.Logger
}
@ -52,7 +51,7 @@ func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (boo
Where("username = ?", username).
Where("domain = ?", nil)
return notExists(ctx, q)
return a.conn.NotExists(ctx, q)
}
func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) {
@ -72,7 +71,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
// fail because we found something
return false, fmt.Errorf("email domain %s is blocked", domain)
} else if err != sql.ErrNoRows {
return false, processErrorResponse(err)
return false, a.conn.ProcessError(err)
}
// check if this email is associated with a user already
@ -82,7 +81,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
Where("email = ?", email).
WhereOr("unconfirmed_email = ?", email)
return notExists(ctx, q)
return a.conn.NotExists(ctx, q)
}
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
@ -187,7 +186,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
a.log.Infof("instance account %s already exists", username)
return nil
} else if err != sql.ErrNoRows {
return processErrorResponse(err)
return a.conn.ProcessError(err)
}
key, err := rsa.GenerateKey(rand.Reader, 2048)
@ -245,7 +244,7 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
a.log.Infof("instance instance %s already exists", domain)
return nil
} else if err != sql.ErrNoRows {
return processErrorResponse(err)
return a.conn.ProcessError(err)
}
iID, err := id.NewRandomULID()

View file

@ -31,7 +31,7 @@ import (
type basicDB struct {
config *config.Config
conn *bun.DB
conn *dbConn
log *logrus.Logger
}
@ -49,7 +49,7 @@ func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Erro
Model(i).
Where("id = ?", id)
return processErrorResponse(q.Scan(ctx))
return b.conn.ProcessError(q.Scan(ctx))
}
func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
@ -59,7 +59,6 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{})
q := b.conn.NewSelect().Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", bun.Ident(w.Key))
} else {
@ -71,7 +70,7 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{})
}
}
return processErrorResponse(q.Scan(ctx))
return b.conn.ProcessError(q.Scan(ctx))
}
func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
@ -79,7 +78,7 @@ func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
NewSelect().
Model(i)
return processErrorResponse(q.Scan(ctx))
return b.conn.ProcessError(q.Scan(ctx))
}
func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error {
@ -90,7 +89,7 @@ func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.E
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
@ -108,7 +107,7 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error {
@ -119,7 +118,7 @@ func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.E
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error {
@ -130,7 +129,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
@ -152,7 +151,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string,
_, err := q.Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
@ -162,7 +161,7 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
_, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
return processErrorResponse(err)
return b.conn.ProcessError(err)
}
func (b *basicDB) IsHealthy(ctx context.Context) db.Error {

View file

@ -68,7 +68,7 @@ type bunDBService struct {
db.Status
db.Timeline
config *config.Config
conn *bun.DB
conn *dbConn
log *logrus.Logger
}
@ -76,7 +76,7 @@ type bunDBService struct {
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
var sqldb *sql.DB
var conn *bun.DB
var conn *dbConn
// depending on the database type we're trying to create, we need to use a different driver...
switch strings.ToLower(c.DBConfig.Type) {
@ -87,25 +87,30 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
}
sqldb = stdlib.OpenDB(*opts)
conn = bun.NewDB(sqldb, pgdialect.New())
conn = &dbConn{
DB: bun.NewDB(sqldb, pgdialect.New()),
errProc: processPostgresError,
}
case dbTypeSqlite:
// SQLITE
var err error
sqldb, err = sql.Open("sqlite", c.DBConfig.Address)
if err != nil {
return nil, fmt.Errorf("could not open sqlite db: %s", err)
}
conn = bun.NewDB(sqldb, sqlitedialect.New())
conn = &dbConn{
DB: bun.NewDB(sqldb, sqlitedialect.New()),
errProc: processSQLiteError,
}
if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") {
log.Warn("sqlite in-memory database should only be used for debugging")
// don't close connections on close -- otherwise
// don't close connections on disconnect -- otherwise
// the SQLite database will be deleted when there
// are no active connections
sqldb.SetConnMaxLifetime(0)
sqldb.SetMaxOpenConns(1000)
sqldb.SetConnMaxLifetime(0)
}
default:
return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))

44
internal/db/bundb/conn.go Normal file
View file

@ -0,0 +1,44 @@
package bundb
import (
"context"
"database/sql"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
)
type dbConn struct {
errProc func(error) db.Error // errProc is the SQL-type specific error processor
*bun.DB // DB is the underlying bun.DB connection
}
func (conn *dbConn) ProcessError(err error) db.Error {
switch {
case err == nil:
return nil
case err == sql.ErrNoRows:
return db.ErrNoEntries
default:
return conn.errProc(err)
}
}
func (conn *dbConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
// Get the select query result
count, err := query.Count(ctx)
// Process error as our own and check if it exists
switch err := conn.ProcessError(err); err {
case nil, db.ErrAlreadyExists:
return (count != 0), nil
default:
return false, err
}
}
func (conn *dbConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
// Simply inverse of conn.exists()
exists, err := conn.Exists(ctx, query)
return !exists, err
}

View file

@ -27,12 +27,11 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
)
type domainDB struct {
config *config.Config
conn *bun.DB
conn *dbConn
log *logrus.Logger
}
@ -47,7 +46,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db
Where("LOWER(domain) = LOWER(?)", domain).
Limit(1)
return exists(ctx, q)
return d.conn.Exists(ctx, q)
}
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {

View file

@ -0,0 +1,43 @@
package bundb
import (
"github.com/jackc/pgconn"
"github.com/superseriousbusiness/gotosocial/internal/db"
"modernc.org/sqlite"
sqlite3 "modernc.org/sqlite/lib"
)
// processPostgresError
func processPostgresError(err error) db.Error {
// Attempt to cast as postgres
pgErr, ok := err.(*pgconn.PgError)
if !ok {
return err
}
// Handle supplied error code:
// (https://www.postgresql.org/docs/10/errcodes-appendix.html)
switch pgErr.Code {
case "23505" /* unique_violation */ :
return db.ErrAlreadyExists
default:
return err
}
}
// processSQLiteError
func processSQLiteError(err error) db.Error {
// Attempt to cast as sqlite
sqliteErr, ok := err.(*sqlite.Error)
if !ok {
return err
}
// Handle supplied error code:
switch sqliteErr.Code() {
case sqlite3.SQLITE_CONSTRAINT_UNIQUE:
return db.ErrAlreadyExists
default:
return nil
}
}

View file

@ -30,7 +30,7 @@ import (
type instanceDB struct {
config *config.Config
conn *bun.DB
conn *dbConn
log *logrus.Logger
}
@ -53,7 +53,7 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int
count, err := q.Count(ctx)
return count, processErrorResponse(err)
return count, i.conn.ProcessError(err)
}
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
@ -72,7 +72,7 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (
count, err := q.Count(ctx)
return count, processErrorResponse(err)
return count, i.conn.ProcessError(err)
}
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
@ -93,7 +93,7 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i
count, err := q.Count(ctx)
return count, processErrorResponse(err)
return count, i.conn.ProcessError(err)
}
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
@ -114,7 +114,7 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max
q = q.Limit(limit)
}
err := processErrorResponse(q.Scan(ctx))
err := i.conn.ProcessError(q.Scan(ctx))
return accounts, err
}

View file

@ -30,7 +30,7 @@ import (
type mediaDB struct {
config *config.Config
conn *bun.DB
conn *dbConn
log *logrus.Logger
}
@ -47,7 +47,7 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
q := m.newMediaQ(attachment).
Where("media_attachment.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
err := m.conn.ProcessError(q.Scan(ctx))
return attachment, err
}

View file

@ -31,7 +31,7 @@ import (
type mentionDB struct {
config *config.Config
conn *bun.DB
conn *dbConn
log *logrus.Logger
cache cache.Cache
}
@ -84,7 +84,7 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
q := m.newMentionQ(mention).
Where("mention.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
err := m.conn.ProcessError(q.Scan(ctx))
if err == nil && mention != nil {
m.cacheMention(id, mention)
@ -99,7 +99,7 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.
for _, i := range ids {
mention, err := m.GetMention(ctx, i)
if err != nil {
return nil, processErrorResponse(err)
return nil, m.conn.ProcessError(err)
}
mentions = append(mentions, mention)
}

View file

@ -31,7 +31,7 @@ import (
type notificationDB struct {
config *config.Config
conn *bun.DB
conn *dbConn
log *logrus.Logger
cache cache.Cache
}
@ -84,7 +84,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo
q := n.newNotificationQ(notification).
Where("notification.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
err := n.conn.ProcessError(q.Scan(ctx))
if err == nil && notification != nil {
n.cacheNotification(id, notification)
@ -115,7 +115,7 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
q = q.Limit(limit)
}
err := processErrorResponse(q.Scan(ctx))
err := n.conn.ProcessError(q.Scan(ctx))
if err != nil {
return nil, err
}
@ -125,7 +125,7 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
notifications := []*gtsmodel.Notification{}
for _, notifID := range notifIDs {
notif, err := n.GetNotification(ctx, notifID.ID)
errP := processErrorResponse(err)
errP := n.conn.ProcessError(err)
if errP != nil {
return nil, errP
}

View file

@ -32,7 +32,7 @@ import (
type relationshipDB struct {
config *config.Config
conn *bun.DB
conn *dbConn
log *logrus.Logger
}
@ -66,7 +66,7 @@ func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account
Where("account_id = ?", account2)
}
return exists(ctx, q)
return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
@ -76,7 +76,7 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2
Where("block.account_id = ?", account1).
Where("block.target_account_id = ?", account2)
err := processErrorResponse(q.Scan(ctx))
err := r.conn.ProcessError(q.Scan(ctx))
return block, err
}
@ -176,7 +176,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
Where("target_account_id = ?", targetAccount.ID).
Limit(1)
return exists(ctx, q)
return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
@ -190,7 +190,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID)
return exists(ctx, q)
return r.conn.Exists(ctx, q)
}
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
@ -201,13 +201,13 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod
// make sure account 1 follows account 2
f1, err := r.IsFollowing(ctx, account1, account2)
if err != nil {
return false, processErrorResponse(err)
return false, r.conn.ProcessError(err)
}
// make sure account 2 follows account 1
f2, err := r.IsFollowing(ctx, account2, account1)
if err != nil {
return false, processErrorResponse(err)
return false, r.conn.ProcessError(err)
}
return f1 && f2, nil
@ -222,7 +222,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
Where("account_id = ?", originAccountID).
Where("target_account_id = ?", targetAccountID).
Scan(ctx); err != nil {
return nil, processErrorResponse(err)
return nil, r.conn.ProcessError(err)
}
// create a new follow to 'replace' the request with
@ -239,7 +239,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
Model(follow).
On("CONFLICT CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
Exec(ctx); err != nil {
return nil, processErrorResponse(err)
return nil, r.conn.ProcessError(err)
}
// now remove the follow request
@ -249,7 +249,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
Where("account_id = ?", originAccountID).
Where("target_account_id = ?", targetAccountID).
Exec(ctx); err != nil {
return nil, processErrorResponse(err)
return nil, r.conn.ProcessError(err)
}
return follow, nil
@ -261,7 +261,7 @@ func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID
q := r.newFollowQ(&followRequests).
Where("target_account_id = ?", accountID)
err := processErrorResponse(q.Scan(ctx))
err := r.conn.ProcessError(q.Scan(ctx))
return followRequests, err
}
@ -272,7 +272,7 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
q := r.newFollowQ(&follows).
Where("account_id = ?", accountID)
err := processErrorResponse(q.Scan(ctx))
err := r.conn.ProcessError(q.Scan(ctx))
return follows, err
}
@ -286,7 +286,6 @@ func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID stri
}
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
q := r.conn.
@ -306,7 +305,7 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
if err == sql.ErrNoRows {
return follows, nil
}
return nil, processErrorResponse(err)
return nil, r.conn.ProcessError(err)
}
return follows, nil
}

View file

@ -27,12 +27,11 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/uptrace/bun"
)
type sessionDB struct {
config *config.Config
conn *bun.DB
conn *dbConn
log *logrus.Logger
}
@ -46,7 +45,7 @@ func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db
_, err := q.Exec(ctx)
err = processErrorResponse(err)
err = s.conn.ProcessError(err)
return rs, err
}
@ -79,7 +78,7 @@ func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession,
_, err = q.Exec(ctx)
err = processErrorResponse(err)
err = s.conn.ProcessError(err)
return rs, err
}

View file

@ -34,7 +34,7 @@ import (
type statusDB struct {
config *config.Config
conn *bun.DB
conn *dbConn
log *logrus.Logger
cache cache.Cache
}
@ -121,8 +121,7 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat
q := s.newStatusQ(status).
Where("status.id = ?", id)
err := processErrorResponse(q.Scan(ctx))
err := s.conn.ProcessError(q.Scan(ctx))
if err != nil {
return nil, err
}
@ -144,8 +143,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St
q := s.newStatusQ(status).
Where("LOWER(status.uri) = LOWER(?)", uri)
err := processErrorResponse(q.Scan(ctx))
err := s.conn.ProcessError(q.Scan(ctx))
if err != nil {
return nil, err
}
@ -167,8 +165,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.St
q := s.newStatusQ(status).
Where("LOWER(status.url) = LOWER(?)", uri)
err := processErrorResponse(q.Scan(ctx))
err := s.conn.ProcessError(q.Scan(ctx))
if err != nil {
return nil, err
}
@ -217,7 +214,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
return err
}
return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction))
}
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
@ -321,7 +318,7 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status,
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
return s.conn.Exists(ctx, q)
}
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
@ -331,7 +328,7 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta
Where("boost_of_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
return s.conn.Exists(ctx, q)
}
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
@ -341,7 +338,7 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status,
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
return s.conn.Exists(ctx, q)
}
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
@ -351,7 +348,7 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St
Where("status_id = ?", status.ID).
Where("account_id = ?", accountID)
return exists(ctx, q)
return s.conn.Exists(ctx, q)
}
func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
@ -360,7 +357,7 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status)
q := s.newFaveQ(&faves).
Where("status_id = ?", status.ID)
err := processErrorResponse(q.Scan(ctx))
err := s.conn.ProcessError(q.Scan(ctx))
return faves, err
}
@ -370,6 +367,6 @@ func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status
q := s.newStatusQ(&reblogs).
Where("boost_of_id = ?", status.ID)
err := processErrorResponse(q.Scan(ctx))
err := s.conn.ProcessError(q.Scan(ctx))
return reblogs, err
}

View file

@ -32,7 +32,7 @@ import (
type timelineDB struct {
config *config.Config
conn *bun.DB
conn *dbConn
log *logrus.Logger
}
@ -86,7 +86,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
q = q.WhereGroup(" AND ", whereGroup)
return statuses, processErrorResponse(q.Scan(ctx))
return statuses, t.conn.ProcessError(q.Scan(ctx))
}
func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
@ -121,13 +121,12 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, ma
q = q.Limit(limit)
}
return statuses, processErrorResponse(q.Scan(ctx))
return statuses, t.conn.ProcessError(q.Scan(ctx))
}
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
faves := []*gtsmodel.StatusFave{}
fq := t.conn.

View file

@ -19,64 +19,9 @@
package bundb
import (
"context"
"strings"
"database/sql"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/uptrace/bun"
)
// processErrorResponse parses the given error and returns an appropriate DBError.
func processErrorResponse(err error) db.Error {
switch err {
case nil:
return nil
case sql.ErrNoRows:
return db.ErrNoEntries
default:
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
}
func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
count, err := q.Count(ctx)
exists := count != 0
err = processErrorResponse(err)
if err != nil {
if err == db.ErrNoEntries {
return false, nil
}
return false, err
}
return exists, nil
}
func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
count, err := q.Count(ctx)
notExists := count == 0
err = processErrorResponse(err)
if err != nil {
if err == db.ErrNoEntries {
return true, nil
}
return false, err
}
return notExists, nil
}
// whereEmptyOrNull is a convenience function to return a bun WhereGroup that specifies
// that the given column should be EITHER an empty string OR null.
//

1
vendor/modules.txt vendored
View file

@ -302,6 +302,7 @@ github.com/h2non/filetype/types
# github.com/jackc/chunkreader/v2 v2.0.1
github.com/jackc/chunkreader/v2
# github.com/jackc/pgconn v1.10.0
## explicit
github.com/jackc/pgconn
github.com/jackc/pgconn/internal/ctxwatch
github.com/jackc/pgconn/stmtcache