mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-12-29 17:46:15 -06:00
wrap bun.DB in dbConn for dbconn specific error processing etc
Signed-off-by: kim (grufwub) <grufwub@gmail.com>
This commit is contained in:
parent
fc561ef308
commit
0000c2c4a5
18 changed files with 171 additions and 141 deletions
1
go.mod
1
go.mod
|
|
@ -29,6 +29,7 @@ require (
|
|||
github.com/gorilla/sessions v1.2.1 // indirect
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/h2non/filetype v1.1.1
|
||||
github.com/jackc/pgconn v1.10.0
|
||||
github.com/jackc/pgx/v4 v4.13.0
|
||||
github.com/json-iterator/go v1.1.11 // indirect
|
||||
github.com/leodido/go-urn v1.2.1 // indirect
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ import (
|
|||
|
||||
type accountDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
conn *dbConn
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac
|
|||
q := a.newAccountQ(account).
|
||||
Where("account.id = ?", id)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := a.conn.ProcessError(q.Scan(ctx))
|
||||
|
||||
return account, err
|
||||
}
|
||||
|
|
@ -63,7 +63,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
|
|||
q := a.newAccountQ(account).
|
||||
Where("account.uri = ?", uri)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := a.conn.ProcessError(q.Scan(ctx))
|
||||
|
||||
return account, err
|
||||
}
|
||||
|
|
@ -74,7 +74,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.
|
|||
q := a.newAccountQ(account).
|
||||
Where("account.url = ?", uri)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := a.conn.ProcessError(q.Scan(ctx))
|
||||
|
||||
return account, err
|
||||
}
|
||||
|
|
@ -93,7 +93,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
|
|||
|
||||
_, err := q.Exec(ctx)
|
||||
|
||||
err = processErrorResponse(err)
|
||||
err = a.conn.ProcessError(err)
|
||||
|
||||
return account, err
|
||||
}
|
||||
|
|
@ -113,7 +113,7 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
|
|||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||
}
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := a.conn.ProcessError(q.Scan(ctx))
|
||||
|
||||
return account, err
|
||||
}
|
||||
|
|
@ -129,7 +129,7 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string)
|
|||
Where("account_id = ?", accountID).
|
||||
Column("created_at")
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := a.conn.ProcessError(q.Scan(ctx))
|
||||
|
||||
return status.CreatedAt, err
|
||||
}
|
||||
|
|
@ -174,7 +174,7 @@ func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username stri
|
|||
Where("username = ?", username).
|
||||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := a.conn.ProcessError(q.Scan(ctx))
|
||||
|
||||
return account, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,13 +35,12 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||
"github.com/uptrace/bun"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type adminDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
conn *dbConn
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
|
|
@ -52,7 +51,7 @@ func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (boo
|
|||
Where("username = ?", username).
|
||||
Where("domain = ?", nil)
|
||||
|
||||
return notExists(ctx, q)
|
||||
return a.conn.NotExists(ctx, q)
|
||||
}
|
||||
|
||||
func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) {
|
||||
|
|
@ -72,7 +71,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
|
|||
// fail because we found something
|
||||
return false, fmt.Errorf("email domain %s is blocked", domain)
|
||||
} else if err != sql.ErrNoRows {
|
||||
return false, processErrorResponse(err)
|
||||
return false, a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// check if this email is associated with a user already
|
||||
|
|
@ -82,7 +81,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
|
|||
Where("email = ?", email).
|
||||
WhereOr("unconfirmed_email = ?", email)
|
||||
|
||||
return notExists(ctx, q)
|
||||
return a.conn.NotExists(ctx, q)
|
||||
}
|
||||
|
||||
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
|
||||
|
|
@ -187,7 +186,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
|
|||
a.log.Infof("instance account %s already exists", username)
|
||||
return nil
|
||||
} else if err != sql.ErrNoRows {
|
||||
return processErrorResponse(err)
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
|
|
@ -245,7 +244,7 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
|
|||
a.log.Infof("instance instance %s already exists", domain)
|
||||
return nil
|
||||
} else if err != sql.ErrNoRows {
|
||||
return processErrorResponse(err)
|
||||
return a.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
iID, err := id.NewRandomULID()
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ import (
|
|||
|
||||
type basicDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
conn *dbConn
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
|
|
@ -49,7 +49,7 @@ func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Erro
|
|||
Model(i).
|
||||
Where("id = ?", id)
|
||||
|
||||
return processErrorResponse(q.Scan(ctx))
|
||||
return b.conn.ProcessError(q.Scan(ctx))
|
||||
}
|
||||
|
||||
func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
|
||||
|
|
@ -59,7 +59,6 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{})
|
|||
|
||||
q := b.conn.NewSelect().Model(i)
|
||||
for _, w := range where {
|
||||
|
||||
if w.Value == nil {
|
||||
q = q.Where("? IS NULL", bun.Ident(w.Key))
|
||||
} else {
|
||||
|
|
@ -71,7 +70,7 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{})
|
|||
}
|
||||
}
|
||||
|
||||
return processErrorResponse(q.Scan(ctx))
|
||||
return b.conn.ProcessError(q.Scan(ctx))
|
||||
}
|
||||
|
||||
func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
|
||||
|
|
@ -79,7 +78,7 @@ func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
|
|||
NewSelect().
|
||||
Model(i)
|
||||
|
||||
return processErrorResponse(q.Scan(ctx))
|
||||
return b.conn.ProcessError(q.Scan(ctx))
|
||||
}
|
||||
|
||||
func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error {
|
||||
|
|
@ -90,7 +89,7 @@ func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.E
|
|||
|
||||
_, err := q.Exec(ctx)
|
||||
|
||||
return processErrorResponse(err)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
|
||||
|
|
@ -108,7 +107,7 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface
|
|||
|
||||
_, err := q.Exec(ctx)
|
||||
|
||||
return processErrorResponse(err)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error {
|
||||
|
|
@ -119,7 +118,7 @@ func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.E
|
|||
|
||||
_, err := q.Exec(ctx)
|
||||
|
||||
return processErrorResponse(err)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error {
|
||||
|
|
@ -130,7 +129,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu
|
|||
|
||||
_, err := q.Exec(ctx)
|
||||
|
||||
return processErrorResponse(err)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
|
||||
|
|
@ -152,7 +151,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string,
|
|||
|
||||
_, err := q.Exec(ctx)
|
||||
|
||||
return processErrorResponse(err)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
|
||||
|
|
@ -162,7 +161,7 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
|
|||
|
||||
func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
|
||||
_, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
|
||||
return processErrorResponse(err)
|
||||
return b.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ type bunDBService struct {
|
|||
db.Status
|
||||
db.Timeline
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
conn *dbConn
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
|
|
@ -76,7 +76,7 @@ type bunDBService struct {
|
|||
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
|
||||
func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
|
||||
var sqldb *sql.DB
|
||||
var conn *bun.DB
|
||||
var conn *dbConn
|
||||
|
||||
// depending on the database type we're trying to create, we need to use a different driver...
|
||||
switch strings.ToLower(c.DBConfig.Type) {
|
||||
|
|
@ -87,25 +87,30 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
|
|||
return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
|
||||
}
|
||||
sqldb = stdlib.OpenDB(*opts)
|
||||
conn = bun.NewDB(sqldb, pgdialect.New())
|
||||
conn = &dbConn{
|
||||
DB: bun.NewDB(sqldb, pgdialect.New()),
|
||||
errProc: processPostgresError,
|
||||
}
|
||||
case dbTypeSqlite:
|
||||
// SQLITE
|
||||
var err error
|
||||
|
||||
sqldb, err = sql.Open("sqlite", c.DBConfig.Address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not open sqlite db: %s", err)
|
||||
}
|
||||
conn = bun.NewDB(sqldb, sqlitedialect.New())
|
||||
conn = &dbConn{
|
||||
DB: bun.NewDB(sqldb, sqlitedialect.New()),
|
||||
errProc: processSQLiteError,
|
||||
}
|
||||
|
||||
if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") {
|
||||
log.Warn("sqlite in-memory database should only be used for debugging")
|
||||
|
||||
// don't close connections on close -- otherwise
|
||||
// don't close connections on disconnect -- otherwise
|
||||
// the SQLite database will be deleted when there
|
||||
// are no active connections
|
||||
sqldb.SetConnMaxLifetime(0)
|
||||
sqldb.SetMaxOpenConns(1000)
|
||||
sqldb.SetConnMaxLifetime(0)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
|
||||
|
|
|
|||
44
internal/db/bundb/conn.go
Normal file
44
internal/db/bundb/conn.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
package bundb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type dbConn struct {
|
||||
errProc func(error) db.Error // errProc is the SQL-type specific error processor
|
||||
*bun.DB // DB is the underlying bun.DB connection
|
||||
}
|
||||
|
||||
func (conn *dbConn) ProcessError(err error) db.Error {
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case err == sql.ErrNoRows:
|
||||
return db.ErrNoEntries
|
||||
default:
|
||||
return conn.errProc(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *dbConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
|
||||
// Get the select query result
|
||||
count, err := query.Count(ctx)
|
||||
|
||||
// Process error as our own and check if it exists
|
||||
switch err := conn.ProcessError(err); err {
|
||||
case nil, db.ErrAlreadyExists:
|
||||
return (count != 0), nil
|
||||
default:
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *dbConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
|
||||
// Simply inverse of conn.exists()
|
||||
exists, err := conn.Exists(ctx, query)
|
||||
return !exists, err
|
||||
}
|
||||
|
|
@ -27,12 +27,11 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type domainDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
conn *dbConn
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
|
|
@ -47,7 +46,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db
|
|||
Where("LOWER(domain) = LOWER(?)", domain).
|
||||
Limit(1)
|
||||
|
||||
return exists(ctx, q)
|
||||
return d.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
|
||||
|
|
|
|||
43
internal/db/bundb/errors.go
Normal file
43
internal/db/bundb/errors.go
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
package bundb
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"modernc.org/sqlite"
|
||||
sqlite3 "modernc.org/sqlite/lib"
|
||||
)
|
||||
|
||||
// processPostgresError
|
||||
func processPostgresError(err error) db.Error {
|
||||
// Attempt to cast as postgres
|
||||
pgErr, ok := err.(*pgconn.PgError)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle supplied error code:
|
||||
// (https://www.postgresql.org/docs/10/errcodes-appendix.html)
|
||||
switch pgErr.Code {
|
||||
case "23505" /* unique_violation */ :
|
||||
return db.ErrAlreadyExists
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// processSQLiteError
|
||||
func processSQLiteError(err error) db.Error {
|
||||
// Attempt to cast as sqlite
|
||||
sqliteErr, ok := err.(*sqlite.Error)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle supplied error code:
|
||||
switch sqliteErr.Code() {
|
||||
case sqlite3.SQLITE_CONSTRAINT_UNIQUE:
|
||||
return db.ErrAlreadyExists
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
@ -30,7 +30,7 @@ import (
|
|||
|
||||
type instanceDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
conn *dbConn
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
|
|
@ -53,7 +53,7 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int
|
|||
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
return count, processErrorResponse(err)
|
||||
return count, i.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
|
||||
|
|
@ -72,7 +72,7 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (
|
|||
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
return count, processErrorResponse(err)
|
||||
return count, i.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
|
||||
|
|
@ -93,7 +93,7 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i
|
|||
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
return count, processErrorResponse(err)
|
||||
return count, i.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
|
||||
|
|
@ -114,7 +114,7 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max
|
|||
q = q.Limit(limit)
|
||||
}
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := i.conn.ProcessError(q.Scan(ctx))
|
||||
|
||||
return accounts, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ import (
|
|||
|
||||
type mediaDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
conn *dbConn
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
|
|
@ -47,7 +47,7 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
|
|||
q := m.newMediaQ(attachment).
|
||||
Where("media_attachment.id = ?", id)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := m.conn.ProcessError(q.Scan(ctx))
|
||||
|
||||
return attachment, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ import (
|
|||
|
||||
type mentionDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
conn *dbConn
|
||||
log *logrus.Logger
|
||||
cache cache.Cache
|
||||
}
|
||||
|
|
@ -84,7 +84,7 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
|
|||
q := m.newMentionQ(mention).
|
||||
Where("mention.id = ?", id)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := m.conn.ProcessError(q.Scan(ctx))
|
||||
|
||||
if err == nil && mention != nil {
|
||||
m.cacheMention(id, mention)
|
||||
|
|
@ -99,7 +99,7 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.
|
|||
for _, i := range ids {
|
||||
mention, err := m.GetMention(ctx, i)
|
||||
if err != nil {
|
||||
return nil, processErrorResponse(err)
|
||||
return nil, m.conn.ProcessError(err)
|
||||
}
|
||||
mentions = append(mentions, mention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ import (
|
|||
|
||||
type notificationDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
conn *dbConn
|
||||
log *logrus.Logger
|
||||
cache cache.Cache
|
||||
}
|
||||
|
|
@ -84,7 +84,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo
|
|||
q := n.newNotificationQ(notification).
|
||||
Where("notification.id = ?", id)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := n.conn.ProcessError(q.Scan(ctx))
|
||||
|
||||
if err == nil && notification != nil {
|
||||
n.cacheNotification(id, notification)
|
||||
|
|
@ -115,7 +115,7 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
|
|||
q = q.Limit(limit)
|
||||
}
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := n.conn.ProcessError(q.Scan(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -125,7 +125,7 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
|
|||
notifications := []*gtsmodel.Notification{}
|
||||
for _, notifID := range notifIDs {
|
||||
notif, err := n.GetNotification(ctx, notifID.ID)
|
||||
errP := processErrorResponse(err)
|
||||
errP := n.conn.ProcessError(err)
|
||||
if errP != nil {
|
||||
return nil, errP
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ import (
|
|||
|
||||
type relationshipDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
conn *dbConn
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account
|
|||
Where("account_id = ?", account2)
|
||||
}
|
||||
|
||||
return exists(ctx, q)
|
||||
return r.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
|
||||
|
|
@ -76,7 +76,7 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2
|
|||
Where("block.account_id = ?", account1).
|
||||
Where("block.target_account_id = ?", account2)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := r.conn.ProcessError(q.Scan(ctx))
|
||||
|
||||
return block, err
|
||||
}
|
||||
|
|
@ -176,7 +176,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
|
|||
Where("target_account_id = ?", targetAccount.ID).
|
||||
Limit(1)
|
||||
|
||||
return exists(ctx, q)
|
||||
return r.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
|
||||
|
|
@ -190,7 +190,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
|
|||
Where("account_id = ?", sourceAccount.ID).
|
||||
Where("target_account_id = ?", targetAccount.ID)
|
||||
|
||||
return exists(ctx, q)
|
||||
return r.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
|
||||
|
|
@ -201,13 +201,13 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod
|
|||
// make sure account 1 follows account 2
|
||||
f1, err := r.IsFollowing(ctx, account1, account2)
|
||||
if err != nil {
|
||||
return false, processErrorResponse(err)
|
||||
return false, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// make sure account 2 follows account 1
|
||||
f2, err := r.IsFollowing(ctx, account2, account1)
|
||||
if err != nil {
|
||||
return false, processErrorResponse(err)
|
||||
return false, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return f1 && f2, nil
|
||||
|
|
@ -222,7 +222,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
|
|||
Where("account_id = ?", originAccountID).
|
||||
Where("target_account_id = ?", targetAccountID).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, processErrorResponse(err)
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// create a new follow to 'replace' the request with
|
||||
|
|
@ -239,7 +239,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
|
|||
Model(follow).
|
||||
On("CONFLICT CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, processErrorResponse(err)
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// now remove the follow request
|
||||
|
|
@ -249,7 +249,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
|
|||
Where("account_id = ?", originAccountID).
|
||||
Where("target_account_id = ?", targetAccountID).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, processErrorResponse(err)
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return follow, nil
|
||||
|
|
@ -261,7 +261,7 @@ func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID
|
|||
q := r.newFollowQ(&followRequests).
|
||||
Where("target_account_id = ?", accountID)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := r.conn.ProcessError(q.Scan(ctx))
|
||||
|
||||
return followRequests, err
|
||||
}
|
||||
|
|
@ -272,7 +272,7 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
|
|||
q := r.newFollowQ(&follows).
|
||||
Where("account_id = ?", accountID)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := r.conn.ProcessError(q.Scan(ctx))
|
||||
|
||||
return follows, err
|
||||
}
|
||||
|
|
@ -286,7 +286,6 @@ func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID stri
|
|||
}
|
||||
|
||||
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
|
||||
|
||||
follows := []*gtsmodel.Follow{}
|
||||
|
||||
q := r.conn.
|
||||
|
|
@ -306,7 +305,7 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
|
|||
if err == sql.ErrNoRows {
|
||||
return follows, nil
|
||||
}
|
||||
return nil, processErrorResponse(err)
|
||||
return nil, r.conn.ProcessError(err)
|
||||
}
|
||||
return follows, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,12 +27,11 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type sessionDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
conn *dbConn
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
|
|
@ -46,7 +45,7 @@ func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db
|
|||
|
||||
_, err := q.Exec(ctx)
|
||||
|
||||
err = processErrorResponse(err)
|
||||
err = s.conn.ProcessError(err)
|
||||
|
||||
return rs, err
|
||||
}
|
||||
|
|
@ -79,7 +78,7 @@ func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession,
|
|||
|
||||
_, err = q.Exec(ctx)
|
||||
|
||||
err = processErrorResponse(err)
|
||||
err = s.conn.ProcessError(err)
|
||||
|
||||
return rs, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ import (
|
|||
|
||||
type statusDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
conn *dbConn
|
||||
log *logrus.Logger
|
||||
cache cache.Cache
|
||||
}
|
||||
|
|
@ -121,8 +121,7 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat
|
|||
q := s.newStatusQ(status).
|
||||
Where("status.id = ?", id)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
err := s.conn.ProcessError(q.Scan(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -144,8 +143,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St
|
|||
q := s.newStatusQ(status).
|
||||
Where("LOWER(status.uri) = LOWER(?)", uri)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
err := s.conn.ProcessError(q.Scan(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -167,8 +165,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.St
|
|||
q := s.newStatusQ(status).
|
||||
Where("LOWER(status.url) = LOWER(?)", uri)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
|
||||
err := s.conn.ProcessError(q.Scan(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -217,7 +214,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
|
|||
return err
|
||||
}
|
||||
|
||||
return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
|
||||
return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction))
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
|
||||
|
|
@ -321,7 +318,7 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status,
|
|||
Where("status_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
|
||||
return exists(ctx, q)
|
||||
return s.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
|
|
@ -331,7 +328,7 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta
|
|||
Where("boost_of_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
|
||||
return exists(ctx, q)
|
||||
return s.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
|
|
@ -341,7 +338,7 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status,
|
|||
Where("status_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
|
||||
return exists(ctx, q)
|
||||
return s.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||
|
|
@ -351,7 +348,7 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St
|
|||
Where("status_id = ?", status.ID).
|
||||
Where("account_id = ?", accountID)
|
||||
|
||||
return exists(ctx, q)
|
||||
return s.conn.Exists(ctx, q)
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
|
||||
|
|
@ -360,7 +357,7 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status)
|
|||
q := s.newFaveQ(&faves).
|
||||
Where("status_id = ?", status.ID)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := s.conn.ProcessError(q.Scan(ctx))
|
||||
return faves, err
|
||||
}
|
||||
|
||||
|
|
@ -370,6 +367,6 @@ func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status
|
|||
q := s.newStatusQ(&reblogs).
|
||||
Where("boost_of_id = ?", status.ID)
|
||||
|
||||
err := processErrorResponse(q.Scan(ctx))
|
||||
err := s.conn.ProcessError(q.Scan(ctx))
|
||||
return reblogs, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ import (
|
|||
|
||||
type timelineDB struct {
|
||||
config *config.Config
|
||||
conn *bun.DB
|
||||
conn *dbConn
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
|
|||
|
||||
q = q.WhereGroup(" AND ", whereGroup)
|
||||
|
||||
return statuses, processErrorResponse(q.Scan(ctx))
|
||||
return statuses, t.conn.ProcessError(q.Scan(ctx))
|
||||
}
|
||||
|
||||
func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
|
||||
|
|
@ -121,13 +121,12 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, ma
|
|||
q = q.Limit(limit)
|
||||
}
|
||||
|
||||
return statuses, processErrorResponse(q.Scan(ctx))
|
||||
return statuses, t.conn.ProcessError(q.Scan(ctx))
|
||||
}
|
||||
|
||||
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
|
||||
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
|
||||
func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
|
||||
|
||||
faves := []*gtsmodel.StatusFave{}
|
||||
|
||||
fq := t.conn.
|
||||
|
|
|
|||
|
|
@ -19,64 +19,9 @@
|
|||
package bundb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"database/sql"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// processErrorResponse parses the given error and returns an appropriate DBError.
|
||||
func processErrorResponse(err error) db.Error {
|
||||
switch err {
|
||||
case nil:
|
||||
return nil
|
||||
case sql.ErrNoRows:
|
||||
return db.ErrNoEntries
|
||||
default:
|
||||
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
|
||||
return db.ErrAlreadyExists
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
exists := count != 0
|
||||
|
||||
err = processErrorResponse(err)
|
||||
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
|
||||
count, err := q.Count(ctx)
|
||||
|
||||
notExists := count == 0
|
||||
|
||||
err = processErrorResponse(err)
|
||||
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
return notExists, nil
|
||||
}
|
||||
|
||||
// whereEmptyOrNull is a convenience function to return a bun WhereGroup that specifies
|
||||
// that the given column should be EITHER an empty string OR null.
|
||||
//
|
||||
|
|
|
|||
1
vendor/modules.txt
vendored
1
vendor/modules.txt
vendored
|
|
@ -302,6 +302,7 @@ github.com/h2non/filetype/types
|
|||
# github.com/jackc/chunkreader/v2 v2.0.1
|
||||
github.com/jackc/chunkreader/v2
|
||||
# github.com/jackc/pgconn v1.10.0
|
||||
## explicit
|
||||
github.com/jackc/pgconn
|
||||
github.com/jackc/pgconn/internal/ctxwatch
|
||||
github.com/jackc/pgconn/stmtcache
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue