From 15153ee0c8b448c77551b8915c09bd9bf94c53d6 Mon Sep 17 00:00:00 2001 From: tsmethurst Date: Tue, 17 Aug 2021 15:23:28 +0200 Subject: [PATCH] continue moving db stuff around --- internal/api/client/auth/callback.go | 4 +- internal/api/security/signaturecheck.go | 2 +- internal/db/account.go | 42 ++-- internal/db/admin.go | 10 +- internal/db/basic.go | 28 +-- internal/db/error.go | 18 +- internal/db/instance.go | 8 +- internal/db/notification.go | 2 +- internal/db/pg/account.go | 235 ++++++++++++------ internal/db/pg/admin.go | 89 ++++--- internal/db/pg/basic.go | 135 ++++++++-- internal/db/pg/instance.go | 39 +-- internal/db/pg/{blocks.go => notification.go} | 40 +-- internal/db/pg/pg.go | 134 ++++------ internal/db/pg/pg_test.go | 47 ++++ internal/db/pg/relationship.go | 92 ++++--- internal/db/pg/status.go | 111 ++++++--- internal/db/pg/status_test.go | 86 +++++++ internal/db/pg/timeline.go | 40 +-- internal/db/pg/update.go | 73 ------ internal/db/relationship.go | 12 +- internal/db/status.go | 28 ++- internal/db/timeline.go | 6 +- internal/federation/dereferencing/blocked.go | 2 +- internal/federation/dereferencing/status.go | 2 +- internal/federation/federatingdb/create.go | 2 +- internal/federation/federatingdb/outbox.go | 2 +- internal/federation/federatingdb/owns.go | 16 +- internal/federation/federatingdb/util.go | 4 +- internal/federation/federatingprotocol.go | 5 +- internal/federation/util.go | 2 +- internal/gtsmodel/account.go | 2 + internal/gtsmodel/instance.go | 2 + internal/gtsmodel/status.go | 31 +-- internal/oauth/clientstore.go | 4 +- internal/oauth/clientstore_test.go | 2 +- internal/oauth/server.go | 2 +- internal/oauth/tokenstore.go | 4 +- internal/processing/account/createblock.go | 2 +- internal/processing/account/createfollow.go | 2 +- internal/processing/account/delete.go | 8 +- internal/processing/account/get.go | 2 +- internal/processing/account/getfollowers.go | 4 +- internal/processing/account/getfollowing.go | 4 +- internal/processing/account/getstatuses.go | 4 +- internal/processing/account/removeblock.go | 2 +- internal/processing/account/removefollow.go | 2 +- .../processing/admin/createdomainblock.go | 6 +- .../processing/admin/deletedomainblock.go | 2 +- internal/processing/admin/getdomainblock.go | 2 +- internal/processing/admin/getdomainblocks.go | 2 +- internal/processing/blocks.go | 2 +- internal/processing/followrequest.go | 2 +- internal/processing/fromcommon.go | 2 +- internal/processing/fromfederator.go | 2 +- internal/processing/media/delete.go | 4 +- internal/processing/media/getmedia.go | 2 +- internal/processing/media/update.go | 2 +- internal/processing/search.go | 2 +- internal/processing/status/context.go | 2 +- internal/processing/status/delete.go | 2 +- internal/processing/status/unboost.go | 2 +- internal/processing/status/unfave.go | 2 +- internal/processing/status/util.go | 6 +- internal/processing/timeline.go | 8 +- internal/router/session.go | 2 +- internal/timeline/index.go | 4 +- internal/timeline/prepare.go | 6 +- internal/typeutils/astointernal.go | 2 +- internal/typeutils/internaltofrontend.go | 72 ++---- internal/visibility/statusvisible.go | 4 +- internal/visibility/util.go | 2 +- 72 files changed, 923 insertions(+), 614 deletions(-) rename internal/db/pg/{blocks.go => notification.go} (53%) create mode 100644 internal/db/pg/pg_test.go create mode 100644 internal/db/pg/status_test.go delete mode 100644 internal/db/pg/update.go diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go index 8bf2a50b5..a26838aa3 100644 --- a/internal/api/client/auth/callback.go +++ b/internal/api/client/auth/callback.go @@ -116,7 +116,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin return user, nil } - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // we have an actual error in the database return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err) } @@ -128,7 +128,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email) } - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // we have an actual error in the database return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err) } diff --git a/internal/api/security/signaturecheck.go b/internal/api/security/signaturecheck.go index b852c92ab..a3a04180d 100644 --- a/internal/api/security/signaturecheck.go +++ b/internal/api/security/signaturecheck.go @@ -59,7 +59,7 @@ func (m *Module) blockedDomain(host string) (bool, error) { return true, nil } - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are no entries so there's no block return false, nil } diff --git a/internal/db/account.go b/internal/db/account.go index 9584291e2..82268fe87 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -6,60 +6,66 @@ type Account interface { // GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID. // The given account pointer will be set to the result of the query, whatever it is. // In case of no entries, a 'no entries' error will be returned - GetAccountByUserID(userID string, account *gtsmodel.Account) error + GetAccountByUserID(userID string, account *gtsmodel.Account) DBError + + // GetAccountByID returns one account with the given ID, or an error if something goes wrong. + GetAccountByID(id string) (*gtsmodel.Account, DBError) + + // GetAccountByURI returns one account with the given URI, or an error if something goes wrong. + GetAccountByURI(uri string) (*gtsmodel.Account, DBError) // GetLocalAccountByUsername is a shortcut for the common action of fetching an account ON THIS INSTANCE // according to its username, which should be unique. // The given account pointer will be set to the result of the query, whatever it is. // In case of no entries, a 'no entries' error will be returned - GetLocalAccountByUsername(username string, account *gtsmodel.Account) error + GetLocalAccountByUsername(username string, account *gtsmodel.Account) DBError // GetAccountFollowRequests is a shortcut for the common action of fetching a list of follow requests targeting the given account ID. // The given slice 'followRequests' will be set to the result of the query, whatever it is. // In case of no entries, a 'no entries' error will be returned - GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) error + GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) DBError // GetAccountFollowing is a shortcut for the common action of fetching a list of accounts that accountID is following. // The given slice 'following' will be set to the result of the query, whatever it is. // In case of no entries, a 'no entries' error will be returned - GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) error + GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) DBError + + CountAccountFollowing(accountID string, localOnly bool) (int, DBError) // GetAccountFollowers is a shortcut for the common action of fetching a list of accounts that accountID is followed by. // The given slice 'followers' will be set to the result of the query, whatever it is. // In case of no entries, a 'no entries' error will be returned // // If localOnly is set to true, then only followers from *this instance* will be returned. - GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error + GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) DBError + + CountAccountFollowers(accountID string, localOnly bool) (int, DBError) // GetAccountFaves is a shortcut for the common action of fetching a list of faves made by the given accountID. // The given slice 'faves' will be set to the result of the query, whatever it is. // In case of no entries, a 'no entries' error will be returned - GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) error + GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) DBError // GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID. - GetAccountStatusesCount(accountID string) (int, error) + CountAccountStatuses(accountID string) (int, DBError) // GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can // be very memory intensive so you probably shouldn't do this! // In case of no entries, a 'no entries' error will be returned - GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error) + GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, DBError) - GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) + GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, DBError) // GetAccountLastStatus simply gets the most recent status by the given account. // The given slice 'status' pointer will be set to the result of the query, whatever it is. // In case of no entries, a 'no entries' error will be returned - GetAccountLastStatus(accountID string, status *gtsmodel.Status) error + GetAccountLastStatus(accountID string, status *gtsmodel.Status) DBError // SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment. - SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error + SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) DBError - // GetHeaderAvatarForAccountID gets the current avatar for the given account ID. - // The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists. - GetAccountAvatar(avatar *gtsmodel.MediaAttachment, accountID string) error - - // GetAccountHeader gets the current header for the given account ID. - // The passed mediaAttachment pointer will be populated with the value of the header, if it exists. - GetAccountHeader(header *gtsmodel.MediaAttachment, accountID string) error + // GetInstanceAccount returns the instance account for the given domain. + // If domain is empty, this instance account will be returned. + GetInstanceAccount(domain string) (*gtsmodel.Account, DBError) } diff --git a/internal/db/admin.go b/internal/db/admin.go index 84589a250..8e7e489d2 100644 --- a/internal/db/admin.go +++ b/internal/db/admin.go @@ -9,26 +9,26 @@ import ( type Admin interface { // IsUsernameAvailable checks whether a given username is available on our domain. // Returns an error if the username is already taken, or something went wrong in the db. - IsUsernameAvailable(username string) error + IsUsernameAvailable(username string) DBError // IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain. // Return an error if: // A) the email is already associated with an account // B) we block signups from this email domain // C) something went wrong in the db - IsEmailAvailable(email string) error + IsEmailAvailable(email string) DBError // NewSignup creates a new user in the database with the given parameters. // By the time this function is called, it should be assumed that all the parameters have passed validation! - NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error) + NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, DBError) // CreateInstanceAccount creates an account in the database with the same username as the instance host value. // Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'. // This is needed for things like serving files that belong to the instance and not an individual user/account. - CreateInstanceAccount() error + CreateInstanceAccount() DBError // CreateInstanceInstance creates an instance in the database with the same domain as the instance host value. // Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'. // This is needed for things like serving instance information through /api/v1/instance - CreateInstanceInstance() error + CreateInstanceInstance() DBError } diff --git a/internal/db/basic.go b/internal/db/basic.go index d63f7034b..5740149c2 100644 --- a/internal/db/basic.go +++ b/internal/db/basic.go @@ -5,60 +5,60 @@ import "context" type Basic interface { // CreateTable creates a table for the given interface. // For implementations that don't use tables, this can just return nil. - CreateTable(i interface{}) error + CreateTable(i interface{}) DBError // DropTable drops the table for the given interface. // For implementations that don't use tables, this can just return nil. - DropTable(i interface{}) error + DropTable(i interface{}) DBError // Stop should stop and close the database connection cleanly, returning an error if this is not possible. // If the database implementation doesn't need to be stopped, this can just return nil. - Stop(ctx context.Context) error + Stop(ctx context.Context) DBError // IsHealthy should return nil if the database connection is healthy, or an error if not. - IsHealthy(ctx context.Context) error + IsHealthy(ctx context.Context) DBError // GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry, // for other implementations (for example, in-memory) it might just be the key of a map. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetByID(id string, i interface{}) error + GetByID(id string, i interface{}) DBError // GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the // name of the key to select from. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetWhere(where []Where, i interface{}) error + GetWhere(where []Where, i interface{}) DBError // GetAll will try to get all entries of type i. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. // In case of no entries, a 'no entries' error will be returned - GetAll(i interface{}) error + GetAll(i interface{}) DBError // Put simply stores i. It is up to the implementation to figure out how to store it, and using what key. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - Put(i interface{}) error + Put(i interface{}) DBError // Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/ // It is up to the implementation to figure out how to store it, and using what key. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - Upsert(i interface{}, conflictColumn string) error + Upsert(i interface{}, conflictColumn string) DBError // UpdateByID updates i with id id. // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. - UpdateByID(id string, i interface{}) error + UpdateByID(id string, i interface{}) DBError // UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value. - UpdateOneByID(id string, key string, value interface{}, i interface{}) error + UpdateOneByID(id string, key string, value interface{}, i interface{}) DBError // UpdateWhere updates column key of interface i with the given value, where the given parameters apply. - UpdateWhere(where []Where, key string, value interface{}, i interface{}) error + UpdateWhere(where []Where, key string, value interface{}, i interface{}) DBError // DeleteByID removes i with id id. // If i didn't exist anyway, then no error should be returned. - DeleteByID(id string, i interface{}) error + DeleteByID(id string, i interface{}) DBError // DeleteWhere deletes i where key = value // If i didn't exist anyway, then no error should be returned. - DeleteWhere(where []Where, i interface{}) error + DeleteWhere(where []Where, i interface{}) DBError } diff --git a/internal/db/error.go b/internal/db/error.go index 197c7bd68..d669f2513 100644 --- a/internal/db/error.go +++ b/internal/db/error.go @@ -18,16 +18,12 @@ package db -// ErrNoEntries is to be returned from the DB interface when no entries are found for a given query. -type ErrNoEntries struct{} +import "fmt" -func (e ErrNoEntries) Error() string { - return "no entries" -} +type DBError error -// ErrAlreadyExists is to be returned from the DB interface when an entry already exists for a given query or its constraints. -type ErrAlreadyExists struct{} - -func (e ErrAlreadyExists) Error() string { - return "already exists" -} +var ( + ErrNoEntries DBError = fmt.Errorf("no entries") + ErrAlreadyExists DBError = fmt.Errorf("already exists") + ErrUnknown DBError = fmt.Errorf("unknown error") +) diff --git a/internal/db/instance.go b/internal/db/instance.go index dccdb9793..152cc39b8 100644 --- a/internal/db/instance.go +++ b/internal/db/instance.go @@ -4,14 +4,14 @@ import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" type Instance interface { // GetUserCountForInstance returns the number of known accounts registered with the given domain. - GetUserCountForInstance(domain string) (int, error) + GetUserCountForInstance(domain string) (int, DBError) // GetStatusCountForInstance returns the number of known statuses posted from the given domain. - GetStatusCountForInstance(domain string) (int, error) + GetStatusCountForInstance(domain string) (int, DBError) // GetDomainCountForInstance returns the number of known instances known that the given domain federates with. - GetDomainCountForInstance(domain string) (int, error) + GetDomainCountForInstance(domain string) (int, DBError) // GetAccountsForInstance returns a slice of accounts from the given instance, arranged by ID. - GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) + GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, DBError) } diff --git a/internal/db/notification.go b/internal/db/notification.go index 6a3c0c29d..7360b4080 100644 --- a/internal/db/notification.go +++ b/internal/db/notification.go @@ -4,5 +4,5 @@ import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" type Notification interface { // GetNotificationsForAccount returns a list of notifications that pertain to the given accountID. - GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error) + GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, DBError) } diff --git a/internal/db/pg/account.go b/internal/db/pg/account.go index 9ecaba486..9c1493fdd 100644 --- a/internal/db/pg/account.go +++ b/internal/db/pg/account.go @@ -1,62 +1,100 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + package pg import ( + "context" "errors" "fmt" "github.com/go-pg/pg/v10" + "github.com/go-pg/pg/v10/orm" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (ps *postgresService) GetAccountHeader(header *gtsmodel.MediaAttachment, accountID string) error { - acct := >smodel.Account{} - if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - - if acct.HeaderMediaAttachmentID == "" { - return db.ErrNoEntries{} - } - - if err := ps.conn.Model(header).Where("id = ?", acct.HeaderMediaAttachmentID).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - return nil +type accountDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc } -func (ps *postgresService) GetAccountAvatar(avatar *gtsmodel.MediaAttachment, accountID string) error { - acct := >smodel.Account{} - if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - - if acct.AvatarMediaAttachmentID == "" { - return db.ErrNoEntries{} - } - - if err := ps.conn.Model(avatar).Where("id = ?", acct.AvatarMediaAttachmentID).Select(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - return nil +func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query { + return a.conn.Model(account). + Relation("AvatarMediaAttachment"). + Relation("HeaderMediaAttachment") } -func (ps *postgresService) GetAccountLastStatus(accountID string, status *gtsmodel.Status) error { - if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil { +func (a *accountDB) processResponse(account *gtsmodel.Account, err error) (*gtsmodel.Account, db.DBError) { + switch err { + case pg.ErrNoRows: + return nil, db.ErrNoEntries + case nil: + return account, nil + default: + return nil, err + } +} + +func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.DBError) { + account := >smodel.Account{} + + q := a.newAccountQ(account). + Where("account.id = ?", id) + + return a.processResponse(account, q.Select()) +} + +func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.DBError) { + account := >smodel.Account{} + + q := a.newAccountQ(account). + Where("account.uri = ?", uri) + + return a.processResponse(account, q.Select()) +} + +func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.DBError) { + account := >smodel.Account{} + + q := a.newAccountQ(account) + + if domain == "" { + q = q. + Where("account.username = ?", domain). + Where("account.domain = ?", domain) + } else { + q = q. + Where("account.username = ?", domain). + Where("? IS NULL", pg.Ident("domain")) + } + + return a.processResponse(account, q.Select()) +} + +func (a *accountDB) GetAccountLastStatus(accountID string, status *gtsmodel.Status) db.DBError { + if err := a.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil { if err == pg.ErrNoRows { - return db.ErrNoEntries{} + return db.ErrNoEntries } return err } @@ -64,7 +102,7 @@ func (ps *postgresService) GetAccountLastStatus(accountID string, status *gtsmod } -func (ps *postgresService) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error { +func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.DBError { if mediaAttachment.Avatar && mediaAttachment.Header { return errors.New("one media attachment cannot be both header and avatar") } @@ -79,47 +117,47 @@ func (ps *postgresService) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.Me } // TODO: there are probably more side effects here that need to be handled - if _, err := ps.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil { + if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil { return err } - if _, err := ps.conn.Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil { + if _, err := a.conn.Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil { return err } return nil } -func (ps *postgresService) GetAccountByUserID(userID string, account *gtsmodel.Account) error { +func (a *accountDB) GetAccountByUserID(userID string, account *gtsmodel.Account) db.DBError { user := >smodel.User{ ID: userID, } - if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil { + if err := a.conn.Model(user).Where("id = ?", userID).Select(); err != nil { if err == pg.ErrNoRows { - return db.ErrNoEntries{} + return db.ErrNoEntries } return err } - if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil { + if err := a.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil { if err == pg.ErrNoRows { - return db.ErrNoEntries{} + return db.ErrNoEntries } return err } return nil } -func (ps *postgresService) GetLocalAccountByUsername(username string, account *gtsmodel.Account) error { - if err := ps.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil { +func (a *accountDB) GetLocalAccountByUsername(username string, account *gtsmodel.Account) db.DBError { + if err := a.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil { if err == pg.ErrNoRows { - return db.ErrNoEntries{} + return db.ErrNoEntries } return err } return nil } -func (ps *postgresService) GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) error { - if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil { +func (a *accountDB) GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) db.DBError { + if err := a.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil { if err == pg.ErrNoRows { return nil } @@ -128,8 +166,8 @@ func (ps *postgresService) GetAccountFollowRequests(accountID string, followRequ return nil } -func (ps *postgresService) GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) error { - if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil { +func (a *accountDB) GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) db.DBError { + if err := a.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil { if err == pg.ErrNoRows { return nil } @@ -138,9 +176,13 @@ func (ps *postgresService) GetAccountFollowing(accountID string, following *[]gt return nil } -func (ps *postgresService) GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error { +func (a *accountDB) CountAccountFollowing(accountID string, localOnly bool) (int, db.DBError) { + return a.conn.Model(&[]*gtsmodel.Follow{}).Where("account_id = ?", accountID).Count() +} - q := ps.conn.Model(followers) +func (a *accountDB) GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) db.DBError { + + q := a.conn.Model(followers) if localOnly { // for local accounts let's get where domain is null OR where domain is an empty string, just to be safe @@ -168,8 +210,12 @@ func (ps *postgresService) GetAccountFollowers(accountID string, followers *[]gt return nil } -func (ps *postgresService) GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) error { - if err := ps.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil { +func (a *accountDB) CountAccountFollowers(accountID string, localOnly bool) (int, db.DBError) { + return a.conn.Model(&[]*gtsmodel.Follow{}).Where("target_account_id = ?", accountID).Count() +} + +func (a *accountDB) GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) db.DBError { + if err := a.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil { if err == pg.ErrNoRows { return nil } @@ -178,22 +224,15 @@ func (ps *postgresService) GetAccountFaves(accountID string, faves *[]gtsmodel.S return nil } -func (ps *postgresService) GetAccountStatusesCount(accountID string) (int, error) { - count, err := ps.conn.Model(>smodel.Status{}).Where("account_id = ?", accountID).Count() - if err != nil { - if err == pg.ErrNoRows { - return 0, nil - } - return 0, err - } - return count, nil +func (a *accountDB) CountAccountStatuses(accountID string) (int, db.DBError) { + return a.conn.Model(>smodel.Status{}).Where("account_id = ?", accountID).Count() } -func (ps *postgresService) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error) { - ps.log.Debugf("getting statuses for account %s", accountID) +func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.DBError) { + a.log.Debugf("getting statuses for account %s", accountID) statuses := []*gtsmodel.Status{} - q := ps.conn.Model(&statuses).Order("id DESC") + q := a.conn.Model(&statuses).Order("id DESC") if accountID != "" { q = q.Where("account_id = ?", accountID) } @@ -222,15 +261,57 @@ func (ps *postgresService) GetAccountStatuses(accountID string, limit int, exclu if err := q.Select(); err != nil { if err == pg.ErrNoRows { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return nil, err } if len(statuses) == 0 { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } - ps.log.Debugf("returning statuses for account %s", accountID) + a.log.Debugf("returning statuses for account %s", accountID) return statuses, nil } + +func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.DBError) { + blocks := []*gtsmodel.Block{} + + fq := a.conn.Model(&blocks). + Where("block.account_id = ?", accountID). + Relation("TargetAccount"). + Order("block.id DESC") + + if maxID != "" { + fq = fq.Where("block.id < ?", maxID) + } + + if sinceID != "" { + fq = fq.Where("block.id > ?", sinceID) + } + + if limit > 0 { + fq = fq.Limit(limit) + } + + err := fq.Select() + if err != nil { + if err == pg.ErrNoRows { + return nil, "", "", db.ErrNoEntries + } + return nil, "", "", err + } + + if len(blocks) == 0 { + return nil, "", "", db.ErrNoEntries + } + + accounts := []*gtsmodel.Account{} + for _, b := range blocks { + accounts = append(accounts, b.TargetAccount) + } + + nextMaxID := blocks[len(blocks)-1].ID + prevMinID := blocks[0].ID + return accounts, nextMaxID, prevMinID, nil +} diff --git a/internal/db/pg/admin.go b/internal/db/pg/admin.go index 7bc614f14..fe3d48b54 100644 --- a/internal/db/pg/admin.go +++ b/internal/db/pg/admin.go @@ -1,6 +1,25 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + package pg import ( + "context" "crypto/rand" "crypto/rsa" "fmt" @@ -10,17 +29,27 @@ import ( "time" "github.com/go-pg/pg/v10" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/util" "golang.org/x/crypto/bcrypt" ) -func (ps *postgresService) IsUsernameAvailable(username string) error { +type adminDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc +} + +func (a *adminDB) IsUsernameAvailable(username string) db.DBError { // if no error we fail because it means we found something // if error but it's not pg.ErrNoRows then we fail // if err is pg.ErrNoRows we're good, we found nothing so continue - if err := ps.conn.Model(>smodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil { + if err := a.conn.Model(>smodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil { return fmt.Errorf("username %s already in use", username) } else if err != pg.ErrNoRows { return fmt.Errorf("db error: %s", err) @@ -28,7 +57,7 @@ func (ps *postgresService) IsUsernameAvailable(username string) error { return nil } -func (ps *postgresService) IsEmailAvailable(email string) error { +func (a *adminDB) IsEmailAvailable(email string) db.DBError { // parse the domain from the email m, err := mail.ParseAddress(email) if err != nil { @@ -37,7 +66,7 @@ func (ps *postgresService) IsEmailAvailable(email string) error { domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ // check if the email domain is blocked - if err := ps.conn.Model(>smodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil { + if err := a.conn.Model(>smodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil { // fail because we found something return fmt.Errorf("email domain %s is blocked", domain) } else if err != pg.ErrNoRows { @@ -46,7 +75,7 @@ func (ps *postgresService) IsEmailAvailable(email string) error { } // check if this email is associated with a user already - if err := ps.conn.Model(>smodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil { + if err := a.conn.Model(>smodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil { // fail because we found something return fmt.Errorf("email %s already in use", email) } else if err != pg.ErrNoRows { @@ -56,16 +85,16 @@ func (ps *postgresService) IsEmailAvailable(email string) error { return nil } -func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error) { +func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.DBError) { key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - ps.log.Errorf("error creating new rsa key: %s", err) + a.log.Errorf("error creating new rsa key: %s", err) return nil, err } // if something went wrong while creating a user, we might already have an account, so check here first... - a := >smodel.Account{} - err = ps.conn.Model(a).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select() + acct := >smodel.Account{} + err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select() if err != nil { // there's been an actual error if err != pg.ErrNoRows { @@ -73,13 +102,13 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr } // we just don't have an account yet create one - newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host) + newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host) newAccountID, err := id.NewRandomULID() if err != nil { return nil, err } - a = >smodel.Account{ + acct = >smodel.Account{ ID: newAccountID, Username: username, DisplayName: username, @@ -96,7 +125,7 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr FollowingURI: newAccountURIs.FollowingURI, FeaturedCollectionURI: newAccountURIs.CollectionURI, } - if _, err = ps.conn.Model(a).Insert(); err != nil { + if _, err = a.conn.Model(acct).Insert(); err != nil { return nil, err } } @@ -113,7 +142,7 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr u := >smodel.User{ ID: newUserID, - AccountID: a.ID, + AccountID: acct.ID, EncryptedPassword: string(pw), SignUpIP: signUpIP.To4(), Locale: locale, @@ -132,18 +161,18 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr u.Moderator = true } - if _, err = ps.conn.Model(u).Insert(); err != nil { + if _, err = a.conn.Model(u).Insert(); err != nil { return nil, err } return u, nil } -func (ps *postgresService) CreateInstanceAccount() error { - username := ps.config.Host +func (a *adminDB) CreateInstanceAccount() db.DBError { + username := a.config.Host key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - ps.log.Errorf("error creating new rsa key: %s", err) + a.log.Errorf("error creating new rsa key: %s", err) return err } @@ -152,10 +181,10 @@ func (ps *postgresService) CreateInstanceAccount() error { return err } - newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host) - a := >smodel.Account{ + newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host) + acct := >smodel.Account{ ID: aID, - Username: ps.config.Host, + Username: a.config.Host, DisplayName: username, URL: newAccountURIs.UserURL, PrivateKey: key, @@ -169,19 +198,19 @@ func (ps *postgresService) CreateInstanceAccount() error { FollowingURI: newAccountURIs.FollowingURI, FeaturedCollectionURI: newAccountURIs.CollectionURI, } - inserted, err := ps.conn.Model(a).Where("username = ?", username).SelectOrInsert() + inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert() if err != nil { return err } if inserted { - ps.log.Infof("created instance account %s with id %s", username, a.ID) + a.log.Infof("created instance account %s with id %s", username, acct.ID) } else { - ps.log.Infof("instance account %s already exists with id %s", username, a.ID) + a.log.Infof("instance account %s already exists with id %s", username, acct.ID) } return nil } -func (ps *postgresService) CreateInstanceInstance() error { +func (a *adminDB) CreateInstanceInstance() db.DBError { iID, err := id.NewRandomULID() if err != nil { return err @@ -189,18 +218,18 @@ func (ps *postgresService) CreateInstanceInstance() error { i := >smodel.Instance{ ID: iID, - Domain: ps.config.Host, - Title: ps.config.Host, - URI: fmt.Sprintf("%s://%s", ps.config.Protocol, ps.config.Host), + Domain: a.config.Host, + Title: a.config.Host, + URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host), } - inserted, err := ps.conn.Model(i).Where("domain = ?", ps.config.Host).SelectOrInsert() + inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert() if err != nil { return err } if inserted { - ps.log.Infof("created instance instance %s with id %s", ps.config.Host, i.ID) + a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID) } else { - ps.log.Infof("instance instance %s already exists with id %s", ps.config.Host, i.ID) + a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID) } return nil } diff --git a/internal/db/pg/basic.go b/internal/db/pg/basic.go index 36009120e..1debe3b74 100644 --- a/internal/db/pg/basic.go +++ b/internal/db/pg/basic.go @@ -1,25 +1,55 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + package pg import ( + "context" "errors" + "fmt" "strings" "github.com/go-pg/pg/v10" + "github.com/go-pg/pg/v10/orm" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" ) -func (ps *postgresService) Put(i interface{}) error { - _, err := ps.conn.Model(i).Insert(i) +type basicDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc +} + +func (b *basicDB) Put(i interface{}) db.DBError { + _, err := b.conn.Model(i).Insert(i) if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") { - return db.ErrAlreadyExists{} + return db.ErrAlreadyExists } return err } -func (ps *postgresService) GetByID(id string, i interface{}) error { - if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil { +func (b *basicDB) GetByID(id string, i interface{}) db.DBError { + if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil { if err == pg.ErrNoRows { - return db.ErrNoEntries{} + return db.ErrNoEntries } return err @@ -27,12 +57,12 @@ func (ps *postgresService) GetByID(id string, i interface{}) error { return nil } -func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error { +func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.DBError { if len(where) == 0 { return errors.New("no queries provided") } - q := ps.conn.Model(i) + q := b.conn.Model(i) for _, w := range where { if w.Value == nil { @@ -48,25 +78,25 @@ func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error { if err := q.Select(); err != nil { if err == pg.ErrNoRows { - return db.ErrNoEntries{} + return db.ErrNoEntries } return err } return nil } -func (ps *postgresService) GetAll(i interface{}) error { - if err := ps.conn.Model(i).Select(); err != nil { +func (b *basicDB) GetAll(i interface{}) db.DBError { + if err := b.conn.Model(i).Select(); err != nil { if err == pg.ErrNoRows { - return db.ErrNoEntries{} + return db.ErrNoEntries } return err } return nil } -func (ps *postgresService) DeleteByID(id string, i interface{}) error { - if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil { +func (b *basicDB) DeleteByID(id string, i interface{}) db.DBError { + if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil { // if there are no rows *anyway* then that's fine // just return err if there's an actual error if err != pg.ErrNoRows { @@ -76,12 +106,12 @@ func (ps *postgresService) DeleteByID(id string, i interface{}) error { return nil } -func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error { +func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.DBError { if len(where) == 0 { return errors.New("no queries provided") } - q := ps.conn.Model(i) + q := b.conn.Model(i) for _, w := range where { q = q.Where("? = ?", pg.Safe(w.Key), w.Value) } @@ -95,3 +125,76 @@ func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error { } return nil } + +func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.DBError { + if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil { + if err == pg.ErrNoRows { + return db.ErrNoEntries + } + return err + } + return nil +} + +func (b *basicDB) UpdateByID(id string, i interface{}) db.DBError { + if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil { + if err == pg.ErrNoRows { + return db.ErrNoEntries + } + return err + } + return nil +} + +func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.DBError { + _, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update() + return err +} + +func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.DBError { + q := b.conn.Model(i) + + for _, w := range where { + if w.Value == nil { + q = q.Where("? IS NULL", pg.Ident(w.Key)) + } else { + if w.CaseInsensitive { + q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value) + } else { + q = q.Where("? = ?", pg.Safe(w.Key), w.Value) + } + } + } + + q = q.Set("? = ?", pg.Safe(key), value) + + _, err := q.Update() + + return err +} + +func (b *basicDB) CreateTable(i interface{}) db.DBError { + return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{ + IfNotExists: true, + }) +} + +func (b *basicDB) DropTable(i interface{}) db.DBError { + return b.conn.Model(i).DropTable(&orm.DropTableOptions{ + IfExists: true, + }) +} + +func (b *basicDB) IsHealthy(ctx context.Context) db.DBError { + return b.conn.Ping(ctx) +} + +func (b *basicDB) Stop(ctx context.Context) db.DBError { + b.log.Info("closing db connection") + if err := b.conn.Close(); err != nil { + // only cancel if there's a problem closing the db + b.cancel() + return err + } + return nil +} diff --git a/internal/db/pg/instance.go b/internal/db/pg/instance.go index c551b2a49..e2e92b23d 100644 --- a/internal/db/pg/instance.go +++ b/internal/db/pg/instance.go @@ -19,15 +19,26 @@ package pg import ( + "context" + "github.com/go-pg/pg/v10" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) { - q := ps.conn.Model(&[]*gtsmodel.Account{}) +type instanceDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc +} - if domain == ps.config.Host { +func (i *instanceDB) GetUserCountForInstance(domain string) (int, db.DBError) { + q := i.conn.Model(&[]*gtsmodel.Account{}) + + if domain == i.config.Host { // if the domain is *this* domain, just count where the domain field is null q = q.Where("? IS NULL", pg.Ident("domain")) } else { @@ -40,10 +51,10 @@ func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) { return q.Count() } -func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) { - q := ps.conn.Model(&[]*gtsmodel.Status{}) +func (i *instanceDB) GetStatusCountForInstance(domain string) (int, db.DBError) { + q := i.conn.Model(&[]*gtsmodel.Status{}) - if domain == ps.config.Host { + if domain == i.config.Host { // if the domain is *this* domain, just count where local is true q = q.Where("local = ?", true) } else { @@ -55,10 +66,10 @@ func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) return q.Count() } -func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) { - q := ps.conn.Model(&[]*gtsmodel.Instance{}) +func (i *instanceDB) GetDomainCountForInstance(domain string) (int, db.DBError) { + q := i.conn.Model(&[]*gtsmodel.Instance{}) - if domain == ps.config.Host { + if domain == i.config.Host { // if the domain is *this* domain, just count other instances it knows about // exclude domains that are blocked q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at")) @@ -70,12 +81,12 @@ func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) return q.Count() } -func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) { - ps.log.Debug("GetAccountsForInstance") +func (i *instanceDB) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.DBError) { + i.log.Debug("GetAccountsForInstance") accounts := []*gtsmodel.Account{} - q := ps.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC") + q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC") if maxID != "" { q = q.Where("id < ?", maxID) @@ -88,13 +99,13 @@ func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, l err := q.Select() if err != nil { if err == pg.ErrNoRows { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return nil, err } if len(accounts) == 0 { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return accounts, nil diff --git a/internal/db/pg/blocks.go b/internal/db/pg/notification.go similarity index 53% rename from internal/db/pg/blocks.go rename to internal/db/pg/notification.go index beada4e88..84359a981 100644 --- a/internal/db/pg/blocks.go +++ b/internal/db/pg/notification.go @@ -24,44 +24,30 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (ps *postgresService) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) { - blocks := []*gtsmodel.Block{} +func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.DBError) { + notifications := []*gtsmodel.Notification{} - fq := ps.conn.Model(&blocks). - Where("block.account_id = ?", accountID). - Relation("TargetAccount"). - Order("block.id DESC") + q := ps.conn.Model(¬ifications).Where("target_account_id = ?", accountID) if maxID != "" { - fq = fq.Where("block.id < ?", maxID) + q = q.Where("id < ?", maxID) } if sinceID != "" { - fq = fq.Where("block.id > ?", sinceID) + q = q.Where("id > ?", sinceID) } - if limit > 0 { - fq = fq.Limit(limit) + if limit != 0 { + q = q.Limit(limit) } - err := fq.Select() - if err != nil { - if err == pg.ErrNoRows { - return nil, "", "", db.ErrNoEntries{} + q = q.Order("created_at DESC") + + if err := q.Select(); err != nil { + if err != pg.ErrNoRows { + return nil, err } - return nil, "", "", err - } - if len(blocks) == 0 { - return nil, "", "", db.ErrNoEntries{} } - - accounts := []*gtsmodel.Account{} - for _, b := range blocks { - accounts = append(accounts, b.TargetAccount) - } - - nextMaxID := blocks[len(blocks)-1].ID - prevMinID := blocks[0].ID - return accounts, nextMaxID, prevMinID, nil + return notifications, nil } diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go index 0c3f7310c..a0a38038e 100644 --- a/internal/db/pg/pg.go +++ b/internal/db/pg/pg.go @@ -31,7 +31,6 @@ import ( "github.com/go-pg/pg/extra/pgdebug" "github.com/go-pg/pg/v10" - "github.com/go-pg/pg/v10/orm" "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -41,6 +40,14 @@ import ( // postgresService satisfies the DB interface type postgresService struct { + db.Account + db.Admin + db.Basic + db.Instance + db.Notification + db.Relationship + db.Status + db.Timeline config *config.Config conn *pg.DB log *logrus.Logger @@ -85,6 +92,48 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge log.Infof("connected to postgres version: %s", version) ps := &postgresService{ + Account: &accountDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Admin: &adminDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Basic: &basicDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Instance: &instanceDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Relationship: &relationshipDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Status: &statusDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, + Timeline: &timelineDB{ + config: c, + conn: conn, + log: log, + cancel: cancel, + }, config: c, conn: conn, log: log, @@ -193,89 +242,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { return options, nil } -/* - BASIC DB FUNCTIONALITY -*/ - -func (ps *postgresService) CreateTable(i interface{}) error { - return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{ - IfNotExists: true, - }) -} - -func (ps *postgresService) DropTable(i interface{}) error { - return ps.conn.Model(i).DropTable(&orm.DropTableOptions{ - IfExists: true, - }) -} - -func (ps *postgresService) Stop(ctx context.Context) error { - ps.log.Info("closing db connection") - if err := ps.conn.Close(); err != nil { - // only cancel if there's a problem closing the db - ps.cancel() - return err - } - return nil -} - -func (ps *postgresService) IsHealthy(ctx context.Context) error { - return ps.conn.Ping(ctx) -} - -func (ps *postgresService) CreateSchema(ctx context.Context) error { - models := []interface{}{ - (*gtsmodel.Account)(nil), - (*gtsmodel.Status)(nil), - (*gtsmodel.User)(nil), - } - ps.log.Info("creating db schema") - - for _, model := range models { - err := ps.conn.Model(model).CreateTable(&orm.CreateTableOptions{ - IfNotExists: true, - }) - if err != nil { - return err - } - } - - ps.log.Info("db schema created") - return nil -} - -/* - HANDY SHORTCUTS -*/ - -func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error) { - notifications := []*gtsmodel.Notification{} - - q := ps.conn.Model(¬ifications).Where("target_account_id = ?", accountID) - - if maxID != "" { - q = q.Where("id < ?", maxID) - } - - if sinceID != "" { - q = q.Where("id > ?", sinceID) - } - - if limit != 0 { - q = q.Limit(limit) - } - - q = q.Order("created_at DESC") - - if err := q.Select(); err != nil { - if err != pg.ErrNoRows { - return nil, err - } - - } - return notifications, nil -} - /* CONVERSION FUNCTIONS */ diff --git a/internal/db/pg/pg_test.go b/internal/db/pg/pg_test.go new file mode 100644 index 000000000..c1e10abdf --- /dev/null +++ b/internal/db/pg/pg_test.go @@ -0,0 +1,47 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package pg_test + +import ( + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +type PGStandardTestSuite struct { + // standard suite interfaces + suite.Suite + config *config.Config + db db.DB + log *logrus.Logger + + // standard suite models + testTokens map[string]*oauth.Token + testClients map[string]*oauth.Client + testApplications map[string]*gtsmodel.Application + testUsers map[string]*gtsmodel.User + testAccounts map[string]*gtsmodel.Account + testAttachments map[string]*gtsmodel.MediaAttachment + testStatuses map[string]*gtsmodel.Status + testTags map[string]*gtsmodel.Tag + testMentions map[string]*gtsmodel.Mention +} diff --git a/internal/db/pg/relationship.go b/internal/db/pg/relationship.go index ac628b87d..1835ca7a6 100644 --- a/internal/db/pg/relationship.go +++ b/internal/db/pg/relationship.go @@ -1,17 +1,45 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + package pg import ( + "context" "fmt" "github.com/go-pg/pg/v10" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (ps *postgresService) Blocked(account1 string, account2 string) (bool, error) { +type relationshipDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc +} + +func (r *relationshipDB) Blocked(account1 string, account2 string) (bool, db.DBError) { // TODO: check domain blocks as well var blocked bool - if err := ps.conn.Model(>smodel.Block{}). + if err := r.conn.Model(>smodel.Block{}). Where("account_id = ?", account1).Where("target_account_id = ?", account2). WhereOr("target_account_id = ?", account1).Where("account_id = ?", account2). Select(); err != nil { @@ -25,83 +53,83 @@ func (ps *postgresService) Blocked(account1 string, account2 string) (bool, erro return blocked, nil } -func (ps *postgresService) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) { - r := >smodel.Relationship{ +func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.DBError) { + rel := >smodel.Relationship{ ID: targetAccount, } // check if the requesting account follows the target account follow := >smodel.Follow{} - if err := ps.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil { + if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil { if err != pg.ErrNoRows { // a proper error return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err) } // no follow exists so these are all false - r.Following = false - r.ShowingReblogs = false - r.Notifying = false + rel.Following = false + rel.ShowingReblogs = false + rel.Notifying = false } else { // follow exists so we can fill these fields out... - r.Following = true - r.ShowingReblogs = follow.ShowReblogs - r.Notifying = follow.Notify + rel.Following = true + rel.ShowingReblogs = follow.ShowReblogs + rel.Notifying = follow.Notify } // check if the target account follows the requesting account - followedBy, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() + followedBy, err := r.conn.Model(>smodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() if err != nil { return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err) } - r.FollowedBy = followedBy + rel.FollowedBy = followedBy // check if the requesting account blocks the target account - blocking, err := ps.conn.Model(>smodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() + blocking, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() if err != nil { return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err) } - r.Blocking = blocking + rel.Blocking = blocking // check if the target account blocks the requesting account - blockedBy, err := ps.conn.Model(>smodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() + blockedBy, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists() if err != nil { return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) } - r.BlockedBy = blockedBy + rel.BlockedBy = blockedBy // check if there's a pending following request from requesting account to target account - requested, err := ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() + requested, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists() if err != nil { return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err) } - r.Requested = requested + rel.Requested = requested - return r, nil + return rel, nil } -func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) { +func (r *relationshipDB) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.DBError) { if sourceAccount == nil || targetAccount == nil { return false, nil } - return ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists() + return r.conn.Model(>smodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists() } -func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) { +func (r *relationshipDB) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.DBError) { if sourceAccount == nil || targetAccount == nil { return false, nil } - return ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists() + return r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists() } -func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error) { +func (r *relationshipDB) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.DBError) { if account1 == nil || account2 == nil { return false, nil } // make sure account 1 follows account 2 - f1, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists() + f1, err := r.conn.Model(>smodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists() if err != nil { if err == pg.ErrNoRows { return false, nil @@ -110,7 +138,7 @@ func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmode } // make sure account 2 follows account 1 - f2, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", account2.ID).Where("target_account_id = ?", account1.ID).Exists() + f2, err := r.conn.Model(>smodel.Follow{}).Where("account_id = ?", account2.ID).Where("target_account_id = ?", account1.ID).Exists() if err != nil { if err == pg.ErrNoRows { return false, nil @@ -121,12 +149,12 @@ func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmode return f1 && f2, nil } -func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) { +func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.DBError) { // make sure the original follow request exists fr := >smodel.FollowRequest{} - if err := ps.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil { + if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil { if err == pg.ErrMultiRows { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return nil, err } @@ -140,12 +168,12 @@ func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAcc } // if the follow already exists, just update the URI -- we don't need to do anything else - if _, err := ps.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil { + if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil { return nil, err } // now remove the follow request - if _, err := ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil { + if _, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil { return nil, err } diff --git a/internal/db/pg/status.go b/internal/db/pg/status.go index ab9243a90..450999b9a 100644 --- a/internal/db/pg/status.go +++ b/internal/db/pg/status.go @@ -20,39 +20,90 @@ package pg import ( "container/list" + "context" "errors" "github.com/go-pg/pg/v10" + "github.com/go-pg/pg/v10/orm" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (ps *postgresService) StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) { +type statusDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc +} + +func (s *statusDB) newStatusQ(status *gtsmodel.Status) *orm.Query { + return s.conn.Model(status). + Relation("Account"). + Relation("InReplyTo"). + Relation("InReplyToAccount"). + Relation("BoostOf"). + Relation("BoostOfAccount"). + Relation("CreatedWithApplication") +} + +func (s *statusDB) processResponse(status *gtsmodel.Status, err error) (*gtsmodel.Status, db.DBError) { + switch err { + case pg.ErrNoRows: + return nil, db.ErrNoEntries + case nil: + return status, nil + default: + return nil, err + } +} + +func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.DBError) { + status := >smodel.Status{} + + q := s.newStatusQ(status). + Where("status.id = ?", id) + + return s.processResponse(status, q.Select()) +} + +func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.DBError) { + status := >smodel.Status{} + + q := s.newStatusQ(status). + Where("LOWER(status.uri) = LOWER(?)", uri) + + return s.processResponse(status, q.Select()) +} + +func (s *statusDB) StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.DBError) { parents := []*gtsmodel.Status{} - ps.statusParent(status, &parents, onlyDirect) + s.statusParent(status, &parents, onlyDirect) return parents, nil } -func (ps *postgresService) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) { +func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) { if status.InReplyToID == "" { return } parentStatus := >smodel.Status{} - if err := ps.conn.Model(parentStatus).Where("id = ?", status.InReplyToID).Select(); err == nil { + if err := s.conn.Model(parentStatus).Where("id = ?", status.InReplyToID).Select(); err == nil { *foundStatuses = append(*foundStatuses, parentStatus) } if onlyDirect { return } - ps.statusParent(parentStatus, foundStatuses, false) + s.statusParent(parentStatus, foundStatuses, false) } -func (ps *postgresService) StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) { +func (s *statusDB) StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.DBError) { foundStatuses := &list.List{} foundStatuses.PushFront(status) - ps.statusChildren(status, foundStatuses, onlyDirect, minID) + s.statusChildren(status, foundStatuses, onlyDirect, minID) children := []*gtsmodel.Status{} for e := foundStatuses.Front(); e != nil; e = e.Next() { @@ -70,10 +121,10 @@ func (ps *postgresService) StatusChildren(status *gtsmodel.Status, onlyDirect bo return children, nil } -func (ps *postgresService) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { +func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) { immediateChildren := []*gtsmodel.Status{} - q := ps.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID) + q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID) if minID != "" { q = q.Where("status.id > ?", minID) } @@ -100,43 +151,43 @@ func (ps *postgresService) statusChildren(status *gtsmodel.Status, foundStatuses if onlyDirect { return } - ps.statusChildren(child, foundStatuses, false, minID) + s.statusChildren(child, foundStatuses, false, minID) } } -func (ps *postgresService) GetReplyCountForStatus(status *gtsmodel.Status) (int, error) { - return ps.conn.Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count() +func (s *statusDB) GetReplyCountForStatus(status *gtsmodel.Status) (int, db.DBError) { + return s.conn.Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count() } -func (ps *postgresService) GetReblogCountForStatus(status *gtsmodel.Status) (int, error) { - return ps.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count() +func (s *statusDB) GetReblogCountForStatus(status *gtsmodel.Status) (int, db.DBError) { + return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count() } -func (ps *postgresService) GetFaveCountForStatus(status *gtsmodel.Status) (int, error) { - return ps.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count() +func (s *statusDB) GetFaveCountForStatus(status *gtsmodel.Status) (int, db.DBError) { + return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count() } -func (ps *postgresService) StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error) { - return ps.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() +func (s *statusDB) StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) { + return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() } -func (ps *postgresService) StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error) { - return ps.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists() +func (s *statusDB) StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) { + return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists() } -func (ps *postgresService) StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error) { - return ps.conn.Model(>smodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() +func (s *statusDB) StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) { + return s.conn.Model(>smodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() } -func (ps *postgresService) StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error) { - return ps.conn.Model(>smodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() +func (s *statusDB) StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) { + return s.conn.Model(>smodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists() } -func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) { +func (s *statusDB) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, db.DBError) { accounts := []*gtsmodel.Account{} faves := []*gtsmodel.StatusFave{} - if err := ps.conn.Model(&faves).Where("status_id = ?", status.ID).Select(); err != nil { + if err := s.conn.Model(&faves).Where("status_id = ?", status.ID).Select(); err != nil { if err == pg.ErrNoRows { return accounts, nil // no rows just means nobody has faved this status, so that's fine } @@ -145,7 +196,7 @@ func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel. for _, f := range faves { acc := >smodel.Account{} - if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil { + if err := s.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil { if err == pg.ErrNoRows { continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it } @@ -156,11 +207,11 @@ func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel. return accounts, nil } -func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) { +func (s *statusDB) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, db.DBError) { accounts := []*gtsmodel.Account{} boosts := []*gtsmodel.Status{} - if err := ps.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil { + if err := s.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil { if err == pg.ErrNoRows { return accounts, nil // no rows just means nobody has boosted this status, so that's fine } @@ -169,7 +220,7 @@ func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmode for _, f := range boosts { acc := >smodel.Account{} - if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil { + if err := s.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil { if err == pg.ErrNoRows { continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it } diff --git a/internal/db/pg/status_test.go b/internal/db/pg/status_test.go new file mode 100644 index 000000000..a412eb55f --- /dev/null +++ b/internal/db/pg/status_test.go @@ -0,0 +1,86 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package pg_test + +import ( + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type StatusTestSuite struct { + PGStandardTestSuite +} + +func (suite *PGStandardTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testAttachments = testrig.NewTestAttachments() + suite.testStatuses = testrig.NewTestStatuses() + suite.testTags = testrig.NewTestTags() + suite.testMentions = testrig.NewTestMentions() +} + +func (suite *PGStandardTestSuite) SetupTest() { + suite.config = testrig.NewTestConfig() + suite.db = testrig.NewTestDB() + suite.log = testrig.NewTestLog() + + testrig.StandardDBSetup(suite.db, suite.testAccounts) +} + +func (suite *PGStandardTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) +} + +func (suite *PGStandardTestSuite) TestGetStatusByID() { + status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(status) + suite.NotNil(status.Account) + suite.NotNil(status.CreatedWithApplication) + suite.Nil(status.BoostOf) + suite.Nil(status.BoostOfAccount) + suite.Nil(status.InReplyTo) + suite.Nil(status.InReplyToAccount) +} + +func (suite *PGStandardTestSuite) TestGetStatusByURI() { + status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(status) + suite.NotNil(status.Account) + suite.NotNil(status.CreatedWithApplication) + suite.Nil(status.BoostOf) + suite.Nil(status.BoostOfAccount) + suite.Nil(status.InReplyTo) + suite.Nil(status.InReplyToAccount) +} + +func TestStatusTestSuite(t *testing.T) { + suite.Run(t, new(PGStandardTestSuite)) +} diff --git a/internal/db/pg/timeline.go b/internal/db/pg/timeline.go index 585ca3067..87e306da8 100644 --- a/internal/db/pg/timeline.go +++ b/internal/db/pg/timeline.go @@ -19,16 +19,26 @@ package pg import ( + "context" "sort" "github.com/go-pg/pg/v10" + "github.com/sirupsen/logrus" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) -func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) { +type timelineDB struct { + config *config.Config + conn *pg.DB + log *logrus.Logger + cancel context.CancelFunc +} + +func (t *timelineDB) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.DBError) { statuses := []*gtsmodel.Status{} - q := ps.conn.Model(&statuses) + q := t.conn.Model(&statuses) q = q.ColumnExpr("status.*"). // Find out who accountID follows. @@ -74,22 +84,22 @@ func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID str err := q.Select() if err != nil { if err == pg.ErrNoRows { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return nil, err } if len(statuses) == 0 { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return statuses, nil } -func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) { +func (t *timelineDB) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.DBError) { statuses := []*gtsmodel.Status{} - q := ps.conn.Model(&statuses). + q := t.conn.Model(&statuses). Where("visibility = ?", gtsmodel.VisibilityPublic). Where("? IS NULL", pg.Ident("in_reply_to_id")). Where("? IS NULL", pg.Ident("in_reply_to_uri")). @@ -119,13 +129,13 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s err := q.Select() if err != nil { if err == pg.ErrNoRows { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return nil, err } if len(statuses) == 0 { - return nil, db.ErrNoEntries{} + return nil, db.ErrNoEntries } return statuses, nil @@ -133,11 +143,11 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s // TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20! // It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds. -func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) { +func (t *timelineDB) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.DBError) { faves := []*gtsmodel.StatusFave{} - fq := ps.conn.Model(&faves). + fq := t.conn.Model(&faves). Where("account_id = ?", accountID). Order("id DESC") @@ -156,13 +166,13 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st err := fq.Select() if err != nil { if err == pg.ErrNoRows { - return nil, "", "", db.ErrNoEntries{} + return nil, "", "", db.ErrNoEntries } return nil, "", "", err } if len(faves) == 0 { - return nil, "", "", db.ErrNoEntries{} + return nil, "", "", db.ErrNoEntries } // map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID @@ -175,16 +185,16 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st } statuses := []*gtsmodel.Status{} - err = ps.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select() + err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select() if err != nil { if err == pg.ErrNoRows { - return nil, "", "", db.ErrNoEntries{} + return nil, "", "", db.ErrNoEntries } return nil, "", "", err } if len(statuses) == 0 { - return nil, "", "", db.ErrNoEntries{} + return nil, "", "", db.ErrNoEntries } // arrange statuses by fave ID diff --git a/internal/db/pg/update.go b/internal/db/pg/update.go deleted file mode 100644 index f6bc70ad9..000000000 --- a/internal/db/pg/update.go +++ /dev/null @@ -1,73 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see . -*/ - -package pg - -import ( - "fmt" - - "github.com/go-pg/pg/v10" - "github.com/superseriousbusiness/gotosocial/internal/db" -) - -func (ps *postgresService) Upsert(i interface{}, conflictColumn string) error { - if _, err := ps.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - return nil -} - -func (ps *postgresService) UpdateByID(id string, i interface{}) error { - if _, err := ps.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil { - if err == pg.ErrNoRows { - return db.ErrNoEntries{} - } - return err - } - return nil -} - -func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error { - _, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update() - return err -} - -func (ps *postgresService) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) error { - q := ps.conn.Model(i) - - for _, w := range where { - if w.Value == nil { - q = q.Where("? IS NULL", pg.Ident(w.Key)) - } else { - if w.CaseInsensitive { - q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value) - } else { - q = q.Where("? = ?", pg.Safe(w.Key), w.Value) - } - } - } - - q = q.Set("? = ?", pg.Safe(key), value) - - _, err := q.Update() - - return err -} diff --git a/internal/db/relationship.go b/internal/db/relationship.go index 1fa532fbb..28c891b99 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -5,23 +5,23 @@ import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" type Relationship interface { // Blocked checks whether a block exists in eiher direction between two accounts. // That is, it returns true if account1 blocks account2, OR if account2 blocks account1. - Blocked(account1 string, account2 string) (bool, error) + Blocked(account1 string, account2 string) (bool, DBError) // GetRelationship retrieves the relationship of the targetAccount to the requestingAccount. - GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) + GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, DBError) // Follows returns true if sourceAccount follows target account, or an error if something goes wrong while finding out. - Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) + Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, DBError) // FollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out. - FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) + FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, DBError) // Mutuals returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out. - Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error) + Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, DBError) // AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table. // In other words, it should create the follow, and delete the existing follow request. // // It will return the newly created follow for further processing. - AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) + AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, DBError) } diff --git a/internal/db/status.go b/internal/db/status.go index 94db51d0a..91ee54a9c 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -3,42 +3,48 @@ package db import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" type Status interface { + // GetStatusByID returns one status from the database, with all rel fields populated (if possible). + GetStatusByID(id string) (*gtsmodel.Status, DBError) + + // GetStatusByURI returns one status from the database, with all rel fields populated (if possible). + GetStatusByURI(uri string) (*gtsmodel.Status, DBError) + // GetReplyCountForStatus returns the amount of replies recorded for a status, or an error if something goes wrong - GetReplyCountForStatus(status *gtsmodel.Status) (int, error) + GetReplyCountForStatus(status *gtsmodel.Status) (int, DBError) // GetReblogCountForStatus returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong - GetReblogCountForStatus(status *gtsmodel.Status) (int, error) + GetReblogCountForStatus(status *gtsmodel.Status) (int, DBError) // GetFaveCountForStatus returns the amount of faves/likes recorded for a status, or an error if something goes wrong - GetFaveCountForStatus(status *gtsmodel.Status) (int, error) + GetFaveCountForStatus(status *gtsmodel.Status) (int, DBError) // StatusParents get the parent statuses of a given status. // // If onlyDirect is true, only the immediate parent will be returned. - StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) + StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, DBError) // StatusChildren gets the child statuses of a given status. // // If onlyDirect is true, only the immediate children will be returned. - StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) + StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, DBError) // StatusFavedBy checks if a given status has been faved by a given account ID - StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error) + StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, DBError) // StatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID - StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error) + StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, DBError) // StatusMutedBy checks if a given status has been muted by a given account ID - StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error) + StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, DBError) // StatusBookmarkedBy checks if a given status has been bookmarked by a given account ID - StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error) + StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, DBError) // WhoFavedStatus returns a slice of accounts who faved the given status. // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) + WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, DBError) // WhoBoostedStatus returns a slice of accounts who boosted the given status. // This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user. - WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) + WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, DBError) } diff --git a/internal/db/timeline.go b/internal/db/timeline.go index 3ef34d0ea..737a78a31 100644 --- a/internal/db/timeline.go +++ b/internal/db/timeline.go @@ -6,13 +6,13 @@ type Timeline interface { // GetHomeTimelineForAccount returns a slice of statuses from accounts that are followed by the given account id. // // Statuses should be returned in descending order of when they were created (newest first). - GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) + GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, DBError) // GetPublicTimelineForAccount fetches the account's PUBLIC timeline -- ie., posts and replies that are public. // It will use the given filters and try to return as many statuses as possible up to the limit. // // Statuses should be returned in descending order of when they were created (newest first). - GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) + GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, DBError) // GetFavedTimelineForAccount fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved. // It will use the given filters and try to return as many statuses as possible up to the limit. @@ -21,5 +21,5 @@ type Timeline interface { // In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created. // // Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers. - GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) + GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, DBError) } diff --git a/internal/federation/dereferencing/blocked.go b/internal/federation/dereferencing/blocked.go index a66afbb60..c8a4c6ade 100644 --- a/internal/federation/dereferencing/blocked.go +++ b/internal/federation/dereferencing/blocked.go @@ -31,7 +31,7 @@ func (d *deref) blockedDomain(host string) (bool, error) { return true, nil } - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are no entries so there's no block return false, nil } diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go index 9d5a0c7ca..c2d293329 100644 --- a/internal/federation/dereferencing/status.go +++ b/internal/federation/dereferencing/status.go @@ -288,7 +288,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername attachmentIDs = append(attachmentIDs, maybeAttachment.ID) continue } - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // we have a real error return fmt.Errorf("error checking db for existence of attachment with remote url %s: %s", a.RemoteURL, err) } diff --git a/internal/federation/federatingdb/create.go b/internal/federation/federatingdb/create.go index 2ac4890e8..d5287136a 100644 --- a/internal/federation/federatingdb/create.go +++ b/internal/federation/federatingdb/create.go @@ -113,7 +113,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error { status.ID = statusID if err := f.db.Put(status); err != nil { - if _, ok := err.(db.ErrAlreadyExists); ok { + if err == db.ErrAlreadyExists { // the status already exists in the database, which means we've already handled everything else, // so we can just return nil here and be done with it. return nil diff --git a/internal/federation/federatingdb/outbox.go b/internal/federation/federatingdb/outbox.go index 1568e0017..7985a31d0 100644 --- a/internal/federation/federatingdb/outbox.go +++ b/internal/federation/federatingdb/outbox.go @@ -62,7 +62,7 @@ func (f *federatingDB) OutboxForInbox(c context.Context, inboxIRI *url.URL) (out } acct := >smodel.Account{} if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String()) } return nil, fmt.Errorf("db error searching for actor with inbox %s", inboxIRI.String()) diff --git a/internal/federation/federatingdb/owns.go b/internal/federation/federatingdb/owns.go index 51b20151a..3cb9603cd 100644 --- a/internal/federation/federatingdb/owns.go +++ b/internal/federation/federatingdb/owns.go @@ -55,7 +55,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) } if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: uid}}, >smodel.Status{}); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are no entries for this status return false, nil } @@ -72,7 +72,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) } if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are no entries for this username return false, nil } @@ -89,7 +89,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) } if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are no entries for this username return false, nil } @@ -106,7 +106,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) } if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are no entries for this username return false, nil } @@ -123,7 +123,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { return false, fmt.Errorf("error parsing like path for url %s: %s", id.String(), err) } if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are no entries for this username return false, nil } @@ -131,7 +131,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { return false, fmt.Errorf("database error fetching account with username %s: %s", username, err) } if err := f.db.GetByID(likeID, >smodel.StatusFave{}); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are no entries return false, nil } @@ -148,7 +148,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { return false, fmt.Errorf("error parsing block path for url %s: %s", id.String(), err) } if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are no entries for this username return false, nil } @@ -156,7 +156,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) { return false, fmt.Errorf("database error fetching account with username %s: %s", username, err) } if err := f.db.GetByID(blockID, >smodel.Block{}); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are no entries return false, nil } diff --git a/internal/federation/federatingdb/util.go b/internal/federation/federatingdb/util.go index 28f4c5a21..9a3e60cf1 100644 --- a/internal/federation/federatingdb/util.go +++ b/internal/federation/federatingdb/util.go @@ -213,7 +213,7 @@ func (f *federatingDB) ActorForOutbox(c context.Context, outboxIRI *url.URL) (ac } acct := >smodel.Account{} if err := f.db.GetWhere([]db.Where{{Key: "outbox_uri", Value: outboxIRI.String()}}, acct); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return nil, fmt.Errorf("no actor found that corresponds to outbox %s", outboxIRI.String()) } return nil, fmt.Errorf("db error searching for actor with outbox %s", outboxIRI.String()) @@ -238,7 +238,7 @@ func (f *federatingDB) ActorForInbox(c context.Context, inboxIRI *url.URL) (acto } acct := >smodel.Account{} if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String()) } return nil, fmt.Errorf("db error searching for actor with inbox %s", inboxIRI.String()) diff --git a/internal/federation/federatingprotocol.go b/internal/federation/federatingprotocol.go index 9e21b43bf..be988cda2 100644 --- a/internal/federation/federatingprotocol.go +++ b/internal/federation/federatingprotocol.go @@ -132,7 +132,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr // authentication has passed, so add an instance entry for this instance if it hasn't been done already i := >smodel.Instance{} if err := f.db.GetWhere([]db.Where{{Key: "domain", Value: publicKeyOwnerURI.Host, CaseInsensitive: true}}, i); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // there's been an actual error return ctx, false, fmt.Errorf("error getting requesting account with public key id %s: %s", publicKeyOwnerURI.String(), err) } @@ -202,8 +202,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er requestingAccount := >smodel.Account{} if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: uri.String()}}, requestingAccount); err != nil { - _, ok := err.(db.ErrNoEntries) - if ok { + if err == db.ErrNoEntries { // we don't have an entry for this account so it's not blocked // TODO: allow a different default to be set for this behavior continue diff --git a/internal/federation/util.go b/internal/federation/util.go index de8654d32..7971f6172 100644 --- a/internal/federation/util.go +++ b/internal/federation/util.go @@ -13,7 +13,7 @@ func (f *federator) blockedDomain(host string) (bool, error) { return true, nil } - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are no entries so there's no block return false, nil } diff --git a/internal/gtsmodel/account.go b/internal/gtsmodel/account.go index e560601b8..369fd6b77 100644 --- a/internal/gtsmodel/account.go +++ b/internal/gtsmodel/account.go @@ -46,10 +46,12 @@ type Account struct { // ID of the avatar as a media attachment AvatarMediaAttachmentID string `pg:"type:CHAR(26)"` + AvatarMediaAttachment *MediaAttachment `pg:"rel:has-one"` // For a non-local account, where can the header be fetched? AvatarRemoteURL string // ID of the header as a media attachment HeaderMediaAttachmentID string `pg:"type:CHAR(26)"` + HeaderMediaAttachment *MediaAttachment `pg:"rel:has-one"` // For a non-local account, where can the header be fetched? HeaderRemoteURL string // DisplayName for this account. Can be empty, then just the Username will be used for display purposes. diff --git a/internal/gtsmodel/instance.go b/internal/gtsmodel/instance.go index 857831ba3..e0eb1435a 100644 --- a/internal/gtsmodel/instance.go +++ b/internal/gtsmodel/instance.go @@ -20,6 +20,7 @@ type Instance struct { SuspendedAt time.Time // ID of any existing domain block for this instance in the database DomainBlockID string `pg:"type:CHAR(26)"` + DomainBlock *DomainBlock `pg:"rel:has-one"` // Short description of this instance ShortDescription string // Longer description of this instance @@ -32,6 +33,7 @@ type Instance struct { ContactAccountUsername string // Contact account ID in the database for this instance ContactAccountID string `pg:"type:CHAR(26)"` + ContactAccount *Account `pg:"rel:has-one"` // Reputation score of this instance Reputation int64 `pg:",notnull,default:0"` // Version of the software used on this instance diff --git a/internal/gtsmodel/status.go b/internal/gtsmodel/status.go index 21b8fc794..2746284a2 100644 --- a/internal/gtsmodel/status.go +++ b/internal/gtsmodel/status.go @@ -47,24 +47,24 @@ type Status struct { // is this status from a local account? Local bool // which account posted this status? - AccountID string `pg:"type:CHAR(26),notnull"` - Account *Account `pg:"rel:has-one"` + AccountID string `pg:"type:CHAR(26),notnull"` + Account *Account `pg:"rel:has-one"` // AP uri of the owner of this status AccountURI string // id of the status this status is a reply to - InReplyToID string `pg:"type:CHAR(26)"` - InReplyTo *Status `pg:"-"` + InReplyToID string `pg:"type:CHAR(26)"` + InReplyTo *Status `pg:"rel:has-one"` // AP uri of the status this status is a reply to InReplyToURI string // id of the account that this status replies to - InReplyToAccountID string `pg:"type:CHAR(26)"` - InReplyToAccount *Account `pg:"-"` + InReplyToAccountID string `pg:"type:CHAR(26)"` + InReplyToAccount *Account `pg:"rel:has-one"` // id of the status this status is a boost of - BoostOfID string `pg:"type:CHAR(26)"` - BoostOf *Status `pg:"-"` + BoostOfID string `pg:"type:CHAR(26)"` + BoostOf *Status `pg:"rel:has-one"` // id of the account that owns the boosted status - BoostOfAccountID string `pg:"type:CHAR(26)"` - BoostOfAccount *Account `pg:"-"` + BoostOfAccountID string `pg:"type:CHAR(26)"` + BoostOfAccount *Account `pg:"rel:has-one"` // cw string for this status ContentWarning string // visibility entry for this status @@ -74,8 +74,8 @@ type Status struct { // what language is this status written in? Language string // Which application was used to create this status? - CreatedWithApplicationID string `pg:"type:CHAR(26)"` - CreatedWithApplication *Application `pg:"-"` + CreatedWithApplicationID string `pg:"type:CHAR(26)"` + CreatedWithApplication *Application `pg:"rel:has-one"` // advanced visibility for this status VisibilityAdvanced *VisibilityAdvanced // What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types @@ -103,13 +103,6 @@ type Status struct { GTSEmojis []*Emoji `pg:"-"` // MediaAttachments used in this status GTSMediaAttachments []*MediaAttachment `pg:"-"` - // Status being replied to - - // Account being replied to - - // Status being boosted - - // Account of the boosted status } diff --git a/internal/oauth/clientstore.go b/internal/oauth/clientstore.go index 998f6784e..2e7e0ae88 100644 --- a/internal/oauth/clientstore.go +++ b/internal/oauth/clientstore.go @@ -27,11 +27,11 @@ import ( ) type clientStore struct { - db db.DB + db db.Basic } // NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend. -func NewClientStore(db db.DB) oauth2.ClientStore { +func NewClientStore(db db.Basic) oauth2.ClientStore { pts := &clientStore{ db: db, } diff --git a/internal/oauth/clientstore_test.go b/internal/oauth/clientstore_test.go index c515ff513..fd3452405 100644 --- a/internal/oauth/clientstore_test.go +++ b/internal/oauth/clientstore_test.go @@ -99,7 +99,7 @@ func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() { // try to get the deleted client; we should get an error deletedClient, err := cs.GetByID(context.Background(), suite.testClientID) suite.Assert().Nil(deletedClient) - suite.Assert().EqualValues(db.ErrNoEntries{}, err) + suite.Assert().EqualValues(db.ErrNoEntries, err) } func TestPgClientStoreTestSuite(t *testing.T) { diff --git a/internal/oauth/server.go b/internal/oauth/server.go index 1289b18af..6d8f50064 100644 --- a/internal/oauth/server.go +++ b/internal/oauth/server.go @@ -66,7 +66,7 @@ type s struct { } // New returns a new oauth server that implements the Server interface -func New(database db.DB, log *logrus.Logger) Server { +func New(database db.Basic, log *logrus.Logger) Server { ts := newTokenStore(context.Background(), database, log) cs := NewClientStore(database) diff --git a/internal/oauth/tokenstore.go b/internal/oauth/tokenstore.go index 5f8e07882..4fd3183fc 100644 --- a/internal/oauth/tokenstore.go +++ b/internal/oauth/tokenstore.go @@ -34,7 +34,7 @@ import ( // tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend. type tokenStore struct { oauth2.TokenStore - db db.DB + db db.Basic log *logrus.Logger } @@ -42,7 +42,7 @@ type tokenStore struct { // // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through // the tokens in the DB once per minute and deletes any that have expired. -func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.TokenStore { +func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.TokenStore { pts := &tokenStore{ db: db, log: log, diff --git a/internal/processing/account/createblock.go b/internal/processing/account/createblock.go index 79ce03805..798e9324f 100644 --- a/internal/processing/account/createblock.go +++ b/internal/processing/account/createblock.go @@ -33,7 +33,7 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou // make sure the target account actually exists in our db targetAcct := >smodel.Account{} if err := p.db.GetByID(targetAccountID, targetAcct); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: account %s not found in the db: %s", targetAccountID, err)) } } diff --git a/internal/processing/account/createfollow.go b/internal/processing/account/createfollow.go index e89db9d47..55d75fcc5 100644 --- a/internal/processing/account/createfollow.go +++ b/internal/processing/account/createfollow.go @@ -42,7 +42,7 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim // make sure the target account actually exists in our db targetAcct := >smodel.Account{} if err := p.db.GetByID(form.ID, targetAcct); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err)) } } diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index 4cadeaac6..e8840abae 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -135,7 +135,7 @@ selectStatusesLoop: for { statuses, err := p.db.GetAccountStatuses(account.ID, 20, false, maxID, false, false) if err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // no statuses left for this instance so we're done l.Infof("Delete: done iterating through statuses for account %s", account.Username) break selectStatusesLoop @@ -158,7 +158,7 @@ selectStatusesLoop: } if err := p.db.DeleteByID(s.ID, s); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // actual error has occurred l.Errorf("Delete: db error status %s for account %s: %s", s.ID, account.Username, err) break selectStatusesLoop @@ -168,7 +168,7 @@ selectStatusesLoop: // if there are any boosts of this status, delete them as well boosts := []*gtsmodel.Status{} if err := p.db.GetWhere([]db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // an actual error has occurred l.Errorf("Delete: db error selecting boosts of status %s for account %s: %s", s.ID, account.Username, err) break selectStatusesLoop @@ -190,7 +190,7 @@ selectStatusesLoop: } if err := p.db.DeleteByID(b.ID, b); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // actual error has occurred l.Errorf("Delete: db error deleting boost with id %s: %s", b.ID, err) break selectStatusesLoop diff --git a/internal/processing/account/get.go b/internal/processing/account/get.go index a70bf02bd..8cfd91cc2 100644 --- a/internal/processing/account/get.go +++ b/internal/processing/account/get.go @@ -30,7 +30,7 @@ import ( func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) { targetAccount := >smodel.Account{} if err := p.db.GetByID(targetAccountID, targetAccount); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return nil, errors.New("account not found") } return nil, fmt.Errorf("db error: %s", err) diff --git a/internal/processing/account/getfollowers.go b/internal/processing/account/getfollowers.go index 66cd38f21..4f3617341 100644 --- a/internal/processing/account/getfollowers.go +++ b/internal/processing/account/getfollowers.go @@ -40,7 +40,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco followers := []gtsmodel.Follow{} accounts := []apimodel.Account{} if err := p.db.GetAccountFollowers(targetAccountID, &followers, false); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return accounts, nil } return nil, gtserror.NewErrorInternalError(err) @@ -57,7 +57,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco a := >smodel.Account{} if err := p.db.GetByID(f.AccountID, a); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { continue } return nil, gtserror.NewErrorInternalError(err) diff --git a/internal/processing/account/getfollowing.go b/internal/processing/account/getfollowing.go index 57461cf23..55e3e25cc 100644 --- a/internal/processing/account/getfollowing.go +++ b/internal/processing/account/getfollowing.go @@ -40,7 +40,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco following := []gtsmodel.Follow{} accounts := []apimodel.Account{} if err := p.db.GetAccountFollowing(targetAccountID, &following); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return accounts, nil } return nil, gtserror.NewErrorInternalError(err) @@ -57,7 +57,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco a := >smodel.Account{} if err := p.db.GetByID(f.TargetAccountID, a); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { continue } return nil, gtserror.NewErrorInternalError(err) diff --git a/internal/processing/account/getstatuses.go b/internal/processing/account/getstatuses.go index 27cabcc13..64ad71a03 100644 --- a/internal/processing/account/getstatuses.go +++ b/internal/processing/account/getstatuses.go @@ -30,7 +30,7 @@ import ( func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) { targetAccount := >smodel.Account{} if err := p.db.GetByID(targetAccountID, targetAccount); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(fmt.Errorf("no entry found for account id %s", targetAccountID)) } return nil, gtserror.NewErrorInternalError(err) @@ -39,7 +39,7 @@ func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccou apiStatuses := []apimodel.Status{} statuses, err := p.db.GetAccountStatuses(targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly) if err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return apiStatuses, nil } return nil, gtserror.NewErrorInternalError(err) diff --git a/internal/processing/account/removeblock.go b/internal/processing/account/removeblock.go index 03b0c6750..88a24eeef 100644 --- a/internal/processing/account/removeblock.go +++ b/internal/processing/account/removeblock.go @@ -31,7 +31,7 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou // make sure the target account actually exists in our db targetAcct := >smodel.Account{} if err := p.db.GetByID(targetAccountID, targetAcct); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockRemove: account %s not found in the db: %s", targetAccountID, err)) } } diff --git a/internal/processing/account/removefollow.go b/internal/processing/account/removefollow.go index ef8994893..21a0eca55 100644 --- a/internal/processing/account/removefollow.go +++ b/internal/processing/account/removefollow.go @@ -40,7 +40,7 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco // make sure the target account actually exists in our db targetAcct := >smodel.Account{} if err := p.db.GetByID(targetAccountID, targetAcct); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err)) } } diff --git a/internal/processing/admin/createdomainblock.go b/internal/processing/admin/createdomainblock.go index df02cef94..a58b2c9ad 100644 --- a/internal/processing/admin/createdomainblock.go +++ b/internal/processing/admin/createdomainblock.go @@ -36,7 +36,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string, domainBlock := >smodel.DomainBlock{} err := p.db.GetWhere([]db.Where{{Key: "domain", Value: domain, CaseInsensitive: true}}, domainBlock) if err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // something went wrong in the DB return nil, gtserror.NewErrorInternalError(fmt.Errorf("DomainBlockCreate: db error checking for existence of domain block %s: %s", domain, err)) } @@ -60,7 +60,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string, // put the new block in the database if err := p.db.Put(domainBlock); err != nil { - if _, ok := err.(db.ErrAlreadyExists); !ok { + if err != db.ErrNoEntries { // there's a real error creating the block return nil, gtserror.NewErrorInternalError(fmt.Errorf("DomainBlockCreate: db error putting new domain block %s: %s", domain, err)) } @@ -125,7 +125,7 @@ selectAccountsLoop: for { accounts, err := p.db.GetAccountsForInstance(block.Domain, maxID, limit) if err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // no accounts left for this instance so we're done l.Infof("domainBlockProcessSideEffects: done iterating through accounts for domain %s", block.Domain) break selectAccountsLoop diff --git a/internal/processing/admin/deletedomainblock.go b/internal/processing/admin/deletedomainblock.go index b41fedd92..edb0a58f9 100644 --- a/internal/processing/admin/deletedomainblock.go +++ b/internal/processing/admin/deletedomainblock.go @@ -32,7 +32,7 @@ func (p *processor) DomainBlockDelete(account *gtsmodel.Account, id string) (*ap domainBlock := >smodel.DomainBlock{} if err := p.db.GetByID(id, domainBlock); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // something has gone really wrong return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/admin/getdomainblock.go b/internal/processing/admin/getdomainblock.go index 7d1f9e2ab..f74010627 100644 --- a/internal/processing/admin/getdomainblock.go +++ b/internal/processing/admin/getdomainblock.go @@ -31,7 +31,7 @@ func (p *processor) DomainBlockGet(account *gtsmodel.Account, id string, export domainBlock := >smodel.DomainBlock{} if err := p.db.GetByID(id, domainBlock); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // something has gone really wrong return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/admin/getdomainblocks.go b/internal/processing/admin/getdomainblocks.go index 5e2241412..f827d03fc 100644 --- a/internal/processing/admin/getdomainblocks.go +++ b/internal/processing/admin/getdomainblocks.go @@ -29,7 +29,7 @@ func (p *processor) DomainBlocksGet(account *gtsmodel.Account, export bool) ([]* domainBlocks := []*gtsmodel.DomainBlock{} if err := p.db.GetAll(&domainBlocks); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // something has gone really wrong return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/blocks.go b/internal/processing/blocks.go index dbc2f77c7..809cbde8e 100644 --- a/internal/processing/blocks.go +++ b/internal/processing/blocks.go @@ -31,7 +31,7 @@ import ( func (p *processor) BlocksGet(authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) { accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(authed.Account.ID, maxID, sinceID, limit) if err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are just no entries return &apimodel.BlocksResponse{ Accounts: []*apimodel.Account{}, diff --git a/internal/processing/followrequest.go b/internal/processing/followrequest.go index 553a953ff..21ee740fc 100644 --- a/internal/processing/followrequest.go +++ b/internal/processing/followrequest.go @@ -29,7 +29,7 @@ import ( func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) { frs := []gtsmodel.FollowRequest{} if err := p.db.GetAccountFollowRequests(auth.Account.ID, &frs); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { return nil, gtserror.NewErrorInternalError(err) } } diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go index 8e52db89f..a5d268613 100644 --- a/internal/processing/fromcommon.go +++ b/internal/processing/fromcommon.go @@ -74,7 +74,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error { // notification exists already so just continue continue } - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // there's a real error in the db return fmt.Errorf("notifyStatus: error checking existence of notification for mention with id %s : %s", m.ID, err) } diff --git a/internal/processing/fromfederator.go b/internal/processing/fromfederator.go index 949a734c7..c75154d43 100644 --- a/internal/processing/fromfederator.go +++ b/internal/processing/fromfederator.go @@ -101,7 +101,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er incomingAnnounce.ID = incomingAnnounceID if err := p.db.Put(incomingAnnounce); err != nil { - if _, ok := err.(db.ErrAlreadyExists); !ok { + if err != db.ErrNoEntries { return fmt.Errorf("error adding dereferenced announce to the db: %s", err) } } diff --git a/internal/processing/media/delete.go b/internal/processing/media/delete.go index 694d78ac3..b5ea8c806 100644 --- a/internal/processing/media/delete.go +++ b/internal/processing/media/delete.go @@ -12,7 +12,7 @@ import ( func (p *processor) Delete(mediaAttachmentID string) gtserror.WithCode { a := >smodel.MediaAttachment{} if err := p.db.GetByID(mediaAttachmentID, a); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // attachment already gone return nil } @@ -38,7 +38,7 @@ func (p *processor) Delete(mediaAttachmentID string) gtserror.WithCode { // delete the attachment if err := p.db.DeleteByID(mediaAttachmentID, a); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { errs = append(errs, fmt.Sprintf("remove attachment: %s", err)) } } diff --git a/internal/processing/media/getmedia.go b/internal/processing/media/getmedia.go index c36370225..380a54cc2 100644 --- a/internal/processing/media/getmedia.go +++ b/internal/processing/media/getmedia.go @@ -31,7 +31,7 @@ import ( func (p *processor) GetMedia(account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) { attachment := >smodel.MediaAttachment{} if err := p.db.GetByID(mediaAttachmentID, attachment); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // attachment doesn't exist return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db")) } diff --git a/internal/processing/media/update.go b/internal/processing/media/update.go index 28f3a26f6..89ed08ac1 100644 --- a/internal/processing/media/update.go +++ b/internal/processing/media/update.go @@ -32,7 +32,7 @@ import ( func (p *processor) Update(account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) { attachment := >smodel.MediaAttachment{} if err := p.db.GetByID(mediaAttachmentID, attachment); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // attachment doesn't exist return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db")) } diff --git a/internal/processing/search.go b/internal/processing/search.go index 737ad8f71..c1cde8714 100644 --- a/internal/processing/search.go +++ b/internal/processing/search.go @@ -196,7 +196,7 @@ func (p *processor) searchAccountByMention(authed *oauth.Auth, mention string, r return maybeAcct, nil } - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // if it's not errNoEntries there's been a real database error so bail at this point return nil, fmt.Errorf("searchAccountByMention: database error: %s", err) } diff --git a/internal/processing/status/context.go b/internal/processing/status/context.go index 32c528296..d342babb0 100644 --- a/internal/processing/status/context.go +++ b/internal/processing/status/context.go @@ -19,7 +19,7 @@ func (p *processor) Context(account *gtsmodel.Account, targetStatusID string) (* targetStatus := >smodel.Status{} if err := p.db.GetByID(targetStatusID, targetStatus); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(err) } return nil, gtserror.NewErrorInternalError(err) diff --git a/internal/processing/status/delete.go b/internal/processing/status/delete.go index 259038dee..c73646af5 100644 --- a/internal/processing/status/delete.go +++ b/internal/processing/status/delete.go @@ -15,7 +15,7 @@ func (p *processor) Delete(account *gtsmodel.Account, targetStatusID string) (*a l.Tracef("going to search for target status %s", targetStatusID) targetStatus := >smodel.Status{} if err := p.db.GetByID(targetStatusID, targetStatus); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } // status is already gone diff --git a/internal/processing/status/unboost.go b/internal/processing/status/unboost.go index 3266c9de3..7a4ae5486 100644 --- a/internal/processing/status/unboost.go +++ b/internal/processing/status/unboost.go @@ -57,7 +57,7 @@ func (p *processor) Unboost(account *gtsmodel.Account, application *gtsmodel.App if err != nil { // something went wrong in the db finding the boost - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error fetching existing boost from database: %s", err)) } // we just don't have a boost diff --git a/internal/processing/status/unfave.go b/internal/processing/status/unfave.go index b51daacb9..82f43ee1f 100644 --- a/internal/processing/status/unfave.go +++ b/internal/processing/status/unfave.go @@ -45,7 +45,7 @@ func (p *processor) Unfave(account *gtsmodel.Account, targetStatusID string) (*a } if err != nil { // something went wrong in the db finding the fave - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error fetching existing fave from database: %s", err)) } // we just don't have a fave diff --git a/internal/processing/status/util.go b/internal/processing/status/util.go index 3be53591b..d700016e6 100644 --- a/internal/processing/status/util.go +++ b/internal/processing/status/util.go @@ -99,7 +99,7 @@ func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, th repliedAccount := >smodel.Account{} // check replied status exists + is replyable if err := p.db.GetByID(form.InReplyToID, repliedStatus); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return fmt.Errorf("status with id %s not replyable because it doesn't exist", form.InReplyToID) } return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err) @@ -113,14 +113,14 @@ func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, th // check replied account is known to us if err := p.db.GetByID(repliedStatus.AccountID, repliedAccount); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return fmt.Errorf("status with id %s not replyable because account id %s is not known", form.InReplyToID, repliedStatus.AccountID) } return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err) } // check if a block exists if blocked, err := p.db.Blocked(thisAccountID, repliedAccount.ID); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err) } } else if blocked { diff --git a/internal/processing/timeline.go b/internal/processing/timeline.go index 18d0a6ac7..bceadda73 100644 --- a/internal/processing/timeline.go +++ b/internal/processing/timeline.go @@ -76,7 +76,7 @@ func (p *processor) HomeTimelineGet(authed *oauth.Auth, maxID string, sinceID st func (p *processor) PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) { statuses, err := p.db.GetPublicTimelineForAccount(authed.Account.ID, maxID, sinceID, minID, limit, local) if err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are just no entries left return &apimodel.StatusTimelineResponse{ Statuses: []*apimodel.Status{}, @@ -97,7 +97,7 @@ func (p *processor) PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID func (p *processor) FavedTimelineGet(authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode) { statuses, nextMaxID, prevMinID, err := p.db.GetFavedTimelineForAccount(authed.Account.ID, maxID, minID, limit) if err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are just no entries left return &apimodel.StatusTimelineResponse{ Statuses: []*apimodel.Status{}, @@ -122,7 +122,7 @@ func (p *processor) filterPublicStatuses(authed *oauth.Auth, statuses []*gtsmode for _, s := range statuses { targetAccount := >smodel.Account{} if err := p.db.GetByID(s.AccountID, targetAccount); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { l.Debugf("filterPublicStatuses: skipping status %s because account %s can't be found in the db", s.ID, s.AccountID) continue } @@ -157,7 +157,7 @@ func (p *processor) filterFavedStatuses(authed *oauth.Auth, statuses []*gtsmodel for _, s := range statuses { targetAccount := >smodel.Account{} if err := p.db.GetByID(s.AccountID, targetAccount); err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { l.Debugf("filterFavedStatuses: skipping status %s because account %s can't be found in the db", s.ID, s.AccountID) continue } diff --git a/internal/router/session.go b/internal/router/session.go index 2b9be2f56..38810572f 100644 --- a/internal/router/session.go +++ b/internal/router/session.go @@ -49,7 +49,7 @@ func useSession(cfg *config.Config, dbService db.DB, engine *gin.Engine) error { // check if we have a saved router session already routerSessions := []*gtsmodel.RouterSession{} if err := dbService.GetAll(&routerSessions); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // proper error occurred return err } diff --git a/internal/timeline/index.go b/internal/timeline/index.go index 1e1a9d7bb..0b28c1a19 100644 --- a/internal/timeline/index.go +++ b/internal/timeline/index.go @@ -52,7 +52,7 @@ grabloop: for ; len(filtered) < amount && i < 5; i = i + 1 { // try the grabloop 5 times only statuses, err := t.db.GetHomeTimelineForAccount(t.accountID, "", "", offsetStatus, amount, false) if err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { break grabloop // we just don't have enough statuses left in the db so index what we've got and then bail } return fmt.Errorf("IndexBefore: error getting statuses from db: %s", err) @@ -132,7 +132,7 @@ grabloop: l.Tracef("entering grabloop; i is %d; len(filtered) is %d", i, len(filtered)) statuses, err := t.db.GetHomeTimelineForAccount(t.accountID, offsetStatus, "", "", amount, false) if err != nil { - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { break grabloop // we just don't have enough statuses left in the db so index what we've got and then bail } return fmt.Errorf("IndexBehind: error getting statuses from db: %s", err) diff --git a/internal/timeline/prepare.go b/internal/timeline/prepare.go index 51846c816..20000b4e9 100644 --- a/internal/timeline/prepare.go +++ b/internal/timeline/prepare.go @@ -95,7 +95,7 @@ prepareloop: if preparing { if err := t.prepare(entry.statusID); err != nil { // there's been an error - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // it's a real error return fmt.Errorf("PrepareBehind: error preparing status with id %s: %s", entry.statusID, err) } @@ -150,7 +150,7 @@ prepareloop: if preparing { if err := t.prepare(entry.statusID); err != nil { // there's been an error - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // it's a real error return fmt.Errorf("PrepareBefore: error preparing status with id %s: %s", entry.statusID, err) } @@ -205,7 +205,7 @@ prepareloop: if err := t.prepare(entry.statusID); err != nil { // there's been an error - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // it's a real error return fmt.Errorf("PrepareFromTop: error preparing status with id %s: %s", entry.statusID, err) } diff --git a/internal/typeutils/astointernal.go b/internal/typeutils/astointernal.go index 47f8c24fc..d426b094e 100644 --- a/internal/typeutils/astointernal.go +++ b/internal/typeutils/astointernal.go @@ -44,7 +44,7 @@ func (c *converter) ASRepresentationToAccount(accountable ap.Accountable, update // we already know this account so we can skip generating it return acct, nil } - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { // we don't know the account and there's been a real error return nil, fmt.Errorf("error getting account with uri %s from the database: %s", uri.String(), err) } diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go index 18270d90e..f8562c1ec 100644 --- a/internal/typeutils/internaltofrontend.go +++ b/internal/typeutils/internaltofrontend.go @@ -40,7 +40,7 @@ func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account // check pending follow requests aimed at this account fr := []gtsmodel.FollowRequest{} if err := c.db.GetAccountFollowRequests(a.ID, &fr); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { return nil, fmt.Errorf("error getting follow requests: %s", err) } } @@ -63,41 +63,27 @@ func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, error) { // count followers - followers := []gtsmodel.Follow{} - if err := c.db.GetAccountFollowers(a.ID, &followers, false); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { - return nil, fmt.Errorf("error getting followers: %s", err) - } - } - var followersCount int - if followers != nil { - followersCount = len(followers) + followersCount, err := c.db.CountAccountFollowers(a.ID, false) + if err != nil { + return nil, fmt.Errorf("error counting followers: %s", err) } // count following - following := []gtsmodel.Follow{} - if err := c.db.GetAccountFollowing(a.ID, &following); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { - return nil, fmt.Errorf("error getting following: %s", err) - } - } - var followingCount int - if following != nil { - followingCount = len(following) + followingCount, err := c.db.CountAccountFollowing(a.ID, false) + if err != nil { + return nil, fmt.Errorf("error counting following: %s", err) } // count statuses - statusesCount, err := c.db.GetAccountStatusesCount(a.ID) + statusesCount, err := c.db.CountAccountStatuses(a.ID) if err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { - return nil, fmt.Errorf("error getting last statuses: %s", err) - } + return nil, fmt.Errorf("error getting last statuses: %s", err) } // check when the last status was lastStatus := >smodel.Status{} if err := c.db.GetAccountLastStatus(a.ID, lastStatus); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { + if err != db.ErrNoEntries { return nil, fmt.Errorf("error getting last status: %s", err) } } @@ -107,23 +93,20 @@ func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, e } // build the avatar and header URLs - avi := >smodel.MediaAttachment{} - if err := c.db.GetAccountAvatar(avi, a.ID); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { - return nil, fmt.Errorf("error getting avatar: %s", err) - } - } - aviURL := avi.URL - aviURLStatic := avi.Thumbnail.URL - header := >smodel.MediaAttachment{} - if err := c.db.GetAccountHeader(header, a.ID); err != nil { - if _, ok := err.(db.ErrNoEntries); !ok { - return nil, fmt.Errorf("error getting header: %s", err) - } + var aviURL string + var aviURLStatic string + if a.AvatarMediaAttachment != nil { + aviURL = a.AvatarMediaAttachment.URL + aviURLStatic = a.AvatarMediaAttachment.Thumbnail.URL + } + + var headerURL string + var headerURLStatic string + if a.HeaderMediaAttachment != nil { + headerURL = a.HeaderMediaAttachment.URL + headerURLStatic = a.HeaderMediaAttachment.Thumbnail.URL } - headerURL := header.URL - headerURLStatic := header.Thumbnail.URL // get the fields set on this account fields := []model.Field{} @@ -585,13 +568,10 @@ func (c *converter) InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, erro } // get the instance account if it exists and just skip if it doesn't - ia := >smodel.Account{} - if err := c.db.GetWhere([]db.Where{{Key: "username", Value: i.Domain}}, ia); err == nil { - // instance account exists, get the header for the account if it exists - attachment := >smodel.MediaAttachment{} - if err := c.db.GetAccountHeader(attachment, ia.ID); err == nil { - // header exists, set it on the api model - mi.Thumbnail = attachment.URL + ia, err := c.db.GetInstanceAccount("") + if err == nil { + if ia.HeaderMediaAttachment != nil { + mi.Thumbnail = ia.HeaderMediaAttachment.URL } } diff --git a/internal/visibility/statusvisible.go b/internal/visibility/statusvisible.go index dc6b74702..638725e2b 100644 --- a/internal/visibility/statusvisible.go +++ b/internal/visibility/statusvisible.go @@ -45,7 +45,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount targetUser := >smodel.User{} if err := f.db.GetWhere([]db.Where{{Key: "account_id", Value: targetAccount.ID}}, targetUser); err != nil { l.Debug("target user could not be selected") - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return false, nil } return false, fmt.Errorf("StatusVisible: db error selecting user for local target account %s: %s", targetAccount.ID, err) @@ -76,7 +76,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount if err := f.db.GetWhere([]db.Where{{Key: "account_id", Value: requestingAccount.ID}}, requestingUser); err != nil { // if the requesting account is local but doesn't have a corresponding user in the db this is a problem l.Debug("requesting user could not be selected") - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { return false, nil } return false, fmt.Errorf("StatusVisible: db error selecting user for local requesting account %s: %s", requestingAccount.ID, err) diff --git a/internal/visibility/util.go b/internal/visibility/util.go index e7d5b4378..f004b4c23 100644 --- a/internal/visibility/util.go +++ b/internal/visibility/util.go @@ -112,7 +112,7 @@ func (f *filter) blockedDomain(host string) (bool, error) { return true, nil } - if _, ok := err.(db.ErrNoEntries); ok { + if err == db.ErrNoEntries { // there are no entries so there's no block return false, nil }