mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 14:32:24 -05:00 
			
		
		
		
	account update nearly working
This commit is contained in:
		
					parent
					
						
							
								362ccf5817
							
						
					
				
			
			
				commit
				
					
						c8ff849a02
					
				
			
		
					 5 changed files with 337 additions and 76 deletions
				
			
		|  | @ -92,6 +92,9 @@ type DB interface { | |||
| 	// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. | ||||
| 	UpdateByID(id string, i interface{}) error | ||||
| 
 | ||||
| 	// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value. | ||||
| 	UpdateOneByID(id string, key string, value interface{}, i interface{}) error | ||||
| 
 | ||||
| 	// DeleteByID removes i with id id. | ||||
| 	// If i didn't exist anyway, then no error should be returned. | ||||
| 	DeleteByID(id string, i interface{}) error | ||||
|  | @ -156,7 +159,15 @@ type DB interface { | |||
| 	NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) | ||||
| 
 | ||||
| 	// SetHeaderOrAvatarForAccountID sets the header or avatar for the given accountID to the given media attachment. | ||||
| 	SetHeaderOrAvatarForAccountID(mediaAttachmen *model.MediaAttachment, accountID string) error | ||||
| 	SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error | ||||
| 
 | ||||
| 	// GetHeaderAvatarForAccountID gets the current avatar for the given account ID. | ||||
| 	// The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists. | ||||
| 	GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error | ||||
| 
 | ||||
| 	// GetHeaderForAccountID gets the current header for the given account ID. | ||||
| 	// The passed mediaAttachment pointer will be populated with the value of the header, if it exists. | ||||
| 	GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error | ||||
| 
 | ||||
| 	/* | ||||
| 		USEFUL CONVERSION FUNCTIONS | ||||
|  |  | |||
|  | @ -274,6 +274,11 @@ func (ps *postgresService) UpdateByID(id string, i interface{}) error { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error { | ||||
| 	_, err := ps.conn.Model(i).Set("? = ?", key, value).Where("id = ?", id).Update() | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (ps *postgresService) DeleteByID(id string, i interface{}) error { | ||||
| 	if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil { | ||||
| 		if err == pg.ErrNoRows { | ||||
|  | @ -468,6 +473,26 @@ func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model. | |||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (ps *postgresService) GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error { | ||||
| 	if err := ps.conn.Model(header).Where("account_id = ?", accountID).Where("header = ?", true).Select(); err != nil { | ||||
| 		if err == pg.ErrNoRows { | ||||
| 			return ErrNoEntries{} | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (ps *postgresService) GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error { | ||||
| 	if err := ps.conn.Model(avatar).Where("account_id = ?", accountID).Where("avatar = ?", true).Select(); err != nil { | ||||
| 		if err == pg.ErrNoRows { | ||||
| 			return ErrNoEntries{} | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| /* | ||||
| 	CONVERSION FUNCTIONS | ||||
| */ | ||||
|  | @ -478,18 +503,6 @@ func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model. | |||
| // that the account actually belongs to. | ||||
| func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) { | ||||
| 
 | ||||
| 	fields := []mastotypes.Field{} | ||||
| 	for _, f := range a.Fields { | ||||
| 		mField := mastotypes.Field{ | ||||
| 			Name:  f.Name, | ||||
| 			Value: f.Value, | ||||
| 		} | ||||
| 		if !f.VerifiedAt.IsZero() { | ||||
| 			mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339) | ||||
| 		} | ||||
| 		fields = append(fields, mField) | ||||
| 	} | ||||
| 
 | ||||
| 	// count followers | ||||
| 	followers := []model.Follow{} | ||||
| 	if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil { | ||||
|  | @ -538,6 +551,39 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype | |||
| 		lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339) | ||||
| 	} | ||||
| 
 | ||||
| 	// build the avatar and header URLs | ||||
| 	avi := &model.MediaAttachment{} | ||||
| 	if err := ps.GetAvatarForAccountID(avi, a.ID); err != nil { | ||||
| 		if _, ok := err.(ErrNoEntries); !ok { | ||||
| 			return nil, fmt.Errorf("error getting avatar: %s", err) | ||||
| 		} | ||||
| 	} | ||||
| 	aviURL := avi.File.Path | ||||
| 	aviURLStatic := avi.Thumbnail.Path | ||||
| 
 | ||||
| 	header := &model.MediaAttachment{} | ||||
| 	if err := ps.GetHeaderForAccountID(avi, a.ID); err != nil { | ||||
| 		if _, ok := err.(ErrNoEntries); !ok { | ||||
| 			return nil, fmt.Errorf("error getting header: %s", err) | ||||
| 		} | ||||
| 	} | ||||
| 	headerURL := header.File.Path | ||||
| 	headerURLStatic := header.Thumbnail.Path | ||||
| 
 | ||||
| 	// get the fields set on this account | ||||
| 	fields := []mastotypes.Field{} | ||||
| 	for _, f := range a.Fields { | ||||
| 		mField := mastotypes.Field{ | ||||
| 			Name:  f.Name, | ||||
| 			Value: f.Value, | ||||
| 		} | ||||
| 		if !f.VerifiedAt.IsZero() { | ||||
| 			mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339) | ||||
| 		} | ||||
| 		fields = append(fields, mField) | ||||
| 	} | ||||
| 
 | ||||
| 	// check pending follow requests aimed at this account | ||||
| 	fr := []model.FollowRequest{} | ||||
| 	if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil { | ||||
| 		if _, ok := err.(ErrNoEntries); !ok { | ||||
|  | @ -549,6 +595,7 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype | |||
| 		frc = len(fr) | ||||
| 	} | ||||
| 
 | ||||
| 	// derive source from fields and other info | ||||
| 	source := &mastotypes.Source{ | ||||
| 		Privacy:             a.Privacy, | ||||
| 		Sensitive:           a.Sensitive, | ||||
|  | @ -567,17 +614,17 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype | |||
| 		Bot:            a.Bot, | ||||
| 		CreatedAt:      a.CreatedAt.Format(time.RFC3339), | ||||
| 		Note:           a.Note, | ||||
| 		URL:            a.URL, | ||||
| 		Avatar:         a.AvatarRemoteURL.String(), | ||||
| 		AvatarStatic:   a.AvatarRemoteURL.String(), | ||||
| 		Header:         a.HeaderRemoteURL.String(), | ||||
| 		HeaderStatic:   a.HeaderRemoteURL.String(), | ||||
| 		URL:            a.URL, // TODO: set this during account creation | ||||
| 		Avatar:         aviURL, // TODO: build this url properly using host and protocol from config | ||||
| 		AvatarStatic:   aviURLStatic, // TODO: build this url properly using host and protocol from config | ||||
| 		Header:         headerURL, // TODO: build this url properly using host and protocol from config | ||||
| 		HeaderStatic:   headerURLStatic, // TODO: build this url properly using host and protocol from config | ||||
| 		FollowersCount: followersCount, | ||||
| 		FollowingCount: followingCount, | ||||
| 		StatusesCount:  statusesCount, | ||||
| 		LastStatusAt:   lastStatusAt, | ||||
| 		Source:         source, | ||||
| 		Emojis:         nil, | ||||
| 		Emojis:         nil, // TODO: implement this | ||||
| 		Fields:         fields, | ||||
| 	}, nil | ||||
| } | ||||
|  |  | |||
|  | @ -19,8 +19,11 @@ | |||
| package account | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"mime/multipart" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 
 | ||||
|  | @ -39,8 +42,9 @@ import ( | |||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	idKey                 = "id" | ||||
| 	basePath              = "/api/v1/accounts" | ||||
| 	basePathWithID        = basePath + "/:id" | ||||
| 	basePathWithID        = basePath + "/:" + idKey | ||||
| 	verifyPath            = basePath + "/verify_credentials" | ||||
| 	updateCredentialsPath = basePath + "/update_credentials" | ||||
| ) | ||||
|  | @ -144,6 +148,10 @@ func (m *accountModule) accountVerifyGETHandler(c *gin.Context) { | |||
| 
 | ||||
| // accountUpdateCredentialsPATCHHandler allows a user to modify their account/profile settings. | ||||
| // It should be served as a PATCH at /api/v1/accounts/update_credentials | ||||
| // | ||||
| // TODO: this can be optimized massively by building up a picture of what we want the new account | ||||
| // details to be, and then inserting it all in the database at once. As it is, we do queries one-by-one | ||||
| // which is not gonna make the database very happy when lots of requests are going through. | ||||
| func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) { | ||||
| 	l := m.log.WithField("func", "accountUpdateCredentialsPATCHHandler") | ||||
| 	authed, err := oauth.MustAuth(c, true, false, false, true) | ||||
|  | @ -152,63 +160,180 @@ func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) { | |||
| 		c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 	l.Tracef("retrieved account %+v", authed.Account.ID) | ||||
| 
 | ||||
| 	l.Trace("parsing request form") | ||||
| 	form := &mastotypes.UpdateCredentialsRequest{} | ||||
| 	if err := c.ShouldBind(form); err != nil || form == nil { | ||||
| 		l.Debugf("could not parse form from request: %s", err) | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": "missing one or more required form values"}) | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO: proper form validation | ||||
| 
 | ||||
| 	// TODO: tidy this code into subfunctions | ||||
| 	if form.Header != nil && form.Header.Size != 0 { | ||||
| 		if form.Header.Size > m.config.MediaConfig.MaxImageSize { | ||||
| 			err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", form.Header.Size, m.config.MediaConfig.MaxImageSize) | ||||
| 			l.Debugf("error processing header: %s", err) | ||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 			return | ||||
| 		} | ||||
| 		f, err := form.Header.Open() | ||||
| 		if err != nil { | ||||
| 			l.Debugf("error processing header: %s", err) | ||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not read provided header: %s", err)}) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		// extract the bytes | ||||
| 		imageBytes := []byte{} | ||||
| 		size, err := f.Read(imageBytes) | ||||
| 		defer func(){ | ||||
| 			if err := f.Close(); err != nil { | ||||
| 				m.log.Errorf("error closing multipart file: %s", err) | ||||
| 			} | ||||
| 		}() | ||||
| 		if err != nil || size == 0 { | ||||
| 			l.Debugf("error processing header: %s", err) | ||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not read provided header: %s", err)}) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		// do the setting | ||||
| 		headerInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(imageBytes, authed.Account.ID, "header") | ||||
| 		if err != nil { | ||||
| 			l.Debugf("error processing header: %s", err) | ||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 			return | ||||
| 		} | ||||
| 		l.Tracef("new header info for account %s is %+v", headerInfo) | ||||
| 	// if everything on the form is nil, then nothing has been set and we shouldn't continue | ||||
| 	if form.Discoverable == nil && form.Bot == nil && form.DisplayName == nil && form.Note == nil && form.Avatar == nil && form.Header == nil && form.Locked == nil && form.Source == nil && form.FieldsAttributes == nil { | ||||
| 		l.Debugf("could not parse form from request") | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": "empty form submitted"}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	l.Tracef("retrieved account %+v", authed.Account.ID) | ||||
| 	if form.Discoverable != nil { | ||||
| 		if err := m.db.UpdateOneByID(authed.Account.ID, "discoverable", *form.Discoverable, &model.Account{}); err != nil { | ||||
| 			l.Debugf("error updating discoverable: %s", err) | ||||
| 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if form.Bot != nil { | ||||
| 		if err := m.db.UpdateOneByID(authed.Account.ID, "bot", *form.Bot, &model.Account{}); err != nil { | ||||
| 			l.Debugf("error updating bot: %s", err) | ||||
| 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if form.DisplayName != nil { | ||||
| 		if err := m.db.UpdateOneByID(authed.Account.ID, "display_name", *form.DisplayName, &model.Account{}); err != nil { | ||||
| 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if form.Note != nil { | ||||
| 		if err := m.db.UpdateOneByID(authed.Account.ID, "note", *form.Note, &model.Account{}); err != nil { | ||||
| 			l.Debugf("error updating note: %s", err) | ||||
| 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if form.Avatar != nil && form.Avatar.Size != 0 { | ||||
| 		avatarInfo, err := m.UpdateAccountAvatar(form.Avatar, authed.Account.ID) | ||||
| 		if err != nil { | ||||
| 			l.Debugf("could not update avatar for account %s: %s", authed.Account.ID, err) | ||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 			return | ||||
| 		} | ||||
| 		l.Tracef("new avatar info for account %s is %+v", authed.Account.ID, avatarInfo) | ||||
| 	} | ||||
| 
 | ||||
| 	if form.Header != nil && form.Header.Size != 0 { | ||||
| 		headerInfo, err := m.UpdateAccountHeader(form.Header, authed.Account.ID) | ||||
| 		if err != nil { | ||||
| 			l.Debugf("could not update header for account %s: %s", authed.Account.ID, err) | ||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 			return | ||||
| 		} | ||||
| 		l.Tracef("new header info for account %s is %+v", authed.Account.ID, headerInfo) | ||||
| 	} | ||||
| 
 | ||||
| 	if form.Locked != nil { | ||||
| 		if err := m.db.UpdateOneByID(authed.Account.ID, "locked", *form.Locked, &model.Account{}); err != nil { | ||||
| 			c.JSON(http.StatusInternalServerError, gin.H{"": err.Error()}) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if form.Source != nil { | ||||
| 
 | ||||
| 	} | ||||
| 
 | ||||
| 	if form.FieldsAttributes != nil { | ||||
| 
 | ||||
| 	} | ||||
| 
 | ||||
| 	// fetch the account with all updated values set | ||||
| 	updatedAccount := &model.Account{} | ||||
| 	if err := m.db.GetByID(authed.Account.ID, updatedAccount); err != nil { | ||||
| 		l.Debugf("could not fetch updated account %s: %s", authed.Account.ID, err) | ||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	acctSensitive, err := m.db.AccountToMastoSensitive(updatedAccount) | ||||
| 	if err != nil { | ||||
| 		l.Tracef("could not convert account into mastosensitive account: %s", err) | ||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive) | ||||
| 	c.JSON(http.StatusOK, acctSensitive) | ||||
| } | ||||
| 
 | ||||
| /* | ||||
| 	HELPER FUNCTIONS | ||||
| */ | ||||
| 
 | ||||
| // TODO: try to combine the below two functions because this is a lot of code repetition. | ||||
| 
 | ||||
| // UpdateAccountAvatar does the dirty work of checking the avatar part of an account update form, | ||||
| // parsing and checking the image, and doing the necessary updates in the database for this to become | ||||
| // the account's new avatar image. | ||||
| func (m *accountModule) UpdateAccountAvatar(avatar *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) { | ||||
| 	var err error | ||||
| 	if avatar.Size > m.config.MediaConfig.MaxImageSize { | ||||
| 		err = fmt.Errorf("avatar with size %d exceeded max image size of %d bytes", avatar.Size, m.config.MediaConfig.MaxImageSize) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	f, err := avatar.Open() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not read provided avatar: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// extract the bytes | ||||
| 	buf := new(bytes.Buffer) | ||||
| 	size, err := io.Copy(buf, f) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not read provided avatar: %s", err) | ||||
| 	} | ||||
| 	if size == 0 { | ||||
| 		return nil, errors.New("could not read provided avatar: size 0 bytes") | ||||
| 	} | ||||
| 
 | ||||
| 	// do the setting | ||||
| 	avatarInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(buf.Bytes(), accountID, "avatar") | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error processing avatar: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return avatarInfo, f.Close() | ||||
| } | ||||
| 
 | ||||
| // UpdateAccountHeader does the dirty work of checking the header part of an account update form, | ||||
| // parsing and checking the image, and doing the necessary updates in the database for this to become | ||||
| // the account's new header image. | ||||
| func (m *accountModule) UpdateAccountHeader(header *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) { | ||||
| 	var err error | ||||
| 	if header.Size > m.config.MediaConfig.MaxImageSize { | ||||
| 		err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", header.Size, m.config.MediaConfig.MaxImageSize) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	f, err := header.Open() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not read provided header: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// extract the bytes | ||||
| 	buf := new(bytes.Buffer) | ||||
| 	size, err := io.Copy(buf, f) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not read provided header: %s", err) | ||||
| 	} | ||||
| 	if size == 0 { | ||||
| 		return nil, errors.New("could not read provided header: size 0 bytes") | ||||
| 	} | ||||
| 
 | ||||
| 	// do the setting | ||||
| 	headerInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(buf.Bytes(), accountID, "header") | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error processing header: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return headerInfo, f.Close() | ||||
| } | ||||
| 
 | ||||
| // accountCreate does the dirty work of making an account and user in the database. | ||||
| // It then returns a token to the caller, for use with the new account, as per the | ||||
| // spec here: https://docs.joinmastodon.org/methods/accounts/ | ||||
|  |  | |||
|  | @ -19,13 +19,17 @@ | |||
| package account | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"mime/multipart" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
|  | @ -40,6 +44,7 @@ import ( | |||
| 	"github.com/superseriousbusiness/gotosocial/internal/db/model" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/media" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/storage" | ||||
| 	"github.com/superseriousbusiness/gotosocial/pkg/mastotypes" | ||||
| 	"github.com/superseriousbusiness/oauth2/v4" | ||||
| 	"github.com/superseriousbusiness/oauth2/v4/models" | ||||
|  | @ -57,7 +62,8 @@ type AccountTestSuite struct { | |||
| 	testApplication      *model.Application | ||||
| 	testToken            oauth2.TokenInfo | ||||
| 	mockOauthServer      *oauth.MockServer | ||||
| 	mockMediaHandler     *media.MockMediaHandler | ||||
| 	mockStorage          *storage.MockStorage | ||||
| 	mediaHandler         media.MediaHandler | ||||
| 	db                   db.DB | ||||
| 	accountModule        *accountModule | ||||
| 	newUserFormHappyPath url.Values | ||||
|  | @ -74,6 +80,11 @@ func (suite *AccountTestSuite) SetupSuite() { | |||
| 	log.SetLevel(logrus.TraceLevel) | ||||
| 	suite.log = log | ||||
| 
 | ||||
| 	suite.testAccountLocal = &model.Account{ | ||||
| 		ID:       uuid.NewString(), | ||||
| 		Username: "test_user", | ||||
| 	} | ||||
| 
 | ||||
| 	// can use this test application throughout | ||||
| 	suite.testApplication = &model.Application{ | ||||
| 		ID:           "weeweeeeeeeeeeeeee", | ||||
|  | @ -107,6 +118,9 @@ func (suite *AccountTestSuite) SetupSuite() { | |||
| 		Database:        "postgres", | ||||
| 		ApplicationName: "gotosocial", | ||||
| 	} | ||||
| 	c.MediaConfig = &config.MediaConfig{ | ||||
| 		MaxImageSize: 2 << 20, | ||||
| 	} | ||||
| 	suite.config = c | ||||
| 
 | ||||
| 	// use an actual database for this, because it's just easier than mocking one out | ||||
|  | @ -130,11 +144,15 @@ func (suite *AccountTestSuite) SetupSuite() { | |||
| 		Code: "we're authorized now!", | ||||
| 	}, nil) | ||||
| 
 | ||||
| 	// mock the media handler because some handlers (eg update credentials) need to upload media (new header/avatar) | ||||
| 	suite.mockMediaHandler = &media.MockMediaHandler{} | ||||
| 	suite.mockStorage = &storage.MockStorage{} | ||||
| 	// We don't need storage to do anything for these tests, so just simulate a success and do nothing -- we won't need to return anything from storage | ||||
| 	suite.mockStorage.On("StoreFileAt", mock.AnythingOfType("string"), mock.AnythingOfType("[]uint8")).Return(nil) | ||||
| 
 | ||||
| 	// set a media handler because some handlers (eg update credentials) need to upload media (new header/avatar) | ||||
| 	suite.mediaHandler = media.New(suite.config, suite.db, suite.mockStorage, log) | ||||
| 
 | ||||
| 	// and finally here's the thing we're actually testing! | ||||
| 	suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mockMediaHandler, suite.log).(*accountModule) | ||||
| 	suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mediaHandler, suite.log).(*accountModule) | ||||
| } | ||||
| 
 | ||||
| func (suite *AccountTestSuite) TearDownSuite() { | ||||
|  | @ -150,9 +168,11 @@ func (suite *AccountTestSuite) SetupTest() { | |||
| 		&model.User{}, | ||||
| 		&model.Account{}, | ||||
| 		&model.Follow{}, | ||||
| 		&model.FollowRequest{}, | ||||
| 		&model.Status{}, | ||||
| 		&model.Application{}, | ||||
| 		&model.EmailDomainBlock{}, | ||||
| 		&model.MediaAttachment{}, | ||||
| 	} | ||||
| 	for _, m := range models { | ||||
| 		if err := suite.db.CreateTable(m); err != nil { | ||||
|  | @ -186,9 +206,11 @@ func (suite *AccountTestSuite) TearDownTest() { | |||
| 		&model.User{}, | ||||
| 		&model.Account{}, | ||||
| 		&model.Follow{}, | ||||
| 		&model.FollowRequest{}, | ||||
| 		&model.Status{}, | ||||
| 		&model.Application{}, | ||||
| 		&model.EmailDomainBlock{}, | ||||
| 		&model.MediaAttachment{}, | ||||
| 	} | ||||
| 	for _, m := range models { | ||||
| 		if err := suite.db.DropTable(m); err != nil { | ||||
|  | @ -201,6 +223,10 @@ func (suite *AccountTestSuite) TearDownTest() { | |||
| 	ACTUAL TESTS | ||||
| */ | ||||
| 
 | ||||
| /* | ||||
| 	TESTING: AccountCreatePOSTHandler | ||||
| */ | ||||
| 
 | ||||
| // TestAccountCreatePOSTHandlerSuccessful checks the happy path for an account creation request: all the fields provided are valid, | ||||
| // and at the end of it a new user and account should be added into the database. | ||||
| // | ||||
|  | @ -455,6 +481,58 @@ func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerInsufficientReason() | |||
| 	assert.Equal(suite.T(), `{"error":"reason should be at least 40 chars but 'just cuz' was 8"}`, string(b)) | ||||
| } | ||||
| 
 | ||||
| /* | ||||
| 	TESTING: AccountUpdateCredentialsPATCHHandler | ||||
| */ | ||||
| 
 | ||||
| func (suite *AccountTestSuite) TestAccountUpdateCredentialsPATCHHandler() { | ||||
| 
 | ||||
| 	// put test local account in db | ||||
| 	err := suite.db.Put(suite.testAccountLocal) | ||||
| 	assert.NoError(suite.T(), err) | ||||
| 
 | ||||
| 	// attach avatar to request | ||||
| 	aviFile, err := os.Open("../../media/test/test-jpeg.jpg") | ||||
| 	assert.NoError(suite.T(), err) | ||||
| 	body := &bytes.Buffer{} | ||||
| 	writer := multipart.NewWriter(body) | ||||
| 
 | ||||
| 	part, err := writer.CreateFormFile("avatar", "test-jpeg.jpg") | ||||
| 	assert.NoError(suite.T(), err) | ||||
| 
 | ||||
| 	_, err = io.Copy(part, aviFile) | ||||
| 	assert.NoError(suite.T(), err) | ||||
| 
 | ||||
| 	err = aviFile.Close() | ||||
| 	assert.NoError(suite.T(), err) | ||||
| 
 | ||||
| 	err = writer.Close() | ||||
| 	assert.NoError(suite.T(), err) | ||||
| 
 | ||||
| 	// setup | ||||
| 	recorder := httptest.NewRecorder() | ||||
| 	ctx, _ := gin.CreateTestContext(recorder) | ||||
| 	ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccountLocal) | ||||
| 	ctx.Set(oauth.SessionAuthorizedToken, suite.testToken) | ||||
| 	ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", updateCredentialsPath), body) // the endpoint we're hitting | ||||
| 	ctx.Request.Header.Set("Content-Type", writer.FormDataContentType()) | ||||
| 	suite.accountModule.accountUpdateCredentialsPATCHHandler(ctx) | ||||
| 
 | ||||
| 	// check response | ||||
| 
 | ||||
| 	// 1. we should have OK because our request was valid | ||||
| 	suite.EqualValues(http.StatusOK, recorder.Code) | ||||
| 
 | ||||
| 	// 2. we should have an error message in the result body | ||||
| 	result := recorder.Result() | ||||
| 	defer result.Body.Close() | ||||
| 	// TODO: implement proper checks here | ||||
| 	// | ||||
| 	// b, err := ioutil.ReadAll(result.Body) | ||||
| 	// assert.NoError(suite.T(), err) | ||||
| 	// assert.Equal(suite.T(), `{"error":"not authorized"}`, string(b)) | ||||
| } | ||||
| 
 | ||||
| func TestAccountTestSuite(t *testing.T) { | ||||
| 	suite.Run(t, new(AccountTestSuite)) | ||||
| } | ||||
|  |  | |||
|  | @ -92,40 +92,40 @@ type AccountCreateRequest struct { | |||
| // See https://docs.joinmastodon.org/methods/accounts/ | ||||
| type UpdateCredentialsRequest struct { | ||||
| 	// Whether the account should be shown in the profile directory. | ||||
| 	Discoverable string `form:"discoverable"` | ||||
| 	Discoverable *bool `form:"discoverable"` | ||||
| 	// Whether the account has a bot flag. | ||||
| 	Bot bool `form:"bot"` | ||||
| 	Bot *bool `form:"bot"` | ||||
| 	// The display name to use for the profile. | ||||
| 	DisplayName string `form:"display_name"` | ||||
| 	DisplayName *string `form:"display_name"` | ||||
| 	// The account bio. | ||||
| 	Note string `form:"note"` | ||||
| 	Note *string `form:"note"` | ||||
| 	// Avatar image encoded using multipart/form-data | ||||
| 	Avatar *multipart.FileHeader `form:"avatar"` | ||||
| 	// Header image encoded using multipart/form-data | ||||
| 	Header *multipart.FileHeader `form:"header"` | ||||
| 	// Whether manual approval of follow requests is required. | ||||
| 	Locked bool `form:"locked"` | ||||
| 	Locked *bool `form:"locked"` | ||||
| 	// New Source values for this account | ||||
| 	Source *UpdateSource `form:"source"` | ||||
| 	// Profile metadata name and value | ||||
| 	FieldsAttributes []UpdateField `form:"fields_attributes"` | ||||
| 	FieldsAttributes *[]UpdateField `form:"fields_attributes"` | ||||
| } | ||||
| 
 | ||||
| // UpdateSource is to be used specifically in an UpdateCredentialsRequest. | ||||
| type UpdateSource struct { | ||||
| 	// Default post privacy for authored statuses. | ||||
| 	Privacy string `form:"privacy"` | ||||
| 	Privacy *string `form:"privacy"` | ||||
| 	// Whether to mark authored statuses as sensitive by default. | ||||
| 	Sensitive bool `form:"sensitive"` | ||||
| 	Sensitive *bool `form:"sensitive"` | ||||
| 	// Default language to use for authored statuses. (ISO 6391) | ||||
| 	Language string `form:"language"` | ||||
| 	Language *string `form:"language"` | ||||
| } | ||||
| 
 | ||||
| // UpdateField is to be used specifically in an UpdateCredentialsRequest. | ||||
| // By default, max 4 fields and 255 characters per property/value. | ||||
| type UpdateField struct { | ||||
| 	// Name of the field | ||||
| 	Name string `form:"name"` | ||||
| 	Name *string `form:"name"` | ||||
| 	// Value of the field | ||||
| 	Value string `form:"value"` | ||||
| 	Value *string `form:"value"` | ||||
| } | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue