continue moving db stuff around

This commit is contained in:
tsmethurst 2021-08-17 15:23:28 +02:00
commit 15153ee0c8
72 changed files with 923 additions and 614 deletions

View file

@ -116,7 +116,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
return user, nil
}
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// we have an actual error in the database
return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err)
}
@ -128,7 +128,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email)
}
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// we have an actual error in the database
return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err)
}

View file

@ -59,7 +59,7 @@ func (m *Module) blockedDomain(host string) (bool, error) {
return true, nil
}
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries so there's no block
return false, nil
}

View file

@ -6,60 +6,66 @@ type Account interface {
// GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID.
// The given account pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetAccountByUserID(userID string, account *gtsmodel.Account) error
GetAccountByUserID(userID string, account *gtsmodel.Account) DBError
// GetAccountByID returns one account with the given ID, or an error if something goes wrong.
GetAccountByID(id string) (*gtsmodel.Account, DBError)
// GetAccountByURI returns one account with the given URI, or an error if something goes wrong.
GetAccountByURI(uri string) (*gtsmodel.Account, DBError)
// GetLocalAccountByUsername is a shortcut for the common action of fetching an account ON THIS INSTANCE
// according to its username, which should be unique.
// The given account pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetLocalAccountByUsername(username string, account *gtsmodel.Account) error
GetLocalAccountByUsername(username string, account *gtsmodel.Account) DBError
// GetAccountFollowRequests is a shortcut for the common action of fetching a list of follow requests targeting the given account ID.
// The given slice 'followRequests' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) error
GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) DBError
// GetAccountFollowing is a shortcut for the common action of fetching a list of accounts that accountID is following.
// The given slice 'following' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) error
GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) DBError
CountAccountFollowing(accountID string, localOnly bool) (int, DBError)
// GetAccountFollowers is a shortcut for the common action of fetching a list of accounts that accountID is followed by.
// The given slice 'followers' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
//
// If localOnly is set to true, then only followers from *this instance* will be returned.
GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error
GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) DBError
CountAccountFollowers(accountID string, localOnly bool) (int, DBError)
// GetAccountFaves is a shortcut for the common action of fetching a list of faves made by the given accountID.
// The given slice 'faves' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) error
GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) DBError
// GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
GetAccountStatusesCount(accountID string) (int, error)
CountAccountStatuses(accountID string) (int, DBError)
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
// In case of no entries, a 'no entries' error will be returned
GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error)
GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, DBError)
GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error)
GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, DBError)
// GetAccountLastStatus simply gets the most recent status by the given account.
// The given slice 'status' pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetAccountLastStatus(accountID string, status *gtsmodel.Status) error
GetAccountLastStatus(accountID string, status *gtsmodel.Status) DBError
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error
SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) DBError
// GetHeaderAvatarForAccountID gets the current avatar for the given account ID.
// The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists.
GetAccountAvatar(avatar *gtsmodel.MediaAttachment, accountID string) error
// GetAccountHeader gets the current header for the given account ID.
// The passed mediaAttachment pointer will be populated with the value of the header, if it exists.
GetAccountHeader(header *gtsmodel.MediaAttachment, accountID string) error
// GetInstanceAccount returns the instance account for the given domain.
// If domain is empty, this instance account will be returned.
GetInstanceAccount(domain string) (*gtsmodel.Account, DBError)
}

View file

@ -9,26 +9,26 @@ import (
type Admin interface {
// IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db.
IsUsernameAvailable(username string) error
IsUsernameAvailable(username string) DBError
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if:
// A) the email is already associated with an account
// B) we block signups from this email domain
// C) something went wrong in the db
IsEmailAvailable(email string) error
IsEmailAvailable(email string) DBError
// NewSignup creates a new user in the database with the given parameters.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error)
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, DBError)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
// This is needed for things like serving files that belong to the instance and not an individual user/account.
CreateInstanceAccount() error
CreateInstanceAccount() DBError
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
// This is needed for things like serving instance information through /api/v1/instance
CreateInstanceInstance() error
CreateInstanceInstance() DBError
}

View file

@ -5,60 +5,60 @@ import "context"
type Basic interface {
// CreateTable creates a table for the given interface.
// For implementations that don't use tables, this can just return nil.
CreateTable(i interface{}) error
CreateTable(i interface{}) DBError
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
DropTable(i interface{}) error
DropTable(i interface{}) DBError
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil.
Stop(ctx context.Context) error
Stop(ctx context.Context) DBError
// IsHealthy should return nil if the database connection is healthy, or an error if not.
IsHealthy(ctx context.Context) error
IsHealthy(ctx context.Context) DBError
// GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
// for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetByID(id string, i interface{}) error
GetByID(id string, i interface{}) DBError
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetWhere(where []Where, i interface{}) error
GetWhere(where []Where, i interface{}) DBError
// GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetAll(i interface{}) error
GetAll(i interface{}) DBError
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Put(i interface{}) error
Put(i interface{}) DBError
// Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
// It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Upsert(i interface{}, conflictColumn string) error
Upsert(i interface{}, conflictColumn string) DBError
// UpdateByID updates i with id id.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
UpdateByID(id string, i interface{}) error
UpdateByID(id string, i interface{}) DBError
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
UpdateOneByID(id string, key string, value interface{}, i interface{}) error
UpdateOneByID(id string, key string, value interface{}, i interface{}) DBError
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
UpdateWhere(where []Where, key string, value interface{}, i interface{}) error
UpdateWhere(where []Where, key string, value interface{}, i interface{}) DBError
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
DeleteByID(id string, i interface{}) error
DeleteByID(id string, i interface{}) DBError
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
DeleteWhere(where []Where, i interface{}) error
DeleteWhere(where []Where, i interface{}) DBError
}

View file

@ -18,16 +18,12 @@
package db
// ErrNoEntries is to be returned from the DB interface when no entries are found for a given query.
type ErrNoEntries struct{}
import "fmt"
func (e ErrNoEntries) Error() string {
return "no entries"
}
type DBError error
// ErrAlreadyExists is to be returned from the DB interface when an entry already exists for a given query or its constraints.
type ErrAlreadyExists struct{}
func (e ErrAlreadyExists) Error() string {
return "already exists"
}
var (
ErrNoEntries DBError = fmt.Errorf("no entries")
ErrAlreadyExists DBError = fmt.Errorf("already exists")
ErrUnknown DBError = fmt.Errorf("unknown error")
)

View file

@ -4,14 +4,14 @@ import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
type Instance interface {
// GetUserCountForInstance returns the number of known accounts registered with the given domain.
GetUserCountForInstance(domain string) (int, error)
GetUserCountForInstance(domain string) (int, DBError)
// GetStatusCountForInstance returns the number of known statuses posted from the given domain.
GetStatusCountForInstance(domain string) (int, error)
GetStatusCountForInstance(domain string) (int, DBError)
// GetDomainCountForInstance returns the number of known instances known that the given domain federates with.
GetDomainCountForInstance(domain string) (int, error)
GetDomainCountForInstance(domain string) (int, DBError)
// GetAccountsForInstance returns a slice of accounts from the given instance, arranged by ID.
GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error)
GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, DBError)
}

View file

@ -4,5 +4,5 @@ import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
type Notification interface {
// GetNotificationsForAccount returns a list of notifications that pertain to the given accountID.
GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error)
GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, DBError)
}

View file

@ -1,62 +1,100 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"errors"
"fmt"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) GetAccountHeader(header *gtsmodel.MediaAttachment, accountID string) error {
acct := &gtsmodel.Account{}
if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
if acct.HeaderMediaAttachmentID == "" {
return db.ErrNoEntries{}
}
if err := ps.conn.Model(header).Where("id = ?", acct.HeaderMediaAttachmentID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
type accountDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (ps *postgresService) GetAccountAvatar(avatar *gtsmodel.MediaAttachment, accountID string) error {
acct := &gtsmodel.Account{}
if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
if acct.AvatarMediaAttachmentID == "" {
return db.ErrNoEntries{}
}
if err := ps.conn.Model(avatar).Where("id = ?", acct.AvatarMediaAttachmentID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query {
return a.conn.Model(account).
Relation("AvatarMediaAttachment").
Relation("HeaderMediaAttachment")
}
func (ps *postgresService) GetAccountLastStatus(accountID string, status *gtsmodel.Status) error {
if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
func (a *accountDB) processResponse(account *gtsmodel.Account, err error) (*gtsmodel.Account, db.DBError) {
switch err {
case pg.ErrNoRows:
return nil, db.ErrNoEntries
case nil:
return account, nil
default:
return nil, err
}
}
func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.DBError) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
Where("account.id = ?", id)
return a.processResponse(account, q.Select())
}
func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.DBError) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
Where("account.uri = ?", uri)
return a.processResponse(account, q.Select())
}
func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.DBError) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account)
if domain == "" {
q = q.
Where("account.username = ?", domain).
Where("account.domain = ?", domain)
} else {
q = q.
Where("account.username = ?", domain).
Where("? IS NULL", pg.Ident("domain"))
}
return a.processResponse(account, q.Select())
}
func (a *accountDB) GetAccountLastStatus(accountID string, status *gtsmodel.Status) db.DBError {
if err := a.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
return db.ErrNoEntries
}
return err
}
@ -64,7 +102,7 @@ func (ps *postgresService) GetAccountLastStatus(accountID string, status *gtsmod
}
func (ps *postgresService) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error {
func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.DBError {
if mediaAttachment.Avatar && mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
@ -79,47 +117,47 @@ func (ps *postgresService) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.Me
}
// TODO: there are probably more side effects here that need to be handled
if _, err := ps.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
return err
}
if _, err := ps.conn.Model(&gtsmodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
if _, err := a.conn.Model(&gtsmodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
return err
}
return nil
}
func (ps *postgresService) GetAccountByUserID(userID string, account *gtsmodel.Account) error {
func (a *accountDB) GetAccountByUserID(userID string, account *gtsmodel.Account) db.DBError {
user := &gtsmodel.User{
ID: userID,
}
if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
if err := a.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
return db.ErrNoEntries
}
return err
}
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
if err := a.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
return db.ErrNoEntries
}
return err
}
return nil
}
func (ps *postgresService) GetLocalAccountByUsername(username string, account *gtsmodel.Account) error {
if err := ps.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil {
func (a *accountDB) GetLocalAccountByUsername(username string, account *gtsmodel.Account) db.DBError {
if err := a.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
return db.ErrNoEntries
}
return err
}
return nil
}
func (ps *postgresService) GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) error {
if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
func (a *accountDB) GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) db.DBError {
if err := a.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
@ -128,8 +166,8 @@ func (ps *postgresService) GetAccountFollowRequests(accountID string, followRequ
return nil
}
func (ps *postgresService) GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) error {
if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
func (a *accountDB) GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) db.DBError {
if err := a.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
@ -138,9 +176,13 @@ func (ps *postgresService) GetAccountFollowing(accountID string, following *[]gt
return nil
}
func (ps *postgresService) GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error {
func (a *accountDB) CountAccountFollowing(accountID string, localOnly bool) (int, db.DBError) {
return a.conn.Model(&[]*gtsmodel.Follow{}).Where("account_id = ?", accountID).Count()
}
q := ps.conn.Model(followers)
func (a *accountDB) GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) db.DBError {
q := a.conn.Model(followers)
if localOnly {
// for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
@ -168,8 +210,12 @@ func (ps *postgresService) GetAccountFollowers(accountID string, followers *[]gt
return nil
}
func (ps *postgresService) GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) error {
if err := ps.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil {
func (a *accountDB) CountAccountFollowers(accountID string, localOnly bool) (int, db.DBError) {
return a.conn.Model(&[]*gtsmodel.Follow{}).Where("target_account_id = ?", accountID).Count()
}
func (a *accountDB) GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) db.DBError {
if err := a.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
@ -178,22 +224,15 @@ func (ps *postgresService) GetAccountFaves(accountID string, faves *[]gtsmodel.S
return nil
}
func (ps *postgresService) GetAccountStatusesCount(accountID string) (int, error) {
count, err := ps.conn.Model(&gtsmodel.Status{}).Where("account_id = ?", accountID).Count()
if err != nil {
if err == pg.ErrNoRows {
return 0, nil
}
return 0, err
}
return count, nil
func (a *accountDB) CountAccountStatuses(accountID string) (int, db.DBError) {
return a.conn.Model(&gtsmodel.Status{}).Where("account_id = ?", accountID).Count()
}
func (ps *postgresService) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error) {
ps.log.Debugf("getting statuses for account %s", accountID)
func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.DBError) {
a.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
q := ps.conn.Model(&statuses).Order("id DESC")
q := a.conn.Model(&statuses).Order("id DESC")
if accountID != "" {
q = q.Where("account_id = ?", accountID)
}
@ -222,15 +261,57 @@ func (ps *postgresService) GetAccountStatuses(accountID string, limit int, exclu
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
ps.log.Debugf("returning statuses for account %s", accountID)
a.log.Debugf("returning statuses for account %s", accountID)
return statuses, nil
}
func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.DBError) {
blocks := []*gtsmodel.Block{}
fq := a.conn.Model(&blocks).
Where("block.account_id = ?", accountID).
Relation("TargetAccount").
Order("block.id DESC")
if maxID != "" {
fq = fq.Where("block.id < ?", maxID)
}
if sinceID != "" {
fq = fq.Where("block.id > ?", sinceID)
}
if limit > 0 {
fq = fq.Limit(limit)
}
err := fq.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(blocks) == 0 {
return nil, "", "", db.ErrNoEntries
}
accounts := []*gtsmodel.Account{}
for _, b := range blocks {
accounts = append(accounts, b.TargetAccount)
}
nextMaxID := blocks[len(blocks)-1].ID
prevMinID := blocks[0].ID
return accounts, nextMaxID, prevMinID, nil
}

View file

@ -1,6 +1,25 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"crypto/rand"
"crypto/rsa"
"fmt"
@ -10,17 +29,27 @@ import (
"time"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
"golang.org/x/crypto/bcrypt"
)
func (ps *postgresService) IsUsernameAvailable(username string) error {
type adminDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (a *adminDB) IsUsernameAvailable(username string) db.DBError {
// if no error we fail because it means we found something
// if error but it's not pg.ErrNoRows then we fail
// if err is pg.ErrNoRows we're good, we found nothing so continue
if err := ps.conn.Model(&gtsmodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
if err := a.conn.Model(&gtsmodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
return fmt.Errorf("username %s already in use", username)
} else if err != pg.ErrNoRows {
return fmt.Errorf("db error: %s", err)
@ -28,7 +57,7 @@ func (ps *postgresService) IsUsernameAvailable(username string) error {
return nil
}
func (ps *postgresService) IsEmailAvailable(email string) error {
func (a *adminDB) IsEmailAvailable(email string) db.DBError {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
@ -37,7 +66,7 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
if err := ps.conn.Model(&gtsmodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
if err := a.conn.Model(&gtsmodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email domain %s is blocked", domain)
} else if err != pg.ErrNoRows {
@ -46,7 +75,7 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
}
// check if this email is associated with a user already
if err := ps.conn.Model(&gtsmodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
if err := a.conn.Model(&gtsmodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email %s already in use", email)
} else if err != pg.ErrNoRows {
@ -56,16 +85,16 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
return nil
}
func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error) {
func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.DBError) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
ps.log.Errorf("error creating new rsa key: %s", err)
a.log.Errorf("error creating new rsa key: %s", err)
return nil, err
}
// if something went wrong while creating a user, we might already have an account, so check here first...
a := &gtsmodel.Account{}
err = ps.conn.Model(a).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
acct := &gtsmodel.Account{}
err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
if err != nil {
// there's been an actual error
if err != pg.ErrNoRows {
@ -73,13 +102,13 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
}
// we just don't have an account yet create one
newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host)
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
newAccountID, err := id.NewRandomULID()
if err != nil {
return nil, err
}
a = &gtsmodel.Account{
acct = &gtsmodel.Account{
ID: newAccountID,
Username: username,
DisplayName: username,
@ -96,7 +125,7 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
if _, err = ps.conn.Model(a).Insert(); err != nil {
if _, err = a.conn.Model(acct).Insert(); err != nil {
return nil, err
}
}
@ -113,7 +142,7 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
u := &gtsmodel.User{
ID: newUserID,
AccountID: a.ID,
AccountID: acct.ID,
EncryptedPassword: string(pw),
SignUpIP: signUpIP.To4(),
Locale: locale,
@ -132,18 +161,18 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
u.Moderator = true
}
if _, err = ps.conn.Model(u).Insert(); err != nil {
if _, err = a.conn.Model(u).Insert(); err != nil {
return nil, err
}
return u, nil
}
func (ps *postgresService) CreateInstanceAccount() error {
username := ps.config.Host
func (a *adminDB) CreateInstanceAccount() db.DBError {
username := a.config.Host
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
ps.log.Errorf("error creating new rsa key: %s", err)
a.log.Errorf("error creating new rsa key: %s", err)
return err
}
@ -152,10 +181,10 @@ func (ps *postgresService) CreateInstanceAccount() error {
return err
}
newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host)
a := &gtsmodel.Account{
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
acct := &gtsmodel.Account{
ID: aID,
Username: ps.config.Host,
Username: a.config.Host,
DisplayName: username,
URL: newAccountURIs.UserURL,
PrivateKey: key,
@ -169,19 +198,19 @@ func (ps *postgresService) CreateInstanceAccount() error {
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
inserted, err := ps.conn.Model(a).Where("username = ?", username).SelectOrInsert()
inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert()
if err != nil {
return err
}
if inserted {
ps.log.Infof("created instance account %s with id %s", username, a.ID)
a.log.Infof("created instance account %s with id %s", username, acct.ID)
} else {
ps.log.Infof("instance account %s already exists with id %s", username, a.ID)
a.log.Infof("instance account %s already exists with id %s", username, acct.ID)
}
return nil
}
func (ps *postgresService) CreateInstanceInstance() error {
func (a *adminDB) CreateInstanceInstance() db.DBError {
iID, err := id.NewRandomULID()
if err != nil {
return err
@ -189,18 +218,18 @@ func (ps *postgresService) CreateInstanceInstance() error {
i := &gtsmodel.Instance{
ID: iID,
Domain: ps.config.Host,
Title: ps.config.Host,
URI: fmt.Sprintf("%s://%s", ps.config.Protocol, ps.config.Host),
Domain: a.config.Host,
Title: a.config.Host,
URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host),
}
inserted, err := ps.conn.Model(i).Where("domain = ?", ps.config.Host).SelectOrInsert()
inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert()
if err != nil {
return err
}
if inserted {
ps.log.Infof("created instance instance %s with id %s", ps.config.Host, i.ID)
a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID)
} else {
ps.log.Infof("instance instance %s already exists with id %s", ps.config.Host, i.ID)
a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID)
}
return nil
}

View file

@ -1,25 +1,55 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"errors"
"fmt"
"strings"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
func (ps *postgresService) Put(i interface{}) error {
_, err := ps.conn.Model(i).Insert(i)
type basicDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (b *basicDB) Put(i interface{}) db.DBError {
_, err := b.conn.Model(i).Insert(i)
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists{}
return db.ErrAlreadyExists
}
return err
}
func (ps *postgresService) GetByID(id string, i interface{}) error {
if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil {
func (b *basicDB) GetByID(id string, i interface{}) db.DBError {
if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
return db.ErrNoEntries
}
return err
@ -27,12 +57,12 @@ func (ps *postgresService) GetByID(id string, i interface{}) error {
return nil
}
func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error {
func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.DBError {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := ps.conn.Model(i)
q := b.conn.Model(i)
for _, w := range where {
if w.Value == nil {
@ -48,25 +78,25 @@ func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error {
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
return db.ErrNoEntries
}
return err
}
return nil
}
func (ps *postgresService) GetAll(i interface{}) error {
if err := ps.conn.Model(i).Select(); err != nil {
func (b *basicDB) GetAll(i interface{}) db.DBError {
if err := b.conn.Model(i).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
return db.ErrNoEntries
}
return err
}
return nil
}
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
func (b *basicDB) DeleteByID(id string, i interface{}) db.DBError {
if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
// if there are no rows *anyway* then that's fine
// just return err if there's an actual error
if err != pg.ErrNoRows {
@ -76,12 +106,12 @@ func (ps *postgresService) DeleteByID(id string, i interface{}) error {
return nil
}
func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error {
func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.DBError {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := ps.conn.Model(i)
q := b.conn.Model(i)
for _, w := range where {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
@ -95,3 +125,76 @@ func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error {
}
return nil
}
func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.DBError {
if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) UpdateByID(id string, i interface{}) db.DBError {
if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.DBError {
_, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
return err
}
func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.DBError {
q := b.conn.Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", pg.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
}
}
q = q.Set("? = ?", pg.Safe(key), value)
_, err := q.Update()
return err
}
func (b *basicDB) CreateTable(i interface{}) db.DBError {
return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
}
func (b *basicDB) DropTable(i interface{}) db.DBError {
return b.conn.Model(i).DropTable(&orm.DropTableOptions{
IfExists: true,
})
}
func (b *basicDB) IsHealthy(ctx context.Context) db.DBError {
return b.conn.Ping(ctx)
}
func (b *basicDB) Stop(ctx context.Context) db.DBError {
b.log.Info("closing db connection")
if err := b.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
b.cancel()
return err
}
return nil
}

View file

@ -19,15 +19,26 @@
package pg
import (
"context"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
q := ps.conn.Model(&[]*gtsmodel.Account{})
type instanceDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
if domain == ps.config.Host {
func (i *instanceDB) GetUserCountForInstance(domain string) (int, db.DBError) {
q := i.conn.Model(&[]*gtsmodel.Account{})
if domain == i.config.Host {
// if the domain is *this* domain, just count where the domain field is null
q = q.Where("? IS NULL", pg.Ident("domain"))
} else {
@ -40,10 +51,10 @@ func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
return q.Count()
}
func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) {
q := ps.conn.Model(&[]*gtsmodel.Status{})
func (i *instanceDB) GetStatusCountForInstance(domain string) (int, db.DBError) {
q := i.conn.Model(&[]*gtsmodel.Status{})
if domain == ps.config.Host {
if domain == i.config.Host {
// if the domain is *this* domain, just count where local is true
q = q.Where("local = ?", true)
} else {
@ -55,10 +66,10 @@ func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error)
return q.Count()
}
func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) {
q := ps.conn.Model(&[]*gtsmodel.Instance{})
func (i *instanceDB) GetDomainCountForInstance(domain string) (int, db.DBError) {
q := i.conn.Model(&[]*gtsmodel.Instance{})
if domain == ps.config.Host {
if domain == i.config.Host {
// if the domain is *this* domain, just count other instances it knows about
// exclude domains that are blocked
q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
@ -70,12 +81,12 @@ func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error)
return q.Count()
}
func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) {
ps.log.Debug("GetAccountsForInstance")
func (i *instanceDB) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.DBError) {
i.log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{}
q := ps.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
if maxID != "" {
q = q.Where("id < ?", maxID)
@ -88,13 +99,13 @@ func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, l
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return nil, err
}
if len(accounts) == 0 {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return accounts, nil

View file

@ -24,44 +24,30 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) {
blocks := []*gtsmodel.Block{}
func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.DBError) {
notifications := []*gtsmodel.Notification{}
fq := ps.conn.Model(&blocks).
Where("block.account_id = ?", accountID).
Relation("TargetAccount").
Order("block.id DESC")
q := ps.conn.Model(&notifications).Where("target_account_id = ?", accountID)
if maxID != "" {
fq = fq.Where("block.id < ?", maxID)
q = q.Where("id < ?", maxID)
}
if sinceID != "" {
fq = fq.Where("block.id > ?", sinceID)
q = q.Where("id > ?", sinceID)
}
if limit > 0 {
fq = fq.Limit(limit)
if limit != 0 {
q = q.Limit(limit)
}
err := fq.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries{}
q = q.Order("created_at DESC")
if err := q.Select(); err != nil {
if err != pg.ErrNoRows {
return nil, err
}
return nil, "", "", err
}
if len(blocks) == 0 {
return nil, "", "", db.ErrNoEntries{}
}
accounts := []*gtsmodel.Account{}
for _, b := range blocks {
accounts = append(accounts, b.TargetAccount)
}
nextMaxID := blocks[len(blocks)-1].ID
prevMinID := blocks[0].ID
return accounts, nextMaxID, prevMinID, nil
return notifications, nil
}

View file

@ -31,7 +31,6 @@ import (
"github.com/go-pg/pg/extra/pgdebug"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
@ -41,6 +40,14 @@ import (
// postgresService satisfies the DB interface
type postgresService struct {
db.Account
db.Admin
db.Basic
db.Instance
db.Notification
db.Relationship
db.Status
db.Timeline
config *config.Config
conn *pg.DB
log *logrus.Logger
@ -85,6 +92,48 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
log.Infof("connected to postgres version: %s", version)
ps := &postgresService{
Account: &accountDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Admin: &adminDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Basic: &basicDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Instance: &instanceDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Relationship: &relationshipDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Status: &statusDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Timeline: &timelineDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
config: c,
conn: conn,
log: log,
@ -193,89 +242,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
return options, nil
}
/*
BASIC DB FUNCTIONALITY
*/
func (ps *postgresService) CreateTable(i interface{}) error {
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
}
func (ps *postgresService) DropTable(i interface{}) error {
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
IfExists: true,
})
}
func (ps *postgresService) Stop(ctx context.Context) error {
ps.log.Info("closing db connection")
if err := ps.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
ps.cancel()
return err
}
return nil
}
func (ps *postgresService) IsHealthy(ctx context.Context) error {
return ps.conn.Ping(ctx)
}
func (ps *postgresService) CreateSchema(ctx context.Context) error {
models := []interface{}{
(*gtsmodel.Account)(nil),
(*gtsmodel.Status)(nil),
(*gtsmodel.User)(nil),
}
ps.log.Info("creating db schema")
for _, model := range models {
err := ps.conn.Model(model).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
if err != nil {
return err
}
}
ps.log.Info("db schema created")
return nil
}
/*
HANDY SHORTCUTS
*/
func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error) {
notifications := []*gtsmodel.Notification{}
q := ps.conn.Model(&notifications).Where("target_account_id = ?", accountID)
if maxID != "" {
q = q.Where("id < ?", maxID)
}
if sinceID != "" {
q = q.Where("id > ?", sinceID)
}
if limit != 0 {
q = q.Limit(limit)
}
q = q.Order("created_at DESC")
if err := q.Select(); err != nil {
if err != pg.ErrNoRows {
return nil, err
}
}
return notifications, nil
}
/*
CONVERSION FUNCTIONS
*/

47
internal/db/pg/pg_test.go Normal file
View file

@ -0,0 +1,47 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
type PGStandardTestSuite struct {
// standard suite interfaces
suite.Suite
config *config.Config
db db.DB
log *logrus.Logger
// standard suite models
testTokens map[string]*oauth.Token
testClients map[string]*oauth.Client
testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account
testAttachments map[string]*gtsmodel.MediaAttachment
testStatuses map[string]*gtsmodel.Status
testTags map[string]*gtsmodel.Tag
testMentions map[string]*gtsmodel.Mention
}

View file

@ -1,17 +1,45 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"fmt"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) Blocked(account1 string, account2 string) (bool, error) {
type relationshipDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (r *relationshipDB) Blocked(account1 string, account2 string) (bool, db.DBError) {
// TODO: check domain blocks as well
var blocked bool
if err := ps.conn.Model(&gtsmodel.Block{}).
if err := r.conn.Model(&gtsmodel.Block{}).
Where("account_id = ?", account1).Where("target_account_id = ?", account2).
WhereOr("target_account_id = ?", account1).Where("account_id = ?", account2).
Select(); err != nil {
@ -25,83 +53,83 @@ func (ps *postgresService) Blocked(account1 string, account2 string) (bool, erro
return blocked, nil
}
func (ps *postgresService) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) {
r := &gtsmodel.Relationship{
func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.DBError) {
rel := &gtsmodel.Relationship{
ID: targetAccount,
}
// check if the requesting account follows the target account
follow := &gtsmodel.Follow{}
if err := ps.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
if err != pg.ErrNoRows {
// a proper error
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
}
// no follow exists so these are all false
r.Following = false
r.ShowingReblogs = false
r.Notifying = false
rel.Following = false
rel.ShowingReblogs = false
rel.Notifying = false
} else {
// follow exists so we can fill these fields out...
r.Following = true
r.ShowingReblogs = follow.ShowReblogs
r.Notifying = follow.Notify
rel.Following = true
rel.ShowingReblogs = follow.ShowReblogs
rel.Notifying = follow.Notify
}
// check if the target account follows the requesting account
followedBy, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
followedBy, err := r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
}
r.FollowedBy = followedBy
rel.FollowedBy = followedBy
// check if the requesting account blocks the target account
blocking, err := ps.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
blocking, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
}
r.Blocking = blocking
rel.Blocking = blocking
// check if the target account blocks the requesting account
blockedBy, err := ps.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
blockedBy, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
r.BlockedBy = blockedBy
rel.BlockedBy = blockedBy
// check if there's a pending following request from requesting account to target account
requested, err := ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
requested, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
r.Requested = requested
rel.Requested = requested
return r, nil
return rel, nil
}
func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
func (r *relationshipDB) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.DBError) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
return ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
return r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
func (r *relationshipDB) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.DBError) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
return ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
return r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error) {
func (r *relationshipDB) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.DBError) {
if account1 == nil || account2 == nil {
return false, nil
}
// make sure account 1 follows account 2
f1, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists()
f1, err := r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists()
if err != nil {
if err == pg.ErrNoRows {
return false, nil
@ -110,7 +138,7 @@ func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmode
}
// make sure account 2 follows account 1
f2, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account2.ID).Where("target_account_id = ?", account1.ID).Exists()
f2, err := r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account2.ID).Where("target_account_id = ?", account1.ID).Exists()
if err != nil {
if err == pg.ErrNoRows {
return false, nil
@ -121,12 +149,12 @@ func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmode
return f1 && f2, nil
}
func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.DBError) {
// make sure the original follow request exists
fr := &gtsmodel.FollowRequest{}
if err := ps.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
if err == pg.ErrMultiRows {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return nil, err
}
@ -140,12 +168,12 @@ func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAcc
}
// if the follow already exists, just update the URI -- we don't need to do anything else
if _, err := ps.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
return nil, err
}
// now remove the follow request
if _, err := ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
if _, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
return nil, err
}

View file

@ -20,39 +20,90 @@ package pg
import (
"container/list"
"context"
"errors"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) {
type statusDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (s *statusDB) newStatusQ(status *gtsmodel.Status) *orm.Query {
return s.conn.Model(status).
Relation("Account").
Relation("InReplyTo").
Relation("InReplyToAccount").
Relation("BoostOf").
Relation("BoostOfAccount").
Relation("CreatedWithApplication")
}
func (s *statusDB) processResponse(status *gtsmodel.Status, err error) (*gtsmodel.Status, db.DBError) {
switch err {
case pg.ErrNoRows:
return nil, db.ErrNoEntries
case nil:
return status, nil
default:
return nil, err
}
}
func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.DBError) {
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("status.id = ?", id)
return s.processResponse(status, q.Select())
}
func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.DBError) {
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.uri) = LOWER(?)", uri)
return s.processResponse(status, q.Select())
}
func (s *statusDB) StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.DBError) {
parents := []*gtsmodel.Status{}
ps.statusParent(status, &parents, onlyDirect)
s.statusParent(status, &parents, onlyDirect)
return parents, nil
}
func (ps *postgresService) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
if status.InReplyToID == "" {
return
}
parentStatus := &gtsmodel.Status{}
if err := ps.conn.Model(parentStatus).Where("id = ?", status.InReplyToID).Select(); err == nil {
if err := s.conn.Model(parentStatus).Where("id = ?", status.InReplyToID).Select(); err == nil {
*foundStatuses = append(*foundStatuses, parentStatus)
}
if onlyDirect {
return
}
ps.statusParent(parentStatus, foundStatuses, false)
s.statusParent(parentStatus, foundStatuses, false)
}
func (ps *postgresService) StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) {
func (s *statusDB) StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.DBError) {
foundStatuses := &list.List{}
foundStatuses.PushFront(status)
ps.statusChildren(status, foundStatuses, onlyDirect, minID)
s.statusChildren(status, foundStatuses, onlyDirect, minID)
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
@ -70,10 +121,10 @@ func (ps *postgresService) StatusChildren(status *gtsmodel.Status, onlyDirect bo
return children, nil
}
func (ps *postgresService) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
immediateChildren := []*gtsmodel.Status{}
q := ps.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
if minID != "" {
q = q.Where("status.id > ?", minID)
}
@ -100,43 +151,43 @@ func (ps *postgresService) statusChildren(status *gtsmodel.Status, foundStatuses
if onlyDirect {
return
}
ps.statusChildren(child, foundStatuses, false, minID)
s.statusChildren(child, foundStatuses, false, minID)
}
}
func (ps *postgresService) GetReplyCountForStatus(status *gtsmodel.Status) (int, error) {
return ps.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
func (s *statusDB) GetReplyCountForStatus(status *gtsmodel.Status) (int, db.DBError) {
return s.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
}
func (ps *postgresService) GetReblogCountForStatus(status *gtsmodel.Status) (int, error) {
return ps.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
func (s *statusDB) GetReblogCountForStatus(status *gtsmodel.Status) (int, db.DBError) {
return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
}
func (ps *postgresService) GetFaveCountForStatus(status *gtsmodel.Status) (int, error) {
return ps.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
func (s *statusDB) GetFaveCountForStatus(status *gtsmodel.Status) (int, db.DBError) {
return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
}
func (ps *postgresService) StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
func (s *statusDB) StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
func (s *statusDB) StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
func (s *statusDB) StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
return s.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
func (s *statusDB) StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
return s.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
func (s *statusDB) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, db.DBError) {
accounts := []*gtsmodel.Account{}
faves := []*gtsmodel.StatusFave{}
if err := ps.conn.Model(&faves).Where("status_id = ?", status.ID).Select(); err != nil {
if err := s.conn.Model(&faves).Where("status_id = ?", status.ID).Select(); err != nil {
if err == pg.ErrNoRows {
return accounts, nil // no rows just means nobody has faved this status, so that's fine
}
@ -145,7 +196,7 @@ func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.
for _, f := range faves {
acc := &gtsmodel.Account{}
if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
if err := s.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
}
@ -156,11 +207,11 @@ func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.
return accounts, nil
}
func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
func (s *statusDB) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, db.DBError) {
accounts := []*gtsmodel.Account{}
boosts := []*gtsmodel.Status{}
if err := ps.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil {
if err := s.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil {
if err == pg.ErrNoRows {
return accounts, nil // no rows just means nobody has boosted this status, so that's fine
}
@ -169,7 +220,7 @@ func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmode
for _, f := range boosts {
acc := &gtsmodel.Account{}
if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
if err := s.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
}

View file

@ -0,0 +1,86 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
import (
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type StatusTestSuite struct {
PGStandardTestSuite
}
func (suite *PGStandardTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
suite.testAttachments = testrig.NewTestAttachments()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
}
func (suite *PGStandardTestSuite) SetupTest() {
suite.config = testrig.NewTestConfig()
suite.db = testrig.NewTestDB()
suite.log = testrig.NewTestLog()
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
func (suite *PGStandardTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *PGStandardTestSuite) TestGetStatusByID() {
status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.Nil(status.BoostOf)
suite.Nil(status.BoostOfAccount)
suite.Nil(status.InReplyTo)
suite.Nil(status.InReplyToAccount)
}
func (suite *PGStandardTestSuite) TestGetStatusByURI() {
status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.Nil(status.BoostOf)
suite.Nil(status.BoostOfAccount)
suite.Nil(status.InReplyTo)
suite.Nil(status.InReplyToAccount)
}
func TestStatusTestSuite(t *testing.T) {
suite.Run(t, new(PGStandardTestSuite))
}

View file

@ -19,16 +19,26 @@
package pg
import (
"context"
"sort"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
type timelineDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (t *timelineDB) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.DBError) {
statuses := []*gtsmodel.Status{}
q := ps.conn.Model(&statuses)
q := t.conn.Model(&statuses)
q = q.ColumnExpr("status.*").
// Find out who accountID follows.
@ -74,22 +84,22 @@ func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID str
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return statuses, nil
}
func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
func (t *timelineDB) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.DBError) {
statuses := []*gtsmodel.Status{}
q := ps.conn.Model(&statuses).
q := t.conn.Model(&statuses).
Where("visibility = ?", gtsmodel.VisibilityPublic).
Where("? IS NULL", pg.Ident("in_reply_to_id")).
Where("? IS NULL", pg.Ident("in_reply_to_uri")).
@ -119,13 +129,13 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return statuses, nil
@ -133,11 +143,11 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) {
func (t *timelineDB) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.DBError) {
faves := []*gtsmodel.StatusFave{}
fq := ps.conn.Model(&faves).
fq := t.conn.Model(&faves).
Where("account_id = ?", accountID).
Order("id DESC")
@ -156,13 +166,13 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st
err := fq.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries{}
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(faves) == 0 {
return nil, "", "", db.ErrNoEntries{}
return nil, "", "", db.ErrNoEntries
}
// map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID
@ -175,16 +185,16 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st
}
statuses := []*gtsmodel.Status{}
err = ps.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries{}
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(statuses) == 0 {
return nil, "", "", db.ErrNoEntries{}
return nil, "", "", db.ErrNoEntries
}
// arrange statuses by fave ID

View file

@ -1,73 +0,0 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"fmt"
"github.com/go-pg/pg/v10"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
func (ps *postgresService) Upsert(i interface{}, conflictColumn string) error {
if _, err := ps.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) UpdateByID(id string, i interface{}) error {
if _, err := ps.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
_, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
return err
}
func (ps *postgresService) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) error {
q := ps.conn.Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", pg.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
}
}
q = q.Set("? = ?", pg.Safe(key), value)
_, err := q.Update()
return err
}

View file

@ -5,23 +5,23 @@ import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
type Relationship interface {
// Blocked checks whether a block exists in eiher direction between two accounts.
// That is, it returns true if account1 blocks account2, OR if account2 blocks account1.
Blocked(account1 string, account2 string) (bool, error)
Blocked(account1 string, account2 string) (bool, DBError)
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error)
GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, DBError)
// Follows returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error)
Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, DBError)
// FollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error)
FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, DBError)
// Mutuals returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error)
Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, DBError)
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
//
// It will return the newly created follow for further processing.
AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, DBError)
}

View file

@ -3,42 +3,48 @@ package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
type Status interface {
// GetStatusByID returns one status from the database, with all rel fields populated (if possible).
GetStatusByID(id string) (*gtsmodel.Status, DBError)
// GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
GetStatusByURI(uri string) (*gtsmodel.Status, DBError)
// GetReplyCountForStatus returns the amount of replies recorded for a status, or an error if something goes wrong
GetReplyCountForStatus(status *gtsmodel.Status) (int, error)
GetReplyCountForStatus(status *gtsmodel.Status) (int, DBError)
// GetReblogCountForStatus returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
GetReblogCountForStatus(status *gtsmodel.Status) (int, error)
GetReblogCountForStatus(status *gtsmodel.Status) (int, DBError)
// GetFaveCountForStatus returns the amount of faves/likes recorded for a status, or an error if something goes wrong
GetFaveCountForStatus(status *gtsmodel.Status) (int, error)
GetFaveCountForStatus(status *gtsmodel.Status) (int, DBError)
// StatusParents get the parent statuses of a given status.
//
// If onlyDirect is true, only the immediate parent will be returned.
StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error)
StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, DBError)
// StatusChildren gets the child statuses of a given status.
//
// If onlyDirect is true, only the immediate children will be returned.
StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error)
StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, DBError)
// StatusFavedBy checks if a given status has been faved by a given account ID
StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error)
StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
// StatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error)
StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
// StatusMutedBy checks if a given status has been muted by a given account ID
StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error)
StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
// StatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error)
StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
// WhoFavedStatus returns a slice of accounts who faved the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error)
WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, DBError)
// WhoBoostedStatus returns a slice of accounts who boosted the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error)
WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, DBError)
}

View file

@ -6,13 +6,13 @@ type Timeline interface {
// GetHomeTimelineForAccount returns a slice of statuses from accounts that are followed by the given account id.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, DBError)
// GetPublicTimelineForAccount fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, DBError)
// GetFavedTimelineForAccount fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
// It will use the given filters and try to return as many statuses as possible up to the limit.
@ -21,5 +21,5 @@ type Timeline interface {
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error)
GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, DBError)
}

View file

@ -31,7 +31,7 @@ func (d *deref) blockedDomain(host string) (bool, error) {
return true, nil
}
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries so there's no block
return false, nil
}

View file

@ -288,7 +288,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
attachmentIDs = append(attachmentIDs, maybeAttachment.ID)
continue
}
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// we have a real error
return fmt.Errorf("error checking db for existence of attachment with remote url %s: %s", a.RemoteURL, err)
}

View file

@ -113,7 +113,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
status.ID = statusID
if err := f.db.Put(status); err != nil {
if _, ok := err.(db.ErrAlreadyExists); ok {
if err == db.ErrAlreadyExists {
// the status already exists in the database, which means we've already handled everything else,
// so we can just return nil here and be done with it.
return nil

View file

@ -62,7 +62,7 @@ func (f *federatingDB) OutboxForInbox(c context.Context, inboxIRI *url.URL) (out
}
acct := &gtsmodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String())
}
return nil, fmt.Errorf("db error searching for actor with inbox %s", inboxIRI.String())

View file

@ -55,7 +55,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: uid}}, &gtsmodel.Status{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries for this status
return false, nil
}
@ -72,7 +72,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@ -89,7 +89,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@ -106,7 +106,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@ -123,7 +123,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("error parsing like path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@ -131,7 +131,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
}
if err := f.db.GetByID(likeID, &gtsmodel.StatusFave{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries
return false, nil
}
@ -148,7 +148,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("error parsing block path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@ -156,7 +156,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
}
if err := f.db.GetByID(blockID, &gtsmodel.Block{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries
return false, nil
}

View file

@ -213,7 +213,7 @@ func (f *federatingDB) ActorForOutbox(c context.Context, outboxIRI *url.URL) (ac
}
acct := &gtsmodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "outbox_uri", Value: outboxIRI.String()}}, acct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to outbox %s", outboxIRI.String())
}
return nil, fmt.Errorf("db error searching for actor with outbox %s", outboxIRI.String())
@ -238,7 +238,7 @@ func (f *federatingDB) ActorForInbox(c context.Context, inboxIRI *url.URL) (acto
}
acct := &gtsmodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String())
}
return nil, fmt.Errorf("db error searching for actor with inbox %s", inboxIRI.String())

View file

@ -132,7 +132,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
// authentication has passed, so add an instance entry for this instance if it hasn't been done already
i := &gtsmodel.Instance{}
if err := f.db.GetWhere([]db.Where{{Key: "domain", Value: publicKeyOwnerURI.Host, CaseInsensitive: true}}, i); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// there's been an actual error
return ctx, false, fmt.Errorf("error getting requesting account with public key id %s: %s", publicKeyOwnerURI.String(), err)
}
@ -202,8 +202,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
requestingAccount := &gtsmodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: uri.String()}}, requestingAccount); err != nil {
_, ok := err.(db.ErrNoEntries)
if ok {
if err == db.ErrNoEntries {
// we don't have an entry for this account so it's not blocked
// TODO: allow a different default to be set for this behavior
continue

View file

@ -13,7 +13,7 @@ func (f *federator) blockedDomain(host string) (bool, error) {
return true, nil
}
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries so there's no block
return false, nil
}

View file

@ -46,10 +46,12 @@ type Account struct {
// ID of the avatar as a media attachment
AvatarMediaAttachmentID string `pg:"type:CHAR(26)"`
AvatarMediaAttachment *MediaAttachment `pg:"rel:has-one"`
// For a non-local account, where can the header be fetched?
AvatarRemoteURL string
// ID of the header as a media attachment
HeaderMediaAttachmentID string `pg:"type:CHAR(26)"`
HeaderMediaAttachment *MediaAttachment `pg:"rel:has-one"`
// For a non-local account, where can the header be fetched?
HeaderRemoteURL string
// DisplayName for this account. Can be empty, then just the Username will be used for display purposes.

View file

@ -20,6 +20,7 @@ type Instance struct {
SuspendedAt time.Time
// ID of any existing domain block for this instance in the database
DomainBlockID string `pg:"type:CHAR(26)"`
DomainBlock *DomainBlock `pg:"rel:has-one"`
// Short description of this instance
ShortDescription string
// Longer description of this instance
@ -32,6 +33,7 @@ type Instance struct {
ContactAccountUsername string
// Contact account ID in the database for this instance
ContactAccountID string `pg:"type:CHAR(26)"`
ContactAccount *Account `pg:"rel:has-one"`
// Reputation score of this instance
Reputation int64 `pg:",notnull,default:0"`
// Version of the software used on this instance

View file

@ -47,24 +47,24 @@ type Status struct {
// is this status from a local account?
Local bool
// which account posted this status?
AccountID string `pg:"type:CHAR(26),notnull"`
Account *Account `pg:"rel:has-one"`
AccountID string `pg:"type:CHAR(26),notnull"`
Account *Account `pg:"rel:has-one"`
// AP uri of the owner of this status
AccountURI string
// id of the status this status is a reply to
InReplyToID string `pg:"type:CHAR(26)"`
InReplyTo *Status `pg:"-"`
InReplyToID string `pg:"type:CHAR(26)"`
InReplyTo *Status `pg:"rel:has-one"`
// AP uri of the status this status is a reply to
InReplyToURI string
// id of the account that this status replies to
InReplyToAccountID string `pg:"type:CHAR(26)"`
InReplyToAccount *Account `pg:"-"`
InReplyToAccountID string `pg:"type:CHAR(26)"`
InReplyToAccount *Account `pg:"rel:has-one"`
// id of the status this status is a boost of
BoostOfID string `pg:"type:CHAR(26)"`
BoostOf *Status `pg:"-"`
BoostOfID string `pg:"type:CHAR(26)"`
BoostOf *Status `pg:"rel:has-one"`
// id of the account that owns the boosted status
BoostOfAccountID string `pg:"type:CHAR(26)"`
BoostOfAccount *Account `pg:"-"`
BoostOfAccountID string `pg:"type:CHAR(26)"`
BoostOfAccount *Account `pg:"rel:has-one"`
// cw string for this status
ContentWarning string
// visibility entry for this status
@ -74,8 +74,8 @@ type Status struct {
// what language is this status written in?
Language string
// Which application was used to create this status?
CreatedWithApplicationID string `pg:"type:CHAR(26)"`
CreatedWithApplication *Application `pg:"-"`
CreatedWithApplicationID string `pg:"type:CHAR(26)"`
CreatedWithApplication *Application `pg:"rel:has-one"`
// advanced visibility for this status
VisibilityAdvanced *VisibilityAdvanced
// What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types
@ -103,13 +103,6 @@ type Status struct {
GTSEmojis []*Emoji `pg:"-"`
// MediaAttachments used in this status
GTSMediaAttachments []*MediaAttachment `pg:"-"`
// Status being replied to
// Account being replied to
// Status being boosted
// Account of the boosted status
}

View file

@ -27,11 +27,11 @@ import (
)
type clientStore struct {
db db.DB
db db.Basic
}
// NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend.
func NewClientStore(db db.DB) oauth2.ClientStore {
func NewClientStore(db db.Basic) oauth2.ClientStore {
pts := &clientStore{
db: db,
}

View file

@ -99,7 +99,7 @@ func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
// try to get the deleted client; we should get an error
deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
suite.Assert().Nil(deletedClient)
suite.Assert().EqualValues(db.ErrNoEntries{}, err)
suite.Assert().EqualValues(db.ErrNoEntries, err)
}
func TestPgClientStoreTestSuite(t *testing.T) {

View file

@ -66,7 +66,7 @@ type s struct {
}
// New returns a new oauth server that implements the Server interface
func New(database db.DB, log *logrus.Logger) Server {
func New(database db.Basic, log *logrus.Logger) Server {
ts := newTokenStore(context.Background(), database, log)
cs := NewClientStore(database)

View file

@ -34,7 +34,7 @@ import (
// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
type tokenStore struct {
oauth2.TokenStore
db db.DB
db db.Basic
log *logrus.Logger
}
@ -42,7 +42,7 @@ type tokenStore struct {
//
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
// the tokens in the DB once per minute and deletes any that have expired.
func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.TokenStore {
func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.TokenStore {
pts := &tokenStore{
db: db,
log: log,

View file

@ -33,7 +33,7 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
// make sure the target account actually exists in our db
targetAcct := &gtsmodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAcct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: account %s not found in the db: %s", targetAccountID, err))
}
}

View file

@ -42,7 +42,7 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim
// make sure the target account actually exists in our db
targetAcct := &gtsmodel.Account{}
if err := p.db.GetByID(form.ID, targetAcct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err))
}
}

View file

@ -135,7 +135,7 @@ selectStatusesLoop:
for {
statuses, err := p.db.GetAccountStatuses(account.ID, 20, false, maxID, false, false)
if err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// no statuses left for this instance so we're done
l.Infof("Delete: done iterating through statuses for account %s", account.Username)
break selectStatusesLoop
@ -158,7 +158,7 @@ selectStatusesLoop:
}
if err := p.db.DeleteByID(s.ID, s); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// actual error has occurred
l.Errorf("Delete: db error status %s for account %s: %s", s.ID, account.Username, err)
break selectStatusesLoop
@ -168,7 +168,7 @@ selectStatusesLoop:
// if there are any boosts of this status, delete them as well
boosts := []*gtsmodel.Status{}
if err := p.db.GetWhere([]db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// an actual error has occurred
l.Errorf("Delete: db error selecting boosts of status %s for account %s: %s", s.ID, account.Username, err)
break selectStatusesLoop
@ -190,7 +190,7 @@ selectStatusesLoop:
}
if err := p.db.DeleteByID(b.ID, b); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// actual error has occurred
l.Errorf("Delete: db error deleting boost with id %s: %s", b.ID, err)
break selectStatusesLoop

View file

@ -30,7 +30,7 @@ import (
func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) {
targetAccount := &gtsmodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAccount); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return nil, errors.New("account not found")
}
return nil, fmt.Errorf("db error: %s", err)

View file

@ -40,7 +40,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco
followers := []gtsmodel.Follow{}
accounts := []apimodel.Account{}
if err := p.db.GetAccountFollowers(targetAccountID, &followers, false); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return accounts, nil
}
return nil, gtserror.NewErrorInternalError(err)
@ -57,7 +57,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco
a := &gtsmodel.Account{}
if err := p.db.GetByID(f.AccountID, a); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
continue
}
return nil, gtserror.NewErrorInternalError(err)

View file

@ -40,7 +40,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco
following := []gtsmodel.Follow{}
accounts := []apimodel.Account{}
if err := p.db.GetAccountFollowing(targetAccountID, &following); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return accounts, nil
}
return nil, gtserror.NewErrorInternalError(err)
@ -57,7 +57,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco
a := &gtsmodel.Account{}
if err := p.db.GetByID(f.TargetAccountID, a); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
continue
}
return nil, gtserror.NewErrorInternalError(err)

View file

@ -30,7 +30,7 @@ import (
func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) {
targetAccount := &gtsmodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAccount); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no entry found for account id %s", targetAccountID))
}
return nil, gtserror.NewErrorInternalError(err)
@ -39,7 +39,7 @@ func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccou
apiStatuses := []apimodel.Status{}
statuses, err := p.db.GetAccountStatuses(targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
if err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return apiStatuses, nil
}
return nil, gtserror.NewErrorInternalError(err)

View file

@ -31,7 +31,7 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou
// make sure the target account actually exists in our db
targetAcct := &gtsmodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAcct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockRemove: account %s not found in the db: %s", targetAccountID, err))
}
}

View file

@ -40,7 +40,7 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco
// make sure the target account actually exists in our db
targetAcct := &gtsmodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAcct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err))
}
}

View file

@ -36,7 +36,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string,
domainBlock := &gtsmodel.DomainBlock{}
err := p.db.GetWhere([]db.Where{{Key: "domain", Value: domain, CaseInsensitive: true}}, domainBlock)
if err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// something went wrong in the DB
return nil, gtserror.NewErrorInternalError(fmt.Errorf("DomainBlockCreate: db error checking for existence of domain block %s: %s", domain, err))
}
@ -60,7 +60,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string,
// put the new block in the database
if err := p.db.Put(domainBlock); err != nil {
if _, ok := err.(db.ErrAlreadyExists); !ok {
if err != db.ErrNoEntries {
// there's a real error creating the block
return nil, gtserror.NewErrorInternalError(fmt.Errorf("DomainBlockCreate: db error putting new domain block %s: %s", domain, err))
}
@ -125,7 +125,7 @@ selectAccountsLoop:
for {
accounts, err := p.db.GetAccountsForInstance(block.Domain, maxID, limit)
if err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// no accounts left for this instance so we're done
l.Infof("domainBlockProcessSideEffects: done iterating through accounts for domain %s", block.Domain)
break selectAccountsLoop

View file

@ -32,7 +32,7 @@ func (p *processor) DomainBlockDelete(account *gtsmodel.Account, id string) (*ap
domainBlock := &gtsmodel.DomainBlock{}
if err := p.db.GetByID(id, domainBlock); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
}

View file

@ -31,7 +31,7 @@ func (p *processor) DomainBlockGet(account *gtsmodel.Account, id string, export
domainBlock := &gtsmodel.DomainBlock{}
if err := p.db.GetByID(id, domainBlock); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
}

View file

@ -29,7 +29,7 @@ func (p *processor) DomainBlocksGet(account *gtsmodel.Account, export bool) ([]*
domainBlocks := []*gtsmodel.DomainBlock{}
if err := p.db.GetAll(&domainBlocks); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
}

View file

@ -31,7 +31,7 @@ import (
func (p *processor) BlocksGet(authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) {
accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(authed.Account.ID, maxID, sinceID, limit)
if err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are just no entries
return &apimodel.BlocksResponse{
Accounts: []*apimodel.Account{},

View file

@ -29,7 +29,7 @@ import (
func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) {
frs := []gtsmodel.FollowRequest{}
if err := p.db.GetAccountFollowRequests(auth.Account.ID, &frs); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
return nil, gtserror.NewErrorInternalError(err)
}
}

View file

@ -74,7 +74,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error {
// notification exists already so just continue
continue
}
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// there's a real error in the db
return fmt.Errorf("notifyStatus: error checking existence of notification for mention with id %s : %s", m.ID, err)
}

View file

@ -101,7 +101,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
incomingAnnounce.ID = incomingAnnounceID
if err := p.db.Put(incomingAnnounce); err != nil {
if _, ok := err.(db.ErrAlreadyExists); !ok {
if err != db.ErrNoEntries {
return fmt.Errorf("error adding dereferenced announce to the db: %s", err)
}
}

View file

@ -12,7 +12,7 @@ import (
func (p *processor) Delete(mediaAttachmentID string) gtserror.WithCode {
a := &gtsmodel.MediaAttachment{}
if err := p.db.GetByID(mediaAttachmentID, a); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// attachment already gone
return nil
}
@ -38,7 +38,7 @@ func (p *processor) Delete(mediaAttachmentID string) gtserror.WithCode {
// delete the attachment
if err := p.db.DeleteByID(mediaAttachmentID, a); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
errs = append(errs, fmt.Sprintf("remove attachment: %s", err))
}
}

View file

@ -31,7 +31,7 @@ import (
func (p *processor) GetMedia(account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
attachment := &gtsmodel.MediaAttachment{}
if err := p.db.GetByID(mediaAttachmentID, attachment); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// attachment doesn't exist
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))
}

View file

@ -32,7 +32,7 @@ import (
func (p *processor) Update(account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) {
attachment := &gtsmodel.MediaAttachment{}
if err := p.db.GetByID(mediaAttachmentID, attachment); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// attachment doesn't exist
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))
}

View file

@ -196,7 +196,7 @@ func (p *processor) searchAccountByMention(authed *oauth.Auth, mention string, r
return maybeAcct, nil
}
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// if it's not errNoEntries there's been a real database error so bail at this point
return nil, fmt.Errorf("searchAccountByMention: database error: %s", err)
}

View file

@ -19,7 +19,7 @@ func (p *processor) Context(account *gtsmodel.Account, targetStatusID string) (*
targetStatus := &gtsmodel.Status{}
if err := p.db.GetByID(targetStatusID, targetStatus); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(err)
}
return nil, gtserror.NewErrorInternalError(err)

View file

@ -15,7 +15,7 @@ func (p *processor) Delete(account *gtsmodel.Account, targetStatusID string) (*a
l.Tracef("going to search for target status %s", targetStatusID)
targetStatus := &gtsmodel.Status{}
if err := p.db.GetByID(targetStatusID, targetStatus); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
// status is already gone

View file

@ -57,7 +57,7 @@ func (p *processor) Unboost(account *gtsmodel.Account, application *gtsmodel.App
if err != nil {
// something went wrong in the db finding the boost
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error fetching existing boost from database: %s", err))
}
// we just don't have a boost

View file

@ -45,7 +45,7 @@ func (p *processor) Unfave(account *gtsmodel.Account, targetStatusID string) (*a
}
if err != nil {
// something went wrong in the db finding the fave
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error fetching existing fave from database: %s", err))
}
// we just don't have a fave

View file

@ -99,7 +99,7 @@ func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, th
repliedAccount := &gtsmodel.Account{}
// check replied status exists + is replyable
if err := p.db.GetByID(form.InReplyToID, repliedStatus); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return fmt.Errorf("status with id %s not replyable because it doesn't exist", form.InReplyToID)
}
return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err)
@ -113,14 +113,14 @@ func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, th
// check replied account is known to us
if err := p.db.GetByID(repliedStatus.AccountID, repliedAccount); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return fmt.Errorf("status with id %s not replyable because account id %s is not known", form.InReplyToID, repliedStatus.AccountID)
}
return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err)
}
// check if a block exists
if blocked, err := p.db.Blocked(thisAccountID, repliedAccount.ID); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err)
}
} else if blocked {

View file

@ -76,7 +76,7 @@ func (p *processor) HomeTimelineGet(authed *oauth.Auth, maxID string, sinceID st
func (p *processor) PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) {
statuses, err := p.db.GetPublicTimelineForAccount(authed.Account.ID, maxID, sinceID, minID, limit, local)
if err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are just no entries left
return &apimodel.StatusTimelineResponse{
Statuses: []*apimodel.Status{},
@ -97,7 +97,7 @@ func (p *processor) PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID
func (p *processor) FavedTimelineGet(authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode) {
statuses, nextMaxID, prevMinID, err := p.db.GetFavedTimelineForAccount(authed.Account.ID, maxID, minID, limit)
if err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are just no entries left
return &apimodel.StatusTimelineResponse{
Statuses: []*apimodel.Status{},
@ -122,7 +122,7 @@ func (p *processor) filterPublicStatuses(authed *oauth.Auth, statuses []*gtsmode
for _, s := range statuses {
targetAccount := &gtsmodel.Account{}
if err := p.db.GetByID(s.AccountID, targetAccount); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
l.Debugf("filterPublicStatuses: skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
continue
}
@ -157,7 +157,7 @@ func (p *processor) filterFavedStatuses(authed *oauth.Auth, statuses []*gtsmodel
for _, s := range statuses {
targetAccount := &gtsmodel.Account{}
if err := p.db.GetByID(s.AccountID, targetAccount); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
l.Debugf("filterFavedStatuses: skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
continue
}

View file

@ -49,7 +49,7 @@ func useSession(cfg *config.Config, dbService db.DB, engine *gin.Engine) error {
// check if we have a saved router session already
routerSessions := []*gtsmodel.RouterSession{}
if err := dbService.GetAll(&routerSessions); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// proper error occurred
return err
}

View file

@ -52,7 +52,7 @@ grabloop:
for ; len(filtered) < amount && i < 5; i = i + 1 { // try the grabloop 5 times only
statuses, err := t.db.GetHomeTimelineForAccount(t.accountID, "", "", offsetStatus, amount, false)
if err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
break grabloop // we just don't have enough statuses left in the db so index what we've got and then bail
}
return fmt.Errorf("IndexBefore: error getting statuses from db: %s", err)
@ -132,7 +132,7 @@ grabloop:
l.Tracef("entering grabloop; i is %d; len(filtered) is %d", i, len(filtered))
statuses, err := t.db.GetHomeTimelineForAccount(t.accountID, offsetStatus, "", "", amount, false)
if err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
break grabloop // we just don't have enough statuses left in the db so index what we've got and then bail
}
return fmt.Errorf("IndexBehind: error getting statuses from db: %s", err)

View file

@ -95,7 +95,7 @@ prepareloop:
if preparing {
if err := t.prepare(entry.statusID); err != nil {
// there's been an error
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// it's a real error
return fmt.Errorf("PrepareBehind: error preparing status with id %s: %s", entry.statusID, err)
}
@ -150,7 +150,7 @@ prepareloop:
if preparing {
if err := t.prepare(entry.statusID); err != nil {
// there's been an error
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// it's a real error
return fmt.Errorf("PrepareBefore: error preparing status with id %s: %s", entry.statusID, err)
}
@ -205,7 +205,7 @@ prepareloop:
if err := t.prepare(entry.statusID); err != nil {
// there's been an error
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// it's a real error
return fmt.Errorf("PrepareFromTop: error preparing status with id %s: %s", entry.statusID, err)
}

View file

@ -44,7 +44,7 @@ func (c *converter) ASRepresentationToAccount(accountable ap.Accountable, update
// we already know this account so we can skip generating it
return acct, nil
}
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// we don't know the account and there's been a real error
return nil, fmt.Errorf("error getting account with uri %s from the database: %s", uri.String(), err)
}

View file

@ -40,7 +40,7 @@ func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account
// check pending follow requests aimed at this account
fr := []gtsmodel.FollowRequest{}
if err := c.db.GetAccountFollowRequests(a.ID, &fr); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
return nil, fmt.Errorf("error getting follow requests: %s", err)
}
}
@ -63,41 +63,27 @@ func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account
func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, error) {
// count followers
followers := []gtsmodel.Follow{}
if err := c.db.GetAccountFollowers(a.ID, &followers, false); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting followers: %s", err)
}
}
var followersCount int
if followers != nil {
followersCount = len(followers)
followersCount, err := c.db.CountAccountFollowers(a.ID, false)
if err != nil {
return nil, fmt.Errorf("error counting followers: %s", err)
}
// count following
following := []gtsmodel.Follow{}
if err := c.db.GetAccountFollowing(a.ID, &following); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting following: %s", err)
}
}
var followingCount int
if following != nil {
followingCount = len(following)
followingCount, err := c.db.CountAccountFollowing(a.ID, false)
if err != nil {
return nil, fmt.Errorf("error counting following: %s", err)
}
// count statuses
statusesCount, err := c.db.GetAccountStatusesCount(a.ID)
statusesCount, err := c.db.CountAccountStatuses(a.ID)
if err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting last statuses: %s", err)
}
return nil, fmt.Errorf("error getting last statuses: %s", err)
}
// check when the last status was
lastStatus := &gtsmodel.Status{}
if err := c.db.GetAccountLastStatus(a.ID, lastStatus); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
return nil, fmt.Errorf("error getting last status: %s", err)
}
}
@ -107,23 +93,20 @@ func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, e
}
// build the avatar and header URLs
avi := &gtsmodel.MediaAttachment{}
if err := c.db.GetAccountAvatar(avi, a.ID); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting avatar: %s", err)
}
}
aviURL := avi.URL
aviURLStatic := avi.Thumbnail.URL
header := &gtsmodel.MediaAttachment{}
if err := c.db.GetAccountHeader(header, a.ID); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
return nil, fmt.Errorf("error getting header: %s", err)
}
var aviURL string
var aviURLStatic string
if a.AvatarMediaAttachment != nil {
aviURL = a.AvatarMediaAttachment.URL
aviURLStatic = a.AvatarMediaAttachment.Thumbnail.URL
}
var headerURL string
var headerURLStatic string
if a.HeaderMediaAttachment != nil {
headerURL = a.HeaderMediaAttachment.URL
headerURLStatic = a.HeaderMediaAttachment.Thumbnail.URL
}
headerURL := header.URL
headerURLStatic := header.Thumbnail.URL
// get the fields set on this account
fields := []model.Field{}
@ -585,13 +568,10 @@ func (c *converter) InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, erro
}
// get the instance account if it exists and just skip if it doesn't
ia := &gtsmodel.Account{}
if err := c.db.GetWhere([]db.Where{{Key: "username", Value: i.Domain}}, ia); err == nil {
// instance account exists, get the header for the account if it exists
attachment := &gtsmodel.MediaAttachment{}
if err := c.db.GetAccountHeader(attachment, ia.ID); err == nil {
// header exists, set it on the api model
mi.Thumbnail = attachment.URL
ia, err := c.db.GetInstanceAccount("")
if err == nil {
if ia.HeaderMediaAttachment != nil {
mi.Thumbnail = ia.HeaderMediaAttachment.URL
}
}

View file

@ -45,7 +45,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
targetUser := &gtsmodel.User{}
if err := f.db.GetWhere([]db.Where{{Key: "account_id", Value: targetAccount.ID}}, targetUser); err != nil {
l.Debug("target user could not be selected")
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return false, nil
}
return false, fmt.Errorf("StatusVisible: db error selecting user for local target account %s: %s", targetAccount.ID, err)
@ -76,7 +76,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
if err := f.db.GetWhere([]db.Where{{Key: "account_id", Value: requestingAccount.ID}}, requestingUser); err != nil {
// if the requesting account is local but doesn't have a corresponding user in the db this is a problem
l.Debug("requesting user could not be selected")
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return false, nil
}
return false, fmt.Errorf("StatusVisible: db error selecting user for local requesting account %s: %s", requestingAccount.ID, err)

View file

@ -112,7 +112,7 @@ func (f *filter) blockedDomain(host string) (bool, error) {
return true, nil
}
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries so there's no block
return false, nil
}