mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-12-30 00:36:14 -06:00
continue moving db stuff around
This commit is contained in:
parent
f409f7c65a
commit
15153ee0c8
72 changed files with 923 additions and 614 deletions
|
|
@ -116,7 +116,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
|
|||
return user, nil
|
||||
}
|
||||
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// we have an actual error in the database
|
||||
return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err)
|
||||
}
|
||||
|
|
@ -128,7 +128,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
|
|||
return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email)
|
||||
}
|
||||
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// we have an actual error in the database
|
||||
return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ func (m *Module) blockedDomain(host string) (bool, error) {
|
|||
return true, nil
|
||||
}
|
||||
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are no entries so there's no block
|
||||
return false, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,60 +6,66 @@ type Account interface {
|
|||
// GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID.
|
||||
// The given account pointer will be set to the result of the query, whatever it is.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetAccountByUserID(userID string, account *gtsmodel.Account) error
|
||||
GetAccountByUserID(userID string, account *gtsmodel.Account) DBError
|
||||
|
||||
// GetAccountByID returns one account with the given ID, or an error if something goes wrong.
|
||||
GetAccountByID(id string) (*gtsmodel.Account, DBError)
|
||||
|
||||
// GetAccountByURI returns one account with the given URI, or an error if something goes wrong.
|
||||
GetAccountByURI(uri string) (*gtsmodel.Account, DBError)
|
||||
|
||||
// GetLocalAccountByUsername is a shortcut for the common action of fetching an account ON THIS INSTANCE
|
||||
// according to its username, which should be unique.
|
||||
// The given account pointer will be set to the result of the query, whatever it is.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetLocalAccountByUsername(username string, account *gtsmodel.Account) error
|
||||
GetLocalAccountByUsername(username string, account *gtsmodel.Account) DBError
|
||||
|
||||
// GetAccountFollowRequests is a shortcut for the common action of fetching a list of follow requests targeting the given account ID.
|
||||
// The given slice 'followRequests' will be set to the result of the query, whatever it is.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) error
|
||||
GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) DBError
|
||||
|
||||
// GetAccountFollowing is a shortcut for the common action of fetching a list of accounts that accountID is following.
|
||||
// The given slice 'following' will be set to the result of the query, whatever it is.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) error
|
||||
GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) DBError
|
||||
|
||||
CountAccountFollowing(accountID string, localOnly bool) (int, DBError)
|
||||
|
||||
// GetAccountFollowers is a shortcut for the common action of fetching a list of accounts that accountID is followed by.
|
||||
// The given slice 'followers' will be set to the result of the query, whatever it is.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
//
|
||||
// If localOnly is set to true, then only followers from *this instance* will be returned.
|
||||
GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error
|
||||
GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) DBError
|
||||
|
||||
CountAccountFollowers(accountID string, localOnly bool) (int, DBError)
|
||||
|
||||
// GetAccountFaves is a shortcut for the common action of fetching a list of faves made by the given accountID.
|
||||
// The given slice 'faves' will be set to the result of the query, whatever it is.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) error
|
||||
GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) DBError
|
||||
|
||||
// GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
|
||||
GetAccountStatusesCount(accountID string) (int, error)
|
||||
CountAccountStatuses(accountID string) (int, DBError)
|
||||
|
||||
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
|
||||
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
|
||||
// be very memory intensive so you probably shouldn't do this!
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error)
|
||||
GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, DBError)
|
||||
|
||||
GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error)
|
||||
GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, DBError)
|
||||
|
||||
// GetAccountLastStatus simply gets the most recent status by the given account.
|
||||
// The given slice 'status' pointer will be set to the result of the query, whatever it is.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetAccountLastStatus(accountID string, status *gtsmodel.Status) error
|
||||
GetAccountLastStatus(accountID string, status *gtsmodel.Status) DBError
|
||||
|
||||
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
|
||||
SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error
|
||||
SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) DBError
|
||||
|
||||
// GetHeaderAvatarForAccountID gets the current avatar for the given account ID.
|
||||
// The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists.
|
||||
GetAccountAvatar(avatar *gtsmodel.MediaAttachment, accountID string) error
|
||||
|
||||
// GetAccountHeader gets the current header for the given account ID.
|
||||
// The passed mediaAttachment pointer will be populated with the value of the header, if it exists.
|
||||
GetAccountHeader(header *gtsmodel.MediaAttachment, accountID string) error
|
||||
// GetInstanceAccount returns the instance account for the given domain.
|
||||
// If domain is empty, this instance account will be returned.
|
||||
GetInstanceAccount(domain string) (*gtsmodel.Account, DBError)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,26 +9,26 @@ import (
|
|||
type Admin interface {
|
||||
// IsUsernameAvailable checks whether a given username is available on our domain.
|
||||
// Returns an error if the username is already taken, or something went wrong in the db.
|
||||
IsUsernameAvailable(username string) error
|
||||
IsUsernameAvailable(username string) DBError
|
||||
|
||||
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
|
||||
// Return an error if:
|
||||
// A) the email is already associated with an account
|
||||
// B) we block signups from this email domain
|
||||
// C) something went wrong in the db
|
||||
IsEmailAvailable(email string) error
|
||||
IsEmailAvailable(email string) DBError
|
||||
|
||||
// NewSignup creates a new user in the database with the given parameters.
|
||||
// By the time this function is called, it should be assumed that all the parameters have passed validation!
|
||||
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error)
|
||||
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, DBError)
|
||||
|
||||
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
|
||||
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
|
||||
// This is needed for things like serving files that belong to the instance and not an individual user/account.
|
||||
CreateInstanceAccount() error
|
||||
CreateInstanceAccount() DBError
|
||||
|
||||
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
|
||||
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
|
||||
// This is needed for things like serving instance information through /api/v1/instance
|
||||
CreateInstanceInstance() error
|
||||
CreateInstanceInstance() DBError
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,60 +5,60 @@ import "context"
|
|||
type Basic interface {
|
||||
// CreateTable creates a table for the given interface.
|
||||
// For implementations that don't use tables, this can just return nil.
|
||||
CreateTable(i interface{}) error
|
||||
CreateTable(i interface{}) DBError
|
||||
|
||||
// DropTable drops the table for the given interface.
|
||||
// For implementations that don't use tables, this can just return nil.
|
||||
DropTable(i interface{}) error
|
||||
DropTable(i interface{}) DBError
|
||||
|
||||
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
|
||||
// If the database implementation doesn't need to be stopped, this can just return nil.
|
||||
Stop(ctx context.Context) error
|
||||
Stop(ctx context.Context) DBError
|
||||
|
||||
// IsHealthy should return nil if the database connection is healthy, or an error if not.
|
||||
IsHealthy(ctx context.Context) error
|
||||
IsHealthy(ctx context.Context) DBError
|
||||
|
||||
// GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
|
||||
// for other implementations (for example, in-memory) it might just be the key of a map.
|
||||
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetByID(id string, i interface{}) error
|
||||
GetByID(id string, i interface{}) DBError
|
||||
|
||||
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
|
||||
// name of the key to select from.
|
||||
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetWhere(where []Where, i interface{}) error
|
||||
GetWhere(where []Where, i interface{}) DBError
|
||||
|
||||
// GetAll will try to get all entries of type i.
|
||||
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||
// In case of no entries, a 'no entries' error will be returned
|
||||
GetAll(i interface{}) error
|
||||
GetAll(i interface{}) DBError
|
||||
|
||||
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
|
||||
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||
Put(i interface{}) error
|
||||
Put(i interface{}) DBError
|
||||
|
||||
// Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
|
||||
// It is up to the implementation to figure out how to store it, and using what key.
|
||||
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||
Upsert(i interface{}, conflictColumn string) error
|
||||
Upsert(i interface{}, conflictColumn string) DBError
|
||||
|
||||
// UpdateByID updates i with id id.
|
||||
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||
UpdateByID(id string, i interface{}) error
|
||||
UpdateByID(id string, i interface{}) DBError
|
||||
|
||||
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
|
||||
UpdateOneByID(id string, key string, value interface{}, i interface{}) error
|
||||
UpdateOneByID(id string, key string, value interface{}, i interface{}) DBError
|
||||
|
||||
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
|
||||
UpdateWhere(where []Where, key string, value interface{}, i interface{}) error
|
||||
UpdateWhere(where []Where, key string, value interface{}, i interface{}) DBError
|
||||
|
||||
// DeleteByID removes i with id id.
|
||||
// If i didn't exist anyway, then no error should be returned.
|
||||
DeleteByID(id string, i interface{}) error
|
||||
DeleteByID(id string, i interface{}) DBError
|
||||
|
||||
// DeleteWhere deletes i where key = value
|
||||
// If i didn't exist anyway, then no error should be returned.
|
||||
DeleteWhere(where []Where, i interface{}) error
|
||||
DeleteWhere(where []Where, i interface{}) DBError
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,16 +18,12 @@
|
|||
|
||||
package db
|
||||
|
||||
// ErrNoEntries is to be returned from the DB interface when no entries are found for a given query.
|
||||
type ErrNoEntries struct{}
|
||||
import "fmt"
|
||||
|
||||
func (e ErrNoEntries) Error() string {
|
||||
return "no entries"
|
||||
}
|
||||
type DBError error
|
||||
|
||||
// ErrAlreadyExists is to be returned from the DB interface when an entry already exists for a given query or its constraints.
|
||||
type ErrAlreadyExists struct{}
|
||||
|
||||
func (e ErrAlreadyExists) Error() string {
|
||||
return "already exists"
|
||||
}
|
||||
var (
|
||||
ErrNoEntries DBError = fmt.Errorf("no entries")
|
||||
ErrAlreadyExists DBError = fmt.Errorf("already exists")
|
||||
ErrUnknown DBError = fmt.Errorf("unknown error")
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@ import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
|||
|
||||
type Instance interface {
|
||||
// GetUserCountForInstance returns the number of known accounts registered with the given domain.
|
||||
GetUserCountForInstance(domain string) (int, error)
|
||||
GetUserCountForInstance(domain string) (int, DBError)
|
||||
|
||||
// GetStatusCountForInstance returns the number of known statuses posted from the given domain.
|
||||
GetStatusCountForInstance(domain string) (int, error)
|
||||
GetStatusCountForInstance(domain string) (int, DBError)
|
||||
|
||||
// GetDomainCountForInstance returns the number of known instances known that the given domain federates with.
|
||||
GetDomainCountForInstance(domain string) (int, error)
|
||||
GetDomainCountForInstance(domain string) (int, DBError)
|
||||
|
||||
// GetAccountsForInstance returns a slice of accounts from the given instance, arranged by ID.
|
||||
GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error)
|
||||
GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, DBError)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,5 +4,5 @@ import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
|||
|
||||
type Notification interface {
|
||||
// GetNotificationsForAccount returns a list of notifications that pertain to the given accountID.
|
||||
GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error)
|
||||
GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, DBError)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,62 +1,100 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package pg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
||||
func (ps *postgresService) GetAccountHeader(header *gtsmodel.MediaAttachment, accountID string) error {
|
||||
acct := >smodel.Account{}
|
||||
if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if acct.HeaderMediaAttachmentID == "" {
|
||||
return db.ErrNoEntries{}
|
||||
}
|
||||
|
||||
if err := ps.conn.Model(header).Where("id = ?", acct.HeaderMediaAttachmentID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
type accountDB struct {
|
||||
config *config.Config
|
||||
conn *pg.DB
|
||||
log *logrus.Logger
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetAccountAvatar(avatar *gtsmodel.MediaAttachment, accountID string) error {
|
||||
acct := >smodel.Account{}
|
||||
if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if acct.AvatarMediaAttachmentID == "" {
|
||||
return db.ErrNoEntries{}
|
||||
}
|
||||
|
||||
if err := ps.conn.Model(avatar).Where("id = ?", acct.AvatarMediaAttachmentID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query {
|
||||
return a.conn.Model(account).
|
||||
Relation("AvatarMediaAttachment").
|
||||
Relation("HeaderMediaAttachment")
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetAccountLastStatus(accountID string, status *gtsmodel.Status) error {
|
||||
if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
|
||||
func (a *accountDB) processResponse(account *gtsmodel.Account, err error) (*gtsmodel.Account, db.DBError) {
|
||||
switch err {
|
||||
case pg.ErrNoRows:
|
||||
return nil, db.ErrNoEntries
|
||||
case nil:
|
||||
return account, nil
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.DBError) {
|
||||
account := >smodel.Account{}
|
||||
|
||||
q := a.newAccountQ(account).
|
||||
Where("account.id = ?", id)
|
||||
|
||||
return a.processResponse(account, q.Select())
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.DBError) {
|
||||
account := >smodel.Account{}
|
||||
|
||||
q := a.newAccountQ(account).
|
||||
Where("account.uri = ?", uri)
|
||||
|
||||
return a.processResponse(account, q.Select())
|
||||
}
|
||||
|
||||
func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.DBError) {
|
||||
account := >smodel.Account{}
|
||||
|
||||
q := a.newAccountQ(account)
|
||||
|
||||
if domain == "" {
|
||||
q = q.
|
||||
Where("account.username = ?", domain).
|
||||
Where("account.domain = ?", domain)
|
||||
} else {
|
||||
q = q.
|
||||
Where("account.username = ?", domain).
|
||||
Where("? IS NULL", pg.Ident("domain"))
|
||||
}
|
||||
|
||||
return a.processResponse(account, q.Select())
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountLastStatus(accountID string, status *gtsmodel.Status) db.DBError {
|
||||
if err := a.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries{}
|
||||
return db.ErrNoEntries
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
@ -64,7 +102,7 @@ func (ps *postgresService) GetAccountLastStatus(accountID string, status *gtsmod
|
|||
|
||||
}
|
||||
|
||||
func (ps *postgresService) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error {
|
||||
func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.DBError {
|
||||
if mediaAttachment.Avatar && mediaAttachment.Header {
|
||||
return errors.New("one media attachment cannot be both header and avatar")
|
||||
}
|
||||
|
|
@ -79,47 +117,47 @@ func (ps *postgresService) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.Me
|
|||
}
|
||||
|
||||
// TODO: there are probably more side effects here that need to be handled
|
||||
if _, err := ps.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
|
||||
if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := ps.conn.Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
|
||||
if _, err := a.conn.Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetAccountByUserID(userID string, account *gtsmodel.Account) error {
|
||||
func (a *accountDB) GetAccountByUserID(userID string, account *gtsmodel.Account) db.DBError {
|
||||
user := >smodel.User{
|
||||
ID: userID,
|
||||
}
|
||||
if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
|
||||
if err := a.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries{}
|
||||
return db.ErrNoEntries
|
||||
}
|
||||
return err
|
||||
}
|
||||
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
|
||||
if err := a.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries{}
|
||||
return db.ErrNoEntries
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetLocalAccountByUsername(username string, account *gtsmodel.Account) error {
|
||||
if err := ps.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil {
|
||||
func (a *accountDB) GetLocalAccountByUsername(username string, account *gtsmodel.Account) db.DBError {
|
||||
if err := a.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries{}
|
||||
return db.ErrNoEntries
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) error {
|
||||
if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
|
||||
func (a *accountDB) GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) db.DBError {
|
||||
if err := a.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -128,8 +166,8 @@ func (ps *postgresService) GetAccountFollowRequests(accountID string, followRequ
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) error {
|
||||
if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
|
||||
func (a *accountDB) GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) db.DBError {
|
||||
if err := a.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -138,9 +176,13 @@ func (ps *postgresService) GetAccountFollowing(accountID string, following *[]gt
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error {
|
||||
func (a *accountDB) CountAccountFollowing(accountID string, localOnly bool) (int, db.DBError) {
|
||||
return a.conn.Model(&[]*gtsmodel.Follow{}).Where("account_id = ?", accountID).Count()
|
||||
}
|
||||
|
||||
q := ps.conn.Model(followers)
|
||||
func (a *accountDB) GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) db.DBError {
|
||||
|
||||
q := a.conn.Model(followers)
|
||||
|
||||
if localOnly {
|
||||
// for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
|
||||
|
|
@ -168,8 +210,12 @@ func (ps *postgresService) GetAccountFollowers(accountID string, followers *[]gt
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) error {
|
||||
if err := ps.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil {
|
||||
func (a *accountDB) CountAccountFollowers(accountID string, localOnly bool) (int, db.DBError) {
|
||||
return a.conn.Model(&[]*gtsmodel.Follow{}).Where("target_account_id = ?", accountID).Count()
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) db.DBError {
|
||||
if err := a.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -178,22 +224,15 @@ func (ps *postgresService) GetAccountFaves(accountID string, faves *[]gtsmodel.S
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetAccountStatusesCount(accountID string) (int, error) {
|
||||
count, err := ps.conn.Model(>smodel.Status{}).Where("account_id = ?", accountID).Count()
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
func (a *accountDB) CountAccountStatuses(accountID string) (int, db.DBError) {
|
||||
return a.conn.Model(>smodel.Status{}).Where("account_id = ?", accountID).Count()
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error) {
|
||||
ps.log.Debugf("getting statuses for account %s", accountID)
|
||||
func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.DBError) {
|
||||
a.log.Debugf("getting statuses for account %s", accountID)
|
||||
statuses := []*gtsmodel.Status{}
|
||||
|
||||
q := ps.conn.Model(&statuses).Order("id DESC")
|
||||
q := a.conn.Model(&statuses).Order("id DESC")
|
||||
if accountID != "" {
|
||||
q = q.Where("account_id = ?", accountID)
|
||||
}
|
||||
|
|
@ -222,15 +261,57 @@ func (ps *postgresService) GetAccountStatuses(accountID string, limit int, exclu
|
|||
|
||||
if err := q.Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return nil, db.ErrNoEntries{}
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(statuses) == 0 {
|
||||
return nil, db.ErrNoEntries{}
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
|
||||
ps.log.Debugf("returning statuses for account %s", accountID)
|
||||
a.log.Debugf("returning statuses for account %s", accountID)
|
||||
return statuses, nil
|
||||
}
|
||||
|
||||
func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.DBError) {
|
||||
blocks := []*gtsmodel.Block{}
|
||||
|
||||
fq := a.conn.Model(&blocks).
|
||||
Where("block.account_id = ?", accountID).
|
||||
Relation("TargetAccount").
|
||||
Order("block.id DESC")
|
||||
|
||||
if maxID != "" {
|
||||
fq = fq.Where("block.id < ?", maxID)
|
||||
}
|
||||
|
||||
if sinceID != "" {
|
||||
fq = fq.Where("block.id > ?", sinceID)
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
fq = fq.Limit(limit)
|
||||
}
|
||||
|
||||
err := fq.Select()
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return nil, "", "", db.ErrNoEntries
|
||||
}
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
return nil, "", "", db.ErrNoEntries
|
||||
}
|
||||
|
||||
accounts := []*gtsmodel.Account{}
|
||||
for _, b := range blocks {
|
||||
accounts = append(accounts, b.TargetAccount)
|
||||
}
|
||||
|
||||
nextMaxID := blocks[len(blocks)-1].ID
|
||||
prevMinID := blocks[0].ID
|
||||
return accounts, nextMaxID, prevMinID, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,25 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package pg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
|
|
@ -10,17 +29,27 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func (ps *postgresService) IsUsernameAvailable(username string) error {
|
||||
type adminDB struct {
|
||||
config *config.Config
|
||||
conn *pg.DB
|
||||
log *logrus.Logger
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (a *adminDB) IsUsernameAvailable(username string) db.DBError {
|
||||
// if no error we fail because it means we found something
|
||||
// if error but it's not pg.ErrNoRows then we fail
|
||||
// if err is pg.ErrNoRows we're good, we found nothing so continue
|
||||
if err := ps.conn.Model(>smodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
|
||||
if err := a.conn.Model(>smodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
|
||||
return fmt.Errorf("username %s already in use", username)
|
||||
} else if err != pg.ErrNoRows {
|
||||
return fmt.Errorf("db error: %s", err)
|
||||
|
|
@ -28,7 +57,7 @@ func (ps *postgresService) IsUsernameAvailable(username string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) IsEmailAvailable(email string) error {
|
||||
func (a *adminDB) IsEmailAvailable(email string) db.DBError {
|
||||
// parse the domain from the email
|
||||
m, err := mail.ParseAddress(email)
|
||||
if err != nil {
|
||||
|
|
@ -37,7 +66,7 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
|
|||
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
|
||||
|
||||
// check if the email domain is blocked
|
||||
if err := ps.conn.Model(>smodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
|
||||
if err := a.conn.Model(>smodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
|
||||
// fail because we found something
|
||||
return fmt.Errorf("email domain %s is blocked", domain)
|
||||
} else if err != pg.ErrNoRows {
|
||||
|
|
@ -46,7 +75,7 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
|
|||
}
|
||||
|
||||
// check if this email is associated with a user already
|
||||
if err := ps.conn.Model(>smodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
|
||||
if err := a.conn.Model(>smodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
|
||||
// fail because we found something
|
||||
return fmt.Errorf("email %s already in use", email)
|
||||
} else if err != pg.ErrNoRows {
|
||||
|
|
@ -56,16 +85,16 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error) {
|
||||
func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.DBError) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
ps.log.Errorf("error creating new rsa key: %s", err)
|
||||
a.log.Errorf("error creating new rsa key: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if something went wrong while creating a user, we might already have an account, so check here first...
|
||||
a := >smodel.Account{}
|
||||
err = ps.conn.Model(a).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
|
||||
acct := >smodel.Account{}
|
||||
err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
|
||||
if err != nil {
|
||||
// there's been an actual error
|
||||
if err != pg.ErrNoRows {
|
||||
|
|
@ -73,13 +102,13 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
|
|||
}
|
||||
|
||||
// we just don't have an account yet create one
|
||||
newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host)
|
||||
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
|
||||
newAccountID, err := id.NewRandomULID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a = >smodel.Account{
|
||||
acct = >smodel.Account{
|
||||
ID: newAccountID,
|
||||
Username: username,
|
||||
DisplayName: username,
|
||||
|
|
@ -96,7 +125,7 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
|
|||
FollowingURI: newAccountURIs.FollowingURI,
|
||||
FeaturedCollectionURI: newAccountURIs.CollectionURI,
|
||||
}
|
||||
if _, err = ps.conn.Model(a).Insert(); err != nil {
|
||||
if _, err = a.conn.Model(acct).Insert(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
|
@ -113,7 +142,7 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
|
|||
|
||||
u := >smodel.User{
|
||||
ID: newUserID,
|
||||
AccountID: a.ID,
|
||||
AccountID: acct.ID,
|
||||
EncryptedPassword: string(pw),
|
||||
SignUpIP: signUpIP.To4(),
|
||||
Locale: locale,
|
||||
|
|
@ -132,18 +161,18 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
|
|||
u.Moderator = true
|
||||
}
|
||||
|
||||
if _, err = ps.conn.Model(u).Insert(); err != nil {
|
||||
if _, err = a.conn.Model(u).Insert(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) CreateInstanceAccount() error {
|
||||
username := ps.config.Host
|
||||
func (a *adminDB) CreateInstanceAccount() db.DBError {
|
||||
username := a.config.Host
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
ps.log.Errorf("error creating new rsa key: %s", err)
|
||||
a.log.Errorf("error creating new rsa key: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -152,10 +181,10 @@ func (ps *postgresService) CreateInstanceAccount() error {
|
|||
return err
|
||||
}
|
||||
|
||||
newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host)
|
||||
a := >smodel.Account{
|
||||
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
|
||||
acct := >smodel.Account{
|
||||
ID: aID,
|
||||
Username: ps.config.Host,
|
||||
Username: a.config.Host,
|
||||
DisplayName: username,
|
||||
URL: newAccountURIs.UserURL,
|
||||
PrivateKey: key,
|
||||
|
|
@ -169,19 +198,19 @@ func (ps *postgresService) CreateInstanceAccount() error {
|
|||
FollowingURI: newAccountURIs.FollowingURI,
|
||||
FeaturedCollectionURI: newAccountURIs.CollectionURI,
|
||||
}
|
||||
inserted, err := ps.conn.Model(a).Where("username = ?", username).SelectOrInsert()
|
||||
inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if inserted {
|
||||
ps.log.Infof("created instance account %s with id %s", username, a.ID)
|
||||
a.log.Infof("created instance account %s with id %s", username, acct.ID)
|
||||
} else {
|
||||
ps.log.Infof("instance account %s already exists with id %s", username, a.ID)
|
||||
a.log.Infof("instance account %s already exists with id %s", username, acct.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) CreateInstanceInstance() error {
|
||||
func (a *adminDB) CreateInstanceInstance() db.DBError {
|
||||
iID, err := id.NewRandomULID()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -189,18 +218,18 @@ func (ps *postgresService) CreateInstanceInstance() error {
|
|||
|
||||
i := >smodel.Instance{
|
||||
ID: iID,
|
||||
Domain: ps.config.Host,
|
||||
Title: ps.config.Host,
|
||||
URI: fmt.Sprintf("%s://%s", ps.config.Protocol, ps.config.Host),
|
||||
Domain: a.config.Host,
|
||||
Title: a.config.Host,
|
||||
URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host),
|
||||
}
|
||||
inserted, err := ps.conn.Model(i).Where("domain = ?", ps.config.Host).SelectOrInsert()
|
||||
inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if inserted {
|
||||
ps.log.Infof("created instance instance %s with id %s", ps.config.Host, i.ID)
|
||||
a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID)
|
||||
} else {
|
||||
ps.log.Infof("instance instance %s already exists with id %s", ps.config.Host, i.ID)
|
||||
a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,25 +1,55 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package pg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
)
|
||||
|
||||
func (ps *postgresService) Put(i interface{}) error {
|
||||
_, err := ps.conn.Model(i).Insert(i)
|
||||
type basicDB struct {
|
||||
config *config.Config
|
||||
conn *pg.DB
|
||||
log *logrus.Logger
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (b *basicDB) Put(i interface{}) db.DBError {
|
||||
_, err := b.conn.Model(i).Insert(i)
|
||||
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
|
||||
return db.ErrAlreadyExists{}
|
||||
return db.ErrAlreadyExists
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetByID(id string, i interface{}) error {
|
||||
if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil {
|
||||
func (b *basicDB) GetByID(id string, i interface{}) db.DBError {
|
||||
if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries{}
|
||||
return db.ErrNoEntries
|
||||
}
|
||||
return err
|
||||
|
||||
|
|
@ -27,12 +57,12 @@ func (ps *postgresService) GetByID(id string, i interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error {
|
||||
func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.DBError {
|
||||
if len(where) == 0 {
|
||||
return errors.New("no queries provided")
|
||||
}
|
||||
|
||||
q := ps.conn.Model(i)
|
||||
q := b.conn.Model(i)
|
||||
for _, w := range where {
|
||||
|
||||
if w.Value == nil {
|
||||
|
|
@ -48,25 +78,25 @@ func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error {
|
|||
|
||||
if err := q.Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries{}
|
||||
return db.ErrNoEntries
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetAll(i interface{}) error {
|
||||
if err := ps.conn.Model(i).Select(); err != nil {
|
||||
func (b *basicDB) GetAll(i interface{}) db.DBError {
|
||||
if err := b.conn.Model(i).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries{}
|
||||
return db.ErrNoEntries
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
|
||||
if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
|
||||
func (b *basicDB) DeleteByID(id string, i interface{}) db.DBError {
|
||||
if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
|
||||
// if there are no rows *anyway* then that's fine
|
||||
// just return err if there's an actual error
|
||||
if err != pg.ErrNoRows {
|
||||
|
|
@ -76,12 +106,12 @@ func (ps *postgresService) DeleteByID(id string, i interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error {
|
||||
func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.DBError {
|
||||
if len(where) == 0 {
|
||||
return errors.New("no queries provided")
|
||||
}
|
||||
|
||||
q := ps.conn.Model(i)
|
||||
q := b.conn.Model(i)
|
||||
for _, w := range where {
|
||||
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
|
||||
}
|
||||
|
|
@ -95,3 +125,76 @@ func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.DBError {
|
||||
if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *basicDB) UpdateByID(id string, i interface{}) db.DBError {
|
||||
if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.DBError {
|
||||
_, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.DBError {
|
||||
q := b.conn.Model(i)
|
||||
|
||||
for _, w := range where {
|
||||
if w.Value == nil {
|
||||
q = q.Where("? IS NULL", pg.Ident(w.Key))
|
||||
} else {
|
||||
if w.CaseInsensitive {
|
||||
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
|
||||
} else {
|
||||
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
q = q.Set("? = ?", pg.Safe(key), value)
|
||||
|
||||
_, err := q.Update()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *basicDB) CreateTable(i interface{}) db.DBError {
|
||||
return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{
|
||||
IfNotExists: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (b *basicDB) DropTable(i interface{}) db.DBError {
|
||||
return b.conn.Model(i).DropTable(&orm.DropTableOptions{
|
||||
IfExists: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (b *basicDB) IsHealthy(ctx context.Context) db.DBError {
|
||||
return b.conn.Ping(ctx)
|
||||
}
|
||||
|
||||
func (b *basicDB) Stop(ctx context.Context) db.DBError {
|
||||
b.log.Info("closing db connection")
|
||||
if err := b.conn.Close(); err != nil {
|
||||
// only cancel if there's a problem closing the db
|
||||
b.cancel()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,15 +19,26 @@
|
|||
package pg
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
||||
func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
|
||||
q := ps.conn.Model(&[]*gtsmodel.Account{})
|
||||
type instanceDB struct {
|
||||
config *config.Config
|
||||
conn *pg.DB
|
||||
log *logrus.Logger
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
if domain == ps.config.Host {
|
||||
func (i *instanceDB) GetUserCountForInstance(domain string) (int, db.DBError) {
|
||||
q := i.conn.Model(&[]*gtsmodel.Account{})
|
||||
|
||||
if domain == i.config.Host {
|
||||
// if the domain is *this* domain, just count where the domain field is null
|
||||
q = q.Where("? IS NULL", pg.Ident("domain"))
|
||||
} else {
|
||||
|
|
@ -40,10 +51,10 @@ func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
|
|||
return q.Count()
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) {
|
||||
q := ps.conn.Model(&[]*gtsmodel.Status{})
|
||||
func (i *instanceDB) GetStatusCountForInstance(domain string) (int, db.DBError) {
|
||||
q := i.conn.Model(&[]*gtsmodel.Status{})
|
||||
|
||||
if domain == ps.config.Host {
|
||||
if domain == i.config.Host {
|
||||
// if the domain is *this* domain, just count where local is true
|
||||
q = q.Where("local = ?", true)
|
||||
} else {
|
||||
|
|
@ -55,10 +66,10 @@ func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error)
|
|||
return q.Count()
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) {
|
||||
q := ps.conn.Model(&[]*gtsmodel.Instance{})
|
||||
func (i *instanceDB) GetDomainCountForInstance(domain string) (int, db.DBError) {
|
||||
q := i.conn.Model(&[]*gtsmodel.Instance{})
|
||||
|
||||
if domain == ps.config.Host {
|
||||
if domain == i.config.Host {
|
||||
// if the domain is *this* domain, just count other instances it knows about
|
||||
// exclude domains that are blocked
|
||||
q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
|
||||
|
|
@ -70,12 +81,12 @@ func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error)
|
|||
return q.Count()
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) {
|
||||
ps.log.Debug("GetAccountsForInstance")
|
||||
func (i *instanceDB) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.DBError) {
|
||||
i.log.Debug("GetAccountsForInstance")
|
||||
|
||||
accounts := []*gtsmodel.Account{}
|
||||
|
||||
q := ps.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
|
||||
q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("id < ?", maxID)
|
||||
|
|
@ -88,13 +99,13 @@ func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, l
|
|||
err := q.Select()
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return nil, db.ErrNoEntries{}
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(accounts) == 0 {
|
||||
return nil, db.ErrNoEntries{}
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
|
||||
return accounts, nil
|
||||
|
|
|
|||
|
|
@ -24,44 +24,30 @@ import (
|
|||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
||||
func (ps *postgresService) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) {
|
||||
blocks := []*gtsmodel.Block{}
|
||||
func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.DBError) {
|
||||
notifications := []*gtsmodel.Notification{}
|
||||
|
||||
fq := ps.conn.Model(&blocks).
|
||||
Where("block.account_id = ?", accountID).
|
||||
Relation("TargetAccount").
|
||||
Order("block.id DESC")
|
||||
q := ps.conn.Model(¬ifications).Where("target_account_id = ?", accountID)
|
||||
|
||||
if maxID != "" {
|
||||
fq = fq.Where("block.id < ?", maxID)
|
||||
q = q.Where("id < ?", maxID)
|
||||
}
|
||||
|
||||
if sinceID != "" {
|
||||
fq = fq.Where("block.id > ?", sinceID)
|
||||
q = q.Where("id > ?", sinceID)
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
fq = fq.Limit(limit)
|
||||
if limit != 0 {
|
||||
q = q.Limit(limit)
|
||||
}
|
||||
|
||||
err := fq.Select()
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return nil, "", "", db.ErrNoEntries{}
|
||||
q = q.Order("created_at DESC")
|
||||
|
||||
if err := q.Select(); err != nil {
|
||||
if err != pg.ErrNoRows {
|
||||
return nil, err
|
||||
}
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
return nil, "", "", db.ErrNoEntries{}
|
||||
}
|
||||
|
||||
accounts := []*gtsmodel.Account{}
|
||||
for _, b := range blocks {
|
||||
accounts = append(accounts, b.TargetAccount)
|
||||
}
|
||||
|
||||
nextMaxID := blocks[len(blocks)-1].ID
|
||||
prevMinID := blocks[0].ID
|
||||
return accounts, nextMaxID, prevMinID, nil
|
||||
return notifications, nil
|
||||
}
|
||||
|
|
@ -31,7 +31,6 @@ import (
|
|||
|
||||
"github.com/go-pg/pg/extra/pgdebug"
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
|
|
@ -41,6 +40,14 @@ import (
|
|||
|
||||
// postgresService satisfies the DB interface
|
||||
type postgresService struct {
|
||||
db.Account
|
||||
db.Admin
|
||||
db.Basic
|
||||
db.Instance
|
||||
db.Notification
|
||||
db.Relationship
|
||||
db.Status
|
||||
db.Timeline
|
||||
config *config.Config
|
||||
conn *pg.DB
|
||||
log *logrus.Logger
|
||||
|
|
@ -85,6 +92,48 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
|
|||
log.Infof("connected to postgres version: %s", version)
|
||||
|
||||
ps := &postgresService{
|
||||
Account: &accountDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
cancel: cancel,
|
||||
},
|
||||
Admin: &adminDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
cancel: cancel,
|
||||
},
|
||||
Basic: &basicDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
cancel: cancel,
|
||||
},
|
||||
Instance: &instanceDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
cancel: cancel,
|
||||
},
|
||||
Relationship: &relationshipDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
cancel: cancel,
|
||||
},
|
||||
Status: &statusDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
cancel: cancel,
|
||||
},
|
||||
Timeline: &timelineDB{
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
cancel: cancel,
|
||||
},
|
||||
config: c,
|
||||
conn: conn,
|
||||
log: log,
|
||||
|
|
@ -193,89 +242,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
|
|||
return options, nil
|
||||
}
|
||||
|
||||
/*
|
||||
BASIC DB FUNCTIONALITY
|
||||
*/
|
||||
|
||||
func (ps *postgresService) CreateTable(i interface{}) error {
|
||||
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
|
||||
IfNotExists: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (ps *postgresService) DropTable(i interface{}) error {
|
||||
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
|
||||
IfExists: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (ps *postgresService) Stop(ctx context.Context) error {
|
||||
ps.log.Info("closing db connection")
|
||||
if err := ps.conn.Close(); err != nil {
|
||||
// only cancel if there's a problem closing the db
|
||||
ps.cancel()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) IsHealthy(ctx context.Context) error {
|
||||
return ps.conn.Ping(ctx)
|
||||
}
|
||||
|
||||
func (ps *postgresService) CreateSchema(ctx context.Context) error {
|
||||
models := []interface{}{
|
||||
(*gtsmodel.Account)(nil),
|
||||
(*gtsmodel.Status)(nil),
|
||||
(*gtsmodel.User)(nil),
|
||||
}
|
||||
ps.log.Info("creating db schema")
|
||||
|
||||
for _, model := range models {
|
||||
err := ps.conn.Model(model).CreateTable(&orm.CreateTableOptions{
|
||||
IfNotExists: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ps.log.Info("db schema created")
|
||||
return nil
|
||||
}
|
||||
|
||||
/*
|
||||
HANDY SHORTCUTS
|
||||
*/
|
||||
|
||||
func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error) {
|
||||
notifications := []*gtsmodel.Notification{}
|
||||
|
||||
q := ps.conn.Model(¬ifications).Where("target_account_id = ?", accountID)
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("id < ?", maxID)
|
||||
}
|
||||
|
||||
if sinceID != "" {
|
||||
q = q.Where("id > ?", sinceID)
|
||||
}
|
||||
|
||||
if limit != 0 {
|
||||
q = q.Limit(limit)
|
||||
}
|
||||
|
||||
q = q.Order("created_at DESC")
|
||||
|
||||
if err := q.Select(); err != nil {
|
||||
if err != pg.ErrNoRows {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
}
|
||||
return notifications, nil
|
||||
}
|
||||
|
||||
/*
|
||||
CONVERSION FUNCTIONS
|
||||
*/
|
||||
|
|
|
|||
47
internal/db/pg/pg_test.go
Normal file
47
internal/db/pg/pg_test.go
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package pg_test
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
)
|
||||
|
||||
type PGStandardTestSuite struct {
|
||||
// standard suite interfaces
|
||||
suite.Suite
|
||||
config *config.Config
|
||||
db db.DB
|
||||
log *logrus.Logger
|
||||
|
||||
// standard suite models
|
||||
testTokens map[string]*oauth.Token
|
||||
testClients map[string]*oauth.Client
|
||||
testApplications map[string]*gtsmodel.Application
|
||||
testUsers map[string]*gtsmodel.User
|
||||
testAccounts map[string]*gtsmodel.Account
|
||||
testAttachments map[string]*gtsmodel.MediaAttachment
|
||||
testStatuses map[string]*gtsmodel.Status
|
||||
testTags map[string]*gtsmodel.Tag
|
||||
testMentions map[string]*gtsmodel.Mention
|
||||
}
|
||||
|
|
@ -1,17 +1,45 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package pg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
||||
func (ps *postgresService) Blocked(account1 string, account2 string) (bool, error) {
|
||||
type relationshipDB struct {
|
||||
config *config.Config
|
||||
conn *pg.DB
|
||||
log *logrus.Logger
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (r *relationshipDB) Blocked(account1 string, account2 string) (bool, db.DBError) {
|
||||
// TODO: check domain blocks as well
|
||||
var blocked bool
|
||||
if err := ps.conn.Model(>smodel.Block{}).
|
||||
if err := r.conn.Model(>smodel.Block{}).
|
||||
Where("account_id = ?", account1).Where("target_account_id = ?", account2).
|
||||
WhereOr("target_account_id = ?", account1).Where("account_id = ?", account2).
|
||||
Select(); err != nil {
|
||||
|
|
@ -25,83 +53,83 @@ func (ps *postgresService) Blocked(account1 string, account2 string) (bool, erro
|
|||
return blocked, nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) {
|
||||
r := >smodel.Relationship{
|
||||
func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.DBError) {
|
||||
rel := >smodel.Relationship{
|
||||
ID: targetAccount,
|
||||
}
|
||||
|
||||
// check if the requesting account follows the target account
|
||||
follow := >smodel.Follow{}
|
||||
if err := ps.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
|
||||
if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
|
||||
if err != pg.ErrNoRows {
|
||||
// a proper error
|
||||
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
|
||||
}
|
||||
// no follow exists so these are all false
|
||||
r.Following = false
|
||||
r.ShowingReblogs = false
|
||||
r.Notifying = false
|
||||
rel.Following = false
|
||||
rel.ShowingReblogs = false
|
||||
rel.Notifying = false
|
||||
} else {
|
||||
// follow exists so we can fill these fields out...
|
||||
r.Following = true
|
||||
r.ShowingReblogs = follow.ShowReblogs
|
||||
r.Notifying = follow.Notify
|
||||
rel.Following = true
|
||||
rel.ShowingReblogs = follow.ShowReblogs
|
||||
rel.Notifying = follow.Notify
|
||||
}
|
||||
|
||||
// check if the target account follows the requesting account
|
||||
followedBy, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
|
||||
followedBy, err := r.conn.Model(>smodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
|
||||
}
|
||||
r.FollowedBy = followedBy
|
||||
rel.FollowedBy = followedBy
|
||||
|
||||
// check if the requesting account blocks the target account
|
||||
blocking, err := ps.conn.Model(>smodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
|
||||
blocking, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
|
||||
}
|
||||
r.Blocking = blocking
|
||||
rel.Blocking = blocking
|
||||
|
||||
// check if the target account blocks the requesting account
|
||||
blockedBy, err := ps.conn.Model(>smodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
|
||||
blockedBy, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
|
||||
}
|
||||
r.BlockedBy = blockedBy
|
||||
rel.BlockedBy = blockedBy
|
||||
|
||||
// check if there's a pending following request from requesting account to target account
|
||||
requested, err := ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
|
||||
requested, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
|
||||
}
|
||||
r.Requested = requested
|
||||
rel.Requested = requested
|
||||
|
||||
return r, nil
|
||||
return rel, nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
|
||||
func (r *relationshipDB) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.DBError) {
|
||||
if sourceAccount == nil || targetAccount == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
|
||||
return r.conn.Model(>smodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
|
||||
}
|
||||
|
||||
func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
|
||||
func (r *relationshipDB) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.DBError) {
|
||||
if sourceAccount == nil || targetAccount == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
|
||||
return r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
|
||||
}
|
||||
|
||||
func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error) {
|
||||
func (r *relationshipDB) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.DBError) {
|
||||
if account1 == nil || account2 == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// make sure account 1 follows account 2
|
||||
f1, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists()
|
||||
f1, err := r.conn.Model(>smodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists()
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return false, nil
|
||||
|
|
@ -110,7 +138,7 @@ func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmode
|
|||
}
|
||||
|
||||
// make sure account 2 follows account 1
|
||||
f2, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", account2.ID).Where("target_account_id = ?", account1.ID).Exists()
|
||||
f2, err := r.conn.Model(>smodel.Follow{}).Where("account_id = ?", account2.ID).Where("target_account_id = ?", account1.ID).Exists()
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return false, nil
|
||||
|
|
@ -121,12 +149,12 @@ func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmode
|
|||
return f1 && f2, nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
|
||||
func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.DBError) {
|
||||
// make sure the original follow request exists
|
||||
fr := >smodel.FollowRequest{}
|
||||
if err := ps.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
|
||||
if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
|
||||
if err == pg.ErrMultiRows {
|
||||
return nil, db.ErrNoEntries{}
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -140,12 +168,12 @@ func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAcc
|
|||
}
|
||||
|
||||
// if the follow already exists, just update the URI -- we don't need to do anything else
|
||||
if _, err := ps.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
|
||||
if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// now remove the follow request
|
||||
if _, err := ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
|
||||
if _, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,39 +20,90 @@ package pg
|
|||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/go-pg/pg/v10/orm"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
||||
func (ps *postgresService) StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) {
|
||||
type statusDB struct {
|
||||
config *config.Config
|
||||
conn *pg.DB
|
||||
log *logrus.Logger
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (s *statusDB) newStatusQ(status *gtsmodel.Status) *orm.Query {
|
||||
return s.conn.Model(status).
|
||||
Relation("Account").
|
||||
Relation("InReplyTo").
|
||||
Relation("InReplyToAccount").
|
||||
Relation("BoostOf").
|
||||
Relation("BoostOfAccount").
|
||||
Relation("CreatedWithApplication")
|
||||
}
|
||||
|
||||
func (s *statusDB) processResponse(status *gtsmodel.Status, err error) (*gtsmodel.Status, db.DBError) {
|
||||
switch err {
|
||||
case pg.ErrNoRows:
|
||||
return nil, db.ErrNoEntries
|
||||
case nil:
|
||||
return status, nil
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.DBError) {
|
||||
status := >smodel.Status{}
|
||||
|
||||
q := s.newStatusQ(status).
|
||||
Where("status.id = ?", id)
|
||||
|
||||
return s.processResponse(status, q.Select())
|
||||
}
|
||||
|
||||
func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.DBError) {
|
||||
status := >smodel.Status{}
|
||||
|
||||
q := s.newStatusQ(status).
|
||||
Where("LOWER(status.uri) = LOWER(?)", uri)
|
||||
|
||||
return s.processResponse(status, q.Select())
|
||||
}
|
||||
|
||||
func (s *statusDB) StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.DBError) {
|
||||
parents := []*gtsmodel.Status{}
|
||||
ps.statusParent(status, &parents, onlyDirect)
|
||||
s.statusParent(status, &parents, onlyDirect)
|
||||
|
||||
return parents, nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
|
||||
func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
|
||||
if status.InReplyToID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
parentStatus := >smodel.Status{}
|
||||
if err := ps.conn.Model(parentStatus).Where("id = ?", status.InReplyToID).Select(); err == nil {
|
||||
if err := s.conn.Model(parentStatus).Where("id = ?", status.InReplyToID).Select(); err == nil {
|
||||
*foundStatuses = append(*foundStatuses, parentStatus)
|
||||
}
|
||||
|
||||
if onlyDirect {
|
||||
return
|
||||
}
|
||||
ps.statusParent(parentStatus, foundStatuses, false)
|
||||
s.statusParent(parentStatus, foundStatuses, false)
|
||||
}
|
||||
|
||||
func (ps *postgresService) StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) {
|
||||
func (s *statusDB) StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.DBError) {
|
||||
foundStatuses := &list.List{}
|
||||
foundStatuses.PushFront(status)
|
||||
ps.statusChildren(status, foundStatuses, onlyDirect, minID)
|
||||
s.statusChildren(status, foundStatuses, onlyDirect, minID)
|
||||
|
||||
children := []*gtsmodel.Status{}
|
||||
for e := foundStatuses.Front(); e != nil; e = e.Next() {
|
||||
|
|
@ -70,10 +121,10 @@ func (ps *postgresService) StatusChildren(status *gtsmodel.Status, onlyDirect bo
|
|||
return children, nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
|
||||
func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
|
||||
immediateChildren := []*gtsmodel.Status{}
|
||||
|
||||
q := ps.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
|
||||
q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
|
||||
if minID != "" {
|
||||
q = q.Where("status.id > ?", minID)
|
||||
}
|
||||
|
|
@ -100,43 +151,43 @@ func (ps *postgresService) statusChildren(status *gtsmodel.Status, foundStatuses
|
|||
if onlyDirect {
|
||||
return
|
||||
}
|
||||
ps.statusChildren(child, foundStatuses, false, minID)
|
||||
s.statusChildren(child, foundStatuses, false, minID)
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetReplyCountForStatus(status *gtsmodel.Status) (int, error) {
|
||||
return ps.conn.Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
|
||||
func (s *statusDB) GetReplyCountForStatus(status *gtsmodel.Status) (int, db.DBError) {
|
||||
return s.conn.Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetReblogCountForStatus(status *gtsmodel.Status) (int, error) {
|
||||
return ps.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
|
||||
func (s *statusDB) GetReblogCountForStatus(status *gtsmodel.Status) (int, db.DBError) {
|
||||
return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetFaveCountForStatus(status *gtsmodel.Status) (int, error) {
|
||||
return ps.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
|
||||
func (s *statusDB) GetFaveCountForStatus(status *gtsmodel.Status) (int, db.DBError) {
|
||||
return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
|
||||
}
|
||||
|
||||
func (ps *postgresService) StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error) {
|
||||
return ps.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
|
||||
func (s *statusDB) StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
|
||||
return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
|
||||
}
|
||||
|
||||
func (ps *postgresService) StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error) {
|
||||
return ps.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
|
||||
func (s *statusDB) StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
|
||||
return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
|
||||
}
|
||||
|
||||
func (ps *postgresService) StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error) {
|
||||
return ps.conn.Model(>smodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
|
||||
func (s *statusDB) StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
|
||||
return s.conn.Model(>smodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
|
||||
}
|
||||
|
||||
func (ps *postgresService) StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error) {
|
||||
return ps.conn.Model(>smodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
|
||||
func (s *statusDB) StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
|
||||
return s.conn.Model(>smodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
|
||||
}
|
||||
|
||||
func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
|
||||
func (s *statusDB) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, db.DBError) {
|
||||
accounts := []*gtsmodel.Account{}
|
||||
|
||||
faves := []*gtsmodel.StatusFave{}
|
||||
if err := ps.conn.Model(&faves).Where("status_id = ?", status.ID).Select(); err != nil {
|
||||
if err := s.conn.Model(&faves).Where("status_id = ?", status.ID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return accounts, nil // no rows just means nobody has faved this status, so that's fine
|
||||
}
|
||||
|
|
@ -145,7 +196,7 @@ func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.
|
|||
|
||||
for _, f := range faves {
|
||||
acc := >smodel.Account{}
|
||||
if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
|
||||
if err := s.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
|
||||
}
|
||||
|
|
@ -156,11 +207,11 @@ func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.
|
|||
return accounts, nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
|
||||
func (s *statusDB) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, db.DBError) {
|
||||
accounts := []*gtsmodel.Account{}
|
||||
|
||||
boosts := []*gtsmodel.Status{}
|
||||
if err := ps.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil {
|
||||
if err := s.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return accounts, nil // no rows just means nobody has boosted this status, so that's fine
|
||||
}
|
||||
|
|
@ -169,7 +220,7 @@ func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmode
|
|||
|
||||
for _, f := range boosts {
|
||||
acc := >smodel.Account{}
|
||||
if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
|
||||
if err := s.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
|
||||
}
|
||||
|
|
|
|||
86
internal/db/pg/status_test.go
Normal file
86
internal/db/pg/status_test.go
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package pg_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
type StatusTestSuite struct {
|
||||
PGStandardTestSuite
|
||||
}
|
||||
|
||||
func (suite *PGStandardTestSuite) SetupSuite() {
|
||||
suite.testTokens = testrig.NewTestTokens()
|
||||
suite.testClients = testrig.NewTestClients()
|
||||
suite.testApplications = testrig.NewTestApplications()
|
||||
suite.testUsers = testrig.NewTestUsers()
|
||||
suite.testAccounts = testrig.NewTestAccounts()
|
||||
suite.testAttachments = testrig.NewTestAttachments()
|
||||
suite.testStatuses = testrig.NewTestStatuses()
|
||||
suite.testTags = testrig.NewTestTags()
|
||||
suite.testMentions = testrig.NewTestMentions()
|
||||
}
|
||||
|
||||
func (suite *PGStandardTestSuite) SetupTest() {
|
||||
suite.config = testrig.NewTestConfig()
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.log = testrig.NewTestLog()
|
||||
|
||||
testrig.StandardDBSetup(suite.db, suite.testAccounts)
|
||||
}
|
||||
|
||||
func (suite *PGStandardTestSuite) TearDownTest() {
|
||||
testrig.StandardDBTeardown(suite.db)
|
||||
}
|
||||
|
||||
func (suite *PGStandardTestSuite) TestGetStatusByID() {
|
||||
status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
suite.NotNil(status)
|
||||
suite.NotNil(status.Account)
|
||||
suite.NotNil(status.CreatedWithApplication)
|
||||
suite.Nil(status.BoostOf)
|
||||
suite.Nil(status.BoostOfAccount)
|
||||
suite.Nil(status.InReplyTo)
|
||||
suite.Nil(status.InReplyToAccount)
|
||||
}
|
||||
|
||||
func (suite *PGStandardTestSuite) TestGetStatusByURI() {
|
||||
status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
|
||||
if err != nil {
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
suite.NotNil(status)
|
||||
suite.NotNil(status.Account)
|
||||
suite.NotNil(status.CreatedWithApplication)
|
||||
suite.Nil(status.BoostOf)
|
||||
suite.Nil(status.BoostOfAccount)
|
||||
suite.Nil(status.InReplyTo)
|
||||
suite.Nil(status.InReplyToAccount)
|
||||
}
|
||||
|
||||
func TestStatusTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(PGStandardTestSuite))
|
||||
}
|
||||
|
|
@ -19,16 +19,26 @@
|
|||
package pg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
||||
func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
|
||||
type timelineDB struct {
|
||||
config *config.Config
|
||||
conn *pg.DB
|
||||
log *logrus.Logger
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (t *timelineDB) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.DBError) {
|
||||
statuses := []*gtsmodel.Status{}
|
||||
q := ps.conn.Model(&statuses)
|
||||
q := t.conn.Model(&statuses)
|
||||
|
||||
q = q.ColumnExpr("status.*").
|
||||
// Find out who accountID follows.
|
||||
|
|
@ -74,22 +84,22 @@ func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID str
|
|||
err := q.Select()
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return nil, db.ErrNoEntries{}
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(statuses) == 0 {
|
||||
return nil, db.ErrNoEntries{}
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
|
||||
return statuses, nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
|
||||
func (t *timelineDB) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.DBError) {
|
||||
statuses := []*gtsmodel.Status{}
|
||||
|
||||
q := ps.conn.Model(&statuses).
|
||||
q := t.conn.Model(&statuses).
|
||||
Where("visibility = ?", gtsmodel.VisibilityPublic).
|
||||
Where("? IS NULL", pg.Ident("in_reply_to_id")).
|
||||
Where("? IS NULL", pg.Ident("in_reply_to_uri")).
|
||||
|
|
@ -119,13 +129,13 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s
|
|||
err := q.Select()
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return nil, db.ErrNoEntries{}
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(statuses) == 0 {
|
||||
return nil, db.ErrNoEntries{}
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
|
||||
return statuses, nil
|
||||
|
|
@ -133,11 +143,11 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s
|
|||
|
||||
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
|
||||
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
|
||||
func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) {
|
||||
func (t *timelineDB) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.DBError) {
|
||||
|
||||
faves := []*gtsmodel.StatusFave{}
|
||||
|
||||
fq := ps.conn.Model(&faves).
|
||||
fq := t.conn.Model(&faves).
|
||||
Where("account_id = ?", accountID).
|
||||
Order("id DESC")
|
||||
|
||||
|
|
@ -156,13 +166,13 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st
|
|||
err := fq.Select()
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return nil, "", "", db.ErrNoEntries{}
|
||||
return nil, "", "", db.ErrNoEntries
|
||||
}
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
if len(faves) == 0 {
|
||||
return nil, "", "", db.ErrNoEntries{}
|
||||
return nil, "", "", db.ErrNoEntries
|
||||
}
|
||||
|
||||
// map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID
|
||||
|
|
@ -175,16 +185,16 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st
|
|||
}
|
||||
|
||||
statuses := []*gtsmodel.Status{}
|
||||
err = ps.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
|
||||
err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
|
||||
if err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return nil, "", "", db.ErrNoEntries{}
|
||||
return nil, "", "", db.ErrNoEntries
|
||||
}
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
if len(statuses) == 0 {
|
||||
return nil, "", "", db.ErrNoEntries{}
|
||||
return nil, "", "", db.ErrNoEntries
|
||||
}
|
||||
|
||||
// arrange statuses by fave ID
|
||||
|
|
|
|||
|
|
@ -1,73 +0,0 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package pg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-pg/pg/v10"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
)
|
||||
|
||||
func (ps *postgresService) Upsert(i interface{}, conflictColumn string) error {
|
||||
if _, err := ps.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) UpdateByID(id string, i interface{}) error {
|
||||
if _, err := ps.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
|
||||
if err == pg.ErrNoRows {
|
||||
return db.ErrNoEntries{}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
|
||||
_, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ps *postgresService) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) error {
|
||||
q := ps.conn.Model(i)
|
||||
|
||||
for _, w := range where {
|
||||
if w.Value == nil {
|
||||
q = q.Where("? IS NULL", pg.Ident(w.Key))
|
||||
} else {
|
||||
if w.CaseInsensitive {
|
||||
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
|
||||
} else {
|
||||
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
q = q.Set("? = ?", pg.Safe(key), value)
|
||||
|
||||
_, err := q.Update()
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
@ -5,23 +5,23 @@ import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
|||
type Relationship interface {
|
||||
// Blocked checks whether a block exists in eiher direction between two accounts.
|
||||
// That is, it returns true if account1 blocks account2, OR if account2 blocks account1.
|
||||
Blocked(account1 string, account2 string) (bool, error)
|
||||
Blocked(account1 string, account2 string) (bool, DBError)
|
||||
|
||||
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
|
||||
GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error)
|
||||
GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, DBError)
|
||||
|
||||
// Follows returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
|
||||
Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error)
|
||||
Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, DBError)
|
||||
|
||||
// FollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
|
||||
FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error)
|
||||
FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, DBError)
|
||||
|
||||
// Mutuals returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
|
||||
Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error)
|
||||
Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, DBError)
|
||||
|
||||
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
|
||||
// In other words, it should create the follow, and delete the existing follow request.
|
||||
//
|
||||
// It will return the newly created follow for further processing.
|
||||
AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
|
||||
AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, DBError)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,42 +3,48 @@ package db
|
|||
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
||||
type Status interface {
|
||||
// GetStatusByID returns one status from the database, with all rel fields populated (if possible).
|
||||
GetStatusByID(id string) (*gtsmodel.Status, DBError)
|
||||
|
||||
// GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
|
||||
GetStatusByURI(uri string) (*gtsmodel.Status, DBError)
|
||||
|
||||
// GetReplyCountForStatus returns the amount of replies recorded for a status, or an error if something goes wrong
|
||||
GetReplyCountForStatus(status *gtsmodel.Status) (int, error)
|
||||
GetReplyCountForStatus(status *gtsmodel.Status) (int, DBError)
|
||||
|
||||
// GetReblogCountForStatus returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
|
||||
GetReblogCountForStatus(status *gtsmodel.Status) (int, error)
|
||||
GetReblogCountForStatus(status *gtsmodel.Status) (int, DBError)
|
||||
|
||||
// GetFaveCountForStatus returns the amount of faves/likes recorded for a status, or an error if something goes wrong
|
||||
GetFaveCountForStatus(status *gtsmodel.Status) (int, error)
|
||||
GetFaveCountForStatus(status *gtsmodel.Status) (int, DBError)
|
||||
|
||||
// StatusParents get the parent statuses of a given status.
|
||||
//
|
||||
// If onlyDirect is true, only the immediate parent will be returned.
|
||||
StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error)
|
||||
StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, DBError)
|
||||
|
||||
// StatusChildren gets the child statuses of a given status.
|
||||
//
|
||||
// If onlyDirect is true, only the immediate children will be returned.
|
||||
StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error)
|
||||
StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, DBError)
|
||||
|
||||
// StatusFavedBy checks if a given status has been faved by a given account ID
|
||||
StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error)
|
||||
StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
|
||||
|
||||
// StatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
|
||||
StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error)
|
||||
StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
|
||||
|
||||
// StatusMutedBy checks if a given status has been muted by a given account ID
|
||||
StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error)
|
||||
StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
|
||||
|
||||
// StatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
|
||||
StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error)
|
||||
StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
|
||||
|
||||
// WhoFavedStatus returns a slice of accounts who faved the given status.
|
||||
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
|
||||
WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error)
|
||||
WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, DBError)
|
||||
|
||||
// WhoBoostedStatus returns a slice of accounts who boosted the given status.
|
||||
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
|
||||
WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error)
|
||||
WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, DBError)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,13 +6,13 @@ type Timeline interface {
|
|||
// GetHomeTimelineForAccount returns a slice of statuses from accounts that are followed by the given account id.
|
||||
//
|
||||
// Statuses should be returned in descending order of when they were created (newest first).
|
||||
GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
|
||||
GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, DBError)
|
||||
|
||||
// GetPublicTimelineForAccount fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
|
||||
// It will use the given filters and try to return as many statuses as possible up to the limit.
|
||||
//
|
||||
// Statuses should be returned in descending order of when they were created (newest first).
|
||||
GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
|
||||
GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, DBError)
|
||||
|
||||
// GetFavedTimelineForAccount fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
|
||||
// It will use the given filters and try to return as many statuses as possible up to the limit.
|
||||
|
|
@ -21,5 +21,5 @@ type Timeline interface {
|
|||
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
|
||||
//
|
||||
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
|
||||
GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error)
|
||||
GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, DBError)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ func (d *deref) blockedDomain(host string) (bool, error) {
|
|||
return true, nil
|
||||
}
|
||||
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are no entries so there's no block
|
||||
return false, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -288,7 +288,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
|
|||
attachmentIDs = append(attachmentIDs, maybeAttachment.ID)
|
||||
continue
|
||||
}
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// we have a real error
|
||||
return fmt.Errorf("error checking db for existence of attachment with remote url %s: %s", a.RemoteURL, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
|
|||
status.ID = statusID
|
||||
|
||||
if err := f.db.Put(status); err != nil {
|
||||
if _, ok := err.(db.ErrAlreadyExists); ok {
|
||||
if err == db.ErrAlreadyExists {
|
||||
// the status already exists in the database, which means we've already handled everything else,
|
||||
// so we can just return nil here and be done with it.
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ func (f *federatingDB) OutboxForInbox(c context.Context, inboxIRI *url.URL) (out
|
|||
}
|
||||
acct := >smodel.Account{}
|
||||
if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String())
|
||||
}
|
||||
return nil, fmt.Errorf("db error searching for actor with inbox %s", inboxIRI.String())
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
|
|||
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
|
||||
}
|
||||
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: uid}}, >smodel.Status{}); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are no entries for this status
|
||||
return false, nil
|
||||
}
|
||||
|
|
@ -72,7 +72,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
|
|||
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
|
||||
}
|
||||
if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are no entries for this username
|
||||
return false, nil
|
||||
}
|
||||
|
|
@ -89,7 +89,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
|
|||
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
|
||||
}
|
||||
if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are no entries for this username
|
||||
return false, nil
|
||||
}
|
||||
|
|
@ -106,7 +106,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
|
|||
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
|
||||
}
|
||||
if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are no entries for this username
|
||||
return false, nil
|
||||
}
|
||||
|
|
@ -123,7 +123,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
|
|||
return false, fmt.Errorf("error parsing like path for url %s: %s", id.String(), err)
|
||||
}
|
||||
if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are no entries for this username
|
||||
return false, nil
|
||||
}
|
||||
|
|
@ -131,7 +131,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
|
|||
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
|
||||
}
|
||||
if err := f.db.GetByID(likeID, >smodel.StatusFave{}); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are no entries
|
||||
return false, nil
|
||||
}
|
||||
|
|
@ -148,7 +148,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
|
|||
return false, fmt.Errorf("error parsing block path for url %s: %s", id.String(), err)
|
||||
}
|
||||
if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are no entries for this username
|
||||
return false, nil
|
||||
}
|
||||
|
|
@ -156,7 +156,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
|
|||
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
|
||||
}
|
||||
if err := f.db.GetByID(blockID, >smodel.Block{}); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are no entries
|
||||
return false, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -213,7 +213,7 @@ func (f *federatingDB) ActorForOutbox(c context.Context, outboxIRI *url.URL) (ac
|
|||
}
|
||||
acct := >smodel.Account{}
|
||||
if err := f.db.GetWhere([]db.Where{{Key: "outbox_uri", Value: outboxIRI.String()}}, acct); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, fmt.Errorf("no actor found that corresponds to outbox %s", outboxIRI.String())
|
||||
}
|
||||
return nil, fmt.Errorf("db error searching for actor with outbox %s", outboxIRI.String())
|
||||
|
|
@ -238,7 +238,7 @@ func (f *federatingDB) ActorForInbox(c context.Context, inboxIRI *url.URL) (acto
|
|||
}
|
||||
acct := >smodel.Account{}
|
||||
if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String())
|
||||
}
|
||||
return nil, fmt.Errorf("db error searching for actor with inbox %s", inboxIRI.String())
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
|
|||
// authentication has passed, so add an instance entry for this instance if it hasn't been done already
|
||||
i := >smodel.Instance{}
|
||||
if err := f.db.GetWhere([]db.Where{{Key: "domain", Value: publicKeyOwnerURI.Host, CaseInsensitive: true}}, i); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// there's been an actual error
|
||||
return ctx, false, fmt.Errorf("error getting requesting account with public key id %s: %s", publicKeyOwnerURI.String(), err)
|
||||
}
|
||||
|
|
@ -202,8 +202,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
|
|||
|
||||
requestingAccount := >smodel.Account{}
|
||||
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: uri.String()}}, requestingAccount); err != nil {
|
||||
_, ok := err.(db.ErrNoEntries)
|
||||
if ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// we don't have an entry for this account so it's not blocked
|
||||
// TODO: allow a different default to be set for this behavior
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ func (f *federator) blockedDomain(host string) (bool, error) {
|
|||
return true, nil
|
||||
}
|
||||
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are no entries so there's no block
|
||||
return false, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -46,10 +46,12 @@ type Account struct {
|
|||
|
||||
// ID of the avatar as a media attachment
|
||||
AvatarMediaAttachmentID string `pg:"type:CHAR(26)"`
|
||||
AvatarMediaAttachment *MediaAttachment `pg:"rel:has-one"`
|
||||
// For a non-local account, where can the header be fetched?
|
||||
AvatarRemoteURL string
|
||||
// ID of the header as a media attachment
|
||||
HeaderMediaAttachmentID string `pg:"type:CHAR(26)"`
|
||||
HeaderMediaAttachment *MediaAttachment `pg:"rel:has-one"`
|
||||
// For a non-local account, where can the header be fetched?
|
||||
HeaderRemoteURL string
|
||||
// DisplayName for this account. Can be empty, then just the Username will be used for display purposes.
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ type Instance struct {
|
|||
SuspendedAt time.Time
|
||||
// ID of any existing domain block for this instance in the database
|
||||
DomainBlockID string `pg:"type:CHAR(26)"`
|
||||
DomainBlock *DomainBlock `pg:"rel:has-one"`
|
||||
// Short description of this instance
|
||||
ShortDescription string
|
||||
// Longer description of this instance
|
||||
|
|
@ -32,6 +33,7 @@ type Instance struct {
|
|||
ContactAccountUsername string
|
||||
// Contact account ID in the database for this instance
|
||||
ContactAccountID string `pg:"type:CHAR(26)"`
|
||||
ContactAccount *Account `pg:"rel:has-one"`
|
||||
// Reputation score of this instance
|
||||
Reputation int64 `pg:",notnull,default:0"`
|
||||
// Version of the software used on this instance
|
||||
|
|
|
|||
|
|
@ -47,24 +47,24 @@ type Status struct {
|
|||
// is this status from a local account?
|
||||
Local bool
|
||||
// which account posted this status?
|
||||
AccountID string `pg:"type:CHAR(26),notnull"`
|
||||
Account *Account `pg:"rel:has-one"`
|
||||
AccountID string `pg:"type:CHAR(26),notnull"`
|
||||
Account *Account `pg:"rel:has-one"`
|
||||
// AP uri of the owner of this status
|
||||
AccountURI string
|
||||
// id of the status this status is a reply to
|
||||
InReplyToID string `pg:"type:CHAR(26)"`
|
||||
InReplyTo *Status `pg:"-"`
|
||||
InReplyToID string `pg:"type:CHAR(26)"`
|
||||
InReplyTo *Status `pg:"rel:has-one"`
|
||||
// AP uri of the status this status is a reply to
|
||||
InReplyToURI string
|
||||
// id of the account that this status replies to
|
||||
InReplyToAccountID string `pg:"type:CHAR(26)"`
|
||||
InReplyToAccount *Account `pg:"-"`
|
||||
InReplyToAccountID string `pg:"type:CHAR(26)"`
|
||||
InReplyToAccount *Account `pg:"rel:has-one"`
|
||||
// id of the status this status is a boost of
|
||||
BoostOfID string `pg:"type:CHAR(26)"`
|
||||
BoostOf *Status `pg:"-"`
|
||||
BoostOfID string `pg:"type:CHAR(26)"`
|
||||
BoostOf *Status `pg:"rel:has-one"`
|
||||
// id of the account that owns the boosted status
|
||||
BoostOfAccountID string `pg:"type:CHAR(26)"`
|
||||
BoostOfAccount *Account `pg:"-"`
|
||||
BoostOfAccountID string `pg:"type:CHAR(26)"`
|
||||
BoostOfAccount *Account `pg:"rel:has-one"`
|
||||
// cw string for this status
|
||||
ContentWarning string
|
||||
// visibility entry for this status
|
||||
|
|
@ -74,8 +74,8 @@ type Status struct {
|
|||
// what language is this status written in?
|
||||
Language string
|
||||
// Which application was used to create this status?
|
||||
CreatedWithApplicationID string `pg:"type:CHAR(26)"`
|
||||
CreatedWithApplication *Application `pg:"-"`
|
||||
CreatedWithApplicationID string `pg:"type:CHAR(26)"`
|
||||
CreatedWithApplication *Application `pg:"rel:has-one"`
|
||||
// advanced visibility for this status
|
||||
VisibilityAdvanced *VisibilityAdvanced
|
||||
// What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types
|
||||
|
|
@ -103,13 +103,6 @@ type Status struct {
|
|||
GTSEmojis []*Emoji `pg:"-"`
|
||||
// MediaAttachments used in this status
|
||||
GTSMediaAttachments []*MediaAttachment `pg:"-"`
|
||||
// Status being replied to
|
||||
|
||||
// Account being replied to
|
||||
|
||||
// Status being boosted
|
||||
|
||||
// Account of the boosted status
|
||||
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -27,11 +27,11 @@ import (
|
|||
)
|
||||
|
||||
type clientStore struct {
|
||||
db db.DB
|
||||
db db.Basic
|
||||
}
|
||||
|
||||
// NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend.
|
||||
func NewClientStore(db db.DB) oauth2.ClientStore {
|
||||
func NewClientStore(db db.Basic) oauth2.ClientStore {
|
||||
pts := &clientStore{
|
||||
db: db,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
|
|||
// try to get the deleted client; we should get an error
|
||||
deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
|
||||
suite.Assert().Nil(deletedClient)
|
||||
suite.Assert().EqualValues(db.ErrNoEntries{}, err)
|
||||
suite.Assert().EqualValues(db.ErrNoEntries, err)
|
||||
}
|
||||
|
||||
func TestPgClientStoreTestSuite(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ type s struct {
|
|||
}
|
||||
|
||||
// New returns a new oauth server that implements the Server interface
|
||||
func New(database db.DB, log *logrus.Logger) Server {
|
||||
func New(database db.Basic, log *logrus.Logger) Server {
|
||||
ts := newTokenStore(context.Background(), database, log)
|
||||
cs := NewClientStore(database)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ import (
|
|||
// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
|
||||
type tokenStore struct {
|
||||
oauth2.TokenStore
|
||||
db db.DB
|
||||
db db.Basic
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
|
|
@ -42,7 +42,7 @@ type tokenStore struct {
|
|||
//
|
||||
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
|
||||
// the tokens in the DB once per minute and deletes any that have expired.
|
||||
func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.TokenStore {
|
||||
func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.TokenStore {
|
||||
pts := &tokenStore{
|
||||
db: db,
|
||||
log: log,
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
|
|||
// make sure the target account actually exists in our db
|
||||
targetAcct := >smodel.Account{}
|
||||
if err := p.db.GetByID(targetAccountID, targetAcct); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: account %s not found in the db: %s", targetAccountID, err))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim
|
|||
// make sure the target account actually exists in our db
|
||||
targetAcct := >smodel.Account{}
|
||||
if err := p.db.GetByID(form.ID, targetAcct); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ selectStatusesLoop:
|
|||
for {
|
||||
statuses, err := p.db.GetAccountStatuses(account.ID, 20, false, maxID, false, false)
|
||||
if err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// no statuses left for this instance so we're done
|
||||
l.Infof("Delete: done iterating through statuses for account %s", account.Username)
|
||||
break selectStatusesLoop
|
||||
|
|
@ -158,7 +158,7 @@ selectStatusesLoop:
|
|||
}
|
||||
|
||||
if err := p.db.DeleteByID(s.ID, s); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// actual error has occurred
|
||||
l.Errorf("Delete: db error status %s for account %s: %s", s.ID, account.Username, err)
|
||||
break selectStatusesLoop
|
||||
|
|
@ -168,7 +168,7 @@ selectStatusesLoop:
|
|||
// if there are any boosts of this status, delete them as well
|
||||
boosts := []*gtsmodel.Status{}
|
||||
if err := p.db.GetWhere([]db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// an actual error has occurred
|
||||
l.Errorf("Delete: db error selecting boosts of status %s for account %s: %s", s.ID, account.Username, err)
|
||||
break selectStatusesLoop
|
||||
|
|
@ -190,7 +190,7 @@ selectStatusesLoop:
|
|||
}
|
||||
|
||||
if err := p.db.DeleteByID(b.ID, b); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// actual error has occurred
|
||||
l.Errorf("Delete: db error deleting boost with id %s: %s", b.ID, err)
|
||||
break selectStatusesLoop
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ import (
|
|||
func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) {
|
||||
targetAccount := >smodel.Account{}
|
||||
if err := p.db.GetByID(targetAccountID, targetAccount); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
return nil, fmt.Errorf("db error: %s", err)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco
|
|||
followers := []gtsmodel.Follow{}
|
||||
accounts := []apimodel.Account{}
|
||||
if err := p.db.GetAccountFollowers(targetAccountID, &followers, false); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return accounts, nil
|
||||
}
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
|
|
@ -57,7 +57,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco
|
|||
|
||||
a := >smodel.Account{}
|
||||
if err := p.db.GetByID(f.AccountID, a); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
continue
|
||||
}
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco
|
|||
following := []gtsmodel.Follow{}
|
||||
accounts := []apimodel.Account{}
|
||||
if err := p.db.GetAccountFollowing(targetAccountID, &following); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return accounts, nil
|
||||
}
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
|
|
@ -57,7 +57,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco
|
|||
|
||||
a := >smodel.Account{}
|
||||
if err := p.db.GetByID(f.TargetAccountID, a); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
continue
|
||||
}
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ import (
|
|||
func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) {
|
||||
targetAccount := >smodel.Account{}
|
||||
if err := p.db.GetByID(targetAccountID, targetAccount); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no entry found for account id %s", targetAccountID))
|
||||
}
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
|
|
@ -39,7 +39,7 @@ func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccou
|
|||
apiStatuses := []apimodel.Status{}
|
||||
statuses, err := p.db.GetAccountStatuses(targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
|
||||
if err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return apiStatuses, nil
|
||||
}
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou
|
|||
// make sure the target account actually exists in our db
|
||||
targetAcct := >smodel.Account{}
|
||||
if err := p.db.GetByID(targetAccountID, targetAcct); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockRemove: account %s not found in the db: %s", targetAccountID, err))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco
|
|||
// make sure the target account actually exists in our db
|
||||
targetAcct := >smodel.Account{}
|
||||
if err := p.db.GetByID(targetAccountID, targetAcct); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string,
|
|||
domainBlock := >smodel.DomainBlock{}
|
||||
err := p.db.GetWhere([]db.Where{{Key: "domain", Value: domain, CaseInsensitive: true}}, domainBlock)
|
||||
if err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// something went wrong in the DB
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("DomainBlockCreate: db error checking for existence of domain block %s: %s", domain, err))
|
||||
}
|
||||
|
|
@ -60,7 +60,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string,
|
|||
|
||||
// put the new block in the database
|
||||
if err := p.db.Put(domainBlock); err != nil {
|
||||
if _, ok := err.(db.ErrAlreadyExists); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// there's a real error creating the block
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("DomainBlockCreate: db error putting new domain block %s: %s", domain, err))
|
||||
}
|
||||
|
|
@ -125,7 +125,7 @@ selectAccountsLoop:
|
|||
for {
|
||||
accounts, err := p.db.GetAccountsForInstance(block.Domain, maxID, limit)
|
||||
if err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// no accounts left for this instance so we're done
|
||||
l.Infof("domainBlockProcessSideEffects: done iterating through accounts for domain %s", block.Domain)
|
||||
break selectAccountsLoop
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ func (p *processor) DomainBlockDelete(account *gtsmodel.Account, id string) (*ap
|
|||
domainBlock := >smodel.DomainBlock{}
|
||||
|
||||
if err := p.db.GetByID(id, domainBlock); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// something has gone really wrong
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ func (p *processor) DomainBlockGet(account *gtsmodel.Account, id string, export
|
|||
domainBlock := >smodel.DomainBlock{}
|
||||
|
||||
if err := p.db.GetByID(id, domainBlock); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// something has gone really wrong
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ func (p *processor) DomainBlocksGet(account *gtsmodel.Account, export bool) ([]*
|
|||
domainBlocks := []*gtsmodel.DomainBlock{}
|
||||
|
||||
if err := p.db.GetAll(&domainBlocks); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// something has gone really wrong
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ import (
|
|||
func (p *processor) BlocksGet(authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) {
|
||||
accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(authed.Account.ID, maxID, sinceID, limit)
|
||||
if err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are just no entries
|
||||
return &apimodel.BlocksResponse{
|
||||
Accounts: []*apimodel.Account{},
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ import (
|
|||
func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) {
|
||||
frs := []gtsmodel.FollowRequest{}
|
||||
if err := p.db.GetAccountFollowRequests(auth.Account.ID, &frs); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error {
|
|||
// notification exists already so just continue
|
||||
continue
|
||||
}
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// there's a real error in the db
|
||||
return fmt.Errorf("notifyStatus: error checking existence of notification for mention with id %s : %s", m.ID, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
|
|||
incomingAnnounce.ID = incomingAnnounceID
|
||||
|
||||
if err := p.db.Put(incomingAnnounce); err != nil {
|
||||
if _, ok := err.(db.ErrAlreadyExists); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
return fmt.Errorf("error adding dereferenced announce to the db: %s", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import (
|
|||
func (p *processor) Delete(mediaAttachmentID string) gtserror.WithCode {
|
||||
a := >smodel.MediaAttachment{}
|
||||
if err := p.db.GetByID(mediaAttachmentID, a); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// attachment already gone
|
||||
return nil
|
||||
}
|
||||
|
|
@ -38,7 +38,7 @@ func (p *processor) Delete(mediaAttachmentID string) gtserror.WithCode {
|
|||
|
||||
// delete the attachment
|
||||
if err := p.db.DeleteByID(mediaAttachmentID, a); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
errs = append(errs, fmt.Sprintf("remove attachment: %s", err))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ import (
|
|||
func (p *processor) GetMedia(account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
|
||||
attachment := >smodel.MediaAttachment{}
|
||||
if err := p.db.GetByID(mediaAttachmentID, attachment); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// attachment doesn't exist
|
||||
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ import (
|
|||
func (p *processor) Update(account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) {
|
||||
attachment := >smodel.MediaAttachment{}
|
||||
if err := p.db.GetByID(mediaAttachmentID, attachment); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// attachment doesn't exist
|
||||
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ func (p *processor) searchAccountByMention(authed *oauth.Auth, mention string, r
|
|||
return maybeAcct, nil
|
||||
}
|
||||
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// if it's not errNoEntries there's been a real database error so bail at this point
|
||||
return nil, fmt.Errorf("searchAccountByMention: database error: %s", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ func (p *processor) Context(account *gtsmodel.Account, targetStatusID string) (*
|
|||
|
||||
targetStatus := >smodel.Status{}
|
||||
if err := p.db.GetByID(targetStatusID, targetStatus); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
}
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ func (p *processor) Delete(account *gtsmodel.Account, targetStatusID string) (*a
|
|||
l.Tracef("going to search for target status %s", targetStatusID)
|
||||
targetStatus := >smodel.Status{}
|
||||
if err := p.db.GetByID(targetStatusID, targetStatus); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
|
||||
}
|
||||
// status is already gone
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ func (p *processor) Unboost(account *gtsmodel.Account, application *gtsmodel.App
|
|||
|
||||
if err != nil {
|
||||
// something went wrong in the db finding the boost
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error fetching existing boost from database: %s", err))
|
||||
}
|
||||
// we just don't have a boost
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ func (p *processor) Unfave(account *gtsmodel.Account, targetStatusID string) (*a
|
|||
}
|
||||
if err != nil {
|
||||
// something went wrong in the db finding the fave
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error fetching existing fave from database: %s", err))
|
||||
}
|
||||
// we just don't have a fave
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, th
|
|||
repliedAccount := >smodel.Account{}
|
||||
// check replied status exists + is replyable
|
||||
if err := p.db.GetByID(form.InReplyToID, repliedStatus); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return fmt.Errorf("status with id %s not replyable because it doesn't exist", form.InReplyToID)
|
||||
}
|
||||
return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err)
|
||||
|
|
@ -113,14 +113,14 @@ func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, th
|
|||
|
||||
// check replied account is known to us
|
||||
if err := p.db.GetByID(repliedStatus.AccountID, repliedAccount); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return fmt.Errorf("status with id %s not replyable because account id %s is not known", form.InReplyToID, repliedStatus.AccountID)
|
||||
}
|
||||
return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err)
|
||||
}
|
||||
// check if a block exists
|
||||
if blocked, err := p.db.Blocked(thisAccountID, repliedAccount.ID); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err)
|
||||
}
|
||||
} else if blocked {
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ func (p *processor) HomeTimelineGet(authed *oauth.Auth, maxID string, sinceID st
|
|||
func (p *processor) PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) {
|
||||
statuses, err := p.db.GetPublicTimelineForAccount(authed.Account.ID, maxID, sinceID, minID, limit, local)
|
||||
if err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are just no entries left
|
||||
return &apimodel.StatusTimelineResponse{
|
||||
Statuses: []*apimodel.Status{},
|
||||
|
|
@ -97,7 +97,7 @@ func (p *processor) PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID
|
|||
func (p *processor) FavedTimelineGet(authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode) {
|
||||
statuses, nextMaxID, prevMinID, err := p.db.GetFavedTimelineForAccount(authed.Account.ID, maxID, minID, limit)
|
||||
if err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are just no entries left
|
||||
return &apimodel.StatusTimelineResponse{
|
||||
Statuses: []*apimodel.Status{},
|
||||
|
|
@ -122,7 +122,7 @@ func (p *processor) filterPublicStatuses(authed *oauth.Auth, statuses []*gtsmode
|
|||
for _, s := range statuses {
|
||||
targetAccount := >smodel.Account{}
|
||||
if err := p.db.GetByID(s.AccountID, targetAccount); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
l.Debugf("filterPublicStatuses: skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
|
||||
continue
|
||||
}
|
||||
|
|
@ -157,7 +157,7 @@ func (p *processor) filterFavedStatuses(authed *oauth.Auth, statuses []*gtsmodel
|
|||
for _, s := range statuses {
|
||||
targetAccount := >smodel.Account{}
|
||||
if err := p.db.GetByID(s.AccountID, targetAccount); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
l.Debugf("filterFavedStatuses: skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
|
||||
continue
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ func useSession(cfg *config.Config, dbService db.DB, engine *gin.Engine) error {
|
|||
// check if we have a saved router session already
|
||||
routerSessions := []*gtsmodel.RouterSession{}
|
||||
if err := dbService.GetAll(&routerSessions); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// proper error occurred
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ grabloop:
|
|||
for ; len(filtered) < amount && i < 5; i = i + 1 { // try the grabloop 5 times only
|
||||
statuses, err := t.db.GetHomeTimelineForAccount(t.accountID, "", "", offsetStatus, amount, false)
|
||||
if err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
break grabloop // we just don't have enough statuses left in the db so index what we've got and then bail
|
||||
}
|
||||
return fmt.Errorf("IndexBefore: error getting statuses from db: %s", err)
|
||||
|
|
@ -132,7 +132,7 @@ grabloop:
|
|||
l.Tracef("entering grabloop; i is %d; len(filtered) is %d", i, len(filtered))
|
||||
statuses, err := t.db.GetHomeTimelineForAccount(t.accountID, offsetStatus, "", "", amount, false)
|
||||
if err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
break grabloop // we just don't have enough statuses left in the db so index what we've got and then bail
|
||||
}
|
||||
return fmt.Errorf("IndexBehind: error getting statuses from db: %s", err)
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ prepareloop:
|
|||
if preparing {
|
||||
if err := t.prepare(entry.statusID); err != nil {
|
||||
// there's been an error
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// it's a real error
|
||||
return fmt.Errorf("PrepareBehind: error preparing status with id %s: %s", entry.statusID, err)
|
||||
}
|
||||
|
|
@ -150,7 +150,7 @@ prepareloop:
|
|||
if preparing {
|
||||
if err := t.prepare(entry.statusID); err != nil {
|
||||
// there's been an error
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// it's a real error
|
||||
return fmt.Errorf("PrepareBefore: error preparing status with id %s: %s", entry.statusID, err)
|
||||
}
|
||||
|
|
@ -205,7 +205,7 @@ prepareloop:
|
|||
|
||||
if err := t.prepare(entry.statusID); err != nil {
|
||||
// there's been an error
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// it's a real error
|
||||
return fmt.Errorf("PrepareFromTop: error preparing status with id %s: %s", entry.statusID, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ func (c *converter) ASRepresentationToAccount(accountable ap.Accountable, update
|
|||
// we already know this account so we can skip generating it
|
||||
return acct, nil
|
||||
}
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
// we don't know the account and there's been a real error
|
||||
return nil, fmt.Errorf("error getting account with uri %s from the database: %s", uri.String(), err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account
|
|||
// check pending follow requests aimed at this account
|
||||
fr := []gtsmodel.FollowRequest{}
|
||||
if err := c.db.GetAccountFollowRequests(a.ID, &fr); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
return nil, fmt.Errorf("error getting follow requests: %s", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -63,41 +63,27 @@ func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account
|
|||
|
||||
func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, error) {
|
||||
// count followers
|
||||
followers := []gtsmodel.Follow{}
|
||||
if err := c.db.GetAccountFollowers(a.ID, &followers, false); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
return nil, fmt.Errorf("error getting followers: %s", err)
|
||||
}
|
||||
}
|
||||
var followersCount int
|
||||
if followers != nil {
|
||||
followersCount = len(followers)
|
||||
followersCount, err := c.db.CountAccountFollowers(a.ID, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error counting followers: %s", err)
|
||||
}
|
||||
|
||||
// count following
|
||||
following := []gtsmodel.Follow{}
|
||||
if err := c.db.GetAccountFollowing(a.ID, &following); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
return nil, fmt.Errorf("error getting following: %s", err)
|
||||
}
|
||||
}
|
||||
var followingCount int
|
||||
if following != nil {
|
||||
followingCount = len(following)
|
||||
followingCount, err := c.db.CountAccountFollowing(a.ID, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error counting following: %s", err)
|
||||
}
|
||||
|
||||
// count statuses
|
||||
statusesCount, err := c.db.GetAccountStatusesCount(a.ID)
|
||||
statusesCount, err := c.db.CountAccountStatuses(a.ID)
|
||||
if err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
return nil, fmt.Errorf("error getting last statuses: %s", err)
|
||||
}
|
||||
return nil, fmt.Errorf("error getting last statuses: %s", err)
|
||||
}
|
||||
|
||||
// check when the last status was
|
||||
lastStatus := >smodel.Status{}
|
||||
if err := c.db.GetAccountLastStatus(a.ID, lastStatus); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
if err != db.ErrNoEntries {
|
||||
return nil, fmt.Errorf("error getting last status: %s", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -107,23 +93,20 @@ func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, e
|
|||
}
|
||||
|
||||
// build the avatar and header URLs
|
||||
avi := >smodel.MediaAttachment{}
|
||||
if err := c.db.GetAccountAvatar(avi, a.ID); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
return nil, fmt.Errorf("error getting avatar: %s", err)
|
||||
}
|
||||
}
|
||||
aviURL := avi.URL
|
||||
aviURLStatic := avi.Thumbnail.URL
|
||||
|
||||
header := >smodel.MediaAttachment{}
|
||||
if err := c.db.GetAccountHeader(header, a.ID); err != nil {
|
||||
if _, ok := err.(db.ErrNoEntries); !ok {
|
||||
return nil, fmt.Errorf("error getting header: %s", err)
|
||||
}
|
||||
var aviURL string
|
||||
var aviURLStatic string
|
||||
if a.AvatarMediaAttachment != nil {
|
||||
aviURL = a.AvatarMediaAttachment.URL
|
||||
aviURLStatic = a.AvatarMediaAttachment.Thumbnail.URL
|
||||
}
|
||||
|
||||
var headerURL string
|
||||
var headerURLStatic string
|
||||
if a.HeaderMediaAttachment != nil {
|
||||
headerURL = a.HeaderMediaAttachment.URL
|
||||
headerURLStatic = a.HeaderMediaAttachment.Thumbnail.URL
|
||||
}
|
||||
headerURL := header.URL
|
||||
headerURLStatic := header.Thumbnail.URL
|
||||
|
||||
// get the fields set on this account
|
||||
fields := []model.Field{}
|
||||
|
|
@ -585,13 +568,10 @@ func (c *converter) InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, erro
|
|||
}
|
||||
|
||||
// get the instance account if it exists and just skip if it doesn't
|
||||
ia := >smodel.Account{}
|
||||
if err := c.db.GetWhere([]db.Where{{Key: "username", Value: i.Domain}}, ia); err == nil {
|
||||
// instance account exists, get the header for the account if it exists
|
||||
attachment := >smodel.MediaAttachment{}
|
||||
if err := c.db.GetAccountHeader(attachment, ia.ID); err == nil {
|
||||
// header exists, set it on the api model
|
||||
mi.Thumbnail = attachment.URL
|
||||
ia, err := c.db.GetInstanceAccount("")
|
||||
if err == nil {
|
||||
if ia.HeaderMediaAttachment != nil {
|
||||
mi.Thumbnail = ia.HeaderMediaAttachment.URL
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
|
|||
targetUser := >smodel.User{}
|
||||
if err := f.db.GetWhere([]db.Where{{Key: "account_id", Value: targetAccount.ID}}, targetUser); err != nil {
|
||||
l.Debug("target user could not be selected")
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("StatusVisible: db error selecting user for local target account %s: %s", targetAccount.ID, err)
|
||||
|
|
@ -76,7 +76,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
|
|||
if err := f.db.GetWhere([]db.Where{{Key: "account_id", Value: requestingAccount.ID}}, requestingUser); err != nil {
|
||||
// if the requesting account is local but doesn't have a corresponding user in the db this is a problem
|
||||
l.Debug("requesting user could not be selected")
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("StatusVisible: db error selecting user for local requesting account %s: %s", requestingAccount.ID, err)
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ func (f *filter) blockedDomain(host string) (bool, error) {
|
|||
return true, nil
|
||||
}
|
||||
|
||||
if _, ok := err.(db.ErrNoEntries); ok {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are no entries so there's no block
|
||||
return false, nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue