This commit is contained in:
tsmethurst 2021-08-19 15:40:33 +02:00
commit accc8971d1
53 changed files with 896 additions and 486 deletions

View file

@ -24,68 +24,69 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Account contains functions related to account getting/setting/creation.
type Account interface {
// GetAccountByID returns one account with the given ID, or an error if something goes wrong.
GetAccountByID(id string) (*gtsmodel.Account, DBError)
GetAccountByID(id string) (*gtsmodel.Account, Error)
// GetAccountByURI returns one account with the given URI, or an error if something goes wrong.
GetAccountByURI(uri string) (*gtsmodel.Account, DBError)
GetAccountByURI(uri string) (*gtsmodel.Account, Error)
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
GetAccountByURL(uri string) (*gtsmodel.Account, DBError)
GetAccountByURL(uri string) (*gtsmodel.Account, Error)
// GetLocalAccountByUsername is a shortcut for the common action of fetching an account ON THIS INSTANCE
// according to its username, which should be unique.
// The given account pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetLocalAccountByUsername(username string) (*gtsmodel.Account, DBError)
GetLocalAccountByUsername(username string) (*gtsmodel.Account, Error)
// GetAccountFollowRequests is a shortcut for the common action of fetching a list of follow requests targeting the given account ID.
// The given slice 'followRequests' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) DBError
GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) Error
// GetAccountFollowing is a shortcut for the common action of fetching a list of accounts that accountID is following.
// The given slice 'following' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) DBError
GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) Error
CountAccountFollowing(accountID string, localOnly bool) (int, DBError)
CountAccountFollowing(accountID string, localOnly bool) (int, Error)
// GetAccountFollowers is a shortcut for the common action of fetching a list of accounts that accountID is followed by.
// The given slice 'followers' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
//
// If localOnly is set to true, then only followers from *this instance* will be returned.
GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) DBError
GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) Error
CountAccountFollowers(accountID string, localOnly bool) (int, DBError)
CountAccountFollowers(accountID string, localOnly bool) (int, Error)
// GetAccountFaves is a shortcut for the common action of fetching a list of faves made by the given accountID.
// The given slice 'faves' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) DBError
GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) Error
// GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
CountAccountStatuses(accountID string) (int, DBError)
CountAccountStatuses(accountID string) (int, Error)
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
// In case of no entries, a 'no entries' error will be returned
GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, DBError)
GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error)
GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, DBError)
GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
// GetAccountLastPosted simply gets the timestamp of the most recent post by the account.
//
// The returned time will be zero if account has never posted anything.
GetAccountLastPosted(accountID string) (time.Time, DBError)
GetAccountLastPosted(accountID string) (time.Time, Error)
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) DBError
SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
// GetInstanceAccount returns the instance account for the given domain.
// If domain is empty, this instance account will be returned.
GetInstanceAccount(domain string) (*gtsmodel.Account, DBError)
GetInstanceAccount(domain string) (*gtsmodel.Account, Error)
}

View file

@ -24,29 +24,30 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Admin contains functions related to instance administration (new signups etc).
type Admin interface {
// IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db.
IsUsernameAvailable(username string) DBError
IsUsernameAvailable(username string) Error
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if:
// A) the email is already associated with an account
// B) we block signups from this email domain
// C) something went wrong in the db
IsEmailAvailable(email string) DBError
IsEmailAvailable(email string) Error
// NewSignup creates a new user in the database with the given parameters.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, DBError)
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
// This is needed for things like serving files that belong to the instance and not an individual user/account.
CreateInstanceAccount() DBError
CreateInstanceAccount() Error
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
// This is needed for things like serving instance information through /api/v1/instance
CreateInstanceInstance() DBError
CreateInstanceInstance() Error
}

View file

@ -20,67 +20,68 @@ package db
import "context"
// Basic wraps basic database functionality.
type Basic interface {
// CreateTable creates a table for the given interface.
// For implementations that don't use tables, this can just return nil.
CreateTable(i interface{}) DBError
CreateTable(i interface{}) Error
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
DropTable(i interface{}) DBError
DropTable(i interface{}) Error
// RegisterTable registers a table for use in many2many relations.
// For implementations that don't use tables, or many2many relations, this can just return nil.
RegisterTable(i interface{}) DBError
RegisterTable(i interface{}) Error
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil.
Stop(ctx context.Context) DBError
Stop(ctx context.Context) Error
// IsHealthy should return nil if the database connection is healthy, or an error if not.
IsHealthy(ctx context.Context) DBError
IsHealthy(ctx context.Context) Error
// GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
// for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetByID(id string, i interface{}) DBError
GetByID(id string, i interface{}) Error
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetWhere(where []Where, i interface{}) DBError
GetWhere(where []Where, i interface{}) Error
// GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetAll(i interface{}) DBError
GetAll(i interface{}) Error
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Put(i interface{}) DBError
Put(i interface{}) Error
// Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
// It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Upsert(i interface{}, conflictColumn string) DBError
Upsert(i interface{}, conflictColumn string) Error
// UpdateByID updates i with id id.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
UpdateByID(id string, i interface{}) DBError
UpdateByID(id string, i interface{}) Error
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
UpdateOneByID(id string, key string, value interface{}, i interface{}) DBError
UpdateOneByID(id string, key string, value interface{}, i interface{}) Error
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
UpdateWhere(where []Where, key string, value interface{}, i interface{}) DBError
UpdateWhere(where []Where, key string, value interface{}, i interface{}) Error
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
DeleteByID(id string, i interface{}) DBError
DeleteByID(id string, i interface{}) Error
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
DeleteWhere(where []Where, i interface{}) DBError
DeleteWhere(where []Where, i interface{}) Error
}

View file

@ -27,13 +27,12 @@ const (
DBTypePostgres string = "POSTGRES"
)
// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres).
// Note that in all of the functions below, the passed interface should be a pointer or a slice, which will then be populated
// by whatever is returned from the database.
// DB provides methods for interacting with an underlying database or other storage mechanism.
type DB interface {
Account
Admin
Basic
Domain
Instance
Media
Mention

36
internal/db/domain.go Normal file
View file

@ -0,0 +1,36 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "net/url"
// Domain contains DB functions related to domains and domain blocks.
type Domain interface {
// IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).
IsDomainBlocked(domain string) (bool, Error)
// AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found.
AreDomainsBlocked(domains []string) (bool, Error)
// IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`).
IsURIBlocked(uri *url.URL) (bool, Error)
// AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found.
AreURIsBlocked(uris []*url.URL) (bool, Error)
}

View file

@ -20,11 +20,16 @@ package db
import "fmt"
type DBError error
// Error denotes a database error.
type Error error
var (
ErrNoEntries DBError = fmt.Errorf("no entries")
ErrMultipleEntries DBError = fmt.Errorf("multiple entries")
ErrAlreadyExists DBError = fmt.Errorf("already exists")
ErrUnknown DBError = fmt.Errorf("unknown error")
// ErrNoEntries is returned when a caller expected an entry for a query, but none was found.
ErrNoEntries Error = fmt.Errorf("no entries")
// ErrMultipleEntries is returned when a caller expected ONE entry for a query, but multiples were found.
ErrMultipleEntries Error = fmt.Errorf("multiple entries")
// ErrAlreadyExists is returned when a caller tries to insert a database entry that already exists in the db.
ErrAlreadyExists Error = fmt.Errorf("already exists")
// ErrUnknown denotes an unknown database error.
ErrUnknown Error = fmt.Errorf("unknown error")
)

View file

@ -20,16 +20,17 @@ package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Instance contains functions for instance-level actions (counting instance users etc.).
type Instance interface {
// GetUserCountForInstance returns the number of known accounts registered with the given domain.
GetUserCountForInstance(domain string) (int, DBError)
GetUserCountForInstance(domain string) (int, Error)
// GetStatusCountForInstance returns the number of known statuses posted from the given domain.
GetStatusCountForInstance(domain string) (int, DBError)
GetStatusCountForInstance(domain string) (int, Error)
// GetDomainCountForInstance returns the number of known instances known that the given domain federates with.
GetDomainCountForInstance(domain string) (int, DBError)
GetDomainCountForInstance(domain string) (int, Error)
// GetAccountsForInstance returns a slice of accounts from the given instance, arranged by ID.
GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, DBError)
GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
}

View file

@ -20,7 +20,8 @@ package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Media contains functions related to creating/getting/removing media attachments.
type Media interface {
// GetAttachmentByID gets a single attachment by its ID
GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, DBError)
GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, Error)
}

View file

@ -20,10 +20,11 @@ package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Mention contains functions for getting/creating mentions in the database.
type Mention interface {
// GetMention gets a single mention by ID
GetMention(id string) (*gtsmodel.Mention, DBError)
GetMention(id string) (*gtsmodel.Mention, Error)
// GetMentions gets multiple mentions.
GetMentions(ids []string) ([]*gtsmodel.Mention, DBError)
GetMentions(ids []string) ([]*gtsmodel.Mention, Error)
}

View file

@ -20,7 +20,8 @@ package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Notification contains functions for creating and getting notifications.
type Notification interface {
// GetNotificationsForAccount returns a list of notifications that pertain to the given accountID.
GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, DBError)
GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
}

View file

@ -45,7 +45,7 @@ func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query {
Relation("HeaderMediaAttachment")
}
func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.DBError) {
func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
@ -56,7 +56,7 @@ func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.DBError) {
return account, err
}
func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.DBError) {
func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
@ -67,7 +67,7 @@ func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.DBError)
return account, err
}
func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.DBError) {
func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
@ -78,7 +78,7 @@ func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.DBError)
return account, err
}
func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.DBError) {
func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account)
@ -98,7 +98,7 @@ func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.DBE
return account, err
}
func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.DBError) {
func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.Error) {
status := &gtsmodel.Status{}
q := a.conn.Model(status).
@ -112,7 +112,7 @@ func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.DBErro
return status.CreatedAt, err
}
func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.DBError {
func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
if mediaAttachment.Avatar && mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
@ -137,7 +137,7 @@ func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAtta
return nil
}
func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.DBError) {
func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
@ -149,7 +149,7 @@ func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Accoun
return account, err
}
func (a *accountDB) GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) db.DBError {
func (a *accountDB) GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) db.Error {
if err := a.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
@ -159,7 +159,7 @@ func (a *accountDB) GetAccountFollowRequests(accountID string, followRequests *[
return nil
}
func (a *accountDB) GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) db.DBError {
func (a *accountDB) GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) db.Error {
if err := a.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
@ -169,11 +169,11 @@ func (a *accountDB) GetAccountFollowing(accountID string, following *[]gtsmodel.
return nil
}
func (a *accountDB) CountAccountFollowing(accountID string, localOnly bool) (int, db.DBError) {
func (a *accountDB) CountAccountFollowing(accountID string, localOnly bool) (int, db.Error) {
return a.conn.Model(&[]*gtsmodel.Follow{}).Where("account_id = ?", accountID).Count()
}
func (a *accountDB) GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) db.DBError {
func (a *accountDB) GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) db.Error {
q := a.conn.Model(followers)
@ -203,11 +203,11 @@ func (a *accountDB) GetAccountFollowers(accountID string, followers *[]gtsmodel.
return nil
}
func (a *accountDB) CountAccountFollowers(accountID string, localOnly bool) (int, db.DBError) {
func (a *accountDB) CountAccountFollowers(accountID string, localOnly bool) (int, db.Error) {
return a.conn.Model(&[]*gtsmodel.Follow{}).Where("target_account_id = ?", accountID).Count()
}
func (a *accountDB) GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) db.DBError {
func (a *accountDB) GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) db.Error {
if err := a.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
@ -217,11 +217,11 @@ func (a *accountDB) GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFa
return nil
}
func (a *accountDB) CountAccountStatuses(accountID string) (int, db.DBError) {
func (a *accountDB) CountAccountStatuses(accountID string) (int, db.Error) {
return a.conn.Model(&gtsmodel.Status{}).Where("account_id = ?", accountID).Count()
}
func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.DBError) {
func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
a.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
@ -267,7 +267,7 @@ func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeRepli
return statuses, nil
}
func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.DBError) {
func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
blocks := []*gtsmodel.Block{}
fq := a.conn.Model(&blocks).

View file

@ -45,7 +45,7 @@ type adminDB struct {
cancel context.CancelFunc
}
func (a *adminDB) IsUsernameAvailable(username string) db.DBError {
func (a *adminDB) IsUsernameAvailable(username string) db.Error {
// if no error we fail because it means we found something
// if error but it's not pg.ErrNoRows then we fail
// if err is pg.ErrNoRows we're good, we found nothing so continue
@ -57,7 +57,7 @@ func (a *adminDB) IsUsernameAvailable(username string) db.DBError {
return nil
}
func (a *adminDB) IsEmailAvailable(email string) db.DBError {
func (a *adminDB) IsEmailAvailable(email string) db.Error {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
@ -85,7 +85,7 @@ func (a *adminDB) IsEmailAvailable(email string) db.DBError {
return nil
}
func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.DBError) {
func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
@ -168,7 +168,7 @@ func (a *adminDB) NewSignup(username string, reason string, requireApproval bool
return u, nil
}
func (a *adminDB) CreateInstanceAccount() db.DBError {
func (a *adminDB) CreateInstanceAccount() db.Error {
username := a.config.Host
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
@ -210,7 +210,7 @@ func (a *adminDB) CreateInstanceAccount() db.DBError {
return nil
}
func (a *adminDB) CreateInstanceInstance() db.DBError {
func (a *adminDB) CreateInstanceInstance() db.Error {
iID, err := id.NewRandomULID()
if err != nil {
return err

View file

@ -38,7 +38,7 @@ type basicDB struct {
cancel context.CancelFunc
}
func (b *basicDB) Put(i interface{}) db.DBError {
func (b *basicDB) Put(i interface{}) db.Error {
_, err := b.conn.Model(i).Insert(i)
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
@ -46,7 +46,7 @@ func (b *basicDB) Put(i interface{}) db.DBError {
return err
}
func (b *basicDB) GetByID(id string, i interface{}) db.DBError {
func (b *basicDB) GetByID(id string, i interface{}) db.Error {
if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
@ -57,7 +57,7 @@ func (b *basicDB) GetByID(id string, i interface{}) db.DBError {
return nil
}
func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.DBError {
func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.Error {
if len(where) == 0 {
return errors.New("no queries provided")
}
@ -85,7 +85,7 @@ func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.DBError {
return nil
}
func (b *basicDB) GetAll(i interface{}) db.DBError {
func (b *basicDB) GetAll(i interface{}) db.Error {
if err := b.conn.Model(i).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
@ -95,7 +95,7 @@ func (b *basicDB) GetAll(i interface{}) db.DBError {
return nil
}
func (b *basicDB) DeleteByID(id string, i interface{}) db.DBError {
func (b *basicDB) DeleteByID(id string, i interface{}) db.Error {
if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
// if there are no rows *anyway* then that's fine
// just return err if there's an actual error
@ -106,7 +106,7 @@ func (b *basicDB) DeleteByID(id string, i interface{}) db.DBError {
return nil
}
func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.DBError {
func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.Error {
if len(where) == 0 {
return errors.New("no queries provided")
}
@ -126,7 +126,7 @@ func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.DBError {
return nil
}
func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.DBError {
func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.Error {
if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
@ -136,7 +136,7 @@ func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.DBError {
return nil
}
func (b *basicDB) UpdateByID(id string, i interface{}) db.DBError {
func (b *basicDB) UpdateByID(id string, i interface{}) db.Error {
if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
@ -146,12 +146,12 @@ func (b *basicDB) UpdateByID(id string, i interface{}) db.DBError {
return nil
}
func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.DBError {
func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.Error {
_, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
return err
}
func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.DBError {
func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.Error {
q := b.conn.Model(i)
for _, w := range where {
@ -173,28 +173,28 @@ func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i
return err
}
func (b *basicDB) CreateTable(i interface{}) db.DBError {
func (b *basicDB) CreateTable(i interface{}) db.Error {
return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
}
func (b *basicDB) DropTable(i interface{}) db.DBError {
func (b *basicDB) DropTable(i interface{}) db.Error {
return b.conn.Model(i).DropTable(&orm.DropTableOptions{
IfExists: true,
})
}
func (b *basicDB) RegisterTable(i interface{}) db.DBError {
func (b *basicDB) RegisterTable(i interface{}) db.Error {
orm.RegisterTable(i)
return nil
}
func (b *basicDB) IsHealthy(ctx context.Context) db.DBError {
func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
return b.conn.Ping(ctx)
}
func (b *basicDB) Stop(ctx context.Context) db.DBError {
func (b *basicDB) Stop(ctx context.Context) db.Error {
b.log.Info("closing db connection")
if err := b.conn.Close(); err != nil {
// only cancel if there's a problem closing the db

83
internal/db/pg/domain.go Normal file
View file

@ -0,0 +1,83 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"net/url"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
type domainDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (d *domainDB) IsDomainBlocked(domain string) (bool, db.Error) {
if domain == "" {
return false, nil
}
blocked, err := d.conn.
Model(&gtsmodel.DomainBlock{}).
Where("LOWER(domain) = LOWER(?)", domain).
Exists()
err = processErrorResponse(err)
return blocked, err
}
func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) {
// filter out any doubles
uniqueDomains := util.UniqueStrings(domains)
for _, domain := range uniqueDomains {
if blocked, err := d.IsDomainBlocked(domain); err != nil {
return false, err
} else if blocked {
return blocked, nil
}
}
// no blocks found
return false, nil
}
func (d *domainDB) IsURIBlocked(uri *url.URL) (bool, db.Error) {
domain := uri.Hostname()
return d.IsDomainBlocked(domain)
}
func (d *domainDB) AreURIsBlocked(uris []*url.URL) (bool, db.Error) {
domains := []string{}
for _, uri := range uris {
domains = append(domains, uri.Hostname())
}
return d.AreDomainsBlocked(domains)
}

View file

@ -35,7 +35,7 @@ type instanceDB struct {
cancel context.CancelFunc
}
func (i *instanceDB) GetUserCountForInstance(domain string) (int, db.DBError) {
func (i *instanceDB) GetUserCountForInstance(domain string) (int, db.Error) {
q := i.conn.Model(&[]*gtsmodel.Account{})
if domain == i.config.Host {
@ -51,7 +51,7 @@ func (i *instanceDB) GetUserCountForInstance(domain string) (int, db.DBError) {
return q.Count()
}
func (i *instanceDB) GetStatusCountForInstance(domain string) (int, db.DBError) {
func (i *instanceDB) GetStatusCountForInstance(domain string) (int, db.Error) {
q := i.conn.Model(&[]*gtsmodel.Status{})
if domain == i.config.Host {
@ -66,7 +66,7 @@ func (i *instanceDB) GetStatusCountForInstance(domain string) (int, db.DBError)
return q.Count()
}
func (i *instanceDB) GetDomainCountForInstance(domain string) (int, db.DBError) {
func (i *instanceDB) GetDomainCountForInstance(domain string) (int, db.Error) {
q := i.conn.Model(&[]*gtsmodel.Instance{})
if domain == i.config.Host {
@ -81,7 +81,7 @@ func (i *instanceDB) GetDomainCountForInstance(domain string) (int, db.DBError)
return q.Count()
}
func (i *instanceDB) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.DBError) {
func (i *instanceDB) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
i.log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{}

View file

@ -41,7 +41,7 @@ func (m *mediaDB) newMediaQ(i interface{}) *orm.Query {
Relation("Account")
}
func (m *mediaDB) GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, db.DBError) {
func (m *mediaDB) GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, db.Error) {
attachment := &gtsmodel.MediaAttachment{}
q := m.newMediaQ(attachment).

View file

@ -24,6 +24,7 @@ import (
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -34,6 +35,36 @@ type mentionDB struct {
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}
func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) {
if m.cache == nil {
m.cache = cache.New()
}
if err := m.cache.Store(id, mention); err != nil {
m.log.Panicf("mentionDB: error storing in cache: %s", err)
}
}
func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) {
if m.cache == nil {
m.cache = cache.New()
return nil, false
}
mI, err := m.cache.Fetch(id)
if err != nil || mI == nil {
return nil, false
}
mention, ok := mI.(*gtsmodel.Mention)
if !ok {
m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id)
}
return mention, true
}
func (m *mentionDB) newMentionQ(i interface{}) *orm.Query {
@ -43,7 +74,11 @@ func (m *mentionDB) newMentionQ(i interface{}) *orm.Query {
Relation("TargetAccount")
}
func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.DBError) {
func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) {
if mention, cached := m.mentionCached(id); cached {
return mention, nil
}
mention := &gtsmodel.Mention{}
q := m.newMentionQ(mention).
@ -51,20 +86,23 @@ func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.DBError) {
err := processErrorResponse(q.Select())
if err == nil && mention != nil {
m.cacheMention(id, mention)
}
return mention, err
}
func (m *mentionDB) GetMentions(ids []string) ([]*gtsmodel.Mention, db.DBError) {
func (m *mentionDB) GetMentions(ids []string) ([]*gtsmodel.Mention, db.Error) {
mentions := []*gtsmodel.Mention{}
if len(ids) == 0 {
return mentions, nil
for _, i := range ids {
mention, err := m.GetMention(i)
if err != nil {
return nil, processErrorResponse(err)
}
mentions = append(mentions, mention)
}
q := m.newMentionQ(&mentions).
Where("mention.id in (?)", pg.In(ids))
err := processErrorResponse(q.Select())
return mentions, err
return mentions, nil
}

View file

@ -24,7 +24,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.DBError) {
func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
notifications := []*gtsmodel.Notification{}
q := ps.conn.Model(&notifications).Where("target_account_id = ?", accountID)

View file

@ -49,6 +49,7 @@ type postgresService struct {
db.Account
db.Admin
db.Basic
db.Domain
db.Instance
db.Media
db.Mention
@ -123,6 +124,12 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
log: log,
cancel: cancel,
},
Domain: &domainDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Instance: &instanceDB{
config: c,
conn: conn,

View file

@ -43,7 +43,7 @@ func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *orm.Query {
Relation("TargetAccount")
}
func (r *relationshipDB) processResponse(block *gtsmodel.Block, err error) (*gtsmodel.Block, db.DBError) {
func (r *relationshipDB) processResponse(block *gtsmodel.Block, err error) (*gtsmodel.Block, db.Error) {
switch err {
case pg.ErrNoRows:
return nil, db.ErrNoEntries
@ -54,7 +54,7 @@ func (r *relationshipDB) processResponse(block *gtsmodel.Block, err error) (*gts
}
}
func (r *relationshipDB) Blocked(account1 string, account2 string, eitherDirection bool) (bool, db.DBError) {
func (r *relationshipDB) Blocked(account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
q := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", account1).Where("target_account_id = ?", account2)
if eitherDirection {
@ -64,7 +64,7 @@ func (r *relationshipDB) Blocked(account1 string, account2 string, eitherDirecti
return q.Exists()
}
func (r *relationshipDB) GetBlock(account1 string, account2 string) (*gtsmodel.Block, db.DBError) {
func (r *relationshipDB) GetBlock(account1 string, account2 string) (*gtsmodel.Block, db.Error) {
block := &gtsmodel.Block{}
q := r.newBlockQ(block).
@ -74,7 +74,7 @@ func (r *relationshipDB) GetBlock(account1 string, account2 string) (*gtsmodel.B
return r.processResponse(block, q.Select())
}
func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.DBError) {
func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
rel := &gtsmodel.Relationship{
ID: targetAccount,
}
@ -128,7 +128,7 @@ func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount
return rel, nil
}
func (r *relationshipDB) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.DBError) {
func (r *relationshipDB) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
@ -136,7 +136,7 @@ func (r *relationshipDB) Follows(sourceAccount *gtsmodel.Account, targetAccount
return r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
func (r *relationshipDB) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.DBError) {
func (r *relationshipDB) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
@ -144,7 +144,7 @@ func (r *relationshipDB) FollowRequested(sourceAccount *gtsmodel.Account, target
return r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
func (r *relationshipDB) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.DBError) {
func (r *relationshipDB) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
if account1 == nil || account2 == nil {
return false, nil
}
@ -170,7 +170,7 @@ func (r *relationshipDB) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.
return f1 && f2, nil
}
func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.DBError) {
func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
// make sure the original follow request exists
fr := &gtsmodel.FollowRequest{}
if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {

View file

@ -27,6 +27,7 @@ import (
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -37,6 +38,36 @@ type statusDB struct {
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}
func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
if s.cache == nil {
s.cache = cache.New()
}
if err := s.cache.Store(id, status); err != nil {
s.log.Panicf("statusDB: error storing in cache: %s", err)
}
}
func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
if s.cache == nil {
s.cache = cache.New()
return nil, false
}
sI, err := s.cache.Fetch(id)
if err != nil || sI == nil {
return nil, false
}
status, ok := sI.(*gtsmodel.Status)
if !ok {
s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
}
return status, true
}
func (s *statusDB) newStatusQ(status interface{}) *orm.Query {
@ -60,7 +91,11 @@ func (s *statusDB) newFaveQ(faves interface{}) *orm.Query {
Relation("Status")
}
func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.DBError) {
func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(id); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
@ -68,10 +103,18 @@ func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.DBError) {
err := processErrorResponse(q.Select())
if err == nil && status != nil {
s.cacheStatus(id, status)
}
return status, err
}
func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.DBError) {
func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
@ -79,10 +122,18 @@ func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.DBError) {
err := processErrorResponse(q.Select())
if err == nil && status != nil {
s.cacheStatus(uri, status)
}
return status, err
}
func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.DBError) {
func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
@ -90,10 +141,14 @@ func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.DBError) {
err := processErrorResponse(q.Select())
if err == nil && status != nil {
s.cacheStatus(uri, status)
}
return status, err
}
func (s *statusDB) PutStatus(status *gtsmodel.Status) db.DBError {
func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error {
transaction := func(tx *pg.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
@ -133,7 +188,7 @@ func (s *statusDB) PutStatus(status *gtsmodel.Status) db.DBError {
return processErrorResponse(s.conn.RunInTransaction(context.Background(), transaction))
}
func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.DBError) {
func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
parents := []*gtsmodel.Status{}
s.statusParent(status, &parents, onlyDirect)
@ -157,7 +212,7 @@ func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmo
s.statusParent(parentStatus, foundStatuses, false)
}
func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.DBError) {
func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
foundStatuses := &list.List{}
foundStatuses.PushFront(status)
s.statusChildren(status, foundStatuses, onlyDirect, minID)
@ -212,35 +267,35 @@ func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.L
}
}
func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.DBError) {
func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
}
func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.DBError) {
func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
}
func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.DBError) {
func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
}
func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.DBError) {
func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
faves := []*gtsmodel.StatusFave{}
q := s.newFaveQ(&faves).
@ -251,7 +306,7 @@ func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFa
return faves, err
}
func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.DBError) {
func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
reblogs := []*gtsmodel.Status{}
q := s.newStatusQ(&reblogs).

View file

@ -19,7 +19,9 @@
package pg_test
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
@ -94,6 +96,39 @@ func (suite *StatusTestSuite) TestGetStatusWithExtras() {
suite.NotEmpty(status.Emojis)
}
func (suite *StatusTestSuite) TestGetStatusWithMention() {
status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_2_status_5"].ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.NotEmpty(status.Mentions)
suite.NotEmpty(status.MentionIDs)
suite.NotNil(status.InReplyTo)
suite.NotNil(status.InReplyToAccount)
}
func (suite *StatusTestSuite) TestGetStatusTwice() {
before1 := time.Now()
_, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after1 := time.Now()
duration1 := after1.Sub(before1)
fmt.Println(duration1.Nanoseconds())
before2 := time.Now()
_, err = suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after2 := time.Now()
duration2 := after2.Sub(before2)
fmt.Println(duration2.Nanoseconds())
// second retrieval should be several orders faster since it will be cached now
suite.Less(duration2, duration1)
}
func TestStatusTestSuite(t *testing.T) {
suite.Run(t, new(StatusTestSuite))
}

View file

@ -36,7 +36,7 @@ type timelineDB struct {
cancel context.CancelFunc
}
func (t *timelineDB) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.DBError) {
func (t *timelineDB) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
q := t.conn.Model(&statuses)
@ -96,7 +96,7 @@ func (t *timelineDB) GetHomeTimelineForAccount(accountID string, maxID string, s
return statuses, nil
}
func (t *timelineDB) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.DBError) {
func (t *timelineDB) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
q := t.conn.Model(&statuses).
@ -143,7 +143,7 @@ func (t *timelineDB) GetPublicTimelineForAccount(accountID string, maxID string,
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
func (t *timelineDB) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.DBError) {
func (t *timelineDB) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
faves := []*gtsmodel.StatusFave{}

View file

@ -8,7 +8,7 @@ import (
)
// processErrorResponse parses the given error and returns an appropriate DBError.
func processErrorResponse(err error) db.DBError {
func processErrorResponse(err error) db.Error {
switch err {
case nil:
return nil

View file

@ -20,32 +20,33 @@ package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Relationship contains functions for getting or modifying the relationship between two accounts.
type Relationship interface {
// Blocked checks whether account 1 has a block in place against block2.
// If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1.
Blocked(account1 string, account2 string, eitherDirection bool) (bool, DBError)
Blocked(account1 string, account2 string, eitherDirection bool) (bool, Error)
// GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't.
//
// Because this is slower than Blocked, only use it if you need the actual Block struct for some reason,
// not if you're just checking for the existence of a block.
GetBlock(account1 string, account2 string) (*gtsmodel.Block, DBError)
GetBlock(account1 string, account2 string) (*gtsmodel.Block, Error)
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, DBError)
GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
// Follows returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, DBError)
Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// FollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, DBError)
FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// Mutuals returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, DBError)
Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
//
// It will return the newly created follow for further processing.
AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, DBError)
AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
}

View file

@ -20,55 +20,56 @@ package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.
type Status interface {
// GetStatusByID returns one status from the database, with all rel fields populated (if possible).
GetStatusByID(id string) (*gtsmodel.Status, DBError)
GetStatusByID(id string) (*gtsmodel.Status, Error)
// GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
GetStatusByURI(uri string) (*gtsmodel.Status, DBError)
GetStatusByURI(uri string) (*gtsmodel.Status, Error)
// GetStatusByURL returns one status from the database, with all rel fields populated (if possible).
GetStatusByURL(uri string) (*gtsmodel.Status, DBError)
GetStatusByURL(uri string) (*gtsmodel.Status, Error)
// PutStatus stores one status in the database.
PutStatus(status *gtsmodel.Status) DBError
PutStatus(status *gtsmodel.Status) Error
// CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong
CountStatusReplies(status *gtsmodel.Status) (int, DBError)
CountStatusReplies(status *gtsmodel.Status) (int, Error)
// CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
CountStatusReblogs(status *gtsmodel.Status) (int, DBError)
CountStatusReblogs(status *gtsmodel.Status) (int, Error)
// CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong
CountStatusFaves(status *gtsmodel.Status) (int, DBError)
CountStatusFaves(status *gtsmodel.Status) (int, Error)
// GetStatusParents get the parent statuses of a given status.
// GetStatusParents gets the parent statuses of a given status.
//
// If onlyDirect is true, only the immediate parent will be returned.
GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, DBError)
GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
// GetStatusChildren gets the child statuses of a given status.
//
// If onlyDirect is true, only the immediate children will be returned.
GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, DBError)
GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
// IsStatusFavedBy checks if a given status has been faved by a given account ID
IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusMutedBy checks if a given status has been muted by a given account ID
IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, Error)
// GetStatusFaves returns a slice of faves/likes of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, DBError)
GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error)
// GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, DBError)
GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
}

View file

@ -20,17 +20,18 @@ package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Timeline contains functionality for retrieving home/public/faved etc timelines for an account.
type Timeline interface {
// GetHomeTimelineForAccount returns a slice of statuses from accounts that are followed by the given account id.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, DBError)
GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetPublicTimelineForAccount fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, DBError)
GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetFavedTimelineForAccount fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
// It will use the given filters and try to return as many statuses as possible up to the limit.
@ -39,5 +40,5 @@ type Timeline interface {
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, DBError)
GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
}