mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 12:12:25 -05:00 
			
		
		
		
	account update nearly working
This commit is contained in:
		
					parent
					
						
							
								362ccf5817
							
						
					
				
			
			
				commit
				
					
						c8ff849a02
					
				
			
		
					 5 changed files with 337 additions and 76 deletions
				
			
		|  | @ -92,6 +92,9 @@ type DB interface { | ||||||
| 	// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. | 	// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. | ||||||
| 	UpdateByID(id string, i interface{}) error | 	UpdateByID(id string, i interface{}) error | ||||||
| 
 | 
 | ||||||
|  | 	// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value. | ||||||
|  | 	UpdateOneByID(id string, key string, value interface{}, i interface{}) error | ||||||
|  | 
 | ||||||
| 	// DeleteByID removes i with id id. | 	// DeleteByID removes i with id id. | ||||||
| 	// If i didn't exist anyway, then no error should be returned. | 	// If i didn't exist anyway, then no error should be returned. | ||||||
| 	DeleteByID(id string, i interface{}) error | 	DeleteByID(id string, i interface{}) error | ||||||
|  | @ -156,7 +159,15 @@ type DB interface { | ||||||
| 	NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) | 	NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) | ||||||
| 
 | 
 | ||||||
| 	// SetHeaderOrAvatarForAccountID sets the header or avatar for the given accountID to the given media attachment. | 	// SetHeaderOrAvatarForAccountID sets the header or avatar for the given accountID to the given media attachment. | ||||||
| 	SetHeaderOrAvatarForAccountID(mediaAttachmen *model.MediaAttachment, accountID string) error | 	SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error | ||||||
|  | 
 | ||||||
|  | 	// GetHeaderAvatarForAccountID gets the current avatar for the given account ID. | ||||||
|  | 	// The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists. | ||||||
|  | 	GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error | ||||||
|  | 
 | ||||||
|  | 	// GetHeaderForAccountID gets the current header for the given account ID. | ||||||
|  | 	// The passed mediaAttachment pointer will be populated with the value of the header, if it exists. | ||||||
|  | 	GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error | ||||||
| 
 | 
 | ||||||
| 	/* | 	/* | ||||||
| 		USEFUL CONVERSION FUNCTIONS | 		USEFUL CONVERSION FUNCTIONS | ||||||
|  |  | ||||||
|  | @ -274,6 +274,11 @@ func (ps *postgresService) UpdateByID(id string, i interface{}) error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error { | ||||||
|  | 	_, err := ps.conn.Model(i).Set("? = ?", key, value).Where("id = ?", id).Update() | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (ps *postgresService) DeleteByID(id string, i interface{}) error { | func (ps *postgresService) DeleteByID(id string, i interface{}) error { | ||||||
| 	if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil { | 	if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil { | ||||||
| 		if err == pg.ErrNoRows { | 		if err == pg.ErrNoRows { | ||||||
|  | @ -468,6 +473,26 @@ func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model. | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (ps *postgresService) GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error { | ||||||
|  | 	if err := ps.conn.Model(header).Where("account_id = ?", accountID).Where("header = ?", true).Select(); err != nil { | ||||||
|  | 		if err == pg.ErrNoRows { | ||||||
|  | 			return ErrNoEntries{} | ||||||
|  | 		} | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (ps *postgresService) GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error { | ||||||
|  | 	if err := ps.conn.Model(avatar).Where("account_id = ?", accountID).Where("avatar = ?", true).Select(); err != nil { | ||||||
|  | 		if err == pg.ErrNoRows { | ||||||
|  | 			return ErrNoEntries{} | ||||||
|  | 		} | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| /* | /* | ||||||
| 	CONVERSION FUNCTIONS | 	CONVERSION FUNCTIONS | ||||||
| */ | */ | ||||||
|  | @ -478,18 +503,6 @@ func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model. | ||||||
| // that the account actually belongs to. | // that the account actually belongs to. | ||||||
| func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) { | func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) { | ||||||
| 
 | 
 | ||||||
| 	fields := []mastotypes.Field{} |  | ||||||
| 	for _, f := range a.Fields { |  | ||||||
| 		mField := mastotypes.Field{ |  | ||||||
| 			Name:  f.Name, |  | ||||||
| 			Value: f.Value, |  | ||||||
| 		} |  | ||||||
| 		if !f.VerifiedAt.IsZero() { |  | ||||||
| 			mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339) |  | ||||||
| 		} |  | ||||||
| 		fields = append(fields, mField) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// count followers | 	// count followers | ||||||
| 	followers := []model.Follow{} | 	followers := []model.Follow{} | ||||||
| 	if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil { | 	if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil { | ||||||
|  | @ -538,6 +551,39 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype | ||||||
| 		lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339) | 		lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// build the avatar and header URLs | ||||||
|  | 	avi := &model.MediaAttachment{} | ||||||
|  | 	if err := ps.GetAvatarForAccountID(avi, a.ID); err != nil { | ||||||
|  | 		if _, ok := err.(ErrNoEntries); !ok { | ||||||
|  | 			return nil, fmt.Errorf("error getting avatar: %s", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	aviURL := avi.File.Path | ||||||
|  | 	aviURLStatic := avi.Thumbnail.Path | ||||||
|  | 
 | ||||||
|  | 	header := &model.MediaAttachment{} | ||||||
|  | 	if err := ps.GetHeaderForAccountID(avi, a.ID); err != nil { | ||||||
|  | 		if _, ok := err.(ErrNoEntries); !ok { | ||||||
|  | 			return nil, fmt.Errorf("error getting header: %s", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	headerURL := header.File.Path | ||||||
|  | 	headerURLStatic := header.Thumbnail.Path | ||||||
|  | 
 | ||||||
|  | 	// get the fields set on this account | ||||||
|  | 	fields := []mastotypes.Field{} | ||||||
|  | 	for _, f := range a.Fields { | ||||||
|  | 		mField := mastotypes.Field{ | ||||||
|  | 			Name:  f.Name, | ||||||
|  | 			Value: f.Value, | ||||||
|  | 		} | ||||||
|  | 		if !f.VerifiedAt.IsZero() { | ||||||
|  | 			mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339) | ||||||
|  | 		} | ||||||
|  | 		fields = append(fields, mField) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// check pending follow requests aimed at this account | ||||||
| 	fr := []model.FollowRequest{} | 	fr := []model.FollowRequest{} | ||||||
| 	if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil { | 	if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil { | ||||||
| 		if _, ok := err.(ErrNoEntries); !ok { | 		if _, ok := err.(ErrNoEntries); !ok { | ||||||
|  | @ -549,6 +595,7 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype | ||||||
| 		frc = len(fr) | 		frc = len(fr) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// derive source from fields and other info | ||||||
| 	source := &mastotypes.Source{ | 	source := &mastotypes.Source{ | ||||||
| 		Privacy:             a.Privacy, | 		Privacy:             a.Privacy, | ||||||
| 		Sensitive:           a.Sensitive, | 		Sensitive:           a.Sensitive, | ||||||
|  | @ -567,17 +614,17 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype | ||||||
| 		Bot:            a.Bot, | 		Bot:            a.Bot, | ||||||
| 		CreatedAt:      a.CreatedAt.Format(time.RFC3339), | 		CreatedAt:      a.CreatedAt.Format(time.RFC3339), | ||||||
| 		Note:           a.Note, | 		Note:           a.Note, | ||||||
| 		URL:            a.URL, | 		URL:            a.URL, // TODO: set this during account creation | ||||||
| 		Avatar:         a.AvatarRemoteURL.String(), | 		Avatar:         aviURL, // TODO: build this url properly using host and protocol from config | ||||||
| 		AvatarStatic:   a.AvatarRemoteURL.String(), | 		AvatarStatic:   aviURLStatic, // TODO: build this url properly using host and protocol from config | ||||||
| 		Header:         a.HeaderRemoteURL.String(), | 		Header:         headerURL, // TODO: build this url properly using host and protocol from config | ||||||
| 		HeaderStatic:   a.HeaderRemoteURL.String(), | 		HeaderStatic:   headerURLStatic, // TODO: build this url properly using host and protocol from config | ||||||
| 		FollowersCount: followersCount, | 		FollowersCount: followersCount, | ||||||
| 		FollowingCount: followingCount, | 		FollowingCount: followingCount, | ||||||
| 		StatusesCount:  statusesCount, | 		StatusesCount:  statusesCount, | ||||||
| 		LastStatusAt:   lastStatusAt, | 		LastStatusAt:   lastStatusAt, | ||||||
| 		Source:         source, | 		Source:         source, | ||||||
| 		Emojis:         nil, | 		Emojis:         nil, // TODO: implement this | ||||||
| 		Fields:         fields, | 		Fields:         fields, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -19,8 +19,11 @@ | ||||||
| package account | package account | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"bytes" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"mime/multipart" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 
 | 
 | ||||||
|  | @ -39,8 +42,9 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
|  | 	idKey                 = "id" | ||||||
| 	basePath              = "/api/v1/accounts" | 	basePath              = "/api/v1/accounts" | ||||||
| 	basePathWithID        = basePath + "/:id" | 	basePathWithID        = basePath + "/:" + idKey | ||||||
| 	verifyPath            = basePath + "/verify_credentials" | 	verifyPath            = basePath + "/verify_credentials" | ||||||
| 	updateCredentialsPath = basePath + "/update_credentials" | 	updateCredentialsPath = basePath + "/update_credentials" | ||||||
| ) | ) | ||||||
|  | @ -144,6 +148,10 @@ func (m *accountModule) accountVerifyGETHandler(c *gin.Context) { | ||||||
| 
 | 
 | ||||||
| // accountUpdateCredentialsPATCHHandler allows a user to modify their account/profile settings. | // accountUpdateCredentialsPATCHHandler allows a user to modify their account/profile settings. | ||||||
| // It should be served as a PATCH at /api/v1/accounts/update_credentials | // It should be served as a PATCH at /api/v1/accounts/update_credentials | ||||||
|  | // | ||||||
|  | // TODO: this can be optimized massively by building up a picture of what we want the new account | ||||||
|  | // details to be, and then inserting it all in the database at once. As it is, we do queries one-by-one | ||||||
|  | // which is not gonna make the database very happy when lots of requests are going through. | ||||||
| func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) { | func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) { | ||||||
| 	l := m.log.WithField("func", "accountUpdateCredentialsPATCHHandler") | 	l := m.log.WithField("func", "accountUpdateCredentialsPATCHHandler") | ||||||
| 	authed, err := oauth.MustAuth(c, true, false, false, true) | 	authed, err := oauth.MustAuth(c, true, false, false, true) | ||||||
|  | @ -152,63 +160,180 @@ func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) { | ||||||
| 		c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) | 		c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  | 	l.Tracef("retrieved account %+v", authed.Account.ID) | ||||||
| 
 | 
 | ||||||
| 	l.Trace("parsing request form") | 	l.Trace("parsing request form") | ||||||
| 	form := &mastotypes.UpdateCredentialsRequest{} | 	form := &mastotypes.UpdateCredentialsRequest{} | ||||||
| 	if err := c.ShouldBind(form); err != nil || form == nil { | 	if err := c.ShouldBind(form); err != nil || form == nil { | ||||||
| 		l.Debugf("could not parse form from request: %s", err) | 		l.Debugf("could not parse form from request: %s", err) | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": "missing one or more required form values"}) | 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// TODO: proper form validation | 	// if everything on the form is nil, then nothing has been set and we shouldn't continue | ||||||
| 
 | 	if form.Discoverable == nil && form.Bot == nil && form.DisplayName == nil && form.Note == nil && form.Avatar == nil && form.Header == nil && form.Locked == nil && form.Source == nil && form.FieldsAttributes == nil { | ||||||
| 	// TODO: tidy this code into subfunctions | 		l.Debugf("could not parse form from request") | ||||||
| 	if form.Header != nil && form.Header.Size != 0 { | 		c.JSON(http.StatusBadRequest, gin.H{"error": "empty form submitted"}) | ||||||
| 		if form.Header.Size > m.config.MediaConfig.MaxImageSize { | 		return | ||||||
| 			err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", form.Header.Size, m.config.MediaConfig.MaxImageSize) |  | ||||||
| 			l.Debugf("error processing header: %s", err) |  | ||||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		f, err := form.Header.Open() |  | ||||||
| 		if err != nil { |  | ||||||
| 			l.Debugf("error processing header: %s", err) |  | ||||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not read provided header: %s", err)}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// extract the bytes |  | ||||||
| 		imageBytes := []byte{} |  | ||||||
| 		size, err := f.Read(imageBytes) |  | ||||||
| 		defer func(){ |  | ||||||
| 			if err := f.Close(); err != nil { |  | ||||||
| 				m.log.Errorf("error closing multipart file: %s", err) |  | ||||||
| 			} |  | ||||||
| 		}() |  | ||||||
| 		if err != nil || size == 0 { |  | ||||||
| 			l.Debugf("error processing header: %s", err) |  | ||||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not read provided header: %s", err)}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// do the setting |  | ||||||
| 		headerInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(imageBytes, authed.Account.ID, "header") |  | ||||||
| 		if err != nil { |  | ||||||
| 			l.Debugf("error processing header: %s", err) |  | ||||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		l.Tracef("new header info for account %s is %+v", headerInfo) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	l.Tracef("retrieved account %+v", authed.Account.ID) | 	if form.Discoverable != nil { | ||||||
|  | 		if err := m.db.UpdateOneByID(authed.Account.ID, "discoverable", *form.Discoverable, &model.Account{}); err != nil { | ||||||
|  | 			l.Debugf("error updating discoverable: %s", err) | ||||||
|  | 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if form.Bot != nil { | ||||||
|  | 		if err := m.db.UpdateOneByID(authed.Account.ID, "bot", *form.Bot, &model.Account{}); err != nil { | ||||||
|  | 			l.Debugf("error updating bot: %s", err) | ||||||
|  | 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if form.DisplayName != nil { | ||||||
|  | 		if err := m.db.UpdateOneByID(authed.Account.ID, "display_name", *form.DisplayName, &model.Account{}); err != nil { | ||||||
|  | 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if form.Note != nil { | ||||||
|  | 		if err := m.db.UpdateOneByID(authed.Account.ID, "note", *form.Note, &model.Account{}); err != nil { | ||||||
|  | 			l.Debugf("error updating note: %s", err) | ||||||
|  | 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if form.Avatar != nil && form.Avatar.Size != 0 { | ||||||
|  | 		avatarInfo, err := m.UpdateAccountAvatar(form.Avatar, authed.Account.ID) | ||||||
|  | 		if err != nil { | ||||||
|  | 			l.Debugf("could not update avatar for account %s: %s", authed.Account.ID, err) | ||||||
|  | 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		l.Tracef("new avatar info for account %s is %+v", authed.Account.ID, avatarInfo) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if form.Header != nil && form.Header.Size != 0 { | ||||||
|  | 		headerInfo, err := m.UpdateAccountHeader(form.Header, authed.Account.ID) | ||||||
|  | 		if err != nil { | ||||||
|  | 			l.Debugf("could not update header for account %s: %s", authed.Account.ID, err) | ||||||
|  | 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		l.Tracef("new header info for account %s is %+v", authed.Account.ID, headerInfo) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if form.Locked != nil { | ||||||
|  | 		if err := m.db.UpdateOneByID(authed.Account.ID, "locked", *form.Locked, &model.Account{}); err != nil { | ||||||
|  | 			c.JSON(http.StatusInternalServerError, gin.H{"": err.Error()}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if form.Source != nil { | ||||||
|  | 
 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if form.FieldsAttributes != nil { | ||||||
|  | 
 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// fetch the account with all updated values set | ||||||
|  | 	updatedAccount := &model.Account{} | ||||||
|  | 	if err := m.db.GetByID(authed.Account.ID, updatedAccount); err != nil { | ||||||
|  | 		l.Debugf("could not fetch updated account %s: %s", authed.Account.ID, err) | ||||||
|  | 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	acctSensitive, err := m.db.AccountToMastoSensitive(updatedAccount) | ||||||
|  | 	if err != nil { | ||||||
|  | 		l.Tracef("could not convert account into mastosensitive account: %s", err) | ||||||
|  | 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive) | ||||||
|  | 	c.JSON(http.StatusOK, acctSensitive) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* | /* | ||||||
| 	HELPER FUNCTIONS | 	HELPER FUNCTIONS | ||||||
| */ | */ | ||||||
| 
 | 
 | ||||||
|  | // TODO: try to combine the below two functions because this is a lot of code repetition. | ||||||
|  | 
 | ||||||
|  | // UpdateAccountAvatar does the dirty work of checking the avatar part of an account update form, | ||||||
|  | // parsing and checking the image, and doing the necessary updates in the database for this to become | ||||||
|  | // the account's new avatar image. | ||||||
|  | func (m *accountModule) UpdateAccountAvatar(avatar *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) { | ||||||
|  | 	var err error | ||||||
|  | 	if avatar.Size > m.config.MediaConfig.MaxImageSize { | ||||||
|  | 		err = fmt.Errorf("avatar with size %d exceeded max image size of %d bytes", avatar.Size, m.config.MediaConfig.MaxImageSize) | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	f, err := avatar.Open() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("could not read provided avatar: %s", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// extract the bytes | ||||||
|  | 	buf := new(bytes.Buffer) | ||||||
|  | 	size, err := io.Copy(buf, f) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("could not read provided avatar: %s", err) | ||||||
|  | 	} | ||||||
|  | 	if size == 0 { | ||||||
|  | 		return nil, errors.New("could not read provided avatar: size 0 bytes") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// do the setting | ||||||
|  | 	avatarInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(buf.Bytes(), accountID, "avatar") | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error processing avatar: %s", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return avatarInfo, f.Close() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // UpdateAccountHeader does the dirty work of checking the header part of an account update form, | ||||||
|  | // parsing and checking the image, and doing the necessary updates in the database for this to become | ||||||
|  | // the account's new header image. | ||||||
|  | func (m *accountModule) UpdateAccountHeader(header *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) { | ||||||
|  | 	var err error | ||||||
|  | 	if header.Size > m.config.MediaConfig.MaxImageSize { | ||||||
|  | 		err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", header.Size, m.config.MediaConfig.MaxImageSize) | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	f, err := header.Open() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("could not read provided header: %s", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// extract the bytes | ||||||
|  | 	buf := new(bytes.Buffer) | ||||||
|  | 	size, err := io.Copy(buf, f) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("could not read provided header: %s", err) | ||||||
|  | 	} | ||||||
|  | 	if size == 0 { | ||||||
|  | 		return nil, errors.New("could not read provided header: size 0 bytes") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// do the setting | ||||||
|  | 	headerInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(buf.Bytes(), accountID, "header") | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error processing header: %s", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return headerInfo, f.Close() | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // accountCreate does the dirty work of making an account and user in the database. | // accountCreate does the dirty work of making an account and user in the database. | ||||||
| // It then returns a token to the caller, for use with the new account, as per the | // It then returns a token to the caller, for use with the new account, as per the | ||||||
| // spec here: https://docs.joinmastodon.org/methods/accounts/ | // spec here: https://docs.joinmastodon.org/methods/accounts/ | ||||||
|  |  | ||||||
|  | @ -19,13 +19,17 @@ | ||||||
| package account | package account | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"io" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
|  | 	"mime/multipart" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | 	"os" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | @ -40,6 +44,7 @@ import ( | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db/model" | 	"github.com/superseriousbusiness/gotosocial/internal/db/model" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/media" | 	"github.com/superseriousbusiness/gotosocial/internal/media" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | 	"github.com/superseriousbusiness/gotosocial/internal/oauth" | ||||||
|  | 	"github.com/superseriousbusiness/gotosocial/internal/storage" | ||||||
| 	"github.com/superseriousbusiness/gotosocial/pkg/mastotypes" | 	"github.com/superseriousbusiness/gotosocial/pkg/mastotypes" | ||||||
| 	"github.com/superseriousbusiness/oauth2/v4" | 	"github.com/superseriousbusiness/oauth2/v4" | ||||||
| 	"github.com/superseriousbusiness/oauth2/v4/models" | 	"github.com/superseriousbusiness/oauth2/v4/models" | ||||||
|  | @ -57,7 +62,8 @@ type AccountTestSuite struct { | ||||||
| 	testApplication      *model.Application | 	testApplication      *model.Application | ||||||
| 	testToken            oauth2.TokenInfo | 	testToken            oauth2.TokenInfo | ||||||
| 	mockOauthServer      *oauth.MockServer | 	mockOauthServer      *oauth.MockServer | ||||||
| 	mockMediaHandler     *media.MockMediaHandler | 	mockStorage          *storage.MockStorage | ||||||
|  | 	mediaHandler         media.MediaHandler | ||||||
| 	db                   db.DB | 	db                   db.DB | ||||||
| 	accountModule        *accountModule | 	accountModule        *accountModule | ||||||
| 	newUserFormHappyPath url.Values | 	newUserFormHappyPath url.Values | ||||||
|  | @ -74,6 +80,11 @@ func (suite *AccountTestSuite) SetupSuite() { | ||||||
| 	log.SetLevel(logrus.TraceLevel) | 	log.SetLevel(logrus.TraceLevel) | ||||||
| 	suite.log = log | 	suite.log = log | ||||||
| 
 | 
 | ||||||
|  | 	suite.testAccountLocal = &model.Account{ | ||||||
|  | 		ID:       uuid.NewString(), | ||||||
|  | 		Username: "test_user", | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// can use this test application throughout | 	// can use this test application throughout | ||||||
| 	suite.testApplication = &model.Application{ | 	suite.testApplication = &model.Application{ | ||||||
| 		ID:           "weeweeeeeeeeeeeeee", | 		ID:           "weeweeeeeeeeeeeeee", | ||||||
|  | @ -107,6 +118,9 @@ func (suite *AccountTestSuite) SetupSuite() { | ||||||
| 		Database:        "postgres", | 		Database:        "postgres", | ||||||
| 		ApplicationName: "gotosocial", | 		ApplicationName: "gotosocial", | ||||||
| 	} | 	} | ||||||
|  | 	c.MediaConfig = &config.MediaConfig{ | ||||||
|  | 		MaxImageSize: 2 << 20, | ||||||
|  | 	} | ||||||
| 	suite.config = c | 	suite.config = c | ||||||
| 
 | 
 | ||||||
| 	// use an actual database for this, because it's just easier than mocking one out | 	// use an actual database for this, because it's just easier than mocking one out | ||||||
|  | @ -130,11 +144,15 @@ func (suite *AccountTestSuite) SetupSuite() { | ||||||
| 		Code: "we're authorized now!", | 		Code: "we're authorized now!", | ||||||
| 	}, nil) | 	}, nil) | ||||||
| 
 | 
 | ||||||
| 	// mock the media handler because some handlers (eg update credentials) need to upload media (new header/avatar) | 	suite.mockStorage = &storage.MockStorage{} | ||||||
| 	suite.mockMediaHandler = &media.MockMediaHandler{} | 	// We don't need storage to do anything for these tests, so just simulate a success and do nothing -- we won't need to return anything from storage | ||||||
|  | 	suite.mockStorage.On("StoreFileAt", mock.AnythingOfType("string"), mock.AnythingOfType("[]uint8")).Return(nil) | ||||||
|  | 
 | ||||||
|  | 	// set a media handler because some handlers (eg update credentials) need to upload media (new header/avatar) | ||||||
|  | 	suite.mediaHandler = media.New(suite.config, suite.db, suite.mockStorage, log) | ||||||
| 
 | 
 | ||||||
| 	// and finally here's the thing we're actually testing! | 	// and finally here's the thing we're actually testing! | ||||||
| 	suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mockMediaHandler, suite.log).(*accountModule) | 	suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mediaHandler, suite.log).(*accountModule) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *AccountTestSuite) TearDownSuite() { | func (suite *AccountTestSuite) TearDownSuite() { | ||||||
|  | @ -150,9 +168,11 @@ func (suite *AccountTestSuite) SetupTest() { | ||||||
| 		&model.User{}, | 		&model.User{}, | ||||||
| 		&model.Account{}, | 		&model.Account{}, | ||||||
| 		&model.Follow{}, | 		&model.Follow{}, | ||||||
|  | 		&model.FollowRequest{}, | ||||||
| 		&model.Status{}, | 		&model.Status{}, | ||||||
| 		&model.Application{}, | 		&model.Application{}, | ||||||
| 		&model.EmailDomainBlock{}, | 		&model.EmailDomainBlock{}, | ||||||
|  | 		&model.MediaAttachment{}, | ||||||
| 	} | 	} | ||||||
| 	for _, m := range models { | 	for _, m := range models { | ||||||
| 		if err := suite.db.CreateTable(m); err != nil { | 		if err := suite.db.CreateTable(m); err != nil { | ||||||
|  | @ -186,9 +206,11 @@ func (suite *AccountTestSuite) TearDownTest() { | ||||||
| 		&model.User{}, | 		&model.User{}, | ||||||
| 		&model.Account{}, | 		&model.Account{}, | ||||||
| 		&model.Follow{}, | 		&model.Follow{}, | ||||||
|  | 		&model.FollowRequest{}, | ||||||
| 		&model.Status{}, | 		&model.Status{}, | ||||||
| 		&model.Application{}, | 		&model.Application{}, | ||||||
| 		&model.EmailDomainBlock{}, | 		&model.EmailDomainBlock{}, | ||||||
|  | 		&model.MediaAttachment{}, | ||||||
| 	} | 	} | ||||||
| 	for _, m := range models { | 	for _, m := range models { | ||||||
| 		if err := suite.db.DropTable(m); err != nil { | 		if err := suite.db.DropTable(m); err != nil { | ||||||
|  | @ -201,6 +223,10 @@ func (suite *AccountTestSuite) TearDownTest() { | ||||||
| 	ACTUAL TESTS | 	ACTUAL TESTS | ||||||
| */ | */ | ||||||
| 
 | 
 | ||||||
|  | /* | ||||||
|  | 	TESTING: AccountCreatePOSTHandler | ||||||
|  | */ | ||||||
|  | 
 | ||||||
| // TestAccountCreatePOSTHandlerSuccessful checks the happy path for an account creation request: all the fields provided are valid, | // TestAccountCreatePOSTHandlerSuccessful checks the happy path for an account creation request: all the fields provided are valid, | ||||||
| // and at the end of it a new user and account should be added into the database. | // and at the end of it a new user and account should be added into the database. | ||||||
| // | // | ||||||
|  | @ -455,6 +481,58 @@ func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerInsufficientReason() | ||||||
| 	assert.Equal(suite.T(), `{"error":"reason should be at least 40 chars but 'just cuz' was 8"}`, string(b)) | 	assert.Equal(suite.T(), `{"error":"reason should be at least 40 chars but 'just cuz' was 8"}`, string(b)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | /* | ||||||
|  | 	TESTING: AccountUpdateCredentialsPATCHHandler | ||||||
|  | */ | ||||||
|  | 
 | ||||||
|  | func (suite *AccountTestSuite) TestAccountUpdateCredentialsPATCHHandler() { | ||||||
|  | 
 | ||||||
|  | 	// put test local account in db | ||||||
|  | 	err := suite.db.Put(suite.testAccountLocal) | ||||||
|  | 	assert.NoError(suite.T(), err) | ||||||
|  | 
 | ||||||
|  | 	// attach avatar to request | ||||||
|  | 	aviFile, err := os.Open("../../media/test/test-jpeg.jpg") | ||||||
|  | 	assert.NoError(suite.T(), err) | ||||||
|  | 	body := &bytes.Buffer{} | ||||||
|  | 	writer := multipart.NewWriter(body) | ||||||
|  | 
 | ||||||
|  | 	part, err := writer.CreateFormFile("avatar", "test-jpeg.jpg") | ||||||
|  | 	assert.NoError(suite.T(), err) | ||||||
|  | 
 | ||||||
|  | 	_, err = io.Copy(part, aviFile) | ||||||
|  | 	assert.NoError(suite.T(), err) | ||||||
|  | 
 | ||||||
|  | 	err = aviFile.Close() | ||||||
|  | 	assert.NoError(suite.T(), err) | ||||||
|  | 
 | ||||||
|  | 	err = writer.Close() | ||||||
|  | 	assert.NoError(suite.T(), err) | ||||||
|  | 
 | ||||||
|  | 	// setup | ||||||
|  | 	recorder := httptest.NewRecorder() | ||||||
|  | 	ctx, _ := gin.CreateTestContext(recorder) | ||||||
|  | 	ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccountLocal) | ||||||
|  | 	ctx.Set(oauth.SessionAuthorizedToken, suite.testToken) | ||||||
|  | 	ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", updateCredentialsPath), body) // the endpoint we're hitting | ||||||
|  | 	ctx.Request.Header.Set("Content-Type", writer.FormDataContentType()) | ||||||
|  | 	suite.accountModule.accountUpdateCredentialsPATCHHandler(ctx) | ||||||
|  | 
 | ||||||
|  | 	// check response | ||||||
|  | 
 | ||||||
|  | 	// 1. we should have OK because our request was valid | ||||||
|  | 	suite.EqualValues(http.StatusOK, recorder.Code) | ||||||
|  | 
 | ||||||
|  | 	// 2. we should have an error message in the result body | ||||||
|  | 	result := recorder.Result() | ||||||
|  | 	defer result.Body.Close() | ||||||
|  | 	// TODO: implement proper checks here | ||||||
|  | 	// | ||||||
|  | 	// b, err := ioutil.ReadAll(result.Body) | ||||||
|  | 	// assert.NoError(suite.T(), err) | ||||||
|  | 	// assert.Equal(suite.T(), `{"error":"not authorized"}`, string(b)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestAccountTestSuite(t *testing.T) { | func TestAccountTestSuite(t *testing.T) { | ||||||
| 	suite.Run(t, new(AccountTestSuite)) | 	suite.Run(t, new(AccountTestSuite)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -92,40 +92,40 @@ type AccountCreateRequest struct { | ||||||
| // See https://docs.joinmastodon.org/methods/accounts/ | // See https://docs.joinmastodon.org/methods/accounts/ | ||||||
| type UpdateCredentialsRequest struct { | type UpdateCredentialsRequest struct { | ||||||
| 	// Whether the account should be shown in the profile directory. | 	// Whether the account should be shown in the profile directory. | ||||||
| 	Discoverable string `form:"discoverable"` | 	Discoverable *bool `form:"discoverable"` | ||||||
| 	// Whether the account has a bot flag. | 	// Whether the account has a bot flag. | ||||||
| 	Bot bool `form:"bot"` | 	Bot *bool `form:"bot"` | ||||||
| 	// The display name to use for the profile. | 	// The display name to use for the profile. | ||||||
| 	DisplayName string `form:"display_name"` | 	DisplayName *string `form:"display_name"` | ||||||
| 	// The account bio. | 	// The account bio. | ||||||
| 	Note string `form:"note"` | 	Note *string `form:"note"` | ||||||
| 	// Avatar image encoded using multipart/form-data | 	// Avatar image encoded using multipart/form-data | ||||||
| 	Avatar *multipart.FileHeader `form:"avatar"` | 	Avatar *multipart.FileHeader `form:"avatar"` | ||||||
| 	// Header image encoded using multipart/form-data | 	// Header image encoded using multipart/form-data | ||||||
| 	Header *multipart.FileHeader `form:"header"` | 	Header *multipart.FileHeader `form:"header"` | ||||||
| 	// Whether manual approval of follow requests is required. | 	// Whether manual approval of follow requests is required. | ||||||
| 	Locked bool `form:"locked"` | 	Locked *bool `form:"locked"` | ||||||
| 	// New Source values for this account | 	// New Source values for this account | ||||||
| 	Source *UpdateSource `form:"source"` | 	Source *UpdateSource `form:"source"` | ||||||
| 	// Profile metadata name and value | 	// Profile metadata name and value | ||||||
| 	FieldsAttributes []UpdateField `form:"fields_attributes"` | 	FieldsAttributes *[]UpdateField `form:"fields_attributes"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // UpdateSource is to be used specifically in an UpdateCredentialsRequest. | // UpdateSource is to be used specifically in an UpdateCredentialsRequest. | ||||||
| type UpdateSource struct { | type UpdateSource struct { | ||||||
| 	// Default post privacy for authored statuses. | 	// Default post privacy for authored statuses. | ||||||
| 	Privacy string `form:"privacy"` | 	Privacy *string `form:"privacy"` | ||||||
| 	// Whether to mark authored statuses as sensitive by default. | 	// Whether to mark authored statuses as sensitive by default. | ||||||
| 	Sensitive bool `form:"sensitive"` | 	Sensitive *bool `form:"sensitive"` | ||||||
| 	// Default language to use for authored statuses. (ISO 6391) | 	// Default language to use for authored statuses. (ISO 6391) | ||||||
| 	Language string `form:"language"` | 	Language *string `form:"language"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // UpdateField is to be used specifically in an UpdateCredentialsRequest. | // UpdateField is to be used specifically in an UpdateCredentialsRequest. | ||||||
| // By default, max 4 fields and 255 characters per property/value. | // By default, max 4 fields and 255 characters per property/value. | ||||||
| type UpdateField struct { | type UpdateField struct { | ||||||
| 	// Name of the field | 	// Name of the field | ||||||
| 	Name string `form:"name"` | 	Name *string `form:"name"` | ||||||
| 	// Value of the field | 	// Value of the field | ||||||
| 	Value string `form:"value"` | 	Value *string `form:"value"` | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue