mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-12-29 22:36:14 -06:00
wrap bun.DB in dbConn for dbconn specific error processing etc
Signed-off-by: kim (grufwub) <grufwub@gmail.com>
This commit is contained in:
parent
fc561ef308
commit
0000c2c4a5
18 changed files with 171 additions and 141 deletions
1
go.mod
1
go.mod
|
|
@ -29,6 +29,7 @@ require (
|
||||||
github.com/gorilla/sessions v1.2.1 // indirect
|
github.com/gorilla/sessions v1.2.1 // indirect
|
||||||
github.com/gorilla/websocket v1.4.2
|
github.com/gorilla/websocket v1.4.2
|
||||||
github.com/h2non/filetype v1.1.1
|
github.com/h2non/filetype v1.1.1
|
||||||
|
github.com/jackc/pgconn v1.10.0
|
||||||
github.com/jackc/pgx/v4 v4.13.0
|
github.com/jackc/pgx/v4 v4.13.0
|
||||||
github.com/json-iterator/go v1.1.11 // indirect
|
github.com/json-iterator/go v1.1.11 // indirect
|
||||||
github.com/leodido/go-urn v1.2.1 // indirect
|
github.com/leodido/go-urn v1.2.1 // indirect
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ import (
|
||||||
|
|
||||||
type accountDB struct {
|
type accountDB struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
conn *bun.DB
|
conn *dbConn
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -52,7 +52,7 @@ func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Ac
|
||||||
q := a.newAccountQ(account).
|
q := a.newAccountQ(account).
|
||||||
Where("account.id = ?", id)
|
Where("account.id = ?", id)
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := a.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
return account, err
|
return account, err
|
||||||
}
|
}
|
||||||
|
|
@ -63,7 +63,7 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
|
||||||
q := a.newAccountQ(account).
|
q := a.newAccountQ(account).
|
||||||
Where("account.uri = ?", uri)
|
Where("account.uri = ?", uri)
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := a.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
return account, err
|
return account, err
|
||||||
}
|
}
|
||||||
|
|
@ -74,7 +74,7 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.
|
||||||
q := a.newAccountQ(account).
|
q := a.newAccountQ(account).
|
||||||
Where("account.url = ?", uri)
|
Where("account.url = ?", uri)
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := a.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
return account, err
|
return account, err
|
||||||
}
|
}
|
||||||
|
|
@ -93,7 +93,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
|
||||||
|
|
||||||
_, err := q.Exec(ctx)
|
_, err := q.Exec(ctx)
|
||||||
|
|
||||||
err = processErrorResponse(err)
|
err = a.conn.ProcessError(err)
|
||||||
|
|
||||||
return account, err
|
return account, err
|
||||||
}
|
}
|
||||||
|
|
@ -113,7 +113,7 @@ func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gts
|
||||||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||||
}
|
}
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := a.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
return account, err
|
return account, err
|
||||||
}
|
}
|
||||||
|
|
@ -129,7 +129,7 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string)
|
||||||
Where("account_id = ?", accountID).
|
Where("account_id = ?", accountID).
|
||||||
Column("created_at")
|
Column("created_at")
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := a.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
return status.CreatedAt, err
|
return status.CreatedAt, err
|
||||||
}
|
}
|
||||||
|
|
@ -174,7 +174,7 @@ func (a *accountDB) GetLocalAccountByUsername(ctx context.Context, username stri
|
||||||
Where("username = ?", username).
|
Where("username = ?", username).
|
||||||
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
WhereGroup(" AND ", whereEmptyOrNull("domain"))
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := a.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
return account, err
|
return account, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -35,13 +35,12 @@ import (
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||||
"github.com/uptrace/bun"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type adminDB struct {
|
type adminDB struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
conn *bun.DB
|
conn *dbConn
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -52,7 +51,7 @@ func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (boo
|
||||||
Where("username = ?", username).
|
Where("username = ?", username).
|
||||||
Where("domain = ?", nil)
|
Where("domain = ?", nil)
|
||||||
|
|
||||||
return notExists(ctx, q)
|
return a.conn.NotExists(ctx, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) {
|
func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.Error) {
|
||||||
|
|
@ -72,7 +71,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
|
||||||
// fail because we found something
|
// fail because we found something
|
||||||
return false, fmt.Errorf("email domain %s is blocked", domain)
|
return false, fmt.Errorf("email domain %s is blocked", domain)
|
||||||
} else if err != sql.ErrNoRows {
|
} else if err != sql.ErrNoRows {
|
||||||
return false, processErrorResponse(err)
|
return false, a.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if this email is associated with a user already
|
// check if this email is associated with a user already
|
||||||
|
|
@ -82,7 +81,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, db.
|
||||||
Where("email = ?", email).
|
Where("email = ?", email).
|
||||||
WhereOr("unconfirmed_email = ?", email)
|
WhereOr("unconfirmed_email = ?", email)
|
||||||
|
|
||||||
return notExists(ctx, q)
|
return a.conn.NotExists(ctx, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
|
func (a *adminDB) NewSignup(ctx context.Context, username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
|
||||||
|
|
@ -187,7 +186,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) db.Error {
|
||||||
a.log.Infof("instance account %s already exists", username)
|
a.log.Infof("instance account %s already exists", username)
|
||||||
return nil
|
return nil
|
||||||
} else if err != sql.ErrNoRows {
|
} else if err != sql.ErrNoRows {
|
||||||
return processErrorResponse(err)
|
return a.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
|
@ -245,7 +244,7 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) db.Error {
|
||||||
a.log.Infof("instance instance %s already exists", domain)
|
a.log.Infof("instance instance %s already exists", domain)
|
||||||
return nil
|
return nil
|
||||||
} else if err != sql.ErrNoRows {
|
} else if err != sql.ErrNoRows {
|
||||||
return processErrorResponse(err)
|
return a.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iID, err := id.NewRandomULID()
|
iID, err := id.NewRandomULID()
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ import (
|
||||||
|
|
||||||
type basicDB struct {
|
type basicDB struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
conn *bun.DB
|
conn *dbConn
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -49,7 +49,7 @@ func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) db.Erro
|
||||||
Model(i).
|
Model(i).
|
||||||
Where("id = ?", id)
|
Where("id = ?", id)
|
||||||
|
|
||||||
return processErrorResponse(q.Scan(ctx))
|
return b.conn.ProcessError(q.Scan(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
|
func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
|
||||||
|
|
@ -59,7 +59,6 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{})
|
||||||
|
|
||||||
q := b.conn.NewSelect().Model(i)
|
q := b.conn.NewSelect().Model(i)
|
||||||
for _, w := range where {
|
for _, w := range where {
|
||||||
|
|
||||||
if w.Value == nil {
|
if w.Value == nil {
|
||||||
q = q.Where("? IS NULL", bun.Ident(w.Key))
|
q = q.Where("? IS NULL", bun.Ident(w.Key))
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -71,7 +70,7 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return processErrorResponse(q.Scan(ctx))
|
return b.conn.ProcessError(q.Scan(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
|
func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
|
||||||
|
|
@ -79,7 +78,7 @@ func (b *basicDB) GetAll(ctx context.Context, i interface{}) db.Error {
|
||||||
NewSelect().
|
NewSelect().
|
||||||
Model(i)
|
Model(i)
|
||||||
|
|
||||||
return processErrorResponse(q.Scan(ctx))
|
return b.conn.ProcessError(q.Scan(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error {
|
func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.Error {
|
||||||
|
|
@ -90,7 +89,7 @@ func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) db.E
|
||||||
|
|
||||||
_, err := q.Exec(ctx)
|
_, err := q.Exec(ctx)
|
||||||
|
|
||||||
return processErrorResponse(err)
|
return b.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
|
func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) db.Error {
|
||||||
|
|
@ -108,7 +107,7 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface
|
||||||
|
|
||||||
_, err := q.Exec(ctx)
|
_, err := q.Exec(ctx)
|
||||||
|
|
||||||
return processErrorResponse(err)
|
return b.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error {
|
func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.Error {
|
||||||
|
|
@ -119,7 +118,7 @@ func (b *basicDB) UpdateByID(ctx context.Context, id string, i interface{}) db.E
|
||||||
|
|
||||||
_, err := q.Exec(ctx)
|
_, err := q.Exec(ctx)
|
||||||
|
|
||||||
return processErrorResponse(err)
|
return b.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error {
|
func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, value interface{}, i interface{}) db.Error {
|
||||||
|
|
@ -130,7 +129,7 @@ func (b *basicDB) UpdateOneByID(ctx context.Context, id string, key string, valu
|
||||||
|
|
||||||
_, err := q.Exec(ctx)
|
_, err := q.Exec(ctx)
|
||||||
|
|
||||||
return processErrorResponse(err)
|
return b.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
|
func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) db.Error {
|
||||||
|
|
@ -152,7 +151,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string,
|
||||||
|
|
||||||
_, err := q.Exec(ctx)
|
_, err := q.Exec(ctx)
|
||||||
|
|
||||||
return processErrorResponse(err)
|
return b.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
|
func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
|
||||||
|
|
@ -162,7 +161,7 @@ func (b *basicDB) CreateTable(ctx context.Context, i interface{}) db.Error {
|
||||||
|
|
||||||
func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
|
func (b *basicDB) DropTable(ctx context.Context, i interface{}) db.Error {
|
||||||
_, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
|
_, err := b.conn.NewDropTable().Model(i).IfExists().Exec(ctx)
|
||||||
return processErrorResponse(err)
|
return b.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
|
func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ type bunDBService struct {
|
||||||
db.Status
|
db.Status
|
||||||
db.Timeline
|
db.Timeline
|
||||||
config *config.Config
|
config *config.Config
|
||||||
conn *bun.DB
|
conn *dbConn
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -76,7 +76,7 @@ type bunDBService struct {
|
||||||
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
|
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
|
||||||
func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
|
func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
|
||||||
var sqldb *sql.DB
|
var sqldb *sql.DB
|
||||||
var conn *bun.DB
|
var conn *dbConn
|
||||||
|
|
||||||
// depending on the database type we're trying to create, we need to use a different driver...
|
// depending on the database type we're trying to create, we need to use a different driver...
|
||||||
switch strings.ToLower(c.DBConfig.Type) {
|
switch strings.ToLower(c.DBConfig.Type) {
|
||||||
|
|
@ -87,25 +87,30 @@ func NewBunDBService(ctx context.Context, c *config.Config, log *logrus.Logger)
|
||||||
return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
|
return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
|
||||||
}
|
}
|
||||||
sqldb = stdlib.OpenDB(*opts)
|
sqldb = stdlib.OpenDB(*opts)
|
||||||
conn = bun.NewDB(sqldb, pgdialect.New())
|
conn = &dbConn{
|
||||||
|
DB: bun.NewDB(sqldb, pgdialect.New()),
|
||||||
|
errProc: processPostgresError,
|
||||||
|
}
|
||||||
case dbTypeSqlite:
|
case dbTypeSqlite:
|
||||||
// SQLITE
|
// SQLITE
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
sqldb, err = sql.Open("sqlite", c.DBConfig.Address)
|
sqldb, err = sql.Open("sqlite", c.DBConfig.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not open sqlite db: %s", err)
|
return nil, fmt.Errorf("could not open sqlite db: %s", err)
|
||||||
}
|
}
|
||||||
conn = bun.NewDB(sqldb, sqlitedialect.New())
|
conn = &dbConn{
|
||||||
|
DB: bun.NewDB(sqldb, sqlitedialect.New()),
|
||||||
|
errProc: processSQLiteError,
|
||||||
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") {
|
if strings.HasPrefix(strings.TrimPrefix(c.DBConfig.Address, "file:"), ":memory:") {
|
||||||
log.Warn("sqlite in-memory database should only be used for debugging")
|
log.Warn("sqlite in-memory database should only be used for debugging")
|
||||||
|
|
||||||
// don't close connections on close -- otherwise
|
// don't close connections on disconnect -- otherwise
|
||||||
// the SQLite database will be deleted when there
|
// the SQLite database will be deleted when there
|
||||||
// are no active connections
|
// are no active connections
|
||||||
sqldb.SetConnMaxLifetime(0)
|
|
||||||
sqldb.SetMaxOpenConns(1000)
|
sqldb.SetMaxOpenConns(1000)
|
||||||
|
sqldb.SetConnMaxLifetime(0)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
|
return nil, fmt.Errorf("database type %s not supported for bundb", strings.ToLower(c.DBConfig.Type))
|
||||||
|
|
|
||||||
44
internal/db/bundb/conn.go
Normal file
44
internal/db/bundb/conn.go
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
package bundb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dbConn struct {
|
||||||
|
errProc func(error) db.Error // errProc is the SQL-type specific error processor
|
||||||
|
*bun.DB // DB is the underlying bun.DB connection
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *dbConn) ProcessError(err error) db.Error {
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
return nil
|
||||||
|
case err == sql.ErrNoRows:
|
||||||
|
return db.ErrNoEntries
|
||||||
|
default:
|
||||||
|
return conn.errProc(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *dbConn) Exists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
|
||||||
|
// Get the select query result
|
||||||
|
count, err := query.Count(ctx)
|
||||||
|
|
||||||
|
// Process error as our own and check if it exists
|
||||||
|
switch err := conn.ProcessError(err); err {
|
||||||
|
case nil, db.ErrAlreadyExists:
|
||||||
|
return (count != 0), nil
|
||||||
|
default:
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *dbConn) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, db.Error) {
|
||||||
|
// Simply inverse of conn.exists()
|
||||||
|
exists, err := conn.Exists(ctx, query)
|
||||||
|
return !exists, err
|
||||||
|
}
|
||||||
|
|
@ -27,12 +27,11 @@ import (
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||||
"github.com/uptrace/bun"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type domainDB struct {
|
type domainDB struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
conn *bun.DB
|
conn *dbConn
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -47,7 +46,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db
|
||||||
Where("LOWER(domain) = LOWER(?)", domain).
|
Where("LOWER(domain) = LOWER(?)", domain).
|
||||||
Limit(1)
|
Limit(1)
|
||||||
|
|
||||||
return exists(ctx, q)
|
return d.conn.Exists(ctx, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
|
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
|
||||||
|
|
|
||||||
43
internal/db/bundb/errors.go
Normal file
43
internal/db/bundb/errors.go
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
package bundb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||||
|
"modernc.org/sqlite"
|
||||||
|
sqlite3 "modernc.org/sqlite/lib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// processPostgresError
|
||||||
|
func processPostgresError(err error) db.Error {
|
||||||
|
// Attempt to cast as postgres
|
||||||
|
pgErr, ok := err.(*pgconn.PgError)
|
||||||
|
if !ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle supplied error code:
|
||||||
|
// (https://www.postgresql.org/docs/10/errcodes-appendix.html)
|
||||||
|
switch pgErr.Code {
|
||||||
|
case "23505" /* unique_violation */ :
|
||||||
|
return db.ErrAlreadyExists
|
||||||
|
default:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processSQLiteError
|
||||||
|
func processSQLiteError(err error) db.Error {
|
||||||
|
// Attempt to cast as sqlite
|
||||||
|
sqliteErr, ok := err.(*sqlite.Error)
|
||||||
|
if !ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle supplied error code:
|
||||||
|
switch sqliteErr.Code() {
|
||||||
|
case sqlite3.SQLITE_CONSTRAINT_UNIQUE:
|
||||||
|
return db.ErrAlreadyExists
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -30,7 +30,7 @@ import (
|
||||||
|
|
||||||
type instanceDB struct {
|
type instanceDB struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
conn *bun.DB
|
conn *dbConn
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -53,7 +53,7 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int
|
||||||
|
|
||||||
count, err := q.Count(ctx)
|
count, err := q.Count(ctx)
|
||||||
|
|
||||||
return count, processErrorResponse(err)
|
return count, i.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
|
func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (int, db.Error) {
|
||||||
|
|
@ -72,7 +72,7 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) (
|
||||||
|
|
||||||
count, err := q.Count(ctx)
|
count, err := q.Count(ctx)
|
||||||
|
|
||||||
return count, processErrorResponse(err)
|
return count, i.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
|
func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (int, db.Error) {
|
||||||
|
|
@ -93,7 +93,7 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i
|
||||||
|
|
||||||
count, err := q.Count(ctx)
|
count, err := q.Count(ctx)
|
||||||
|
|
||||||
return count, processErrorResponse(err)
|
return count, i.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
|
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
|
||||||
|
|
@ -114,7 +114,7 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max
|
||||||
q = q.Limit(limit)
|
q = q.Limit(limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := i.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
return accounts, err
|
return accounts, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ import (
|
||||||
|
|
||||||
type mediaDB struct {
|
type mediaDB struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
conn *bun.DB
|
conn *dbConn
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -47,7 +47,7 @@ func (m *mediaDB) GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.M
|
||||||
q := m.newMediaQ(attachment).
|
q := m.newMediaQ(attachment).
|
||||||
Where("media_attachment.id = ?", id)
|
Where("media_attachment.id = ?", id)
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := m.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
return attachment, err
|
return attachment, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ import (
|
||||||
|
|
||||||
type mentionDB struct {
|
type mentionDB struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
conn *bun.DB
|
conn *dbConn
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
cache cache.Cache
|
cache cache.Cache
|
||||||
}
|
}
|
||||||
|
|
@ -84,7 +84,7 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
|
||||||
q := m.newMentionQ(mention).
|
q := m.newMentionQ(mention).
|
||||||
Where("mention.id = ?", id)
|
Where("mention.id = ?", id)
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := m.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
if err == nil && mention != nil {
|
if err == nil && mention != nil {
|
||||||
m.cacheMention(id, mention)
|
m.cacheMention(id, mention)
|
||||||
|
|
@ -99,7 +99,7 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.
|
||||||
for _, i := range ids {
|
for _, i := range ids {
|
||||||
mention, err := m.GetMention(ctx, i)
|
mention, err := m.GetMention(ctx, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, processErrorResponse(err)
|
return nil, m.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
mentions = append(mentions, mention)
|
mentions = append(mentions, mention)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ import (
|
||||||
|
|
||||||
type notificationDB struct {
|
type notificationDB struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
conn *bun.DB
|
conn *dbConn
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
cache cache.Cache
|
cache cache.Cache
|
||||||
}
|
}
|
||||||
|
|
@ -84,7 +84,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo
|
||||||
q := n.newNotificationQ(notification).
|
q := n.newNotificationQ(notification).
|
||||||
Where("notification.id = ?", id)
|
Where("notification.id = ?", id)
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := n.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
if err == nil && notification != nil {
|
if err == nil && notification != nil {
|
||||||
n.cacheNotification(id, notification)
|
n.cacheNotification(id, notification)
|
||||||
|
|
@ -115,7 +115,7 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
|
||||||
q = q.Limit(limit)
|
q = q.Limit(limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := n.conn.ProcessError(q.Scan(ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -125,7 +125,7 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
|
||||||
notifications := []*gtsmodel.Notification{}
|
notifications := []*gtsmodel.Notification{}
|
||||||
for _, notifID := range notifIDs {
|
for _, notifID := range notifIDs {
|
||||||
notif, err := n.GetNotification(ctx, notifID.ID)
|
notif, err := n.GetNotification(ctx, notifID.ID)
|
||||||
errP := processErrorResponse(err)
|
errP := n.conn.ProcessError(err)
|
||||||
if errP != nil {
|
if errP != nil {
|
||||||
return nil, errP
|
return nil, errP
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ import (
|
||||||
|
|
||||||
type relationshipDB struct {
|
type relationshipDB struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
conn *bun.DB
|
conn *dbConn
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -66,7 +66,7 @@ func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account
|
||||||
Where("account_id = ?", account2)
|
Where("account_id = ?", account2)
|
||||||
}
|
}
|
||||||
|
|
||||||
return exists(ctx, q)
|
return r.conn.Exists(ctx, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
|
func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
|
||||||
|
|
@ -76,7 +76,7 @@ func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2
|
||||||
Where("block.account_id = ?", account1).
|
Where("block.account_id = ?", account1).
|
||||||
Where("block.target_account_id = ?", account2)
|
Where("block.target_account_id = ?", account2)
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := r.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
return block, err
|
return block, err
|
||||||
}
|
}
|
||||||
|
|
@ -176,7 +176,7 @@ func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmode
|
||||||
Where("target_account_id = ?", targetAccount.ID).
|
Where("target_account_id = ?", targetAccount.ID).
|
||||||
Limit(1)
|
Limit(1)
|
||||||
|
|
||||||
return exists(ctx, q)
|
return r.conn.Exists(ctx, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
|
func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
|
||||||
|
|
@ -190,7 +190,7 @@ func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *g
|
||||||
Where("account_id = ?", sourceAccount.ID).
|
Where("account_id = ?", sourceAccount.ID).
|
||||||
Where("target_account_id = ?", targetAccount.ID)
|
Where("target_account_id = ?", targetAccount.ID)
|
||||||
|
|
||||||
return exists(ctx, q)
|
return r.conn.Exists(ctx, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
|
func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
|
||||||
|
|
@ -201,13 +201,13 @@ func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmod
|
||||||
// make sure account 1 follows account 2
|
// make sure account 1 follows account 2
|
||||||
f1, err := r.IsFollowing(ctx, account1, account2)
|
f1, err := r.IsFollowing(ctx, account1, account2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, processErrorResponse(err)
|
return false, r.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// make sure account 2 follows account 1
|
// make sure account 2 follows account 1
|
||||||
f2, err := r.IsFollowing(ctx, account2, account1)
|
f2, err := r.IsFollowing(ctx, account2, account1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, processErrorResponse(err)
|
return false, r.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return f1 && f2, nil
|
return f1 && f2, nil
|
||||||
|
|
@ -222,7 +222,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
|
||||||
Where("account_id = ?", originAccountID).
|
Where("account_id = ?", originAccountID).
|
||||||
Where("target_account_id = ?", targetAccountID).
|
Where("target_account_id = ?", targetAccountID).
|
||||||
Scan(ctx); err != nil {
|
Scan(ctx); err != nil {
|
||||||
return nil, processErrorResponse(err)
|
return nil, r.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a new follow to 'replace' the request with
|
// create a new follow to 'replace' the request with
|
||||||
|
|
@ -239,7 +239,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
|
||||||
Model(follow).
|
Model(follow).
|
||||||
On("CONFLICT CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
|
On("CONFLICT CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
return nil, processErrorResponse(err)
|
return nil, r.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// now remove the follow request
|
// now remove the follow request
|
||||||
|
|
@ -249,7 +249,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountI
|
||||||
Where("account_id = ?", originAccountID).
|
Where("account_id = ?", originAccountID).
|
||||||
Where("target_account_id = ?", targetAccountID).
|
Where("target_account_id = ?", targetAccountID).
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
return nil, processErrorResponse(err)
|
return nil, r.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return follow, nil
|
return follow, nil
|
||||||
|
|
@ -261,7 +261,7 @@ func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID
|
||||||
q := r.newFollowQ(&followRequests).
|
q := r.newFollowQ(&followRequests).
|
||||||
Where("target_account_id = ?", accountID)
|
Where("target_account_id = ?", accountID)
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := r.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
return followRequests, err
|
return followRequests, err
|
||||||
}
|
}
|
||||||
|
|
@ -272,7 +272,7 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string
|
||||||
q := r.newFollowQ(&follows).
|
q := r.newFollowQ(&follows).
|
||||||
Where("account_id = ?", accountID)
|
Where("account_id = ?", accountID)
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := r.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
return follows, err
|
return follows, err
|
||||||
}
|
}
|
||||||
|
|
@ -286,7 +286,6 @@ func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID stri
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
|
func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
|
||||||
|
|
||||||
follows := []*gtsmodel.Follow{}
|
follows := []*gtsmodel.Follow{}
|
||||||
|
|
||||||
q := r.conn.
|
q := r.conn.
|
||||||
|
|
@ -306,7 +305,7 @@ func (r *relationshipDB) GetAccountFollowedBy(ctx context.Context, accountID str
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return follows, nil
|
return follows, nil
|
||||||
}
|
}
|
||||||
return nil, processErrorResponse(err)
|
return nil, r.conn.ProcessError(err)
|
||||||
}
|
}
|
||||||
return follows, nil
|
return follows, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -27,12 +27,11 @@ import (
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||||
"github.com/uptrace/bun"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type sessionDB struct {
|
type sessionDB struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
conn *bun.DB
|
conn *dbConn
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -46,7 +45,7 @@ func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, db
|
||||||
|
|
||||||
_, err := q.Exec(ctx)
|
_, err := q.Exec(ctx)
|
||||||
|
|
||||||
err = processErrorResponse(err)
|
err = s.conn.ProcessError(err)
|
||||||
|
|
||||||
return rs, err
|
return rs, err
|
||||||
}
|
}
|
||||||
|
|
@ -79,7 +78,7 @@ func (s *sessionDB) CreateSession(ctx context.Context) (*gtsmodel.RouterSession,
|
||||||
|
|
||||||
_, err = q.Exec(ctx)
|
_, err = q.Exec(ctx)
|
||||||
|
|
||||||
err = processErrorResponse(err)
|
err = s.conn.ProcessError(err)
|
||||||
|
|
||||||
return rs, err
|
return rs, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ import (
|
||||||
|
|
||||||
type statusDB struct {
|
type statusDB struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
conn *bun.DB
|
conn *dbConn
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
cache cache.Cache
|
cache cache.Cache
|
||||||
}
|
}
|
||||||
|
|
@ -121,8 +121,7 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat
|
||||||
q := s.newStatusQ(status).
|
q := s.newStatusQ(status).
|
||||||
Where("status.id = ?", id)
|
Where("status.id = ?", id)
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := s.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -144,8 +143,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St
|
||||||
q := s.newStatusQ(status).
|
q := s.newStatusQ(status).
|
||||||
Where("LOWER(status.uri) = LOWER(?)", uri)
|
Where("LOWER(status.uri) = LOWER(?)", uri)
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := s.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -167,8 +165,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.St
|
||||||
q := s.newStatusQ(status).
|
q := s.newStatusQ(status).
|
||||||
Where("LOWER(status.url) = LOWER(?)", uri)
|
Where("LOWER(status.url) = LOWER(?)", uri)
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := s.conn.ProcessError(q.Scan(ctx))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -217,7 +214,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return processErrorResponse(s.conn.RunInTx(ctx, nil, transaction))
|
return s.conn.ProcessError(s.conn.RunInTx(ctx, nil, transaction))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
|
func (s *statusDB) GetStatusParents(ctx context.Context, status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
|
||||||
|
|
@ -321,7 +318,7 @@ func (s *statusDB) IsStatusFavedBy(ctx context.Context, status *gtsmodel.Status,
|
||||||
Where("status_id = ?", status.ID).
|
Where("status_id = ?", status.ID).
|
||||||
Where("account_id = ?", accountID)
|
Where("account_id = ?", accountID)
|
||||||
|
|
||||||
return exists(ctx, q)
|
return s.conn.Exists(ctx, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||||
|
|
@ -331,7 +328,7 @@ func (s *statusDB) IsStatusRebloggedBy(ctx context.Context, status *gtsmodel.Sta
|
||||||
Where("boost_of_id = ?", status.ID).
|
Where("boost_of_id = ?", status.ID).
|
||||||
Where("account_id = ?", accountID)
|
Where("account_id = ?", accountID)
|
||||||
|
|
||||||
return exists(ctx, q)
|
return s.conn.Exists(ctx, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||||
|
|
@ -341,7 +338,7 @@ func (s *statusDB) IsStatusMutedBy(ctx context.Context, status *gtsmodel.Status,
|
||||||
Where("status_id = ?", status.ID).
|
Where("status_id = ?", status.ID).
|
||||||
Where("account_id = ?", accountID)
|
Where("account_id = ?", accountID)
|
||||||
|
|
||||||
return exists(ctx, q)
|
return s.conn.Exists(ctx, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.Status, accountID string) (bool, db.Error) {
|
||||||
|
|
@ -351,7 +348,7 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St
|
||||||
Where("status_id = ?", status.ID).
|
Where("status_id = ?", status.ID).
|
||||||
Where("account_id = ?", accountID)
|
Where("account_id = ?", accountID)
|
||||||
|
|
||||||
return exists(ctx, q)
|
return s.conn.Exists(ctx, q)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
|
func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
|
||||||
|
|
@ -360,7 +357,7 @@ func (s *statusDB) GetStatusFaves(ctx context.Context, status *gtsmodel.Status)
|
||||||
q := s.newFaveQ(&faves).
|
q := s.newFaveQ(&faves).
|
||||||
Where("status_id = ?", status.ID)
|
Where("status_id = ?", status.ID)
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := s.conn.ProcessError(q.Scan(ctx))
|
||||||
return faves, err
|
return faves, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -370,6 +367,6 @@ func (s *statusDB) GetStatusReblogs(ctx context.Context, status *gtsmodel.Status
|
||||||
q := s.newStatusQ(&reblogs).
|
q := s.newStatusQ(&reblogs).
|
||||||
Where("boost_of_id = ?", status.ID)
|
Where("boost_of_id = ?", status.ID)
|
||||||
|
|
||||||
err := processErrorResponse(q.Scan(ctx))
|
err := s.conn.ProcessError(q.Scan(ctx))
|
||||||
return reblogs, err
|
return reblogs, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ import (
|
||||||
|
|
||||||
type timelineDB struct {
|
type timelineDB struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
conn *bun.DB
|
conn *dbConn
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -86,7 +86,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
|
||||||
|
|
||||||
q = q.WhereGroup(" AND ", whereGroup)
|
q = q.WhereGroup(" AND ", whereGroup)
|
||||||
|
|
||||||
return statuses, processErrorResponse(q.Scan(ctx))
|
return statuses, t.conn.ProcessError(q.Scan(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
|
func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
|
||||||
|
|
@ -121,13 +121,12 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, accountID string, ma
|
||||||
q = q.Limit(limit)
|
q = q.Limit(limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
return statuses, processErrorResponse(q.Scan(ctx))
|
return statuses, t.conn.ProcessError(q.Scan(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
|
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
|
||||||
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
|
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
|
||||||
func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
|
func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
|
||||||
|
|
||||||
faves := []*gtsmodel.StatusFave{}
|
faves := []*gtsmodel.StatusFave{}
|
||||||
|
|
||||||
fq := t.conn.
|
fq := t.conn.
|
||||||
|
|
|
||||||
|
|
@ -19,64 +19,9 @@
|
||||||
package bundb
|
package bundb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
)
|
)
|
||||||
|
|
||||||
// processErrorResponse parses the given error and returns an appropriate DBError.
|
|
||||||
func processErrorResponse(err error) db.Error {
|
|
||||||
switch err {
|
|
||||||
case nil:
|
|
||||||
return nil
|
|
||||||
case sql.ErrNoRows:
|
|
||||||
return db.ErrNoEntries
|
|
||||||
default:
|
|
||||||
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
|
|
||||||
return db.ErrAlreadyExists
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func exists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
|
|
||||||
count, err := q.Count(ctx)
|
|
||||||
|
|
||||||
exists := count != 0
|
|
||||||
|
|
||||||
err = processErrorResponse(err)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if err == db.ErrNoEntries {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return exists, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func notExists(ctx context.Context, q *bun.SelectQuery) (bool, db.Error) {
|
|
||||||
count, err := q.Count(ctx)
|
|
||||||
|
|
||||||
notExists := count == 0
|
|
||||||
|
|
||||||
err = processErrorResponse(err)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if err == db.ErrNoEntries {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return notExists, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// whereEmptyOrNull is a convenience function to return a bun WhereGroup that specifies
|
// whereEmptyOrNull is a convenience function to return a bun WhereGroup that specifies
|
||||||
// that the given column should be EITHER an empty string OR null.
|
// that the given column should be EITHER an empty string OR null.
|
||||||
//
|
//
|
||||||
|
|
|
||||||
1
vendor/modules.txt
vendored
1
vendor/modules.txt
vendored
|
|
@ -302,6 +302,7 @@ github.com/h2non/filetype/types
|
||||||
# github.com/jackc/chunkreader/v2 v2.0.1
|
# github.com/jackc/chunkreader/v2 v2.0.1
|
||||||
github.com/jackc/chunkreader/v2
|
github.com/jackc/chunkreader/v2
|
||||||
# github.com/jackc/pgconn v1.10.0
|
# github.com/jackc/pgconn v1.10.0
|
||||||
|
## explicit
|
||||||
github.com/jackc/pgconn
|
github.com/jackc/pgconn
|
||||||
github.com/jackc/pgconn/internal/ctxwatch
|
github.com/jackc/pgconn/internal/ctxwatch
|
||||||
github.com/jackc/pgconn/stmtcache
|
github.com/jackc/pgconn/stmtcache
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue