continue moving db stuff around

This commit is contained in:
tsmethurst 2021-08-17 15:23:28 +02:00
commit 15153ee0c8
72 changed files with 923 additions and 614 deletions

View file

@ -1,62 +1,100 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"errors"
"fmt"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) GetAccountHeader(header *gtsmodel.MediaAttachment, accountID string) error {
acct := &gtsmodel.Account{}
if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
if acct.HeaderMediaAttachmentID == "" {
return db.ErrNoEntries{}
}
if err := ps.conn.Model(header).Where("id = ?", acct.HeaderMediaAttachmentID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
type accountDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (ps *postgresService) GetAccountAvatar(avatar *gtsmodel.MediaAttachment, accountID string) error {
acct := &gtsmodel.Account{}
if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
if acct.AvatarMediaAttachmentID == "" {
return db.ErrNoEntries{}
}
if err := ps.conn.Model(avatar).Where("id = ?", acct.AvatarMediaAttachmentID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query {
return a.conn.Model(account).
Relation("AvatarMediaAttachment").
Relation("HeaderMediaAttachment")
}
func (ps *postgresService) GetAccountLastStatus(accountID string, status *gtsmodel.Status) error {
if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
func (a *accountDB) processResponse(account *gtsmodel.Account, err error) (*gtsmodel.Account, db.DBError) {
switch err {
case pg.ErrNoRows:
return nil, db.ErrNoEntries
case nil:
return account, nil
default:
return nil, err
}
}
func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.DBError) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
Where("account.id = ?", id)
return a.processResponse(account, q.Select())
}
func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.DBError) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
Where("account.uri = ?", uri)
return a.processResponse(account, q.Select())
}
func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.DBError) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account)
if domain == "" {
q = q.
Where("account.username = ?", domain).
Where("account.domain = ?", domain)
} else {
q = q.
Where("account.username = ?", domain).
Where("? IS NULL", pg.Ident("domain"))
}
return a.processResponse(account, q.Select())
}
func (a *accountDB) GetAccountLastStatus(accountID string, status *gtsmodel.Status) db.DBError {
if err := a.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
return db.ErrNoEntries
}
return err
}
@ -64,7 +102,7 @@ func (ps *postgresService) GetAccountLastStatus(accountID string, status *gtsmod
}
func (ps *postgresService) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error {
func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.DBError {
if mediaAttachment.Avatar && mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
@ -79,47 +117,47 @@ func (ps *postgresService) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.Me
}
// TODO: there are probably more side effects here that need to be handled
if _, err := ps.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
return err
}
if _, err := ps.conn.Model(&gtsmodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
if _, err := a.conn.Model(&gtsmodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
return err
}
return nil
}
func (ps *postgresService) GetAccountByUserID(userID string, account *gtsmodel.Account) error {
func (a *accountDB) GetAccountByUserID(userID string, account *gtsmodel.Account) db.DBError {
user := &gtsmodel.User{
ID: userID,
}
if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
if err := a.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
return db.ErrNoEntries
}
return err
}
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
if err := a.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
return db.ErrNoEntries
}
return err
}
return nil
}
func (ps *postgresService) GetLocalAccountByUsername(username string, account *gtsmodel.Account) error {
if err := ps.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil {
func (a *accountDB) GetLocalAccountByUsername(username string, account *gtsmodel.Account) db.DBError {
if err := a.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
return db.ErrNoEntries
}
return err
}
return nil
}
func (ps *postgresService) GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) error {
if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
func (a *accountDB) GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) db.DBError {
if err := a.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
@ -128,8 +166,8 @@ func (ps *postgresService) GetAccountFollowRequests(accountID string, followRequ
return nil
}
func (ps *postgresService) GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) error {
if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
func (a *accountDB) GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) db.DBError {
if err := a.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
@ -138,9 +176,13 @@ func (ps *postgresService) GetAccountFollowing(accountID string, following *[]gt
return nil
}
func (ps *postgresService) GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error {
func (a *accountDB) CountAccountFollowing(accountID string, localOnly bool) (int, db.DBError) {
return a.conn.Model(&[]*gtsmodel.Follow{}).Where("account_id = ?", accountID).Count()
}
q := ps.conn.Model(followers)
func (a *accountDB) GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) db.DBError {
q := a.conn.Model(followers)
if localOnly {
// for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
@ -168,8 +210,12 @@ func (ps *postgresService) GetAccountFollowers(accountID string, followers *[]gt
return nil
}
func (ps *postgresService) GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) error {
if err := ps.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil {
func (a *accountDB) CountAccountFollowers(accountID string, localOnly bool) (int, db.DBError) {
return a.conn.Model(&[]*gtsmodel.Follow{}).Where("target_account_id = ?", accountID).Count()
}
func (a *accountDB) GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) db.DBError {
if err := a.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
@ -178,22 +224,15 @@ func (ps *postgresService) GetAccountFaves(accountID string, faves *[]gtsmodel.S
return nil
}
func (ps *postgresService) GetAccountStatusesCount(accountID string) (int, error) {
count, err := ps.conn.Model(&gtsmodel.Status{}).Where("account_id = ?", accountID).Count()
if err != nil {
if err == pg.ErrNoRows {
return 0, nil
}
return 0, err
}
return count, nil
func (a *accountDB) CountAccountStatuses(accountID string) (int, db.DBError) {
return a.conn.Model(&gtsmodel.Status{}).Where("account_id = ?", accountID).Count()
}
func (ps *postgresService) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error) {
ps.log.Debugf("getting statuses for account %s", accountID)
func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.DBError) {
a.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
q := ps.conn.Model(&statuses).Order("id DESC")
q := a.conn.Model(&statuses).Order("id DESC")
if accountID != "" {
q = q.Where("account_id = ?", accountID)
}
@ -222,15 +261,57 @@ func (ps *postgresService) GetAccountStatuses(accountID string, limit int, exclu
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
ps.log.Debugf("returning statuses for account %s", accountID)
a.log.Debugf("returning statuses for account %s", accountID)
return statuses, nil
}
func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.DBError) {
blocks := []*gtsmodel.Block{}
fq := a.conn.Model(&blocks).
Where("block.account_id = ?", accountID).
Relation("TargetAccount").
Order("block.id DESC")
if maxID != "" {
fq = fq.Where("block.id < ?", maxID)
}
if sinceID != "" {
fq = fq.Where("block.id > ?", sinceID)
}
if limit > 0 {
fq = fq.Limit(limit)
}
err := fq.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(blocks) == 0 {
return nil, "", "", db.ErrNoEntries
}
accounts := []*gtsmodel.Account{}
for _, b := range blocks {
accounts = append(accounts, b.TargetAccount)
}
nextMaxID := blocks[len(blocks)-1].ID
prevMinID := blocks[0].ID
return accounts, nextMaxID, prevMinID, nil
}

View file

@ -1,6 +1,25 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"crypto/rand"
"crypto/rsa"
"fmt"
@ -10,17 +29,27 @@ import (
"time"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
"golang.org/x/crypto/bcrypt"
)
func (ps *postgresService) IsUsernameAvailable(username string) error {
type adminDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (a *adminDB) IsUsernameAvailable(username string) db.DBError {
// if no error we fail because it means we found something
// if error but it's not pg.ErrNoRows then we fail
// if err is pg.ErrNoRows we're good, we found nothing so continue
if err := ps.conn.Model(&gtsmodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
if err := a.conn.Model(&gtsmodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
return fmt.Errorf("username %s already in use", username)
} else if err != pg.ErrNoRows {
return fmt.Errorf("db error: %s", err)
@ -28,7 +57,7 @@ func (ps *postgresService) IsUsernameAvailable(username string) error {
return nil
}
func (ps *postgresService) IsEmailAvailable(email string) error {
func (a *adminDB) IsEmailAvailable(email string) db.DBError {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
@ -37,7 +66,7 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
if err := ps.conn.Model(&gtsmodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
if err := a.conn.Model(&gtsmodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email domain %s is blocked", domain)
} else if err != pg.ErrNoRows {
@ -46,7 +75,7 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
}
// check if this email is associated with a user already
if err := ps.conn.Model(&gtsmodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
if err := a.conn.Model(&gtsmodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email %s already in use", email)
} else if err != pg.ErrNoRows {
@ -56,16 +85,16 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
return nil
}
func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error) {
func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.DBError) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
ps.log.Errorf("error creating new rsa key: %s", err)
a.log.Errorf("error creating new rsa key: %s", err)
return nil, err
}
// if something went wrong while creating a user, we might already have an account, so check here first...
a := &gtsmodel.Account{}
err = ps.conn.Model(a).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
acct := &gtsmodel.Account{}
err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
if err != nil {
// there's been an actual error
if err != pg.ErrNoRows {
@ -73,13 +102,13 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
}
// we just don't have an account yet create one
newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host)
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
newAccountID, err := id.NewRandomULID()
if err != nil {
return nil, err
}
a = &gtsmodel.Account{
acct = &gtsmodel.Account{
ID: newAccountID,
Username: username,
DisplayName: username,
@ -96,7 +125,7 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
if _, err = ps.conn.Model(a).Insert(); err != nil {
if _, err = a.conn.Model(acct).Insert(); err != nil {
return nil, err
}
}
@ -113,7 +142,7 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
u := &gtsmodel.User{
ID: newUserID,
AccountID: a.ID,
AccountID: acct.ID,
EncryptedPassword: string(pw),
SignUpIP: signUpIP.To4(),
Locale: locale,
@ -132,18 +161,18 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
u.Moderator = true
}
if _, err = ps.conn.Model(u).Insert(); err != nil {
if _, err = a.conn.Model(u).Insert(); err != nil {
return nil, err
}
return u, nil
}
func (ps *postgresService) CreateInstanceAccount() error {
username := ps.config.Host
func (a *adminDB) CreateInstanceAccount() db.DBError {
username := a.config.Host
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
ps.log.Errorf("error creating new rsa key: %s", err)
a.log.Errorf("error creating new rsa key: %s", err)
return err
}
@ -152,10 +181,10 @@ func (ps *postgresService) CreateInstanceAccount() error {
return err
}
newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host)
a := &gtsmodel.Account{
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
acct := &gtsmodel.Account{
ID: aID,
Username: ps.config.Host,
Username: a.config.Host,
DisplayName: username,
URL: newAccountURIs.UserURL,
PrivateKey: key,
@ -169,19 +198,19 @@ func (ps *postgresService) CreateInstanceAccount() error {
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
inserted, err := ps.conn.Model(a).Where("username = ?", username).SelectOrInsert()
inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert()
if err != nil {
return err
}
if inserted {
ps.log.Infof("created instance account %s with id %s", username, a.ID)
a.log.Infof("created instance account %s with id %s", username, acct.ID)
} else {
ps.log.Infof("instance account %s already exists with id %s", username, a.ID)
a.log.Infof("instance account %s already exists with id %s", username, acct.ID)
}
return nil
}
func (ps *postgresService) CreateInstanceInstance() error {
func (a *adminDB) CreateInstanceInstance() db.DBError {
iID, err := id.NewRandomULID()
if err != nil {
return err
@ -189,18 +218,18 @@ func (ps *postgresService) CreateInstanceInstance() error {
i := &gtsmodel.Instance{
ID: iID,
Domain: ps.config.Host,
Title: ps.config.Host,
URI: fmt.Sprintf("%s://%s", ps.config.Protocol, ps.config.Host),
Domain: a.config.Host,
Title: a.config.Host,
URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host),
}
inserted, err := ps.conn.Model(i).Where("domain = ?", ps.config.Host).SelectOrInsert()
inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert()
if err != nil {
return err
}
if inserted {
ps.log.Infof("created instance instance %s with id %s", ps.config.Host, i.ID)
a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID)
} else {
ps.log.Infof("instance instance %s already exists with id %s", ps.config.Host, i.ID)
a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID)
}
return nil
}

View file

@ -1,25 +1,55 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"errors"
"fmt"
"strings"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
func (ps *postgresService) Put(i interface{}) error {
_, err := ps.conn.Model(i).Insert(i)
type basicDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (b *basicDB) Put(i interface{}) db.DBError {
_, err := b.conn.Model(i).Insert(i)
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists{}
return db.ErrAlreadyExists
}
return err
}
func (ps *postgresService) GetByID(id string, i interface{}) error {
if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil {
func (b *basicDB) GetByID(id string, i interface{}) db.DBError {
if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
return db.ErrNoEntries
}
return err
@ -27,12 +57,12 @@ func (ps *postgresService) GetByID(id string, i interface{}) error {
return nil
}
func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error {
func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.DBError {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := ps.conn.Model(i)
q := b.conn.Model(i)
for _, w := range where {
if w.Value == nil {
@ -48,25 +78,25 @@ func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error {
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
return db.ErrNoEntries
}
return err
}
return nil
}
func (ps *postgresService) GetAll(i interface{}) error {
if err := ps.conn.Model(i).Select(); err != nil {
func (b *basicDB) GetAll(i interface{}) db.DBError {
if err := b.conn.Model(i).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
return db.ErrNoEntries
}
return err
}
return nil
}
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
func (b *basicDB) DeleteByID(id string, i interface{}) db.DBError {
if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
// if there are no rows *anyway* then that's fine
// just return err if there's an actual error
if err != pg.ErrNoRows {
@ -76,12 +106,12 @@ func (ps *postgresService) DeleteByID(id string, i interface{}) error {
return nil
}
func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error {
func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.DBError {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := ps.conn.Model(i)
q := b.conn.Model(i)
for _, w := range where {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
@ -95,3 +125,76 @@ func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error {
}
return nil
}
func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.DBError {
if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) UpdateByID(id string, i interface{}) db.DBError {
if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.DBError {
_, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
return err
}
func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.DBError {
q := b.conn.Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", pg.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
}
}
q = q.Set("? = ?", pg.Safe(key), value)
_, err := q.Update()
return err
}
func (b *basicDB) CreateTable(i interface{}) db.DBError {
return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
}
func (b *basicDB) DropTable(i interface{}) db.DBError {
return b.conn.Model(i).DropTable(&orm.DropTableOptions{
IfExists: true,
})
}
func (b *basicDB) IsHealthy(ctx context.Context) db.DBError {
return b.conn.Ping(ctx)
}
func (b *basicDB) Stop(ctx context.Context) db.DBError {
b.log.Info("closing db connection")
if err := b.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
b.cancel()
return err
}
return nil
}

View file

@ -19,15 +19,26 @@
package pg
import (
"context"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
q := ps.conn.Model(&[]*gtsmodel.Account{})
type instanceDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
if domain == ps.config.Host {
func (i *instanceDB) GetUserCountForInstance(domain string) (int, db.DBError) {
q := i.conn.Model(&[]*gtsmodel.Account{})
if domain == i.config.Host {
// if the domain is *this* domain, just count where the domain field is null
q = q.Where("? IS NULL", pg.Ident("domain"))
} else {
@ -40,10 +51,10 @@ func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
return q.Count()
}
func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) {
q := ps.conn.Model(&[]*gtsmodel.Status{})
func (i *instanceDB) GetStatusCountForInstance(domain string) (int, db.DBError) {
q := i.conn.Model(&[]*gtsmodel.Status{})
if domain == ps.config.Host {
if domain == i.config.Host {
// if the domain is *this* domain, just count where local is true
q = q.Where("local = ?", true)
} else {
@ -55,10 +66,10 @@ func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error)
return q.Count()
}
func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) {
q := ps.conn.Model(&[]*gtsmodel.Instance{})
func (i *instanceDB) GetDomainCountForInstance(domain string) (int, db.DBError) {
q := i.conn.Model(&[]*gtsmodel.Instance{})
if domain == ps.config.Host {
if domain == i.config.Host {
// if the domain is *this* domain, just count other instances it knows about
// exclude domains that are blocked
q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
@ -70,12 +81,12 @@ func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error)
return q.Count()
}
func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) {
ps.log.Debug("GetAccountsForInstance")
func (i *instanceDB) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.DBError) {
i.log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{}
q := ps.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
if maxID != "" {
q = q.Where("id < ?", maxID)
@ -88,13 +99,13 @@ func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, l
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return nil, err
}
if len(accounts) == 0 {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return accounts, nil

View file

@ -24,44 +24,30 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) {
blocks := []*gtsmodel.Block{}
func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.DBError) {
notifications := []*gtsmodel.Notification{}
fq := ps.conn.Model(&blocks).
Where("block.account_id = ?", accountID).
Relation("TargetAccount").
Order("block.id DESC")
q := ps.conn.Model(&notifications).Where("target_account_id = ?", accountID)
if maxID != "" {
fq = fq.Where("block.id < ?", maxID)
q = q.Where("id < ?", maxID)
}
if sinceID != "" {
fq = fq.Where("block.id > ?", sinceID)
q = q.Where("id > ?", sinceID)
}
if limit > 0 {
fq = fq.Limit(limit)
if limit != 0 {
q = q.Limit(limit)
}
err := fq.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries{}
q = q.Order("created_at DESC")
if err := q.Select(); err != nil {
if err != pg.ErrNoRows {
return nil, err
}
return nil, "", "", err
}
if len(blocks) == 0 {
return nil, "", "", db.ErrNoEntries{}
}
accounts := []*gtsmodel.Account{}
for _, b := range blocks {
accounts = append(accounts, b.TargetAccount)
}
nextMaxID := blocks[len(blocks)-1].ID
prevMinID := blocks[0].ID
return accounts, nextMaxID, prevMinID, nil
return notifications, nil
}

View file

@ -31,7 +31,6 @@ import (
"github.com/go-pg/pg/extra/pgdebug"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
@ -41,6 +40,14 @@ import (
// postgresService satisfies the DB interface
type postgresService struct {
db.Account
db.Admin
db.Basic
db.Instance
db.Notification
db.Relationship
db.Status
db.Timeline
config *config.Config
conn *pg.DB
log *logrus.Logger
@ -85,6 +92,48 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
log.Infof("connected to postgres version: %s", version)
ps := &postgresService{
Account: &accountDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Admin: &adminDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Basic: &basicDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Instance: &instanceDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Relationship: &relationshipDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Status: &statusDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Timeline: &timelineDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
config: c,
conn: conn,
log: log,
@ -193,89 +242,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
return options, nil
}
/*
BASIC DB FUNCTIONALITY
*/
func (ps *postgresService) CreateTable(i interface{}) error {
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
}
func (ps *postgresService) DropTable(i interface{}) error {
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
IfExists: true,
})
}
func (ps *postgresService) Stop(ctx context.Context) error {
ps.log.Info("closing db connection")
if err := ps.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
ps.cancel()
return err
}
return nil
}
func (ps *postgresService) IsHealthy(ctx context.Context) error {
return ps.conn.Ping(ctx)
}
func (ps *postgresService) CreateSchema(ctx context.Context) error {
models := []interface{}{
(*gtsmodel.Account)(nil),
(*gtsmodel.Status)(nil),
(*gtsmodel.User)(nil),
}
ps.log.Info("creating db schema")
for _, model := range models {
err := ps.conn.Model(model).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
if err != nil {
return err
}
}
ps.log.Info("db schema created")
return nil
}
/*
HANDY SHORTCUTS
*/
func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error) {
notifications := []*gtsmodel.Notification{}
q := ps.conn.Model(&notifications).Where("target_account_id = ?", accountID)
if maxID != "" {
q = q.Where("id < ?", maxID)
}
if sinceID != "" {
q = q.Where("id > ?", sinceID)
}
if limit != 0 {
q = q.Limit(limit)
}
q = q.Order("created_at DESC")
if err := q.Select(); err != nil {
if err != pg.ErrNoRows {
return nil, err
}
}
return notifications, nil
}
/*
CONVERSION FUNCTIONS
*/

47
internal/db/pg/pg_test.go Normal file
View file

@ -0,0 +1,47 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
type PGStandardTestSuite struct {
// standard suite interfaces
suite.Suite
config *config.Config
db db.DB
log *logrus.Logger
// standard suite models
testTokens map[string]*oauth.Token
testClients map[string]*oauth.Client
testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account
testAttachments map[string]*gtsmodel.MediaAttachment
testStatuses map[string]*gtsmodel.Status
testTags map[string]*gtsmodel.Tag
testMentions map[string]*gtsmodel.Mention
}

View file

@ -1,17 +1,45 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"fmt"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) Blocked(account1 string, account2 string) (bool, error) {
type relationshipDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (r *relationshipDB) Blocked(account1 string, account2 string) (bool, db.DBError) {
// TODO: check domain blocks as well
var blocked bool
if err := ps.conn.Model(&gtsmodel.Block{}).
if err := r.conn.Model(&gtsmodel.Block{}).
Where("account_id = ?", account1).Where("target_account_id = ?", account2).
WhereOr("target_account_id = ?", account1).Where("account_id = ?", account2).
Select(); err != nil {
@ -25,83 +53,83 @@ func (ps *postgresService) Blocked(account1 string, account2 string) (bool, erro
return blocked, nil
}
func (ps *postgresService) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) {
r := &gtsmodel.Relationship{
func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.DBError) {
rel := &gtsmodel.Relationship{
ID: targetAccount,
}
// check if the requesting account follows the target account
follow := &gtsmodel.Follow{}
if err := ps.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
if err != pg.ErrNoRows {
// a proper error
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
}
// no follow exists so these are all false
r.Following = false
r.ShowingReblogs = false
r.Notifying = false
rel.Following = false
rel.ShowingReblogs = false
rel.Notifying = false
} else {
// follow exists so we can fill these fields out...
r.Following = true
r.ShowingReblogs = follow.ShowReblogs
r.Notifying = follow.Notify
rel.Following = true
rel.ShowingReblogs = follow.ShowReblogs
rel.Notifying = follow.Notify
}
// check if the target account follows the requesting account
followedBy, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
followedBy, err := r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
}
r.FollowedBy = followedBy
rel.FollowedBy = followedBy
// check if the requesting account blocks the target account
blocking, err := ps.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
blocking, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
}
r.Blocking = blocking
rel.Blocking = blocking
// check if the target account blocks the requesting account
blockedBy, err := ps.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
blockedBy, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
r.BlockedBy = blockedBy
rel.BlockedBy = blockedBy
// check if there's a pending following request from requesting account to target account
requested, err := ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
requested, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
r.Requested = requested
rel.Requested = requested
return r, nil
return rel, nil
}
func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
func (r *relationshipDB) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.DBError) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
return ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
return r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
func (r *relationshipDB) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.DBError) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
return ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
return r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error) {
func (r *relationshipDB) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.DBError) {
if account1 == nil || account2 == nil {
return false, nil
}
// make sure account 1 follows account 2
f1, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists()
f1, err := r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists()
if err != nil {
if err == pg.ErrNoRows {
return false, nil
@ -110,7 +138,7 @@ func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmode
}
// make sure account 2 follows account 1
f2, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account2.ID).Where("target_account_id = ?", account1.ID).Exists()
f2, err := r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account2.ID).Where("target_account_id = ?", account1.ID).Exists()
if err != nil {
if err == pg.ErrNoRows {
return false, nil
@ -121,12 +149,12 @@ func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmode
return f1 && f2, nil
}
func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.DBError) {
// make sure the original follow request exists
fr := &gtsmodel.FollowRequest{}
if err := ps.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
if err == pg.ErrMultiRows {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return nil, err
}
@ -140,12 +168,12 @@ func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAcc
}
// if the follow already exists, just update the URI -- we don't need to do anything else
if _, err := ps.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
return nil, err
}
// now remove the follow request
if _, err := ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
if _, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
return nil, err
}

View file

@ -20,39 +20,90 @@ package pg
import (
"container/list"
"context"
"errors"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) {
type statusDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (s *statusDB) newStatusQ(status *gtsmodel.Status) *orm.Query {
return s.conn.Model(status).
Relation("Account").
Relation("InReplyTo").
Relation("InReplyToAccount").
Relation("BoostOf").
Relation("BoostOfAccount").
Relation("CreatedWithApplication")
}
func (s *statusDB) processResponse(status *gtsmodel.Status, err error) (*gtsmodel.Status, db.DBError) {
switch err {
case pg.ErrNoRows:
return nil, db.ErrNoEntries
case nil:
return status, nil
default:
return nil, err
}
}
func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.DBError) {
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("status.id = ?", id)
return s.processResponse(status, q.Select())
}
func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.DBError) {
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.uri) = LOWER(?)", uri)
return s.processResponse(status, q.Select())
}
func (s *statusDB) StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.DBError) {
parents := []*gtsmodel.Status{}
ps.statusParent(status, &parents, onlyDirect)
s.statusParent(status, &parents, onlyDirect)
return parents, nil
}
func (ps *postgresService) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
if status.InReplyToID == "" {
return
}
parentStatus := &gtsmodel.Status{}
if err := ps.conn.Model(parentStatus).Where("id = ?", status.InReplyToID).Select(); err == nil {
if err := s.conn.Model(parentStatus).Where("id = ?", status.InReplyToID).Select(); err == nil {
*foundStatuses = append(*foundStatuses, parentStatus)
}
if onlyDirect {
return
}
ps.statusParent(parentStatus, foundStatuses, false)
s.statusParent(parentStatus, foundStatuses, false)
}
func (ps *postgresService) StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) {
func (s *statusDB) StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.DBError) {
foundStatuses := &list.List{}
foundStatuses.PushFront(status)
ps.statusChildren(status, foundStatuses, onlyDirect, minID)
s.statusChildren(status, foundStatuses, onlyDirect, minID)
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
@ -70,10 +121,10 @@ func (ps *postgresService) StatusChildren(status *gtsmodel.Status, onlyDirect bo
return children, nil
}
func (ps *postgresService) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
immediateChildren := []*gtsmodel.Status{}
q := ps.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
if minID != "" {
q = q.Where("status.id > ?", minID)
}
@ -100,43 +151,43 @@ func (ps *postgresService) statusChildren(status *gtsmodel.Status, foundStatuses
if onlyDirect {
return
}
ps.statusChildren(child, foundStatuses, false, minID)
s.statusChildren(child, foundStatuses, false, minID)
}
}
func (ps *postgresService) GetReplyCountForStatus(status *gtsmodel.Status) (int, error) {
return ps.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
func (s *statusDB) GetReplyCountForStatus(status *gtsmodel.Status) (int, db.DBError) {
return s.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
}
func (ps *postgresService) GetReblogCountForStatus(status *gtsmodel.Status) (int, error) {
return ps.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
func (s *statusDB) GetReblogCountForStatus(status *gtsmodel.Status) (int, db.DBError) {
return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
}
func (ps *postgresService) GetFaveCountForStatus(status *gtsmodel.Status) (int, error) {
return ps.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
func (s *statusDB) GetFaveCountForStatus(status *gtsmodel.Status) (int, db.DBError) {
return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
}
func (ps *postgresService) StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
func (s *statusDB) StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
func (s *statusDB) StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
func (s *statusDB) StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
return s.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
func (s *statusDB) StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
return s.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
func (s *statusDB) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, db.DBError) {
accounts := []*gtsmodel.Account{}
faves := []*gtsmodel.StatusFave{}
if err := ps.conn.Model(&faves).Where("status_id = ?", status.ID).Select(); err != nil {
if err := s.conn.Model(&faves).Where("status_id = ?", status.ID).Select(); err != nil {
if err == pg.ErrNoRows {
return accounts, nil // no rows just means nobody has faved this status, so that's fine
}
@ -145,7 +196,7 @@ func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.
for _, f := range faves {
acc := &gtsmodel.Account{}
if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
if err := s.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
}
@ -156,11 +207,11 @@ func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.
return accounts, nil
}
func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
func (s *statusDB) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, db.DBError) {
accounts := []*gtsmodel.Account{}
boosts := []*gtsmodel.Status{}
if err := ps.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil {
if err := s.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil {
if err == pg.ErrNoRows {
return accounts, nil // no rows just means nobody has boosted this status, so that's fine
}
@ -169,7 +220,7 @@ func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmode
for _, f := range boosts {
acc := &gtsmodel.Account{}
if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
if err := s.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
}

View file

@ -0,0 +1,86 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
import (
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type StatusTestSuite struct {
PGStandardTestSuite
}
func (suite *PGStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
suite.testAttachments = testrig.NewTestAttachments()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
}
func (suite *PGStandardTestSuite) SetupTest() {
suite.config = testrig.NewTestConfig()
suite.db = testrig.NewTestDB()
suite.log = testrig.NewTestLog()
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
func (suite *PGStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *PGStandardTestSuite) TestGetStatusByID() {
status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.Nil(status.BoostOf)
suite.Nil(status.BoostOfAccount)
suite.Nil(status.InReplyTo)
suite.Nil(status.InReplyToAccount)
}
func (suite *PGStandardTestSuite) TestGetStatusByURI() {
status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.Nil(status.BoostOf)
suite.Nil(status.BoostOfAccount)
suite.Nil(status.InReplyTo)
suite.Nil(status.InReplyToAccount)
}
func TestStatusTestSuite(t *testing.T) {
suite.Run(t, new(PGStandardTestSuite))
}

View file

@ -19,16 +19,26 @@
package pg
import (
"context"
"sort"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
type timelineDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (t *timelineDB) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.DBError) {
statuses := []*gtsmodel.Status{}
q := ps.conn.Model(&statuses)
q := t.conn.Model(&statuses)
q = q.ColumnExpr("status.*").
// Find out who accountID follows.
@ -74,22 +84,22 @@ func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID str
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return statuses, nil
}
func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
func (t *timelineDB) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.DBError) {
statuses := []*gtsmodel.Status{}
q := ps.conn.Model(&statuses).
q := t.conn.Model(&statuses).
Where("visibility = ?", gtsmodel.VisibilityPublic).
Where("? IS NULL", pg.Ident("in_reply_to_id")).
Where("? IS NULL", pg.Ident("in_reply_to_uri")).
@ -119,13 +129,13 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return statuses, nil
@ -133,11 +143,11 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) {
func (t *timelineDB) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.DBError) {
faves := []*gtsmodel.StatusFave{}
fq := ps.conn.Model(&faves).
fq := t.conn.Model(&faves).
Where("account_id = ?", accountID).
Order("id DESC")
@ -156,13 +166,13 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st
err := fq.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries{}
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(faves) == 0 {
return nil, "", "", db.ErrNoEntries{}
return nil, "", "", db.ErrNoEntries
}
// map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID
@ -175,16 +185,16 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st
}
statuses := []*gtsmodel.Status{}
err = ps.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries{}
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(statuses) == 0 {
return nil, "", "", db.ErrNoEntries{}
return nil, "", "", db.ErrNoEntries
}
// arrange statuses by fave ID

View file

@ -1,73 +0,0 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"fmt"
"github.com/go-pg/pg/v10"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
func (ps *postgresService) Upsert(i interface{}, conflictColumn string) error {
if _, err := ps.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) UpdateByID(id string, i interface{}) error {
if _, err := ps.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
_, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
return err
}
func (ps *postgresService) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) error {
q := ps.conn.Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", pg.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
}
}
q = q.Set("? = ?", pg.Safe(key), value)
_, err := q.Update()
return err
}