diff --git a/internal/db/db.go b/internal/db/db.go index 186a9e674..2cd9c1562 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -92,6 +92,9 @@ type DB interface { // The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice. UpdateByID(id string, i interface{}) error + // UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value. + UpdateOneByID(id string, key string, value interface{}, i interface{}) error + // DeleteByID removes i with id id. // If i didn't exist anyway, then no error should be returned. DeleteByID(id string, i interface{}) error @@ -156,7 +159,15 @@ type DB interface { NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) // SetHeaderOrAvatarForAccountID sets the header or avatar for the given accountID to the given media attachment. - SetHeaderOrAvatarForAccountID(mediaAttachmen *model.MediaAttachment, accountID string) error + SetHeaderOrAvatarForAccountID(mediaAttachment *model.MediaAttachment, accountID string) error + + // GetHeaderAvatarForAccountID gets the current avatar for the given account ID. + // The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists. + GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error + + // GetHeaderForAccountID gets the current header for the given account ID. + // The passed mediaAttachment pointer will be populated with the value of the header, if it exists. + GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error /* USEFUL CONVERSION FUNCTIONS diff --git a/internal/db/pg.go b/internal/db/pg.go index c2471416d..8d6c4a763 100644 --- a/internal/db/pg.go +++ b/internal/db/pg.go @@ -274,6 +274,11 @@ func (ps *postgresService) UpdateByID(id string, i interface{}) error { return nil } +func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error { + _, err := ps.conn.Model(i).Set("? = ?", key, value).Where("id = ?", id).Update() + return err +} + func (ps *postgresService) DeleteByID(id string, i interface{}) error { if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil { if err == pg.ErrNoRows { @@ -468,6 +473,26 @@ func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model. return err } +func (ps *postgresService) GetHeaderForAccountID(header *model.MediaAttachment, accountID string) error { + if err := ps.conn.Model(header).Where("account_id = ?", accountID).Where("header = ?", true).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + +func (ps *postgresService) GetAvatarForAccountID(avatar *model.MediaAttachment, accountID string) error { + if err := ps.conn.Model(avatar).Where("account_id = ?", accountID).Where("avatar = ?", true).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + /* CONVERSION FUNCTIONS */ @@ -478,18 +503,6 @@ func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *model. // that the account actually belongs to. func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotypes.Account, error) { - fields := []mastotypes.Field{} - for _, f := range a.Fields { - mField := mastotypes.Field{ - Name: f.Name, - Value: f.Value, - } - if !f.VerifiedAt.IsZero() { - mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339) - } - fields = append(fields, mField) - } - // count followers followers := []model.Follow{} if err := ps.GetFollowersByAccountID(a.ID, &followers); err != nil { @@ -538,6 +551,39 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339) } + // build the avatar and header URLs + avi := &model.MediaAttachment{} + if err := ps.GetAvatarForAccountID(avi, a.ID); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting avatar: %s", err) + } + } + aviURL := avi.File.Path + aviURLStatic := avi.Thumbnail.Path + + header := &model.MediaAttachment{} + if err := ps.GetHeaderForAccountID(avi, a.ID); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting header: %s", err) + } + } + headerURL := header.File.Path + headerURLStatic := header.Thumbnail.Path + + // get the fields set on this account + fields := []mastotypes.Field{} + for _, f := range a.Fields { + mField := mastotypes.Field{ + Name: f.Name, + Value: f.Value, + } + if !f.VerifiedAt.IsZero() { + mField.VerifiedAt = f.VerifiedAt.Format(time.RFC3339) + } + fields = append(fields, mField) + } + + // check pending follow requests aimed at this account fr := []model.FollowRequest{} if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil { if _, ok := err.(ErrNoEntries); !ok { @@ -549,6 +595,7 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype frc = len(fr) } + // derive source from fields and other info source := &mastotypes.Source{ Privacy: a.Privacy, Sensitive: a.Sensitive, @@ -567,17 +614,17 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype Bot: a.Bot, CreatedAt: a.CreatedAt.Format(time.RFC3339), Note: a.Note, - URL: a.URL, - Avatar: a.AvatarRemoteURL.String(), - AvatarStatic: a.AvatarRemoteURL.String(), - Header: a.HeaderRemoteURL.String(), - HeaderStatic: a.HeaderRemoteURL.String(), + URL: a.URL, // TODO: set this during account creation + Avatar: aviURL, // TODO: build this url properly using host and protocol from config + AvatarStatic: aviURLStatic, // TODO: build this url properly using host and protocol from config + Header: headerURL, // TODO: build this url properly using host and protocol from config + HeaderStatic: headerURLStatic, // TODO: build this url properly using host and protocol from config FollowersCount: followersCount, FollowingCount: followingCount, StatusesCount: statusesCount, LastStatusAt: lastStatusAt, Source: source, - Emojis: nil, + Emojis: nil, // TODO: implement this Fields: fields, }, nil } diff --git a/internal/module/account/account.go b/internal/module/account/account.go index 96e6428c1..d749f7981 100644 --- a/internal/module/account/account.go +++ b/internal/module/account/account.go @@ -19,8 +19,11 @@ package account import ( + "bytes" "errors" "fmt" + "io" + "mime/multipart" "net" "net/http" @@ -39,8 +42,9 @@ import ( ) const ( + idKey = "id" basePath = "/api/v1/accounts" - basePathWithID = basePath + "/:id" + basePathWithID = basePath + "/:" + idKey verifyPath = basePath + "/verify_credentials" updateCredentialsPath = basePath + "/update_credentials" ) @@ -144,6 +148,10 @@ func (m *accountModule) accountVerifyGETHandler(c *gin.Context) { // accountUpdateCredentialsPATCHHandler allows a user to modify their account/profile settings. // It should be served as a PATCH at /api/v1/accounts/update_credentials +// +// TODO: this can be optimized massively by building up a picture of what we want the new account +// details to be, and then inserting it all in the database at once. As it is, we do queries one-by-one +// which is not gonna make the database very happy when lots of requests are going through. func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) { l := m.log.WithField("func", "accountUpdateCredentialsPATCHHandler") authed, err := oauth.MustAuth(c, true, false, false, true) @@ -152,63 +160,180 @@ func (m *accountModule) accountUpdateCredentialsPATCHHandler(c *gin.Context) { c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } + l.Tracef("retrieved account %+v", authed.Account.ID) l.Trace("parsing request form") form := &mastotypes.UpdateCredentialsRequest{} if err := c.ShouldBind(form); err != nil || form == nil { l.Debugf("could not parse form from request: %s", err) - c.JSON(http.StatusBadRequest, gin.H{"error": "missing one or more required form values"}) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // TODO: proper form validation - - // TODO: tidy this code into subfunctions - if form.Header != nil && form.Header.Size != 0 { - if form.Header.Size > m.config.MediaConfig.MaxImageSize { - err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", form.Header.Size, m.config.MediaConfig.MaxImageSize) - l.Debugf("error processing header: %s", err) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - f, err := form.Header.Open() - if err != nil { - l.Debugf("error processing header: %s", err) - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not read provided header: %s", err)}) - return - } - - // extract the bytes - imageBytes := []byte{} - size, err := f.Read(imageBytes) - defer func(){ - if err := f.Close(); err != nil { - m.log.Errorf("error closing multipart file: %s", err) - } - }() - if err != nil || size == 0 { - l.Debugf("error processing header: %s", err) - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("could not read provided header: %s", err)}) - return - } - - // do the setting - headerInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(imageBytes, authed.Account.ID, "header") - if err != nil { - l.Debugf("error processing header: %s", err) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - l.Tracef("new header info for account %s is %+v", headerInfo) + // if everything on the form is nil, then nothing has been set and we shouldn't continue + if form.Discoverable == nil && form.Bot == nil && form.DisplayName == nil && form.Note == nil && form.Avatar == nil && form.Header == nil && form.Locked == nil && form.Source == nil && form.FieldsAttributes == nil { + l.Debugf("could not parse form from request") + c.JSON(http.StatusBadRequest, gin.H{"error": "empty form submitted"}) + return } - l.Tracef("retrieved account %+v", authed.Account.ID) + if form.Discoverable != nil { + if err := m.db.UpdateOneByID(authed.Account.ID, "discoverable", *form.Discoverable, &model.Account{}); err != nil { + l.Debugf("error updating discoverable: %s", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + if form.Bot != nil { + if err := m.db.UpdateOneByID(authed.Account.ID, "bot", *form.Bot, &model.Account{}); err != nil { + l.Debugf("error updating bot: %s", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + if form.DisplayName != nil { + if err := m.db.UpdateOneByID(authed.Account.ID, "display_name", *form.DisplayName, &model.Account{}); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + if form.Note != nil { + if err := m.db.UpdateOneByID(authed.Account.ID, "note", *form.Note, &model.Account{}); err != nil { + l.Debugf("error updating note: %s", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + if form.Avatar != nil && form.Avatar.Size != 0 { + avatarInfo, err := m.UpdateAccountAvatar(form.Avatar, authed.Account.ID) + if err != nil { + l.Debugf("could not update avatar for account %s: %s", authed.Account.ID, err) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + l.Tracef("new avatar info for account %s is %+v", authed.Account.ID, avatarInfo) + } + + if form.Header != nil && form.Header.Size != 0 { + headerInfo, err := m.UpdateAccountHeader(form.Header, authed.Account.ID) + if err != nil { + l.Debugf("could not update header for account %s: %s", authed.Account.ID, err) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + l.Tracef("new header info for account %s is %+v", authed.Account.ID, headerInfo) + } + + if form.Locked != nil { + if err := m.db.UpdateOneByID(authed.Account.ID, "locked", *form.Locked, &model.Account{}); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"": err.Error()}) + return + } + } + + if form.Source != nil { + + } + + if form.FieldsAttributes != nil { + + } + + // fetch the account with all updated values set + updatedAccount := &model.Account{} + if err := m.db.GetByID(authed.Account.ID, updatedAccount); err != nil { + l.Debugf("could not fetch updated account %s: %s", authed.Account.ID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + acctSensitive, err := m.db.AccountToMastoSensitive(updatedAccount) + if err != nil { + l.Tracef("could not convert account into mastosensitive account: %s", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive) + c.JSON(http.StatusOK, acctSensitive) } /* HELPER FUNCTIONS */ +// TODO: try to combine the below two functions because this is a lot of code repetition. + +// UpdateAccountAvatar does the dirty work of checking the avatar part of an account update form, +// parsing and checking the image, and doing the necessary updates in the database for this to become +// the account's new avatar image. +func (m *accountModule) UpdateAccountAvatar(avatar *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) { + var err error + if avatar.Size > m.config.MediaConfig.MaxImageSize { + err = fmt.Errorf("avatar with size %d exceeded max image size of %d bytes", avatar.Size, m.config.MediaConfig.MaxImageSize) + return nil, err + } + f, err := avatar.Open() + if err != nil { + return nil, fmt.Errorf("could not read provided avatar: %s", err) + } + + // extract the bytes + buf := new(bytes.Buffer) + size, err := io.Copy(buf, f) + if err != nil { + return nil, fmt.Errorf("could not read provided avatar: %s", err) + } + if size == 0 { + return nil, errors.New("could not read provided avatar: size 0 bytes") + } + + // do the setting + avatarInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(buf.Bytes(), accountID, "avatar") + if err != nil { + return nil, fmt.Errorf("error processing avatar: %s", err) + } + + return avatarInfo, f.Close() +} + +// UpdateAccountHeader does the dirty work of checking the header part of an account update form, +// parsing and checking the image, and doing the necessary updates in the database for this to become +// the account's new header image. +func (m *accountModule) UpdateAccountHeader(header *multipart.FileHeader, accountID string) (*model.MediaAttachment, error) { + var err error + if header.Size > m.config.MediaConfig.MaxImageSize { + err = fmt.Errorf("header with size %d exceeded max image size of %d bytes", header.Size, m.config.MediaConfig.MaxImageSize) + return nil, err + } + f, err := header.Open() + if err != nil { + return nil, fmt.Errorf("could not read provided header: %s", err) + } + + // extract the bytes + buf := new(bytes.Buffer) + size, err := io.Copy(buf, f) + if err != nil { + return nil, fmt.Errorf("could not read provided header: %s", err) + } + if size == 0 { + return nil, errors.New("could not read provided header: size 0 bytes") + } + + // do the setting + headerInfo, err := m.mediaHandler.SetHeaderOrAvatarForAccountID(buf.Bytes(), accountID, "header") + if err != nil { + return nil, fmt.Errorf("error processing header: %s", err) + } + + return headerInfo, f.Close() +} + // accountCreate does the dirty work of making an account and user in the database. // It then returns a token to the caller, for use with the new account, as per the // spec here: https://docs.joinmastodon.org/methods/accounts/ diff --git a/internal/module/account/account_test.go b/internal/module/account/account_test.go index c515bccb4..293f5512d 100644 --- a/internal/module/account/account_test.go +++ b/internal/module/account/account_test.go @@ -19,13 +19,17 @@ package account import ( + "bytes" "context" "encoding/json" "fmt" + "io" "io/ioutil" + "mime/multipart" "net/http" "net/http/httptest" "net/url" + "os" "testing" "time" @@ -40,6 +44,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db/model" "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/pkg/mastotypes" "github.com/superseriousbusiness/oauth2/v4" "github.com/superseriousbusiness/oauth2/v4/models" @@ -57,7 +62,8 @@ type AccountTestSuite struct { testApplication *model.Application testToken oauth2.TokenInfo mockOauthServer *oauth.MockServer - mockMediaHandler *media.MockMediaHandler + mockStorage *storage.MockStorage + mediaHandler media.MediaHandler db db.DB accountModule *accountModule newUserFormHappyPath url.Values @@ -74,6 +80,11 @@ func (suite *AccountTestSuite) SetupSuite() { log.SetLevel(logrus.TraceLevel) suite.log = log + suite.testAccountLocal = &model.Account{ + ID: uuid.NewString(), + Username: "test_user", + } + // can use this test application throughout suite.testApplication = &model.Application{ ID: "weeweeeeeeeeeeeeee", @@ -107,6 +118,9 @@ func (suite *AccountTestSuite) SetupSuite() { Database: "postgres", ApplicationName: "gotosocial", } + c.MediaConfig = &config.MediaConfig{ + MaxImageSize: 2 << 20, + } suite.config = c // use an actual database for this, because it's just easier than mocking one out @@ -130,11 +144,15 @@ func (suite *AccountTestSuite) SetupSuite() { Code: "we're authorized now!", }, nil) - // mock the media handler because some handlers (eg update credentials) need to upload media (new header/avatar) - suite.mockMediaHandler = &media.MockMediaHandler{} + suite.mockStorage = &storage.MockStorage{} + // We don't need storage to do anything for these tests, so just simulate a success and do nothing -- we won't need to return anything from storage + suite.mockStorage.On("StoreFileAt", mock.AnythingOfType("string"), mock.AnythingOfType("[]uint8")).Return(nil) + + // set a media handler because some handlers (eg update credentials) need to upload media (new header/avatar) + suite.mediaHandler = media.New(suite.config, suite.db, suite.mockStorage, log) // and finally here's the thing we're actually testing! - suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mockMediaHandler, suite.log).(*accountModule) + suite.accountModule = New(suite.config, suite.db, suite.mockOauthServer, suite.mediaHandler, suite.log).(*accountModule) } func (suite *AccountTestSuite) TearDownSuite() { @@ -150,9 +168,11 @@ func (suite *AccountTestSuite) SetupTest() { &model.User{}, &model.Account{}, &model.Follow{}, + &model.FollowRequest{}, &model.Status{}, &model.Application{}, &model.EmailDomainBlock{}, + &model.MediaAttachment{}, } for _, m := range models { if err := suite.db.CreateTable(m); err != nil { @@ -186,9 +206,11 @@ func (suite *AccountTestSuite) TearDownTest() { &model.User{}, &model.Account{}, &model.Follow{}, + &model.FollowRequest{}, &model.Status{}, &model.Application{}, &model.EmailDomainBlock{}, + &model.MediaAttachment{}, } for _, m := range models { if err := suite.db.DropTable(m); err != nil { @@ -201,6 +223,10 @@ func (suite *AccountTestSuite) TearDownTest() { ACTUAL TESTS */ +/* + TESTING: AccountCreatePOSTHandler +*/ + // TestAccountCreatePOSTHandlerSuccessful checks the happy path for an account creation request: all the fields provided are valid, // and at the end of it a new user and account should be added into the database. // @@ -455,6 +481,58 @@ func (suite *AccountTestSuite) TestAccountCreatePOSTHandlerInsufficientReason() assert.Equal(suite.T(), `{"error":"reason should be at least 40 chars but 'just cuz' was 8"}`, string(b)) } +/* + TESTING: AccountUpdateCredentialsPATCHHandler +*/ + +func (suite *AccountTestSuite) TestAccountUpdateCredentialsPATCHHandler() { + + // put test local account in db + err := suite.db.Put(suite.testAccountLocal) + assert.NoError(suite.T(), err) + + // attach avatar to request + aviFile, err := os.Open("../../media/test/test-jpeg.jpg") + assert.NoError(suite.T(), err) + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + part, err := writer.CreateFormFile("avatar", "test-jpeg.jpg") + assert.NoError(suite.T(), err) + + _, err = io.Copy(part, aviFile) + assert.NoError(suite.T(), err) + + err = aviFile.Close() + assert.NoError(suite.T(), err) + + err = writer.Close() + assert.NoError(suite.T(), err) + + // setup + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccountLocal) + ctx.Set(oauth.SessionAuthorizedToken, suite.testToken) + ctx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("http://localhost:8080/%s", updateCredentialsPath), body) // the endpoint we're hitting + ctx.Request.Header.Set("Content-Type", writer.FormDataContentType()) + suite.accountModule.accountUpdateCredentialsPATCHHandler(ctx) + + // check response + + // 1. we should have OK because our request was valid + suite.EqualValues(http.StatusOK, recorder.Code) + + // 2. we should have an error message in the result body + result := recorder.Result() + defer result.Body.Close() + // TODO: implement proper checks here + // + // b, err := ioutil.ReadAll(result.Body) + // assert.NoError(suite.T(), err) + // assert.Equal(suite.T(), `{"error":"not authorized"}`, string(b)) +} + func TestAccountTestSuite(t *testing.T) { suite.Run(t, new(AccountTestSuite)) } diff --git a/pkg/mastotypes/account.go b/pkg/mastotypes/account.go index 6ab5b0430..3ddd3c517 100644 --- a/pkg/mastotypes/account.go +++ b/pkg/mastotypes/account.go @@ -92,40 +92,40 @@ type AccountCreateRequest struct { // See https://docs.joinmastodon.org/methods/accounts/ type UpdateCredentialsRequest struct { // Whether the account should be shown in the profile directory. - Discoverable string `form:"discoverable"` + Discoverable *bool `form:"discoverable"` // Whether the account has a bot flag. - Bot bool `form:"bot"` + Bot *bool `form:"bot"` // The display name to use for the profile. - DisplayName string `form:"display_name"` + DisplayName *string `form:"display_name"` // The account bio. - Note string `form:"note"` + Note *string `form:"note"` // Avatar image encoded using multipart/form-data Avatar *multipart.FileHeader `form:"avatar"` // Header image encoded using multipart/form-data Header *multipart.FileHeader `form:"header"` // Whether manual approval of follow requests is required. - Locked bool `form:"locked"` + Locked *bool `form:"locked"` // New Source values for this account Source *UpdateSource `form:"source"` // Profile metadata name and value - FieldsAttributes []UpdateField `form:"fields_attributes"` + FieldsAttributes *[]UpdateField `form:"fields_attributes"` } // UpdateSource is to be used specifically in an UpdateCredentialsRequest. type UpdateSource struct { // Default post privacy for authored statuses. - Privacy string `form:"privacy"` + Privacy *string `form:"privacy"` // Whether to mark authored statuses as sensitive by default. - Sensitive bool `form:"sensitive"` + Sensitive *bool `form:"sensitive"` // Default language to use for authored statuses. (ISO 6391) - Language string `form:"language"` + Language *string `form:"language"` } // UpdateField is to be used specifically in an UpdateCredentialsRequest. // By default, max 4 fields and 255 characters per property/value. type UpdateField struct { // Name of the field - Name string `form:"name"` + Name *string `form:"name"` // Value of the field - Value string `form:"value"` + Value *string `form:"value"` }