mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-30 23:12:25 -05:00 
			
		
		
		
	chunking away at it
This commit is contained in:
		
					parent
					
						
							
								0a244be523
							
						
					
				
			
			
				commit
				
					
						f58f77bf1f
					
				
			
		
					 23 changed files with 860 additions and 394 deletions
				
			
		|  | @ -21,6 +21,7 @@ package db | |||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/go-fed/activity/pub" | ||||
|  | @ -145,6 +146,10 @@ type DB interface { | |||
| 	// C) something went wrong in the db | ||||
| 	IsEmailAvailable(email string) error | ||||
| 
 | ||||
| 	// NewSignup creates a new user in the database with the given parameters, with an *unconfirmed* email address. | ||||
| 	// By the time this function is called, it should be assumed that all the parameters have passed validation! | ||||
| 	NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string) (*model.User, error) | ||||
| 
 | ||||
| 	/* | ||||
| 		USEFUL CONVERSION FUNCTIONS | ||||
| 	*/ | ||||
|  |  | |||
|  | @ -23,6 +23,7 @@ | |||
| package model | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/rsa" | ||||
| 	"net/url" | ||||
| 	"time" | ||||
| ) | ||||
|  | @ -82,6 +83,8 @@ type Account struct { | |||
| 	SubscriptionExpiresAt time.Time `pg:"type:timestamp"` | ||||
| 	// Does this account identify itself as a bot? | ||||
| 	Bot bool | ||||
| 	// What reason was given for signing up when this account was created? | ||||
| 	Reason string | ||||
| 
 | ||||
| 	/* | ||||
| 		PRIVACY SETTINGS | ||||
|  | @ -123,9 +126,9 @@ type Account struct { | |||
| 
 | ||||
| 	Secret string | ||||
| 	// Privatekey for validating activitypub requests, will obviously only be defined for local accounts | ||||
| 	PrivateKey string | ||||
| 	PrivateKey *rsa.PrivateKey | ||||
| 	// Publickey for encoding activitypub requests, will be defined for both local and remote accounts | ||||
| 	PublicKey string | ||||
| 	PublicKey *rsa.PublicKey | ||||
| 
 | ||||
| 	/* | ||||
| 		ADMIN FIELDS | ||||
|  |  | |||
|  | @ -20,8 +20,11 @@ package db | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/rand" | ||||
| 	"crypto/rsa" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"net/mail" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
|  | @ -35,6 +38,7 @@ import ( | |||
| 	"github.com/gotosocial/gotosocial/internal/db/model" | ||||
| 	"github.com/gotosocial/gotosocial/pkg/mastotypes" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"golang.org/x/crypto/bcrypt" | ||||
| ) | ||||
| 
 | ||||
| // postgresService satisfies the DB interface | ||||
|  | @ -305,7 +309,6 @@ func (ps *postgresService) GetAccountByUserID(userID string, account *model.Acco | |||
| 		return err | ||||
| 	} | ||||
| 	if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil { | ||||
| 		fmt.Println(account) | ||||
| 		if err == pg.ErrNoRows { | ||||
| 			return ErrNoEntries{} | ||||
| 		} | ||||
|  | @ -412,6 +415,43 @@ func (ps *postgresService) IsEmailAvailable(email string) error { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string) (*model.User, error) { | ||||
| 	key, err := rsa.GenerateKey(rand.Reader, 2048) | ||||
| 	if err != nil { | ||||
| 		ps.log.Errorf("error creating new rsa key: %s", err) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	a := &model.Account{ | ||||
| 		Username:    username, | ||||
| 		DisplayName: username, | ||||
| 		Reason:      reason, | ||||
| 		PrivateKey:  key, | ||||
| 		PublicKey:   &key.PublicKey, | ||||
| 		ActorType:   "Person", | ||||
| 	} | ||||
| 	if _, err = ps.conn.Model(a).Insert(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error hashing password: %s", err) | ||||
| 	} | ||||
| 	u := &model.User{ | ||||
| 		AccountID:         a.ID, | ||||
| 		EncryptedPassword: string(pw), | ||||
| 		SignUpIP:          signUpIP, | ||||
| 		Locale:            locale, | ||||
| 		UnconfirmedEmail:  email, | ||||
| 	} | ||||
| 	if _, err = ps.conn.Model(u).Insert(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return u, nil | ||||
| } | ||||
| 
 | ||||
| /* | ||||
| 	CONVERSION FUNCTIONS | ||||
| */ | ||||
|  | @ -433,7 +473,6 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype | |||
| 		} | ||||
| 		fields = append(fields, mField) | ||||
| 	} | ||||
| 	fmt.Printf("fields: %+v", fields) | ||||
| 
 | ||||
| 	// count followers | ||||
| 	followers := []model.Follow{} | ||||
|  |  | |||
							
								
								
									
										21
									
								
								internal/db/pg_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/db/pg_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,21 @@ | |||
| /* | ||||
|    GoToSocial | ||||
|    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org | ||||
| 
 | ||||
|    This program is free software: you can redistribute it and/or modify | ||||
|    it under the terms of the GNU Affero General Public License as published by | ||||
|    the Free Software Foundation, either version 3 of the License, or | ||||
|    (at your option) any later version. | ||||
| 
 | ||||
|    This program is distributed in the hope that it will be useful, | ||||
|    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
|    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
|    GNU Affero General Public License for more details. | ||||
| 
 | ||||
|    You should have received a copy of the GNU Affero General Public License | ||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| */ | ||||
| 
 | ||||
| package db | ||||
| 
 | ||||
| // TODO: write tests for postgres | ||||
|  | @ -1,3 +1,21 @@ | |||
| /* | ||||
|    GoToSocial | ||||
|    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org | ||||
| 
 | ||||
|    This program is free software: you can redistribute it and/or modify | ||||
|    it under the terms of the GNU Affero General Public License as published by | ||||
|    the Free Software Foundation, either version 3 of the License, or | ||||
|    (at your option) any later version. | ||||
| 
 | ||||
|    This program is distributed in the hope that it will be useful, | ||||
|    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
|    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
|    GNU Affero General Public License for more details. | ||||
| 
 | ||||
|    You should have received a copy of the GNU Affero General Public License | ||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| */ | ||||
| 
 | ||||
| package db | ||||
| 
 | ||||
| import ( | ||||
							
								
								
									
										21
									
								
								internal/db/pgfed_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/db/pgfed_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,21 @@ | |||
| /* | ||||
|    GoToSocial | ||||
|    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org | ||||
| 
 | ||||
|    This program is free software: you can redistribute it and/or modify | ||||
|    it under the terms of the GNU Affero General Public License as published by | ||||
|    the Free Software Foundation, either version 3 of the License, or | ||||
|    (at your option) any later version. | ||||
| 
 | ||||
|    This program is distributed in the hope that it will be useful, | ||||
|    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
|    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
|    GNU Affero General Public License for more details. | ||||
| 
 | ||||
|    You should have received a copy of the GNU Affero General Public License | ||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| */ | ||||
| 
 | ||||
| package db | ||||
| 
 | ||||
| // TODO: write tests for pgfed | ||||
|  | @ -19,6 +19,8 @@ | |||
| package account | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/gin-gonic/gin" | ||||
|  | @ -26,9 +28,10 @@ import ( | |||
| 	"github.com/gotosocial/gotosocial/internal/db" | ||||
| 	"github.com/gotosocial/gotosocial/internal/db/model" | ||||
| 	"github.com/gotosocial/gotosocial/internal/module" | ||||
| 	"github.com/gotosocial/gotosocial/internal/module/oauth" | ||||
| 	"github.com/gotosocial/gotosocial/internal/oauth" | ||||
| 	"github.com/gotosocial/gotosocial/internal/router" | ||||
| 	"github.com/gotosocial/gotosocial/pkg/mastotypes" | ||||
| 	"github.com/gotosocial/oauth2/v4" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| ) | ||||
| 
 | ||||
|  | @ -41,6 +44,7 @@ const ( | |||
| type accountModule struct { | ||||
| 	config      *config.Config | ||||
| 	db          db.DB | ||||
| 	oauthServer oauth.Server | ||||
| 	log         *logrus.Logger | ||||
| } | ||||
| 
 | ||||
|  | @ -60,15 +64,15 @@ func (m *accountModule) Route(r router.Router) error { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // accountCreatePOSTHandler handles create account requests, validates them, | ||||
| // and puts them in the database if they're valid. | ||||
| // It should be served as a POST at /api/v1/accounts | ||||
| func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) { | ||||
| 	l := m.log.WithField("func", "AccountCreatePOSTHandler") | ||||
| 	// TODO: check whether a valid app token has been presented!! | ||||
| 	// See: https://docs.joinmastodon.org/methods/accounts/ | ||||
| 
 | ||||
| 	l.Trace("checking if registration is open") | ||||
| 	if !m.config.AccountsConfig.OpenRegistration { | ||||
| 		l.Debug("account registration is closed, returning error to client") | ||||
| 		c.JSON(http.StatusUnauthorized, gin.H{"error": "account registration is closed"}) | ||||
| 	l := m.log.WithField("func", "accountCreatePOSTHandler") | ||||
| 	authed, err := oauth.GetAuthed(c) | ||||
| 	if err != nil { | ||||
| 		l.Debugf("couldn't auth: %s", err) | ||||
| 		c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
|  | @ -81,15 +85,34 @@ func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) { | |||
| 	} | ||||
| 
 | ||||
| 	l.Tracef("validating form %+v", form) | ||||
| 	if err := validateCreateAccount(form, m.config.AccountsConfig.ReasonRequired, m.db); err != nil { | ||||
| 	if err := validateCreateAccount(form, m.config.AccountsConfig, m.db); err != nil { | ||||
| 		l.Debugf("error validating form: %s", err) | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	clientIP := c.ClientIP() | ||||
| 	l.Tracef("attempting to parse client ip address %s", clientIP) | ||||
| 	signUpIP := net.ParseIP(clientIP) | ||||
| 	if signUpIP == nil { | ||||
| 		l.Debugf("error validating sign up ip address %s", clientIP) | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": "ip address could not be parsed from request"}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	ti, err := m.accountCreate(form, signUpIP, authed.Token, authed.Application) | ||||
| 	if err != nil { | ||||
| 		l.Errorf("internal server error while creating new account: %s", err) | ||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	c.JSON(http.StatusOK, ti) | ||||
| } | ||||
| 
 | ||||
| // accountVerifyGETHandler serves a user's account details to them IF they reached this | ||||
| // handler while in possession of a valid token, according to the oauth middleware. | ||||
| // It should be served as a GET at /api/v1/accounts/verify_credentials | ||||
| func (m *accountModule) accountVerifyGETHandler(c *gin.Context) { | ||||
| 	l := m.log.WithField("func", "AccountVerifyGETHandler") | ||||
| 
 | ||||
|  | @ -120,3 +143,39 @@ func (m *accountModule) accountVerifyGETHandler(c *gin.Context) { | |||
| 	l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive) | ||||
| 	c.JSON(http.StatusOK, acctSensitive) | ||||
| } | ||||
| 
 | ||||
| /* | ||||
| 	HELPER FUNCTIONS | ||||
| */ | ||||
| 
 | ||||
| // accountCreate does the dirty work of making an account and user in the database. | ||||
| // It then returns a token to the caller, for use with the new account, as per the | ||||
| // spec here: https://docs.joinmastodon.org/methods/accounts/ | ||||
| func (m *accountModule) accountCreate(form *mastotypes.AccountCreateRequest, signUpIP net.IP, token oauth2.TokenInfo, app *model.Application) (*mastotypes.Token, error) { | ||||
| 	l := m.log.WithField("func", "accountCreate") | ||||
| 
 | ||||
| 	// don't store a reason if we don't require one | ||||
| 	reason := form.Reason | ||||
| 	if !m.config.AccountsConfig.ReasonRequired { | ||||
| 		reason = "" | ||||
| 	} | ||||
| 
 | ||||
| 	l.Trace("creating new username and account") | ||||
| 	user, err := m.db.NewSignup(form.Username, reason, m.config.AccountsConfig.RequireApproval, form.Email, form.Password, signUpIP, form.Locale) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error creating new signup in the database: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	l.Tracef("generating a token for user %s with account %s and application %s", user.ID, user.AccountID, app.ID) | ||||
| 	ti, err := m.oauthServer.GenerateUserAccessToken(token, app.ClientSecret, user.ID) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error creating new access token for user %s: %s", user.ID, err) | ||||
| 	} | ||||
| 
 | ||||
| 	return &mastotypes.Token{ | ||||
| 		AccessToken: ti.GetCode(), | ||||
| 		TokenType:   "Bearer", | ||||
| 		Scope:       ti.GetScope(), | ||||
| 		CreatedAt:   ti.GetCodeCreateAt().Unix(), | ||||
| 	}, nil | ||||
| } | ||||
|  |  | |||
|  | @ -20,34 +20,33 @@ package account | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/url" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/gotosocial/gotosocial/internal/config" | ||||
| 	"github.com/gotosocial/gotosocial/internal/db" | ||||
| 	"github.com/gotosocial/gotosocial/internal/db/model" | ||||
| 	"github.com/gotosocial/gotosocial/internal/module/oauth" | ||||
| 	"github.com/gotosocial/gotosocial/internal/router" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"github.com/stretchr/testify/suite" | ||||
| 	"golang.org/x/crypto/bcrypt" | ||||
| ) | ||||
| 
 | ||||
| type AccountTestSuite struct { | ||||
| 	suite.Suite | ||||
| 	db                db.DB | ||||
| 	log               *logrus.Logger | ||||
| 	testAccountLocal  *model.Account | ||||
| 	testAccountRemote *model.Account | ||||
| 	testUser          *model.User | ||||
| 	config            *config.Config | ||||
| 	db                db.DB | ||||
| 	accountModule     *accountModule | ||||
| } | ||||
| 
 | ||||
| // SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout | ||||
| func (suite *AccountTestSuite) SetupSuite() { | ||||
| 	log := logrus.New() | ||||
| 	log.SetLevel(logrus.TraceLevel) | ||||
| 	suite.log = log | ||||
| 
 | ||||
| 	c := config.Empty() | ||||
| 	c.DBConfig = &config.DBConfig{ | ||||
| 		Type:            "postgres", | ||||
|  | @ -58,118 +57,126 @@ func (suite *AccountTestSuite) SetupSuite() { | |||
| 		Database:        "postgres", | ||||
| 		ApplicationName: "gotosocial", | ||||
| 	} | ||||
| 	suite.config = c | ||||
| 
 | ||||
| 	encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) | ||||
| 	database, err := db.New(context.Background(), c, log) | ||||
| 	if err != nil { | ||||
| 		logrus.Panicf("error encrypting user pass: %s", err) | ||||
| 		suite.FailNow(err.Error()) | ||||
| 	} | ||||
| 	suite.db = database | ||||
| 
 | ||||
| 	suite.accountModule = &accountModule{ | ||||
| 		config: c, | ||||
| 		db:     database, | ||||
| 		log:    log, | ||||
| 	} | ||||
| 
 | ||||
| 	localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png") | ||||
| 	if err != nil { | ||||
| 		logrus.Panicf("error parsing localavatar url: %s", err) | ||||
| 	} | ||||
| 	localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png") | ||||
| 	if err != nil { | ||||
| 		logrus.Panicf("error parsing localheader url: %s", err) | ||||
| 	} | ||||
| 	// encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) | ||||
| 	// if err != nil { | ||||
| 	// 	logrus.Panicf("error encrypting user pass: %s", err) | ||||
| 	// } | ||||
| 
 | ||||
| 	acctID := uuid.NewString() | ||||
| 	suite.testAccountLocal = &model.Account{ | ||||
| 		ID:              acctID, | ||||
| 		Username:        "local_account_of_some_kind", | ||||
| 		AvatarRemoteURL: localAvatar, | ||||
| 		HeaderRemoteURL: localHeader, | ||||
| 		DisplayName:     "michael caine", | ||||
| 		Fields: []model.Field{ | ||||
| 			{ | ||||
| 				Name:  "come and ave a go", | ||||
| 				Value: "if you think you're hard enough", | ||||
| 			}, | ||||
| 			{ | ||||
| 				Name:       "website", | ||||
| 				Value:      "https://imdb.com", | ||||
| 				VerifiedAt: time.Now(), | ||||
| 			}, | ||||
| 		}, | ||||
| 		Note:         "My name is Michael Caine and i'm a local user.", | ||||
| 		Discoverable: true, | ||||
| 	} | ||||
| 	// localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png") | ||||
| 	// if err != nil { | ||||
| 	// 	logrus.Panicf("error parsing localavatar url: %s", err) | ||||
| 	// } | ||||
| 	// localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png") | ||||
| 	// if err != nil { | ||||
| 	// 	logrus.Panicf("error parsing localheader url: %s", err) | ||||
| 	// } | ||||
| 
 | ||||
| 	avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png") | ||||
| 	if err != nil { | ||||
| 		logrus.Panicf("error parsing avatarURL: %s", err) | ||||
| 	} | ||||
| 	// acctID := uuid.NewString() | ||||
| 	// suite.testAccountLocal = &model.Account{ | ||||
| 	// 	ID:              acctID, | ||||
| 	// 	Username:        "local_account_of_some_kind", | ||||
| 	// 	AvatarRemoteURL: localAvatar, | ||||
| 	// 	HeaderRemoteURL: localHeader, | ||||
| 	// 	DisplayName:     "michael caine", | ||||
| 	// 	Fields: []model.Field{ | ||||
| 	// 		{ | ||||
| 	// 			Name:  "come and ave a go", | ||||
| 	// 			Value: "if you think you're hard enough", | ||||
| 	// 		}, | ||||
| 	// 		{ | ||||
| 	// 			Name:       "website", | ||||
| 	// 			Value:      "https://imdb.com", | ||||
| 	// 			VerifiedAt: time.Now(), | ||||
| 	// 		}, | ||||
| 	// 	}, | ||||
| 	// 	Note:         "My name is Michael Caine and i'm a local user.", | ||||
| 	// 	Discoverable: true, | ||||
| 	// } | ||||
| 
 | ||||
| 	headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png") | ||||
| 	if err != nil { | ||||
| 		logrus.Panicf("error parsing avatarURL: %s", err) | ||||
| 	} | ||||
| 	suite.testAccountRemote = &model.Account{ | ||||
| 		ID:       uuid.NewString(), | ||||
| 		Username: "neato_bombeato", | ||||
| 		Domain:   "example.org", | ||||
| 	// avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png") | ||||
| 	// if err != nil { | ||||
| 	// 	logrus.Panicf("error parsing avatarURL: %s", err) | ||||
| 	// } | ||||
| 
 | ||||
| 		AvatarFileName:    "avatar.png", | ||||
| 		AvatarContentType: "image/png", | ||||
| 		AvatarFileSize:    1024, | ||||
| 		AvatarUpdatedAt:   time.Now(), | ||||
| 		AvatarRemoteURL:   avatarURL, | ||||
| 	// headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png") | ||||
| 	// if err != nil { | ||||
| 	// 	logrus.Panicf("error parsing avatarURL: %s", err) | ||||
| 	// } | ||||
| 	// suite.testAccountRemote = &model.Account{ | ||||
| 	// 	ID:       uuid.NewString(), | ||||
| 	// 	Username: "neato_bombeato", | ||||
| 	// 	Domain:   "example.org", | ||||
| 
 | ||||
| 		HeaderFileName:    "avatar.png", | ||||
| 		HeaderContentType: "image/png", | ||||
| 		HeaderFileSize:    1024, | ||||
| 		HeaderUpdatedAt:   time.Now(), | ||||
| 		HeaderRemoteURL:   headerURL, | ||||
| 	// 	AvatarFileName:    "avatar.png", | ||||
| 	// 	AvatarContentType: "image/png", | ||||
| 	// 	AvatarFileSize:    1024, | ||||
| 	// 	AvatarUpdatedAt:   time.Now(), | ||||
| 	// 	AvatarRemoteURL:   avatarURL, | ||||
| 
 | ||||
| 		DisplayName: "one cool dude 420", | ||||
| 		Fields: []model.Field{ | ||||
| 			{ | ||||
| 				Name:  "pronouns", | ||||
| 				Value: "he/they", | ||||
| 			}, | ||||
| 			{ | ||||
| 				Name:       "website", | ||||
| 				Value:      "https://imcool.edu", | ||||
| 				VerifiedAt: time.Now(), | ||||
| 			}, | ||||
| 		}, | ||||
| 		Note:                  "<p>I'm cool as heck!</p>", | ||||
| 		Discoverable:          true, | ||||
| 		URI:                   "https://example.org/users/neato_bombeato", | ||||
| 		URL:                   "https://example.org/@neato_bombeato", | ||||
| 		LastWebfingeredAt:     time.Now(), | ||||
| 		InboxURL:              "https://example.org/users/neato_bombeato/inbox", | ||||
| 		OutboxURL:             "https://example.org/users/neato_bombeato/outbox", | ||||
| 		SharedInboxURL:        "https://example.org/inbox", | ||||
| 		FollowersURL:          "https://example.org/users/neato_bombeato/followers", | ||||
| 		FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured", | ||||
| 	} | ||||
| 	suite.testUser = &model.User{ | ||||
| 		ID:                uuid.NewString(), | ||||
| 		EncryptedPassword: string(encryptedPassword), | ||||
| 		Email:             "user@example.org", | ||||
| 		AccountID:         acctID, | ||||
| 	// 	HeaderFileName:    "avatar.png", | ||||
| 	// 	HeaderContentType: "image/png", | ||||
| 	// 	HeaderFileSize:    1024, | ||||
| 	// 	HeaderUpdatedAt:   time.Now(), | ||||
| 	// 	HeaderRemoteURL:   headerURL, | ||||
| 
 | ||||
| 	// 	DisplayName: "one cool dude 420", | ||||
| 	// 	Fields: []model.Field{ | ||||
| 	// 		{ | ||||
| 	// 			Name:  "pronouns", | ||||
| 	// 			Value: "he/they", | ||||
| 	// 		}, | ||||
| 	// 		{ | ||||
| 	// 			Name:       "website", | ||||
| 	// 			Value:      "https://imcool.edu", | ||||
| 	// 			VerifiedAt: time.Now(), | ||||
| 	// 		}, | ||||
| 	// 	}, | ||||
| 	// 	Note:                  "<p>I'm cool as heck!</p>", | ||||
| 	// 	Discoverable:          true, | ||||
| 	// 	URI:                   "https://example.org/users/neato_bombeato", | ||||
| 	// 	URL:                   "https://example.org/@neato_bombeato", | ||||
| 	// 	LastWebfingeredAt:     time.Now(), | ||||
| 	// 	InboxURL:              "https://example.org/users/neato_bombeato/inbox", | ||||
| 	// 	OutboxURL:             "https://example.org/users/neato_bombeato/outbox", | ||||
| 	// 	SharedInboxURL:        "https://example.org/inbox", | ||||
| 	// 	FollowersURL:          "https://example.org/users/neato_bombeato/followers", | ||||
| 	// 	FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured", | ||||
| 	// } | ||||
| 	// suite.testUser = &model.User{ | ||||
| 	// 	ID:                uuid.NewString(), | ||||
| 	// 	EncryptedPassword: string(encryptedPassword), | ||||
| 	// 	Email:             "user@example.org", | ||||
| 	// 	AccountID:         acctID, | ||||
| 	// } | ||||
| } | ||||
| 
 | ||||
| func (suite *AccountTestSuite) TearDownSuite() { | ||||
| 	if err := suite.db.Stop(context.Background()); err != nil { | ||||
| 		logrus.Panicf("error closing db connection: %s", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // SetupTest creates a postgres connection and creates the oauth_clients table before each test | ||||
| // SetupTest creates a db connection and creates necessary tables before each test | ||||
| func (suite *AccountTestSuite) SetupTest() { | ||||
| 
 | ||||
| 	log := logrus.New() | ||||
| 	log.SetLevel(logrus.TraceLevel) | ||||
| 	db, err := db.New(context.Background(), suite.config, log) | ||||
| 	if err != nil { | ||||
| 		logrus.Panicf("error creating database connection: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	suite.db = db | ||||
| 
 | ||||
| 	models := []interface{}{ | ||||
| 		&model.User{}, | ||||
| 		&model.Account{}, | ||||
| 		&model.Follow{}, | ||||
| 		&model.Status{}, | ||||
| 		&model.Application{}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, m := range models { | ||||
|  | @ -177,70 +184,31 @@ func (suite *AccountTestSuite) SetupTest() { | |||
| 			logrus.Panicf("db connection error: %s", err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if err := suite.db.Put(suite.testAccountLocal); err != nil { | ||||
| 		logrus.Panicf("could not insert test account into db: %s", err) | ||||
| 	} | ||||
| 	if err := suite.db.Put(suite.testUser); err != nil { | ||||
| 		logrus.Panicf("could not insert test user into db: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| // TearDownTest drops the oauth_clients table and closes the pg connection after each test | ||||
| // TearDownTest drops tables to make sure there's no data in the db | ||||
| func (suite *AccountTestSuite) TearDownTest() { | ||||
| 	models := []interface{}{ | ||||
| 		&model.User{}, | ||||
| 		&model.Account{}, | ||||
| 		&model.Follow{}, | ||||
| 		&model.Status{}, | ||||
| 		&model.Application{}, | ||||
| 	} | ||||
| 	for _, m := range models { | ||||
| 		if err := suite.db.DropTable(m); err != nil { | ||||
| 			logrus.Panicf("error dropping table: %s", err) | ||||
| 		} | ||||
| 	} | ||||
| 	if err := suite.db.Stop(context.Background()); err != nil { | ||||
| 		logrus.Panicf("error closing db connection: %s", err) | ||||
| 	} | ||||
| 	suite.db = nil | ||||
| } | ||||
| 
 | ||||
| func (suite *AccountTestSuite) TestAPIInitialize() { | ||||
| 	log := logrus.New() | ||||
| 	log.SetLevel(logrus.TraceLevel) | ||||
| 
 | ||||
| 	r, err := router.New(suite.config, log) | ||||
| 	if err != nil { | ||||
| 		suite.FailNow(fmt.Sprintf("error creating router: %s", err)) | ||||
| 	} | ||||
| 
 | ||||
| 	r.AttachMiddleware(func(c *gin.Context) { | ||||
| 		account := &model.Account{} | ||||
| 		if err := suite.db.GetAccountByUserID(suite.testUser.ID, account); err != nil || account == nil { | ||||
| 			suite.T().Log(err) | ||||
| 			suite.FailNowf("no account found for user %s, continuing with unauthenticated request: %+v", "", suite.testUser.ID, account) | ||||
| 			fmt.Println(account) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		c.Set(oauth.SessionAuthorizedAccount, account) | ||||
| 		c.Set(oauth.SessionAuthorizedUser, suite.testUser.ID) | ||||
| 	}) | ||||
| 
 | ||||
| 	acct := New(suite.config, suite.db, log) | ||||
| 	if err := acct.Route(r); err != nil { | ||||
| 		suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) | ||||
| 	} | ||||
| 
 | ||||
| 	r.Start() | ||||
| 	defer func() { | ||||
| 		if err := r.Stop(context.Background()); err != nil { | ||||
| 			panic(fmt.Errorf("error stopping router: %s", err)) | ||||
| 		} | ||||
| 	}() | ||||
| 	time.Sleep(10 * time.Second) | ||||
| 
 | ||||
| func (suite *AccountTestSuite) TestAccountCreatePOSTHandler() { | ||||
| 	// TODO: figure out how to test this properly | ||||
| 	recorder := httptest.NewRecorder() | ||||
| 	recorder.Header().Set("X-Forwarded-For", "127.0.0.1") | ||||
| 	ctx, _ := gin.CreateTestContext(recorder) | ||||
| 	// ctx.Set() | ||||
| 	suite.accountModule.accountCreatePOSTHandler(ctx) | ||||
| } | ||||
| 
 | ||||
| func TestAccountTestSuite(t *testing.T) { | ||||
|  |  | |||
|  | @ -21,12 +21,17 @@ package account | |||
| import ( | ||||
| 	"errors" | ||||
| 
 | ||||
| 	"github.com/gotosocial/gotosocial/internal/config" | ||||
| 	"github.com/gotosocial/gotosocial/internal/db" | ||||
| 	"github.com/gotosocial/gotosocial/internal/util" | ||||
| 	"github.com/gotosocial/gotosocial/pkg/mastotypes" | ||||
| ) | ||||
| 
 | ||||
| func validateCreateAccount(form *mastotypes.AccountCreateRequest, reasonRequired bool, database db.DB) error { | ||||
| func validateCreateAccount(form *mastotypes.AccountCreateRequest, c *config.AccountsConfig, database db.DB) error { | ||||
| 	if !c.OpenRegistration { | ||||
| 		return errors.New("registration is not open for this server") | ||||
| 	} | ||||
| 
 | ||||
| 	if err := util.ValidateSignUpUsername(form.Username); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | @ -47,7 +52,7 @@ func validateCreateAccount(form *mastotypes.AccountCreateRequest, reasonRequired | |||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if err := util.ValidateSignUpReason(form.Reason, reasonRequired); err != nil { | ||||
| 	if err := util.ValidateSignUpReason(form.Reason, c.ReasonRequired); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										140
									
								
								internal/module/app/app.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										140
									
								
								internal/module/app/app.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,140 @@ | |||
| /* | ||||
|    GoToSocial | ||||
|    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org | ||||
| 
 | ||||
|    This program is free software: you can redistribute it and/or modify | ||||
|    it under the terms of the GNU Affero General Public License as published by | ||||
|    the Free Software Foundation, either version 3 of the License, or | ||||
|    (at your option) any later version. | ||||
| 
 | ||||
|    This program is distributed in the hope that it will be useful, | ||||
|    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
|    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
|    GNU Affero General Public License for more details. | ||||
| 
 | ||||
|    You should have received a copy of the GNU Affero General Public License | ||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| */ | ||||
| 
 | ||||
| package app | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/gotosocial/gotosocial/internal/db" | ||||
| 	"github.com/gotosocial/gotosocial/internal/db/model" | ||||
| 	"github.com/gotosocial/gotosocial/internal/module" | ||||
| 	"github.com/gotosocial/gotosocial/internal/oauth" | ||||
| 	"github.com/gotosocial/gotosocial/internal/router" | ||||
| 	"github.com/gotosocial/gotosocial/pkg/mastotypes" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| ) | ||||
| 
 | ||||
| const appsPath = "/api/v1/apps" | ||||
| 
 | ||||
| type appModule struct { | ||||
| 	server oauth.Server | ||||
| 	db     db.DB | ||||
| 	log    *logrus.Logger | ||||
| } | ||||
| 
 | ||||
| // New returns a new auth module | ||||
| func New(srv oauth.Server, db db.DB, log *logrus.Logger) module.ClientAPIModule { | ||||
| 	return &appModule{ | ||||
| 		server: srv, | ||||
| 		db:     db, | ||||
| 		log:    log, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Route satisfies the RESTAPIModule interface | ||||
| func (m *appModule) Route(s router.Router) error { | ||||
| 	s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // appsPOSTHandler should be served at https://example.org/api/v1/apps | ||||
| // It is equivalent to: https://docs.joinmastodon.org/methods/apps/ | ||||
| func (m *appModule) appsPOSTHandler(c *gin.Context) { | ||||
| 	l := m.log.WithField("func", "AppsPOSTHandler") | ||||
| 	l.Trace("entering AppsPOSTHandler") | ||||
| 
 | ||||
| 	form := &mastotypes.ApplicationPOSTRequest{} | ||||
| 	if err := c.ShouldBind(form); err != nil { | ||||
| 		c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// permitted length for most fields | ||||
| 	permittedLength := 64 | ||||
| 	// redirect can be a bit bigger because we probably need to encode data in the redirect uri | ||||
| 	permittedRedirect := 256 | ||||
| 
 | ||||
| 	// check lengths of fields before proceeding so the user can't spam huge entries into the database | ||||
| 	if len(form.ClientName) > permittedLength { | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)}) | ||||
| 		return | ||||
| 	} | ||||
| 	if len(form.Website) > permittedLength { | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)}) | ||||
| 		return | ||||
| 	} | ||||
| 	if len(form.RedirectURIs) > permittedRedirect { | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)}) | ||||
| 		return | ||||
| 	} | ||||
| 	if len(form.Scopes) > permittedLength { | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/ | ||||
| 	var scopes string | ||||
| 	if form.Scopes == "" { | ||||
| 		scopes = "read" | ||||
| 	} else { | ||||
| 		scopes = form.Scopes | ||||
| 	} | ||||
| 
 | ||||
| 	// generate new IDs for this application and its associated client | ||||
| 	clientID := uuid.NewString() | ||||
| 	clientSecret := uuid.NewString() | ||||
| 	vapidKey := uuid.NewString() | ||||
| 
 | ||||
| 	// generate the application to put in the database | ||||
| 	app := &model.Application{ | ||||
| 		Name:         form.ClientName, | ||||
| 		Website:      form.Website, | ||||
| 		RedirectURI:  form.RedirectURIs, | ||||
| 		ClientID:     clientID, | ||||
| 		ClientSecret: clientSecret, | ||||
| 		Scopes:       scopes, | ||||
| 		VapidKey:     vapidKey, | ||||
| 	} | ||||
| 
 | ||||
| 	// chuck it in the db | ||||
| 	if err := m.db.Put(app); err != nil { | ||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// now we need to model an oauth client from the application that the oauth library can use | ||||
| 	oc := &oauth.Client{ | ||||
| 		ID:     clientID, | ||||
| 		Secret: clientSecret, | ||||
| 		Domain: form.RedirectURIs, | ||||
| 		UserID: "", // This client isn't yet associated with a specific user,  it's just an app client right now | ||||
| 	} | ||||
| 
 | ||||
| 	// chuck it in the db | ||||
| 	if err := m.db.Put(oc); err != nil { | ||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/ | ||||
| 	c.JSON(http.StatusOK, app.ToMasto()) | ||||
| } | ||||
							
								
								
									
										21
									
								
								internal/module/app/app_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/module/app/app_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,21 @@ | |||
| /* | ||||
|    GoToSocial | ||||
|    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org | ||||
| 
 | ||||
|    This program is free software: you can redistribute it and/or modify | ||||
|    it under the terms of the GNU Affero General Public License as published by | ||||
|    the Free Software Foundation, either version 3 of the License, or | ||||
|    (at your option) any later version. | ||||
| 
 | ||||
|    This program is distributed in the hope that it will be useful, | ||||
|    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
|    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
|    GNU Affero General Public License for more details. | ||||
| 
 | ||||
|    You should have received a copy of the GNU Affero General Public License | ||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| */ | ||||
| 
 | ||||
| package app | ||||
| 
 | ||||
| // TODO: write tests | ||||
|  | @ -1,4 +1,4 @@ | |||
| # oauth | ||||
| # auth | ||||
| 
 | ||||
| This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API. | ||||
| 
 | ||||
|  | @ -16,55 +16,40 @@ | |||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| */ | ||||
| 
 | ||||
| // Package oauth is a module that provides oauth functionality to a router. | ||||
| // Package auth is a module that provides oauth functionality to a router. | ||||
| // It adds the following paths: | ||||
| //    /api/v1/apps | ||||
| //    /auth/sign_in | ||||
| //    /oauth/token | ||||
| //    /oauth/authorize | ||||
| // It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token. | ||||
| package oauth | ||||
| package auth | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 
 | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/gotosocial/gotosocial/internal/db" | ||||
| 	"github.com/gotosocial/gotosocial/internal/db/model" | ||||
| 	"github.com/gotosocial/gotosocial/internal/module" | ||||
| 	"github.com/gotosocial/gotosocial/internal/oauth" | ||||
| 	"github.com/gotosocial/gotosocial/internal/router" | ||||
| 	"github.com/gotosocial/gotosocial/pkg/mastotypes" | ||||
| 	"github.com/gotosocial/oauth2/v4" | ||||
| 	"github.com/gotosocial/oauth2/v4/errors" | ||||
| 	"github.com/gotosocial/oauth2/v4/manage" | ||||
| 	"github.com/gotosocial/oauth2/v4/server" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"golang.org/x/crypto/bcrypt" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	appsPath           = "/api/v1/apps" | ||||
| 	authSignInPath     = "/auth/sign_in" | ||||
| 	oauthTokenPath     = "/oauth/token" | ||||
| 	oauthAuthorizePath = "/oauth/authorize" | ||||
| 	// SessionAuthorizedUser is the key set in the gin context for the id of | ||||
| 	// a User who has successfully passed Bearer token authorization. | ||||
| 	// The interface returned from grabbing this key should be parsed as a string. | ||||
| 	SessionAuthorizedUser = "authorized_user" | ||||
| 	// SessionAuthorizedAccount is the key set in the gin context for the Account | ||||
| 	// of a User who has successfully passed Bearer token authorization. | ||||
| 	// The interface returned from grabbing this key should be parsed as a *gtsmodel.Account | ||||
| 	SessionAuthorizedAccount = "authorized_account" | ||||
| ) | ||||
| 
 | ||||
| // oauthModule is an oauth2 oauthModule that satisfies the ClientAPIModule interface | ||||
| type oauthModule struct { | ||||
| 	oauthManager *manage.Manager | ||||
| 	oauthServer  *server.Server | ||||
| type authModule struct { | ||||
| 	server oauth.Server | ||||
| 	db     db.DB | ||||
| 	log    *logrus.Logger | ||||
| } | ||||
|  | @ -74,52 +59,17 @@ type login struct { | |||
| 	Password string `form:"password"` | ||||
| } | ||||
| 
 | ||||
| // New returns a new oauth module | ||||
| func New(ts oauth2.TokenStore, cs oauth2.ClientStore, db db.DB, log *logrus.Logger) module.ClientAPIModule { | ||||
| 	manager := manage.NewDefaultManager() | ||||
| 	manager.MapTokenStorage(ts) | ||||
| 	manager.MapClientStorage(cs) | ||||
| 	manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) | ||||
| 	sc := &server.Config{ | ||||
| 		TokenType: "Bearer", | ||||
| 		// Must follow the spec. | ||||
| 		AllowGetAccessRequest: false, | ||||
| 		// Support only the non-implicit flow. | ||||
| 		AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, | ||||
| 		// Allow: | ||||
| 		// - Authorization Code (for first & third parties) | ||||
| 		AllowedGrantTypes: []oauth2.GrantType{ | ||||
| 			oauth2.AuthorizationCode, | ||||
| 		}, | ||||
| 		AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain}, | ||||
| 	} | ||||
| 
 | ||||
| 	srv := server.NewServer(sc, manager) | ||||
| 	srv.SetInternalErrorHandler(func(err error) *errors.Response { | ||||
| 		log.Errorf("internal oauth error: %s", err) | ||||
| 		return nil | ||||
| 	}) | ||||
| 
 | ||||
| 	srv.SetResponseErrorHandler(func(re *errors.Response) { | ||||
| 		log.Errorf("internal response error: %s", re.Error) | ||||
| 	}) | ||||
| 
 | ||||
| 	m := &oauthModule{ | ||||
| 		oauthManager: manager, | ||||
| 		oauthServer:  srv, | ||||
| // New returns a new auth module | ||||
| func New(srv oauth.Server, db db.DB, log *logrus.Logger) module.ClientAPIModule { | ||||
| 	return &authModule{ | ||||
| 		server: srv, | ||||
| 		db:     db, | ||||
| 		log:    log, | ||||
| 	} | ||||
| 
 | ||||
| 	m.oauthServer.SetUserAuthorizationHandler(m.userAuthorizationHandler) | ||||
| 	m.oauthServer.SetClientInfoHandler(server.ClientFormHandler) | ||||
| 	return m | ||||
| } | ||||
| 
 | ||||
| // Route satisfies the RESTAPIModule interface | ||||
| func (m *oauthModule) Route(s router.Router) error { | ||||
| 	s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler) | ||||
| 
 | ||||
| func (m *authModule) Route(s router.Router) error { | ||||
| 	s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler) | ||||
| 	s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler) | ||||
| 
 | ||||
|  | @ -129,7 +79,6 @@ func (m *oauthModule) Route(s router.Router) error { | |||
| 	s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler) | ||||
| 
 | ||||
| 	s.AttachMiddleware(m.oauthTokenMiddleware) | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
|  | @ -137,93 +86,10 @@ func (m *oauthModule) Route(s router.Router) error { | |||
| 	MAIN HANDLERS -- serve these through a server/router | ||||
| */ | ||||
| 
 | ||||
| // appsPOSTHandler should be served at https://example.org/api/v1/apps | ||||
| // It is equivalent to: https://docs.joinmastodon.org/methods/apps/ | ||||
| func (m *oauthModule) appsPOSTHandler(c *gin.Context) { | ||||
| 	l := m.log.WithField("func", "AppsPOSTHandler") | ||||
| 	l.Trace("entering AppsPOSTHandler") | ||||
| 
 | ||||
| 	form := &mastotypes.ApplicationPOSTRequest{} | ||||
| 	if err := c.ShouldBind(form); err != nil { | ||||
| 		c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// permitted length for most fields | ||||
| 	permittedLength := 64 | ||||
| 	// redirect can be a bit bigger because we probably need to encode data in the redirect uri | ||||
| 	permittedRedirect := 256 | ||||
| 
 | ||||
| 	// check lengths of fields before proceeding so the user can't spam huge entries into the database | ||||
| 	if len(form.ClientName) > permittedLength { | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)}) | ||||
| 		return | ||||
| 	} | ||||
| 	if len(form.Website) > permittedLength { | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)}) | ||||
| 		return | ||||
| 	} | ||||
| 	if len(form.RedirectURIs) > permittedRedirect { | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)}) | ||||
| 		return | ||||
| 	} | ||||
| 	if len(form.Scopes) > permittedLength { | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/ | ||||
| 	var scopes string | ||||
| 	if form.Scopes == "" { | ||||
| 		scopes = "read" | ||||
| 	} else { | ||||
| 		scopes = form.Scopes | ||||
| 	} | ||||
| 
 | ||||
| 	// generate new IDs for this application and its associated client | ||||
| 	clientID := uuid.NewString() | ||||
| 	clientSecret := uuid.NewString() | ||||
| 	vapidKey := uuid.NewString() | ||||
| 
 | ||||
| 	// generate the application to put in the database | ||||
| 	app := &model.Application{ | ||||
| 		Name:         form.ClientName, | ||||
| 		Website:      form.Website, | ||||
| 		RedirectURI:  form.RedirectURIs, | ||||
| 		ClientID:     clientID, | ||||
| 		ClientSecret: clientSecret, | ||||
| 		Scopes:       scopes, | ||||
| 		VapidKey:     vapidKey, | ||||
| 	} | ||||
| 
 | ||||
| 	// chuck it in the db | ||||
| 	if err := m.db.Put(app); err != nil { | ||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// now we need to model an oauth client from the application that the oauth library can use | ||||
| 	oc := &oauthClient{ | ||||
| 		ID:     clientID, | ||||
| 		Secret: clientSecret, | ||||
| 		Domain: form.RedirectURIs, | ||||
| 		UserID: "", // This client isn't yet associated with a specific user,  it's just an app client right now | ||||
| 	} | ||||
| 
 | ||||
| 	// chuck it in the db | ||||
| 	if err := m.db.Put(oc); err != nil { | ||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/ | ||||
| 	c.JSON(http.StatusOK, app.ToMasto()) | ||||
| } | ||||
| 
 | ||||
| // signInGETHandler should be served at https://example.org/auth/sign_in. | ||||
| // The idea is to present a sign in page to the user, where they can enter their username and password. | ||||
| // The form will then POST to the sign in page, which will be handled by SignInPOSTHandler | ||||
| func (m *oauthModule) signInGETHandler(c *gin.Context) { | ||||
| func (m *authModule) signInGETHandler(c *gin.Context) { | ||||
| 	m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html") | ||||
| 	c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{}) | ||||
| } | ||||
|  | @ -231,7 +97,7 @@ func (m *oauthModule) signInGETHandler(c *gin.Context) { | |||
| // signInPOSTHandler should be served at https://example.org/auth/sign_in. | ||||
| // The idea is to present a sign in page to the user, where they can enter their username and password. | ||||
| // The handler will then redirect to the auth handler served at /auth | ||||
| func (m *oauthModule) signInPOSTHandler(c *gin.Context) { | ||||
| func (m *authModule) signInPOSTHandler(c *gin.Context) { | ||||
| 	l := m.log.WithField("func", "SignInPOSTHandler") | ||||
| 	s := sessions.Default(c) | ||||
| 	form := &login{} | ||||
|  | @ -260,10 +126,10 @@ func (m *oauthModule) signInPOSTHandler(c *gin.Context) { | |||
| // tokenPOSTHandler should be served as a POST at https://example.org/oauth/token | ||||
| // The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs. | ||||
| // See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token | ||||
| func (m *oauthModule) tokenPOSTHandler(c *gin.Context) { | ||||
| func (m *authModule) tokenPOSTHandler(c *gin.Context) { | ||||
| 	l := m.log.WithField("func", "TokenPOSTHandler") | ||||
| 	l.Trace("entered TokenPOSTHandler") | ||||
| 	if err := m.oauthServer.HandleTokenRequest(c.Writer, c.Request); err != nil { | ||||
| 	if err := m.server.HandleTokenRequest(c.Writer, c.Request); err != nil { | ||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||
| 	} | ||||
| } | ||||
|  | @ -271,7 +137,7 @@ func (m *oauthModule) tokenPOSTHandler(c *gin.Context) { | |||
| // authorizeGETHandler should be served as GET at https://example.org/oauth/authorize | ||||
| // The idea here is to present an oauth authorize page to the user, with a button | ||||
| // that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user | ||||
| func (m *oauthModule) authorizeGETHandler(c *gin.Context) { | ||||
| func (m *authModule) authorizeGETHandler(c *gin.Context) { | ||||
| 	l := m.log.WithField("func", "AuthorizeGETHandler") | ||||
| 	s := sessions.Default(c) | ||||
| 
 | ||||
|  | @ -349,7 +215,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) { | |||
| // At this point we assume that the user has A) logged in and B) accepted that the app should act for them, | ||||
| // so we should proceed with the authentication flow and generate an oauth token for them if we can. | ||||
| // See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user | ||||
| func (m *oauthModule) authorizePOSTHandler(c *gin.Context) { | ||||
| func (m *authModule) authorizePOSTHandler(c *gin.Context) { | ||||
| 	l := m.log.WithField("func", "AuthorizePOSTHandler") | ||||
| 	s := sessions.Default(c) | ||||
| 
 | ||||
|  | @ -404,7 +270,7 @@ func (m *oauthModule) authorizePOSTHandler(c *gin.Context) { | |||
| 	l.Tracef("values on request set to %+v", c.Request.Form) | ||||
| 
 | ||||
| 	// and proceed with authorization using the oauth2 library | ||||
| 	if err := m.oauthServer.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { | ||||
| 	if err := m.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { | ||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||
| 	} | ||||
| } | ||||
|  | @ -418,25 +284,50 @@ func (m *oauthModule) authorizePOSTHandler(c *gin.Context) { | |||
| // the request. Then, it will look up the account for that user, and set that in the request too. | ||||
| // If user or account can't be found, then the handler won't *fail*, in case the server wants to allow | ||||
| // public requests that don't have a Bearer token set (eg., for public instance information and so on). | ||||
| func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) { | ||||
| func (m *authModule) oauthTokenMiddleware(c *gin.Context) { | ||||
| 	l := m.log.WithField("func", "ValidatePassword") | ||||
| 	l.Trace("entering OauthTokenMiddleware") | ||||
| 
 | ||||
| 	ti, err := m.oauthServer.ValidationBearerToken(c.Request) | ||||
| 	ti, err := m.server.ValidationBearerToken(c.Request) | ||||
| 	if err != nil { | ||||
| 		l.Trace("no valid token presented: continuing with unauthenticated request") | ||||
| 		return | ||||
| 	} | ||||
| 	l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope()) | ||||
| 	c.Set(oauth.SessionAuthorizedToken, ti) | ||||
| 	l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedToken, ti) | ||||
| 
 | ||||
| 	acct := &model.Account{} | ||||
| 	if err := m.db.GetAccountByUserID(ti.GetUserID(), acct); err != nil || acct == nil { | ||||
| 		l.Tracef("no account found for user %s, continuing with unauthenticated request", ti.GetUserID()) | ||||
| 	// check for user-level token | ||||
| 	if uid := ti.GetUserID(); uid != "" { | ||||
| 		l.Tracef("authenticated user %s with bearer token, scope is %s", uid, ti.GetScope()) | ||||
| 
 | ||||
| 		// fetch user's and account for this user id | ||||
| 		user := &model.User{} | ||||
| 		if err := m.db.GetByID(uid, user); err != nil || user == nil { | ||||
| 			l.Warnf("no user found for validated uid %s", uid) | ||||
| 			return | ||||
| 		} | ||||
| 		c.Set(oauth.SessionAuthorizedUser, user) | ||||
| 		l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user) | ||||
| 
 | ||||
| 	c.Set(SessionAuthorizedAccount, acct) | ||||
| 	c.Set(SessionAuthorizedUser, ti.GetUserID()) | ||||
| 		acct := &model.Account{} | ||||
| 		if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil { | ||||
| 			l.Warnf("no account found for validated user %s", uid) | ||||
| 			return | ||||
| 		} | ||||
| 		c.Set(oauth.SessionAuthorizedAccount, acct) | ||||
| 		l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedAccount, acct) | ||||
| 	} | ||||
| 
 | ||||
| 	// check for application token | ||||
| 	if cid := ti.GetClientID(); cid != "" { | ||||
| 		l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope()) | ||||
| 		app := &model.Application{} | ||||
| 		if err := m.db.GetWhere("client_id", cid, app); err != nil { | ||||
| 			l.Tracef("no app found for client %s", cid) | ||||
| 		} | ||||
| 		c.Set(oauth.SessionAuthorizedApplication, app) | ||||
| 		l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedApplication, app) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| /* | ||||
|  | @ -447,7 +338,7 @@ func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) { | |||
| // The goal is to authenticate the password against the one for that email | ||||
| // address stored in the database. If OK, we return the userid (a uuid) for that user, | ||||
| // so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db. | ||||
| func (m *oauthModule) validatePassword(email string, password string) (userid string, err error) { | ||||
| func (m *authModule) validatePassword(email string, password string) (userid string, err error) { | ||||
| 	l := m.log.WithField("func", "ValidatePassword") | ||||
| 
 | ||||
| 	// make sure an email/password was provided and bail if not | ||||
|  | @ -487,18 +378,6 @@ func incorrectPassword() (string, error) { | |||
| 	return "", errors.New("password/email combination was incorrect") | ||||
| } | ||||
| 
 | ||||
| // userAuthorizationHandler gets the user's ID from the 'userid' field of the request form, | ||||
| // or redirects to the /auth/sign_in page, if this key is not present. | ||||
| func (m *oauthModule) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) { | ||||
| 	l := m.log.WithField("func", "UserAuthorizationHandler") | ||||
| 	userID = r.FormValue("userid") | ||||
| 	if userID == "" { | ||||
| 		return "", errors.New("userid was empty, redirecting to sign in page") | ||||
| 	} | ||||
| 	l.Tracef("returning userID %s", userID) | ||||
| 	return userID, err | ||||
| } | ||||
| 
 | ||||
| // parseAuthForm parses the OAuthAuthorize form in the gin context, and stores | ||||
| // the values in the form into the session. | ||||
| func parseAuthForm(c *gin.Context, l *logrus.Entry) error { | ||||
|  | @ -16,38 +16,38 @@ | |||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| */ | ||||
| 
 | ||||
| package oauth | ||||
| package auth | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/gotosocial/gotosocial/internal/config" | ||||
| 	"github.com/gotosocial/gotosocial/internal/db" | ||||
| 	"github.com/gotosocial/gotosocial/internal/db/model" | ||||
| 	"github.com/gotosocial/gotosocial/internal/oauth" | ||||
| 	"github.com/gotosocial/gotosocial/internal/router" | ||||
| 	"github.com/gotosocial/oauth2/v4" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"github.com/stretchr/testify/suite" | ||||
| 	"golang.org/x/crypto/bcrypt" | ||||
| ) | ||||
| 
 | ||||
| type OauthTestSuite struct { | ||||
| type AuthTestSuite struct { | ||||
| 	suite.Suite | ||||
| 	tokenStore      oauth2.TokenStore | ||||
| 	clientStore     oauth2.ClientStore | ||||
| 	oauthServer     oauth.Server | ||||
| 	db              db.DB | ||||
| 	testAccount     *model.Account | ||||
| 	testApplication *model.Application | ||||
| 	testUser        *model.User | ||||
| 	testClient      *oauthClient | ||||
| 	testClient      *oauth.Client | ||||
| 	config          *config.Config | ||||
| } | ||||
| 
 | ||||
| // SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout | ||||
| func (suite *OauthTestSuite) SetupSuite() { | ||||
| func (suite *AuthTestSuite) SetupSuite() { | ||||
| 	c := config.Empty() | ||||
| 	// we're running on localhost without https so set the protocol to http | ||||
| 	c.Protocol = "http" | ||||
|  | @ -84,7 +84,7 @@ func (suite *OauthTestSuite) SetupSuite() { | |||
| 		Email:             "user@example.org", | ||||
| 		AccountID:         acctID, | ||||
| 	} | ||||
| 	suite.testClient = &oauthClient{ | ||||
| 	suite.testClient = &oauth.Client{ | ||||
| 		ID:     "a-known-client-id", | ||||
| 		Secret: "some-secret", | ||||
| 		Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host), | ||||
|  | @ -101,7 +101,7 @@ func (suite *OauthTestSuite) SetupSuite() { | |||
| } | ||||
| 
 | ||||
| // SetupTest creates a postgres connection and creates the oauth_clients table before each test | ||||
| func (suite *OauthTestSuite) SetupTest() { | ||||
| func (suite *AuthTestSuite) SetupTest() { | ||||
| 
 | ||||
| 	log := logrus.New() | ||||
| 	log.SetLevel(logrus.TraceLevel) | ||||
|  | @ -113,8 +113,8 @@ func (suite *OauthTestSuite) SetupTest() { | |||
| 	suite.db = db | ||||
| 
 | ||||
| 	models := []interface{}{ | ||||
| 		&oauthClient{}, | ||||
| 		&oauthToken{}, | ||||
| 		&oauth.Client{}, | ||||
| 		&oauth.Token{}, | ||||
| 		&model.User{}, | ||||
| 		&model.Account{}, | ||||
| 		&model.Application{}, | ||||
|  | @ -126,8 +126,7 @@ func (suite *OauthTestSuite) SetupTest() { | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	suite.tokenStore = newTokenStore(context.Background(), suite.db, logrus.New()) | ||||
| 	suite.clientStore = newClientStore(suite.db) | ||||
| 	suite.oauthServer = oauth.New(suite.db, log) | ||||
| 
 | ||||
| 	if err := suite.db.Put(suite.testAccount); err != nil { | ||||
| 		logrus.Panicf("could not insert test account into db: %s", err) | ||||
|  | @ -145,10 +144,10 @@ func (suite *OauthTestSuite) SetupTest() { | |||
| } | ||||
| 
 | ||||
| // TearDownTest drops the oauth_clients table and closes the pg connection after each test | ||||
| func (suite *OauthTestSuite) TearDownTest() { | ||||
| func (suite *AuthTestSuite) TearDownTest() { | ||||
| 	models := []interface{}{ | ||||
| 		&oauthClient{}, | ||||
| 		&oauthToken{}, | ||||
| 		&oauth.Client{}, | ||||
| 		&oauth.Token{}, | ||||
| 		&model.User{}, | ||||
| 		&model.Account{}, | ||||
| 		&model.Application{}, | ||||
|  | @ -164,7 +163,7 @@ func (suite *OauthTestSuite) TearDownTest() { | |||
| 	suite.db = nil | ||||
| } | ||||
| 
 | ||||
| func (suite *OauthTestSuite) TestAPIInitialize() { | ||||
| func (suite *AuthTestSuite) TestAPIInitialize() { | ||||
| 	log := logrus.New() | ||||
| 	log.SetLevel(logrus.TraceLevel) | ||||
| 
 | ||||
|  | @ -173,17 +172,18 @@ func (suite *OauthTestSuite) TestAPIInitialize() { | |||
| 		suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) | ||||
| 	} | ||||
| 
 | ||||
| 	api := New(suite.tokenStore, suite.clientStore, suite.db, log) | ||||
| 	api := New(suite.oauthServer, suite.db, log) | ||||
| 	if err := api.Route(r); err != nil { | ||||
| 		suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) | ||||
| 	} | ||||
| 
 | ||||
| 	r.Start() | ||||
| 	time.Sleep(60 * time.Second) | ||||
| 	if err := r.Stop(context.Background()); err != nil { | ||||
| 		suite.FailNow(fmt.Sprintf("error stopping router: %s", err)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOauthTestSuite(t *testing.T) { | ||||
| 	suite.Run(t, new(OauthTestSuite)) | ||||
| func TestAuthTestSuite(t *testing.T) { | ||||
| 	suite.Run(t, new(AuthTestSuite)) | ||||
| } | ||||
|  | @ -38,7 +38,7 @@ func newClientStore(db db.DB) oauth2.ClientStore { | |||
| } | ||||
| 
 | ||||
| func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { | ||||
| 	poc := &oauthClient{ | ||||
| 	poc := &Client{ | ||||
| 		ID: clientID, | ||||
| 	} | ||||
| 	if err := cs.db.GetByID(clientID, poc); err != nil { | ||||
|  | @ -48,7 +48,7 @@ func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.Cli | |||
| } | ||||
| 
 | ||||
| func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { | ||||
| 	poc := &oauthClient{ | ||||
| 	poc := &Client{ | ||||
| 		ID:     cli.GetID(), | ||||
| 		Secret: cli.GetSecret(), | ||||
| 		Domain: cli.GetDomain(), | ||||
|  | @ -58,13 +58,13 @@ func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo | |||
| } | ||||
| 
 | ||||
| func (cs *clientStore) Delete(ctx context.Context, id string) error { | ||||
| 	poc := &oauthClient{ | ||||
| 	poc := &Client{ | ||||
| 		ID: id, | ||||
| 	} | ||||
| 	return cs.db.DeleteByID(id, poc) | ||||
| } | ||||
| 
 | ||||
| type oauthClient struct { | ||||
| type Client struct { | ||||
| 	ID     string | ||||
| 	Secret string | ||||
| 	Domain string | ||||
|  | @ -69,7 +69,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() { | |||
| 	suite.db = db | ||||
| 
 | ||||
| 	models := []interface{}{ | ||||
| 		&oauthClient{}, | ||||
| 		&Client{}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, m := range models { | ||||
|  | @ -82,7 +82,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() { | |||
| // TearDownTest drops the oauth_clients table and closes the pg connection after each test | ||||
| func (suite *PgClientStoreTestSuite) TearDownTest() { | ||||
| 	models := []interface{}{ | ||||
| 		&oauthClient{}, | ||||
| 		&Client{}, | ||||
| 	} | ||||
| 	for _, m := range models { | ||||
| 		if err := suite.db.DropTable(m); err != nil { | ||||
							
								
								
									
										212
									
								
								internal/oauth/oauth.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										212
									
								
								internal/oauth/oauth.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,212 @@ | |||
| /* | ||||
|    GoToSocial | ||||
|    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org | ||||
| 
 | ||||
|    This program is free software: you can redistribute it and/or modify | ||||
|    it under the terms of the GNU Affero General Public License as published by | ||||
|    the Free Software Foundation, either version 3 of the License, or | ||||
|    (at your option) any later version. | ||||
| 
 | ||||
|    This program is distributed in the hope that it will be useful, | ||||
|    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
|    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
|    GNU Affero General Public License for more details. | ||||
| 
 | ||||
|    You should have received a copy of the GNU Affero General Public License | ||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| */ | ||||
| 
 | ||||
| package oauth | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/gotosocial/gotosocial/internal/db" | ||||
| 	"github.com/gotosocial/gotosocial/internal/db/model" | ||||
| 	"github.com/gotosocial/oauth2/v4" | ||||
| 	"github.com/gotosocial/oauth2/v4/errors" | ||||
| 	"github.com/gotosocial/oauth2/v4/manage" | ||||
| 	"github.com/gotosocial/oauth2/v4/server" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	SessionAuthorizedToken = "authorized_token" | ||||
| 	// SessionAuthorizedUser is the key set in the gin context for the id of | ||||
| 	// a User who has successfully passed Bearer token authorization. | ||||
| 	// The interface returned from grabbing this key should be parsed as a *gtsmodel.User | ||||
| 	SessionAuthorizedUser = "authorized_user" | ||||
| 	// SessionAuthorizedAccount is the key set in the gin context for the Account | ||||
| 	// of a User who has successfully passed Bearer token authorization. | ||||
| 	// The interface returned from grabbing this key should be parsed as a *gtsmodel.Account | ||||
| 	SessionAuthorizedAccount = "authorized_account" | ||||
| 	// SessionAuthorizedAccount is the key set in the gin context for the Application | ||||
| 	// of a Client who has successfully passed Bearer token authorization. | ||||
| 	// The interface returned from grabbing this key should be parsed as a *gtsmodel.Application | ||||
| 	SessionAuthorizedApplication = "authorized_app" | ||||
| ) | ||||
| 
 | ||||
| // Server wraps some oauth2 server functions in an interface, exposing only what is needed | ||||
| type Server interface { | ||||
| 	HandleTokenRequest(w http.ResponseWriter, r *http.Request) error | ||||
| 	HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error | ||||
| 	ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) | ||||
| 	GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error) | ||||
| } | ||||
| 
 | ||||
| // s fulfils the Server interface using the underlying oauth2 server | ||||
| type s struct { | ||||
| 	server *server.Server | ||||
| 	log    *logrus.Logger | ||||
| } | ||||
| 
 | ||||
| type Authed struct { | ||||
| 	Token       oauth2.TokenInfo | ||||
| 	Application *model.Application | ||||
| 	User        *model.User | ||||
| 	Account     *model.Account | ||||
| } | ||||
| 
 | ||||
| // GetAuthed is a convenience function for returning an Authed struct from a gin context. | ||||
| // In essence, it tries to extract a token, application, user, and account from the context, | ||||
| // and then sets them on a struct for convenience. | ||||
| // | ||||
| // If any are not present in the context, they will be set to nil on the returned Authed struct. | ||||
| // | ||||
| // If *ALL* are not present, then nil and an error will be returned. | ||||
| // | ||||
| // If something goes wrong during parsing, then nil and an error will be returned (consider this not authed). | ||||
| func GetAuthed(c *gin.Context) (*Authed, error) { | ||||
| 	ctx := c.Copy() | ||||
| 	a := &Authed{} | ||||
| 	var i interface{} | ||||
| 	var ok bool | ||||
| 
 | ||||
| 	i, ok = ctx.Get(SessionAuthorizedToken) | ||||
| 	if ok { | ||||
| 		parsed, ok := i.(oauth2.TokenInfo) | ||||
| 		if !ok { | ||||
| 			return nil, errors.New("could not parse token from session context") | ||||
| 		} | ||||
| 		a.Token = parsed | ||||
| 	} | ||||
| 
 | ||||
| 	i, ok = ctx.Get(SessionAuthorizedApplication) | ||||
| 	if ok { | ||||
| 		parsed, ok := i.(*model.Application) | ||||
| 		if !ok { | ||||
| 			return nil, errors.New("could not parse application from session context") | ||||
| 		} | ||||
| 		a.Application = parsed | ||||
| 	} | ||||
| 
 | ||||
| 	i, ok = ctx.Get(SessionAuthorizedUser) | ||||
| 	if ok { | ||||
| 		parsed, ok := i.(*model.User) | ||||
| 		if !ok { | ||||
| 			return nil, errors.New("could not parse user from session context") | ||||
| 		} | ||||
| 		a.User = parsed | ||||
| 	} | ||||
| 
 | ||||
| 	i, ok = ctx.Get(SessionAuthorizedAccount) | ||||
| 	if ok { | ||||
| 		parsed, ok := i.(*model.Account) | ||||
| 		if !ok { | ||||
| 			return nil, errors.New("could not parse account from session context") | ||||
| 		} | ||||
| 		a.Account = parsed | ||||
| 	} | ||||
| 
 | ||||
| 	if a.Token == nil && a.Application == nil && a.User == nil && a.Account == nil { | ||||
| 		return nil, errors.New("not authorized") | ||||
| 	} | ||||
| 
 | ||||
| 	return a, nil | ||||
| } | ||||
| 
 | ||||
| // HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function | ||||
| func (s *s) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { | ||||
| 	return s.server.HandleTokenRequest(w, r) | ||||
| } | ||||
| 
 | ||||
| // HandleAuthorizeRequest wraps the oauth2 library's HandleAuthorizeRequest function | ||||
| func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { | ||||
| 	return s.server.HandleAuthorizeRequest(w, r) | ||||
| } | ||||
| 
 | ||||
| // ValidationBearerToken wraps the oauth2 library's ValidationBearerToken function | ||||
| func (s *s) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { | ||||
| 	return s.server.ValidationBearerToken(r) | ||||
| } | ||||
| 
 | ||||
| // GenerateUserAccessToken shortcuts the normal oauth flow to create an user-level | ||||
| // bearer token *without* requiring that user to log in. This is useful when we | ||||
| // need to create a token for new users who haven't validated their email or logged in yet. | ||||
| // | ||||
| // The ti parameter refers to an existing Application token that was used to make the upstream | ||||
| // request. This token needs to be validated and exist in database in order to create a new token. | ||||
| func (s *s) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error) { | ||||
| 
 | ||||
| 	tgr := &oauth2.TokenGenerateRequest{ | ||||
| 		ClientID:            ti.GetClientID(), | ||||
| 		ClientSecret:        clientSecret, | ||||
| 		UserID:              userID, | ||||
| 		RedirectURI:         ti.GetRedirectURI(), | ||||
| 		Scope:               ti.GetScope(), | ||||
| 		Code:                ti.GetCode(), | ||||
| 		CodeChallenge:       ti.GetCodeChallenge(), | ||||
| 		CodeChallengeMethod: ti.GetCodeChallengeMethod(), | ||||
| 	} | ||||
| 
 | ||||
| 	return s.server.Manager.GenerateAccessToken(context.Background(), oauth2.AuthorizationCode, tgr) | ||||
| } | ||||
| 
 | ||||
| func New(database db.DB, log *logrus.Logger) Server { | ||||
| 	ts := newTokenStore(context.Background(), database, log) | ||||
| 	cs := newClientStore(database) | ||||
| 
 | ||||
| 	manager := manage.NewDefaultManager() | ||||
| 	manager.MapTokenStorage(ts) | ||||
| 	manager.MapClientStorage(cs) | ||||
| 	manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) | ||||
| 	sc := &server.Config{ | ||||
| 		TokenType: "Bearer", | ||||
| 		// Must follow the spec. | ||||
| 		AllowGetAccessRequest: false, | ||||
| 		// Support only the non-implicit flow. | ||||
| 		AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, | ||||
| 		// Allow: | ||||
| 		// - Authorization Code (for first & third parties) | ||||
| 		// - Client Credentials (for applications) | ||||
| 		AllowedGrantTypes: []oauth2.GrantType{ | ||||
| 			oauth2.AuthorizationCode, | ||||
| 			oauth2.ClientCredentials, | ||||
| 		}, | ||||
| 		AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain}, | ||||
| 	} | ||||
| 
 | ||||
| 	srv := server.NewServer(sc, manager) | ||||
| 	srv.SetInternalErrorHandler(func(err error) *errors.Response { | ||||
| 		log.Errorf("internal oauth error: %s", err) | ||||
| 		return nil | ||||
| 	}) | ||||
| 
 | ||||
| 	srv.SetResponseErrorHandler(func(re *errors.Response) { | ||||
| 		log.Errorf("internal response error: %s", re.Error) | ||||
| 	}) | ||||
| 
 | ||||
| 	srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) { | ||||
| 		userID := r.FormValue("userid") | ||||
| 		if userID == "" { | ||||
| 			return "", errors.New("userid was empty") | ||||
| 		} | ||||
| 		return userID, nil | ||||
| 	}) | ||||
| 	srv.SetClientInfoHandler(server.ClientFormHandler) | ||||
| 	return &s{ | ||||
| 		server: srv, | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										21
									
								
								internal/oauth/oauth_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/oauth/oauth_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,21 @@ | |||
| /* | ||||
|    GoToSocial | ||||
|    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org | ||||
| 
 | ||||
|    This program is free software: you can redistribute it and/or modify | ||||
|    it under the terms of the GNU Affero General Public License as published by | ||||
|    the Free Software Foundation, either version 3 of the License, or | ||||
|    (at your option) any later version. | ||||
| 
 | ||||
|    This program is distributed in the hope that it will be useful, | ||||
|    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
|    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
|    GNU Affero General Public License for more details. | ||||
| 
 | ||||
|    You should have received a copy of the GNU Affero General Public License | ||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| */ | ||||
| 
 | ||||
| package oauth | ||||
| 
 | ||||
| // TODO: write tests | ||||
|  | @ -70,7 +70,7 @@ func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.Tok | |||
| func (pts *tokenStore) sweep() error { | ||||
| 	// select *all* tokens from the db | ||||
| 	// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. | ||||
| 	tokens := new([]*oauthToken) | ||||
| 	tokens := new([]*Token) | ||||
| 	if err := pts.db.GetAll(tokens); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | @ -106,22 +106,22 @@ func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error | |||
| 
 | ||||
| // RemoveByCode deletes a token from the DB based on the Code field | ||||
| func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error { | ||||
| 	return pts.db.DeleteWhere("code", code, &oauthToken{}) | ||||
| 	return pts.db.DeleteWhere("code", code, &Token{}) | ||||
| } | ||||
| 
 | ||||
| // RemoveByAccess deletes a token from the DB based on the Access field | ||||
| func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { | ||||
| 	return pts.db.DeleteWhere("access", access, &oauthToken{}) | ||||
| 	return pts.db.DeleteWhere("access", access, &Token{}) | ||||
| } | ||||
| 
 | ||||
| // RemoveByRefresh deletes a token from the DB based on the Refresh field | ||||
| func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { | ||||
| 	return pts.db.DeleteWhere("refresh", refresh, &oauthToken{}) | ||||
| 	return pts.db.DeleteWhere("refresh", refresh, &Token{}) | ||||
| } | ||||
| 
 | ||||
| // GetByCode selects a token from the DB based on the Code field | ||||
| func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { | ||||
| 	pgt := &oauthToken{ | ||||
| 	pgt := &Token{ | ||||
| 		Code: code, | ||||
| 	} | ||||
| 	if err := pts.db.GetWhere("code", code, pgt); err != nil { | ||||
|  | @ -132,7 +132,7 @@ func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.Token | |||
| 
 | ||||
| // GetByAccess selects a token from the DB based on the Access field | ||||
| func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { | ||||
| 	pgt := &oauthToken{ | ||||
| 	pgt := &Token{ | ||||
| 		Access: access, | ||||
| 	} | ||||
| 	if err := pts.db.GetWhere("access", access, pgt); err != nil { | ||||
|  | @ -143,7 +143,7 @@ func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.T | |||
| 
 | ||||
| // GetByRefresh selects a token from the DB based on the Refresh field | ||||
| func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { | ||||
| 	pgt := &oauthToken{ | ||||
| 	pgt := &Token{ | ||||
| 		Refresh: refresh, | ||||
| 	} | ||||
| 	if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil { | ||||
|  | @ -156,7 +156,7 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2 | |||
| 	The following models are basically helpers for the postgres token store implementation, they should only be used internally. | ||||
| */ | ||||
| 
 | ||||
| // oauthToken is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt. | ||||
| // Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt. | ||||
| // | ||||
| // Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined, | ||||
| // and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and | ||||
|  | @ -164,9 +164,9 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2 | |||
| // | ||||
| // Note that this struct does *not* satisfy the token interface shown here: https://github.com/gotosocial/oauth2/blob/master/model.go#L22 | ||||
| // and implemented here: https://github.com/gotosocial/oauth2/blob/master/models/token.go. | ||||
| // As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken | ||||
| // As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken | ||||
| // and pgTokenToOauthToken can be used for that. | ||||
| type oauthToken struct { | ||||
| type Token struct { | ||||
| 	ID                  string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` | ||||
| 	ClientID            string | ||||
| 	UserID              string | ||||
|  | @ -186,7 +186,7 @@ type oauthToken struct { | |||
| } | ||||
| 
 | ||||
| // oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres | ||||
| func oauthTokenToPGToken(tkn *models.Token) *oauthToken { | ||||
| func oauthTokenToPGToken(tkn *models.Token) *Token { | ||||
| 	now := time.Now() | ||||
| 
 | ||||
| 	// For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's | ||||
|  | @ -208,7 +208,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken { | |||
| 		rea = now.Add(tkn.RefreshExpiresIn) | ||||
| 	} | ||||
| 
 | ||||
| 	return &oauthToken{ | ||||
| 	return &Token{ | ||||
| 		ClientID:            tkn.ClientID, | ||||
| 		UserID:              tkn.UserID, | ||||
| 		RedirectURI:         tkn.RedirectURI, | ||||
|  | @ -228,7 +228,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken { | |||
| } | ||||
| 
 | ||||
| // pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token | ||||
| func pgTokenToOauthToken(pgt *oauthToken) *models.Token { | ||||
| func pgTokenToOauthToken(pgt *Token) *models.Token { | ||||
| 	now := time.Now() | ||||
| 
 | ||||
| 	return &models.Token{ | ||||
							
								
								
									
										21
									
								
								internal/oauth/tokenstore_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/oauth/tokenstore_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,21 @@ | |||
| /* | ||||
|    GoToSocial | ||||
|    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org | ||||
| 
 | ||||
|    This program is free software: you can redistribute it and/or modify | ||||
|    it under the terms of the GNU Affero General Public License as published by | ||||
|    the Free Software Foundation, either version 3 of the License, or | ||||
|    (at your option) any later version. | ||||
| 
 | ||||
|    This program is distributed in the hope that it will be useful, | ||||
|    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
|    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
|    GNU Affero General Public License for more details. | ||||
| 
 | ||||
|    You should have received a copy of the GNU Affero General Public License | ||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| */ | ||||
| 
 | ||||
| package oauth | ||||
| 
 | ||||
| // TODO: write tests | ||||
|  | @ -36,7 +36,7 @@ import ( | |||
| // Router provides the REST interface for gotosocial, using gin. | ||||
| type Router interface { | ||||
| 	// Attach a gin handler to the router with the given method and path | ||||
| 	AttachHandler(method string, path string, handler gin.HandlerFunc) | ||||
| 	AttachHandler(method string, path string, f gin.HandlerFunc) | ||||
| 	// Attach a gin middleware to the router that will be used globally | ||||
| 	AttachMiddleware(handler gin.HandlerFunc) | ||||
| 	// Start the router | ||||
|  | @ -59,6 +59,8 @@ func (r *router) Start() { | |||
| 			r.logger.Fatalf("listen: %s", err) | ||||
| 		} | ||||
| 	}() | ||||
| 	// c := &gin.Context{} | ||||
| 	// c.Get() | ||||
| } | ||||
| 
 | ||||
| // Stop shuts down the router nicely | ||||
|  |  | |||
							
								
								
									
										31
									
								
								pkg/mastotypes/token.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								pkg/mastotypes/token.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,31 @@ | |||
| /* | ||||
|    GoToSocial | ||||
|    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org | ||||
| 
 | ||||
|    This program is free software: you can redistribute it and/or modify | ||||
|    it under the terms of the GNU Affero General Public License as published by | ||||
|    the Free Software Foundation, either version 3 of the License, or | ||||
|    (at your option) any later version. | ||||
| 
 | ||||
|    This program is distributed in the hope that it will be useful, | ||||
|    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
|    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||
|    GNU Affero General Public License for more details. | ||||
| 
 | ||||
|    You should have received a copy of the GNU Affero General Public License | ||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||
| */ | ||||
| 
 | ||||
| package mastotypes | ||||
| 
 | ||||
| // Token represents an OAuth token used for authenticating with the API and performing actions.. See https://docs.joinmastodon.org/entities/token/ | ||||
| type Token struct { | ||||
| 	// An OAuth token to be used for authorization. | ||||
| 	AccessToken string `json:"access_token"` | ||||
| 	// The OAuth token type. Mastodon uses Bearer tokens. | ||||
| 	TokenType string `json:"token_type"` | ||||
| 	// The OAuth scopes granted by this token, space-separated. | ||||
| 	Scope string `json:"scope"` | ||||
| 	// When the token was generated. (UNIX timestamp seconds) | ||||
| 	CreatedAt int64 `json:"created_at"` | ||||
| } | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue