mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-10-30 08:12:27 -05:00
account update nearly working
This commit is contained in:
parent
362ccf5817
commit
c8ff849a02
5 changed files with 337 additions and 76 deletions
|
|
@ -92,6 +92,9 @@ type DB interface {
|
||||||
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
|
||||||
UpdateByID(id string, i interface{}) error
|
UpdateByID(id string, i interface{}) error
|
||||||
|
|
||||||
|
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
|
||||||
|
UpdateOneByID(id string, key string, value interface{}, i interface{}) error
|
||||||
|
|
||||||
// DeleteByID removes i with id id.
|
// DeleteByID removes i with id id.
|
||||||
// If i didn't exist anyway, then no error should be returned.
|
// If i didn't exist anyway, then no error should be returned.
|
||||||
DeleteByID(id string, i interface{}) error
|
DeleteByID(id string, i interface{}) error
|
||||||
|
|
@ -156,7 +159,15 @@ type DB interface {
|
||||||
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error)
|
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error)
|
||||||
|
|
||||||
// SetHeaderOrAvatarForAccountID sets the header or avatar for the given accountID to the given media attachment.
|
// SetHeaderOrAvatarForAccountID sets the header or avatar for the given accountID to the given media attachment.
|
||||||
SetHeaderOrAvatarForAccountID(mediaAttachmen *model.MediaAttachment, accountID string) error
|
SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error
|
||||||
|
|
||||||
|
// GetHeaderAvatarForAccountID gets the current avatar for the given account ID.
|
||||||
|
// The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists.
|
||||||
|
GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error
|
||||||
|
|
||||||
|
// GetHeaderForAccountID gets the current header for the given account ID.
|
||||||
|
// The passed mediaAttachment pointer will be populated with the value of the header, if it exists.
|
||||||
|
GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error
|
||||||
|
|
||||||
/*
|
/*
|
||||||
USEFUL CONVERSION FUNCTIONS
|
USEFUL CONVERSION FUNCTIONS
|
||||||
|
|
|
||||||
|
|
@ -274,6 +274,11 @@ func (ps *postgresService) UpdateByID(id string, i interface{}) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
|
||||||
|
_, err := ps.conn.Model(i).Set("? = ?", key, value).Where("id = ?", id).Update()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
|
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
|
||||||
if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
|
if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
|
||||||
if err == pg.ErrNoRows {
|
if err == pg.ErrNoRows {
|
||||||
|
|
@ -468,6 +473,26 @@ func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error {
|
||||||
|
if err := ps.conn.Model(header).Where("account_id = ?", accountID).Where("header = ?", true).Select(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error {
|
||||||
|
if err := ps.conn.Model(avatar).Where("account_id = ?", accountID).Where("avatar = ?", true).Select(); err != nil {
|
||||||
|
if err == pg.ErrNoRows {
|
||||||
|
return ErrNoEntries{}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
CONVERSION FUNCTIONS
|
CONVERSION FUNCTIONS
|
||||||
*/
|
*/
|
||||||
|
|
@ -478,18 +503,6 @@ func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model.
|
||||||
// that the account actually belongs to.
|
// that the account actually belongs to.
|
||||||
func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) {
|
func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) {
|
||||||
|
|
||||||
fields := []mastotypes.Field{}
|
|
||||||
for _, f := range a.Fields {
|
|
||||||
mField := mastotypes.Field{
|
|
||||||
Name: f.Name,
|
|
||||||
Value: f.Value,
|
|
||||||
}
|
|
||||||
if !f.VerifiedAt.IsZero() {
|
|
||||||
mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339)
|
|
||||||
}
|
|
||||||
fields = append(fields, mField)
|
|
||||||
}
|
|
||||||
|
|
||||||
// count followers
|
// count followers
|
||||||
followers := []model.Follow{}
|
followers := []model.Follow{}
|
||||||
if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil {
|
if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil {
|
||||||
|
|
@ -538,6 +551,39 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype
|
||||||
lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339)
|
lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// build the avatar and header URLs
|
||||||
|
avi := &model.MediaAttachment{}
|
||||||
|
if err := ps.GetAvatarForAccountID(avi, a.ID); err != nil {
|
||||||
|
if _, ok := err.(ErrNoEntries); !ok {
|
||||||
|
return nil, fmt.Errorf("error getting avatar: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
aviURL := avi.File.Path
|
||||||
|
aviURLStatic := avi.Thumbnail.Path
|
||||||
|
|
||||||
|
header := &model.MediaAttachment{}
|
||||||
|
if err := ps.GetHeaderForAccountID(avi, a.ID); err != nil {
|
||||||
|
if _, ok := err.(ErrNoEntries); !ok {
|
||||||
|
return nil, fmt.Errorf("error getting header: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
headerURL := header.File.Path
|
||||||
|
headerURLStatic := header.Thumbnail.Path
|
||||||
|
|
||||||
|
// get the fields set on this account
|
||||||
|
fields := []mastotypes.Field{}
|
||||||
|
for _, f := range a.Fields {
|
||||||
|
mField := mastotypes.Field{
|
||||||
|
Name: f.Name,
|
||||||
|
Value: f.Value,
|
||||||
|
}
|
||||||
|
if !f.VerifiedAt.IsZero() {
|
||||||
|
mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
fields = append(fields, mField)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check pending follow requests aimed at this account
|
||||||
fr := []model.FollowRequest{}
|
fr := []model.FollowRequest{}
|
||||||
if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil {
|
if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil {
|
||||||
if _, ok := err.(ErrNoEntries); !ok {
|
if _, ok := err.(ErrNoEntries); !ok {
|
||||||
|
|
@ -549,6 +595,7 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype
|
||||||
frc = len(fr)
|
frc = len(fr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// derive source from fields and other info
|
||||||
source := &mastotypes.Source{
|
source := &mastotypes.Source{
|
||||||
Privacy: a.Privacy,
|
Privacy: a.Privacy,
|
||||||
Sensitive: a.Sensitive,
|
Sensitive: a.Sensitive,
|
||||||
|
|
@ -567,17 +614,17 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype
|
||||||
Bot: a.Bot,
|
Bot: a.Bot,
|
||||||
CreatedAt: a.CreatedAt.Format(time.RFC3339),
|
CreatedAt: a.CreatedAt.Format(time.RFC3339),
|
||||||
Note: a.Note,
|
Note: a.Note,
|
||||||
URL: a.URL,
|
URL: a.URL, // TODO: set this during account creation
|
||||||
Avatar: a.AvatarRemoteURL.String(),
|
Avatar: aviURL, // TODO: build this url properly using host and protocol from config
|
||||||
AvatarStatic: a.AvatarRemoteURL.String(),
|
AvatarStatic: aviURLStatic, // TODO: build this url properly using host and protocol from config
|
||||||
Header: a.HeaderRemoteURL.String(),
|
Header: headerURL, // TODO: build this url properly using host and protocol from config
|
||||||
HeaderStatic: a.HeaderRemoteURL.String(),
|
HeaderStatic: headerURLStatic, // TODO: build this url properly using host and protocol from config
|
||||||
FollowersCount: followersCount,
|
FollowersCount: followersCount,
|
||||||
FollowingCount: followingCount,
|
FollowingCount: followingCount,
|
||||||
StatusesCount: statusesCount,
|
StatusesCount: statusesCount,
|
||||||
LastStatusAt: lastStatusAt,
|
LastStatusAt: lastStatusAt,
|
||||||
Source: source,
|
Source: source,
|
||||||
Emojis: nil,
|
Emojis: nil, // TODO: implement this
|
||||||
Fields: fields,
|
Fields: fields,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,11 @@
|
||||||
package account
|
package account
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
|
@ -39,8 +42,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
idKey = "id"
|
||||||
basePath = "/api/v1/accounts"
|
basePath = "/api/v1/accounts"
|
||||||
basePathWithID = basePath + "/:id"
|
basePathWithID = basePath + "/:" + idKey
|
||||||
verifyPath = basePath + "/verify_credentials"
|
verifyPath = basePath + "/verify_credentials"
|
||||||
updateCredentialsPath = basePath + "/update_credentials"
|
updateCredentialsPath = basePath + "/update_credentials"
|
||||||
)
|
)
|
||||||
|
|
@ -144,6 +148,10 @@ func (m *accountModule) accountVerifyGETHandler(c *gin.Context) {
|
||||||
|
|
||||||
// accountUpdateCredentialsPATCHHandler allows a user to modify their account/profile settings.
|
// accountUpdateCredentialsPATCHHandler allows a user to modify their account/profile settings.
|
||||||
// It should be served as a PATCH at /api/v1/accounts/update_credentials
|
// It should be served as a PATCH at /api/v1/accounts/update_credentials
|
||||||
|
//
|
||||||
|
// TODO: this can be optimized massively by building up a picture of what we want the new account
|
||||||
|
// details to be, and then inserting it all in the database at once. As it is, we do queries one-by-one
|
||||||
|
// which is not gonna make the database very happy when lots of requests are going through.
|
||||||
func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) {
|
func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) {
|
||||||
l := m.log.WithField("func", "accountUpdateCredentialsPATCHHandler")
|
l := m.log.WithField("func", "accountUpdateCredentialsPATCHHandler")
|
||||||
authed, err := oauth.MustAuth(c, true, false, false, true)
|
authed, err := oauth.MustAuth(c, true, false, false, true)
|
||||||
|
|
@ -152,63 +160,180 @@ func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) {
|
||||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
l.Tracef("retrieved account %+v", authed.Account.ID)
|
||||||
|
|
||||||
l.Trace("parsing request form")
|
l.Trace("parsing request form")
|
||||||
form := &mastotypes.UpdateCredentialsRequest{}
|
form := &mastotypes.UpdateCredentialsRequest{}
|
||||||
if err := c.ShouldBind(form); err != nil || form == nil {
|
if err := c.ShouldBind(form); err != nil || form == nil {
|
||||||
l.Debugf("could not parse form from request: %s", err)
|
l.Debugf("could not parse form from request: %s", err)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing one or more required form values"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: proper form validation
|
// if everything on the form is nil, then nothing has been set and we shouldn't continue
|
||||||
|
if form.Discoverable == nil && form.Bot == nil && form.DisplayName == nil && form.Note == nil && form.Avatar == nil && form.Header == nil && form.Locked == nil && form.Source == nil && form.FieldsAttributes == nil {
|
||||||
// TODO: tidy this code into subfunctions
|
l.Debugf("could not parse form from request")
|
||||||
if form.Header != nil && form.Header.Size != 0 {
|
c.JSON(http.StatusBadRequest, gin.H{"error": "empty form submitted"})
|
||||||
if form.Header.Size > m.config.MediaConfig.MaxImageSize {
|
return
|
||||||
err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", form.Header.Size, m.config.MediaConfig.MaxImageSize)
|
|
||||||
l.Debugf("error processing header: %s", err)
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
f, err := form.Header.Open()
|
|
||||||
if err != nil {
|
|
||||||
l.Debugf("error processing header: %s", err)
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not read provided header: %s", err)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// extract the bytes
|
|
||||||
imageBytes := []byte{}
|
|
||||||
size, err := f.Read(imageBytes)
|
|
||||||
defer func(){
|
|
||||||
if err := f.Close(); err != nil {
|
|
||||||
m.log.Errorf("error closing multipart file: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if err != nil || size == 0 {
|
|
||||||
l.Debugf("error processing header: %s", err)
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not read provided header: %s", err)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// do the setting
|
|
||||||
headerInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(imageBytes, authed.Account.ID, "header")
|
|
||||||
if err != nil {
|
|
||||||
l.Debugf("error processing header: %s", err)
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
l.Tracef("new header info for account %s is %+v", headerInfo)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
l.Tracef("retrieved account %+v", authed.Account.ID)
|
if form.Discoverable != nil {
|
||||||
|
if err := m.db.UpdateOneByID(authed.Account.ID, "discoverable", *form.Discoverable, &model.Account{}); err != nil {
|
||||||
|
l.Debugf("error updating discoverable: %s", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if form.Bot != nil {
|
||||||
|
if err := m.db.UpdateOneByID(authed.Account.ID, "bot", *form.Bot, &model.Account{}); err != nil {
|
||||||
|
l.Debugf("error updating bot: %s", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if form.DisplayName != nil {
|
||||||
|
if err := m.db.UpdateOneByID(authed.Account.ID, "display_name", *form.DisplayName, &model.Account{}); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if form.Note != nil {
|
||||||
|
if err := m.db.UpdateOneByID(authed.Account.ID, "note", *form.Note, &model.Account{}); err != nil {
|
||||||
|
l.Debugf("error updating note: %s", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if form.Avatar != nil && form.Avatar.Size != 0 {
|
||||||
|
avatarInfo, err := m.UpdateAccountAvatar(form.Avatar, authed.Account.ID)
|
||||||
|
if err != nil {
|
||||||
|
l.Debugf("could not update avatar for account %s: %s", authed.Account.ID, err)
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
l.Tracef("new avatar info for account %s is %+v", authed.Account.ID, avatarInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
if form.Header != nil && form.Header.Size != 0 {
|
||||||
|
headerInfo, err := m.UpdateAccountHeader(form.Header, authed.Account.ID)
|
||||||
|
if err != nil {
|
||||||
|
l.Debugf("could not update header for account %s: %s", authed.Account.ID, err)
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
l.Tracef("new header info for account %s is %+v", authed.Account.ID, headerInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
if form.Locked != nil {
|
||||||
|
if err := m.db.UpdateOneByID(authed.Account.ID, "locked", *form.Locked, &model.Account{}); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if form.Source != nil {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if form.FieldsAttributes != nil {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetch the account with all updated values set
|
||||||
|
updatedAccount := &model.Account{}
|
||||||
|
if err := m.db.GetByID(authed.Account.ID, updatedAccount); err != nil {
|
||||||
|
l.Debugf("could not fetch updated account %s: %s", authed.Account.ID, err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
acctSensitive, err := m.db.AccountToMastoSensitive(updatedAccount)
|
||||||
|
if err != nil {
|
||||||
|
l.Tracef("could not convert account into mastosensitive account: %s", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive)
|
||||||
|
c.JSON(http.StatusOK, acctSensitive)
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
HELPER FUNCTIONS
|
HELPER FUNCTIONS
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
// TODO: try to combine the below two functions because this is a lot of code repetition.
|
||||||
|
|
||||||
|
// UpdateAccountAvatar does the dirty work of checking the avatar part of an account update form,
|
||||||
|
// parsing and checking the image, and doing the necessary updates in the database for this to become
|
||||||
|
// the account's new avatar image.
|
||||||
|
func (m *accountModule) UpdateAccountAvatar(avatar *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) {
|
||||||
|
var err error
|
||||||
|
if avatar.Size > m.config.MediaConfig.MaxImageSize {
|
||||||
|
err = fmt.Errorf("avatar with size %d exceeded max image size of %d bytes", avatar.Size, m.config.MediaConfig.MaxImageSize)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f, err := avatar.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not read provided avatar: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract the bytes
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
size, err := io.Copy(buf, f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not read provided avatar: %s", err)
|
||||||
|
}
|
||||||
|
if size == 0 {
|
||||||
|
return nil, errors.New("could not read provided avatar: size 0 bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// do the setting
|
||||||
|
avatarInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(buf.Bytes(), accountID, "avatar")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error processing avatar: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return avatarInfo, f.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAccountHeader does the dirty work of checking the header part of an account update form,
|
||||||
|
// parsing and checking the image, and doing the necessary updates in the database for this to become
|
||||||
|
// the account's new header image.
|
||||||
|
func (m *accountModule) UpdateAccountHeader(header *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) {
|
||||||
|
var err error
|
||||||
|
if header.Size > m.config.MediaConfig.MaxImageSize {
|
||||||
|
err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", header.Size, m.config.MediaConfig.MaxImageSize)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f, err := header.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not read provided header: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract the bytes
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
size, err := io.Copy(buf, f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not read provided header: %s", err)
|
||||||
|
}
|
||||||
|
if size == 0 {
|
||||||
|
return nil, errors.New("could not read provided header: size 0 bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// do the setting
|
||||||
|
headerInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(buf.Bytes(), accountID, "header")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error processing header: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return headerInfo, f.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// accountCreate does the dirty work of making an account and user in the database.
|
// accountCreate does the dirty work of making an account and user in the database.
|
||||||
// It then returns a token to the caller, for use with the new account, as per the
|
// It then returns a token to the caller, for use with the new account, as per the
|
||||||
// spec here: https://docs.joinmastodon.org/methods/accounts/
|
// spec here: https://docs.joinmastodon.org/methods/accounts/
|
||||||
|
|
|
||||||
|
|
@ -19,13 +19,17 @@
|
||||||
package account
|
package account
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -40,6 +44,7 @@ import (
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
"github.com/superseriousbusiness/gotosocial/internal/db/model"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||||
|
"github.com/superseriousbusiness/gotosocial/internal/storage"
|
||||||
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
|
"github.com/superseriousbusiness/gotosocial/pkg/mastotypes"
|
||||||
"github.com/superseriousbusiness/oauth2/v4"
|
"github.com/superseriousbusiness/oauth2/v4"
|
||||||
"github.com/superseriousbusiness/oauth2/v4/models"
|
"github.com/superseriousbusiness/oauth2/v4/models"
|
||||||
|
|
@ -57,7 +62,8 @@ type AccountTestSuite struct {
|
||||||
testApplication *model.Application
|
testApplication *model.Application
|
||||||
testToken oauth2.TokenInfo
|
testToken oauth2.TokenInfo
|
||||||
mockOauthServer *oauth.MockServer
|
mockOauthServer *oauth.MockServer
|
||||||
mockMediaHandler *media.MockMediaHandler
|
mockStorage *storage.MockStorage
|
||||||
|
mediaHandler media.MediaHandler
|
||||||
db db.DB
|
db db.DB
|
||||||
accountModule *accountModule
|
accountModule *accountModule
|
||||||
newUserFormHappyPath url.Values
|
newUserFormHappyPath url.Values
|
||||||
|
|
@ -74,6 +80,11 @@ func (suite *AccountTestSuite) SetupSuite() {
|
||||||
log.SetLevel(logrus.TraceLevel)
|
log.SetLevel(logrus.TraceLevel)
|
||||||
suite.log = log
|
suite.log = log
|
||||||
|
|
||||||
|
suite.testAccountLocal = &model.Account{
|
||||||
|
ID: uuid.NewString(),
|
||||||
|
Username: "test_user",
|
||||||
|
}
|
||||||
|
|
||||||
// can use this test application throughout
|
// can use this test application throughout
|
||||||
suite.testApplication = &model.Application{
|
suite.testApplication = &model.Application{
|
||||||
ID: "weeweeeeeeeeeeeeee",
|
ID: "weeweeeeeeeeeeeeee",
|
||||||
|
|
@ -107,6 +118,9 @@ func (suite *AccountTestSuite) SetupSuite() {
|
||||||
Database: "postgres",
|
Database: "postgres",
|
||||||
ApplicationName: "gotosocial",
|
ApplicationName: "gotosocial",
|
||||||
}
|
}
|
||||||
|
c.MediaConfig = &config.MediaConfig{
|
||||||
|
MaxImageSize: 2 << 20,
|
||||||
|
}
|
||||||
suite.config = c
|
suite.config = c
|
||||||
|
|
||||||
// use an actual database for this, because it's just easier than mocking one out
|
// use an actual database for this, because it's just easier than mocking one out
|
||||||
|
|
@ -130,11 +144,15 @@ func (suite *AccountTestSuite) SetupSuite() {
|
||||||
Code: "we're authorized now!",
|
Code: "we're authorized now!",
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
// mock the media handler because some handlers (eg update credentials) need to upload media (new header/avatar)
|
suite.mockStorage = &storage.MockStorage{}
|
||||||
suite.mockMediaHandler = &media.MockMediaHandler{}
|
// We don't need storage to do anything for these tests, so just simulate a success and do nothing -- we won't need to return anything from storage
|
||||||
|
suite.mockStorage.On("StoreFileAt", mock.AnythingOfType("string"), mock.AnythingOfType("[]uint8")).Return(nil)
|
||||||
|
|
||||||
|
// set a media handler because some handlers (eg update credentials) need to upload media (new header/avatar)
|
||||||
|
suite.mediaHandler = media.New(suite.config, suite.db, suite.mockStorage, log)
|
||||||
|
|
||||||
// and finally here's the thing we're actually testing!
|
// and finally here's the thing we're actually testing!
|
||||||
suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mockMediaHandler, suite.log).(*accountModule)
|
suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mediaHandler, suite.log).(*accountModule)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *AccountTestSuite) TearDownSuite() {
|
func (suite *AccountTestSuite) TearDownSuite() {
|
||||||
|
|
@ -150,9 +168,11 @@ func (suite *AccountTestSuite) SetupTest() {
|
||||||
&model.User{},
|
&model.User{},
|
||||||
&model.Account{},
|
&model.Account{},
|
||||||
&model.Follow{},
|
&model.Follow{},
|
||||||
|
&model.FollowRequest{},
|
||||||
&model.Status{},
|
&model.Status{},
|
||||||
&model.Application{},
|
&model.Application{},
|
||||||
&model.EmailDomainBlock{},
|
&model.EmailDomainBlock{},
|
||||||
|
&model.MediaAttachment{},
|
||||||
}
|
}
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
if err := suite.db.CreateTable(m); err != nil {
|
if err := suite.db.CreateTable(m); err != nil {
|
||||||
|
|
@ -186,9 +206,11 @@ func (suite *AccountTestSuite) TearDownTest() {
|
||||||
&model.User{},
|
&model.User{},
|
||||||
&model.Account{},
|
&model.Account{},
|
||||||
&model.Follow{},
|
&model.Follow{},
|
||||||
|
&model.FollowRequest{},
|
||||||
&model.Status{},
|
&model.Status{},
|
||||||
&model.Application{},
|
&model.Application{},
|
||||||
&model.EmailDomainBlock{},
|
&model.EmailDomainBlock{},
|
||||||
|
&model.MediaAttachment{},
|
||||||
}
|
}
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
if err := suite.db.DropTable(m); err != nil {
|
if err := suite.db.DropTable(m); err != nil {
|
||||||
|
|
@ -201,6 +223,10 @@ func (suite *AccountTestSuite) TearDownTest() {
|
||||||
ACTUAL TESTS
|
ACTUAL TESTS
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
TESTING: AccountCreatePOSTHandler
|
||||||
|
*/
|
||||||
|
|
||||||
// TestAccountCreatePOSTHandlerSuccessful checks the happy path for an account creation request: all the fields provided are valid,
|
// TestAccountCreatePOSTHandlerSuccessful checks the happy path for an account creation request: all the fields provided are valid,
|
||||||
// and at the end of it a new user and account should be added into the database.
|
// and at the end of it a new user and account should be added into the database.
|
||||||
//
|
//
|
||||||
|
|
@ -455,6 +481,58 @@ func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerInsufficientReason()
|
||||||
assert.Equal(suite.T(), `{"error":"reason should be at least 40 chars but 'just cuz' was 8"}`, string(b))
|
assert.Equal(suite.T(), `{"error":"reason should be at least 40 chars but 'just cuz' was 8"}`, string(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
TESTING: AccountUpdateCredentialsPATCHHandler
|
||||||
|
*/
|
||||||
|
|
||||||
|
func (suite *AccountTestSuite) TestAccountUpdateCredentialsPATCHHandler() {
|
||||||
|
|
||||||
|
// put test local account in db
|
||||||
|
err := suite.db.Put(suite.testAccountLocal)
|
||||||
|
assert.NoError(suite.T(), err)
|
||||||
|
|
||||||
|
// attach avatar to request
|
||||||
|
aviFile, err := os.Open("../../media/test/test-jpeg.jpg")
|
||||||
|
assert.NoError(suite.T(), err)
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
|
||||||
|
part, err := writer.CreateFormFile("avatar", "test-jpeg.jpg")
|
||||||
|
assert.NoError(suite.T(), err)
|
||||||
|
|
||||||
|
_, err = io.Copy(part, aviFile)
|
||||||
|
assert.NoError(suite.T(), err)
|
||||||
|
|
||||||
|
err = aviFile.Close()
|
||||||
|
assert.NoError(suite.T(), err)
|
||||||
|
|
||||||
|
err = writer.Close()
|
||||||
|
assert.NoError(suite.T(), err)
|
||||||
|
|
||||||
|
// setup
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(recorder)
|
||||||
|
ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccountLocal)
|
||||||
|
ctx.Set(oauth.SessionAuthorizedToken, suite.testToken)
|
||||||
|
ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", updateCredentialsPath), body) // the endpoint we're hitting
|
||||||
|
ctx.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
suite.accountModule.accountUpdateCredentialsPATCHHandler(ctx)
|
||||||
|
|
||||||
|
// check response
|
||||||
|
|
||||||
|
// 1. we should have OK because our request was valid
|
||||||
|
suite.EqualValues(http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
// 2. we should have an error message in the result body
|
||||||
|
result := recorder.Result()
|
||||||
|
defer result.Body.Close()
|
||||||
|
// TODO: implement proper checks here
|
||||||
|
//
|
||||||
|
// b, err := ioutil.ReadAll(result.Body)
|
||||||
|
// assert.NoError(suite.T(), err)
|
||||||
|
// assert.Equal(suite.T(), `{"error":"not authorized"}`, string(b))
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountTestSuite(t *testing.T) {
|
func TestAccountTestSuite(t *testing.T) {
|
||||||
suite.Run(t, new(AccountTestSuite))
|
suite.Run(t, new(AccountTestSuite))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -92,40 +92,40 @@ type AccountCreateRequest struct {
|
||||||
// See https://docs.joinmastodon.org/methods/accounts/
|
// See https://docs.joinmastodon.org/methods/accounts/
|
||||||
type UpdateCredentialsRequest struct {
|
type UpdateCredentialsRequest struct {
|
||||||
// Whether the account should be shown in the profile directory.
|
// Whether the account should be shown in the profile directory.
|
||||||
Discoverable string `form:"discoverable"`
|
Discoverable *bool `form:"discoverable"`
|
||||||
// Whether the account has a bot flag.
|
// Whether the account has a bot flag.
|
||||||
Bot bool `form:"bot"`
|
Bot *bool `form:"bot"`
|
||||||
// The display name to use for the profile.
|
// The display name to use for the profile.
|
||||||
DisplayName string `form:"display_name"`
|
DisplayName *string `form:"display_name"`
|
||||||
// The account bio.
|
// The account bio.
|
||||||
Note string `form:"note"`
|
Note *string `form:"note"`
|
||||||
// Avatar image encoded using multipart/form-data
|
// Avatar image encoded using multipart/form-data
|
||||||
Avatar *multipart.FileHeader `form:"avatar"`
|
Avatar *multipart.FileHeader `form:"avatar"`
|
||||||
// Header image encoded using multipart/form-data
|
// Header image encoded using multipart/form-data
|
||||||
Header *multipart.FileHeader `form:"header"`
|
Header *multipart.FileHeader `form:"header"`
|
||||||
// Whether manual approval of follow requests is required.
|
// Whether manual approval of follow requests is required.
|
||||||
Locked bool `form:"locked"`
|
Locked *bool `form:"locked"`
|
||||||
// New Source values for this account
|
// New Source values for this account
|
||||||
Source *UpdateSource `form:"source"`
|
Source *UpdateSource `form:"source"`
|
||||||
// Profile metadata name and value
|
// Profile metadata name and value
|
||||||
FieldsAttributes []UpdateField `form:"fields_attributes"`
|
FieldsAttributes *[]UpdateField `form:"fields_attributes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSource is to be used specifically in an UpdateCredentialsRequest.
|
// UpdateSource is to be used specifically in an UpdateCredentialsRequest.
|
||||||
type UpdateSource struct {
|
type UpdateSource struct {
|
||||||
// Default post privacy for authored statuses.
|
// Default post privacy for authored statuses.
|
||||||
Privacy string `form:"privacy"`
|
Privacy *string `form:"privacy"`
|
||||||
// Whether to mark authored statuses as sensitive by default.
|
// Whether to mark authored statuses as sensitive by default.
|
||||||
Sensitive bool `form:"sensitive"`
|
Sensitive *bool `form:"sensitive"`
|
||||||
// Default language to use for authored statuses. (ISO 6391)
|
// Default language to use for authored statuses. (ISO 6391)
|
||||||
Language string `form:"language"`
|
Language *string `form:"language"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateField is to be used specifically in an UpdateCredentialsRequest.
|
// UpdateField is to be used specifically in an UpdateCredentialsRequest.
|
||||||
// By default, max 4 fields and 255 characters per property/value.
|
// By default, max 4 fields and 255 characters per property/value.
|
||||||
type UpdateField struct {
|
type UpdateField struct {
|
||||||
// Name of the field
|
// Name of the field
|
||||||
Name string `form:"name"`
|
Name *string `form:"name"`
|
||||||
// Value of the field
|
// Value of the field
|
||||||
Value string `form:"value"`
|
Value *string `form:"value"`
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue