diff --git a/internal/db/db.go b/internal/db/db.go index fbd13d729..641a2efcf 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -21,6 +21,7 @@ package db import ( "context" "fmt" + "net" "strings" "github.com/go-fed/activity/pub" @@ -145,6 +146,10 @@ type DB interface { // C) something went wrong in the db IsEmailAvailable(email string) error + // NewSignup creates a new user in the database with the given parameters, with an *unconfirmed* email address. + // By the time this function is called, it should be assumed that all the parameters have passed validation! + NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string) (*model.User, error) + /* USEFUL CONVERSION FUNCTIONS */ diff --git a/internal/db/model/account.go b/internal/db/model/account.go index 2954e386a..b4e8c2546 100644 --- a/internal/db/model/account.go +++ b/internal/db/model/account.go @@ -23,6 +23,7 @@ package model import ( + "crypto/rsa" "net/url" "time" ) @@ -82,6 +83,8 @@ type Account struct { SubscriptionExpiresAt time.Time `pg:"type:timestamp"` // Does this account identify itself as a bot? Bot bool + // What reason was given for signing up when this account was created? + Reason string /* PRIVACY SETTINGS @@ -123,9 +126,9 @@ type Account struct { Secret string // Privatekey for validating activitypub requests, will obviously only be defined for local accounts - PrivateKey string + PrivateKey *rsa.PrivateKey // Publickey for encoding activitypub requests, will be defined for both local and remote accounts - PublicKey string + PublicKey *rsa.PublicKey /* ADMIN FIELDS diff --git a/internal/db/model/domainblock.go b/internal/db/model/domainblock.go index 19c2c5fa5..e6e89bc20 100644 --- a/internal/db/model/domainblock.go +++ b/internal/db/model/domainblock.go @@ -35,13 +35,13 @@ type DomainBlock struct { // Account ID of the creator of this block CreatedByAccountID string `pg:",notnull"` // TODO: define this - Severity int + Severity int // Reject media from this domain? - RejectMedia bool + RejectMedia bool // Reject reports from this domain? - RejectReports bool + RejectReports bool // Private comment on this block, viewable to admins - PrivateComment string + PrivateComment string // Public comment on this block, viewable (optionally) by everyone - PublicComment string + PublicComment string } diff --git a/internal/db/pg.go b/internal/db/pg.go index 90c2d4687..bc0cc0501 100644 --- a/internal/db/pg.go +++ b/internal/db/pg.go @@ -20,8 +20,11 @@ package db import ( "context" + "crypto/rand" + "crypto/rsa" "errors" "fmt" + "net" "net/mail" "regexp" "strings" @@ -35,6 +38,7 @@ import ( "github.com/gotosocial/gotosocial/internal/db/model" "github.com/gotosocial/gotosocial/pkg/mastotypes" "github.com/sirupsen/logrus" + "golang.org/x/crypto/bcrypt" ) // postgresService satisfies the DB interface @@ -305,7 +309,6 @@ func (ps *postgresService) GetAccountByUserID(userID string, account *model.Acco return err } if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil { - fmt.Println(account) if err == pg.ErrNoRows { return ErrNoEntries{} } @@ -400,7 +403,7 @@ func (ps *postgresService) IsEmailAvailable(email string) error { // fail because we got an unexpected error return fmt.Errorf("db error: %s", err) } - + // check if this email is associated with an account already if err := ps.conn.Model(&model.Account{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil { // fail because we found something @@ -412,6 +415,43 @@ func (ps *postgresService) IsEmailAvailable(email string) error { return nil } +func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string) (*model.User, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + ps.log.Errorf("error creating new rsa key: %s", err) + return nil, err + } + + a := &model.Account{ + Username: username, + DisplayName: username, + Reason: reason, + PrivateKey: key, + PublicKey: &key.PublicKey, + ActorType: "Person", + } + if _, err = ps.conn.Model(a).Insert(); err != nil { + return nil, err + } + + pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return nil, fmt.Errorf("error hashing password: %s", err) + } + u := &model.User{ + AccountID: a.ID, + EncryptedPassword: string(pw), + SignUpIP: signUpIP, + Locale: locale, + UnconfirmedEmail: email, + } + if _, err = ps.conn.Model(u).Insert(); err != nil { + return nil, err + } + + return u, nil +} + /* CONVERSION FUNCTIONS */ @@ -433,7 +473,6 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype } fields = append(fields, mField) } - fmt.Printf("fields: %+v", fields) // count followers followers := []model.Follow{} diff --git a/internal/db/pg_test.go b/internal/db/pg_test.go new file mode 100644 index 000000000..f9bd21c48 --- /dev/null +++ b/internal/db/pg_test.go @@ -0,0 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package db + +// TODO: write tests for postgres diff --git a/internal/db/pg-fed.go b/internal/db/pgfed.go similarity index 83% rename from internal/db/pg-fed.go rename to internal/db/pgfed.go index ec1957abc..de9bbd8ab 100644 --- a/internal/db/pg-fed.go +++ b/internal/db/pgfed.go @@ -1,3 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + package db import ( diff --git a/internal/db/pgfed_test.go b/internal/db/pgfed_test.go new file mode 100644 index 000000000..529d2efd0 --- /dev/null +++ b/internal/db/pgfed_test.go @@ -0,0 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package db + +// TODO: write tests for pgfed diff --git a/internal/module/account/account.go b/internal/module/account/account.go index c820f6618..8276a01f4 100644 --- a/internal/module/account/account.go +++ b/internal/module/account/account.go @@ -19,6 +19,8 @@ package account import ( + "fmt" + "net" "net/http" "github.com/gin-gonic/gin" @@ -26,9 +28,10 @@ import ( "github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/gotosocial/internal/db/model" "github.com/gotosocial/gotosocial/internal/module" - "github.com/gotosocial/gotosocial/internal/module/oauth" + "github.com/gotosocial/gotosocial/internal/oauth" "github.com/gotosocial/gotosocial/internal/router" "github.com/gotosocial/gotosocial/pkg/mastotypes" + "github.com/gotosocial/oauth2/v4" "github.com/sirupsen/logrus" ) @@ -39,9 +42,10 @@ const ( ) type accountModule struct { - config *config.Config - db db.DB - log *logrus.Logger + config *config.Config + db db.DB + oauthServer oauth.Server + log *logrus.Logger } // New returns a new account module @@ -60,15 +64,15 @@ func (m *accountModule) Route(r router.Router) error { return nil } +// accountCreatePOSTHandler handles create account requests, validates them, +// and puts them in the database if they're valid. +// It should be served as a POST at /api/v1/accounts func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) { - l := m.log.WithField("func", "AccountCreatePOSTHandler") - // TODO: check whether a valid app token has been presented!! - // See: https://docs.joinmastodon.org/methods/accounts/ - - l.Trace("checking if registration is open") - if !m.config.AccountsConfig.OpenRegistration { - l.Debug("account registration is closed, returning error to client") - c.JSON(http.StatusUnauthorized, gin.H{"error": "account registration is closed"}) + l := m.log.WithField("func", "accountCreatePOSTHandler") + authed, err := oauth.GetAuthed(c) + if err != nil { + l.Debugf("couldn't auth: %s", err) + c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } @@ -81,15 +85,34 @@ func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) { } l.Tracef("validating form %+v", form) - if err := validateCreateAccount(form, m.config.AccountsConfig.ReasonRequired, m.db); err != nil { + if err := validateCreateAccount(form, m.config.AccountsConfig, m.db); err != nil { l.Debugf("error validating form: %s", err) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + + clientIP := c.ClientIP() + l.Tracef("attempting to parse client ip address %s", clientIP) + signUpIP := net.ParseIP(clientIP) + if signUpIP == nil { + l.Debugf("error validating sign up ip address %s", clientIP) + c.JSON(http.StatusBadRequest, gin.H{"error": "ip address could not be parsed from request"}) + return + } + + ti, err := m.accountCreate(form, signUpIP, authed.Token, authed.Application) + if err != nil { + l.Errorf("internal server error while creating new account: %s", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, ti) } // accountVerifyGETHandler serves a user's account details to them IF they reached this // handler while in possession of a valid token, according to the oauth middleware. +// It should be served as a GET at /api/v1/accounts/verify_credentials func (m *accountModule) accountVerifyGETHandler(c *gin.Context) { l := m.log.WithField("func", "AccountVerifyGETHandler") @@ -120,3 +143,39 @@ func (m *accountModule) accountVerifyGETHandler(c *gin.Context) { l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive) c.JSON(http.StatusOK, acctSensitive) } + +/* + HELPER FUNCTIONS +*/ + +// accountCreate does the dirty work of making an account and user in the database. +// It then returns a token to the caller, for use with the new account, as per the +// spec here: https://docs.joinmastodon.org/methods/accounts/ +func (m *accountModule) accountCreate(form *mastotypes.AccountCreateRequest, signUpIP net.IP, token oauth2.TokenInfo, app *model.Application) (*mastotypes.Token, error) { + l := m.log.WithField("func", "accountCreate") + + // don't store a reason if we don't require one + reason := form.Reason + if !m.config.AccountsConfig.ReasonRequired { + reason = "" + } + + l.Trace("creating new username and account") + user, err := m.db.NewSignup(form.Username, reason, m.config.AccountsConfig.RequireApproval, form.Email, form.Password, signUpIP, form.Locale) + if err != nil { + return nil, fmt.Errorf("error creating new signup in the database: %s", err) + } + + l.Tracef("generating a token for user %s with account %s and application %s", user.ID, user.AccountID, app.ID) + ti, err := m.oauthServer.GenerateUserAccessToken(token, app.ClientSecret, user.ID) + if err != nil { + return nil, fmt.Errorf("error creating new access token for user %s: %s", user.ID, err) + } + + return &mastotypes.Token{ + AccessToken: ti.GetCode(), + TokenType: "Bearer", + Scope: ti.GetScope(), + CreatedAt: ti.GetCodeCreateAt().Unix(), + }, nil +} diff --git a/internal/module/account/account_test.go b/internal/module/account/account_test.go index 6ba2558a5..2d9076dff 100644 --- a/internal/module/account/account_test.go +++ b/internal/module/account/account_test.go @@ -20,34 +20,33 @@ package account import ( "context" - "fmt" - "net/url" + "net/http/httptest" "testing" - "time" "github.com/gin-gonic/gin" - "github.com/google/uuid" "github.com/gotosocial/gotosocial/internal/config" "github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/gotosocial/internal/db/model" - "github.com/gotosocial/gotosocial/internal/module/oauth" - "github.com/gotosocial/gotosocial/internal/router" "github.com/sirupsen/logrus" "github.com/stretchr/testify/suite" - "golang.org/x/crypto/bcrypt" ) type AccountTestSuite struct { suite.Suite - db db.DB + log *logrus.Logger testAccountLocal *model.Account testAccountRemote *model.Account testUser *model.User - config *config.Config + db db.DB + accountModule *accountModule } // SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout func (suite *AccountTestSuite) SetupSuite() { + log := logrus.New() + log.SetLevel(logrus.TraceLevel) + suite.log = log + c := config.Empty() c.DBConfig = &config.DBConfig{ Type: "postgres", @@ -58,118 +57,126 @@ func (suite *AccountTestSuite) SetupSuite() { Database: "postgres", ApplicationName: "gotosocial", } - suite.config = c - encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) + database, err := db.New(context.Background(), c, log) if err != nil { - logrus.Panicf("error encrypting user pass: %s", err) + suite.FailNow(err.Error()) + } + suite.db = database + + suite.accountModule = &accountModule{ + config: c, + db: database, + log: log, } - localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png") - if err != nil { - logrus.Panicf("error parsing localavatar url: %s", err) - } - localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png") - if err != nil { - logrus.Panicf("error parsing localheader url: %s", err) - } + // encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) + // if err != nil { + // logrus.Panicf("error encrypting user pass: %s", err) + // } - acctID := uuid.NewString() - suite.testAccountLocal = &model.Account{ - ID: acctID, - Username: "local_account_of_some_kind", - AvatarRemoteURL: localAvatar, - HeaderRemoteURL: localHeader, - DisplayName: "michael caine", - Fields: []model.Field{ - { - Name: "come and ave a go", - Value: "if you think you're hard enough", - }, - { - Name: "website", - Value: "https://imdb.com", - VerifiedAt: time.Now(), - }, - }, - Note: "My name is Michael Caine and i'm a local user.", - Discoverable: true, - } + // localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png") + // if err != nil { + // logrus.Panicf("error parsing localavatar url: %s", err) + // } + // localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png") + // if err != nil { + // logrus.Panicf("error parsing localheader url: %s", err) + // } - avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png") - if err != nil { - logrus.Panicf("error parsing avatarURL: %s", err) - } + // acctID := uuid.NewString() + // suite.testAccountLocal = &model.Account{ + // ID: acctID, + // Username: "local_account_of_some_kind", + // AvatarRemoteURL: localAvatar, + // HeaderRemoteURL: localHeader, + // DisplayName: "michael caine", + // Fields: []model.Field{ + // { + // Name: "come and ave a go", + // Value: "if you think you're hard enough", + // }, + // { + // Name: "website", + // Value: "https://imdb.com", + // VerifiedAt: time.Now(), + // }, + // }, + // Note: "My name is Michael Caine and i'm a local user.", + // Discoverable: true, + // } - headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png") - if err != nil { - logrus.Panicf("error parsing avatarURL: %s", err) - } - suite.testAccountRemote = &model.Account{ - ID: uuid.NewString(), - Username: "neato_bombeato", - Domain: "example.org", + // avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png") + // if err != nil { + // logrus.Panicf("error parsing avatarURL: %s", err) + // } - AvatarFileName: "avatar.png", - AvatarContentType: "image/png", - AvatarFileSize: 1024, - AvatarUpdatedAt: time.Now(), - AvatarRemoteURL: avatarURL, + // headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png") + // if err != nil { + // logrus.Panicf("error parsing avatarURL: %s", err) + // } + // suite.testAccountRemote = &model.Account{ + // ID: uuid.NewString(), + // Username: "neato_bombeato", + // Domain: "example.org", - HeaderFileName: "avatar.png", - HeaderContentType: "image/png", - HeaderFileSize: 1024, - HeaderUpdatedAt: time.Now(), - HeaderRemoteURL: headerURL, + // AvatarFileName: "avatar.png", + // AvatarContentType: "image/png", + // AvatarFileSize: 1024, + // AvatarUpdatedAt: time.Now(), + // AvatarRemoteURL: avatarURL, - DisplayName: "one cool dude 420", - Fields: []model.Field{ - { - Name: "pronouns", - Value: "he/they", - }, - { - Name: "website", - Value: "https://imcool.edu", - VerifiedAt: time.Now(), - }, - }, - Note: "

I'm cool as heck!

", - Discoverable: true, - URI: "https://example.org/users/neato_bombeato", - URL: "https://example.org/@neato_bombeato", - LastWebfingeredAt: time.Now(), - InboxURL: "https://example.org/users/neato_bombeato/inbox", - OutboxURL: "https://example.org/users/neato_bombeato/outbox", - SharedInboxURL: "https://example.org/inbox", - FollowersURL: "https://example.org/users/neato_bombeato/followers", - FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured", - } - suite.testUser = &model.User{ - ID: uuid.NewString(), - EncryptedPassword: string(encryptedPassword), - Email: "user@example.org", - AccountID: acctID, + // HeaderFileName: "avatar.png", + // HeaderContentType: "image/png", + // HeaderFileSize: 1024, + // HeaderUpdatedAt: time.Now(), + // HeaderRemoteURL: headerURL, + + // DisplayName: "one cool dude 420", + // Fields: []model.Field{ + // { + // Name: "pronouns", + // Value: "he/they", + // }, + // { + // Name: "website", + // Value: "https://imcool.edu", + // VerifiedAt: time.Now(), + // }, + // }, + // Note: "

I'm cool as heck!

", + // Discoverable: true, + // URI: "https://example.org/users/neato_bombeato", + // URL: "https://example.org/@neato_bombeato", + // LastWebfingeredAt: time.Now(), + // InboxURL: "https://example.org/users/neato_bombeato/inbox", + // OutboxURL: "https://example.org/users/neato_bombeato/outbox", + // SharedInboxURL: "https://example.org/inbox", + // FollowersURL: "https://example.org/users/neato_bombeato/followers", + // FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured", + // } + // suite.testUser = &model.User{ + // ID: uuid.NewString(), + // EncryptedPassword: string(encryptedPassword), + // Email: "user@example.org", + // AccountID: acctID, + // } +} + +func (suite *AccountTestSuite) TearDownSuite() { + if err := suite.db.Stop(context.Background()); err != nil { + logrus.Panicf("error closing db connection: %s", err) } } -// SetupTest creates a postgres connection and creates the oauth_clients table before each test +// SetupTest creates a db connection and creates necessary tables before each test func (suite *AccountTestSuite) SetupTest() { - - log := logrus.New() - log.SetLevel(logrus.TraceLevel) - db, err := db.New(context.Background(), suite.config, log) - if err != nil { - logrus.Panicf("error creating database connection: %s", err) - } - - suite.db = db - models := []interface{}{ &model.User{}, &model.Account{}, &model.Follow{}, &model.Status{}, + &model.Application{}, } for _, m := range models { @@ -177,70 +184,31 @@ func (suite *AccountTestSuite) SetupTest() { logrus.Panicf("db connection error: %s", err) } } - - if err := suite.db.Put(suite.testAccountLocal); err != nil { - logrus.Panicf("could not insert test account into db: %s", err) - } - if err := suite.db.Put(suite.testUser); err != nil { - logrus.Panicf("could not insert test user into db: %s", err) - } - } -// TearDownTest drops the oauth_clients table and closes the pg connection after each test +// TearDownTest drops tables to make sure there's no data in the db func (suite *AccountTestSuite) TearDownTest() { models := []interface{}{ &model.User{}, &model.Account{}, &model.Follow{}, &model.Status{}, + &model.Application{}, } for _, m := range models { if err := suite.db.DropTable(m); err != nil { logrus.Panicf("error dropping table: %s", err) } } - if err := suite.db.Stop(context.Background()); err != nil { - logrus.Panicf("error closing db connection: %s", err) - } - suite.db = nil } -func (suite *AccountTestSuite) TestAPIInitialize() { - log := logrus.New() - log.SetLevel(logrus.TraceLevel) - - r, err := router.New(suite.config, log) - if err != nil { - suite.FailNow(fmt.Sprintf("error creating router: %s", err)) - } - - r.AttachMiddleware(func(c *gin.Context) { - account := &model.Account{} - if err := suite.db.GetAccountByUserID(suite.testUser.ID, account); err != nil || account == nil { - suite.T().Log(err) - suite.FailNowf("no account found for user %s, continuing with unauthenticated request: %+v", "", suite.testUser.ID, account) - fmt.Println(account) - return - } - - c.Set(oauth.SessionAuthorizedAccount, account) - c.Set(oauth.SessionAuthorizedUser, suite.testUser.ID) - }) - - acct := New(suite.config, suite.db, log) - if err := acct.Route(r); err != nil { - suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) - } - - r.Start() - defer func() { - if err := r.Stop(context.Background()); err != nil { - panic(fmt.Errorf("error stopping router: %s", err)) - } - }() - time.Sleep(10 * time.Second) - +func (suite *AccountTestSuite) TestAccountCreatePOSTHandler() { + // TODO: figure out how to test this properly + recorder := httptest.NewRecorder() + recorder.Header().Set("X-Forwarded-For", "127.0.0.1") + ctx, _ := gin.CreateTestContext(recorder) + // ctx.Set() + suite.accountModule.accountCreatePOSTHandler(ctx) } func TestAccountTestSuite(t *testing.T) { diff --git a/internal/module/account/validation.go b/internal/module/account/validation.go index f2a0ac9cd..a0ad09406 100644 --- a/internal/module/account/validation.go +++ b/internal/module/account/validation.go @@ -21,12 +21,17 @@ package account import ( "errors" + "github.com/gotosocial/gotosocial/internal/config" "github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/gotosocial/internal/util" "github.com/gotosocial/gotosocial/pkg/mastotypes" ) -func validateCreateAccount(form *mastotypes.AccountCreateRequest, reasonRequired bool, database db.DB) error { +func validateCreateAccount(form *mastotypes.AccountCreateRequest, c *config.AccountsConfig, database db.DB) error { + if !c.OpenRegistration { + return errors.New("registration is not open for this server") + } + if err := util.ValidateSignUpUsername(form.Username); err != nil { return err } @@ -47,7 +52,7 @@ func validateCreateAccount(form *mastotypes.AccountCreateRequest, reasonRequired return err } - if err := util.ValidateSignUpReason(form.Reason, reasonRequired); err != nil { + if err := util.ValidateSignUpReason(form.Reason, c.ReasonRequired); err != nil { return err } diff --git a/internal/module/app/app.go b/internal/module/app/app.go new file mode 100644 index 000000000..b6dd685d4 --- /dev/null +++ b/internal/module/app/app.go @@ -0,0 +1,140 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package app + +import ( + "fmt" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/gotosocial/gotosocial/internal/db" + "github.com/gotosocial/gotosocial/internal/db/model" + "github.com/gotosocial/gotosocial/internal/module" + "github.com/gotosocial/gotosocial/internal/oauth" + "github.com/gotosocial/gotosocial/internal/router" + "github.com/gotosocial/gotosocial/pkg/mastotypes" + "github.com/sirupsen/logrus" +) + +const appsPath = "/api/v1/apps" + +type appModule struct { + server oauth.Server + db db.DB + log *logrus.Logger +} + +// New returns a new auth module +func New(srv oauth.Server, db db.DB, log *logrus.Logger) module.ClientAPIModule { + return &appModule{ + server: srv, + db: db, + log: log, + } +} + +// Route satisfies the RESTAPIModule interface +func (m *appModule) Route(s router.Router) error { + s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler) + return nil +} + +// appsPOSTHandler should be served at https://example.org/api/v1/apps +// It is equivalent to: https://docs.joinmastodon.org/methods/apps/ +func (m *appModule) appsPOSTHandler(c *gin.Context) { + l := m.log.WithField("func", "AppsPOSTHandler") + l.Trace("entering AppsPOSTHandler") + + form := &mastotypes.ApplicationPOSTRequest{} + if err := c.ShouldBind(form); err != nil { + c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()}) + return + } + + // permitted length for most fields + permittedLength := 64 + // redirect can be a bit bigger because we probably need to encode data in the redirect uri + permittedRedirect := 256 + + // check lengths of fields before proceeding so the user can't spam huge entries into the database + if len(form.ClientName) > permittedLength { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)}) + return + } + if len(form.Website) > permittedLength { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)}) + return + } + if len(form.RedirectURIs) > permittedRedirect { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)}) + return + } + if len(form.Scopes) > permittedLength { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)}) + return + } + + // set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/ + var scopes string + if form.Scopes == "" { + scopes = "read" + } else { + scopes = form.Scopes + } + + // generate new IDs for this application and its associated client + clientID := uuid.NewString() + clientSecret := uuid.NewString() + vapidKey := uuid.NewString() + + // generate the application to put in the database + app := &model.Application{ + Name: form.ClientName, + Website: form.Website, + RedirectURI: form.RedirectURIs, + ClientID: clientID, + ClientSecret: clientSecret, + Scopes: scopes, + VapidKey: vapidKey, + } + + // chuck it in the db + if err := m.db.Put(app); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // now we need to model an oauth client from the application that the oauth library can use + oc := &oauth.Client{ + ID: clientID, + Secret: clientSecret, + Domain: form.RedirectURIs, + UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now + } + + // chuck it in the db + if err := m.db.Put(oc); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/ + c.JSON(http.StatusOK, app.ToMasto()) +} diff --git a/internal/module/app/app_test.go b/internal/module/app/app_test.go new file mode 100644 index 000000000..d45b04e74 --- /dev/null +++ b/internal/module/app/app_test.go @@ -0,0 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package app + +// TODO: write tests diff --git a/internal/module/oauth/README.md b/internal/module/auth/README.md similarity index 97% rename from internal/module/oauth/README.md rename to internal/module/auth/README.md index 3d8427302..96b2443c1 100644 --- a/internal/module/oauth/README.md +++ b/internal/module/auth/README.md @@ -1,4 +1,4 @@ -# oauth +# auth This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API. diff --git a/internal/module/oauth/oauth.go b/internal/module/auth/auth.go similarity index 64% rename from internal/module/oauth/oauth.go rename to internal/module/auth/auth.go index 2014f7d16..922aab86b 100644 --- a/internal/module/oauth/oauth.go +++ b/internal/module/auth/auth.go @@ -16,57 +16,42 @@ along with this program. If not, see . */ -// Package oauth is a module that provides oauth functionality to a router. +// Package auth is a module that provides oauth functionality to a router. // It adds the following paths: -// /api/v1/apps // /auth/sign_in // /oauth/token // /oauth/authorize // It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token. -package oauth +package auth import ( + "errors" "fmt" "net/http" "net/url" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" - "github.com/google/uuid" "github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/gotosocial/internal/db/model" "github.com/gotosocial/gotosocial/internal/module" + "github.com/gotosocial/gotosocial/internal/oauth" "github.com/gotosocial/gotosocial/internal/router" "github.com/gotosocial/gotosocial/pkg/mastotypes" - "github.com/gotosocial/oauth2/v4" - "github.com/gotosocial/oauth2/v4/errors" - "github.com/gotosocial/oauth2/v4/manage" - "github.com/gotosocial/oauth2/v4/server" "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" ) const ( - appsPath = "/api/v1/apps" authSignInPath = "/auth/sign_in" oauthTokenPath = "/oauth/token" oauthAuthorizePath = "/oauth/authorize" - // SessionAuthorizedUser is the key set in the gin context for the id of - // a User who has successfully passed Bearer token authorization. - // The interface returned from grabbing this key should be parsed as a string. - SessionAuthorizedUser = "authorized_user" - // SessionAuthorizedAccount is the key set in the gin context for the Account - // of a User who has successfully passed Bearer token authorization. - // The interface returned from grabbing this key should be parsed as a *gtsmodel.Account - SessionAuthorizedAccount = "authorized_account" ) -// oauthModule is an oauth2 oauthModule that satisfies the ClientAPIModule interface -type oauthModule struct { - oauthManager *manage.Manager - oauthServer *server.Server - db db.DB - log *logrus.Logger +type authModule struct { + server oauth.Server + db db.DB + log *logrus.Logger } type login struct { @@ -74,52 +59,17 @@ type login struct { Password string `form:"password"` } -// New returns a new oauth module -func New(ts oauth2.TokenStore, cs oauth2.ClientStore, db db.DB, log *logrus.Logger) module.ClientAPIModule { - manager := manage.NewDefaultManager() - manager.MapTokenStorage(ts) - manager.MapClientStorage(cs) - manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) - sc := &server.Config{ - TokenType: "Bearer", - // Must follow the spec. - AllowGetAccessRequest: false, - // Support only the non-implicit flow. - AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, - // Allow: - // - Authorization Code (for first & third parties) - AllowedGrantTypes: []oauth2.GrantType{ - oauth2.AuthorizationCode, - }, - AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain}, +// New returns a new auth module +func New(srv oauth.Server, db db.DB, log *logrus.Logger) module.ClientAPIModule { + return &authModule{ + server: srv, + db: db, + log: log, } - - srv := server.NewServer(sc, manager) - srv.SetInternalErrorHandler(func(err error) *errors.Response { - log.Errorf("internal oauth error: %s", err) - return nil - }) - - srv.SetResponseErrorHandler(func(re *errors.Response) { - log.Errorf("internal response error: %s", re.Error) - }) - - m := &oauthModule{ - oauthManager: manager, - oauthServer: srv, - db: db, - log: log, - } - - m.oauthServer.SetUserAuthorizationHandler(m.userAuthorizationHandler) - m.oauthServer.SetClientInfoHandler(server.ClientFormHandler) - return m } // Route satisfies the RESTAPIModule interface -func (m *oauthModule) Route(s router.Router) error { - s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler) - +func (m *authModule) Route(s router.Router) error { s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler) s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler) @@ -129,7 +79,6 @@ func (m *oauthModule) Route(s router.Router) error { s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler) s.AttachMiddleware(m.oauthTokenMiddleware) - return nil } @@ -137,93 +86,10 @@ func (m *oauthModule) Route(s router.Router) error { MAIN HANDLERS -- serve these through a server/router */ -// appsPOSTHandler should be served at https://example.org/api/v1/apps -// It is equivalent to: https://docs.joinmastodon.org/methods/apps/ -func (m *oauthModule) appsPOSTHandler(c *gin.Context) { - l := m.log.WithField("func", "AppsPOSTHandler") - l.Trace("entering AppsPOSTHandler") - - form := &mastotypes.ApplicationPOSTRequest{} - if err := c.ShouldBind(form); err != nil { - c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()}) - return - } - - // permitted length for most fields - permittedLength := 64 - // redirect can be a bit bigger because we probably need to encode data in the redirect uri - permittedRedirect := 256 - - // check lengths of fields before proceeding so the user can't spam huge entries into the database - if len(form.ClientName) > permittedLength { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)}) - return - } - if len(form.Website) > permittedLength { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)}) - return - } - if len(form.RedirectURIs) > permittedRedirect { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)}) - return - } - if len(form.Scopes) > permittedLength { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)}) - return - } - - // set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/ - var scopes string - if form.Scopes == "" { - scopes = "read" - } else { - scopes = form.Scopes - } - - // generate new IDs for this application and its associated client - clientID := uuid.NewString() - clientSecret := uuid.NewString() - vapidKey := uuid.NewString() - - // generate the application to put in the database - app := &model.Application{ - Name: form.ClientName, - Website: form.Website, - RedirectURI: form.RedirectURIs, - ClientID: clientID, - ClientSecret: clientSecret, - Scopes: scopes, - VapidKey: vapidKey, - } - - // chuck it in the db - if err := m.db.Put(app); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // now we need to model an oauth client from the application that the oauth library can use - oc := &oauthClient{ - ID: clientID, - Secret: clientSecret, - Domain: form.RedirectURIs, - UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now - } - - // chuck it in the db - if err := m.db.Put(oc); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - // done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/ - c.JSON(http.StatusOK, app.ToMasto()) -} - // signInGETHandler should be served at https://example.org/auth/sign_in. // The idea is to present a sign in page to the user, where they can enter their username and password. // The form will then POST to the sign in page, which will be handled by SignInPOSTHandler -func (m *oauthModule) signInGETHandler(c *gin.Context) { +func (m *authModule) signInGETHandler(c *gin.Context) { m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html") c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{}) } @@ -231,7 +97,7 @@ func (m *oauthModule) signInGETHandler(c *gin.Context) { // signInPOSTHandler should be served at https://example.org/auth/sign_in. // The idea is to present a sign in page to the user, where they can enter their username and password. // The handler will then redirect to the auth handler served at /auth -func (m *oauthModule) signInPOSTHandler(c *gin.Context) { +func (m *authModule) signInPOSTHandler(c *gin.Context) { l := m.log.WithField("func", "SignInPOSTHandler") s := sessions.Default(c) form := &login{} @@ -260,10 +126,10 @@ func (m *oauthModule) signInPOSTHandler(c *gin.Context) { // tokenPOSTHandler should be served as a POST at https://example.org/oauth/token // The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs. // See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token -func (m *oauthModule) tokenPOSTHandler(c *gin.Context) { +func (m *authModule) tokenPOSTHandler(c *gin.Context) { l := m.log.WithField("func", "TokenPOSTHandler") l.Trace("entered TokenPOSTHandler") - if err := m.oauthServer.HandleTokenRequest(c.Writer, c.Request); err != nil { + if err := m.server.HandleTokenRequest(c.Writer, c.Request); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } } @@ -271,7 +137,7 @@ func (m *oauthModule) tokenPOSTHandler(c *gin.Context) { // authorizeGETHandler should be served as GET at https://example.org/oauth/authorize // The idea here is to present an oauth authorize page to the user, with a button // that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user -func (m *oauthModule) authorizeGETHandler(c *gin.Context) { +func (m *authModule) authorizeGETHandler(c *gin.Context) { l := m.log.WithField("func", "AuthorizeGETHandler") s := sessions.Default(c) @@ -349,7 +215,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) { // At this point we assume that the user has A) logged in and B) accepted that the app should act for them, // so we should proceed with the authentication flow and generate an oauth token for them if we can. // See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user -func (m *oauthModule) authorizePOSTHandler(c *gin.Context) { +func (m *authModule) authorizePOSTHandler(c *gin.Context) { l := m.log.WithField("func", "AuthorizePOSTHandler") s := sessions.Default(c) @@ -404,7 +270,7 @@ func (m *oauthModule) authorizePOSTHandler(c *gin.Context) { l.Tracef("values on request set to %+v", c.Request.Form) // and proceed with authorization using the oauth2 library - if err := m.oauthServer.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { + if err := m.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) } } @@ -418,25 +284,50 @@ func (m *oauthModule) authorizePOSTHandler(c *gin.Context) { // the request. Then, it will look up the account for that user, and set that in the request too. // If user or account can't be found, then the handler won't *fail*, in case the server wants to allow // public requests that don't have a Bearer token set (eg., for public instance information and so on). -func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) { +func (m *authModule) oauthTokenMiddleware(c *gin.Context) { l := m.log.WithField("func", "ValidatePassword") l.Trace("entering OauthTokenMiddleware") - ti, err := m.oauthServer.ValidationBearerToken(c.Request) + ti, err := m.server.ValidationBearerToken(c.Request) if err != nil { l.Trace("no valid token presented: continuing with unauthenticated request") return } - l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope()) + c.Set(oauth.SessionAuthorizedToken, ti) + l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedToken, ti) - acct := &model.Account{} - if err := m.db.GetAccountByUserID(ti.GetUserID(), acct); err != nil || acct == nil { - l.Tracef("no account found for user %s, continuing with unauthenticated request", ti.GetUserID()) - return + // check for user-level token + if uid := ti.GetUserID(); uid != "" { + l.Tracef("authenticated user %s with bearer token, scope is %s", uid, ti.GetScope()) + + // fetch user's and account for this user id + user := &model.User{} + if err := m.db.GetByID(uid, user); err != nil || user == nil { + l.Warnf("no user found for validated uid %s", uid) + return + } + c.Set(oauth.SessionAuthorizedUser, user) + l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user) + + acct := &model.Account{} + if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil { + l.Warnf("no account found for validated user %s", uid) + return + } + c.Set(oauth.SessionAuthorizedAccount, acct) + l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedAccount, acct) } - c.Set(SessionAuthorizedAccount, acct) - c.Set(SessionAuthorizedUser, ti.GetUserID()) + // check for application token + if cid := ti.GetClientID(); cid != "" { + l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope()) + app := &model.Application{} + if err := m.db.GetWhere("client_id", cid, app); err != nil { + l.Tracef("no app found for client %s", cid) + } + c.Set(oauth.SessionAuthorizedApplication, app) + l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedApplication, app) + } } /* @@ -447,7 +338,7 @@ func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) { // The goal is to authenticate the password against the one for that email // address stored in the database. If OK, we return the userid (a uuid) for that user, // so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db. -func (m *oauthModule) validatePassword(email string, password string) (userid string, err error) { +func (m *authModule) validatePassword(email string, password string) (userid string, err error) { l := m.log.WithField("func", "ValidatePassword") // make sure an email/password was provided and bail if not @@ -487,18 +378,6 @@ func incorrectPassword() (string, error) { return "", errors.New("password/email combination was incorrect") } -// userAuthorizationHandler gets the user's ID from the 'userid' field of the request form, -// or redirects to the /auth/sign_in page, if this key is not present. -func (m *oauthModule) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) { - l := m.log.WithField("func", "UserAuthorizationHandler") - userID = r.FormValue("userid") - if userID == "" { - return "", errors.New("userid was empty, redirecting to sign in page") - } - l.Tracef("returning userID %s", userID) - return userID, err -} - // parseAuthForm parses the OAuthAuthorize form in the gin context, and stores // the values in the form into the session. func parseAuthForm(c *gin.Context, l *logrus.Entry) error { diff --git a/internal/module/oauth/oauth_test.go b/internal/module/auth/auth_test.go similarity index 87% rename from internal/module/oauth/oauth_test.go rename to internal/module/auth/auth_test.go index 7dcff0d88..fa5cb16b1 100644 --- a/internal/module/oauth/oauth_test.go +++ b/internal/module/auth/auth_test.go @@ -16,38 +16,38 @@ along with this program. If not, see . */ -package oauth +package auth import ( "context" "fmt" "testing" + "time" "github.com/google/uuid" "github.com/gotosocial/gotosocial/internal/config" "github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/gotosocial/internal/db/model" + "github.com/gotosocial/gotosocial/internal/oauth" "github.com/gotosocial/gotosocial/internal/router" - "github.com/gotosocial/oauth2/v4" "github.com/sirupsen/logrus" "github.com/stretchr/testify/suite" "golang.org/x/crypto/bcrypt" ) -type OauthTestSuite struct { +type AuthTestSuite struct { suite.Suite - tokenStore oauth2.TokenStore - clientStore oauth2.ClientStore + oauthServer oauth.Server db db.DB testAccount *model.Account testApplication *model.Application testUser *model.User - testClient *oauthClient + testClient *oauth.Client config *config.Config } // SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout -func (suite *OauthTestSuite) SetupSuite() { +func (suite *AuthTestSuite) SetupSuite() { c := config.Empty() // we're running on localhost without https so set the protocol to http c.Protocol = "http" @@ -84,7 +84,7 @@ func (suite *OauthTestSuite) SetupSuite() { Email: "user@example.org", AccountID: acctID, } - suite.testClient = &oauthClient{ + suite.testClient = &oauth.Client{ ID: "a-known-client-id", Secret: "some-secret", Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host), @@ -101,7 +101,7 @@ func (suite *OauthTestSuite) SetupSuite() { } // SetupTest creates a postgres connection and creates the oauth_clients table before each test -func (suite *OauthTestSuite) SetupTest() { +func (suite *AuthTestSuite) SetupTest() { log := logrus.New() log.SetLevel(logrus.TraceLevel) @@ -113,8 +113,8 @@ func (suite *OauthTestSuite) SetupTest() { suite.db = db models := []interface{}{ - &oauthClient{}, - &oauthToken{}, + &oauth.Client{}, + &oauth.Token{}, &model.User{}, &model.Account{}, &model.Application{}, @@ -126,8 +126,7 @@ func (suite *OauthTestSuite) SetupTest() { } } - suite.tokenStore = newTokenStore(context.Background(), suite.db, logrus.New()) - suite.clientStore = newClientStore(suite.db) + suite.oauthServer = oauth.New(suite.db, log) if err := suite.db.Put(suite.testAccount); err != nil { logrus.Panicf("could not insert test account into db: %s", err) @@ -145,10 +144,10 @@ func (suite *OauthTestSuite) SetupTest() { } // TearDownTest drops the oauth_clients table and closes the pg connection after each test -func (suite *OauthTestSuite) TearDownTest() { +func (suite *AuthTestSuite) TearDownTest() { models := []interface{}{ - &oauthClient{}, - &oauthToken{}, + &oauth.Client{}, + &oauth.Token{}, &model.User{}, &model.Account{}, &model.Application{}, @@ -164,7 +163,7 @@ func (suite *OauthTestSuite) TearDownTest() { suite.db = nil } -func (suite *OauthTestSuite) TestAPIInitialize() { +func (suite *AuthTestSuite) TestAPIInitialize() { log := logrus.New() log.SetLevel(logrus.TraceLevel) @@ -173,17 +172,18 @@ func (suite *OauthTestSuite) TestAPIInitialize() { suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) } - api := New(suite.tokenStore, suite.clientStore, suite.db, log) + api := New(suite.oauthServer, suite.db, log) if err := api.Route(r); err != nil { suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) } r.Start() + time.Sleep(60 * time.Second) if err := r.Stop(context.Background()); err != nil { suite.FailNow(fmt.Sprintf("error stopping router: %s", err)) } } -func TestOauthTestSuite(t *testing.T) { - suite.Run(t, new(OauthTestSuite)) +func TestAuthTestSuite(t *testing.T) { + suite.Run(t, new(AuthTestSuite)) } diff --git a/internal/module/oauth/clientstore.go b/internal/oauth/clientstore.go similarity index 94% rename from internal/module/oauth/clientstore.go rename to internal/oauth/clientstore.go index 45b518c61..1f9a1282b 100644 --- a/internal/module/oauth/clientstore.go +++ b/internal/oauth/clientstore.go @@ -38,7 +38,7 @@ func newClientStore(db db.DB) oauth2.ClientStore { } func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { - poc := &oauthClient{ + poc := &Client{ ID: clientID, } if err := cs.db.GetByID(clientID, poc); err != nil { @@ -48,7 +48,7 @@ func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.Cli } func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { - poc := &oauthClient{ + poc := &Client{ ID: cli.GetID(), Secret: cli.GetSecret(), Domain: cli.GetDomain(), @@ -58,13 +58,13 @@ func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo } func (cs *clientStore) Delete(ctx context.Context, id string) error { - poc := &oauthClient{ + poc := &Client{ ID: id, } return cs.db.DeleteByID(id, poc) } -type oauthClient struct { +type Client struct { ID string Secret string Domain string diff --git a/internal/module/oauth/clientstore_test.go b/internal/oauth/clientstore_test.go similarity index 99% rename from internal/module/oauth/clientstore_test.go rename to internal/oauth/clientstore_test.go index 8401142c0..87995cbdd 100644 --- a/internal/module/oauth/clientstore_test.go +++ b/internal/oauth/clientstore_test.go @@ -69,7 +69,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() { suite.db = db models := []interface{}{ - &oauthClient{}, + &Client{}, } for _, m := range models { @@ -82,7 +82,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() { // TearDownTest drops the oauth_clients table and closes the pg connection after each test func (suite *PgClientStoreTestSuite) TearDownTest() { models := []interface{}{ - &oauthClient{}, + &Client{}, } for _, m := range models { if err := suite.db.DropTable(m); err != nil { diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go new file mode 100644 index 000000000..f7d2b570a --- /dev/null +++ b/internal/oauth/oauth.go @@ -0,0 +1,212 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package oauth + +import ( + "context" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/gotosocial/gotosocial/internal/db" + "github.com/gotosocial/gotosocial/internal/db/model" + "github.com/gotosocial/oauth2/v4" + "github.com/gotosocial/oauth2/v4/errors" + "github.com/gotosocial/oauth2/v4/manage" + "github.com/gotosocial/oauth2/v4/server" + "github.com/sirupsen/logrus" +) + +const ( + SessionAuthorizedToken = "authorized_token" + // SessionAuthorizedUser is the key set in the gin context for the id of + // a User who has successfully passed Bearer token authorization. + // The interface returned from grabbing this key should be parsed as a *gtsmodel.User + SessionAuthorizedUser = "authorized_user" + // SessionAuthorizedAccount is the key set in the gin context for the Account + // of a User who has successfully passed Bearer token authorization. + // The interface returned from grabbing this key should be parsed as a *gtsmodel.Account + SessionAuthorizedAccount = "authorized_account" + // SessionAuthorizedAccount is the key set in the gin context for the Application + // of a Client who has successfully passed Bearer token authorization. + // The interface returned from grabbing this key should be parsed as a *gtsmodel.Application + SessionAuthorizedApplication = "authorized_app" +) + +// Server wraps some oauth2 server functions in an interface, exposing only what is needed +type Server interface { + HandleTokenRequest(w http.ResponseWriter, r *http.Request) error + HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error + ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) + GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error) +} + +// s fulfils the Server interface using the underlying oauth2 server +type s struct { + server *server.Server + log *logrus.Logger +} + +type Authed struct { + Token oauth2.TokenInfo + Application *model.Application + User *model.User + Account *model.Account +} + +// GetAuthed is a convenience function for returning an Authed struct from a gin context. +// In essence, it tries to extract a token, application, user, and account from the context, +// and then sets them on a struct for convenience. +// +// If any are not present in the context, they will be set to nil on the returned Authed struct. +// +// If *ALL* are not present, then nil and an error will be returned. +// +// If something goes wrong during parsing, then nil and an error will be returned (consider this not authed). +func GetAuthed(c *gin.Context) (*Authed, error) { + ctx := c.Copy() + a := &Authed{} + var i interface{} + var ok bool + + i, ok = ctx.Get(SessionAuthorizedToken) + if ok { + parsed, ok := i.(oauth2.TokenInfo) + if !ok { + return nil, errors.New("could not parse token from session context") + } + a.Token = parsed + } + + i, ok = ctx.Get(SessionAuthorizedApplication) + if ok { + parsed, ok := i.(*model.Application) + if !ok { + return nil, errors.New("could not parse application from session context") + } + a.Application = parsed + } + + i, ok = ctx.Get(SessionAuthorizedUser) + if ok { + parsed, ok := i.(*model.User) + if !ok { + return nil, errors.New("could not parse user from session context") + } + a.User = parsed + } + + i, ok = ctx.Get(SessionAuthorizedAccount) + if ok { + parsed, ok := i.(*model.Account) + if !ok { + return nil, errors.New("could not parse account from session context") + } + a.Account = parsed + } + + if a.Token == nil && a.Application == nil && a.User == nil && a.Account == nil { + return nil, errors.New("not authorized") + } + + return a, nil +} + +// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function +func (s *s) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { + return s.server.HandleTokenRequest(w, r) +} + +// HandleAuthorizeRequest wraps the oauth2 library's HandleAuthorizeRequest function +func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { + return s.server.HandleAuthorizeRequest(w, r) +} + +// ValidationBearerToken wraps the oauth2 library's ValidationBearerToken function +func (s *s) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { + return s.server.ValidationBearerToken(r) +} + +// GenerateUserAccessToken shortcuts the normal oauth flow to create an user-level +// bearer token *without* requiring that user to log in. This is useful when we +// need to create a token for new users who haven't validated their email or logged in yet. +// +// The ti parameter refers to an existing Application token that was used to make the upstream +// request. This token needs to be validated and exist in database in order to create a new token. +func (s *s) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error) { + + tgr := &oauth2.TokenGenerateRequest{ + ClientID: ti.GetClientID(), + ClientSecret: clientSecret, + UserID: userID, + RedirectURI: ti.GetRedirectURI(), + Scope: ti.GetScope(), + Code: ti.GetCode(), + CodeChallenge: ti.GetCodeChallenge(), + CodeChallengeMethod: ti.GetCodeChallengeMethod(), + } + + return s.server.Manager.GenerateAccessToken(context.Background(), oauth2.AuthorizationCode, tgr) +} + +func New(database db.DB, log *logrus.Logger) Server { + ts := newTokenStore(context.Background(), database, log) + cs := newClientStore(database) + + manager := manage.NewDefaultManager() + manager.MapTokenStorage(ts) + manager.MapClientStorage(cs) + manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) + sc := &server.Config{ + TokenType: "Bearer", + // Must follow the spec. + AllowGetAccessRequest: false, + // Support only the non-implicit flow. + AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, + // Allow: + // - Authorization Code (for first & third parties) + // - Client Credentials (for applications) + AllowedGrantTypes: []oauth2.GrantType{ + oauth2.AuthorizationCode, + oauth2.ClientCredentials, + }, + AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain}, + } + + srv := server.NewServer(sc, manager) + srv.SetInternalErrorHandler(func(err error) *errors.Response { + log.Errorf("internal oauth error: %s", err) + return nil + }) + + srv.SetResponseErrorHandler(func(re *errors.Response) { + log.Errorf("internal response error: %s", re.Error) + }) + + srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) { + userID := r.FormValue("userid") + if userID == "" { + return "", errors.New("userid was empty") + } + return userID, nil + }) + srv.SetClientInfoHandler(server.ClientFormHandler) + return &s{ + server: srv, + } +} diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go new file mode 100644 index 000000000..594b9b5a9 --- /dev/null +++ b/internal/oauth/oauth_test.go @@ -0,0 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package oauth + +// TODO: write tests diff --git a/internal/module/oauth/tokenstore.go b/internal/oauth/tokenstore.go similarity index 92% rename from internal/module/oauth/tokenstore.go rename to internal/oauth/tokenstore.go index d8a6d5814..bd18bb8cd 100644 --- a/internal/module/oauth/tokenstore.go +++ b/internal/oauth/tokenstore.go @@ -70,7 +70,7 @@ func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.Tok func (pts *tokenStore) sweep() error { // select *all* tokens from the db // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. - tokens := new([]*oauthToken) + tokens := new([]*Token) if err := pts.db.GetAll(tokens); err != nil { return err } @@ -106,22 +106,22 @@ func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error // RemoveByCode deletes a token from the DB based on the Code field func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error { - return pts.db.DeleteWhere("code", code, &oauthToken{}) + return pts.db.DeleteWhere("code", code, &Token{}) } // RemoveByAccess deletes a token from the DB based on the Access field func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { - return pts.db.DeleteWhere("access", access, &oauthToken{}) + return pts.db.DeleteWhere("access", access, &Token{}) } // RemoveByRefresh deletes a token from the DB based on the Refresh field func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { - return pts.db.DeleteWhere("refresh", refresh, &oauthToken{}) + return pts.db.DeleteWhere("refresh", refresh, &Token{}) } // GetByCode selects a token from the DB based on the Code field func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { - pgt := &oauthToken{ + pgt := &Token{ Code: code, } if err := pts.db.GetWhere("code", code, pgt); err != nil { @@ -132,7 +132,7 @@ func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.Token // GetByAccess selects a token from the DB based on the Access field func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { - pgt := &oauthToken{ + pgt := &Token{ Access: access, } if err := pts.db.GetWhere("access", access, pgt); err != nil { @@ -143,7 +143,7 @@ func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.T // GetByRefresh selects a token from the DB based on the Refresh field func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { - pgt := &oauthToken{ + pgt := &Token{ Refresh: refresh, } if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil { @@ -156,7 +156,7 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2 The following models are basically helpers for the postgres token store implementation, they should only be used internally. */ -// oauthToken is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt. +// Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt. // // Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined, // and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and @@ -164,9 +164,9 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2 // // Note that this struct does *not* satisfy the token interface shown here: https://github.com/gotosocial/oauth2/blob/master/model.go#L22 // and implemented here: https://github.com/gotosocial/oauth2/blob/master/models/token.go. -// As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken +// As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken // and pgTokenToOauthToken can be used for that. -type oauthToken struct { +type Token struct { ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` ClientID string UserID string @@ -186,7 +186,7 @@ type oauthToken struct { } // oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres -func oauthTokenToPGToken(tkn *models.Token) *oauthToken { +func oauthTokenToPGToken(tkn *models.Token) *Token { now := time.Now() // For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's @@ -208,7 +208,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken { rea = now.Add(tkn.RefreshExpiresIn) } - return &oauthToken{ + return &Token{ ClientID: tkn.ClientID, UserID: tkn.UserID, RedirectURI: tkn.RedirectURI, @@ -228,7 +228,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken { } // pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token -func pgTokenToOauthToken(pgt *oauthToken) *models.Token { +func pgTokenToOauthToken(pgt *Token) *models.Token { now := time.Now() return &models.Token{ diff --git a/internal/oauth/tokenstore_test.go b/internal/oauth/tokenstore_test.go new file mode 100644 index 000000000..594b9b5a9 --- /dev/null +++ b/internal/oauth/tokenstore_test.go @@ -0,0 +1,21 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package oauth + +// TODO: write tests diff --git a/internal/router/router.go b/internal/router/router.go index cab7fa7f8..afb50836b 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -36,7 +36,7 @@ import ( // Router provides the REST interface for gotosocial, using gin. type Router interface { // Attach a gin handler to the router with the given method and path - AttachHandler(method string, path string, handler gin.HandlerFunc) + AttachHandler(method string, path string, f gin.HandlerFunc) // Attach a gin middleware to the router that will be used globally AttachMiddleware(handler gin.HandlerFunc) // Start the router @@ -59,6 +59,8 @@ func (r *router) Start() { r.logger.Fatalf("listen: %s", err) } }() + // c := &gin.Context{} + // c.Get() } // Stop shuts down the router nicely diff --git a/pkg/mastotypes/token.go b/pkg/mastotypes/token.go new file mode 100644 index 000000000..c9ac1f177 --- /dev/null +++ b/pkg/mastotypes/token.go @@ -0,0 +1,31 @@ +/* + GoToSocial + Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package mastotypes + +// Token represents an OAuth token used for authenticating with the API and performing actions.. See https://docs.joinmastodon.org/entities/token/ +type Token struct { + // An OAuth token to be used for authorization. + AccessToken string `json:"access_token"` + // The OAuth token type. Mastodon uses Bearer tokens. + TokenType string `json:"token_type"` + // The OAuth scopes granted by this token, space-separated. + Scope string `json:"scope"` + // When the token was generated. (UNIX timestamp seconds) + CreatedAt int64 `json:"created_at"` +}