mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 06:52:26 -05:00 
			
		
		
		
	Oauth/token (#7)
* add host and protocol options * some fiddling * tidying up and comments * tick off /oauth/token * tidying a bit * tidying * go mod tidy * allow attaching middleware to server * add middleware * more user friendly * add comments * comments * store account + app * tidying * lots of restructuring * lint + tidy
This commit is contained in:
		
					parent
					
						
							
								4194f8d88f
							
						
					
				
			
			
				commit
				
					
						aa9ce272dc
					
				
			
		
					 30 changed files with 1346 additions and 977 deletions
				
			
		|  | @ -6,7 +6,7 @@ | ||||||
|     * [ ] /api/v1/apps/verify_credentials GET               (Verify an application works) |     * [ ] /api/v1/apps/verify_credentials GET               (Verify an application works) | ||||||
|     * [x] /oauth/authorize GET                              (Show authorize page to user) |     * [x] /oauth/authorize GET                              (Show authorize page to user) | ||||||
|     * [x] /oauth/authorize POST                             (Get an oauth access code for an app/user) |     * [x] /oauth/authorize POST                             (Get an oauth access code for an app/user) | ||||||
|     * [ ] /oauth/token POST                                 (Obtain a user-level access token) |     * [x] /oauth/token POST                                 (Obtain a user-level access token) | ||||||
|     * [ ] /oauth/revoke POST                                (Revoke a user-level access token) |     * [ ] /oauth/revoke POST                                (Revoke a user-level access token) | ||||||
|     * [x] /auth/sign_in GET                                 (Show form for user signin) |     * [x] /auth/sign_in GET                                 (Show form for user signin) | ||||||
|     * [x] /auth/sign_in POST                                (Validate username and password and sign user in) |     * [x] /auth/sign_in POST                                (Validate username and password and sign user in) | ||||||
|  |  | ||||||
|  | @ -58,6 +58,18 @@ func main() { | ||||||
| 				Value:   "", | 				Value:   "", | ||||||
| 				EnvVars: []string{envNames.ConfigPath}, | 				EnvVars: []string{envNames.ConfigPath}, | ||||||
| 			}, | 			}, | ||||||
|  | 			&cli.StringFlag{ | ||||||
|  | 				Name:    flagNames.Host, | ||||||
|  | 				Usage:   "Hostname to use for the server (eg., example.org, gotosocial.whatever.com)", | ||||||
|  | 				Value:   "localhost", | ||||||
|  | 				EnvVars: []string{envNames.Host}, | ||||||
|  | 			}, | ||||||
|  | 			&cli.StringFlag{ | ||||||
|  | 				Name:    flagNames.Protocol, | ||||||
|  | 				Usage:   "Protocol to use for the REST api of the server (only use http for debugging and tests!)", | ||||||
|  | 				Value:   "https", | ||||||
|  | 				EnvVars: []string{envNames.Protocol}, | ||||||
|  | 			}, | ||||||
| 
 | 
 | ||||||
| 			// DATABASE FLAGS | 			// DATABASE FLAGS | ||||||
| 			&cli.StringFlag{ | 			&cli.StringFlag{ | ||||||
|  |  | ||||||
|  | @ -28,6 +28,17 @@ logLevel: "info" | ||||||
| # Default: "gotosocial" | # Default: "gotosocial" | ||||||
| applicationName: "gotosocial" | applicationName: "gotosocial" | ||||||
| 
 | 
 | ||||||
|  | # String. Hostname/domain to use for the server. Defaults to localhost for local testing, | ||||||
|  | # but you should *definitely* change this when running for real, or your server won't work at all. | ||||||
|  | # Examples: ["example.org","some.server.com"] | ||||||
|  | # Default: "localhost" | ||||||
|  | host: "localhost" | ||||||
|  | 
 | ||||||
|  | # String. Protocol to use for the server. Only change to http for local testing! | ||||||
|  | # Options: ["http","https"] | ||||||
|  | # Default: "https" | ||||||
|  | protocol: "https" | ||||||
|  | 
 | ||||||
| # Config pertaining to the Gotosocial database connection | # Config pertaining to the Gotosocial database connection | ||||||
| db: | db: | ||||||
|   # String. Database type. |   # String. Database type. | ||||||
|  |  | ||||||
							
								
								
									
										2
									
								
								go.mod
									
										
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
										
									
									
									
								
							|  | @ -10,7 +10,7 @@ require ( | ||||||
| 	github.com/go-pg/pg/v10 v10.8.0 | 	github.com/go-pg/pg/v10 v10.8.0 | ||||||
| 	github.com/golang/mock v1.4.4 // indirect | 	github.com/golang/mock v1.4.4 // indirect | ||||||
| 	github.com/google/uuid v1.2.0 | 	github.com/google/uuid v1.2.0 | ||||||
| 	github.com/gotosocial/oauth2/v4 v4.2.1-0.20210318133800-45d321d259b3 | 	github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88 | ||||||
| 	github.com/onsi/ginkgo v1.15.0 // indirect | 	github.com/onsi/ginkgo v1.15.0 // indirect | ||||||
| 	github.com/onsi/gomega v1.10.5 // indirect | 	github.com/onsi/gomega v1.10.5 // indirect | ||||||
| 	github.com/sirupsen/logrus v1.8.0 | 	github.com/sirupsen/logrus v1.8.0 | ||||||
|  |  | ||||||
							
								
								
									
										4
									
								
								go.sum
									
										
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
										
									
									
									
								
							|  | @ -103,8 +103,8 @@ github.com/gorilla/sessions v1.1.3 h1:uXoZdcdA5XdXF3QzuSlheVRUvjl+1rKY7zBXL68L9R | ||||||
| github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= | github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w= | ||||||
| github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= | github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= | ||||||
| github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= | github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= | ||||||
| github.com/gotosocial/oauth2/v4 v4.2.1-0.20210318133800-45d321d259b3 h1:CKRz5d7mRum+UMR88Ue33tCYcej14WjUsB59C02DDqY= | github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88 h1:YJ//HmHOYJ4srm/LA6VPNjNisneMbY6TTM1xttV/ZQU= | ||||||
| github.com/gotosocial/oauth2/v4 v4.2.1-0.20210318133800-45d321d259b3/go.mod h1:zl5kwHf/atRUrY5yOyDnk49Us1Ygs0BzdW4jKAgoiP8= | github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88/go.mod h1:zl5kwHf/atRUrY5yOyDnk49Us1Ygs0BzdW4jKAgoiP8= | ||||||
| github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= | github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= | ||||||
| github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk= | github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk= | ||||||
| github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= | github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= | ||||||
|  |  | ||||||
|  | @ -1,87 +0,0 @@ | ||||||
| /* |  | ||||||
|    GoToSocial |  | ||||||
|    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org |  | ||||||
| 
 |  | ||||||
|    This program is free software: you can redistribute it and/or modify |  | ||||||
|    it under the terms of the GNU Affero General Public License as published by |  | ||||||
|    the Free Software Foundation, either version 3 of the License, or |  | ||||||
|    (at your option) any later version. |  | ||||||
| 
 |  | ||||||
|    This program is distributed in the hope that it will be useful, |  | ||||||
|    but WITHOUT ANY WARRANTY; without even the implied warranty of |  | ||||||
|    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the |  | ||||||
|    GNU Affero General Public License for more details. |  | ||||||
| 
 |  | ||||||
|    You should have received a copy of the GNU Affero General Public License |  | ||||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. |  | ||||||
| */ |  | ||||||
| 
 |  | ||||||
| package api |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"os" |  | ||||||
| 	"path/filepath" |  | ||||||
| 
 |  | ||||||
| 	"github.com/gin-contrib/sessions" |  | ||||||
| 	"github.com/gin-contrib/sessions/memstore" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"github.com/gotosocial/gotosocial/internal/config" |  | ||||||
| 	"github.com/sirupsen/logrus" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| type Server interface { |  | ||||||
| 	AttachHandler(method string, path string, handler gin.HandlerFunc) |  | ||||||
| 	// AttachMiddleware(handler gin.HandlerFunc) |  | ||||||
| 	GetAPIGroup() *gin.RouterGroup |  | ||||||
| 	Start() |  | ||||||
| 	Stop() |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type AddsRoutes interface { |  | ||||||
| 	AddRoutes(s Server) error |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type server struct { |  | ||||||
| 	APIGroup *gin.RouterGroup |  | ||||||
| 	logger   *logrus.Logger |  | ||||||
| 	engine   *gin.Engine |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (s *server) GetAPIGroup() *gin.RouterGroup { |  | ||||||
| 	return s.APIGroup |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (s *server) Start() { |  | ||||||
| 	// todo: start gracefully |  | ||||||
| 	if err := s.engine.Run(); err != nil { |  | ||||||
| 		s.logger.Panicf("server error: %s", err) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (s *server) Stop() { |  | ||||||
| 	// todo: shut down gracefully |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (s *server) AttachHandler(method string, path string, handler gin.HandlerFunc) { |  | ||||||
| 	if method == "ANY" { |  | ||||||
| 		s.engine.Any(path, handler) |  | ||||||
| 	} else { |  | ||||||
| 		s.engine.Handle(method, path, handler) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func New(config *config.Config, logger *logrus.Logger) Server { |  | ||||||
| 	engine := gin.New() |  | ||||||
| 	store := memstore.NewStore([]byte("authentication-key"), []byte("encryption-keyencryption-key----")) |  | ||||||
| 	engine.Use(sessions.Sessions("gotosocial-session", store)) |  | ||||||
| 	cwd, _ := os.Getwd() |  | ||||||
| 	tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", config.TemplateConfig.BaseDir)) |  | ||||||
| 	logger.Debugf("loading templates from %s", tmPath) |  | ||||||
| 	engine.LoadHTMLGlob(tmPath) |  | ||||||
| 	return &server{ |  | ||||||
| 		APIGroup: engine.Group("/api").Group("/v1"), |  | ||||||
| 		logger:   logger, |  | ||||||
| 		engine:   engine, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  | @ -29,6 +29,8 @@ import ( | ||||||
| type Config struct { | type Config struct { | ||||||
| 	LogLevel        string          `yaml:"logLevel"` | 	LogLevel        string          `yaml:"logLevel"` | ||||||
| 	ApplicationName string          `yaml:"applicationName"` | 	ApplicationName string          `yaml:"applicationName"` | ||||||
|  | 	Host            string          `yaml:"host"` | ||||||
|  | 	Protocol        string          `yaml:"protocol"` | ||||||
| 	DBConfig        *DBConfig       `yaml:"db"` | 	DBConfig        *DBConfig       `yaml:"db"` | ||||||
| 	TemplateConfig  *TemplateConfig `yaml:"template"` | 	TemplateConfig  *TemplateConfig `yaml:"template"` | ||||||
| } | } | ||||||
|  | @ -97,6 +99,14 @@ func (c *Config) ParseCLIFlags(f KeyedFlags) { | ||||||
| 		c.ApplicationName = f.String(fn.ApplicationName) | 		c.ApplicationName = f.String(fn.ApplicationName) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	if c.Host == "" || f.IsSet(fn.Host) { | ||||||
|  | 		c.Host = f.String(fn.Host) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if c.Protocol == "" || f.IsSet(fn.Protocol) { | ||||||
|  | 		c.Protocol = f.String(fn.Protocol) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// db flags | 	// db flags | ||||||
| 	if c.DBConfig.Type == "" || f.IsSet(fn.DbType) { | 	if c.DBConfig.Type == "" || f.IsSet(fn.DbType) { | ||||||
| 		c.DBConfig.Type = f.String(fn.DbType) | 		c.DBConfig.Type = f.String(fn.DbType) | ||||||
|  | @ -142,6 +152,8 @@ type Flags struct { | ||||||
| 	LogLevel        string | 	LogLevel        string | ||||||
| 	ApplicationName string | 	ApplicationName string | ||||||
| 	ConfigPath      string | 	ConfigPath      string | ||||||
|  | 	Host            string | ||||||
|  | 	Protocol        string | ||||||
| 	DbType          string | 	DbType          string | ||||||
| 	DbAddress       string | 	DbAddress       string | ||||||
| 	DbPort          string | 	DbPort          string | ||||||
|  | @ -158,6 +170,8 @@ func GetFlagNames() Flags { | ||||||
| 		LogLevel:        "log-level", | 		LogLevel:        "log-level", | ||||||
| 		ApplicationName: "application-name", | 		ApplicationName: "application-name", | ||||||
| 		ConfigPath:      "config-path", | 		ConfigPath:      "config-path", | ||||||
|  | 		Host:            "host", | ||||||
|  | 		Protocol:        "protocol", | ||||||
| 		DbType:          "db-type", | 		DbType:          "db-type", | ||||||
| 		DbAddress:       "db-address", | 		DbAddress:       "db-address", | ||||||
| 		DbPort:          "db-port", | 		DbPort:          "db-port", | ||||||
|  | @ -175,6 +189,8 @@ func GetEnvNames() Flags { | ||||||
| 		LogLevel:        "GTS_LOG_LEVEL", | 		LogLevel:        "GTS_LOG_LEVEL", | ||||||
| 		ApplicationName: "GTS_APPLICATION_NAME", | 		ApplicationName: "GTS_APPLICATION_NAME", | ||||||
| 		ConfigPath:      "GTS_CONFIG_PATH", | 		ConfigPath:      "GTS_CONFIG_PATH", | ||||||
|  | 		Host:            "GTS_HOST", | ||||||
|  | 		Protocol:        "GTS_PROTOCOL", | ||||||
| 		DbType:          "GTS_DB_TYPE", | 		DbType:          "GTS_DB_TYPE", | ||||||
| 		DbAddress:       "GTS_DB_ADDRESS", | 		DbAddress:       "GTS_DB_ADDRESS", | ||||||
| 		DbPort:          "GTS_DB_PORT", | 		DbPort:          "GTS_DB_PORT", | ||||||
|  |  | ||||||
|  | @ -28,9 +28,10 @@ import ( | ||||||
| 
 | 
 | ||||||
| // Initialize will initialize the database given in the config for use with GoToSocial | // Initialize will initialize the database given in the config for use with GoToSocial | ||||||
| var Initialize action.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { | var Initialize action.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error { | ||||||
| 	db, err := New(ctx, c, log) | 	// db, err := New(ctx, c, log) | ||||||
| 	if err != nil { | 	// if err != nil { | ||||||
| 		return err | 	// 	return err | ||||||
| 	} | 	// } | ||||||
| 	return db.CreateSchema(ctx) | 	return nil | ||||||
|  | 	// return db.CreateSchema(ctx) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -30,30 +30,47 @@ import ( | ||||||
| 
 | 
 | ||||||
| const dbTypePostgres string = "POSTGRES" | const dbTypePostgres string = "POSTGRES" | ||||||
| 
 | 
 | ||||||
| // DB provides methods for interacting with an underlying database (for now, just postgres). | // DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres). | ||||||
| // The function mapping lines up with the DB interface described in go-fed. |  | ||||||
| // See here: https://github.com/go-fed/activity/blob/master/pub/database.go |  | ||||||
| type DB interface { | type DB interface { | ||||||
| 	/* | 	// Federation returns an interface that's compatible with go-fed, for performing federation storage/retrieval functions. | ||||||
| 		GO-FED DATABASE FUNCTIONS | 	// See: https://pkg.go.dev/github.com/go-fed/activity@v1.0.0/pub?utm_source=gopls#Database | ||||||
| 	*/ | 	Federation() pub.Database | ||||||
| 	pub.Database |  | ||||||
| 
 | 
 | ||||||
| 	/* | 	// CreateTable creates a table for the given interface | ||||||
| 		ANY ADDITIONAL DESIRED FUNCTIONS | 	CreateTable(i interface{}) error | ||||||
| 	*/ |  | ||||||
| 
 | 
 | ||||||
| 	// CreateSchema should populate the database with the required tables | 	// DropTable drops the table for the given interface | ||||||
| 	CreateSchema(context.Context) error | 	DropTable(i interface{}) error | ||||||
| 
 | 
 | ||||||
| 	// Stop should stop and close the database connection cleanly, returning an error if this is not possible | 	// Stop should stop and close the database connection cleanly, returning an error if this is not possible | ||||||
| 	Stop(context.Context) error | 	Stop(ctx context.Context) error | ||||||
| 
 | 
 | ||||||
| 	// IsHealthy should return nil if the database connection is healthy, or an error if not | 	// IsHealthy should return nil if the database connection is healthy, or an error if not | ||||||
| 	IsHealthy(context.Context) error | 	IsHealthy(ctx context.Context) error | ||||||
|  | 
 | ||||||
|  | 	// GetByID gets one entry by its id. | ||||||
|  | 	GetByID(id string, i interface{}) error | ||||||
|  | 
 | ||||||
|  | 	// GetWhere gets one entry where key = value | ||||||
|  | 	GetWhere(key string, value interface{}, i interface{}) error | ||||||
|  | 
 | ||||||
|  | 	// GetAll gets all entries of interface type i | ||||||
|  | 	GetAll(i interface{}) error | ||||||
|  | 
 | ||||||
|  | 	// Put stores i | ||||||
|  | 	Put(i interface{}) error | ||||||
|  | 
 | ||||||
|  | 	// Update by id updates i with id id | ||||||
|  | 	UpdateByID(id string, i interface{}) error | ||||||
|  | 
 | ||||||
|  | 	// Delete by id removes i with id id | ||||||
|  | 	DeleteByID(id string, i interface{}) error | ||||||
|  | 
 | ||||||
|  | 	// Delete where deletes i where key = value | ||||||
|  | 	DeleteWhere(key string, value interface{}, i interface{}) error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // New returns a new database service that satisfies the Service interface and, by extension, | // New returns a new database service that satisfies the DB interface and, by extension, | ||||||
| // the go-fed database interface described here: https://github.com/go-fed/activity/blob/master/pub/database.go | // the go-fed database interface described here: https://github.com/go-fed/activity/blob/master/pub/database.go | ||||||
| func New(ctx context.Context, c *config.Config, log *logrus.Logger) (DB, error) { | func New(ctx context.Context, c *config.Config, log *logrus.Logger) (DB, error) { | ||||||
| 	switch strings.ToUpper(c.DBConfig.Type) { | 	switch strings.ToUpper(c.DBConfig.Type) { | ||||||
|  |  | ||||||
							
								
								
									
										137
									
								
								internal/db/pg-fed.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								internal/db/pg-fed.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,137 @@ | ||||||
|  | package db | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 	"net/url" | ||||||
|  | 	"sync" | ||||||
|  | 
 | ||||||
|  | 	"github.com/go-fed/activity/pub" | ||||||
|  | 	"github.com/go-fed/activity/streams" | ||||||
|  | 	"github.com/go-fed/activity/streams/vocab" | ||||||
|  | 	"github.com/go-pg/pg/v10" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type postgresFederation struct { | ||||||
|  | 	locks *sync.Map | ||||||
|  | 	conn  *pg.DB | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func newPostgresFederation(conn *pg.DB) pub.Database { | ||||||
|  | 	return &postgresFederation{ | ||||||
|  | 		locks: new(sync.Map), | ||||||
|  | 		conn:  conn, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* | ||||||
|  |    GO-FED DB INTERFACE-IMPLEMENTING FUNCTIONS | ||||||
|  | */ | ||||||
|  | func (pf *postgresFederation) Lock(ctx context.Context, id *url.URL) error { | ||||||
|  | 	// Before any other Database methods are called, the relevant `id` | ||||||
|  | 	// entries are locked to allow for fine-grained concurrency. | ||||||
|  | 
 | ||||||
|  | 	// Strategy: create a new lock, if stored, continue. Otherwise, lock the | ||||||
|  | 	// existing mutex. | ||||||
|  | 	mu := &sync.Mutex{} | ||||||
|  | 	mu.Lock() // Optimistically lock if we do store it. | ||||||
|  | 	i, loaded := pf.locks.LoadOrStore(id.String(), mu) | ||||||
|  | 	if loaded { | ||||||
|  | 		mu = i.(*sync.Mutex) | ||||||
|  | 		mu.Lock() | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) Unlock(ctx context.Context, id *url.URL) error { | ||||||
|  | 	// Once Go-Fed is done calling Database methods, the relevant `id` | ||||||
|  | 	// entries are unlocked. | ||||||
|  | 
 | ||||||
|  | 	i, ok := pf.locks.Load(id.String()) | ||||||
|  | 	if !ok { | ||||||
|  | 		return errors.New("missing an id in unlock") | ||||||
|  | 	} | ||||||
|  | 	mu := i.(*sync.Mutex) | ||||||
|  | 	mu.Unlock() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) InboxContains(ctx context.Context, inbox *url.URL, id *url.URL) (bool, error) { | ||||||
|  | 	return false, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) GetInbox(ctx context.Context, inboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) { | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) SetInbox(ctx context.Context, inbox vocab.ActivityStreamsOrderedCollectionPage) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) Owns(ctx context.Context, id *url.URL) (owns bool, err error) { | ||||||
|  | 	return false, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) { | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) { | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) { | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) Exists(ctx context.Context, id *url.URL) (exists bool, err error) { | ||||||
|  | 	return false, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) { | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) Create(ctx context.Context, asType vocab.Type) error { | ||||||
|  | 	t, err := streams.NewTypeResolver() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if err := t.Resolve(ctx, asType); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	asType.GetTypeName() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) Update(ctx context.Context, asType vocab.Type) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) Delete(ctx context.Context, id *url.URL) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) { | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) NewID(ctx context.Context, t vocab.Type) (id *url.URL, err error) { | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) Following(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pf *postgresFederation) Liked(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  | @ -22,30 +22,26 @@ import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" |  | ||||||
| 	"regexp" | 	"regexp" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" |  | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/go-fed/activity/streams" | 	"github.com/go-fed/activity/pub" | ||||||
| 	"github.com/go-fed/activity/streams/vocab" |  | ||||||
| 	"github.com/go-pg/pg/extra/pgdebug" | 	"github.com/go-pg/pg/extra/pgdebug" | ||||||
| 	"github.com/go-pg/pg/v10" | 	"github.com/go-pg/pg/v10" | ||||||
| 	"github.com/go-pg/pg/v10/orm" | 	"github.com/go-pg/pg/v10/orm" | ||||||
| 	"github.com/gotosocial/gotosocial/internal/config" | 	"github.com/gotosocial/gotosocial/internal/config" | ||||||
| 	"github.com/gotosocial/gotosocial/internal/gtsmodel" | 	"github.com/gotosocial/gotosocial/internal/gtsmodel" | ||||||
| 	"github.com/gotosocial/oauth2/v4" |  | ||||||
| 	"github.com/sirupsen/logrus" | 	"github.com/sirupsen/logrus" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // postgresService satisfies the DB interface | ||||||
| type postgresService struct { | type postgresService struct { | ||||||
| 	config     *config.DBConfig | 	config       *config.DBConfig | ||||||
| 	conn       *pg.DB | 	conn         *pg.DB | ||||||
| 	log        *logrus.Entry | 	log          *logrus.Entry | ||||||
| 	cancel     context.CancelFunc | 	cancel       context.CancelFunc | ||||||
| 	locks      *sync.Map | 	federationDB pub.Database | ||||||
| 	tokenStore oauth2.TokenStore |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. | // newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface. | ||||||
|  | @ -102,36 +98,20 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry | ||||||
| 		return nil, errors.New("db connection timeout") | 		return nil, errors.New("db connection timeout") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// acc := model.StubAccount() |  | ||||||
| 	// if _, err := conn.Model(acc).Returning("id").Insert(); err != nil { |  | ||||||
| 	// 	cancel() |  | ||||||
| 	// 	return nil, fmt.Errorf("db insert error: %s", err) |  | ||||||
| 	// } |  | ||||||
| 	// log.Infof("created account with id %s", acc.ID) |  | ||||||
| 
 |  | ||||||
| 	// note := &model.Note{ |  | ||||||
| 	// 	Visibility: &model.Visibility{ |  | ||||||
| 	// 		Local: true, |  | ||||||
| 	// 	}, |  | ||||||
| 	// 	CreatedAt: time.Now(), |  | ||||||
| 	// 	UpdatedAt: time.Now(), |  | ||||||
| 	// } |  | ||||||
| 	// if _, err := conn.WithContext(ctx).Model(note).Returning("id").Insert(); err != nil { |  | ||||||
| 	// 	cancel() |  | ||||||
| 	// 	return nil, fmt.Errorf("db insert error: %s", err) |  | ||||||
| 	// } |  | ||||||
| 	// log.Infof("created note with id %s", note.ID) |  | ||||||
| 
 |  | ||||||
| 	// we can confidently return this useable postgres service now | 	// we can confidently return this useable postgres service now | ||||||
| 	return &postgresService{ | 	return &postgresService{ | ||||||
| 		config: c.DBConfig, | 		config:       c.DBConfig, | ||||||
| 		conn:   conn, | 		conn:         conn, | ||||||
| 		log:    log, | 		log:          log, | ||||||
| 		cancel: cancel, | 		cancel:       cancel, | ||||||
| 		locks:  &sync.Map{}, | 		federationDB: newPostgresFederation(conn), | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (ps *postgresService) Federation() pub.Database { | ||||||
|  | 	return ps.federationDB | ||||||
|  | } | ||||||
|  | 
 | ||||||
| /* | /* | ||||||
| 	HANDY STUFF | 	HANDY STUFF | ||||||
| */ | */ | ||||||
|  | @ -187,118 +167,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) { | ||||||
| 	return options, nil | 	return options, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* |  | ||||||
|    GO-FED DB INTERFACE-IMPLEMENTING FUNCTIONS |  | ||||||
| */ |  | ||||||
| func (ps *postgresService) Lock(ctx context.Context, id *url.URL) error { |  | ||||||
| 	// Before any other Database methods are called, the relevant `id` |  | ||||||
| 	// entries are locked to allow for fine-grained concurrency. |  | ||||||
| 
 |  | ||||||
| 	// Strategy: create a new lock, if stored, continue. Otherwise, lock the |  | ||||||
| 	// existing mutex. |  | ||||||
| 	mu := &sync.Mutex{} |  | ||||||
| 	mu.Lock() // Optimistically lock if we do store it. |  | ||||||
| 	i, loaded := ps.locks.LoadOrStore(id.String(), mu) |  | ||||||
| 	if loaded { |  | ||||||
| 		mu = i.(*sync.Mutex) |  | ||||||
| 		mu.Lock() |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) Unlock(ctx context.Context, id *url.URL) error { |  | ||||||
| 	// Once Go-Fed is done calling Database methods, the relevant `id` |  | ||||||
| 	// entries are unlocked. |  | ||||||
| 
 |  | ||||||
| 	i, ok := ps.locks.Load(id.String()) |  | ||||||
| 	if !ok { |  | ||||||
| 		return errors.New("missing an id in unlock") |  | ||||||
| 	} |  | ||||||
| 	mu := i.(*sync.Mutex) |  | ||||||
| 	mu.Unlock() |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) InboxContains(ctx context.Context, inbox *url.URL, id *url.URL) (bool, error) { |  | ||||||
| 	return false, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) GetInbox(ctx context.Context, inboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) { |  | ||||||
| 	return nil, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) SetInbox(ctx context.Context, inbox vocab.ActivityStreamsOrderedCollectionPage) error { |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) Owns(ctx context.Context, id *url.URL) (owns bool, err error) { |  | ||||||
| 	return false, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) { |  | ||||||
| 	return nil, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) { |  | ||||||
| 	return nil, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) { |  | ||||||
| 	return nil, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) Exists(ctx context.Context, id *url.URL) (exists bool, err error) { |  | ||||||
| 	return false, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) { |  | ||||||
| 	return nil, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) Create(ctx context.Context, asType vocab.Type) error { |  | ||||||
| 	t, err := streams.NewTypeResolver() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	if err := t.Resolve(ctx, asType); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	asType.GetTypeName() |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) Update(ctx context.Context, asType vocab.Type) error { |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) Delete(ctx context.Context, id *url.URL) error { |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) { |  | ||||||
| 	return nil, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error { |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) NewID(ctx context.Context, t vocab.Type) (id *url.URL, err error) { |  | ||||||
| 	return nil, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { |  | ||||||
| 	return nil, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) Following(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { |  | ||||||
| 	return nil, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ps *postgresService) Liked(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) { |  | ||||||
| 	return nil, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /* | /* | ||||||
| 	EXTRA FUNCTIONS | 	EXTRA FUNCTIONS | ||||||
| */ | */ | ||||||
|  | @ -338,6 +206,46 @@ func (ps *postgresService) IsHealthy(ctx context.Context) error { | ||||||
| 	return ps.conn.Ping(ctx) | 	return ps.conn.Ping(ctx) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (ps *postgresService) TokenStore() oauth2.TokenStore { | func (ps *postgresService) CreateTable(i interface{}) error { | ||||||
| 	return ps.tokenStore | 	return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{ | ||||||
|  | 		IfNotExists: true, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (ps *postgresService) DropTable(i interface{}) error { | ||||||
|  | 	return ps.conn.Model(i).DropTable(&orm.DropTableOptions{ | ||||||
|  | 		IfExists: true, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (ps *postgresService) GetByID(id string, i interface{}) error { | ||||||
|  | 	return ps.conn.Model(i).Where("id = ?", id).Select() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error { | ||||||
|  | 	return ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (ps *postgresService) GetAll(i interface{}) error { | ||||||
|  | 	return ps.conn.Model(i).Select() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (ps *postgresService) Put(i interface{}) error { | ||||||
|  | 	_, err := ps.conn.Model(i).Insert(i) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (ps *postgresService) UpdateByID(id string, i interface{}) error { | ||||||
|  | 	_, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert() | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (ps *postgresService) DeleteByID(id string, i interface{}) error { | ||||||
|  | 	_, err := ps.conn.Model(i).Where("id = ?", id).Delete() | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error { | ||||||
|  | 	_, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete() | ||||||
|  | 	return err | ||||||
| } | } | ||||||
|  | @ -16,5 +16,5 @@ | ||||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. |    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
| */ | */ | ||||||
| 
 | 
 | ||||||
| // package email provides a service for interacting with an SMTP server | // Package email provides a service for interacting with an SMTP server | ||||||
| package email | package email | ||||||
|  |  | ||||||
|  | @ -30,11 +30,13 @@ import ( | ||||||
| 	"github.com/gotosocial/gotosocial/internal/db" | 	"github.com/gotosocial/gotosocial/internal/db" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // New returns a go-fed compatible federating actor | ||||||
| func New(db db.DB) pub.FederatingActor { | func New(db db.DB) pub.FederatingActor { | ||||||
| 	fa := &API{} | 	fa := &API{} | ||||||
| 	return pub.NewFederatingActor(fa, fa, db, fa) | 	return pub.NewFederatingActor(fa, fa, db.Federation(), fa) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // API implements several go-fed interfaces in one convenient location | ||||||
| type API struct { | type API struct { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -38,9 +38,9 @@ var Run action.GTSAction = func(ctx context.Context, c *config.Config, log *logr | ||||||
| 		return fmt.Errorf("error creating dbservice: %s", err) | 		return fmt.Errorf("error creating dbservice: %s", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := dbService.CreateSchema(ctx); err != nil { | 	// if err := dbService.CreateSchema(ctx); err != nil { | ||||||
| 		return fmt.Errorf("error creating dbschema: %s", err) | 	// 	return fmt.Errorf("error creating dbschema: %s", err) | ||||||
| 	} | 	// } | ||||||
| 
 | 
 | ||||||
| 	// catch shutdown signals from the operating system | 	// catch shutdown signals from the operating system | ||||||
| 	sigs := make(chan os.Signal, 1) | 	sigs := make(chan os.Signal, 1) | ||||||
|  |  | ||||||
|  | @ -22,10 +22,10 @@ import ( | ||||||
| 	"context" | 	"context" | ||||||
| 
 | 
 | ||||||
| 	"github.com/go-fed/activity/pub" | 	"github.com/go-fed/activity/pub" | ||||||
| 	"github.com/gotosocial/gotosocial/internal/api" |  | ||||||
| 	"github.com/gotosocial/gotosocial/internal/cache" | 	"github.com/gotosocial/gotosocial/internal/cache" | ||||||
| 	"github.com/gotosocial/gotosocial/internal/config" | 	"github.com/gotosocial/gotosocial/internal/config" | ||||||
| 	"github.com/gotosocial/gotosocial/internal/db" | 	"github.com/gotosocial/gotosocial/internal/db" | ||||||
|  | 	"github.com/gotosocial/gotosocial/internal/router" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Gotosocial interface { | type Gotosocial interface { | ||||||
|  | @ -33,11 +33,11 @@ type Gotosocial interface { | ||||||
| 	Stop(context.Context) error | 	Stop(context.Context) error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func New(db db.DB, cache cache.Cache, clientAPI api.Server, federationAPI pub.FederatingActor, config *config.Config) (Gotosocial, error) { | func New(db db.DB, cache cache.Cache, apiRouter router.Router, federationAPI pub.FederatingActor, config *config.Config) (Gotosocial, error) { | ||||||
| 	return &gotosocial{ | 	return &gotosocial{ | ||||||
| 		db:            db, | 		db:            db, | ||||||
| 		cache:         cache, | 		cache:         cache, | ||||||
| 		clientAPI:     clientAPI, | 		apiRouter:     apiRouter, | ||||||
| 		federationAPI: federationAPI, | 		federationAPI: federationAPI, | ||||||
| 		config:        config, | 		config:        config, | ||||||
| 	}, nil | 	}, nil | ||||||
|  | @ -46,7 +46,7 @@ func New(db db.DB, cache cache.Cache, clientAPI api.Server, federationAPI pub.Fe | ||||||
| type gotosocial struct { | type gotosocial struct { | ||||||
| 	db            db.DB | 	db            db.DB | ||||||
| 	cache         cache.Cache | 	cache         cache.Cache | ||||||
| 	clientAPI     api.Server | 	apiRouter     router.Router | ||||||
| 	federationAPI pub.FederatingActor | 	federationAPI pub.FederatingActor | ||||||
| 	config        *config.Config | 	config        *config.Config | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -16,7 +16,7 @@ | ||||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. |    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
| */ | */ | ||||||
| 
 | 
 | ||||||
| // package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database. | // Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database. | ||||||
| // These types should never be serialized and/or sent out via public APIs, as they contain sensitive information. | // These types should never be serialized and/or sent out via public APIs, as they contain sensitive information. | ||||||
| // The annotation used on these structs is for handling them via the go-pg ORM. See here: https://pg.uptrace.dev/models/ | // The annotation used on these structs is for handling them via the go-pg ORM. See here: https://pg.uptrace.dev/models/ | ||||||
| package gtsmodel | package gtsmodel | ||||||
|  |  | ||||||
|  | @ -18,13 +18,38 @@ | ||||||
| 
 | 
 | ||||||
| package gtsmodel | package gtsmodel | ||||||
| 
 | 
 | ||||||
|  | import "github.com/gotosocial/gotosocial/pkg/mastotypes" | ||||||
|  | 
 | ||||||
|  | // Application represents an application that can perform actions on behalf of a user. | ||||||
|  | // It is used to authorize tokens etc, and is associated with an oauth client id in the database. | ||||||
| type Application struct { | type Application struct { | ||||||
| 	ID           string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` | 	// id of this application in the db | ||||||
| 	Name         string | 	ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` | ||||||
| 	Website      string | 	// name of the application given when it was created (eg., 'tusky') | ||||||
| 	RedirectURI  string `json:"redirect_uri"` | 	Name string | ||||||
| 	ClientID     string `json:"client_id"` | 	// website for the application given when it was created (eg., 'https://tusky.app') | ||||||
| 	ClientSecret string `json:"client_secret"` | 	Website string | ||||||
| 	Scopes       string `json:"scopes"` | 	// redirect uri requested by the application for oauth2 flow | ||||||
| 	VapidKey     string `json:"vapid_key"` | 	RedirectURI string | ||||||
|  | 	// id of the associated oauth client entity in the db | ||||||
|  | 	ClientID string | ||||||
|  | 	// secret of the associated oauth client entity in the db | ||||||
|  | 	ClientSecret string | ||||||
|  | 	// scopes requested when this app was created | ||||||
|  | 	Scopes string | ||||||
|  | 	// a vapid key generated for this app when it was created | ||||||
|  | 	VapidKey string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ToMastotype returns this application as a mastodon api type, ready for serialization | ||||||
|  | func (a *Application) ToMastotype() *mastotypes.Application { | ||||||
|  | 	return &mastotypes.Application{ | ||||||
|  | 		ID:           a.ID, | ||||||
|  | 		Name:         a.Name, | ||||||
|  | 		Website:      a.Website, | ||||||
|  | 		RedirectURI:  a.RedirectURI, | ||||||
|  | 		ClientID:     a.ClientID, | ||||||
|  | 		ClientSecret: a.ClientSecret, | ||||||
|  | 		VapidKey:     a.VapidKey, | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -20,25 +20,44 @@ package gtsmodel | ||||||
| 
 | 
 | ||||||
| import "time" | import "time" | ||||||
| 
 | 
 | ||||||
|  | // Status represents a user-created 'post' or 'status' in the database, either remote or local | ||||||
| type Status struct { | type Status struct { | ||||||
| 	ID             string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` | 	// id of the status in the database | ||||||
| 	URI            string `pg:",unique"` | 	ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` | ||||||
| 	URL            string `pg:",unique"` | 	// uri at which this status is reachable | ||||||
| 	Content        string | 	URI string `pg:",unique"` | ||||||
| 	CreatedAt      time.Time `pg:"type:timestamp,notnull,default:now()"` | 	// web url for viewing this status | ||||||
| 	UpdatedAt      time.Time `pg:"type:timestamp,notnull,default:now()"` | 	URL string `pg:",unique"` | ||||||
| 	Local          bool | 	// the html-formatted content of this status | ||||||
| 	AccountID      string | 	Content string | ||||||
| 	InReplyToID    string | 	// when was this status created? | ||||||
| 	BoostOfID      string | 	CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` | ||||||
|  | 	// when was this status updated? | ||||||
|  | 	UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` | ||||||
|  | 	// is this status from a local account? | ||||||
|  | 	Local bool | ||||||
|  | 	// which account posted this status? | ||||||
|  | 	AccountID string | ||||||
|  | 	// id of the status this status is a reply to | ||||||
|  | 	InReplyToID string | ||||||
|  | 	// id of the status this status is a boost of | ||||||
|  | 	BoostOfID string | ||||||
|  | 	// cw string for this status | ||||||
| 	ContentWarning string | 	ContentWarning string | ||||||
| 	Visibility     *Visibility | 	// visibility entry for this status | ||||||
|  | 	Visibility *Visibility | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Visibility represents the visibility granularity of a status. It is a combination of flags. | ||||||
| type Visibility struct { | type Visibility struct { | ||||||
| 	Direct    bool | 	// Is this status viewable as a direct message? | ||||||
|  | 	Direct bool | ||||||
|  | 	// Is this status viewable to followers? | ||||||
| 	Followers bool | 	Followers bool | ||||||
| 	Local     bool | 	// Is this status viewable on the local timeline? | ||||||
| 	Unlisted  bool | 	Local bool | ||||||
| 	Public    bool | 	// Is this status boostable but not shown on public timelines? | ||||||
|  | 	Unlisted bool | ||||||
|  | 	// Is this status shown on public and federated timelines? | ||||||
|  | 	Public bool | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -16,4 +16,22 @@ | ||||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. |    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
| */ | */ | ||||||
| 
 | 
 | ||||||
| package api | package account | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"github.com/gotosocial/gotosocial/internal/module" | ||||||
|  | 	"github.com/gotosocial/gotosocial/internal/router" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type accountModule struct { | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New returns a new account module | ||||||
|  | func New() module.ClientAPIModule { | ||||||
|  | 	return &accountModule{} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Route attaches all routes from this module to the given router | ||||||
|  | func (m *accountModule) Route(r router.Router) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
							
								
								
									
										29
									
								
								internal/module/module.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								internal/module/module.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,29 @@ | ||||||
|  | /* | ||||||
|  |    GoToSocial | ||||||
|  |    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org | ||||||
|  | 
 | ||||||
|  |    This program is free software: you can redistribute it and/or modify | ||||||
|  |    it under the terms of the GNU Affero General Public License as published by | ||||||
|  |    the Free Software Foundation, either version 3 of the License, or | ||||||
|  |    (at your option) any later version. | ||||||
|  | 
 | ||||||
|  |    This program is distributed in the hope that it will be useful, | ||||||
|  |    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  |    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  |    GNU Affero General Public License for more details. | ||||||
|  | 
 | ||||||
|  |    You should have received a copy of the GNU Affero General Public License | ||||||
|  |    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | */ | ||||||
|  | 
 | ||||||
|  | // Package module is basically a wrapper for a lot of modules (in subdirectories) that satisfy the ClientAPIModule interface. | ||||||
|  | package module | ||||||
|  | 
 | ||||||
|  | import "github.com/gotosocial/gotosocial/internal/router" | ||||||
|  | 
 | ||||||
|  | // ClientAPIModule represents a chunk of code (usually contained in a single package) that adds a set | ||||||
|  | // of functionalities and side effects to a router, by mapping routes and handlers onto it--in other words, a REST API ;) | ||||||
|  | // A ClientAPIMpdule corresponds roughly to one main path of the gotosocial REST api, for example /api/v1/accounts/ or /oauth/ | ||||||
|  | type ClientAPIModule interface { | ||||||
|  | 	Route(s router.Router) error | ||||||
|  | } | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
| # oauth | # oauth | ||||||
| 
 | 
 | ||||||
| This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) server functionality to the GoToSocial APIs. | This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API. | ||||||
|  | 
 | ||||||
|  | It also provides a handler/middleware for attaching to the Gin engine for validating authenticated users. | ||||||
|  | @ -22,55 +22,47 @@ import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 
 | 
 | ||||||
| 	"github.com/go-pg/pg/v10" | 	"github.com/gotosocial/gotosocial/internal/db" | ||||||
| 	"github.com/gotosocial/oauth2/v4" | 	"github.com/gotosocial/oauth2/v4" | ||||||
| 	"github.com/gotosocial/oauth2/v4/models" | 	"github.com/gotosocial/oauth2/v4/models" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type pgClientStore struct { | type clientStore struct { | ||||||
| 	conn *pg.DB | 	db db.DB | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewPGClientStore(conn *pg.DB) oauth2.ClientStore { | func newClientStore(db db.DB) oauth2.ClientStore { | ||||||
| 	pts := &pgClientStore{ | 	pts := &clientStore{ | ||||||
| 		conn: conn, | 		db: db, | ||||||
| 	} | 	} | ||||||
| 	return pts | 	return pts | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (pcs *pgClientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { | func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { | ||||||
| 	poc := &oauthClient{ | 	poc := &oauthClient{ | ||||||
| 		ID: clientID, | 		ID: clientID, | ||||||
| 	} | 	} | ||||||
| 	if err := pcs.conn.WithContext(ctx).Model(poc).Where("id = ?", poc.ID).Select(); err != nil { | 	if err := cs.db.GetByID(clientID, poc); err != nil { | ||||||
| 		return nil, fmt.Errorf("error in clientstore getbyid searching for client %s: %s", clientID, err) | 		return nil, fmt.Errorf("database error: %s", err) | ||||||
| 	} | 	} | ||||||
| 	return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil | 	return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (pcs *pgClientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { | func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { | ||||||
| 	poc := &oauthClient{ | 	poc := &oauthClient{ | ||||||
| 		ID:     cli.GetID(), | 		ID:     cli.GetID(), | ||||||
| 		Secret: cli.GetSecret(), | 		Secret: cli.GetSecret(), | ||||||
| 		Domain: cli.GetDomain(), | 		Domain: cli.GetDomain(), | ||||||
| 		UserID: cli.GetUserID(), | 		UserID: cli.GetUserID(), | ||||||
| 	} | 	} | ||||||
| 	_, err := pcs.conn.WithContext(ctx).Model(poc).OnConflict("(id) DO UPDATE").Insert() | 	return cs.db.UpdateByID(id, poc) | ||||||
| 	if err != nil { |  | ||||||
| 		return fmt.Errorf("error in clientstore set: %s", err) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (pcs *pgClientStore) Delete(ctx context.Context, id string) error { | func (cs *clientStore) Delete(ctx context.Context, id string) error { | ||||||
| 	poc := &oauthClient{ | 	poc := &oauthClient{ | ||||||
| 		ID: id, | 		ID: id, | ||||||
| 	} | 	} | ||||||
| 	_, err := pcs.conn.WithContext(ctx).Model(poc).Where("id = ?", poc.ID).Delete() | 	return cs.db.DeleteByID(id, poc) | ||||||
| 	if err != nil { |  | ||||||
| 		return fmt.Errorf("error in clientstore delete: %s", err) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type oauthClient struct { | type oauthClient struct { | ||||||
|  | @ -1,11 +1,28 @@ | ||||||
|  | /* | ||||||
|  |    GoToSocial | ||||||
|  |    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org | ||||||
|  | 
 | ||||||
|  |    This program is free software: you can redistribute it and/or modify | ||||||
|  |    it under the terms of the GNU Affero General Public License as published by | ||||||
|  |    the Free Software Foundation, either version 3 of the License, or | ||||||
|  |    (at your option) any later version. | ||||||
|  | 
 | ||||||
|  |    This program is distributed in the hope that it will be useful, | ||||||
|  |    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  |    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  |    GNU Affero General Public License for more details. | ||||||
|  | 
 | ||||||
|  |    You should have received a copy of the GNU Affero General Public License | ||||||
|  |    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | */ | ||||||
| package oauth | package oauth | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"github.com/go-pg/pg/v10" | 	"github.com/gotosocial/gotosocial/internal/config" | ||||||
| 	"github.com/go-pg/pg/v10/orm" | 	"github.com/gotosocial/gotosocial/internal/db" | ||||||
| 	"github.com/gotosocial/oauth2/v4/models" | 	"github.com/gotosocial/oauth2/v4/models" | ||||||
| 	"github.com/sirupsen/logrus" | 	"github.com/sirupsen/logrus" | ||||||
| 	"github.com/stretchr/testify/suite" | 	"github.com/stretchr/testify/suite" | ||||||
|  | @ -13,7 +30,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| type PgClientStoreTestSuite struct { | type PgClientStoreTestSuite struct { | ||||||
| 	suite.Suite | 	suite.Suite | ||||||
| 	conn             *pg.DB | 	db               db.DB | ||||||
| 	testClientID     string | 	testClientID     string | ||||||
| 	testClientSecret string | 	testClientSecret string | ||||||
| 	testClientDomain string | 	testClientDomain string | ||||||
|  | @ -32,31 +49,55 @@ func (suite *PgClientStoreTestSuite) SetupSuite() { | ||||||
| 
 | 
 | ||||||
| // SetupTest creates a postgres connection and creates the oauth_clients table before each test | // SetupTest creates a postgres connection and creates the oauth_clients table before each test | ||||||
| func (suite *PgClientStoreTestSuite) SetupTest() { | func (suite *PgClientStoreTestSuite) SetupTest() { | ||||||
| 	suite.conn = pg.Connect(&pg.Options{}) | 	log := logrus.New() | ||||||
| 	if err := suite.conn.Ping(context.Background()); err != nil { | 	log.SetLevel(logrus.TraceLevel) | ||||||
| 		logrus.Panicf("db connection error: %s", err) | 	c := config.Empty() | ||||||
|  | 	c.DBConfig = &config.DBConfig{ | ||||||
|  | 		Type:            "postgres", | ||||||
|  | 		Address:         "localhost", | ||||||
|  | 		Port:            5432, | ||||||
|  | 		User:            "postgres", | ||||||
|  | 		Password:        "postgres", | ||||||
|  | 		Database:        "postgres", | ||||||
|  | 		ApplicationName: "gotosocial", | ||||||
| 	} | 	} | ||||||
| 	if err := suite.conn.Model(&oauthClient{}).CreateTable(&orm.CreateTableOptions{ | 	db, err := db.New(context.Background(), c, log) | ||||||
| 		IfNotExists: true, | 	if err != nil { | ||||||
| 	}); err != nil { | 		logrus.Panicf("error creating database connection: %s", err) | ||||||
| 		logrus.Panicf("db connection error: %s", err) | 	} | ||||||
|  | 
 | ||||||
|  | 	suite.db = db | ||||||
|  | 
 | ||||||
|  | 	models := []interface{}{ | ||||||
|  | 		&oauthClient{}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, m := range models { | ||||||
|  | 		if err := suite.db.CreateTable(m); err != nil { | ||||||
|  | 			logrus.Panicf("db connection error: %s", err) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // TearDownTest drops the oauth_clients table and closes the pg connection after each test | // TearDownTest drops the oauth_clients table and closes the pg connection after each test | ||||||
| func (suite *PgClientStoreTestSuite) TearDownTest() { | func (suite *PgClientStoreTestSuite) TearDownTest() { | ||||||
| 	if err := suite.conn.Model(&oauthClient{}).DropTable(&orm.DropTableOptions{}); err != nil { | 	models := []interface{}{ | ||||||
| 		logrus.Panicf("drop table error: %s", err) | 		&oauthClient{}, | ||||||
| 	} | 	} | ||||||
| 	if err := suite.conn.Close(); err != nil { | 	for _, m := range models { | ||||||
|  | 		if err := suite.db.DropTable(m); err != nil { | ||||||
|  | 			logrus.Panicf("error dropping table: %s", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if err := suite.db.Stop(context.Background()); err != nil { | ||||||
| 		logrus.Panicf("error closing db connection: %s", err) | 		logrus.Panicf("error closing db connection: %s", err) | ||||||
| 	} | 	} | ||||||
| 	suite.conn = nil | 	suite.db = nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() { | func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() { | ||||||
| 	// set a new client in the store | 	// set a new client in the store | ||||||
| 	cs := NewPGClientStore(suite.conn) | 	cs := newClientStore(suite.db) | ||||||
| 	if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil { | 	if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
| 	} | 	} | ||||||
|  | @ -74,7 +115,7 @@ func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() { | ||||||
| 
 | 
 | ||||||
| func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() { | func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() { | ||||||
| 	// set a new client in the store | 	// set a new client in the store | ||||||
| 	cs := NewPGClientStore(suite.conn) | 	cs := newClientStore(suite.db) | ||||||
| 	if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil { | 	if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil { | ||||||
| 		suite.FailNow(err.Error()) | 		suite.FailNow(err.Error()) | ||||||
| 	} | 	} | ||||||
							
								
								
									
										510
									
								
								internal/module/oauth/oauth.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										510
									
								
								internal/module/oauth/oauth.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,510 @@ | ||||||
|  | /* | ||||||
|  |    GoToSocial | ||||||
|  |    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org | ||||||
|  | 
 | ||||||
|  |    This program is free software: you can redistribute it and/or modify | ||||||
|  |    it under the terms of the GNU Affero General Public License as published by | ||||||
|  |    the Free Software Foundation, either version 3 of the License, or | ||||||
|  |    (at your option) any later version. | ||||||
|  | 
 | ||||||
|  |    This program is distributed in the hope that it will be useful, | ||||||
|  |    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  |    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  |    GNU Affero General Public License for more details. | ||||||
|  | 
 | ||||||
|  |    You should have received a copy of the GNU Affero General Public License | ||||||
|  |    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | */ | ||||||
|  | 
 | ||||||
|  | // Package oauth is a module that provides oauth functionality to a router. | ||||||
|  | // It adds the following paths: | ||||||
|  | //    /api/v1/apps | ||||||
|  | //    /auth/sign_in | ||||||
|  | //    /oauth/token | ||||||
|  | //    /oauth/authorize | ||||||
|  | // It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token. | ||||||
|  | package oauth | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/url" | ||||||
|  | 
 | ||||||
|  | 	"github.com/gin-contrib/sessions" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/google/uuid" | ||||||
|  | 	"github.com/gotosocial/gotosocial/internal/db" | ||||||
|  | 	"github.com/gotosocial/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/gotosocial/gotosocial/internal/module" | ||||||
|  | 	"github.com/gotosocial/gotosocial/internal/router" | ||||||
|  | 	"github.com/gotosocial/gotosocial/pkg/mastotypes" | ||||||
|  | 	"github.com/gotosocial/oauth2/v4" | ||||||
|  | 	"github.com/gotosocial/oauth2/v4/errors" | ||||||
|  | 	"github.com/gotosocial/oauth2/v4/manage" | ||||||
|  | 	"github.com/gotosocial/oauth2/v4/server" | ||||||
|  | 	"github.com/sirupsen/logrus" | ||||||
|  | 	"golang.org/x/crypto/bcrypt" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	appsPath           = "/api/v1/apps" | ||||||
|  | 	authSignInPath     = "/auth/sign_in" | ||||||
|  | 	oauthTokenPath     = "/oauth/token" | ||||||
|  | 	oauthAuthorizePath = "/oauth/authorize" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // oauthModule is an oauth2 oauthModule that satisfies the ClientAPIModule interface | ||||||
|  | type oauthModule struct { | ||||||
|  | 	oauthManager *manage.Manager | ||||||
|  | 	oauthServer  *server.Server | ||||||
|  | 	db           db.DB | ||||||
|  | 	log          *logrus.Logger | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type login struct { | ||||||
|  | 	Email    string `form:"username"` | ||||||
|  | 	Password string `form:"password"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New returns a new oauth module | ||||||
|  | func New(ts oauth2.TokenStore, cs oauth2.ClientStore, db db.DB, log *logrus.Logger) module.ClientAPIModule { | ||||||
|  | 	manager := manage.NewDefaultManager() | ||||||
|  | 	manager.MapTokenStorage(ts) | ||||||
|  | 	manager.MapClientStorage(cs) | ||||||
|  | 	manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) | ||||||
|  | 	sc := &server.Config{ | ||||||
|  | 		TokenType: "Bearer", | ||||||
|  | 		// Must follow the spec. | ||||||
|  | 		AllowGetAccessRequest: false, | ||||||
|  | 		// Support only the non-implicit flow. | ||||||
|  | 		AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, | ||||||
|  | 		// Allow: | ||||||
|  | 		// - Authorization Code (for first & third parties) | ||||||
|  | 		AllowedGrantTypes: []oauth2.GrantType{ | ||||||
|  | 			oauth2.AuthorizationCode, | ||||||
|  | 		}, | ||||||
|  | 		AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	srv := server.NewServer(sc, manager) | ||||||
|  | 	srv.SetInternalErrorHandler(func(err error) *errors.Response { | ||||||
|  | 		log.Errorf("internal oauth error: %s", err) | ||||||
|  | 		return nil | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	srv.SetResponseErrorHandler(func(re *errors.Response) { | ||||||
|  | 		log.Errorf("internal response error: %s", re.Error) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	m := &oauthModule{ | ||||||
|  | 		oauthManager: manager, | ||||||
|  | 		oauthServer:  srv, | ||||||
|  | 		db:           db, | ||||||
|  | 		log:          log, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	m.oauthServer.SetUserAuthorizationHandler(m.userAuthorizationHandler) | ||||||
|  | 	m.oauthServer.SetClientInfoHandler(server.ClientFormHandler) | ||||||
|  | 	return m | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Route satisfies the RESTAPIModule interface | ||||||
|  | func (m *oauthModule) Route(s router.Router) error { | ||||||
|  | 	s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler) | ||||||
|  | 
 | ||||||
|  | 	s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler) | ||||||
|  | 	s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler) | ||||||
|  | 
 | ||||||
|  | 	s.AttachHandler(http.MethodPost, oauthTokenPath, m.tokenPOSTHandler) | ||||||
|  | 
 | ||||||
|  | 	s.AttachHandler(http.MethodGet, oauthAuthorizePath, m.authorizeGETHandler) | ||||||
|  | 	s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler) | ||||||
|  | 
 | ||||||
|  | 	s.AttachMiddleware(m.oauthTokenMiddleware) | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* | ||||||
|  | 	MAIN HANDLERS -- serve these through a server/router | ||||||
|  | */ | ||||||
|  | 
 | ||||||
|  | // appsPOSTHandler should be served at https://example.org/api/v1/apps | ||||||
|  | // It is equivalent to: https://docs.joinmastodon.org/methods/apps/ | ||||||
|  | func (m *oauthModule) appsPOSTHandler(c *gin.Context) { | ||||||
|  | 	l := m.log.WithField("func", "AppsPOSTHandler") | ||||||
|  | 	l.Trace("entering AppsPOSTHandler") | ||||||
|  | 
 | ||||||
|  | 	form := &mastotypes.ApplicationPOSTRequest{} | ||||||
|  | 	if err := c.ShouldBind(form); err != nil { | ||||||
|  | 		c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// permitted length for most fields | ||||||
|  | 	permittedLength := 64 | ||||||
|  | 	// redirect can be a bit bigger because we probably need to encode data in the redirect uri | ||||||
|  | 	permittedRedirect := 256 | ||||||
|  | 
 | ||||||
|  | 	// check lengths of fields before proceeding so the user can't spam huge entries into the database | ||||||
|  | 	if len(form.ClientName) > permittedLength { | ||||||
|  | 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if len(form.Website) > permittedLength { | ||||||
|  | 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if len(form.RedirectURIs) > permittedRedirect { | ||||||
|  | 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if len(form.Scopes) > permittedLength { | ||||||
|  | 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/ | ||||||
|  | 	var scopes string | ||||||
|  | 	if form.Scopes == "" { | ||||||
|  | 		scopes = "read" | ||||||
|  | 	} else { | ||||||
|  | 		scopes = form.Scopes | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// generate new IDs for this application and its associated client | ||||||
|  | 	clientID := uuid.NewString() | ||||||
|  | 	clientSecret := uuid.NewString() | ||||||
|  | 	vapidKey := uuid.NewString() | ||||||
|  | 
 | ||||||
|  | 	// generate the application to put in the database | ||||||
|  | 	app := >smodel.Application{ | ||||||
|  | 		Name:         form.ClientName, | ||||||
|  | 		Website:      form.Website, | ||||||
|  | 		RedirectURI:  form.RedirectURIs, | ||||||
|  | 		ClientID:     clientID, | ||||||
|  | 		ClientSecret: clientSecret, | ||||||
|  | 		Scopes:       scopes, | ||||||
|  | 		VapidKey:     vapidKey, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// chuck it in the db | ||||||
|  | 	if err := m.db.Put(app); err != nil { | ||||||
|  | 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// now we need to model an oauth client from the application that the oauth library can use | ||||||
|  | 	oc := &oauthClient{ | ||||||
|  | 		ID:     clientID, | ||||||
|  | 		Secret: clientSecret, | ||||||
|  | 		Domain: form.RedirectURIs, | ||||||
|  | 		UserID: "", // This client isn't yet associated with a specific user,  it's just an app client right now | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// chuck it in the db | ||||||
|  | 	if err := m.db.Put(oc); err != nil { | ||||||
|  | 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/ | ||||||
|  | 	c.JSON(http.StatusOK, app.ToMastotype()) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // signInGETHandler should be served at https://example.org/auth/sign_in. | ||||||
|  | // The idea is to present a sign in page to the user, where they can enter their username and password. | ||||||
|  | // The form will then POST to the sign in page, which will be handled by SignInPOSTHandler | ||||||
|  | func (m *oauthModule) signInGETHandler(c *gin.Context) { | ||||||
|  | 	m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html") | ||||||
|  | 	c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // signInPOSTHandler should be served at https://example.org/auth/sign_in. | ||||||
|  | // The idea is to present a sign in page to the user, where they can enter their username and password. | ||||||
|  | // The handler will then redirect to the auth handler served at /auth | ||||||
|  | func (m *oauthModule) signInPOSTHandler(c *gin.Context) { | ||||||
|  | 	l := m.log.WithField("func", "SignInPOSTHandler") | ||||||
|  | 	s := sessions.Default(c) | ||||||
|  | 	form := &login{} | ||||||
|  | 	if err := c.ShouldBind(form); err != nil { | ||||||
|  | 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	l.Tracef("parsed form: %+v", form) | ||||||
|  | 
 | ||||||
|  | 	userid, err := m.validatePassword(form.Email, form.Password) | ||||||
|  | 	if err != nil { | ||||||
|  | 		c.String(http.StatusForbidden, err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s.Set("userid", userid) | ||||||
|  | 	if err := s.Save(); err != nil { | ||||||
|  | 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	l.Trace("redirecting to auth page") | ||||||
|  | 	c.Redirect(http.StatusFound, oauthAuthorizePath) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // tokenPOSTHandler should be served as a POST at https://example.org/oauth/token | ||||||
|  | // The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs. | ||||||
|  | // See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token | ||||||
|  | func (m *oauthModule) tokenPOSTHandler(c *gin.Context) { | ||||||
|  | 	l := m.log.WithField("func", "TokenPOSTHandler") | ||||||
|  | 	l.Trace("entered TokenPOSTHandler") | ||||||
|  | 	if err := m.oauthServer.HandleTokenRequest(c.Writer, c.Request); err != nil { | ||||||
|  | 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // authorizeGETHandler should be served as GET at https://example.org/oauth/authorize | ||||||
|  | // The idea here is to present an oauth authorize page to the user, with a button | ||||||
|  | // that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user | ||||||
|  | func (m *oauthModule) authorizeGETHandler(c *gin.Context) { | ||||||
|  | 	l := m.log.WithField("func", "AuthorizeGETHandler") | ||||||
|  | 	s := sessions.Default(c) | ||||||
|  | 
 | ||||||
|  | 	// UserID will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow | ||||||
|  | 	// If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page. | ||||||
|  | 	userID, ok := s.Get("userid").(string) | ||||||
|  | 	if !ok || userID == "" { | ||||||
|  | 		l.Trace("userid was empty, parsing form then redirecting to sign in page") | ||||||
|  | 		if err := parseAuthForm(c, l); err != nil { | ||||||
|  | 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||||
|  | 		} else { | ||||||
|  | 			c.Redirect(http.StatusFound, authSignInPath) | ||||||
|  | 		} | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// We can use the client_id on the session to retrieve info about the app associated with the client_id | ||||||
|  | 	clientID, ok := s.Get("client_id").(string) | ||||||
|  | 	if !ok || clientID == "" { | ||||||
|  | 		c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	app := >smodel.Application{ | ||||||
|  | 		ClientID: clientID, | ||||||
|  | 	} | ||||||
|  | 	if err := m.db.GetWhere("client_id", app.ClientID, app); err != nil { | ||||||
|  | 		c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// we can also use the userid of the user to fetch their username from the db to greet them nicely <3 | ||||||
|  | 	user := >smodel.User{ | ||||||
|  | 		ID: userID, | ||||||
|  | 	} | ||||||
|  | 	if err := m.db.GetByID(user.ID, user); err != nil { | ||||||
|  | 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	acct := >smodel.Account{ | ||||||
|  | 		ID: user.AccountID, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := m.db.GetByID(acct.ID, acct); err != nil { | ||||||
|  | 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Finally we should also get the redirect and scope of this particular request, as stored in the session. | ||||||
|  | 	redirect, ok := s.Get("redirect_uri").(string) | ||||||
|  | 	if !ok || redirect == "" { | ||||||
|  | 		c.JSON(http.StatusInternalServerError, gin.H{"error": "no redirect_uri found in session"}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	scope, ok := s.Get("scope").(string) | ||||||
|  | 	if !ok || scope == "" { | ||||||
|  | 		c.JSON(http.StatusInternalServerError, gin.H{"error": "no scope found in session"}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// the authorize template will display a form to the user where they can get some information | ||||||
|  | 	// about the app that's trying to authorize, and the scope of the request. | ||||||
|  | 	// They can then approve it if it looks OK to them, which will POST to the AuthorizePOSTHandler | ||||||
|  | 	l.Trace("serving authorize html") | ||||||
|  | 	c.HTML(http.StatusOK, "authorize.tmpl", gin.H{ | ||||||
|  | 		"appname":    app.Name, | ||||||
|  | 		"appwebsite": app.Website, | ||||||
|  | 		"redirect":   redirect, | ||||||
|  | 		"scope":      scope, | ||||||
|  | 		"user":       acct.Username, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // authorizePOSTHandler should be served as POST at https://example.org/oauth/authorize | ||||||
|  | // At this point we assume that the user has A) logged in and B) accepted that the app should act for them, | ||||||
|  | // so we should proceed with the authentication flow and generate an oauth token for them if we can. | ||||||
|  | // See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user | ||||||
|  | func (m *oauthModule) authorizePOSTHandler(c *gin.Context) { | ||||||
|  | 	l := m.log.WithField("func", "AuthorizePOSTHandler") | ||||||
|  | 	s := sessions.Default(c) | ||||||
|  | 
 | ||||||
|  | 	// At this point we know the user has said 'yes' to allowing the application and oauth client | ||||||
|  | 	// work for them, so we can set the | ||||||
|  | 
 | ||||||
|  | 	// We need to retrieve the original form submitted to the authorizeGEThandler, and | ||||||
|  | 	// recreate it on the request so that it can be used further by the oauth2 library. | ||||||
|  | 	// So first fetch all the values from the session. | ||||||
|  | 	forceLogin, ok := s.Get("force_login").(string) | ||||||
|  | 	if !ok { | ||||||
|  | 		c.JSON(http.StatusBadRequest, gin.H{"error": "session missing force_login"}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	responseType, ok := s.Get("response_type").(string) | ||||||
|  | 	if !ok || responseType == "" { | ||||||
|  | 		c.JSON(http.StatusBadRequest, gin.H{"error": "session missing response_type"}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	clientID, ok := s.Get("client_id").(string) | ||||||
|  | 	if !ok || clientID == "" { | ||||||
|  | 		c.JSON(http.StatusBadRequest, gin.H{"error": "session missing client_id"}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	redirectURI, ok := s.Get("redirect_uri").(string) | ||||||
|  | 	if !ok || redirectURI == "" { | ||||||
|  | 		c.JSON(http.StatusBadRequest, gin.H{"error": "session missing redirect_uri"}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	scope, ok := s.Get("scope").(string) | ||||||
|  | 	if !ok { | ||||||
|  | 		c.JSON(http.StatusBadRequest, gin.H{"error": "session missing scope"}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	userID, ok := s.Get("userid").(string) | ||||||
|  | 	if !ok { | ||||||
|  | 		c.JSON(http.StatusBadRequest, gin.H{"error": "session missing userid"}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	// we're done with the session so we can clear it now | ||||||
|  | 	s.Clear() | ||||||
|  | 
 | ||||||
|  | 	// now set the values on the request | ||||||
|  | 	values := url.Values{} | ||||||
|  | 	values.Set("force_login", forceLogin) | ||||||
|  | 	values.Set("response_type", responseType) | ||||||
|  | 	values.Set("client_id", clientID) | ||||||
|  | 	values.Set("redirect_uri", redirectURI) | ||||||
|  | 	values.Set("scope", scope) | ||||||
|  | 	values.Set("userid", userID) | ||||||
|  | 	c.Request.Form = values | ||||||
|  | 	l.Tracef("values on request set to %+v", c.Request.Form) | ||||||
|  | 
 | ||||||
|  | 	// and proceed with authorization using the oauth2 library | ||||||
|  | 	if err := m.oauthServer.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { | ||||||
|  | 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* | ||||||
|  | 	MIDDLEWARE | ||||||
|  | */ | ||||||
|  | 
 | ||||||
|  | // oauthTokenMiddleware | ||||||
|  | func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) { | ||||||
|  | 	l := m.log.WithField("func", "ValidatePassword") | ||||||
|  | 	l.Trace("entering OauthTokenMiddleware") | ||||||
|  | 	if ti, err := m.oauthServer.ValidationBearerToken(c.Request); err == nil { | ||||||
|  | 		l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope()) | ||||||
|  | 		c.Set("authenticated_user", ti.GetUserID()) | ||||||
|  | 
 | ||||||
|  | 	} else { | ||||||
|  | 		l.Trace("continuing with unauthenticated request") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* | ||||||
|  | 	SUB-HANDLERS -- don't serve these directly, they should be attached to the oauth2 server or used inside handler funcs | ||||||
|  | */ | ||||||
|  | 
 | ||||||
|  | // validatePassword takes an email address and a password. | ||||||
|  | // The goal is to authenticate the password against the one for that email | ||||||
|  | // address stored in the database. If OK, we return the userid (a uuid) for that user, | ||||||
|  | // so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db. | ||||||
|  | func (m *oauthModule) validatePassword(email string, password string) (userid string, err error) { | ||||||
|  | 	l := m.log.WithField("func", "ValidatePassword") | ||||||
|  | 
 | ||||||
|  | 	// make sure an email/password was provided and bail if not | ||||||
|  | 	if email == "" || password == "" { | ||||||
|  | 		l.Debug("email or password was not provided") | ||||||
|  | 		return incorrectPassword() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// first we select the user from the database based on email address, bail if no user found for that email | ||||||
|  | 	gtsUser := >smodel.User{} | ||||||
|  | 
 | ||||||
|  | 	if err := m.db.GetWhere("email", email, gtsUser); err != nil { | ||||||
|  | 		l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err) | ||||||
|  | 		return incorrectPassword() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// make sure a password is actually set and bail if not | ||||||
|  | 	if gtsUser.EncryptedPassword == "" { | ||||||
|  | 		l.Warnf("encrypted password for user %s was empty for some reason", gtsUser.Email) | ||||||
|  | 		return incorrectPassword() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// compare the provided password with the encrypted one from the db, bail if they don't match | ||||||
|  | 	if err := bcrypt.CompareHashAndPassword([]byte(gtsUser.EncryptedPassword), []byte(password)); err != nil { | ||||||
|  | 		l.Debugf("password hash didn't match for user %s during login attempt: %s", gtsUser.Email, err) | ||||||
|  | 		return incorrectPassword() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// If we've made it this far the email/password is correct, so we can just return the id of the user. | ||||||
|  | 	userid = gtsUser.ID | ||||||
|  | 	l.Tracef("returning (%s, %s)", userid, err) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // incorrectPassword is just a little helper function to use in the ValidatePassword function | ||||||
|  | func incorrectPassword() (string, error) { | ||||||
|  | 	return "", errors.New("password/email combination was incorrect") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // userAuthorizationHandler gets the user's ID from the 'userid' field of the request form, | ||||||
|  | // or redirects to the /auth/sign_in page, if this key is not present. | ||||||
|  | func (m *oauthModule) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) { | ||||||
|  | 	l := m.log.WithField("func", "UserAuthorizationHandler") | ||||||
|  | 	userID = r.FormValue("userid") | ||||||
|  | 	if userID == "" { | ||||||
|  | 		return "", errors.New("userid was empty, redirecting to sign in page") | ||||||
|  | 	} | ||||||
|  | 	l.Tracef("returning userID %s", userID) | ||||||
|  | 	return userID, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // parseAuthForm parses the OAuthAuthorize form in the gin context, and stores | ||||||
|  | // the values in the form into the session. | ||||||
|  | func parseAuthForm(c *gin.Context, l *logrus.Entry) error { | ||||||
|  | 	s := sessions.Default(c) | ||||||
|  | 
 | ||||||
|  | 	// first make sure they've filled out the authorize form with the required values | ||||||
|  | 	form := &mastotypes.OAuthAuthorize{} | ||||||
|  | 	if err := c.ShouldBind(form); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	l.Tracef("parsed form: %+v", form) | ||||||
|  | 
 | ||||||
|  | 	// these fields are *required* so check 'em | ||||||
|  | 	if form.ResponseType == "" || form.ClientID == "" || form.RedirectURI == "" { | ||||||
|  | 		return errors.New("missing one of: response_type, client_id or redirect_uri") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// set default scope to read | ||||||
|  | 	if form.Scope == "" { | ||||||
|  | 		form.Scope = "read" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// save these values from the form so we can use them elsewhere in the session | ||||||
|  | 	s.Set("force_login", form.ForceLogin) | ||||||
|  | 	s.Set("response_type", form.ResponseType) | ||||||
|  | 	s.Set("client_id", form.ClientID) | ||||||
|  | 	s.Set("redirect_uri", form.RedirectURI) | ||||||
|  | 	s.Set("scope", form.Scope) | ||||||
|  | 	return s.Save() | ||||||
|  | } | ||||||
							
								
								
									
										191
									
								
								internal/module/oauth/oauth_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								internal/module/oauth/oauth_test.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,191 @@ | ||||||
|  | /* | ||||||
|  |    GoToSocial | ||||||
|  |    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org | ||||||
|  | 
 | ||||||
|  |    This program is free software: you can redistribute it and/or modify | ||||||
|  |    it under the terms of the GNU Affero General Public License as published by | ||||||
|  |    the Free Software Foundation, either version 3 of the License, or | ||||||
|  |    (at your option) any later version. | ||||||
|  | 
 | ||||||
|  |    This program is distributed in the hope that it will be useful, | ||||||
|  |    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  |    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  |    GNU Affero General Public License for more details. | ||||||
|  | 
 | ||||||
|  |    You should have received a copy of the GNU Affero General Public License | ||||||
|  |    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | */ | ||||||
|  | 
 | ||||||
|  | package oauth | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/google/uuid" | ||||||
|  | 	"github.com/gotosocial/gotosocial/internal/config" | ||||||
|  | 	"github.com/gotosocial/gotosocial/internal/db" | ||||||
|  | 	"github.com/gotosocial/gotosocial/internal/gtsmodel" | ||||||
|  | 	"github.com/gotosocial/gotosocial/internal/router" | ||||||
|  | 	"github.com/gotosocial/oauth2/v4" | ||||||
|  | 	"github.com/sirupsen/logrus" | ||||||
|  | 	"github.com/stretchr/testify/suite" | ||||||
|  | 	"golang.org/x/crypto/bcrypt" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type OauthTestSuite struct { | ||||||
|  | 	suite.Suite | ||||||
|  | 	tokenStore      oauth2.TokenStore | ||||||
|  | 	clientStore     oauth2.ClientStore | ||||||
|  | 	db              db.DB | ||||||
|  | 	testAccount     *gtsmodel.Account | ||||||
|  | 	testApplication *gtsmodel.Application | ||||||
|  | 	testUser        *gtsmodel.User | ||||||
|  | 	testClient      *oauthClient | ||||||
|  | 	config          *config.Config | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout | ||||||
|  | func (suite *OauthTestSuite) SetupSuite() { | ||||||
|  | 	c := config.Empty() | ||||||
|  | 	// we're running on localhost without https so set the protocol to http | ||||||
|  | 	c.Protocol = "http" | ||||||
|  | 	// just for testing | ||||||
|  | 	c.Host = "localhost:8080" | ||||||
|  | 	// because go tests are run within the test package directory, we need to fiddle with the templateconfig | ||||||
|  | 	// basedir in a way that we wouldn't normally have to do when running the binary, in order to make | ||||||
|  | 	// the templates actually load | ||||||
|  | 	c.TemplateConfig.BaseDir = "../../../web/template/" | ||||||
|  | 	c.DBConfig = &config.DBConfig{ | ||||||
|  | 		Type:            "postgres", | ||||||
|  | 		Address:         "localhost", | ||||||
|  | 		Port:            5432, | ||||||
|  | 		User:            "postgres", | ||||||
|  | 		Password:        "postgres", | ||||||
|  | 		Database:        "postgres", | ||||||
|  | 		ApplicationName: "gotosocial", | ||||||
|  | 	} | ||||||
|  | 	suite.config = c | ||||||
|  | 
 | ||||||
|  | 	encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logrus.Panicf("error encrypting user pass: %s", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	acctID := uuid.NewString() | ||||||
|  | 
 | ||||||
|  | 	suite.testAccount = >smodel.Account{ | ||||||
|  | 		ID:       acctID, | ||||||
|  | 		Username: "test_user", | ||||||
|  | 	} | ||||||
|  | 	suite.testUser = >smodel.User{ | ||||||
|  | 		EncryptedPassword: string(encryptedPassword), | ||||||
|  | 		Email:             "user@example.org", | ||||||
|  | 		AccountID:         acctID, | ||||||
|  | 	} | ||||||
|  | 	suite.testClient = &oauthClient{ | ||||||
|  | 		ID:     "a-known-client-id", | ||||||
|  | 		Secret: "some-secret", | ||||||
|  | 		Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host), | ||||||
|  | 	} | ||||||
|  | 	suite.testApplication = >smodel.Application{ | ||||||
|  | 		Name:         "a test application", | ||||||
|  | 		Website:      "https://some-application-website.com", | ||||||
|  | 		RedirectURI:  "http://localhost:8080", | ||||||
|  | 		ClientID:     "a-known-client-id", | ||||||
|  | 		ClientSecret: "some-secret", | ||||||
|  | 		Scopes:       "read", | ||||||
|  | 		VapidKey:     uuid.NewString(), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SetupTest creates a postgres connection and creates the oauth_clients table before each test | ||||||
|  | func (suite *OauthTestSuite) SetupTest() { | ||||||
|  | 
 | ||||||
|  | 	log := logrus.New() | ||||||
|  | 	log.SetLevel(logrus.TraceLevel) | ||||||
|  | 	db, err := db.New(context.Background(), suite.config, log) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logrus.Panicf("error creating database connection: %s", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	suite.db = db | ||||||
|  | 
 | ||||||
|  | 	models := []interface{}{ | ||||||
|  | 		&oauthClient{}, | ||||||
|  | 		&oauthToken{}, | ||||||
|  | 		>smodel.User{}, | ||||||
|  | 		>smodel.Account{}, | ||||||
|  | 		>smodel.Application{}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, m := range models { | ||||||
|  | 		if err := suite.db.CreateTable(m); err != nil { | ||||||
|  | 			logrus.Panicf("db connection error: %s", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	suite.tokenStore = newTokenStore(context.Background(), suite.db, logrus.New()) | ||||||
|  | 	suite.clientStore = newClientStore(suite.db) | ||||||
|  | 
 | ||||||
|  | 	if err := suite.db.Put(suite.testAccount); err != nil { | ||||||
|  | 		logrus.Panicf("could not insert test account into db: %s", err) | ||||||
|  | 	} | ||||||
|  | 	if err := suite.db.Put(suite.testUser); err != nil { | ||||||
|  | 		logrus.Panicf("could not insert test user into db: %s", err) | ||||||
|  | 	} | ||||||
|  | 	if err := suite.db.Put(suite.testClient); err != nil { | ||||||
|  | 		logrus.Panicf("could not insert test client into db: %s", err) | ||||||
|  | 	} | ||||||
|  | 	if err := suite.db.Put(suite.testApplication); err != nil { | ||||||
|  | 		logrus.Panicf("could not insert test application into db: %s", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // TearDownTest drops the oauth_clients table and closes the pg connection after each test | ||||||
|  | func (suite *OauthTestSuite) TearDownTest() { | ||||||
|  | 	models := []interface{}{ | ||||||
|  | 		&oauthClient{}, | ||||||
|  | 		&oauthToken{}, | ||||||
|  | 		>smodel.User{}, | ||||||
|  | 		>smodel.Account{}, | ||||||
|  | 		>smodel.Application{}, | ||||||
|  | 	} | ||||||
|  | 	for _, m := range models { | ||||||
|  | 		if err := suite.db.DropTable(m); err != nil { | ||||||
|  | 			logrus.Panicf("error dropping table: %s", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if err := suite.db.Stop(context.Background()); err != nil { | ||||||
|  | 		logrus.Panicf("error closing db connection: %s", err) | ||||||
|  | 	} | ||||||
|  | 	suite.db = nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (suite *OauthTestSuite) TestAPIInitialize() { | ||||||
|  | 	log := logrus.New() | ||||||
|  | 	log.SetLevel(logrus.TraceLevel) | ||||||
|  | 
 | ||||||
|  | 	r, err := router.New(suite.config, log) | ||||||
|  | 	if err != nil { | ||||||
|  | 		suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	api := New(suite.tokenStore, suite.clientStore, suite.db, log) | ||||||
|  | 	if err := api.Route(r); err != nil { | ||||||
|  | 		suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	go r.Start() | ||||||
|  | 	time.Sleep(60 * time.Second) | ||||||
|  | 	// http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=http://localhost:8080&scope=read | ||||||
|  | 	// curl -v -F client_id=a-known-client-id -F client_secret=some-secret -F redirect_uri=http://localhost:8080 -F code=[ INSERT CODE HERE ] -F grant_type=authorization_code localhost:8080/oauth/token | ||||||
|  | 	// curl -v -H "Authorization: Bearer [INSERT TOKEN HERE]" http://localhost:8080 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestOauthTestSuite(t *testing.T) { | ||||||
|  | 	suite.Run(t, new(OauthTestSuite)) | ||||||
|  | } | ||||||
|  | @ -24,31 +24,31 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/go-pg/pg/v10" | 	"github.com/gotosocial/gotosocial/internal/db" | ||||||
| 	"github.com/gotosocial/oauth2/v4" | 	"github.com/gotosocial/oauth2/v4" | ||||||
| 	"github.com/gotosocial/oauth2/v4/models" | 	"github.com/gotosocial/oauth2/v4/models" | ||||||
| 	"github.com/sirupsen/logrus" | 	"github.com/sirupsen/logrus" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // pgTokenStore is an implementation of oauth2.TokenStore, which uses Postgres as a storage backend. | // tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend. | ||||||
| type pgTokenStore struct { | type tokenStore struct { | ||||||
| 	oauth2.TokenStore | 	oauth2.TokenStore | ||||||
| 	conn *pg.DB | 	db  db.DB | ||||||
| 	log  *logrus.Logger | 	log *logrus.Logger | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // NewPGTokenStore returns a token store, using postgres, that satisfies the oauth2.TokenStore interface. | // newTokenStore returns a token store that satisfies the oauth2.TokenStore interface. | ||||||
| // | // | ||||||
| // In order to allow tokens to 'expire' (not really a thing in Postgres world), it will also set off a | // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through | ||||||
| // goroutine that iterates through the tokens in the DB once per minute and deletes any that have expired. | // the tokens in the DB once per minute and deletes any that have expired. | ||||||
| func NewPGTokenStore(ctx context.Context, conn *pg.DB, log *logrus.Logger) oauth2.TokenStore { | func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.TokenStore { | ||||||
| 	pts := &pgTokenStore{ | 	pts := &tokenStore{ | ||||||
| 		conn: conn, | 		db:  db, | ||||||
| 		log:  log, | 		log: log, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// set the token store to clean out expired tokens once per minute, or return if we're done | 	// set the token store to clean out expired tokens once per minute, or return if we're done | ||||||
| 	go func(ctx context.Context, pts *pgTokenStore, log *logrus.Logger) { | 	go func(ctx context.Context, pts *tokenStore, log *logrus.Logger) { | ||||||
| 	cleanloop: | 	cleanloop: | ||||||
| 		for { | 		for { | ||||||
| 			select { | 			select { | ||||||
|  | @ -67,22 +67,22 @@ func NewPGTokenStore(ctx context.Context, conn *pg.DB, log *logrus.Logger) oauth | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so. | // sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so. | ||||||
| func (pts *pgTokenStore) sweep() error { | func (pts *tokenStore) sweep() error { | ||||||
| 	// select *all* tokens from the db | 	// select *all* tokens from the db | ||||||
| 	// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. | 	// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. | ||||||
| 	var tokens []oauthToken | 	tokens := new([]*oauthToken) | ||||||
| 	if err := pts.conn.Model(&tokens).Select(); err != nil { | 	if err := pts.db.GetAll(tokens); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// iterate through and remove expired tokens | 	// iterate through and remove expired tokens | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 	for _, pgt := range tokens { | 	for _, pgt := range *tokens { | ||||||
| 		// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: | 		// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: | ||||||
| 		// we only want to check if a token expired before now if the expiry time is *not zero*; | 		// we only want to check if a token expired before now if the expiry time is *not zero*; | ||||||
| 		// ie., if it's been explicity set. | 		// ie., if it's been explicity set. | ||||||
| 		if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) { | 		if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) { | ||||||
| 			if _, err := pts.conn.Model(&pgt).Delete(); err != nil { | 			if err := pts.db.DeleteByID(pgt.ID, &pgt); err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | @ -93,68 +93,61 @@ func (pts *pgTokenStore) sweep() error { | ||||||
| 
 | 
 | ||||||
| // Create creates and store the new token information. | // Create creates and store the new token information. | ||||||
| // For the original implementation, see https://github.com/gotosocial/oauth2/blob/master/store/token.go#L34 | // For the original implementation, see https://github.com/gotosocial/oauth2/blob/master/store/token.go#L34 | ||||||
| func (pts *pgTokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { | func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { | ||||||
| 	t, ok := info.(*models.Token) | 	t, ok := info.(*models.Token) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return errors.New("info param was not a models.Token") | 		return errors.New("info param was not a models.Token") | ||||||
| 	} | 	} | ||||||
| 	_, err := pts.conn.WithContext(ctx).Model(oauthTokenToPGToken(t)).Insert() | 	if err := pts.db.Put(oauthTokenToPGToken(t)); err != nil { | ||||||
| 	if err != nil { |  | ||||||
| 		return fmt.Errorf("error in tokenstore create: %s", err) | 		return fmt.Errorf("error in tokenstore create: %s", err) | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RemoveByCode deletes a token from the DB based on the Code field | // RemoveByCode deletes a token from the DB based on the Code field | ||||||
| func (pts *pgTokenStore) RemoveByCode(ctx context.Context, code string) error { | func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error { | ||||||
| 	_, err := pts.conn.Model(&oauthToken{}).Where("code = ?", code).Delete() | 	return pts.db.DeleteWhere("code", code, &oauthToken{}) | ||||||
| 	if err != nil { |  | ||||||
| 		return fmt.Errorf("error in tokenstore removebycode: %s", err) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RemoveByAccess deletes a token from the DB based on the Access field | // RemoveByAccess deletes a token from the DB based on the Access field | ||||||
| func (pts *pgTokenStore) RemoveByAccess(ctx context.Context, access string) error { | func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { | ||||||
| 	_, err := pts.conn.Model(&oauthToken{}).Where("access = ?", access).Delete() | 	return pts.db.DeleteWhere("access", access, &oauthToken{}) | ||||||
| 	if err != nil { |  | ||||||
| 		return fmt.Errorf("error in tokenstore removebyaccess: %s", err) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RemoveByRefresh deletes a token from the DB based on the Refresh field | // RemoveByRefresh deletes a token from the DB based on the Refresh field | ||||||
| func (pts *pgTokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { | func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { | ||||||
| 	_, err := pts.conn.Model(&oauthToken{}).Where("refresh = ?", refresh).Delete() | 	return pts.db.DeleteWhere("refresh", refresh, &oauthToken{}) | ||||||
| 	if err != nil { |  | ||||||
| 		return fmt.Errorf("error in tokenstore removebyrefresh: %s", err) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetByCode selects a token from the DB based on the Code field | // GetByCode selects a token from the DB based on the Code field | ||||||
| func (pts *pgTokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { | func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { | ||||||
| 	pgt := &oauthToken{} | 	pgt := &oauthToken{ | ||||||
| 	if err := pts.conn.Model(pgt).Where("code = ?", code).Select(); err != nil { | 		Code: code, | ||||||
| 		return nil, fmt.Errorf("error in tokenstore getbycode: %s", err) | 	} | ||||||
|  | 	if err := pts.db.GetWhere("code", code, pgt); err != nil { | ||||||
|  | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return pgTokenToOauthToken(pgt), nil | 	return pgTokenToOauthToken(pgt), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetByAccess selects a token from the DB based on the Access field | // GetByAccess selects a token from the DB based on the Access field | ||||||
| func (pts *pgTokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { | func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { | ||||||
| 	pgt := &oauthToken{} | 	pgt := &oauthToken{ | ||||||
| 	if err := pts.conn.Model(pgt).Where("access = ?", access).Select(); err != nil { | 		Access: access, | ||||||
| 		return nil, fmt.Errorf("error in tokenstore getbyaccess: %s", err) | 	} | ||||||
|  | 	if err := pts.db.GetWhere("access", access, pgt); err != nil { | ||||||
|  | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return pgTokenToOauthToken(pgt), nil | 	return pgTokenToOauthToken(pgt), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetByRefresh selects a token from the DB based on the Refresh field | // GetByRefresh selects a token from the DB based on the Refresh field | ||||||
| func (pts *pgTokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { | func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { | ||||||
| 	pgt := &oauthToken{} | 	pgt := &oauthToken{ | ||||||
| 	if err := pts.conn.Model(pgt).Where("refresh = ?", refresh).Select(); err != nil { | 		Refresh: refresh, | ||||||
| 		return nil, fmt.Errorf("error in tokenstore getbyrefresh: %s", err) | 	} | ||||||
|  | 	if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil { | ||||||
|  | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return pgTokenToOauthToken(pgt), nil | 	return pgTokenToOauthToken(pgt), nil | ||||||
| } | } | ||||||
|  | @ -174,6 +167,7 @@ func (pts *pgTokenStore) GetByRefresh(ctx context.Context, refresh string) (oaut | ||||||
| // As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken | // As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken | ||||||
| // and pgTokenToOauthToken can be used for that. | // and pgTokenToOauthToken can be used for that. | ||||||
| type oauthToken struct { | type oauthToken struct { | ||||||
|  | 	ID                  string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` | ||||||
| 	ClientID            string | 	ClientID            string | ||||||
| 	UserID              string | 	UserID              string | ||||||
| 	RedirectURI         string | 	RedirectURI         string | ||||||
|  | @ -1,446 +0,0 @@ | ||||||
| /* |  | ||||||
|    GoToSocial |  | ||||||
|    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org |  | ||||||
| 
 |  | ||||||
|    This program is free software: you can redistribute it and/or modify |  | ||||||
|    it under the terms of the GNU Affero General Public License as published by |  | ||||||
|    the Free Software Foundation, either version 3 of the License, or |  | ||||||
|    (at your option) any later version. |  | ||||||
| 
 |  | ||||||
|    This program is distributed in the hope that it will be useful, |  | ||||||
|    but WITHOUT ANY WARRANTY; without even the implied warranty of |  | ||||||
|    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the |  | ||||||
|    GNU Affero General Public License for more details. |  | ||||||
| 
 |  | ||||||
|    You should have received a copy of the GNU Affero General Public License |  | ||||||
|    along with this program.  If not, see <http://www.gnu.org/licenses/>. |  | ||||||
| */ |  | ||||||
| 
 |  | ||||||
| package oauth |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/url" |  | ||||||
| 
 |  | ||||||
| 	"github.com/gin-contrib/sessions" |  | ||||||
| 	"github.com/gin-gonic/gin" |  | ||||||
| 	"github.com/go-pg/pg/v10" |  | ||||||
| 	"github.com/google/uuid" |  | ||||||
| 	"github.com/gotosocial/gotosocial/internal/api" |  | ||||||
| 	"github.com/gotosocial/gotosocial/internal/gtsmodel" |  | ||||||
| 	"github.com/gotosocial/gotosocial/pkg/mastotypes" |  | ||||||
| 	"github.com/gotosocial/oauth2/v4" |  | ||||||
| 	"github.com/gotosocial/oauth2/v4/errors" |  | ||||||
| 	"github.com/gotosocial/oauth2/v4/manage" |  | ||||||
| 	"github.com/gotosocial/oauth2/v4/server" |  | ||||||
| 	"github.com/sirupsen/logrus" |  | ||||||
| 	"golang.org/x/crypto/bcrypt" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| type API struct { |  | ||||||
| 	manager *manage.Manager |  | ||||||
| 	server  *server.Server |  | ||||||
| 	conn    *pg.DB |  | ||||||
| 	log     *logrus.Logger |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type login struct { |  | ||||||
| 	Email    string `form:"username"` |  | ||||||
| 	Password string `form:"password"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type code struct { |  | ||||||
| 	Code string `form:"code"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func New(ts oauth2.TokenStore, cs oauth2.ClientStore, conn *pg.DB, log *logrus.Logger) *API { |  | ||||||
| 	manager := manage.NewDefaultManager() |  | ||||||
| 	manager.MapTokenStorage(ts) |  | ||||||
| 	manager.MapClientStorage(cs) |  | ||||||
| 	manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) |  | ||||||
| 	sc := &server.Config{ |  | ||||||
| 		TokenType: "Bearer", |  | ||||||
| 		// Must follow the spec. |  | ||||||
| 		AllowGetAccessRequest: false, |  | ||||||
| 		// Support only the non-implicit flow. |  | ||||||
| 		AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code}, |  | ||||||
| 		// Allow: |  | ||||||
| 		// - Authorization Code (for first & third parties) |  | ||||||
| 		// - Refreshing Tokens |  | ||||||
| 		// |  | ||||||
| 		// Deny: |  | ||||||
| 		// - Resource owner secrets (password grant) |  | ||||||
| 		// - Client secrets |  | ||||||
| 		AllowedGrantTypes: []oauth2.GrantType{ |  | ||||||
| 			oauth2.AuthorizationCode, |  | ||||||
| 			oauth2.Refreshing, |  | ||||||
| 		}, |  | ||||||
| 		AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{ |  | ||||||
| 			oauth2.CodeChallengePlain, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	srv := server.NewServer(sc, manager) |  | ||||||
| 	srv.SetInternalErrorHandler(func(err error) *errors.Response { |  | ||||||
| 		log.Errorf("internal oauth error: %s", err) |  | ||||||
| 		return nil |  | ||||||
| 	}) |  | ||||||
| 
 |  | ||||||
| 	srv.SetResponseErrorHandler(func(re *errors.Response) { |  | ||||||
| 		log.Errorf("internal response error: %s", re.Error) |  | ||||||
| 	}) |  | ||||||
| 
 |  | ||||||
| 	api := &API{ |  | ||||||
| 		manager: manager, |  | ||||||
| 		server:  srv, |  | ||||||
| 		conn:    conn, |  | ||||||
| 		log:     log, |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	api.server.SetUserAuthorizationHandler(api.UserAuthorizationHandler) |  | ||||||
| 	api.server.SetClientInfoHandler(server.ClientFormHandler) |  | ||||||
| 	return api |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (a *API) AddRoutes(s api.Server) error { |  | ||||||
| 	s.AttachHandler(http.MethodPost, "/api/v1/apps", a.AppsPOSTHandler) |  | ||||||
| 
 |  | ||||||
| 	s.AttachHandler(http.MethodGet, "/auth/sign_in", a.SignInGETHandler) |  | ||||||
| 	s.AttachHandler(http.MethodPost, "/auth/sign_in", a.SignInPOSTHandler) |  | ||||||
| 
 |  | ||||||
| 	s.AttachHandler(http.MethodPost, "/oauth/token", a.TokenPOSTHandler) |  | ||||||
| 
 |  | ||||||
| 	s.AttachHandler(http.MethodGet, "/oauth/authorize", a.AuthorizeGETHandler) |  | ||||||
| 	s.AttachHandler(http.MethodPost, "/oauth/authorize", a.AuthorizePOSTHandler) |  | ||||||
| 
 |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func incorrectPassword() (string, error) { |  | ||||||
| 	return "", errors.New("password/email combination was incorrect") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /* |  | ||||||
| 	MAIN HANDLERS -- serve these through a server/router |  | ||||||
| */ |  | ||||||
| 
 |  | ||||||
| // AppsPOSTHandler should be served at https://example.org/api/v1/apps |  | ||||||
| // It is equivalent to: https://docs.joinmastodon.org/methods/apps/ |  | ||||||
| func (a *API) AppsPOSTHandler(c *gin.Context) { |  | ||||||
| 	l := a.log.WithField("func", "AppsPOSTHandler") |  | ||||||
| 	l.Trace("entering AppsPOSTHandler") |  | ||||||
| 
 |  | ||||||
| 	form := &mastotypes.ApplicationPOSTRequest{} |  | ||||||
| 	if err := c.ShouldBind(form); err != nil { |  | ||||||
| 		c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()}) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// permitted length for most fields |  | ||||||
| 	permittedLength := 64 |  | ||||||
| 	// redirect can be a bit bigger because we probably need to encode data in the redirect uri |  | ||||||
| 	permittedRedirect := 256 |  | ||||||
| 
 |  | ||||||
| 	// check lengths of fields before proceeding so the user can't spam huge entries into the database |  | ||||||
| 	if len(form.ClientName) > permittedLength { |  | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)}) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	if len(form.Website) > permittedLength { |  | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)}) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	if len(form.RedirectURIs) > permittedRedirect { |  | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)}) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	if len(form.Scopes) > permittedLength { |  | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)}) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// set default 'read' for scopes if it's not set |  | ||||||
| 	var scopes string |  | ||||||
| 	if form.Scopes == "" { |  | ||||||
| 		scopes = "read" |  | ||||||
| 	} else { |  | ||||||
| 		scopes = form.Scopes |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// generate new IDs for this application and its associated client |  | ||||||
| 	clientID := uuid.NewString() |  | ||||||
| 	clientSecret := uuid.NewString() |  | ||||||
| 	vapidKey := uuid.NewString() |  | ||||||
| 
 |  | ||||||
| 	// generate the application to put in the database |  | ||||||
| 	app := >smodel.Application{ |  | ||||||
| 		Name:         form.ClientName, |  | ||||||
| 		Website:      form.Website, |  | ||||||
| 		RedirectURI:  form.RedirectURIs, |  | ||||||
| 		ClientID:     clientID, |  | ||||||
| 		ClientSecret: clientSecret, |  | ||||||
| 		Scopes:       scopes, |  | ||||||
| 		VapidKey:     vapidKey, |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// chuck it in the db |  | ||||||
| 	if _, err := a.conn.Model(app).Insert(); err != nil { |  | ||||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// now we need to model an oauth client from the application that the oauth library can use |  | ||||||
| 	oc := &oauthClient{ |  | ||||||
| 		ID:     clientID, |  | ||||||
| 		Secret: clientSecret, |  | ||||||
| 		Domain: form.RedirectURIs, |  | ||||||
| 		UserID: "", // This client isn't yet associated with a specific user,  it's just an app client right now |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// chuck it in the db |  | ||||||
| 	if _, err := a.conn.Model(oc).Insert(); err != nil { |  | ||||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/ |  | ||||||
| 	c.JSON(http.StatusOK, app) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // SignInGETHandler should be served at https://example.org/auth/sign_in. |  | ||||||
| // The idea is to present a sign in page to the user, where they can enter their username and password. |  | ||||||
| // The form will then POST to the sign in page, which will be handled by SignInPOSTHandler |  | ||||||
| func (a *API) SignInGETHandler(c *gin.Context) { |  | ||||||
| 	a.log.WithField("func", "SignInGETHandler").Trace("serving sign in html") |  | ||||||
| 	c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // SignInPOSTHandler should be served at https://example.org/auth/sign_in. |  | ||||||
| // The idea is to present a sign in page to the user, where they can enter their username and password. |  | ||||||
| // The handler will then redirect to the auth handler served at /auth |  | ||||||
| func (a *API) SignInPOSTHandler(c *gin.Context) { |  | ||||||
| 	l := a.log.WithField("func", "SignInPOSTHandler") |  | ||||||
| 	s := sessions.Default(c) |  | ||||||
| 	form := &login{} |  | ||||||
| 	if err := c.ShouldBind(form); err != nil { |  | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	l.Tracef("parsed form: %+v", form) |  | ||||||
| 
 |  | ||||||
| 	userid, err := a.ValidatePassword(form.Email, form.Password) |  | ||||||
| 	if err != nil { |  | ||||||
| 		c.String(http.StatusForbidden, err.Error()) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	s.Set("username", userid) |  | ||||||
| 	if err := s.Save(); err != nil { |  | ||||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	l.Trace("redirecting to auth page") |  | ||||||
| 	c.Redirect(http.StatusFound, "/oauth/authorize") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // TokenPOSTHandler should be served as a POST at https://example.org/oauth/token |  | ||||||
| // The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs. |  | ||||||
| // See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token |  | ||||||
| func (a *API) TokenPOSTHandler(c *gin.Context) { |  | ||||||
| 	l := a.log.WithField("func", "TokenHandler") |  | ||||||
| 	l.Trace("entered token handler, will now go to server.HandleTokenRequest") |  | ||||||
| 	if err := a.server.HandleTokenRequest(c.Writer, c.Request); err != nil { |  | ||||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // AuthorizeGETHandler should be served as GET at https://example.org/oauth/authorize |  | ||||||
| // The idea here is to present an oauth authorize page to the user, with a button |  | ||||||
| // that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user |  | ||||||
| func (a *API) AuthorizeGETHandler(c *gin.Context) { |  | ||||||
| 	l := a.log.WithField("func", "AuthorizeGETHandler") |  | ||||||
| 	s := sessions.Default(c) |  | ||||||
| 
 |  | ||||||
| 	// Username will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow |  | ||||||
| 	// If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page. |  | ||||||
| 	v := s.Get("username") |  | ||||||
| 	if username, ok := v.(string); !ok || username == "" { |  | ||||||
| 		l.Trace("username was empty, parsing form then redirecting to sign in page") |  | ||||||
| 
 |  | ||||||
| 		// first make sure they've filled out the authorize form with the required values |  | ||||||
| 		form := &mastotypes.OAuthAuthorize{} |  | ||||||
| 		if err := c.ShouldBind(form); err != nil { |  | ||||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		l.Tracef("parsed form: %+v", form) |  | ||||||
| 
 |  | ||||||
| 		// these fields are *required* so check 'em |  | ||||||
| 		if form.ResponseType == "" || form.ClientID == "" || form.RedirectURI == "" { |  | ||||||
| 			c.JSON(http.StatusBadRequest, gin.H{"error": "missing one of: response_type, client_id or redirect_uri"}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// save these values from the form so we can use them elsewhere in the session |  | ||||||
| 		s.Set("force_login", form.ForceLogin) |  | ||||||
| 		s.Set("response_type", form.ResponseType) |  | ||||||
| 		s.Set("client_id", form.ClientID) |  | ||||||
| 		s.Set("redirect_uri", form.RedirectURI) |  | ||||||
| 		s.Set("scope", form.Scope) |  | ||||||
| 		if err := s.Save(); err != nil { |  | ||||||
| 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// send them to the sign in page so we can tell who they are |  | ||||||
| 		c.Redirect(http.StatusFound, "/auth/sign_in") |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Check if we have a code already. If we do, it means the user used urn:ietf:wg:oauth:2.0:oob as their redirect URI |  | ||||||
| 	// and were sent here, which means they just want the code displayed so they can use it out of band. |  | ||||||
| 	code := &code{} |  | ||||||
| 	if err := c.Bind(code); err != nil { |  | ||||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// the authorize template will either: |  | ||||||
| 	// 1. Display the code to the user if they're already authorized and were redirected here because they selected urn:ietf:wg:oauth:2.0:oob. |  | ||||||
| 	// 2. Display a form where they can get some information about the app that's trying to authorize, and approve it, which will then go to AuthorizePOSTHandler |  | ||||||
| 	l.Trace("serving authorize html") |  | ||||||
| 	c.HTML(http.StatusOK, "authorize.tmpl", gin.H{ |  | ||||||
| 		"code": code.Code, |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // AuthorizePOSTHandler should be served as POST at https://example.org/oauth/authorize |  | ||||||
| // The idea here is to present an oauth authorize page to the user, with a button |  | ||||||
| // that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user |  | ||||||
| func (a *API) AuthorizePOSTHandler(c *gin.Context) { |  | ||||||
| 	l := a.log.WithField("func", "AuthorizePOSTHandler") |  | ||||||
| 	s := sessions.Default(c) |  | ||||||
| 
 |  | ||||||
| 	v := s.Get("username") |  | ||||||
| 	if username, ok := v.(string); !ok || username == "" { |  | ||||||
| 		c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not signed in"}) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	values := url.Values{} |  | ||||||
| 
 |  | ||||||
| 	if v, ok := s.Get("force_login").(string); !ok { |  | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": "session missing force_login"}) |  | ||||||
| 		return |  | ||||||
| 	} else { |  | ||||||
| 		values.Add("force_login", v) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if v, ok := s.Get("response_type").(string); !ok { |  | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": "session missing response_type"}) |  | ||||||
| 		return |  | ||||||
| 	} else { |  | ||||||
| 		values.Add("response_type", v) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if v, ok := s.Get("client_id").(string); !ok { |  | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": "session missing client_id"}) |  | ||||||
| 		return |  | ||||||
| 	} else { |  | ||||||
| 		values.Add("client_id", v) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if v, ok := s.Get("redirect_uri").(string); !ok { |  | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": "session missing redirect_uri"}) |  | ||||||
| 		return |  | ||||||
| 	} else { |  | ||||||
| 		// todo: explain this little hack |  | ||||||
| 		if v == "urn:ietf:wg:oauth:2.0:oob" { |  | ||||||
| 			v = "http://localhost:8080/oauth/authorize" |  | ||||||
| 		} |  | ||||||
| 		values.Add("redirect_uri", v) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if v, ok := s.Get("scope").(string); !ok { |  | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": "session missing scope"}) |  | ||||||
| 		return |  | ||||||
| 	} else { |  | ||||||
| 		values.Add("scope", v) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if v, ok := s.Get("username").(string); !ok { |  | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": "session missing username"}) |  | ||||||
| 		return |  | ||||||
| 	} else { |  | ||||||
| 		values.Add("username", v) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	c.Request.Form = values |  | ||||||
| 	l.Tracef("values on request set to %+v", c.Request.Form) |  | ||||||
| 
 |  | ||||||
| 	if err := s.Save(); err != nil { |  | ||||||
| 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if err := a.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { |  | ||||||
| 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /* |  | ||||||
| 	SUB-HANDLERS -- don't serve these directly, they should be attached to the oauth2 server |  | ||||||
| */ |  | ||||||
| 
 |  | ||||||
| // PasswordAuthorizationHandler takes a username (in this case, we use an email address) |  | ||||||
| // and a password. The goal is to authenticate the password against the one for that email |  | ||||||
| // address stored in the database. If OK, we return the userid (a uuid) for that user, |  | ||||||
| // so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db. |  | ||||||
| func (a *API) ValidatePassword(email string, password string) (userid string, err error) { |  | ||||||
| 	l := a.log.WithField("func", "PasswordAuthorizationHandler") |  | ||||||
| 
 |  | ||||||
| 	// make sure an email/password was provided and bail if not |  | ||||||
| 	if email == "" || password == "" { |  | ||||||
| 		l.Debug("email or password was not provided") |  | ||||||
| 		return incorrectPassword() |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// first we select the user from the database based on email address, bail if no user found for that email |  | ||||||
| 	gtsUser := >smodel.User{} |  | ||||||
| 	if err := a.conn.Model(gtsUser).Where("email = ?", email).Select(); err != nil { |  | ||||||
| 		l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err) |  | ||||||
| 		return incorrectPassword() |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// make sure a password is actually set and bail if not |  | ||||||
| 	if gtsUser.EncryptedPassword == "" { |  | ||||||
| 		l.Warnf("encrypted password for user %s was empty for some reason", gtsUser.Email) |  | ||||||
| 		return incorrectPassword() |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// compare the provided password with the encrypted one from the db, bail if they don't match |  | ||||||
| 	if err := bcrypt.CompareHashAndPassword([]byte(gtsUser.EncryptedPassword), []byte(password)); err != nil { |  | ||||||
| 		l.Debugf("password hash didn't match for user %s during login attempt: %s", gtsUser.Email, err) |  | ||||||
| 		return incorrectPassword() |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// If we've made it this far the email/password is correct, so we can just return the id of the user. |  | ||||||
| 	userid = gtsUser.ID |  | ||||||
| 	l.Tracef("returning (%s, %s)", userid, err) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // UserAuthorizationHandler gets the user's ID from the 'username' field of the request form, |  | ||||||
| // or redirects to the /auth/sign_in page, if this key is not present. |  | ||||||
| func (a *API) UserAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) { |  | ||||||
| 	l := a.log.WithField("func", "UserAuthorizationHandler") |  | ||||||
| 	userID = r.FormValue("username") |  | ||||||
| 	if userID == "" { |  | ||||||
| 		l.Trace("username was empty, redirecting to sign in page") |  | ||||||
| 		http.Redirect(w, r, "/auth/sign_in", http.StatusFound) |  | ||||||
| 		return "", nil |  | ||||||
| 	} |  | ||||||
| 	l.Tracef("returning (%s, %s)", userID, err) |  | ||||||
| 	return userID, err |  | ||||||
| } |  | ||||||
|  | @ -1,133 +0,0 @@ | ||||||
| package oauth |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"context" |  | ||||||
| 	"fmt" |  | ||||||
| 	"testing" |  | ||||||
| 	"time" |  | ||||||
| 
 |  | ||||||
| 	"github.com/go-pg/pg/v10" |  | ||||||
| 	"github.com/go-pg/pg/v10/orm" |  | ||||||
| 	"github.com/gotosocial/gotosocial/internal/api" |  | ||||||
| 	"github.com/gotosocial/gotosocial/internal/config" |  | ||||||
| 	"github.com/gotosocial/gotosocial/internal/gtsmodel" |  | ||||||
| 	"github.com/gotosocial/oauth2/v4" |  | ||||||
| 	"github.com/sirupsen/logrus" |  | ||||||
| 	"github.com/stretchr/testify/suite" |  | ||||||
| 	"golang.org/x/crypto/bcrypt" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| type OauthTestSuite struct { |  | ||||||
| 	suite.Suite |  | ||||||
| 	tokenStore  oauth2.TokenStore |  | ||||||
| 	clientStore oauth2.ClientStore |  | ||||||
| 	conn        *pg.DB |  | ||||||
| 	testAccount *gtsmodel.Account |  | ||||||
| 	testUser    *gtsmodel.User |  | ||||||
| 	testClient  *oauthClient |  | ||||||
| 	config      *config.Config |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| const () |  | ||||||
| 
 |  | ||||||
| // SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout |  | ||||||
| func (suite *OauthTestSuite) SetupSuite() { |  | ||||||
| 	encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("test-password"), bcrypt.DefaultCost) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logrus.Panicf("error encrypting user pass: %s", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	suite.testAccount = >smodel.Account{} |  | ||||||
| 	suite.testUser = >smodel.User{ |  | ||||||
| 		EncryptedPassword: string(encryptedPassword), |  | ||||||
| 		Email:             "user@localhost", |  | ||||||
| 		AccountID:         "some-account-id-it-doesn't-matter-really-since-this-user-doesn't-actually-have-an-account!", |  | ||||||
| 	} |  | ||||||
| 	suite.testClient = &oauthClient{ |  | ||||||
| 		ID:     "a-known-client-id", |  | ||||||
| 		Secret: "some-secret", |  | ||||||
| 		Domain: "http://localhost:8080", |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// because go tests are run within the test package directory, we need to fiddle with the templateconfig |  | ||||||
| 	// basedir in a way that we wouldn't normally have to do when running the binary, in order to make |  | ||||||
| 	// the templates actually load |  | ||||||
| 	c := config.Empty() |  | ||||||
| 	c.TemplateConfig.BaseDir = "../../web/template/" |  | ||||||
| 	suite.config = c |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // SetupTest creates a postgres connection and creates the oauth_clients table before each test |  | ||||||
| func (suite *OauthTestSuite) SetupTest() { |  | ||||||
| 	suite.conn = pg.Connect(&pg.Options{}) |  | ||||||
| 	if err := suite.conn.Ping(context.Background()); err != nil { |  | ||||||
| 		logrus.Panicf("db connection error: %s", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	models := []interface{}{ |  | ||||||
| 		&oauthClient{}, |  | ||||||
| 		&oauthToken{}, |  | ||||||
| 		>smodel.User{}, |  | ||||||
| 		>smodel.Account{}, |  | ||||||
| 		>smodel.Application{}, |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for _, m := range models { |  | ||||||
| 		if err := suite.conn.Model(m).CreateTable(&orm.CreateTableOptions{ |  | ||||||
| 			IfNotExists: true, |  | ||||||
| 		}); err != nil { |  | ||||||
| 			logrus.Panicf("db connection error: %s", err) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	suite.tokenStore = NewPGTokenStore(context.Background(), suite.conn, logrus.New()) |  | ||||||
| 	suite.clientStore = NewPGClientStore(suite.conn) |  | ||||||
| 
 |  | ||||||
| 	if _, err := suite.conn.Model(suite.testUser).Insert(); err != nil { |  | ||||||
| 		logrus.Panicf("could not insert test user into db: %s", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if _, err := suite.conn.Model(suite.testClient).Insert(); err != nil { |  | ||||||
| 		logrus.Panicf("could not insert test client into db: %s", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // TearDownTest drops the oauth_clients table and closes the pg connection after each test |  | ||||||
| func (suite *OauthTestSuite) TearDownTest() { |  | ||||||
| 	models := []interface{}{ |  | ||||||
| 		&oauthClient{}, |  | ||||||
| 		&oauthToken{}, |  | ||||||
| 		>smodel.User{}, |  | ||||||
| 		>smodel.Account{}, |  | ||||||
| 		>smodel.Application{}, |  | ||||||
| 	} |  | ||||||
| 	for _, m := range models { |  | ||||||
| 		if err := suite.conn.Model(m).DropTable(&orm.DropTableOptions{}); err != nil { |  | ||||||
| 			logrus.Panicf("drop table error: %s", err) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	if err := suite.conn.Close(); err != nil { |  | ||||||
| 		logrus.Panicf("error closing db connection: %s", err) |  | ||||||
| 	} |  | ||||||
| 	suite.conn = nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (suite *OauthTestSuite) TestAPIInitialize() { |  | ||||||
| 	log := logrus.New() |  | ||||||
| 	log.SetLevel(logrus.TraceLevel) |  | ||||||
| 
 |  | ||||||
| 	r := api.New(suite.config, log) |  | ||||||
| 	api := New(suite.tokenStore, suite.clientStore, suite.conn, log) |  | ||||||
| 	if err := api.AddRoutes(r); err != nil { |  | ||||||
| 		suite.FailNow(fmt.Sprintf("error initializing api: %s", err)) |  | ||||||
| 	} |  | ||||||
| 	go r.Start() |  | ||||||
| 	time.Sleep(30 * time.Second) |  | ||||||
| 	// http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=https://example.org |  | ||||||
| 	// http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=urn:ietf:wg:oauth:2.0:oob |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestOauthTestSuite(t *testing.T) { |  | ||||||
| 	suite.Run(t, new(OauthTestSuite)) |  | ||||||
| } |  | ||||||
							
								
								
									
										120
									
								
								internal/router/router.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								internal/router/router.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,120 @@ | ||||||
|  | /* | ||||||
|  |    GoToSocial | ||||||
|  |    Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org | ||||||
|  | 
 | ||||||
|  |    This program is free software: you can redistribute it and/or modify | ||||||
|  |    it under the terms of the GNU Affero General Public License as published by | ||||||
|  |    the Free Software Foundation, either version 3 of the License, or | ||||||
|  |    (at your option) any later version. | ||||||
|  | 
 | ||||||
|  |    This program is distributed in the hope that it will be useful, | ||||||
|  |    but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  |    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||||||
|  |    GNU Affero General Public License for more details. | ||||||
|  | 
 | ||||||
|  |    You should have received a copy of the GNU Affero General Public License | ||||||
|  |    along with this program.  If not, see <http://www.gnu.org/licenses/>. | ||||||
|  | */ | ||||||
|  | 
 | ||||||
|  | package router | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"fmt" | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  | 
 | ||||||
|  | 	"github.com/gin-contrib/sessions" | ||||||
|  | 	"github.com/gin-contrib/sessions/memstore" | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/gotosocial/gotosocial/internal/config" | ||||||
|  | 	"github.com/sirupsen/logrus" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Router provides the REST interface for gotosocial, using gin. | ||||||
|  | type Router interface { | ||||||
|  | 	// Attach a gin handler to the router with the given method and path | ||||||
|  | 	AttachHandler(method string, path string, handler gin.HandlerFunc) | ||||||
|  | 	// Attach a gin middleware to the router that will be used globally | ||||||
|  | 	AttachMiddleware(handler gin.HandlerFunc) | ||||||
|  | 	// Start the router | ||||||
|  | 	Start() | ||||||
|  | 	// Stop the router | ||||||
|  | 	Stop() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // router fulfils the Router interface using gin and logrus | ||||||
|  | type router struct { | ||||||
|  | 	logger *logrus.Logger | ||||||
|  | 	engine *gin.Engine | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Start starts the router nicely | ||||||
|  | func (s *router) Start() { | ||||||
|  | 	// todo: start gracefully | ||||||
|  | 	if err := s.engine.Run(); err != nil { | ||||||
|  | 		s.logger.Panicf("server error: %s", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Stop shuts down the router nicely | ||||||
|  | func (s *router) Stop() { | ||||||
|  | 	// todo: shut down gracefully | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // AttachHandler attaches the given gin.HandlerFunc to the router with the specified method and path. | ||||||
|  | // If the path is set to ANY, then the handlerfunc will be used for ALL methods at its given path. | ||||||
|  | func (s *router) AttachHandler(method string, path string, handler gin.HandlerFunc) { | ||||||
|  | 	if method == "ANY" { | ||||||
|  | 		s.engine.Any(path, handler) | ||||||
|  | 	} else { | ||||||
|  | 		s.engine.Handle(method, path, handler) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // AttachMiddleware attaches a gin middleware to the router that will be used globally | ||||||
|  | func (s *router) AttachMiddleware(middleware gin.HandlerFunc) { | ||||||
|  | 	s.engine.Use(middleware) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New returns a new Router with the specified configuration, using the given logrus logger. | ||||||
|  | func New(config *config.Config, logger *logrus.Logger) (Router, error) { | ||||||
|  | 	engine := gin.New() | ||||||
|  | 
 | ||||||
|  | 	// create a new session store middleware | ||||||
|  | 	store, err := sessionStore() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error creating session store: %s", err) | ||||||
|  | 	} | ||||||
|  | 	engine.Use(sessions.Sessions("gotosocial-session", store)) | ||||||
|  | 
 | ||||||
|  | 	// load html templates for use by the router | ||||||
|  | 	cwd, err := os.Getwd() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error getting current working directory: %s", err) | ||||||
|  | 	} | ||||||
|  | 	tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", config.TemplateConfig.BaseDir)) | ||||||
|  | 	logger.Debugf("loading templates from %s", tmPath) | ||||||
|  | 	engine.LoadHTMLGlob(tmPath) | ||||||
|  | 
 | ||||||
|  | 	return &router{ | ||||||
|  | 		logger: logger, | ||||||
|  | 		engine: engine, | ||||||
|  | 	}, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // sessionStore returns a new session store with a random auth and encryption key. | ||||||
|  | // This means that cookies using the store will be reset if gotosocial is restarted! | ||||||
|  | func sessionStore() (memstore.Store, error) { | ||||||
|  | 	auth := make([]byte, 32) | ||||||
|  | 	crypt := make([]byte, 32) | ||||||
|  | 
 | ||||||
|  | 	if _, err := rand.Read(auth); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if _, err := rand.Read(crypt); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return memstore.NewStore(auth, crypt), nil | ||||||
|  | } | ||||||
|  | @ -2,7 +2,7 @@ | ||||||
| <html lang="en"> | <html lang="en"> | ||||||
|   <head> |   <head> | ||||||
|     <meta charset="UTF-8" /> |     <meta charset="UTF-8" /> | ||||||
|     <title>Auth</title> |     <title>GoToSocial Authorization</title> | ||||||
|     <link |     <link | ||||||
|       rel="stylesheet" |       rel="stylesheet" | ||||||
|       href="//maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" |       href="//maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" | ||||||
|  | @ -11,13 +11,13 @@ | ||||||
|     <script src="//maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script> |     <script src="//maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script> | ||||||
|   </head> |   </head> | ||||||
| 
 | 
 | ||||||
| {{if len .code | eq 0 }} |  | ||||||
|   <body> |   <body> | ||||||
|     <div class="container"> |     <div class="container"> | ||||||
|       <div class="jumbotron"> |       <div class="jumbotron"> | ||||||
|         <form action="/oauth/authorize" method="POST"> |         <form action="/oauth/authorize" method="POST"> | ||||||
|           <h1>Authorize</h1> |           <h1>Hi {{.user}}!</h1> | ||||||
|           <p>The client would like to perform actions on your behalf.</p> |           <p>Application <b>{{.appname}}</b> {{if len .appwebsite | eq 0 | not}}({{.appwebsite}}) {{end}}would like to perform actions on your behalf, with scope <em>{{.scope}}</em>.</p> | ||||||
|  |           <p>The application will redirect to {{.redirect}} to continue.</p> | ||||||
|           <p> |           <p> | ||||||
|             <button |             <button | ||||||
|               type="submit" |               type="submit" | ||||||
|  | @ -31,14 +31,4 @@ | ||||||
|       </div> |       </div> | ||||||
|     </div> |     </div> | ||||||
|   </body> |   </body> | ||||||
| {{else}} |  | ||||||
|   <body> |  | ||||||
|     <div class="container"> |  | ||||||
|       <div class="jumbotron"> |  | ||||||
|         {{.code}} |  | ||||||
|       </div> |  | ||||||
|     </div> |  | ||||||
|   </body> |  | ||||||
| {{end}} |  | ||||||
| 
 |  | ||||||
| </html> | </html> | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue