diff --git a/internal/api/client/auth/callback.go b/internal/api/client/auth/callback.go
index 8bf2a50b5..a26838aa3 100644
--- a/internal/api/client/auth/callback.go
+++ b/internal/api/client/auth/callback.go
@@ -116,7 +116,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
return user, nil
}
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// we have an actual error in the database
return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err)
}
@@ -128,7 +128,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email)
}
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// we have an actual error in the database
return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err)
}
diff --git a/internal/api/security/signaturecheck.go b/internal/api/security/signaturecheck.go
index b852c92ab..a3a04180d 100644
--- a/internal/api/security/signaturecheck.go
+++ b/internal/api/security/signaturecheck.go
@@ -59,7 +59,7 @@ func (m *Module) blockedDomain(host string) (bool, error) {
return true, nil
}
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are no entries so there's no block
return false, nil
}
diff --git a/internal/db/account.go b/internal/db/account.go
index 9584291e2..82268fe87 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -6,60 +6,66 @@ type Account interface {
// GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID.
// The given account pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
- GetAccountByUserID(userID string, account *gtsmodel.Account) error
+ GetAccountByUserID(userID string, account *gtsmodel.Account) DBError
+
+ // GetAccountByID returns one account with the given ID, or an error if something goes wrong.
+ GetAccountByID(id string) (*gtsmodel.Account, DBError)
+
+ // GetAccountByURI returns one account with the given URI, or an error if something goes wrong.
+ GetAccountByURI(uri string) (*gtsmodel.Account, DBError)
// GetLocalAccountByUsername is a shortcut for the common action of fetching an account ON THIS INSTANCE
// according to its username, which should be unique.
// The given account pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
- GetLocalAccountByUsername(username string, account *gtsmodel.Account) error
+ GetLocalAccountByUsername(username string, account *gtsmodel.Account) DBError
// GetAccountFollowRequests is a shortcut for the common action of fetching a list of follow requests targeting the given account ID.
// The given slice 'followRequests' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
- GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) error
+ GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) DBError
// GetAccountFollowing is a shortcut for the common action of fetching a list of accounts that accountID is following.
// The given slice 'following' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
- GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) error
+ GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) DBError
+
+ CountAccountFollowing(accountID string, localOnly bool) (int, DBError)
// GetAccountFollowers is a shortcut for the common action of fetching a list of accounts that accountID is followed by.
// The given slice 'followers' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
//
// If localOnly is set to true, then only followers from *this instance* will be returned.
- GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error
+ GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) DBError
+
+ CountAccountFollowers(accountID string, localOnly bool) (int, DBError)
// GetAccountFaves is a shortcut for the common action of fetching a list of faves made by the given accountID.
// The given slice 'faves' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
- GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) error
+ GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) DBError
// GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
- GetAccountStatusesCount(accountID string) (int, error)
+ CountAccountStatuses(accountID string) (int, DBError)
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
// In case of no entries, a 'no entries' error will be returned
- GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error)
+ GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, DBError)
- GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error)
+ GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, DBError)
// GetAccountLastStatus simply gets the most recent status by the given account.
// The given slice 'status' pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
- GetAccountLastStatus(accountID string, status *gtsmodel.Status) error
+ GetAccountLastStatus(accountID string, status *gtsmodel.Status) DBError
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
- SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error
+ SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) DBError
- // GetHeaderAvatarForAccountID gets the current avatar for the given account ID.
- // The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists.
- GetAccountAvatar(avatar *gtsmodel.MediaAttachment, accountID string) error
-
- // GetAccountHeader gets the current header for the given account ID.
- // The passed mediaAttachment pointer will be populated with the value of the header, if it exists.
- GetAccountHeader(header *gtsmodel.MediaAttachment, accountID string) error
+ // GetInstanceAccount returns the instance account for the given domain.
+ // If domain is empty, this instance account will be returned.
+ GetInstanceAccount(domain string) (*gtsmodel.Account, DBError)
}
diff --git a/internal/db/admin.go b/internal/db/admin.go
index 84589a250..8e7e489d2 100644
--- a/internal/db/admin.go
+++ b/internal/db/admin.go
@@ -9,26 +9,26 @@ import (
type Admin interface {
// IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db.
- IsUsernameAvailable(username string) error
+ IsUsernameAvailable(username string) DBError
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if:
// A) the email is already associated with an account
// B) we block signups from this email domain
// C) something went wrong in the db
- IsEmailAvailable(email string) error
+ IsEmailAvailable(email string) DBError
// NewSignup creates a new user in the database with the given parameters.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
- NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error)
+ NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, DBError)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
// This is needed for things like serving files that belong to the instance and not an individual user/account.
- CreateInstanceAccount() error
+ CreateInstanceAccount() DBError
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
// This is needed for things like serving instance information through /api/v1/instance
- CreateInstanceInstance() error
+ CreateInstanceInstance() DBError
}
diff --git a/internal/db/basic.go b/internal/db/basic.go
index d63f7034b..5740149c2 100644
--- a/internal/db/basic.go
+++ b/internal/db/basic.go
@@ -5,60 +5,60 @@ import "context"
type Basic interface {
// CreateTable creates a table for the given interface.
// For implementations that don't use tables, this can just return nil.
- CreateTable(i interface{}) error
+ CreateTable(i interface{}) DBError
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
- DropTable(i interface{}) error
+ DropTable(i interface{}) DBError
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil.
- Stop(ctx context.Context) error
+ Stop(ctx context.Context) DBError
// IsHealthy should return nil if the database connection is healthy, or an error if not.
- IsHealthy(ctx context.Context) error
+ IsHealthy(ctx context.Context) DBError
// GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
// for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetByID(id string, i interface{}) error
+ GetByID(id string, i interface{}) DBError
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetWhere(where []Where, i interface{}) error
+ GetWhere(where []Where, i interface{}) DBError
// GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
- GetAll(i interface{}) error
+ GetAll(i interface{}) DBError
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- Put(i interface{}) error
+ Put(i interface{}) DBError
// Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
// It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- Upsert(i interface{}, conflictColumn string) error
+ Upsert(i interface{}, conflictColumn string) DBError
// UpdateByID updates i with id id.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
- UpdateByID(id string, i interface{}) error
+ UpdateByID(id string, i interface{}) DBError
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
- UpdateOneByID(id string, key string, value interface{}, i interface{}) error
+ UpdateOneByID(id string, key string, value interface{}, i interface{}) DBError
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
- UpdateWhere(where []Where, key string, value interface{}, i interface{}) error
+ UpdateWhere(where []Where, key string, value interface{}, i interface{}) DBError
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
- DeleteByID(id string, i interface{}) error
+ DeleteByID(id string, i interface{}) DBError
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
- DeleteWhere(where []Where, i interface{}) error
+ DeleteWhere(where []Where, i interface{}) DBError
}
diff --git a/internal/db/error.go b/internal/db/error.go
index 197c7bd68..d669f2513 100644
--- a/internal/db/error.go
+++ b/internal/db/error.go
@@ -18,16 +18,12 @@
package db
-// ErrNoEntries is to be returned from the DB interface when no entries are found for a given query.
-type ErrNoEntries struct{}
+import "fmt"
-func (e ErrNoEntries) Error() string {
- return "no entries"
-}
+type DBError error
-// ErrAlreadyExists is to be returned from the DB interface when an entry already exists for a given query or its constraints.
-type ErrAlreadyExists struct{}
-
-func (e ErrAlreadyExists) Error() string {
- return "already exists"
-}
+var (
+ ErrNoEntries DBError = fmt.Errorf("no entries")
+ ErrAlreadyExists DBError = fmt.Errorf("already exists")
+ ErrUnknown DBError = fmt.Errorf("unknown error")
+)
diff --git a/internal/db/instance.go b/internal/db/instance.go
index dccdb9793..152cc39b8 100644
--- a/internal/db/instance.go
+++ b/internal/db/instance.go
@@ -4,14 +4,14 @@ import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
type Instance interface {
// GetUserCountForInstance returns the number of known accounts registered with the given domain.
- GetUserCountForInstance(domain string) (int, error)
+ GetUserCountForInstance(domain string) (int, DBError)
// GetStatusCountForInstance returns the number of known statuses posted from the given domain.
- GetStatusCountForInstance(domain string) (int, error)
+ GetStatusCountForInstance(domain string) (int, DBError)
// GetDomainCountForInstance returns the number of known instances known that the given domain federates with.
- GetDomainCountForInstance(domain string) (int, error)
+ GetDomainCountForInstance(domain string) (int, DBError)
// GetAccountsForInstance returns a slice of accounts from the given instance, arranged by ID.
- GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error)
+ GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, DBError)
}
diff --git a/internal/db/notification.go b/internal/db/notification.go
index 6a3c0c29d..7360b4080 100644
--- a/internal/db/notification.go
+++ b/internal/db/notification.go
@@ -4,5 +4,5 @@ import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
type Notification interface {
// GetNotificationsForAccount returns a list of notifications that pertain to the given accountID.
- GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error)
+ GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, DBError)
}
diff --git a/internal/db/pg/account.go b/internal/db/pg/account.go
index 9ecaba486..9c1493fdd 100644
--- a/internal/db/pg/account.go
+++ b/internal/db/pg/account.go
@@ -1,62 +1,100 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
package pg
import (
+ "context"
"errors"
"fmt"
"github.com/go-pg/pg/v10"
+ "github.com/go-pg/pg/v10/orm"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (ps *postgresService) GetAccountHeader(header *gtsmodel.MediaAttachment, accountID string) error {
- acct := >smodel.Account{}
- if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
-
- if acct.HeaderMediaAttachmentID == "" {
- return db.ErrNoEntries{}
- }
-
- if err := ps.conn.Model(header).Where("id = ?", acct.HeaderMediaAttachmentID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
- return nil
+type accountDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
}
-func (ps *postgresService) GetAccountAvatar(avatar *gtsmodel.MediaAttachment, accountID string) error {
- acct := >smodel.Account{}
- if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
-
- if acct.AvatarMediaAttachmentID == "" {
- return db.ErrNoEntries{}
- }
-
- if err := ps.conn.Model(avatar).Where("id = ?", acct.AvatarMediaAttachmentID).Select(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
- return nil
+func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query {
+ return a.conn.Model(account).
+ Relation("AvatarMediaAttachment").
+ Relation("HeaderMediaAttachment")
}
-func (ps *postgresService) GetAccountLastStatus(accountID string, status *gtsmodel.Status) error {
- if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
+func (a *accountDB) processResponse(account *gtsmodel.Account, err error) (*gtsmodel.Account, db.DBError) {
+ switch err {
+ case pg.ErrNoRows:
+ return nil, db.ErrNoEntries
+ case nil:
+ return account, nil
+ default:
+ return nil, err
+ }
+}
+
+func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.DBError) {
+ account := >smodel.Account{}
+
+ q := a.newAccountQ(account).
+ Where("account.id = ?", id)
+
+ return a.processResponse(account, q.Select())
+}
+
+func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.DBError) {
+ account := >smodel.Account{}
+
+ q := a.newAccountQ(account).
+ Where("account.uri = ?", uri)
+
+ return a.processResponse(account, q.Select())
+}
+
+func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.DBError) {
+ account := >smodel.Account{}
+
+ q := a.newAccountQ(account)
+
+ if domain == "" {
+ q = q.
+ Where("account.username = ?", domain).
+ Where("account.domain = ?", domain)
+ } else {
+ q = q.
+ Where("account.username = ?", domain).
+ Where("? IS NULL", pg.Ident("domain"))
+ }
+
+ return a.processResponse(account, q.Select())
+}
+
+func (a *accountDB) GetAccountLastStatus(accountID string, status *gtsmodel.Status) db.DBError {
+ if err := a.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
+ return db.ErrNoEntries
}
return err
}
@@ -64,7 +102,7 @@ func (ps *postgresService) GetAccountLastStatus(accountID string, status *gtsmod
}
-func (ps *postgresService) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error {
+func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.DBError {
if mediaAttachment.Avatar && mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
@@ -79,47 +117,47 @@ func (ps *postgresService) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.Me
}
// TODO: there are probably more side effects here that need to be handled
- if _, err := ps.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
+ if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
return err
}
- if _, err := ps.conn.Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
+ if _, err := a.conn.Model(>smodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
return err
}
return nil
}
-func (ps *postgresService) GetAccountByUserID(userID string, account *gtsmodel.Account) error {
+func (a *accountDB) GetAccountByUserID(userID string, account *gtsmodel.Account) db.DBError {
user := >smodel.User{
ID: userID,
}
- if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
+ if err := a.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
+ return db.ErrNoEntries
}
return err
}
- if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
+ if err := a.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
+ return db.ErrNoEntries
}
return err
}
return nil
}
-func (ps *postgresService) GetLocalAccountByUsername(username string, account *gtsmodel.Account) error {
- if err := ps.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil {
+func (a *accountDB) GetLocalAccountByUsername(username string, account *gtsmodel.Account) db.DBError {
+ if err := a.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil {
if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
+ return db.ErrNoEntries
}
return err
}
return nil
}
-func (ps *postgresService) GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) error {
- if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
+func (a *accountDB) GetAccountFollowRequests(accountID string, followRequests *[]gtsmodel.FollowRequest) db.DBError {
+ if err := a.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
@@ -128,8 +166,8 @@ func (ps *postgresService) GetAccountFollowRequests(accountID string, followRequ
return nil
}
-func (ps *postgresService) GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) error {
- if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
+func (a *accountDB) GetAccountFollowing(accountID string, following *[]gtsmodel.Follow) db.DBError {
+ if err := a.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
@@ -138,9 +176,13 @@ func (ps *postgresService) GetAccountFollowing(accountID string, following *[]gt
return nil
}
-func (ps *postgresService) GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error {
+func (a *accountDB) CountAccountFollowing(accountID string, localOnly bool) (int, db.DBError) {
+ return a.conn.Model(&[]*gtsmodel.Follow{}).Where("account_id = ?", accountID).Count()
+}
- q := ps.conn.Model(followers)
+func (a *accountDB) GetAccountFollowers(accountID string, followers *[]gtsmodel.Follow, localOnly bool) db.DBError {
+
+ q := a.conn.Model(followers)
if localOnly {
// for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
@@ -168,8 +210,12 @@ func (ps *postgresService) GetAccountFollowers(accountID string, followers *[]gt
return nil
}
-func (ps *postgresService) GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) error {
- if err := ps.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil {
+func (a *accountDB) CountAccountFollowers(accountID string, localOnly bool) (int, db.DBError) {
+ return a.conn.Model(&[]*gtsmodel.Follow{}).Where("target_account_id = ?", accountID).Count()
+}
+
+func (a *accountDB) GetAccountFaves(accountID string, faves *[]gtsmodel.StatusFave) db.DBError {
+ if err := a.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
@@ -178,22 +224,15 @@ func (ps *postgresService) GetAccountFaves(accountID string, faves *[]gtsmodel.S
return nil
}
-func (ps *postgresService) GetAccountStatusesCount(accountID string) (int, error) {
- count, err := ps.conn.Model(>smodel.Status{}).Where("account_id = ?", accountID).Count()
- if err != nil {
- if err == pg.ErrNoRows {
- return 0, nil
- }
- return 0, err
- }
- return count, nil
+func (a *accountDB) CountAccountStatuses(accountID string) (int, db.DBError) {
+ return a.conn.Model(>smodel.Status{}).Where("account_id = ?", accountID).Count()
}
-func (ps *postgresService) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error) {
- ps.log.Debugf("getting statuses for account %s", accountID)
+func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.DBError) {
+ a.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
- q := ps.conn.Model(&statuses).Order("id DESC")
+ q := a.conn.Model(&statuses).Order("id DESC")
if accountID != "" {
q = q.Where("account_id = ?", accountID)
}
@@ -222,15 +261,57 @@ func (ps *postgresService) GetAccountStatuses(accountID string, limit int, exclu
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
- ps.log.Debugf("returning statuses for account %s", accountID)
+ a.log.Debugf("returning statuses for account %s", accountID)
return statuses, nil
}
+
+func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.DBError) {
+ blocks := []*gtsmodel.Block{}
+
+ fq := a.conn.Model(&blocks).
+ Where("block.account_id = ?", accountID).
+ Relation("TargetAccount").
+ Order("block.id DESC")
+
+ if maxID != "" {
+ fq = fq.Where("block.id < ?", maxID)
+ }
+
+ if sinceID != "" {
+ fq = fq.Where("block.id > ?", sinceID)
+ }
+
+ if limit > 0 {
+ fq = fq.Limit(limit)
+ }
+
+ err := fq.Select()
+ if err != nil {
+ if err == pg.ErrNoRows {
+ return nil, "", "", db.ErrNoEntries
+ }
+ return nil, "", "", err
+ }
+
+ if len(blocks) == 0 {
+ return nil, "", "", db.ErrNoEntries
+ }
+
+ accounts := []*gtsmodel.Account{}
+ for _, b := range blocks {
+ accounts = append(accounts, b.TargetAccount)
+ }
+
+ nextMaxID := blocks[len(blocks)-1].ID
+ prevMinID := blocks[0].ID
+ return accounts, nextMaxID, prevMinID, nil
+}
diff --git a/internal/db/pg/admin.go b/internal/db/pg/admin.go
index 7bc614f14..fe3d48b54 100644
--- a/internal/db/pg/admin.go
+++ b/internal/db/pg/admin.go
@@ -1,6 +1,25 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
package pg
import (
+ "context"
"crypto/rand"
"crypto/rsa"
"fmt"
@@ -10,17 +29,27 @@ import (
"time"
"github.com/go-pg/pg/v10"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
"golang.org/x/crypto/bcrypt"
)
-func (ps *postgresService) IsUsernameAvailable(username string) error {
+type adminDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+}
+
+func (a *adminDB) IsUsernameAvailable(username string) db.DBError {
// if no error we fail because it means we found something
// if error but it's not pg.ErrNoRows then we fail
// if err is pg.ErrNoRows we're good, we found nothing so continue
- if err := ps.conn.Model(>smodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
+ if err := a.conn.Model(>smodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
return fmt.Errorf("username %s already in use", username)
} else if err != pg.ErrNoRows {
return fmt.Errorf("db error: %s", err)
@@ -28,7 +57,7 @@ func (ps *postgresService) IsUsernameAvailable(username string) error {
return nil
}
-func (ps *postgresService) IsEmailAvailable(email string) error {
+func (a *adminDB) IsEmailAvailable(email string) db.DBError {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
@@ -37,7 +66,7 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
- if err := ps.conn.Model(>smodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
+ if err := a.conn.Model(>smodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email domain %s is blocked", domain)
} else if err != pg.ErrNoRows {
@@ -46,7 +75,7 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
}
// check if this email is associated with a user already
- if err := ps.conn.Model(>smodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
+ if err := a.conn.Model(>smodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email %s already in use", email)
} else if err != pg.ErrNoRows {
@@ -56,16 +85,16 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
return nil
}
-func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error) {
+func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.DBError) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
- ps.log.Errorf("error creating new rsa key: %s", err)
+ a.log.Errorf("error creating new rsa key: %s", err)
return nil, err
}
// if something went wrong while creating a user, we might already have an account, so check here first...
- a := >smodel.Account{}
- err = ps.conn.Model(a).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
+ acct := >smodel.Account{}
+ err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
if err != nil {
// there's been an actual error
if err != pg.ErrNoRows {
@@ -73,13 +102,13 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
}
// we just don't have an account yet create one
- newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host)
+ newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
newAccountID, err := id.NewRandomULID()
if err != nil {
return nil, err
}
- a = >smodel.Account{
+ acct = >smodel.Account{
ID: newAccountID,
Username: username,
DisplayName: username,
@@ -96,7 +125,7 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
- if _, err = ps.conn.Model(a).Insert(); err != nil {
+ if _, err = a.conn.Model(acct).Insert(); err != nil {
return nil, err
}
}
@@ -113,7 +142,7 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
u := >smodel.User{
ID: newUserID,
- AccountID: a.ID,
+ AccountID: acct.ID,
EncryptedPassword: string(pw),
SignUpIP: signUpIP.To4(),
Locale: locale,
@@ -132,18 +161,18 @@ func (ps *postgresService) NewSignup(username string, reason string, requireAppr
u.Moderator = true
}
- if _, err = ps.conn.Model(u).Insert(); err != nil {
+ if _, err = a.conn.Model(u).Insert(); err != nil {
return nil, err
}
return u, nil
}
-func (ps *postgresService) CreateInstanceAccount() error {
- username := ps.config.Host
+func (a *adminDB) CreateInstanceAccount() db.DBError {
+ username := a.config.Host
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
- ps.log.Errorf("error creating new rsa key: %s", err)
+ a.log.Errorf("error creating new rsa key: %s", err)
return err
}
@@ -152,10 +181,10 @@ func (ps *postgresService) CreateInstanceAccount() error {
return err
}
- newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host)
- a := >smodel.Account{
+ newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
+ acct := >smodel.Account{
ID: aID,
- Username: ps.config.Host,
+ Username: a.config.Host,
DisplayName: username,
URL: newAccountURIs.UserURL,
PrivateKey: key,
@@ -169,19 +198,19 @@ func (ps *postgresService) CreateInstanceAccount() error {
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
- inserted, err := ps.conn.Model(a).Where("username = ?", username).SelectOrInsert()
+ inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert()
if err != nil {
return err
}
if inserted {
- ps.log.Infof("created instance account %s with id %s", username, a.ID)
+ a.log.Infof("created instance account %s with id %s", username, acct.ID)
} else {
- ps.log.Infof("instance account %s already exists with id %s", username, a.ID)
+ a.log.Infof("instance account %s already exists with id %s", username, acct.ID)
}
return nil
}
-func (ps *postgresService) CreateInstanceInstance() error {
+func (a *adminDB) CreateInstanceInstance() db.DBError {
iID, err := id.NewRandomULID()
if err != nil {
return err
@@ -189,18 +218,18 @@ func (ps *postgresService) CreateInstanceInstance() error {
i := >smodel.Instance{
ID: iID,
- Domain: ps.config.Host,
- Title: ps.config.Host,
- URI: fmt.Sprintf("%s://%s", ps.config.Protocol, ps.config.Host),
+ Domain: a.config.Host,
+ Title: a.config.Host,
+ URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host),
}
- inserted, err := ps.conn.Model(i).Where("domain = ?", ps.config.Host).SelectOrInsert()
+ inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert()
if err != nil {
return err
}
if inserted {
- ps.log.Infof("created instance instance %s with id %s", ps.config.Host, i.ID)
+ a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID)
} else {
- ps.log.Infof("instance instance %s already exists with id %s", ps.config.Host, i.ID)
+ a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID)
}
return nil
}
diff --git a/internal/db/pg/basic.go b/internal/db/pg/basic.go
index 36009120e..1debe3b74 100644
--- a/internal/db/pg/basic.go
+++ b/internal/db/pg/basic.go
@@ -1,25 +1,55 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
package pg
import (
+ "context"
"errors"
+ "fmt"
"strings"
"github.com/go-pg/pg/v10"
+ "github.com/go-pg/pg/v10/orm"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
-func (ps *postgresService) Put(i interface{}) error {
- _, err := ps.conn.Model(i).Insert(i)
+type basicDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+}
+
+func (b *basicDB) Put(i interface{}) db.DBError {
+ _, err := b.conn.Model(i).Insert(i)
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
- return db.ErrAlreadyExists{}
+ return db.ErrAlreadyExists
}
return err
}
-func (ps *postgresService) GetByID(id string, i interface{}) error {
- if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil {
+func (b *basicDB) GetByID(id string, i interface{}) db.DBError {
+ if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil {
if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
+ return db.ErrNoEntries
}
return err
@@ -27,12 +57,12 @@ func (ps *postgresService) GetByID(id string, i interface{}) error {
return nil
}
-func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error {
+func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.DBError {
if len(where) == 0 {
return errors.New("no queries provided")
}
- q := ps.conn.Model(i)
+ q := b.conn.Model(i)
for _, w := range where {
if w.Value == nil {
@@ -48,25 +78,25 @@ func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error {
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
+ return db.ErrNoEntries
}
return err
}
return nil
}
-func (ps *postgresService) GetAll(i interface{}) error {
- if err := ps.conn.Model(i).Select(); err != nil {
+func (b *basicDB) GetAll(i interface{}) db.DBError {
+ if err := b.conn.Model(i).Select(); err != nil {
if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
+ return db.ErrNoEntries
}
return err
}
return nil
}
-func (ps *postgresService) DeleteByID(id string, i interface{}) error {
- if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
+func (b *basicDB) DeleteByID(id string, i interface{}) db.DBError {
+ if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
// if there are no rows *anyway* then that's fine
// just return err if there's an actual error
if err != pg.ErrNoRows {
@@ -76,12 +106,12 @@ func (ps *postgresService) DeleteByID(id string, i interface{}) error {
return nil
}
-func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error {
+func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.DBError {
if len(where) == 0 {
return errors.New("no queries provided")
}
- q := ps.conn.Model(i)
+ q := b.conn.Model(i)
for _, w := range where {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
@@ -95,3 +125,76 @@ func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error {
}
return nil
}
+
+func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.DBError {
+ if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
+ if err == pg.ErrNoRows {
+ return db.ErrNoEntries
+ }
+ return err
+ }
+ return nil
+}
+
+func (b *basicDB) UpdateByID(id string, i interface{}) db.DBError {
+ if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
+ if err == pg.ErrNoRows {
+ return db.ErrNoEntries
+ }
+ return err
+ }
+ return nil
+}
+
+func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.DBError {
+ _, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
+ return err
+}
+
+func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.DBError {
+ q := b.conn.Model(i)
+
+ for _, w := range where {
+ if w.Value == nil {
+ q = q.Where("? IS NULL", pg.Ident(w.Key))
+ } else {
+ if w.CaseInsensitive {
+ q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
+ } else {
+ q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
+ }
+ }
+ }
+
+ q = q.Set("? = ?", pg.Safe(key), value)
+
+ _, err := q.Update()
+
+ return err
+}
+
+func (b *basicDB) CreateTable(i interface{}) db.DBError {
+ return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{
+ IfNotExists: true,
+ })
+}
+
+func (b *basicDB) DropTable(i interface{}) db.DBError {
+ return b.conn.Model(i).DropTable(&orm.DropTableOptions{
+ IfExists: true,
+ })
+}
+
+func (b *basicDB) IsHealthy(ctx context.Context) db.DBError {
+ return b.conn.Ping(ctx)
+}
+
+func (b *basicDB) Stop(ctx context.Context) db.DBError {
+ b.log.Info("closing db connection")
+ if err := b.conn.Close(); err != nil {
+ // only cancel if there's a problem closing the db
+ b.cancel()
+ return err
+ }
+ return nil
+}
diff --git a/internal/db/pg/instance.go b/internal/db/pg/instance.go
index c551b2a49..e2e92b23d 100644
--- a/internal/db/pg/instance.go
+++ b/internal/db/pg/instance.go
@@ -19,15 +19,26 @@
package pg
import (
+ "context"
+
"github.com/go-pg/pg/v10"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
- q := ps.conn.Model(&[]*gtsmodel.Account{})
+type instanceDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+}
- if domain == ps.config.Host {
+func (i *instanceDB) GetUserCountForInstance(domain string) (int, db.DBError) {
+ q := i.conn.Model(&[]*gtsmodel.Account{})
+
+ if domain == i.config.Host {
// if the domain is *this* domain, just count where the domain field is null
q = q.Where("? IS NULL", pg.Ident("domain"))
} else {
@@ -40,10 +51,10 @@ func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
return q.Count()
}
-func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) {
- q := ps.conn.Model(&[]*gtsmodel.Status{})
+func (i *instanceDB) GetStatusCountForInstance(domain string) (int, db.DBError) {
+ q := i.conn.Model(&[]*gtsmodel.Status{})
- if domain == ps.config.Host {
+ if domain == i.config.Host {
// if the domain is *this* domain, just count where local is true
q = q.Where("local = ?", true)
} else {
@@ -55,10 +66,10 @@ func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error)
return q.Count()
}
-func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) {
- q := ps.conn.Model(&[]*gtsmodel.Instance{})
+func (i *instanceDB) GetDomainCountForInstance(domain string) (int, db.DBError) {
+ q := i.conn.Model(&[]*gtsmodel.Instance{})
- if domain == ps.config.Host {
+ if domain == i.config.Host {
// if the domain is *this* domain, just count other instances it knows about
// exclude domains that are blocked
q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
@@ -70,12 +81,12 @@ func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error)
return q.Count()
}
-func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) {
- ps.log.Debug("GetAccountsForInstance")
+func (i *instanceDB) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.DBError) {
+ i.log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{}
- q := ps.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
+ q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
if maxID != "" {
q = q.Where("id < ?", maxID)
@@ -88,13 +99,13 @@ func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, l
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return nil, err
}
if len(accounts) == 0 {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return accounts, nil
diff --git a/internal/db/pg/blocks.go b/internal/db/pg/notification.go
similarity index 53%
rename from internal/db/pg/blocks.go
rename to internal/db/pg/notification.go
index beada4e88..84359a981 100644
--- a/internal/db/pg/blocks.go
+++ b/internal/db/pg/notification.go
@@ -24,44 +24,30 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (ps *postgresService) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) {
- blocks := []*gtsmodel.Block{}
+func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.DBError) {
+ notifications := []*gtsmodel.Notification{}
- fq := ps.conn.Model(&blocks).
- Where("block.account_id = ?", accountID).
- Relation("TargetAccount").
- Order("block.id DESC")
+ q := ps.conn.Model(¬ifications).Where("target_account_id = ?", accountID)
if maxID != "" {
- fq = fq.Where("block.id < ?", maxID)
+ q = q.Where("id < ?", maxID)
}
if sinceID != "" {
- fq = fq.Where("block.id > ?", sinceID)
+ q = q.Where("id > ?", sinceID)
}
- if limit > 0 {
- fq = fq.Limit(limit)
+ if limit != 0 {
+ q = q.Limit(limit)
}
- err := fq.Select()
- if err != nil {
- if err == pg.ErrNoRows {
- return nil, "", "", db.ErrNoEntries{}
+ q = q.Order("created_at DESC")
+
+ if err := q.Select(); err != nil {
+ if err != pg.ErrNoRows {
+ return nil, err
}
- return nil, "", "", err
- }
- if len(blocks) == 0 {
- return nil, "", "", db.ErrNoEntries{}
}
-
- accounts := []*gtsmodel.Account{}
- for _, b := range blocks {
- accounts = append(accounts, b.TargetAccount)
- }
-
- nextMaxID := blocks[len(blocks)-1].ID
- prevMinID := blocks[0].ID
- return accounts, nextMaxID, prevMinID, nil
+ return notifications, nil
}
diff --git a/internal/db/pg/pg.go b/internal/db/pg/pg.go
index 0c3f7310c..a0a38038e 100644
--- a/internal/db/pg/pg.go
+++ b/internal/db/pg/pg.go
@@ -31,7 +31,6 @@ import (
"github.com/go-pg/pg/extra/pgdebug"
"github.com/go-pg/pg/v10"
- "github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
@@ -41,6 +40,14 @@ import (
// postgresService satisfies the DB interface
type postgresService struct {
+ db.Account
+ db.Admin
+ db.Basic
+ db.Instance
+ db.Notification
+ db.Relationship
+ db.Status
+ db.Timeline
config *config.Config
conn *pg.DB
log *logrus.Logger
@@ -85,6 +92,48 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
log.Infof("connected to postgres version: %s", version)
ps := &postgresService{
+ Account: &accountDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Admin: &adminDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Basic: &basicDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Instance: &instanceDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Relationship: &relationshipDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Status: &statusDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
+ Timeline: &timelineDB{
+ config: c,
+ conn: conn,
+ log: log,
+ cancel: cancel,
+ },
config: c,
conn: conn,
log: log,
@@ -193,89 +242,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
return options, nil
}
-/*
- BASIC DB FUNCTIONALITY
-*/
-
-func (ps *postgresService) CreateTable(i interface{}) error {
- return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
- IfNotExists: true,
- })
-}
-
-func (ps *postgresService) DropTable(i interface{}) error {
- return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
- IfExists: true,
- })
-}
-
-func (ps *postgresService) Stop(ctx context.Context) error {
- ps.log.Info("closing db connection")
- if err := ps.conn.Close(); err != nil {
- // only cancel if there's a problem closing the db
- ps.cancel()
- return err
- }
- return nil
-}
-
-func (ps *postgresService) IsHealthy(ctx context.Context) error {
- return ps.conn.Ping(ctx)
-}
-
-func (ps *postgresService) CreateSchema(ctx context.Context) error {
- models := []interface{}{
- (*gtsmodel.Account)(nil),
- (*gtsmodel.Status)(nil),
- (*gtsmodel.User)(nil),
- }
- ps.log.Info("creating db schema")
-
- for _, model := range models {
- err := ps.conn.Model(model).CreateTable(&orm.CreateTableOptions{
- IfNotExists: true,
- })
- if err != nil {
- return err
- }
- }
-
- ps.log.Info("db schema created")
- return nil
-}
-
-/*
- HANDY SHORTCUTS
-*/
-
-func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error) {
- notifications := []*gtsmodel.Notification{}
-
- q := ps.conn.Model(¬ifications).Where("target_account_id = ?", accountID)
-
- if maxID != "" {
- q = q.Where("id < ?", maxID)
- }
-
- if sinceID != "" {
- q = q.Where("id > ?", sinceID)
- }
-
- if limit != 0 {
- q = q.Limit(limit)
- }
-
- q = q.Order("created_at DESC")
-
- if err := q.Select(); err != nil {
- if err != pg.ErrNoRows {
- return nil, err
- }
-
- }
- return notifications, nil
-}
-
/*
CONVERSION FUNCTIONS
*/
diff --git a/internal/db/pg/pg_test.go b/internal/db/pg/pg_test.go
new file mode 100644
index 000000000..c1e10abdf
--- /dev/null
+++ b/internal/db/pg/pg_test.go
@@ -0,0 +1,47 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package pg_test
+
+import (
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/oauth"
+)
+
+type PGStandardTestSuite struct {
+ // standard suite interfaces
+ suite.Suite
+ config *config.Config
+ db db.DB
+ log *logrus.Logger
+
+ // standard suite models
+ testTokens map[string]*oauth.Token
+ testClients map[string]*oauth.Client
+ testApplications map[string]*gtsmodel.Application
+ testUsers map[string]*gtsmodel.User
+ testAccounts map[string]*gtsmodel.Account
+ testAttachments map[string]*gtsmodel.MediaAttachment
+ testStatuses map[string]*gtsmodel.Status
+ testTags map[string]*gtsmodel.Tag
+ testMentions map[string]*gtsmodel.Mention
+}
diff --git a/internal/db/pg/relationship.go b/internal/db/pg/relationship.go
index ac628b87d..1835ca7a6 100644
--- a/internal/db/pg/relationship.go
+++ b/internal/db/pg/relationship.go
@@ -1,17 +1,45 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
package pg
import (
+ "context"
"fmt"
"github.com/go-pg/pg/v10"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (ps *postgresService) Blocked(account1 string, account2 string) (bool, error) {
+type relationshipDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+}
+
+func (r *relationshipDB) Blocked(account1 string, account2 string) (bool, db.DBError) {
// TODO: check domain blocks as well
var blocked bool
- if err := ps.conn.Model(>smodel.Block{}).
+ if err := r.conn.Model(>smodel.Block{}).
Where("account_id = ?", account1).Where("target_account_id = ?", account2).
WhereOr("target_account_id = ?", account1).Where("account_id = ?", account2).
Select(); err != nil {
@@ -25,83 +53,83 @@ func (ps *postgresService) Blocked(account1 string, account2 string) (bool, erro
return blocked, nil
}
-func (ps *postgresService) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) {
- r := >smodel.Relationship{
+func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.DBError) {
+ rel := >smodel.Relationship{
ID: targetAccount,
}
// check if the requesting account follows the target account
follow := >smodel.Follow{}
- if err := ps.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
+ if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
if err != pg.ErrNoRows {
// a proper error
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
}
// no follow exists so these are all false
- r.Following = false
- r.ShowingReblogs = false
- r.Notifying = false
+ rel.Following = false
+ rel.ShowingReblogs = false
+ rel.Notifying = false
} else {
// follow exists so we can fill these fields out...
- r.Following = true
- r.ShowingReblogs = follow.ShowReblogs
- r.Notifying = follow.Notify
+ rel.Following = true
+ rel.ShowingReblogs = follow.ShowReblogs
+ rel.Notifying = follow.Notify
}
// check if the target account follows the requesting account
- followedBy, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
+ followedBy, err := r.conn.Model(>smodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
}
- r.FollowedBy = followedBy
+ rel.FollowedBy = followedBy
// check if the requesting account blocks the target account
- blocking, err := ps.conn.Model(>smodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
+ blocking, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
}
- r.Blocking = blocking
+ rel.Blocking = blocking
// check if the target account blocks the requesting account
- blockedBy, err := ps.conn.Model(>smodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
+ blockedBy, err := r.conn.Model(>smodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
- r.BlockedBy = blockedBy
+ rel.BlockedBy = blockedBy
// check if there's a pending following request from requesting account to target account
- requested, err := ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
+ requested, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
- r.Requested = requested
+ rel.Requested = requested
- return r, nil
+ return rel, nil
}
-func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
+func (r *relationshipDB) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.DBError) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
- return ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
+ return r.conn.Model(>smodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
-func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
+func (r *relationshipDB) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.DBError) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
- return ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
+ return r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
-func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error) {
+func (r *relationshipDB) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.DBError) {
if account1 == nil || account2 == nil {
return false, nil
}
// make sure account 1 follows account 2
- f1, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists()
+ f1, err := r.conn.Model(>smodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists()
if err != nil {
if err == pg.ErrNoRows {
return false, nil
@@ -110,7 +138,7 @@ func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmode
}
// make sure account 2 follows account 1
- f2, err := ps.conn.Model(>smodel.Follow{}).Where("account_id = ?", account2.ID).Where("target_account_id = ?", account1.ID).Exists()
+ f2, err := r.conn.Model(>smodel.Follow{}).Where("account_id = ?", account2.ID).Where("target_account_id = ?", account1.ID).Exists()
if err != nil {
if err == pg.ErrNoRows {
return false, nil
@@ -121,12 +149,12 @@ func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmode
return f1 && f2, nil
}
-func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
+func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.DBError) {
// make sure the original follow request exists
fr := >smodel.FollowRequest{}
- if err := ps.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
+ if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
if err == pg.ErrMultiRows {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return nil, err
}
@@ -140,12 +168,12 @@ func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAcc
}
// if the follow already exists, just update the URI -- we don't need to do anything else
- if _, err := ps.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
+ if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
return nil, err
}
// now remove the follow request
- if _, err := ps.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
+ if _, err := r.conn.Model(>smodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
return nil, err
}
diff --git a/internal/db/pg/status.go b/internal/db/pg/status.go
index ab9243a90..450999b9a 100644
--- a/internal/db/pg/status.go
+++ b/internal/db/pg/status.go
@@ -20,39 +20,90 @@ package pg
import (
"container/list"
+ "context"
"errors"
"github.com/go-pg/pg/v10"
+ "github.com/go-pg/pg/v10/orm"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (ps *postgresService) StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) {
+type statusDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+}
+
+func (s *statusDB) newStatusQ(status *gtsmodel.Status) *orm.Query {
+ return s.conn.Model(status).
+ Relation("Account").
+ Relation("InReplyTo").
+ Relation("InReplyToAccount").
+ Relation("BoostOf").
+ Relation("BoostOfAccount").
+ Relation("CreatedWithApplication")
+}
+
+func (s *statusDB) processResponse(status *gtsmodel.Status, err error) (*gtsmodel.Status, db.DBError) {
+ switch err {
+ case pg.ErrNoRows:
+ return nil, db.ErrNoEntries
+ case nil:
+ return status, nil
+ default:
+ return nil, err
+ }
+}
+
+func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.DBError) {
+ status := >smodel.Status{}
+
+ q := s.newStatusQ(status).
+ Where("status.id = ?", id)
+
+ return s.processResponse(status, q.Select())
+}
+
+func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.DBError) {
+ status := >smodel.Status{}
+
+ q := s.newStatusQ(status).
+ Where("LOWER(status.uri) = LOWER(?)", uri)
+
+ return s.processResponse(status, q.Select())
+}
+
+func (s *statusDB) StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.DBError) {
parents := []*gtsmodel.Status{}
- ps.statusParent(status, &parents, onlyDirect)
+ s.statusParent(status, &parents, onlyDirect)
return parents, nil
}
-func (ps *postgresService) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
+func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
if status.InReplyToID == "" {
return
}
parentStatus := >smodel.Status{}
- if err := ps.conn.Model(parentStatus).Where("id = ?", status.InReplyToID).Select(); err == nil {
+ if err := s.conn.Model(parentStatus).Where("id = ?", status.InReplyToID).Select(); err == nil {
*foundStatuses = append(*foundStatuses, parentStatus)
}
if onlyDirect {
return
}
- ps.statusParent(parentStatus, foundStatuses, false)
+ s.statusParent(parentStatus, foundStatuses, false)
}
-func (ps *postgresService) StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) {
+func (s *statusDB) StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.DBError) {
foundStatuses := &list.List{}
foundStatuses.PushFront(status)
- ps.statusChildren(status, foundStatuses, onlyDirect, minID)
+ s.statusChildren(status, foundStatuses, onlyDirect, minID)
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
@@ -70,10 +121,10 @@ func (ps *postgresService) StatusChildren(status *gtsmodel.Status, onlyDirect bo
return children, nil
}
-func (ps *postgresService) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
+func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
immediateChildren := []*gtsmodel.Status{}
- q := ps.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
+ q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
if minID != "" {
q = q.Where("status.id > ?", minID)
}
@@ -100,43 +151,43 @@ func (ps *postgresService) statusChildren(status *gtsmodel.Status, foundStatuses
if onlyDirect {
return
}
- ps.statusChildren(child, foundStatuses, false, minID)
+ s.statusChildren(child, foundStatuses, false, minID)
}
}
-func (ps *postgresService) GetReplyCountForStatus(status *gtsmodel.Status) (int, error) {
- return ps.conn.Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
+func (s *statusDB) GetReplyCountForStatus(status *gtsmodel.Status) (int, db.DBError) {
+ return s.conn.Model(>smodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
}
-func (ps *postgresService) GetReblogCountForStatus(status *gtsmodel.Status) (int, error) {
- return ps.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
+func (s *statusDB) GetReblogCountForStatus(status *gtsmodel.Status) (int, db.DBError) {
+ return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
}
-func (ps *postgresService) GetFaveCountForStatus(status *gtsmodel.Status) (int, error) {
- return ps.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
+func (s *statusDB) GetFaveCountForStatus(status *gtsmodel.Status) (int, db.DBError) {
+ return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
}
-func (ps *postgresService) StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error) {
- return ps.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
+func (s *statusDB) StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
+ return s.conn.Model(>smodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
-func (ps *postgresService) StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error) {
- return ps.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
+func (s *statusDB) StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
+ return s.conn.Model(>smodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
-func (ps *postgresService) StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error) {
- return ps.conn.Model(>smodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
+func (s *statusDB) StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
+ return s.conn.Model(>smodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
-func (ps *postgresService) StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error) {
- return ps.conn.Model(>smodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
+func (s *statusDB) StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.DBError) {
+ return s.conn.Model(>smodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
-func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
+func (s *statusDB) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, db.DBError) {
accounts := []*gtsmodel.Account{}
faves := []*gtsmodel.StatusFave{}
- if err := ps.conn.Model(&faves).Where("status_id = ?", status.ID).Select(); err != nil {
+ if err := s.conn.Model(&faves).Where("status_id = ?", status.ID).Select(); err != nil {
if err == pg.ErrNoRows {
return accounts, nil // no rows just means nobody has faved this status, so that's fine
}
@@ -145,7 +196,7 @@ func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.
for _, f := range faves {
acc := >smodel.Account{}
- if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
+ if err := s.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
}
@@ -156,11 +207,11 @@ func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.
return accounts, nil
}
-func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
+func (s *statusDB) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, db.DBError) {
accounts := []*gtsmodel.Account{}
boosts := []*gtsmodel.Status{}
- if err := ps.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil {
+ if err := s.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil {
if err == pg.ErrNoRows {
return accounts, nil // no rows just means nobody has boosted this status, so that's fine
}
@@ -169,7 +220,7 @@ func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmode
for _, f := range boosts {
acc := >smodel.Account{}
- if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
+ if err := s.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
}
diff --git a/internal/db/pg/status_test.go b/internal/db/pg/status_test.go
new file mode 100644
index 000000000..a412eb55f
--- /dev/null
+++ b/internal/db/pg/status_test.go
@@ -0,0 +1,86 @@
+/*
+ GoToSocial
+ Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package pg_test
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type StatusTestSuite struct {
+ PGStandardTestSuite
+}
+
+func (suite *PGStandardTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+ suite.testStatuses = testrig.NewTestStatuses()
+ suite.testTags = testrig.NewTestTags()
+ suite.testMentions = testrig.NewTestMentions()
+}
+
+func (suite *PGStandardTestSuite) SetupTest() {
+ suite.config = testrig.NewTestConfig()
+ suite.db = testrig.NewTestDB()
+ suite.log = testrig.NewTestLog()
+
+ testrig.StandardDBSetup(suite.db, suite.testAccounts)
+}
+
+func (suite *PGStandardTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+}
+
+func (suite *PGStandardTestSuite) TestGetStatusByID() {
+ status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(status)
+ suite.NotNil(status.Account)
+ suite.NotNil(status.CreatedWithApplication)
+ suite.Nil(status.BoostOf)
+ suite.Nil(status.BoostOfAccount)
+ suite.Nil(status.InReplyTo)
+ suite.Nil(status.InReplyToAccount)
+}
+
+func (suite *PGStandardTestSuite) TestGetStatusByURI() {
+ status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(status)
+ suite.NotNil(status.Account)
+ suite.NotNil(status.CreatedWithApplication)
+ suite.Nil(status.BoostOf)
+ suite.Nil(status.BoostOfAccount)
+ suite.Nil(status.InReplyTo)
+ suite.Nil(status.InReplyToAccount)
+}
+
+func TestStatusTestSuite(t *testing.T) {
+ suite.Run(t, new(PGStandardTestSuite))
+}
diff --git a/internal/db/pg/timeline.go b/internal/db/pg/timeline.go
index 585ca3067..87e306da8 100644
--- a/internal/db/pg/timeline.go
+++ b/internal/db/pg/timeline.go
@@ -19,16 +19,26 @@
package pg
import (
+ "context"
"sort"
"github.com/go-pg/pg/v10"
+ "github.com/sirupsen/logrus"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
+type timelineDB struct {
+ config *config.Config
+ conn *pg.DB
+ log *logrus.Logger
+ cancel context.CancelFunc
+}
+
+func (t *timelineDB) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.DBError) {
statuses := []*gtsmodel.Status{}
- q := ps.conn.Model(&statuses)
+ q := t.conn.Model(&statuses)
q = q.ColumnExpr("status.*").
// Find out who accountID follows.
@@ -74,22 +84,22 @@ func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID str
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return statuses, nil
}
-func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
+func (t *timelineDB) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.DBError) {
statuses := []*gtsmodel.Status{}
- q := ps.conn.Model(&statuses).
+ q := t.conn.Model(&statuses).
Where("visibility = ?", gtsmodel.VisibilityPublic).
Where("? IS NULL", pg.Ident("in_reply_to_id")).
Where("? IS NULL", pg.Ident("in_reply_to_uri")).
@@ -119,13 +129,13 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
- return nil, db.ErrNoEntries{}
+ return nil, db.ErrNoEntries
}
return statuses, nil
@@ -133,11 +143,11 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
-func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) {
+func (t *timelineDB) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.DBError) {
faves := []*gtsmodel.StatusFave{}
- fq := ps.conn.Model(&faves).
+ fq := t.conn.Model(&faves).
Where("account_id = ?", accountID).
Order("id DESC")
@@ -156,13 +166,13 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st
err := fq.Select()
if err != nil {
if err == pg.ErrNoRows {
- return nil, "", "", db.ErrNoEntries{}
+ return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(faves) == 0 {
- return nil, "", "", db.ErrNoEntries{}
+ return nil, "", "", db.ErrNoEntries
}
// map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID
@@ -175,16 +185,16 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st
}
statuses := []*gtsmodel.Status{}
- err = ps.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
+ err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
if err != nil {
if err == pg.ErrNoRows {
- return nil, "", "", db.ErrNoEntries{}
+ return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(statuses) == 0 {
- return nil, "", "", db.ErrNoEntries{}
+ return nil, "", "", db.ErrNoEntries
}
// arrange statuses by fave ID
diff --git a/internal/db/pg/update.go b/internal/db/pg/update.go
deleted file mode 100644
index f6bc70ad9..000000000
--- a/internal/db/pg/update.go
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- GoToSocial
- Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
-
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
-
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
-*/
-
-package pg
-
-import (
- "fmt"
-
- "github.com/go-pg/pg/v10"
- "github.com/superseriousbusiness/gotosocial/internal/db"
-)
-
-func (ps *postgresService) Upsert(i interface{}, conflictColumn string) error {
- if _, err := ps.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
- return nil
-}
-
-func (ps *postgresService) UpdateByID(id string, i interface{}) error {
- if _, err := ps.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
- if err == pg.ErrNoRows {
- return db.ErrNoEntries{}
- }
- return err
- }
- return nil
-}
-
-func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
- _, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
- return err
-}
-
-func (ps *postgresService) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) error {
- q := ps.conn.Model(i)
-
- for _, w := range where {
- if w.Value == nil {
- q = q.Where("? IS NULL", pg.Ident(w.Key))
- } else {
- if w.CaseInsensitive {
- q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
- } else {
- q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
- }
- }
- }
-
- q = q.Set("? = ?", pg.Safe(key), value)
-
- _, err := q.Update()
-
- return err
-}
diff --git a/internal/db/relationship.go b/internal/db/relationship.go
index 1fa532fbb..28c891b99 100644
--- a/internal/db/relationship.go
+++ b/internal/db/relationship.go
@@ -5,23 +5,23 @@ import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
type Relationship interface {
// Blocked checks whether a block exists in eiher direction between two accounts.
// That is, it returns true if account1 blocks account2, OR if account2 blocks account1.
- Blocked(account1 string, account2 string) (bool, error)
+ Blocked(account1 string, account2 string) (bool, DBError)
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
- GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error)
+ GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, DBError)
// Follows returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
- Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error)
+ Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, DBError)
// FollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
- FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error)
+ FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, DBError)
// Mutuals returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
- Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error)
+ Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, DBError)
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
//
// It will return the newly created follow for further processing.
- AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
+ AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, DBError)
}
diff --git a/internal/db/status.go b/internal/db/status.go
index 94db51d0a..91ee54a9c 100644
--- a/internal/db/status.go
+++ b/internal/db/status.go
@@ -3,42 +3,48 @@ package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
type Status interface {
+ // GetStatusByID returns one status from the database, with all rel fields populated (if possible).
+ GetStatusByID(id string) (*gtsmodel.Status, DBError)
+
+ // GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
+ GetStatusByURI(uri string) (*gtsmodel.Status, DBError)
+
// GetReplyCountForStatus returns the amount of replies recorded for a status, or an error if something goes wrong
- GetReplyCountForStatus(status *gtsmodel.Status) (int, error)
+ GetReplyCountForStatus(status *gtsmodel.Status) (int, DBError)
// GetReblogCountForStatus returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
- GetReblogCountForStatus(status *gtsmodel.Status) (int, error)
+ GetReblogCountForStatus(status *gtsmodel.Status) (int, DBError)
// GetFaveCountForStatus returns the amount of faves/likes recorded for a status, or an error if something goes wrong
- GetFaveCountForStatus(status *gtsmodel.Status) (int, error)
+ GetFaveCountForStatus(status *gtsmodel.Status) (int, DBError)
// StatusParents get the parent statuses of a given status.
//
// If onlyDirect is true, only the immediate parent will be returned.
- StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error)
+ StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, DBError)
// StatusChildren gets the child statuses of a given status.
//
// If onlyDirect is true, only the immediate children will be returned.
- StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error)
+ StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, DBError)
// StatusFavedBy checks if a given status has been faved by a given account ID
- StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error)
+ StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
// StatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
- StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error)
+ StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
// StatusMutedBy checks if a given status has been muted by a given account ID
- StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error)
+ StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
// StatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
- StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error)
+ StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, DBError)
// WhoFavedStatus returns a slice of accounts who faved the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
- WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error)
+ WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, DBError)
// WhoBoostedStatus returns a slice of accounts who boosted the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
- WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error)
+ WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, DBError)
}
diff --git a/internal/db/timeline.go b/internal/db/timeline.go
index 3ef34d0ea..737a78a31 100644
--- a/internal/db/timeline.go
+++ b/internal/db/timeline.go
@@ -6,13 +6,13 @@ type Timeline interface {
// GetHomeTimelineForAccount returns a slice of statuses from accounts that are followed by the given account id.
//
// Statuses should be returned in descending order of when they were created (newest first).
- GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
+ GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, DBError)
// GetPublicTimelineForAccount fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Statuses should be returned in descending order of when they were created (newest first).
- GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
+ GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, DBError)
// GetFavedTimelineForAccount fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
// It will use the given filters and try to return as many statuses as possible up to the limit.
@@ -21,5 +21,5 @@ type Timeline interface {
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
- GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error)
+ GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, DBError)
}
diff --git a/internal/federation/dereferencing/blocked.go b/internal/federation/dereferencing/blocked.go
index a66afbb60..c8a4c6ade 100644
--- a/internal/federation/dereferencing/blocked.go
+++ b/internal/federation/dereferencing/blocked.go
@@ -31,7 +31,7 @@ func (d *deref) blockedDomain(host string) (bool, error) {
return true, nil
}
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are no entries so there's no block
return false, nil
}
diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go
index 9d5a0c7ca..c2d293329 100644
--- a/internal/federation/dereferencing/status.go
+++ b/internal/federation/dereferencing/status.go
@@ -288,7 +288,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
attachmentIDs = append(attachmentIDs, maybeAttachment.ID)
continue
}
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// we have a real error
return fmt.Errorf("error checking db for existence of attachment with remote url %s: %s", a.RemoteURL, err)
}
diff --git a/internal/federation/federatingdb/create.go b/internal/federation/federatingdb/create.go
index 2ac4890e8..d5287136a 100644
--- a/internal/federation/federatingdb/create.go
+++ b/internal/federation/federatingdb/create.go
@@ -113,7 +113,7 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
status.ID = statusID
if err := f.db.Put(status); err != nil {
- if _, ok := err.(db.ErrAlreadyExists); ok {
+ if err == db.ErrAlreadyExists {
// the status already exists in the database, which means we've already handled everything else,
// so we can just return nil here and be done with it.
return nil
diff --git a/internal/federation/federatingdb/outbox.go b/internal/federation/federatingdb/outbox.go
index 1568e0017..7985a31d0 100644
--- a/internal/federation/federatingdb/outbox.go
+++ b/internal/federation/federatingdb/outbox.go
@@ -62,7 +62,7 @@ func (f *federatingDB) OutboxForInbox(c context.Context, inboxIRI *url.URL) (out
}
acct := >smodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String())
}
return nil, fmt.Errorf("db error searching for actor with inbox %s", inboxIRI.String())
diff --git a/internal/federation/federatingdb/owns.go b/internal/federation/federatingdb/owns.go
index 51b20151a..3cb9603cd 100644
--- a/internal/federation/federatingdb/owns.go
+++ b/internal/federation/federatingdb/owns.go
@@ -55,7 +55,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: uid}}, >smodel.Status{}); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are no entries for this status
return false, nil
}
@@ -72,7 +72,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@@ -89,7 +89,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@@ -106,7 +106,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@@ -123,7 +123,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("error parsing like path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@@ -131,7 +131,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
}
if err := f.db.GetByID(likeID, >smodel.StatusFave{}); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are no entries
return false, nil
}
@@ -148,7 +148,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("error parsing block path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, >smodel.Account{}); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@@ -156,7 +156,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
}
if err := f.db.GetByID(blockID, >smodel.Block{}); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are no entries
return false, nil
}
diff --git a/internal/federation/federatingdb/util.go b/internal/federation/federatingdb/util.go
index 28f4c5a21..9a3e60cf1 100644
--- a/internal/federation/federatingdb/util.go
+++ b/internal/federation/federatingdb/util.go
@@ -213,7 +213,7 @@ func (f *federatingDB) ActorForOutbox(c context.Context, outboxIRI *url.URL) (ac
}
acct := >smodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "outbox_uri", Value: outboxIRI.String()}}, acct); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to outbox %s", outboxIRI.String())
}
return nil, fmt.Errorf("db error searching for actor with outbox %s", outboxIRI.String())
@@ -238,7 +238,7 @@ func (f *federatingDB) ActorForInbox(c context.Context, inboxIRI *url.URL) (acto
}
acct := >smodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String())
}
return nil, fmt.Errorf("db error searching for actor with inbox %s", inboxIRI.String())
diff --git a/internal/federation/federatingprotocol.go b/internal/federation/federatingprotocol.go
index 9e21b43bf..be988cda2 100644
--- a/internal/federation/federatingprotocol.go
+++ b/internal/federation/federatingprotocol.go
@@ -132,7 +132,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
// authentication has passed, so add an instance entry for this instance if it hasn't been done already
i := >smodel.Instance{}
if err := f.db.GetWhere([]db.Where{{Key: "domain", Value: publicKeyOwnerURI.Host, CaseInsensitive: true}}, i); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// there's been an actual error
return ctx, false, fmt.Errorf("error getting requesting account with public key id %s: %s", publicKeyOwnerURI.String(), err)
}
@@ -202,8 +202,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
requestingAccount := >smodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: uri.String()}}, requestingAccount); err != nil {
- _, ok := err.(db.ErrNoEntries)
- if ok {
+ if err == db.ErrNoEntries {
// we don't have an entry for this account so it's not blocked
// TODO: allow a different default to be set for this behavior
continue
diff --git a/internal/federation/util.go b/internal/federation/util.go
index de8654d32..7971f6172 100644
--- a/internal/federation/util.go
+++ b/internal/federation/util.go
@@ -13,7 +13,7 @@ func (f *federator) blockedDomain(host string) (bool, error) {
return true, nil
}
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are no entries so there's no block
return false, nil
}
diff --git a/internal/gtsmodel/account.go b/internal/gtsmodel/account.go
index e560601b8..369fd6b77 100644
--- a/internal/gtsmodel/account.go
+++ b/internal/gtsmodel/account.go
@@ -46,10 +46,12 @@ type Account struct {
// ID of the avatar as a media attachment
AvatarMediaAttachmentID string `pg:"type:CHAR(26)"`
+ AvatarMediaAttachment *MediaAttachment `pg:"rel:has-one"`
// For a non-local account, where can the header be fetched?
AvatarRemoteURL string
// ID of the header as a media attachment
HeaderMediaAttachmentID string `pg:"type:CHAR(26)"`
+ HeaderMediaAttachment *MediaAttachment `pg:"rel:has-one"`
// For a non-local account, where can the header be fetched?
HeaderRemoteURL string
// DisplayName for this account. Can be empty, then just the Username will be used for display purposes.
diff --git a/internal/gtsmodel/instance.go b/internal/gtsmodel/instance.go
index 857831ba3..e0eb1435a 100644
--- a/internal/gtsmodel/instance.go
+++ b/internal/gtsmodel/instance.go
@@ -20,6 +20,7 @@ type Instance struct {
SuspendedAt time.Time
// ID of any existing domain block for this instance in the database
DomainBlockID string `pg:"type:CHAR(26)"`
+ DomainBlock *DomainBlock `pg:"rel:has-one"`
// Short description of this instance
ShortDescription string
// Longer description of this instance
@@ -32,6 +33,7 @@ type Instance struct {
ContactAccountUsername string
// Contact account ID in the database for this instance
ContactAccountID string `pg:"type:CHAR(26)"`
+ ContactAccount *Account `pg:"rel:has-one"`
// Reputation score of this instance
Reputation int64 `pg:",notnull,default:0"`
// Version of the software used on this instance
diff --git a/internal/gtsmodel/status.go b/internal/gtsmodel/status.go
index 21b8fc794..2746284a2 100644
--- a/internal/gtsmodel/status.go
+++ b/internal/gtsmodel/status.go
@@ -47,24 +47,24 @@ type Status struct {
// is this status from a local account?
Local bool
// which account posted this status?
- AccountID string `pg:"type:CHAR(26),notnull"`
- Account *Account `pg:"rel:has-one"`
+ AccountID string `pg:"type:CHAR(26),notnull"`
+ Account *Account `pg:"rel:has-one"`
// AP uri of the owner of this status
AccountURI string
// id of the status this status is a reply to
- InReplyToID string `pg:"type:CHAR(26)"`
- InReplyTo *Status `pg:"-"`
+ InReplyToID string `pg:"type:CHAR(26)"`
+ InReplyTo *Status `pg:"rel:has-one"`
// AP uri of the status this status is a reply to
InReplyToURI string
// id of the account that this status replies to
- InReplyToAccountID string `pg:"type:CHAR(26)"`
- InReplyToAccount *Account `pg:"-"`
+ InReplyToAccountID string `pg:"type:CHAR(26)"`
+ InReplyToAccount *Account `pg:"rel:has-one"`
// id of the status this status is a boost of
- BoostOfID string `pg:"type:CHAR(26)"`
- BoostOf *Status `pg:"-"`
+ BoostOfID string `pg:"type:CHAR(26)"`
+ BoostOf *Status `pg:"rel:has-one"`
// id of the account that owns the boosted status
- BoostOfAccountID string `pg:"type:CHAR(26)"`
- BoostOfAccount *Account `pg:"-"`
+ BoostOfAccountID string `pg:"type:CHAR(26)"`
+ BoostOfAccount *Account `pg:"rel:has-one"`
// cw string for this status
ContentWarning string
// visibility entry for this status
@@ -74,8 +74,8 @@ type Status struct {
// what language is this status written in?
Language string
// Which application was used to create this status?
- CreatedWithApplicationID string `pg:"type:CHAR(26)"`
- CreatedWithApplication *Application `pg:"-"`
+ CreatedWithApplicationID string `pg:"type:CHAR(26)"`
+ CreatedWithApplication *Application `pg:"rel:has-one"`
// advanced visibility for this status
VisibilityAdvanced *VisibilityAdvanced
// What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types
@@ -103,13 +103,6 @@ type Status struct {
GTSEmojis []*Emoji `pg:"-"`
// MediaAttachments used in this status
GTSMediaAttachments []*MediaAttachment `pg:"-"`
- // Status being replied to
-
- // Account being replied to
-
- // Status being boosted
-
- // Account of the boosted status
}
diff --git a/internal/oauth/clientstore.go b/internal/oauth/clientstore.go
index 998f6784e..2e7e0ae88 100644
--- a/internal/oauth/clientstore.go
+++ b/internal/oauth/clientstore.go
@@ -27,11 +27,11 @@ import (
)
type clientStore struct {
- db db.DB
+ db db.Basic
}
// NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend.
-func NewClientStore(db db.DB) oauth2.ClientStore {
+func NewClientStore(db db.Basic) oauth2.ClientStore {
pts := &clientStore{
db: db,
}
diff --git a/internal/oauth/clientstore_test.go b/internal/oauth/clientstore_test.go
index c515ff513..fd3452405 100644
--- a/internal/oauth/clientstore_test.go
+++ b/internal/oauth/clientstore_test.go
@@ -99,7 +99,7 @@ func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
// try to get the deleted client; we should get an error
deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
suite.Assert().Nil(deletedClient)
- suite.Assert().EqualValues(db.ErrNoEntries{}, err)
+ suite.Assert().EqualValues(db.ErrNoEntries, err)
}
func TestPgClientStoreTestSuite(t *testing.T) {
diff --git a/internal/oauth/server.go b/internal/oauth/server.go
index 1289b18af..6d8f50064 100644
--- a/internal/oauth/server.go
+++ b/internal/oauth/server.go
@@ -66,7 +66,7 @@ type s struct {
}
// New returns a new oauth server that implements the Server interface
-func New(database db.DB, log *logrus.Logger) Server {
+func New(database db.Basic, log *logrus.Logger) Server {
ts := newTokenStore(context.Background(), database, log)
cs := NewClientStore(database)
diff --git a/internal/oauth/tokenstore.go b/internal/oauth/tokenstore.go
index 5f8e07882..4fd3183fc 100644
--- a/internal/oauth/tokenstore.go
+++ b/internal/oauth/tokenstore.go
@@ -34,7 +34,7 @@ import (
// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
type tokenStore struct {
oauth2.TokenStore
- db db.DB
+ db db.Basic
log *logrus.Logger
}
@@ -42,7 +42,7 @@ type tokenStore struct {
//
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
// the tokens in the DB once per minute and deletes any that have expired.
-func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.TokenStore {
+func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.TokenStore {
pts := &tokenStore{
db: db,
log: log,
diff --git a/internal/processing/account/createblock.go b/internal/processing/account/createblock.go
index 79ce03805..798e9324f 100644
--- a/internal/processing/account/createblock.go
+++ b/internal/processing/account/createblock.go
@@ -33,7 +33,7 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
// make sure the target account actually exists in our db
targetAcct := >smodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAcct); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: account %s not found in the db: %s", targetAccountID, err))
}
}
diff --git a/internal/processing/account/createfollow.go b/internal/processing/account/createfollow.go
index e89db9d47..55d75fcc5 100644
--- a/internal/processing/account/createfollow.go
+++ b/internal/processing/account/createfollow.go
@@ -42,7 +42,7 @@ func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apim
// make sure the target account actually exists in our db
targetAcct := >smodel.Account{}
if err := p.db.GetByID(form.ID, targetAcct); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err))
}
}
diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go
index 4cadeaac6..e8840abae 100644
--- a/internal/processing/account/delete.go
+++ b/internal/processing/account/delete.go
@@ -135,7 +135,7 @@ selectStatusesLoop:
for {
statuses, err := p.db.GetAccountStatuses(account.ID, 20, false, maxID, false, false)
if err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// no statuses left for this instance so we're done
l.Infof("Delete: done iterating through statuses for account %s", account.Username)
break selectStatusesLoop
@@ -158,7 +158,7 @@ selectStatusesLoop:
}
if err := p.db.DeleteByID(s.ID, s); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// actual error has occurred
l.Errorf("Delete: db error status %s for account %s: %s", s.ID, account.Username, err)
break selectStatusesLoop
@@ -168,7 +168,7 @@ selectStatusesLoop:
// if there are any boosts of this status, delete them as well
boosts := []*gtsmodel.Status{}
if err := p.db.GetWhere([]db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// an actual error has occurred
l.Errorf("Delete: db error selecting boosts of status %s for account %s: %s", s.ID, account.Username, err)
break selectStatusesLoop
@@ -190,7 +190,7 @@ selectStatusesLoop:
}
if err := p.db.DeleteByID(b.ID, b); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// actual error has occurred
l.Errorf("Delete: db error deleting boost with id %s: %s", b.ID, err)
break selectStatusesLoop
diff --git a/internal/processing/account/get.go b/internal/processing/account/get.go
index a70bf02bd..8cfd91cc2 100644
--- a/internal/processing/account/get.go
+++ b/internal/processing/account/get.go
@@ -30,7 +30,7 @@ import (
func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) {
targetAccount := >smodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAccount); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return nil, errors.New("account not found")
}
return nil, fmt.Errorf("db error: %s", err)
diff --git a/internal/processing/account/getfollowers.go b/internal/processing/account/getfollowers.go
index 66cd38f21..4f3617341 100644
--- a/internal/processing/account/getfollowers.go
+++ b/internal/processing/account/getfollowers.go
@@ -40,7 +40,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco
followers := []gtsmodel.Follow{}
accounts := []apimodel.Account{}
if err := p.db.GetAccountFollowers(targetAccountID, &followers, false); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return accounts, nil
}
return nil, gtserror.NewErrorInternalError(err)
@@ -57,7 +57,7 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco
a := >smodel.Account{}
if err := p.db.GetByID(f.AccountID, a); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
continue
}
return nil, gtserror.NewErrorInternalError(err)
diff --git a/internal/processing/account/getfollowing.go b/internal/processing/account/getfollowing.go
index 57461cf23..55e3e25cc 100644
--- a/internal/processing/account/getfollowing.go
+++ b/internal/processing/account/getfollowing.go
@@ -40,7 +40,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco
following := []gtsmodel.Follow{}
accounts := []apimodel.Account{}
if err := p.db.GetAccountFollowing(targetAccountID, &following); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return accounts, nil
}
return nil, gtserror.NewErrorInternalError(err)
@@ -57,7 +57,7 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco
a := >smodel.Account{}
if err := p.db.GetByID(f.TargetAccountID, a); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
continue
}
return nil, gtserror.NewErrorInternalError(err)
diff --git a/internal/processing/account/getstatuses.go b/internal/processing/account/getstatuses.go
index 27cabcc13..64ad71a03 100644
--- a/internal/processing/account/getstatuses.go
+++ b/internal/processing/account/getstatuses.go
@@ -30,7 +30,7 @@ import (
func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) {
targetAccount := >smodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAccount); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no entry found for account id %s", targetAccountID))
}
return nil, gtserror.NewErrorInternalError(err)
@@ -39,7 +39,7 @@ func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccou
apiStatuses := []apimodel.Status{}
statuses, err := p.db.GetAccountStatuses(targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
if err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return apiStatuses, nil
}
return nil, gtserror.NewErrorInternalError(err)
diff --git a/internal/processing/account/removeblock.go b/internal/processing/account/removeblock.go
index 03b0c6750..88a24eeef 100644
--- a/internal/processing/account/removeblock.go
+++ b/internal/processing/account/removeblock.go
@@ -31,7 +31,7 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou
// make sure the target account actually exists in our db
targetAcct := >smodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAcct); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockRemove: account %s not found in the db: %s", targetAccountID, err))
}
}
diff --git a/internal/processing/account/removefollow.go b/internal/processing/account/removefollow.go
index ef8994893..21a0eca55 100644
--- a/internal/processing/account/removefollow.go
+++ b/internal/processing/account/removefollow.go
@@ -40,7 +40,7 @@ func (p *processor) FollowRemove(requestingAccount *gtsmodel.Account, targetAcco
// make sure the target account actually exists in our db
targetAcct := >smodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAcct); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err))
}
}
diff --git a/internal/processing/admin/createdomainblock.go b/internal/processing/admin/createdomainblock.go
index df02cef94..a58b2c9ad 100644
--- a/internal/processing/admin/createdomainblock.go
+++ b/internal/processing/admin/createdomainblock.go
@@ -36,7 +36,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string,
domainBlock := >smodel.DomainBlock{}
err := p.db.GetWhere([]db.Where{{Key: "domain", Value: domain, CaseInsensitive: true}}, domainBlock)
if err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// something went wrong in the DB
return nil, gtserror.NewErrorInternalError(fmt.Errorf("DomainBlockCreate: db error checking for existence of domain block %s: %s", domain, err))
}
@@ -60,7 +60,7 @@ func (p *processor) DomainBlockCreate(account *gtsmodel.Account, domain string,
// put the new block in the database
if err := p.db.Put(domainBlock); err != nil {
- if _, ok := err.(db.ErrAlreadyExists); !ok {
+ if err != db.ErrNoEntries {
// there's a real error creating the block
return nil, gtserror.NewErrorInternalError(fmt.Errorf("DomainBlockCreate: db error putting new domain block %s: %s", domain, err))
}
@@ -125,7 +125,7 @@ selectAccountsLoop:
for {
accounts, err := p.db.GetAccountsForInstance(block.Domain, maxID, limit)
if err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// no accounts left for this instance so we're done
l.Infof("domainBlockProcessSideEffects: done iterating through accounts for domain %s", block.Domain)
break selectAccountsLoop
diff --git a/internal/processing/admin/deletedomainblock.go b/internal/processing/admin/deletedomainblock.go
index b41fedd92..edb0a58f9 100644
--- a/internal/processing/admin/deletedomainblock.go
+++ b/internal/processing/admin/deletedomainblock.go
@@ -32,7 +32,7 @@ func (p *processor) DomainBlockDelete(account *gtsmodel.Account, id string) (*ap
domainBlock := >smodel.DomainBlock{}
if err := p.db.GetByID(id, domainBlock); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/admin/getdomainblock.go b/internal/processing/admin/getdomainblock.go
index 7d1f9e2ab..f74010627 100644
--- a/internal/processing/admin/getdomainblock.go
+++ b/internal/processing/admin/getdomainblock.go
@@ -31,7 +31,7 @@ func (p *processor) DomainBlockGet(account *gtsmodel.Account, id string, export
domainBlock := >smodel.DomainBlock{}
if err := p.db.GetByID(id, domainBlock); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/admin/getdomainblocks.go b/internal/processing/admin/getdomainblocks.go
index 5e2241412..f827d03fc 100644
--- a/internal/processing/admin/getdomainblocks.go
+++ b/internal/processing/admin/getdomainblocks.go
@@ -29,7 +29,7 @@ func (p *processor) DomainBlocksGet(account *gtsmodel.Account, export bool) ([]*
domainBlocks := []*gtsmodel.DomainBlock{}
if err := p.db.GetAll(&domainBlocks); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// something has gone really wrong
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/blocks.go b/internal/processing/blocks.go
index dbc2f77c7..809cbde8e 100644
--- a/internal/processing/blocks.go
+++ b/internal/processing/blocks.go
@@ -31,7 +31,7 @@ import (
func (p *processor) BlocksGet(authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) {
accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(authed.Account.ID, maxID, sinceID, limit)
if err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are just no entries
return &apimodel.BlocksResponse{
Accounts: []*apimodel.Account{},
diff --git a/internal/processing/followrequest.go b/internal/processing/followrequest.go
index 553a953ff..21ee740fc 100644
--- a/internal/processing/followrequest.go
+++ b/internal/processing/followrequest.go
@@ -29,7 +29,7 @@ import (
func (p *processor) FollowRequestsGet(auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) {
frs := []gtsmodel.FollowRequest{}
if err := p.db.GetAccountFollowRequests(auth.Account.ID, &frs); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
return nil, gtserror.NewErrorInternalError(err)
}
}
diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go
index 8e52db89f..a5d268613 100644
--- a/internal/processing/fromcommon.go
+++ b/internal/processing/fromcommon.go
@@ -74,7 +74,7 @@ func (p *processor) notifyStatus(status *gtsmodel.Status) error {
// notification exists already so just continue
continue
}
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// there's a real error in the db
return fmt.Errorf("notifyStatus: error checking existence of notification for mention with id %s : %s", m.ID, err)
}
diff --git a/internal/processing/fromfederator.go b/internal/processing/fromfederator.go
index 949a734c7..c75154d43 100644
--- a/internal/processing/fromfederator.go
+++ b/internal/processing/fromfederator.go
@@ -101,7 +101,7 @@ func (p *processor) processFromFederator(federatorMsg gtsmodel.FromFederator) er
incomingAnnounce.ID = incomingAnnounceID
if err := p.db.Put(incomingAnnounce); err != nil {
- if _, ok := err.(db.ErrAlreadyExists); !ok {
+ if err != db.ErrNoEntries {
return fmt.Errorf("error adding dereferenced announce to the db: %s", err)
}
}
diff --git a/internal/processing/media/delete.go b/internal/processing/media/delete.go
index 694d78ac3..b5ea8c806 100644
--- a/internal/processing/media/delete.go
+++ b/internal/processing/media/delete.go
@@ -12,7 +12,7 @@ import (
func (p *processor) Delete(mediaAttachmentID string) gtserror.WithCode {
a := >smodel.MediaAttachment{}
if err := p.db.GetByID(mediaAttachmentID, a); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// attachment already gone
return nil
}
@@ -38,7 +38,7 @@ func (p *processor) Delete(mediaAttachmentID string) gtserror.WithCode {
// delete the attachment
if err := p.db.DeleteByID(mediaAttachmentID, a); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
errs = append(errs, fmt.Sprintf("remove attachment: %s", err))
}
}
diff --git a/internal/processing/media/getmedia.go b/internal/processing/media/getmedia.go
index c36370225..380a54cc2 100644
--- a/internal/processing/media/getmedia.go
+++ b/internal/processing/media/getmedia.go
@@ -31,7 +31,7 @@ import (
func (p *processor) GetMedia(account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
attachment := >smodel.MediaAttachment{}
if err := p.db.GetByID(mediaAttachmentID, attachment); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// attachment doesn't exist
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))
}
diff --git a/internal/processing/media/update.go b/internal/processing/media/update.go
index 28f3a26f6..89ed08ac1 100644
--- a/internal/processing/media/update.go
+++ b/internal/processing/media/update.go
@@ -32,7 +32,7 @@ import (
func (p *processor) Update(account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) {
attachment := >smodel.MediaAttachment{}
if err := p.db.GetByID(mediaAttachmentID, attachment); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// attachment doesn't exist
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))
}
diff --git a/internal/processing/search.go b/internal/processing/search.go
index 737ad8f71..c1cde8714 100644
--- a/internal/processing/search.go
+++ b/internal/processing/search.go
@@ -196,7 +196,7 @@ func (p *processor) searchAccountByMention(authed *oauth.Auth, mention string, r
return maybeAcct, nil
}
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// if it's not errNoEntries there's been a real database error so bail at this point
return nil, fmt.Errorf("searchAccountByMention: database error: %s", err)
}
diff --git a/internal/processing/status/context.go b/internal/processing/status/context.go
index 32c528296..d342babb0 100644
--- a/internal/processing/status/context.go
+++ b/internal/processing/status/context.go
@@ -19,7 +19,7 @@ func (p *processor) Context(account *gtsmodel.Account, targetStatusID string) (*
targetStatus := >smodel.Status{}
if err := p.db.GetByID(targetStatusID, targetStatus); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(err)
}
return nil, gtserror.NewErrorInternalError(err)
diff --git a/internal/processing/status/delete.go b/internal/processing/status/delete.go
index 259038dee..c73646af5 100644
--- a/internal/processing/status/delete.go
+++ b/internal/processing/status/delete.go
@@ -15,7 +15,7 @@ func (p *processor) Delete(account *gtsmodel.Account, targetStatusID string) (*a
l.Tracef("going to search for target status %s", targetStatusID)
targetStatus := >smodel.Status{}
if err := p.db.GetByID(targetStatusID, targetStatus); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
}
// status is already gone
diff --git a/internal/processing/status/unboost.go b/internal/processing/status/unboost.go
index 3266c9de3..7a4ae5486 100644
--- a/internal/processing/status/unboost.go
+++ b/internal/processing/status/unboost.go
@@ -57,7 +57,7 @@ func (p *processor) Unboost(account *gtsmodel.Account, application *gtsmodel.App
if err != nil {
// something went wrong in the db finding the boost
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error fetching existing boost from database: %s", err))
}
// we just don't have a boost
diff --git a/internal/processing/status/unfave.go b/internal/processing/status/unfave.go
index b51daacb9..82f43ee1f 100644
--- a/internal/processing/status/unfave.go
+++ b/internal/processing/status/unfave.go
@@ -45,7 +45,7 @@ func (p *processor) Unfave(account *gtsmodel.Account, targetStatusID string) (*a
}
if err != nil {
// something went wrong in the db finding the fave
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error fetching existing fave from database: %s", err))
}
// we just don't have a fave
diff --git a/internal/processing/status/util.go b/internal/processing/status/util.go
index 3be53591b..d700016e6 100644
--- a/internal/processing/status/util.go
+++ b/internal/processing/status/util.go
@@ -99,7 +99,7 @@ func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, th
repliedAccount := >smodel.Account{}
// check replied status exists + is replyable
if err := p.db.GetByID(form.InReplyToID, repliedStatus); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return fmt.Errorf("status with id %s not replyable because it doesn't exist", form.InReplyToID)
}
return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err)
@@ -113,14 +113,14 @@ func (p *processor) ProcessReplyToID(form *apimodel.AdvancedStatusCreateForm, th
// check replied account is known to us
if err := p.db.GetByID(repliedStatus.AccountID, repliedAccount); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return fmt.Errorf("status with id %s not replyable because account id %s is not known", form.InReplyToID, repliedStatus.AccountID)
}
return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err)
}
// check if a block exists
if blocked, err := p.db.Blocked(thisAccountID, repliedAccount.ID); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
return fmt.Errorf("status with id %s not replyable: %s", form.InReplyToID, err)
}
} else if blocked {
diff --git a/internal/processing/timeline.go b/internal/processing/timeline.go
index 18d0a6ac7..bceadda73 100644
--- a/internal/processing/timeline.go
+++ b/internal/processing/timeline.go
@@ -76,7 +76,7 @@ func (p *processor) HomeTimelineGet(authed *oauth.Auth, maxID string, sinceID st
func (p *processor) PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.StatusTimelineResponse, gtserror.WithCode) {
statuses, err := p.db.GetPublicTimelineForAccount(authed.Account.ID, maxID, sinceID, minID, limit, local)
if err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are just no entries left
return &apimodel.StatusTimelineResponse{
Statuses: []*apimodel.Status{},
@@ -97,7 +97,7 @@ func (p *processor) PublicTimelineGet(authed *oauth.Auth, maxID string, sinceID
func (p *processor) FavedTimelineGet(authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.StatusTimelineResponse, gtserror.WithCode) {
statuses, nextMaxID, prevMinID, err := p.db.GetFavedTimelineForAccount(authed.Account.ID, maxID, minID, limit)
if err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are just no entries left
return &apimodel.StatusTimelineResponse{
Statuses: []*apimodel.Status{},
@@ -122,7 +122,7 @@ func (p *processor) filterPublicStatuses(authed *oauth.Auth, statuses []*gtsmode
for _, s := range statuses {
targetAccount := >smodel.Account{}
if err := p.db.GetByID(s.AccountID, targetAccount); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
l.Debugf("filterPublicStatuses: skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
continue
}
@@ -157,7 +157,7 @@ func (p *processor) filterFavedStatuses(authed *oauth.Auth, statuses []*gtsmodel
for _, s := range statuses {
targetAccount := >smodel.Account{}
if err := p.db.GetByID(s.AccountID, targetAccount); err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
l.Debugf("filterFavedStatuses: skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
continue
}
diff --git a/internal/router/session.go b/internal/router/session.go
index 2b9be2f56..38810572f 100644
--- a/internal/router/session.go
+++ b/internal/router/session.go
@@ -49,7 +49,7 @@ func useSession(cfg *config.Config, dbService db.DB, engine *gin.Engine) error {
// check if we have a saved router session already
routerSessions := []*gtsmodel.RouterSession{}
if err := dbService.GetAll(&routerSessions); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// proper error occurred
return err
}
diff --git a/internal/timeline/index.go b/internal/timeline/index.go
index 1e1a9d7bb..0b28c1a19 100644
--- a/internal/timeline/index.go
+++ b/internal/timeline/index.go
@@ -52,7 +52,7 @@ grabloop:
for ; len(filtered) < amount && i < 5; i = i + 1 { // try the grabloop 5 times only
statuses, err := t.db.GetHomeTimelineForAccount(t.accountID, "", "", offsetStatus, amount, false)
if err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
break grabloop // we just don't have enough statuses left in the db so index what we've got and then bail
}
return fmt.Errorf("IndexBefore: error getting statuses from db: %s", err)
@@ -132,7 +132,7 @@ grabloop:
l.Tracef("entering grabloop; i is %d; len(filtered) is %d", i, len(filtered))
statuses, err := t.db.GetHomeTimelineForAccount(t.accountID, offsetStatus, "", "", amount, false)
if err != nil {
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
break grabloop // we just don't have enough statuses left in the db so index what we've got and then bail
}
return fmt.Errorf("IndexBehind: error getting statuses from db: %s", err)
diff --git a/internal/timeline/prepare.go b/internal/timeline/prepare.go
index 51846c816..20000b4e9 100644
--- a/internal/timeline/prepare.go
+++ b/internal/timeline/prepare.go
@@ -95,7 +95,7 @@ prepareloop:
if preparing {
if err := t.prepare(entry.statusID); err != nil {
// there's been an error
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// it's a real error
return fmt.Errorf("PrepareBehind: error preparing status with id %s: %s", entry.statusID, err)
}
@@ -150,7 +150,7 @@ prepareloop:
if preparing {
if err := t.prepare(entry.statusID); err != nil {
// there's been an error
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// it's a real error
return fmt.Errorf("PrepareBefore: error preparing status with id %s: %s", entry.statusID, err)
}
@@ -205,7 +205,7 @@ prepareloop:
if err := t.prepare(entry.statusID); err != nil {
// there's been an error
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// it's a real error
return fmt.Errorf("PrepareFromTop: error preparing status with id %s: %s", entry.statusID, err)
}
diff --git a/internal/typeutils/astointernal.go b/internal/typeutils/astointernal.go
index 47f8c24fc..d426b094e 100644
--- a/internal/typeutils/astointernal.go
+++ b/internal/typeutils/astointernal.go
@@ -44,7 +44,7 @@ func (c *converter) ASRepresentationToAccount(accountable ap.Accountable, update
// we already know this account so we can skip generating it
return acct, nil
}
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
// we don't know the account and there's been a real error
return nil, fmt.Errorf("error getting account with uri %s from the database: %s", uri.String(), err)
}
diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go
index 18270d90e..f8562c1ec 100644
--- a/internal/typeutils/internaltofrontend.go
+++ b/internal/typeutils/internaltofrontend.go
@@ -40,7 +40,7 @@ func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account
// check pending follow requests aimed at this account
fr := []gtsmodel.FollowRequest{}
if err := c.db.GetAccountFollowRequests(a.ID, &fr); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
return nil, fmt.Errorf("error getting follow requests: %s", err)
}
}
@@ -63,41 +63,27 @@ func (c *converter) AccountToMastoSensitive(a *gtsmodel.Account) (*model.Account
func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, error) {
// count followers
- followers := []gtsmodel.Follow{}
- if err := c.db.GetAccountFollowers(a.ID, &followers, false); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
- return nil, fmt.Errorf("error getting followers: %s", err)
- }
- }
- var followersCount int
- if followers != nil {
- followersCount = len(followers)
+ followersCount, err := c.db.CountAccountFollowers(a.ID, false)
+ if err != nil {
+ return nil, fmt.Errorf("error counting followers: %s", err)
}
// count following
- following := []gtsmodel.Follow{}
- if err := c.db.GetAccountFollowing(a.ID, &following); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
- return nil, fmt.Errorf("error getting following: %s", err)
- }
- }
- var followingCount int
- if following != nil {
- followingCount = len(following)
+ followingCount, err := c.db.CountAccountFollowing(a.ID, false)
+ if err != nil {
+ return nil, fmt.Errorf("error counting following: %s", err)
}
// count statuses
- statusesCount, err := c.db.GetAccountStatusesCount(a.ID)
+ statusesCount, err := c.db.CountAccountStatuses(a.ID)
if err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
- return nil, fmt.Errorf("error getting last statuses: %s", err)
- }
+ return nil, fmt.Errorf("error getting last statuses: %s", err)
}
// check when the last status was
lastStatus := >smodel.Status{}
if err := c.db.GetAccountLastStatus(a.ID, lastStatus); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
+ if err != db.ErrNoEntries {
return nil, fmt.Errorf("error getting last status: %s", err)
}
}
@@ -107,23 +93,20 @@ func (c *converter) AccountToMastoPublic(a *gtsmodel.Account) (*model.Account, e
}
// build the avatar and header URLs
- avi := >smodel.MediaAttachment{}
- if err := c.db.GetAccountAvatar(avi, a.ID); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
- return nil, fmt.Errorf("error getting avatar: %s", err)
- }
- }
- aviURL := avi.URL
- aviURLStatic := avi.Thumbnail.URL
- header := >smodel.MediaAttachment{}
- if err := c.db.GetAccountHeader(header, a.ID); err != nil {
- if _, ok := err.(db.ErrNoEntries); !ok {
- return nil, fmt.Errorf("error getting header: %s", err)
- }
+ var aviURL string
+ var aviURLStatic string
+ if a.AvatarMediaAttachment != nil {
+ aviURL = a.AvatarMediaAttachment.URL
+ aviURLStatic = a.AvatarMediaAttachment.Thumbnail.URL
+ }
+
+ var headerURL string
+ var headerURLStatic string
+ if a.HeaderMediaAttachment != nil {
+ headerURL = a.HeaderMediaAttachment.URL
+ headerURLStatic = a.HeaderMediaAttachment.Thumbnail.URL
}
- headerURL := header.URL
- headerURLStatic := header.Thumbnail.URL
// get the fields set on this account
fields := []model.Field{}
@@ -585,13 +568,10 @@ func (c *converter) InstanceToMasto(i *gtsmodel.Instance) (*model.Instance, erro
}
// get the instance account if it exists and just skip if it doesn't
- ia := >smodel.Account{}
- if err := c.db.GetWhere([]db.Where{{Key: "username", Value: i.Domain}}, ia); err == nil {
- // instance account exists, get the header for the account if it exists
- attachment := >smodel.MediaAttachment{}
- if err := c.db.GetAccountHeader(attachment, ia.ID); err == nil {
- // header exists, set it on the api model
- mi.Thumbnail = attachment.URL
+ ia, err := c.db.GetInstanceAccount("")
+ if err == nil {
+ if ia.HeaderMediaAttachment != nil {
+ mi.Thumbnail = ia.HeaderMediaAttachment.URL
}
}
diff --git a/internal/visibility/statusvisible.go b/internal/visibility/statusvisible.go
index dc6b74702..638725e2b 100644
--- a/internal/visibility/statusvisible.go
+++ b/internal/visibility/statusvisible.go
@@ -45,7 +45,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
targetUser := >smodel.User{}
if err := f.db.GetWhere([]db.Where{{Key: "account_id", Value: targetAccount.ID}}, targetUser); err != nil {
l.Debug("target user could not be selected")
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return false, nil
}
return false, fmt.Errorf("StatusVisible: db error selecting user for local target account %s: %s", targetAccount.ID, err)
@@ -76,7 +76,7 @@ func (f *filter) StatusVisible(targetStatus *gtsmodel.Status, requestingAccount
if err := f.db.GetWhere([]db.Where{{Key: "account_id", Value: requestingAccount.ID}}, requestingUser); err != nil {
// if the requesting account is local but doesn't have a corresponding user in the db this is a problem
l.Debug("requesting user could not be selected")
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
return false, nil
}
return false, fmt.Errorf("StatusVisible: db error selecting user for local requesting account %s: %s", requestingAccount.ID, err)
diff --git a/internal/visibility/util.go b/internal/visibility/util.go
index e7d5b4378..f004b4c23 100644
--- a/internal/visibility/util.go
+++ b/internal/visibility/util.go
@@ -112,7 +112,7 @@ func (f *filter) blockedDomain(host string) (bool, error) {
return true, nil
}
- if _, ok := err.(db.ErrNoEntries); ok {
+ if err == db.ErrNoEntries {
// there are no entries so there's no block
return false, nil
}