mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-12-17 01:13:00 -06:00
chunking away at it
This commit is contained in:
parent
0a244be523
commit
f58f77bf1f
23 changed files with 860 additions and 394 deletions
|
|
@ -21,6 +21,7 @@ package db
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/go-fed/activity/pub"
|
"github.com/go-fed/activity/pub"
|
||||||
|
|
@ -145,6 +146,10 @@ type DB interface {
|
||||||
// C) something went wrong in the db
|
// C) something went wrong in the db
|
||||||
IsEmailAvailable(email string) error
|
IsEmailAvailable(email string) error
|
||||||
|
|
||||||
|
// NewSignup creates a new user in the database with the given parameters, with an *unconfirmed* email address.
|
||||||
|
// By the time this function is called, it should be assumed that all the parameters have passed validation!
|
||||||
|
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string) (*model.User, error)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
USEFUL CONVERSION FUNCTIONS
|
USEFUL CONVERSION FUNCTIONS
|
||||||
*/
|
*/
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@
|
||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rsa"
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
@ -82,6 +83,8 @@ type Account struct {
|
||||||
SubscriptionExpiresAt time.Time `pg:"type:timestamp"`
|
SubscriptionExpiresAt time.Time `pg:"type:timestamp"`
|
||||||
// Does this account identify itself as a bot?
|
// Does this account identify itself as a bot?
|
||||||
Bot bool
|
Bot bool
|
||||||
|
// What reason was given for signing up when this account was created?
|
||||||
|
Reason string
|
||||||
|
|
||||||
/*
|
/*
|
||||||
PRIVACY SETTINGS
|
PRIVACY SETTINGS
|
||||||
|
|
@ -123,9 +126,9 @@ type Account struct {
|
||||||
|
|
||||||
Secret string
|
Secret string
|
||||||
// Privatekey for validating activitypub requests, will obviously only be defined for local accounts
|
// Privatekey for validating activitypub requests, will obviously only be defined for local accounts
|
||||||
PrivateKey string
|
PrivateKey *rsa.PrivateKey
|
||||||
// Publickey for encoding activitypub requests, will be defined for both local and remote accounts
|
// Publickey for encoding activitypub requests, will be defined for both local and remote accounts
|
||||||
PublicKey string
|
PublicKey *rsa.PublicKey
|
||||||
|
|
||||||
/*
|
/*
|
||||||
ADMIN FIELDS
|
ADMIN FIELDS
|
||||||
|
|
|
||||||
|
|
@ -35,13 +35,13 @@ type DomainBlock struct {
|
||||||
// Account ID of the creator of this block
|
// Account ID of the creator of this block
|
||||||
CreatedByAccountID string `pg:",notnull"`
|
CreatedByAccountID string `pg:",notnull"`
|
||||||
// TODO: define this
|
// TODO: define this
|
||||||
Severity int
|
Severity int
|
||||||
// Reject media from this domain?
|
// Reject media from this domain?
|
||||||
RejectMedia bool
|
RejectMedia bool
|
||||||
// Reject reports from this domain?
|
// Reject reports from this domain?
|
||||||
RejectReports bool
|
RejectReports bool
|
||||||
// Private comment on this block, viewable to admins
|
// Private comment on this block, viewable to admins
|
||||||
PrivateComment string
|
PrivateComment string
|
||||||
// Public comment on this block, viewable (optionally) by everyone
|
// Public comment on this block, viewable (optionally) by everyone
|
||||||
PublicComment string
|
PublicComment string
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,11 @@ package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/mail"
|
"net/mail"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -35,6 +38,7 @@ import (
|
||||||
"github.com/gotosocial/gotosocial/internal/db/model"
|
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||||
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// postgresService satisfies the DB interface
|
// postgresService satisfies the DB interface
|
||||||
|
|
@ -305,7 +309,6 @@ func (ps *postgresService) GetAccountByUserID(userID string, account *model.Acco
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
|
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
|
||||||
fmt.Println(account)
|
|
||||||
if err == pg.ErrNoRows {
|
if err == pg.ErrNoRows {
|
||||||
return ErrNoEntries{}
|
return ErrNoEntries{}
|
||||||
}
|
}
|
||||||
|
|
@ -412,6 +415,43 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string) (*model.User, error) {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
ps.log.Errorf("error creating new rsa key: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
a := &model.Account{
|
||||||
|
Username: username,
|
||||||
|
DisplayName: username,
|
||||||
|
Reason: reason,
|
||||||
|
PrivateKey: key,
|
||||||
|
PublicKey: &key.PublicKey,
|
||||||
|
ActorType: "Person",
|
||||||
|
}
|
||||||
|
if _, err = ps.conn.Model(a).Insert(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error hashing password: %s", err)
|
||||||
|
}
|
||||||
|
u := &model.User{
|
||||||
|
AccountID: a.ID,
|
||||||
|
EncryptedPassword: string(pw),
|
||||||
|
SignUpIP: signUpIP,
|
||||||
|
Locale: locale,
|
||||||
|
UnconfirmedEmail: email,
|
||||||
|
}
|
||||||
|
if _, err = ps.conn.Model(u).Insert(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
CONVERSION FUNCTIONS
|
CONVERSION FUNCTIONS
|
||||||
*/
|
*/
|
||||||
|
|
@ -433,7 +473,6 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype
|
||||||
}
|
}
|
||||||
fields = append(fields, mField)
|
fields = append(fields, mField)
|
||||||
}
|
}
|
||||||
fmt.Printf("fields: %+v", fields)
|
|
||||||
|
|
||||||
// count followers
|
// count followers
|
||||||
followers := []model.Follow{}
|
followers := []model.Follow{}
|
||||||
|
|
|
||||||
21
internal/db/pg_test.go
Normal file
21
internal/db/pg_test.go
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package db
|
||||||
|
|
||||||
|
// TODO: write tests for postgres
|
||||||
|
|
@ -1,3 +1,21 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
21
internal/db/pgfed_test.go
Normal file
21
internal/db/pgfed_test.go
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package db
|
||||||
|
|
||||||
|
// TODO: write tests for pgfed
|
||||||
|
|
@ -19,6 +19,8 @@
|
||||||
package account
|
package account
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
@ -26,9 +28,10 @@ import (
|
||||||
"github.com/gotosocial/gotosocial/internal/db"
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
"github.com/gotosocial/gotosocial/internal/db/model"
|
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||||
"github.com/gotosocial/gotosocial/internal/module"
|
"github.com/gotosocial/gotosocial/internal/module"
|
||||||
"github.com/gotosocial/gotosocial/internal/module/oauth"
|
"github.com/gotosocial/gotosocial/internal/oauth"
|
||||||
"github.com/gotosocial/gotosocial/internal/router"
|
"github.com/gotosocial/gotosocial/internal/router"
|
||||||
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||||
|
"github.com/gotosocial/oauth2/v4"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -39,9 +42,10 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
type accountModule struct {
|
type accountModule struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
db db.DB
|
db db.DB
|
||||||
log *logrus.Logger
|
oauthServer oauth.Server
|
||||||
|
log *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new account module
|
// New returns a new account module
|
||||||
|
|
@ -60,15 +64,15 @@ func (m *accountModule) Route(r router.Router) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// accountCreatePOSTHandler handles create account requests, validates them,
|
||||||
|
// and puts them in the database if they're valid.
|
||||||
|
// It should be served as a POST at /api/v1/accounts
|
||||||
func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) {
|
func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) {
|
||||||
l := m.log.WithField("func", "AccountCreatePOSTHandler")
|
l := m.log.WithField("func", "accountCreatePOSTHandler")
|
||||||
// TODO: check whether a valid app token has been presented!!
|
authed, err := oauth.GetAuthed(c)
|
||||||
// See: https://docs.joinmastodon.org/methods/accounts/
|
if err != nil {
|
||||||
|
l.Debugf("couldn't auth: %s", err)
|
||||||
l.Trace("checking if registration is open")
|
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||||
if !m.config.AccountsConfig.OpenRegistration {
|
|
||||||
l.Debug("account registration is closed, returning error to client")
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "account registration is closed"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -81,15 +85,34 @@ func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
l.Tracef("validating form %+v", form)
|
l.Tracef("validating form %+v", form)
|
||||||
if err := validateCreateAccount(form, m.config.AccountsConfig.ReasonRequired, m.db); err != nil {
|
if err := validateCreateAccount(form, m.config.AccountsConfig, m.db); err != nil {
|
||||||
l.Debugf("error validating form: %s", err)
|
l.Debugf("error validating form: %s", err)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
clientIP := c.ClientIP()
|
||||||
|
l.Tracef("attempting to parse client ip address %s", clientIP)
|
||||||
|
signUpIP := net.ParseIP(clientIP)
|
||||||
|
if signUpIP == nil {
|
||||||
|
l.Debugf("error validating sign up ip address %s", clientIP)
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "ip address could not be parsed from request"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ti, err := m.accountCreate(form, signUpIP, authed.Token, authed.Application)
|
||||||
|
if err != nil {
|
||||||
|
l.Errorf("internal server error while creating new account: %s", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, ti)
|
||||||
}
|
}
|
||||||
|
|
||||||
// accountVerifyGETHandler serves a user's account details to them IF they reached this
|
// accountVerifyGETHandler serves a user's account details to them IF they reached this
|
||||||
// handler while in possession of a valid token, according to the oauth middleware.
|
// handler while in possession of a valid token, according to the oauth middleware.
|
||||||
|
// It should be served as a GET at /api/v1/accounts/verify_credentials
|
||||||
func (m *accountModule) accountVerifyGETHandler(c *gin.Context) {
|
func (m *accountModule) accountVerifyGETHandler(c *gin.Context) {
|
||||||
l := m.log.WithField("func", "AccountVerifyGETHandler")
|
l := m.log.WithField("func", "AccountVerifyGETHandler")
|
||||||
|
|
||||||
|
|
@ -120,3 +143,39 @@ func (m *accountModule) accountVerifyGETHandler(c *gin.Context) {
|
||||||
l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive)
|
l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive)
|
||||||
c.JSON(http.StatusOK, acctSensitive)
|
c.JSON(http.StatusOK, acctSensitive)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
HELPER FUNCTIONS
|
||||||
|
*/
|
||||||
|
|
||||||
|
// accountCreate does the dirty work of making an account and user in the database.
|
||||||
|
// It then returns a token to the caller, for use with the new account, as per the
|
||||||
|
// spec here: https://docs.joinmastodon.org/methods/accounts/
|
||||||
|
func (m *accountModule) accountCreate(form *mastotypes.AccountCreateRequest, signUpIP net.IP, token oauth2.TokenInfo, app *model.Application) (*mastotypes.Token, error) {
|
||||||
|
l := m.log.WithField("func", "accountCreate")
|
||||||
|
|
||||||
|
// don't store a reason if we don't require one
|
||||||
|
reason := form.Reason
|
||||||
|
if !m.config.AccountsConfig.ReasonRequired {
|
||||||
|
reason = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Trace("creating new username and account")
|
||||||
|
user, err := m.db.NewSignup(form.Username, reason, m.config.AccountsConfig.RequireApproval, form.Email, form.Password, signUpIP, form.Locale)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating new signup in the database: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Tracef("generating a token for user %s with account %s and application %s", user.ID, user.AccountID, app.ID)
|
||||||
|
ti, err := m.oauthServer.GenerateUserAccessToken(token, app.ClientSecret, user.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating new access token for user %s: %s", user.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &mastotypes.Token{
|
||||||
|
AccessToken: ti.GetCode(),
|
||||||
|
TokenType: "Bearer",
|
||||||
|
Scope: ti.GetScope(),
|
||||||
|
CreatedAt: ti.GetCodeCreateAt().Unix(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,34 +20,33 @@ package account
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"net/http/httptest"
|
||||||
"net/url"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/gotosocial/gotosocial/internal/config"
|
"github.com/gotosocial/gotosocial/internal/config"
|
||||||
"github.com/gotosocial/gotosocial/internal/db"
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
"github.com/gotosocial/gotosocial/internal/db/model"
|
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||||
"github.com/gotosocial/gotosocial/internal/module/oauth"
|
|
||||||
"github.com/gotosocial/gotosocial/internal/router"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type AccountTestSuite struct {
|
type AccountTestSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
db db.DB
|
log *logrus.Logger
|
||||||
testAccountLocal *model.Account
|
testAccountLocal *model.Account
|
||||||
testAccountRemote *model.Account
|
testAccountRemote *model.Account
|
||||||
testUser *model.User
|
testUser *model.User
|
||||||
config *config.Config
|
db db.DB
|
||||||
|
accountModule *accountModule
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
||||||
func (suite *AccountTestSuite) SetupSuite() {
|
func (suite *AccountTestSuite) SetupSuite() {
|
||||||
|
log := logrus.New()
|
||||||
|
log.SetLevel(logrus.TraceLevel)
|
||||||
|
suite.log = log
|
||||||
|
|
||||||
c := config.Empty()
|
c := config.Empty()
|
||||||
c.DBConfig = &config.DBConfig{
|
c.DBConfig = &config.DBConfig{
|
||||||
Type: "postgres",
|
Type: "postgres",
|
||||||
|
|
@ -58,118 +57,126 @@ func (suite *AccountTestSuite) SetupSuite() {
|
||||||
Database: "postgres",
|
Database: "postgres",
|
||||||
ApplicationName: "gotosocial",
|
ApplicationName: "gotosocial",
|
||||||
}
|
}
|
||||||
suite.config = c
|
|
||||||
|
|
||||||
encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
database, err := db.New(context.Background(), c, log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Panicf("error encrypting user pass: %s", err)
|
suite.FailNow(err.Error())
|
||||||
|
}
|
||||||
|
suite.db = database
|
||||||
|
|
||||||
|
suite.accountModule = &accountModule{
|
||||||
|
config: c,
|
||||||
|
db: database,
|
||||||
|
log: log,
|
||||||
}
|
}
|
||||||
|
|
||||||
localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png")
|
// encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
logrus.Panicf("error parsing localavatar url: %s", err)
|
// logrus.Panicf("error encrypting user pass: %s", err)
|
||||||
}
|
// }
|
||||||
localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png")
|
|
||||||
if err != nil {
|
|
||||||
logrus.Panicf("error parsing localheader url: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
acctID := uuid.NewString()
|
// localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png")
|
||||||
suite.testAccountLocal = &model.Account{
|
// if err != nil {
|
||||||
ID: acctID,
|
// logrus.Panicf("error parsing localavatar url: %s", err)
|
||||||
Username: "local_account_of_some_kind",
|
// }
|
||||||
AvatarRemoteURL: localAvatar,
|
// localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png")
|
||||||
HeaderRemoteURL: localHeader,
|
// if err != nil {
|
||||||
DisplayName: "michael caine",
|
// logrus.Panicf("error parsing localheader url: %s", err)
|
||||||
Fields: []model.Field{
|
// }
|
||||||
{
|
|
||||||
Name: "come and ave a go",
|
|
||||||
Value: "if you think you're hard enough",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "website",
|
|
||||||
Value: "https://imdb.com",
|
|
||||||
VerifiedAt: time.Now(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Note: "My name is Michael Caine and i'm a local user.",
|
|
||||||
Discoverable: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png")
|
// acctID := uuid.NewString()
|
||||||
if err != nil {
|
// suite.testAccountLocal = &model.Account{
|
||||||
logrus.Panicf("error parsing avatarURL: %s", err)
|
// ID: acctID,
|
||||||
}
|
// Username: "local_account_of_some_kind",
|
||||||
|
// AvatarRemoteURL: localAvatar,
|
||||||
|
// HeaderRemoteURL: localHeader,
|
||||||
|
// DisplayName: "michael caine",
|
||||||
|
// Fields: []model.Field{
|
||||||
|
// {
|
||||||
|
// Name: "come and ave a go",
|
||||||
|
// Value: "if you think you're hard enough",
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// Name: "website",
|
||||||
|
// Value: "https://imdb.com",
|
||||||
|
// VerifiedAt: time.Now(),
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// Note: "My name is Michael Caine and i'm a local user.",
|
||||||
|
// Discoverable: true,
|
||||||
|
// }
|
||||||
|
|
||||||
headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png")
|
// avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png")
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
logrus.Panicf("error parsing avatarURL: %s", err)
|
// logrus.Panicf("error parsing avatarURL: %s", err)
|
||||||
}
|
// }
|
||||||
suite.testAccountRemote = &model.Account{
|
|
||||||
ID: uuid.NewString(),
|
|
||||||
Username: "neato_bombeato",
|
|
||||||
Domain: "example.org",
|
|
||||||
|
|
||||||
AvatarFileName: "avatar.png",
|
// headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png")
|
||||||
AvatarContentType: "image/png",
|
// if err != nil {
|
||||||
AvatarFileSize: 1024,
|
// logrus.Panicf("error parsing avatarURL: %s", err)
|
||||||
AvatarUpdatedAt: time.Now(),
|
// }
|
||||||
AvatarRemoteURL: avatarURL,
|
// suite.testAccountRemote = &model.Account{
|
||||||
|
// ID: uuid.NewString(),
|
||||||
|
// Username: "neato_bombeato",
|
||||||
|
// Domain: "example.org",
|
||||||
|
|
||||||
HeaderFileName: "avatar.png",
|
// AvatarFileName: "avatar.png",
|
||||||
HeaderContentType: "image/png",
|
// AvatarContentType: "image/png",
|
||||||
HeaderFileSize: 1024,
|
// AvatarFileSize: 1024,
|
||||||
HeaderUpdatedAt: time.Now(),
|
// AvatarUpdatedAt: time.Now(),
|
||||||
HeaderRemoteURL: headerURL,
|
// AvatarRemoteURL: avatarURL,
|
||||||
|
|
||||||
DisplayName: "one cool dude 420",
|
// HeaderFileName: "avatar.png",
|
||||||
Fields: []model.Field{
|
// HeaderContentType: "image/png",
|
||||||
{
|
// HeaderFileSize: 1024,
|
||||||
Name: "pronouns",
|
// HeaderUpdatedAt: time.Now(),
|
||||||
Value: "he/they",
|
// HeaderRemoteURL: headerURL,
|
||||||
},
|
|
||||||
{
|
// DisplayName: "one cool dude 420",
|
||||||
Name: "website",
|
// Fields: []model.Field{
|
||||||
Value: "https://imcool.edu",
|
// {
|
||||||
VerifiedAt: time.Now(),
|
// Name: "pronouns",
|
||||||
},
|
// Value: "he/they",
|
||||||
},
|
// },
|
||||||
Note: "<p>I'm cool as heck!</p>",
|
// {
|
||||||
Discoverable: true,
|
// Name: "website",
|
||||||
URI: "https://example.org/users/neato_bombeato",
|
// Value: "https://imcool.edu",
|
||||||
URL: "https://example.org/@neato_bombeato",
|
// VerifiedAt: time.Now(),
|
||||||
LastWebfingeredAt: time.Now(),
|
// },
|
||||||
InboxURL: "https://example.org/users/neato_bombeato/inbox",
|
// },
|
||||||
OutboxURL: "https://example.org/users/neato_bombeato/outbox",
|
// Note: "<p>I'm cool as heck!</p>",
|
||||||
SharedInboxURL: "https://example.org/inbox",
|
// Discoverable: true,
|
||||||
FollowersURL: "https://example.org/users/neato_bombeato/followers",
|
// URI: "https://example.org/users/neato_bombeato",
|
||||||
FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured",
|
// URL: "https://example.org/@neato_bombeato",
|
||||||
}
|
// LastWebfingeredAt: time.Now(),
|
||||||
suite.testUser = &model.User{
|
// InboxURL: "https://example.org/users/neato_bombeato/inbox",
|
||||||
ID: uuid.NewString(),
|
// OutboxURL: "https://example.org/users/neato_bombeato/outbox",
|
||||||
EncryptedPassword: string(encryptedPassword),
|
// SharedInboxURL: "https://example.org/inbox",
|
||||||
Email: "user@example.org",
|
// FollowersURL: "https://example.org/users/neato_bombeato/followers",
|
||||||
AccountID: acctID,
|
// FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured",
|
||||||
|
// }
|
||||||
|
// suite.testUser = &model.User{
|
||||||
|
// ID: uuid.NewString(),
|
||||||
|
// EncryptedPassword: string(encryptedPassword),
|
||||||
|
// Email: "user@example.org",
|
||||||
|
// AccountID: acctID,
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *AccountTestSuite) TearDownSuite() {
|
||||||
|
if err := suite.db.Stop(context.Background()); err != nil {
|
||||||
|
logrus.Panicf("error closing db connection: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
|
// SetupTest creates a db connection and creates necessary tables before each test
|
||||||
func (suite *AccountTestSuite) SetupTest() {
|
func (suite *AccountTestSuite) SetupTest() {
|
||||||
|
|
||||||
log := logrus.New()
|
|
||||||
log.SetLevel(logrus.TraceLevel)
|
|
||||||
db, err := db.New(context.Background(), suite.config, log)
|
|
||||||
if err != nil {
|
|
||||||
logrus.Panicf("error creating database connection: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
suite.db = db
|
|
||||||
|
|
||||||
models := []interface{}{
|
models := []interface{}{
|
||||||
&model.User{},
|
&model.User{},
|
||||||
&model.Account{},
|
&model.Account{},
|
||||||
&model.Follow{},
|
&model.Follow{},
|
||||||
&model.Status{},
|
&model.Status{},
|
||||||
|
&model.Application{},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
|
|
@ -177,70 +184,31 @@ func (suite *AccountTestSuite) SetupTest() {
|
||||||
logrus.Panicf("db connection error: %s", err)
|
logrus.Panicf("db connection error: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := suite.db.Put(suite.testAccountLocal); err != nil {
|
|
||||||
logrus.Panicf("could not insert test account into db: %s", err)
|
|
||||||
}
|
|
||||||
if err := suite.db.Put(suite.testUser); err != nil {
|
|
||||||
logrus.Panicf("could not insert test user into db: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
// TearDownTest drops tables to make sure there's no data in the db
|
||||||
func (suite *AccountTestSuite) TearDownTest() {
|
func (suite *AccountTestSuite) TearDownTest() {
|
||||||
models := []interface{}{
|
models := []interface{}{
|
||||||
&model.User{},
|
&model.User{},
|
||||||
&model.Account{},
|
&model.Account{},
|
||||||
&model.Follow{},
|
&model.Follow{},
|
||||||
&model.Status{},
|
&model.Status{},
|
||||||
|
&model.Application{},
|
||||||
}
|
}
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
if err := suite.db.DropTable(m); err != nil {
|
if err := suite.db.DropTable(m); err != nil {
|
||||||
logrus.Panicf("error dropping table: %s", err)
|
logrus.Panicf("error dropping table: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := suite.db.Stop(context.Background()); err != nil {
|
|
||||||
logrus.Panicf("error closing db connection: %s", err)
|
|
||||||
}
|
|
||||||
suite.db = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *AccountTestSuite) TestAPIInitialize() {
|
func (suite *AccountTestSuite) TestAccountCreatePOSTHandler() {
|
||||||
log := logrus.New()
|
// TODO: figure out how to test this properly
|
||||||
log.SetLevel(logrus.TraceLevel)
|
recorder := httptest.NewRecorder()
|
||||||
|
recorder.Header().Set("X-Forwarded-For", "127.0.0.1")
|
||||||
r, err := router.New(suite.config, log)
|
ctx, _ := gin.CreateTestContext(recorder)
|
||||||
if err != nil {
|
// ctx.Set()
|
||||||
suite.FailNow(fmt.Sprintf("error creating router: %s", err))
|
suite.accountModule.accountCreatePOSTHandler(ctx)
|
||||||
}
|
|
||||||
|
|
||||||
r.AttachMiddleware(func(c *gin.Context) {
|
|
||||||
account := &model.Account{}
|
|
||||||
if err := suite.db.GetAccountByUserID(suite.testUser.ID, account); err != nil || account == nil {
|
|
||||||
suite.T().Log(err)
|
|
||||||
suite.FailNowf("no account found for user %s, continuing with unauthenticated request: %+v", "", suite.testUser.ID, account)
|
|
||||||
fmt.Println(account)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Set(oauth.SessionAuthorizedAccount, account)
|
|
||||||
c.Set(oauth.SessionAuthorizedUser, suite.testUser.ID)
|
|
||||||
})
|
|
||||||
|
|
||||||
acct := New(suite.config, suite.db, log)
|
|
||||||
if err := acct.Route(r); err != nil {
|
|
||||||
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Start()
|
|
||||||
defer func() {
|
|
||||||
if err := r.Stop(context.Background()); err != nil {
|
|
||||||
panic(fmt.Errorf("error stopping router: %s", err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
time.Sleep(10 * time.Second)
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountTestSuite(t *testing.T) {
|
func TestAccountTestSuite(t *testing.T) {
|
||||||
|
|
|
||||||
|
|
@ -21,12 +21,17 @@ package account
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
|
"github.com/gotosocial/gotosocial/internal/config"
|
||||||
"github.com/gotosocial/gotosocial/internal/db"
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
"github.com/gotosocial/gotosocial/internal/util"
|
"github.com/gotosocial/gotosocial/internal/util"
|
||||||
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
func validateCreateAccount(form *mastotypes.AccountCreateRequest, reasonRequired bool, database db.DB) error {
|
func validateCreateAccount(form *mastotypes.AccountCreateRequest, c *config.AccountsConfig, database db.DB) error {
|
||||||
|
if !c.OpenRegistration {
|
||||||
|
return errors.New("registration is not open for this server")
|
||||||
|
}
|
||||||
|
|
||||||
if err := util.ValidateSignUpUsername(form.Username); err != nil {
|
if err := util.ValidateSignUpUsername(form.Username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -47,7 +52,7 @@ func validateCreateAccount(form *mastotypes.AccountCreateRequest, reasonRequired
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := util.ValidateSignUpReason(form.Reason, reasonRequired); err != nil {
|
if err := util.ValidateSignUpReason(form.Reason, c.ReasonRequired); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
140
internal/module/app/app.go
Normal file
140
internal/module/app/app.go
Normal file
|
|
@ -0,0 +1,140 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/module"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/oauth"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/router"
|
||||||
|
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const appsPath = "/api/v1/apps"
|
||||||
|
|
||||||
|
type appModule struct {
|
||||||
|
server oauth.Server
|
||||||
|
db db.DB
|
||||||
|
log *logrus.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new auth module
|
||||||
|
func New(srv oauth.Server, db db.DB, log *logrus.Logger) module.ClientAPIModule {
|
||||||
|
return &appModule{
|
||||||
|
server: srv,
|
||||||
|
db: db,
|
||||||
|
log: log,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route satisfies the RESTAPIModule interface
|
||||||
|
func (m *appModule) Route(s router.Router) error {
|
||||||
|
s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// appsPOSTHandler should be served at https://example.org/api/v1/apps
|
||||||
|
// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
|
||||||
|
func (m *appModule) appsPOSTHandler(c *gin.Context) {
|
||||||
|
l := m.log.WithField("func", "AppsPOSTHandler")
|
||||||
|
l.Trace("entering AppsPOSTHandler")
|
||||||
|
|
||||||
|
form := &mastotypes.ApplicationPOSTRequest{}
|
||||||
|
if err := c.ShouldBind(form); err != nil {
|
||||||
|
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// permitted length for most fields
|
||||||
|
permittedLength := 64
|
||||||
|
// redirect can be a bit bigger because we probably need to encode data in the redirect uri
|
||||||
|
permittedRedirect := 256
|
||||||
|
|
||||||
|
// check lengths of fields before proceeding so the user can't spam huge entries into the database
|
||||||
|
if len(form.ClientName) > permittedLength {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(form.Website) > permittedLength {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(form.RedirectURIs) > permittedRedirect {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(form.Scopes) > permittedLength {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
|
||||||
|
var scopes string
|
||||||
|
if form.Scopes == "" {
|
||||||
|
scopes = "read"
|
||||||
|
} else {
|
||||||
|
scopes = form.Scopes
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate new IDs for this application and its associated client
|
||||||
|
clientID := uuid.NewString()
|
||||||
|
clientSecret := uuid.NewString()
|
||||||
|
vapidKey := uuid.NewString()
|
||||||
|
|
||||||
|
// generate the application to put in the database
|
||||||
|
app := &model.Application{
|
||||||
|
Name: form.ClientName,
|
||||||
|
Website: form.Website,
|
||||||
|
RedirectURI: form.RedirectURIs,
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
Scopes: scopes,
|
||||||
|
VapidKey: vapidKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
// chuck it in the db
|
||||||
|
if err := m.db.Put(app); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// now we need to model an oauth client from the application that the oauth library can use
|
||||||
|
oc := &oauth.Client{
|
||||||
|
ID: clientID,
|
||||||
|
Secret: clientSecret,
|
||||||
|
Domain: form.RedirectURIs,
|
||||||
|
UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
|
||||||
|
}
|
||||||
|
|
||||||
|
// chuck it in the db
|
||||||
|
if err := m.db.Put(oc); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
|
||||||
|
c.JSON(http.StatusOK, app.ToMasto())
|
||||||
|
}
|
||||||
21
internal/module/app/app_test.go
Normal file
21
internal/module/app/app_test.go
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package app
|
||||||
|
|
||||||
|
// TODO: write tests
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
# oauth
|
# auth
|
||||||
|
|
||||||
This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API.
|
This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API.
|
||||||
|
|
||||||
|
|
@ -16,57 +16,42 @@
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Package oauth is a module that provides oauth functionality to a router.
|
// Package auth is a module that provides oauth functionality to a router.
|
||||||
// It adds the following paths:
|
// It adds the following paths:
|
||||||
// /api/v1/apps
|
|
||||||
// /auth/sign_in
|
// /auth/sign_in
|
||||||
// /oauth/token
|
// /oauth/token
|
||||||
// /oauth/authorize
|
// /oauth/authorize
|
||||||
// It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token.
|
// It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token.
|
||||||
package oauth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/gotosocial/gotosocial/internal/db"
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
"github.com/gotosocial/gotosocial/internal/db/model"
|
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||||
"github.com/gotosocial/gotosocial/internal/module"
|
"github.com/gotosocial/gotosocial/internal/module"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/oauth"
|
||||||
"github.com/gotosocial/gotosocial/internal/router"
|
"github.com/gotosocial/gotosocial/internal/router"
|
||||||
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||||
"github.com/gotosocial/oauth2/v4"
|
|
||||||
"github.com/gotosocial/oauth2/v4/errors"
|
|
||||||
"github.com/gotosocial/oauth2/v4/manage"
|
|
||||||
"github.com/gotosocial/oauth2/v4/server"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
appsPath = "/api/v1/apps"
|
|
||||||
authSignInPath = "/auth/sign_in"
|
authSignInPath = "/auth/sign_in"
|
||||||
oauthTokenPath = "/oauth/token"
|
oauthTokenPath = "/oauth/token"
|
||||||
oauthAuthorizePath = "/oauth/authorize"
|
oauthAuthorizePath = "/oauth/authorize"
|
||||||
// SessionAuthorizedUser is the key set in the gin context for the id of
|
|
||||||
// a User who has successfully passed Bearer token authorization.
|
|
||||||
// The interface returned from grabbing this key should be parsed as a string.
|
|
||||||
SessionAuthorizedUser = "authorized_user"
|
|
||||||
// SessionAuthorizedAccount is the key set in the gin context for the Account
|
|
||||||
// of a User who has successfully passed Bearer token authorization.
|
|
||||||
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Account
|
|
||||||
SessionAuthorizedAccount = "authorized_account"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// oauthModule is an oauth2 oauthModule that satisfies the ClientAPIModule interface
|
type authModule struct {
|
||||||
type oauthModule struct {
|
server oauth.Server
|
||||||
oauthManager *manage.Manager
|
db db.DB
|
||||||
oauthServer *server.Server
|
log *logrus.Logger
|
||||||
db db.DB
|
|
||||||
log *logrus.Logger
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type login struct {
|
type login struct {
|
||||||
|
|
@ -74,52 +59,17 @@ type login struct {
|
||||||
Password string `form:"password"`
|
Password string `form:"password"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new oauth module
|
// New returns a new auth module
|
||||||
func New(ts oauth2.TokenStore, cs oauth2.ClientStore, db db.DB, log *logrus.Logger) module.ClientAPIModule {
|
func New(srv oauth.Server, db db.DB, log *logrus.Logger) module.ClientAPIModule {
|
||||||
manager := manage.NewDefaultManager()
|
return &authModule{
|
||||||
manager.MapTokenStorage(ts)
|
server: srv,
|
||||||
manager.MapClientStorage(cs)
|
db: db,
|
||||||
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
|
log: log,
|
||||||
sc := &server.Config{
|
|
||||||
TokenType: "Bearer",
|
|
||||||
// Must follow the spec.
|
|
||||||
AllowGetAccessRequest: false,
|
|
||||||
// Support only the non-implicit flow.
|
|
||||||
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
|
|
||||||
// Allow:
|
|
||||||
// - Authorization Code (for first & third parties)
|
|
||||||
AllowedGrantTypes: []oauth2.GrantType{
|
|
||||||
oauth2.AuthorizationCode,
|
|
||||||
},
|
|
||||||
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := server.NewServer(sc, manager)
|
|
||||||
srv.SetInternalErrorHandler(func(err error) *errors.Response {
|
|
||||||
log.Errorf("internal oauth error: %s", err)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
srv.SetResponseErrorHandler(func(re *errors.Response) {
|
|
||||||
log.Errorf("internal response error: %s", re.Error)
|
|
||||||
})
|
|
||||||
|
|
||||||
m := &oauthModule{
|
|
||||||
oauthManager: manager,
|
|
||||||
oauthServer: srv,
|
|
||||||
db: db,
|
|
||||||
log: log,
|
|
||||||
}
|
|
||||||
|
|
||||||
m.oauthServer.SetUserAuthorizationHandler(m.userAuthorizationHandler)
|
|
||||||
m.oauthServer.SetClientInfoHandler(server.ClientFormHandler)
|
|
||||||
return m
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Route satisfies the RESTAPIModule interface
|
// Route satisfies the RESTAPIModule interface
|
||||||
func (m *oauthModule) Route(s router.Router) error {
|
func (m *authModule) Route(s router.Router) error {
|
||||||
s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
|
|
||||||
|
|
||||||
s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler)
|
s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler)
|
||||||
s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler)
|
s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler)
|
||||||
|
|
||||||
|
|
@ -129,7 +79,6 @@ func (m *oauthModule) Route(s router.Router) error {
|
||||||
s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler)
|
s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler)
|
||||||
|
|
||||||
s.AttachMiddleware(m.oauthTokenMiddleware)
|
s.AttachMiddleware(m.oauthTokenMiddleware)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -137,93 +86,10 @@ func (m *oauthModule) Route(s router.Router) error {
|
||||||
MAIN HANDLERS -- serve these through a server/router
|
MAIN HANDLERS -- serve these through a server/router
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// appsPOSTHandler should be served at https://example.org/api/v1/apps
|
|
||||||
// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
|
|
||||||
func (m *oauthModule) appsPOSTHandler(c *gin.Context) {
|
|
||||||
l := m.log.WithField("func", "AppsPOSTHandler")
|
|
||||||
l.Trace("entering AppsPOSTHandler")
|
|
||||||
|
|
||||||
form := &mastotypes.ApplicationPOSTRequest{}
|
|
||||||
if err := c.ShouldBind(form); err != nil {
|
|
||||||
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// permitted length for most fields
|
|
||||||
permittedLength := 64
|
|
||||||
// redirect can be a bit bigger because we probably need to encode data in the redirect uri
|
|
||||||
permittedRedirect := 256
|
|
||||||
|
|
||||||
// check lengths of fields before proceeding so the user can't spam huge entries into the database
|
|
||||||
if len(form.ClientName) > permittedLength {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(form.Website) > permittedLength {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(form.RedirectURIs) > permittedRedirect {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(form.Scopes) > permittedLength {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
|
|
||||||
var scopes string
|
|
||||||
if form.Scopes == "" {
|
|
||||||
scopes = "read"
|
|
||||||
} else {
|
|
||||||
scopes = form.Scopes
|
|
||||||
}
|
|
||||||
|
|
||||||
// generate new IDs for this application and its associated client
|
|
||||||
clientID := uuid.NewString()
|
|
||||||
clientSecret := uuid.NewString()
|
|
||||||
vapidKey := uuid.NewString()
|
|
||||||
|
|
||||||
// generate the application to put in the database
|
|
||||||
app := &model.Application{
|
|
||||||
Name: form.ClientName,
|
|
||||||
Website: form.Website,
|
|
||||||
RedirectURI: form.RedirectURIs,
|
|
||||||
ClientID: clientID,
|
|
||||||
ClientSecret: clientSecret,
|
|
||||||
Scopes: scopes,
|
|
||||||
VapidKey: vapidKey,
|
|
||||||
}
|
|
||||||
|
|
||||||
// chuck it in the db
|
|
||||||
if err := m.db.Put(app); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// now we need to model an oauth client from the application that the oauth library can use
|
|
||||||
oc := &oauthClient{
|
|
||||||
ID: clientID,
|
|
||||||
Secret: clientSecret,
|
|
||||||
Domain: form.RedirectURIs,
|
|
||||||
UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
|
|
||||||
}
|
|
||||||
|
|
||||||
// chuck it in the db
|
|
||||||
if err := m.db.Put(oc); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
|
|
||||||
c.JSON(http.StatusOK, app.ToMasto())
|
|
||||||
}
|
|
||||||
|
|
||||||
// signInGETHandler should be served at https://example.org/auth/sign_in.
|
// signInGETHandler should be served at https://example.org/auth/sign_in.
|
||||||
// The idea is to present a sign in page to the user, where they can enter their username and password.
|
// The idea is to present a sign in page to the user, where they can enter their username and password.
|
||||||
// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler
|
// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler
|
||||||
func (m *oauthModule) signInGETHandler(c *gin.Context) {
|
func (m *authModule) signInGETHandler(c *gin.Context) {
|
||||||
m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html")
|
m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html")
|
||||||
c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{})
|
c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{})
|
||||||
}
|
}
|
||||||
|
|
@ -231,7 +97,7 @@ func (m *oauthModule) signInGETHandler(c *gin.Context) {
|
||||||
// signInPOSTHandler should be served at https://example.org/auth/sign_in.
|
// signInPOSTHandler should be served at https://example.org/auth/sign_in.
|
||||||
// The idea is to present a sign in page to the user, where they can enter their username and password.
|
// The idea is to present a sign in page to the user, where they can enter their username and password.
|
||||||
// The handler will then redirect to the auth handler served at /auth
|
// The handler will then redirect to the auth handler served at /auth
|
||||||
func (m *oauthModule) signInPOSTHandler(c *gin.Context) {
|
func (m *authModule) signInPOSTHandler(c *gin.Context) {
|
||||||
l := m.log.WithField("func", "SignInPOSTHandler")
|
l := m.log.WithField("func", "SignInPOSTHandler")
|
||||||
s := sessions.Default(c)
|
s := sessions.Default(c)
|
||||||
form := &login{}
|
form := &login{}
|
||||||
|
|
@ -260,10 +126,10 @@ func (m *oauthModule) signInPOSTHandler(c *gin.Context) {
|
||||||
// tokenPOSTHandler should be served as a POST at https://example.org/oauth/token
|
// tokenPOSTHandler should be served as a POST at https://example.org/oauth/token
|
||||||
// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
|
// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
|
||||||
// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token
|
// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token
|
||||||
func (m *oauthModule) tokenPOSTHandler(c *gin.Context) {
|
func (m *authModule) tokenPOSTHandler(c *gin.Context) {
|
||||||
l := m.log.WithField("func", "TokenPOSTHandler")
|
l := m.log.WithField("func", "TokenPOSTHandler")
|
||||||
l.Trace("entered TokenPOSTHandler")
|
l.Trace("entered TokenPOSTHandler")
|
||||||
if err := m.oauthServer.HandleTokenRequest(c.Writer, c.Request); err != nil {
|
if err := m.server.HandleTokenRequest(c.Writer, c.Request); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -271,7 +137,7 @@ func (m *oauthModule) tokenPOSTHandler(c *gin.Context) {
|
||||||
// authorizeGETHandler should be served as GET at https://example.org/oauth/authorize
|
// authorizeGETHandler should be served as GET at https://example.org/oauth/authorize
|
||||||
// The idea here is to present an oauth authorize page to the user, with a button
|
// The idea here is to present an oauth authorize page to the user, with a button
|
||||||
// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
|
// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
|
||||||
func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
|
func (m *authModule) authorizeGETHandler(c *gin.Context) {
|
||||||
l := m.log.WithField("func", "AuthorizeGETHandler")
|
l := m.log.WithField("func", "AuthorizeGETHandler")
|
||||||
s := sessions.Default(c)
|
s := sessions.Default(c)
|
||||||
|
|
||||||
|
|
@ -349,7 +215,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
|
||||||
// At this point we assume that the user has A) logged in and B) accepted that the app should act for them,
|
// At this point we assume that the user has A) logged in and B) accepted that the app should act for them,
|
||||||
// so we should proceed with the authentication flow and generate an oauth token for them if we can.
|
// so we should proceed with the authentication flow and generate an oauth token for them if we can.
|
||||||
// See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
|
// See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
|
||||||
func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
|
func (m *authModule) authorizePOSTHandler(c *gin.Context) {
|
||||||
l := m.log.WithField("func", "AuthorizePOSTHandler")
|
l := m.log.WithField("func", "AuthorizePOSTHandler")
|
||||||
s := sessions.Default(c)
|
s := sessions.Default(c)
|
||||||
|
|
||||||
|
|
@ -404,7 +270,7 @@ func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
|
||||||
l.Tracef("values on request set to %+v", c.Request.Form)
|
l.Tracef("values on request set to %+v", c.Request.Form)
|
||||||
|
|
||||||
// and proceed with authorization using the oauth2 library
|
// and proceed with authorization using the oauth2 library
|
||||||
if err := m.oauthServer.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
|
if err := m.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -418,25 +284,50 @@ func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
|
||||||
// the request. Then, it will look up the account for that user, and set that in the request too.
|
// the request. Then, it will look up the account for that user, and set that in the request too.
|
||||||
// If user or account can't be found, then the handler won't *fail*, in case the server wants to allow
|
// If user or account can't be found, then the handler won't *fail*, in case the server wants to allow
|
||||||
// public requests that don't have a Bearer token set (eg., for public instance information and so on).
|
// public requests that don't have a Bearer token set (eg., for public instance information and so on).
|
||||||
func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) {
|
func (m *authModule) oauthTokenMiddleware(c *gin.Context) {
|
||||||
l := m.log.WithField("func", "ValidatePassword")
|
l := m.log.WithField("func", "ValidatePassword")
|
||||||
l.Trace("entering OauthTokenMiddleware")
|
l.Trace("entering OauthTokenMiddleware")
|
||||||
|
|
||||||
ti, err := m.oauthServer.ValidationBearerToken(c.Request)
|
ti, err := m.server.ValidationBearerToken(c.Request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Trace("no valid token presented: continuing with unauthenticated request")
|
l.Trace("no valid token presented: continuing with unauthenticated request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope())
|
c.Set(oauth.SessionAuthorizedToken, ti)
|
||||||
|
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedToken, ti)
|
||||||
|
|
||||||
acct := &model.Account{}
|
// check for user-level token
|
||||||
if err := m.db.GetAccountByUserID(ti.GetUserID(), acct); err != nil || acct == nil {
|
if uid := ti.GetUserID(); uid != "" {
|
||||||
l.Tracef("no account found for user %s, continuing with unauthenticated request", ti.GetUserID())
|
l.Tracef("authenticated user %s with bearer token, scope is %s", uid, ti.GetScope())
|
||||||
return
|
|
||||||
|
// fetch user's and account for this user id
|
||||||
|
user := &model.User{}
|
||||||
|
if err := m.db.GetByID(uid, user); err != nil || user == nil {
|
||||||
|
l.Warnf("no user found for validated uid %s", uid)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(oauth.SessionAuthorizedUser, user)
|
||||||
|
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user)
|
||||||
|
|
||||||
|
acct := &model.Account{}
|
||||||
|
if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil {
|
||||||
|
l.Warnf("no account found for validated user %s", uid)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(oauth.SessionAuthorizedAccount, acct)
|
||||||
|
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedAccount, acct)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Set(SessionAuthorizedAccount, acct)
|
// check for application token
|
||||||
c.Set(SessionAuthorizedUser, ti.GetUserID())
|
if cid := ti.GetClientID(); cid != "" {
|
||||||
|
l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope())
|
||||||
|
app := &model.Application{}
|
||||||
|
if err := m.db.GetWhere("client_id", cid, app); err != nil {
|
||||||
|
l.Tracef("no app found for client %s", cid)
|
||||||
|
}
|
||||||
|
c.Set(oauth.SessionAuthorizedApplication, app)
|
||||||
|
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedApplication, app)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
@ -447,7 +338,7 @@ func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) {
|
||||||
// The goal is to authenticate the password against the one for that email
|
// The goal is to authenticate the password against the one for that email
|
||||||
// address stored in the database. If OK, we return the userid (a uuid) for that user,
|
// address stored in the database. If OK, we return the userid (a uuid) for that user,
|
||||||
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
|
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
|
||||||
func (m *oauthModule) validatePassword(email string, password string) (userid string, err error) {
|
func (m *authModule) validatePassword(email string, password string) (userid string, err error) {
|
||||||
l := m.log.WithField("func", "ValidatePassword")
|
l := m.log.WithField("func", "ValidatePassword")
|
||||||
|
|
||||||
// make sure an email/password was provided and bail if not
|
// make sure an email/password was provided and bail if not
|
||||||
|
|
@ -487,18 +378,6 @@ func incorrectPassword() (string, error) {
|
||||||
return "", errors.New("password/email combination was incorrect")
|
return "", errors.New("password/email combination was incorrect")
|
||||||
}
|
}
|
||||||
|
|
||||||
// userAuthorizationHandler gets the user's ID from the 'userid' field of the request form,
|
|
||||||
// or redirects to the /auth/sign_in page, if this key is not present.
|
|
||||||
func (m *oauthModule) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) {
|
|
||||||
l := m.log.WithField("func", "UserAuthorizationHandler")
|
|
||||||
userID = r.FormValue("userid")
|
|
||||||
if userID == "" {
|
|
||||||
return "", errors.New("userid was empty, redirecting to sign in page")
|
|
||||||
}
|
|
||||||
l.Tracef("returning userID %s", userID)
|
|
||||||
return userID, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseAuthForm parses the OAuthAuthorize form in the gin context, and stores
|
// parseAuthForm parses the OAuthAuthorize form in the gin context, and stores
|
||||||
// the values in the form into the session.
|
// the values in the form into the session.
|
||||||
func parseAuthForm(c *gin.Context, l *logrus.Entry) error {
|
func parseAuthForm(c *gin.Context, l *logrus.Entry) error {
|
||||||
|
|
@ -16,38 +16,38 @@
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package oauth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/gotosocial/gotosocial/internal/config"
|
"github.com/gotosocial/gotosocial/internal/config"
|
||||||
"github.com/gotosocial/gotosocial/internal/db"
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
"github.com/gotosocial/gotosocial/internal/db/model"
|
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/oauth"
|
||||||
"github.com/gotosocial/gotosocial/internal/router"
|
"github.com/gotosocial/gotosocial/internal/router"
|
||||||
"github.com/gotosocial/oauth2/v4"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OauthTestSuite struct {
|
type AuthTestSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
tokenStore oauth2.TokenStore
|
oauthServer oauth.Server
|
||||||
clientStore oauth2.ClientStore
|
|
||||||
db db.DB
|
db db.DB
|
||||||
testAccount *model.Account
|
testAccount *model.Account
|
||||||
testApplication *model.Application
|
testApplication *model.Application
|
||||||
testUser *model.User
|
testUser *model.User
|
||||||
testClient *oauthClient
|
testClient *oauth.Client
|
||||||
config *config.Config
|
config *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
||||||
func (suite *OauthTestSuite) SetupSuite() {
|
func (suite *AuthTestSuite) SetupSuite() {
|
||||||
c := config.Empty()
|
c := config.Empty()
|
||||||
// we're running on localhost without https so set the protocol to http
|
// we're running on localhost without https so set the protocol to http
|
||||||
c.Protocol = "http"
|
c.Protocol = "http"
|
||||||
|
|
@ -84,7 +84,7 @@ func (suite *OauthTestSuite) SetupSuite() {
|
||||||
Email: "user@example.org",
|
Email: "user@example.org",
|
||||||
AccountID: acctID,
|
AccountID: acctID,
|
||||||
}
|
}
|
||||||
suite.testClient = &oauthClient{
|
suite.testClient = &oauth.Client{
|
||||||
ID: "a-known-client-id",
|
ID: "a-known-client-id",
|
||||||
Secret: "some-secret",
|
Secret: "some-secret",
|
||||||
Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host),
|
Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host),
|
||||||
|
|
@ -101,7 +101,7 @@ func (suite *OauthTestSuite) SetupSuite() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
|
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
|
||||||
func (suite *OauthTestSuite) SetupTest() {
|
func (suite *AuthTestSuite) SetupTest() {
|
||||||
|
|
||||||
log := logrus.New()
|
log := logrus.New()
|
||||||
log.SetLevel(logrus.TraceLevel)
|
log.SetLevel(logrus.TraceLevel)
|
||||||
|
|
@ -113,8 +113,8 @@ func (suite *OauthTestSuite) SetupTest() {
|
||||||
suite.db = db
|
suite.db = db
|
||||||
|
|
||||||
models := []interface{}{
|
models := []interface{}{
|
||||||
&oauthClient{},
|
&oauth.Client{},
|
||||||
&oauthToken{},
|
&oauth.Token{},
|
||||||
&model.User{},
|
&model.User{},
|
||||||
&model.Account{},
|
&model.Account{},
|
||||||
&model.Application{},
|
&model.Application{},
|
||||||
|
|
@ -126,8 +126,7 @@ func (suite *OauthTestSuite) SetupTest() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
suite.tokenStore = newTokenStore(context.Background(), suite.db, logrus.New())
|
suite.oauthServer = oauth.New(suite.db, log)
|
||||||
suite.clientStore = newClientStore(suite.db)
|
|
||||||
|
|
||||||
if err := suite.db.Put(suite.testAccount); err != nil {
|
if err := suite.db.Put(suite.testAccount); err != nil {
|
||||||
logrus.Panicf("could not insert test account into db: %s", err)
|
logrus.Panicf("could not insert test account into db: %s", err)
|
||||||
|
|
@ -145,10 +144,10 @@ func (suite *OauthTestSuite) SetupTest() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
||||||
func (suite *OauthTestSuite) TearDownTest() {
|
func (suite *AuthTestSuite) TearDownTest() {
|
||||||
models := []interface{}{
|
models := []interface{}{
|
||||||
&oauthClient{},
|
&oauth.Client{},
|
||||||
&oauthToken{},
|
&oauth.Token{},
|
||||||
&model.User{},
|
&model.User{},
|
||||||
&model.Account{},
|
&model.Account{},
|
||||||
&model.Application{},
|
&model.Application{},
|
||||||
|
|
@ -164,7 +163,7 @@ func (suite *OauthTestSuite) TearDownTest() {
|
||||||
suite.db = nil
|
suite.db = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *OauthTestSuite) TestAPIInitialize() {
|
func (suite *AuthTestSuite) TestAPIInitialize() {
|
||||||
log := logrus.New()
|
log := logrus.New()
|
||||||
log.SetLevel(logrus.TraceLevel)
|
log.SetLevel(logrus.TraceLevel)
|
||||||
|
|
||||||
|
|
@ -173,17 +172,18 @@ func (suite *OauthTestSuite) TestAPIInitialize() {
|
||||||
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
api := New(suite.tokenStore, suite.clientStore, suite.db, log)
|
api := New(suite.oauthServer, suite.db, log)
|
||||||
if err := api.Route(r); err != nil {
|
if err := api.Route(r); err != nil {
|
||||||
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Start()
|
r.Start()
|
||||||
|
time.Sleep(60 * time.Second)
|
||||||
if err := r.Stop(context.Background()); err != nil {
|
if err := r.Stop(context.Background()); err != nil {
|
||||||
suite.FailNow(fmt.Sprintf("error stopping router: %s", err))
|
suite.FailNow(fmt.Sprintf("error stopping router: %s", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOauthTestSuite(t *testing.T) {
|
func TestAuthTestSuite(t *testing.T) {
|
||||||
suite.Run(t, new(OauthTestSuite))
|
suite.Run(t, new(AuthTestSuite))
|
||||||
}
|
}
|
||||||
|
|
@ -38,7 +38,7 @@ func newClientStore(db db.DB) oauth2.ClientStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
|
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
|
||||||
poc := &oauthClient{
|
poc := &Client{
|
||||||
ID: clientID,
|
ID: clientID,
|
||||||
}
|
}
|
||||||
if err := cs.db.GetByID(clientID, poc); err != nil {
|
if err := cs.db.GetByID(clientID, poc); err != nil {
|
||||||
|
|
@ -48,7 +48,7 @@ func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.Cli
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
|
func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
|
||||||
poc := &oauthClient{
|
poc := &Client{
|
||||||
ID: cli.GetID(),
|
ID: cli.GetID(),
|
||||||
Secret: cli.GetSecret(),
|
Secret: cli.GetSecret(),
|
||||||
Domain: cli.GetDomain(),
|
Domain: cli.GetDomain(),
|
||||||
|
|
@ -58,13 +58,13 @@ func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *clientStore) Delete(ctx context.Context, id string) error {
|
func (cs *clientStore) Delete(ctx context.Context, id string) error {
|
||||||
poc := &oauthClient{
|
poc := &Client{
|
||||||
ID: id,
|
ID: id,
|
||||||
}
|
}
|
||||||
return cs.db.DeleteByID(id, poc)
|
return cs.db.DeleteByID(id, poc)
|
||||||
}
|
}
|
||||||
|
|
||||||
type oauthClient struct {
|
type Client struct {
|
||||||
ID string
|
ID string
|
||||||
Secret string
|
Secret string
|
||||||
Domain string
|
Domain string
|
||||||
|
|
@ -69,7 +69,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
|
||||||
suite.db = db
|
suite.db = db
|
||||||
|
|
||||||
models := []interface{}{
|
models := []interface{}{
|
||||||
&oauthClient{},
|
&Client{},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
|
|
@ -82,7 +82,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
|
||||||
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
||||||
func (suite *PgClientStoreTestSuite) TearDownTest() {
|
func (suite *PgClientStoreTestSuite) TearDownTest() {
|
||||||
models := []interface{}{
|
models := []interface{}{
|
||||||
&oauthClient{},
|
&Client{},
|
||||||
}
|
}
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
if err := suite.db.DropTable(m); err != nil {
|
if err := suite.db.DropTable(m); err != nil {
|
||||||
212
internal/oauth/oauth.go
Normal file
212
internal/oauth/oauth.go
Normal file
|
|
@ -0,0 +1,212 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package oauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||||
|
"github.com/gotosocial/oauth2/v4"
|
||||||
|
"github.com/gotosocial/oauth2/v4/errors"
|
||||||
|
"github.com/gotosocial/oauth2/v4/manage"
|
||||||
|
"github.com/gotosocial/oauth2/v4/server"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SessionAuthorizedToken = "authorized_token"
|
||||||
|
// SessionAuthorizedUser is the key set in the gin context for the id of
|
||||||
|
// a User who has successfully passed Bearer token authorization.
|
||||||
|
// The interface returned from grabbing this key should be parsed as a *gtsmodel.User
|
||||||
|
SessionAuthorizedUser = "authorized_user"
|
||||||
|
// SessionAuthorizedAccount is the key set in the gin context for the Account
|
||||||
|
// of a User who has successfully passed Bearer token authorization.
|
||||||
|
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Account
|
||||||
|
SessionAuthorizedAccount = "authorized_account"
|
||||||
|
// SessionAuthorizedAccount is the key set in the gin context for the Application
|
||||||
|
// of a Client who has successfully passed Bearer token authorization.
|
||||||
|
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Application
|
||||||
|
SessionAuthorizedApplication = "authorized_app"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server wraps some oauth2 server functions in an interface, exposing only what is needed
|
||||||
|
type Server interface {
|
||||||
|
HandleTokenRequest(w http.ResponseWriter, r *http.Request) error
|
||||||
|
HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error
|
||||||
|
ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error)
|
||||||
|
GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// s fulfils the Server interface using the underlying oauth2 server
|
||||||
|
type s struct {
|
||||||
|
server *server.Server
|
||||||
|
log *logrus.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
type Authed struct {
|
||||||
|
Token oauth2.TokenInfo
|
||||||
|
Application *model.Application
|
||||||
|
User *model.User
|
||||||
|
Account *model.Account
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthed is a convenience function for returning an Authed struct from a gin context.
|
||||||
|
// In essence, it tries to extract a token, application, user, and account from the context,
|
||||||
|
// and then sets them on a struct for convenience.
|
||||||
|
//
|
||||||
|
// If any are not present in the context, they will be set to nil on the returned Authed struct.
|
||||||
|
//
|
||||||
|
// If *ALL* are not present, then nil and an error will be returned.
|
||||||
|
//
|
||||||
|
// If something goes wrong during parsing, then nil and an error will be returned (consider this not authed).
|
||||||
|
func GetAuthed(c *gin.Context) (*Authed, error) {
|
||||||
|
ctx := c.Copy()
|
||||||
|
a := &Authed{}
|
||||||
|
var i interface{}
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
i, ok = ctx.Get(SessionAuthorizedToken)
|
||||||
|
if ok {
|
||||||
|
parsed, ok := i.(oauth2.TokenInfo)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("could not parse token from session context")
|
||||||
|
}
|
||||||
|
a.Token = parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
i, ok = ctx.Get(SessionAuthorizedApplication)
|
||||||
|
if ok {
|
||||||
|
parsed, ok := i.(*model.Application)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("could not parse application from session context")
|
||||||
|
}
|
||||||
|
a.Application = parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
i, ok = ctx.Get(SessionAuthorizedUser)
|
||||||
|
if ok {
|
||||||
|
parsed, ok := i.(*model.User)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("could not parse user from session context")
|
||||||
|
}
|
||||||
|
a.User = parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
i, ok = ctx.Get(SessionAuthorizedAccount)
|
||||||
|
if ok {
|
||||||
|
parsed, ok := i.(*model.Account)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("could not parse account from session context")
|
||||||
|
}
|
||||||
|
a.Account = parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.Token == nil && a.Application == nil && a.User == nil && a.Account == nil {
|
||||||
|
return nil, errors.New("not authorized")
|
||||||
|
}
|
||||||
|
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
|
||||||
|
func (s *s) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
return s.server.HandleTokenRequest(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleAuthorizeRequest wraps the oauth2 library's HandleAuthorizeRequest function
|
||||||
|
func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
return s.server.HandleAuthorizeRequest(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidationBearerToken wraps the oauth2 library's ValidationBearerToken function
|
||||||
|
func (s *s) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
|
||||||
|
return s.server.ValidationBearerToken(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateUserAccessToken shortcuts the normal oauth flow to create an user-level
|
||||||
|
// bearer token *without* requiring that user to log in. This is useful when we
|
||||||
|
// need to create a token for new users who haven't validated their email or logged in yet.
|
||||||
|
//
|
||||||
|
// The ti parameter refers to an existing Application token that was used to make the upstream
|
||||||
|
// request. This token needs to be validated and exist in database in order to create a new token.
|
||||||
|
func (s *s) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error) {
|
||||||
|
|
||||||
|
tgr := &oauth2.TokenGenerateRequest{
|
||||||
|
ClientID: ti.GetClientID(),
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
UserID: userID,
|
||||||
|
RedirectURI: ti.GetRedirectURI(),
|
||||||
|
Scope: ti.GetScope(),
|
||||||
|
Code: ti.GetCode(),
|
||||||
|
CodeChallenge: ti.GetCodeChallenge(),
|
||||||
|
CodeChallengeMethod: ti.GetCodeChallengeMethod(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.server.Manager.GenerateAccessToken(context.Background(), oauth2.AuthorizationCode, tgr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(database db.DB, log *logrus.Logger) Server {
|
||||||
|
ts := newTokenStore(context.Background(), database, log)
|
||||||
|
cs := newClientStore(database)
|
||||||
|
|
||||||
|
manager := manage.NewDefaultManager()
|
||||||
|
manager.MapTokenStorage(ts)
|
||||||
|
manager.MapClientStorage(cs)
|
||||||
|
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
|
||||||
|
sc := &server.Config{
|
||||||
|
TokenType: "Bearer",
|
||||||
|
// Must follow the spec.
|
||||||
|
AllowGetAccessRequest: false,
|
||||||
|
// Support only the non-implicit flow.
|
||||||
|
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
|
||||||
|
// Allow:
|
||||||
|
// - Authorization Code (for first & third parties)
|
||||||
|
// - Client Credentials (for applications)
|
||||||
|
AllowedGrantTypes: []oauth2.GrantType{
|
||||||
|
oauth2.AuthorizationCode,
|
||||||
|
oauth2.ClientCredentials,
|
||||||
|
},
|
||||||
|
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := server.NewServer(sc, manager)
|
||||||
|
srv.SetInternalErrorHandler(func(err error) *errors.Response {
|
||||||
|
log.Errorf("internal oauth error: %s", err)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
srv.SetResponseErrorHandler(func(re *errors.Response) {
|
||||||
|
log.Errorf("internal response error: %s", re.Error)
|
||||||
|
})
|
||||||
|
|
||||||
|
srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||||
|
userID := r.FormValue("userid")
|
||||||
|
if userID == "" {
|
||||||
|
return "", errors.New("userid was empty")
|
||||||
|
}
|
||||||
|
return userID, nil
|
||||||
|
})
|
||||||
|
srv.SetClientInfoHandler(server.ClientFormHandler)
|
||||||
|
return &s{
|
||||||
|
server: srv,
|
||||||
|
}
|
||||||
|
}
|
||||||
21
internal/oauth/oauth_test.go
Normal file
21
internal/oauth/oauth_test.go
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package oauth
|
||||||
|
|
||||||
|
// TODO: write tests
|
||||||
|
|
@ -70,7 +70,7 @@ func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.Tok
|
||||||
func (pts *tokenStore) sweep() error {
|
func (pts *tokenStore) sweep() error {
|
||||||
// select *all* tokens from the db
|
// select *all* tokens from the db
|
||||||
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
|
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
|
||||||
tokens := new([]*oauthToken)
|
tokens := new([]*Token)
|
||||||
if err := pts.db.GetAll(tokens); err != nil {
|
if err := pts.db.GetAll(tokens); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -106,22 +106,22 @@ func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error
|
||||||
|
|
||||||
// RemoveByCode deletes a token from the DB based on the Code field
|
// RemoveByCode deletes a token from the DB based on the Code field
|
||||||
func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
|
func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
|
||||||
return pts.db.DeleteWhere("code", code, &oauthToken{})
|
return pts.db.DeleteWhere("code", code, &Token{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveByAccess deletes a token from the DB based on the Access field
|
// RemoveByAccess deletes a token from the DB based on the Access field
|
||||||
func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
|
func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
|
||||||
return pts.db.DeleteWhere("access", access, &oauthToken{})
|
return pts.db.DeleteWhere("access", access, &Token{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveByRefresh deletes a token from the DB based on the Refresh field
|
// RemoveByRefresh deletes a token from the DB based on the Refresh field
|
||||||
func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
|
func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
|
||||||
return pts.db.DeleteWhere("refresh", refresh, &oauthToken{})
|
return pts.db.DeleteWhere("refresh", refresh, &Token{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByCode selects a token from the DB based on the Code field
|
// GetByCode selects a token from the DB based on the Code field
|
||||||
func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
|
func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
|
||||||
pgt := &oauthToken{
|
pgt := &Token{
|
||||||
Code: code,
|
Code: code,
|
||||||
}
|
}
|
||||||
if err := pts.db.GetWhere("code", code, pgt); err != nil {
|
if err := pts.db.GetWhere("code", code, pgt); err != nil {
|
||||||
|
|
@ -132,7 +132,7 @@ func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.Token
|
||||||
|
|
||||||
// GetByAccess selects a token from the DB based on the Access field
|
// GetByAccess selects a token from the DB based on the Access field
|
||||||
func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
|
func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
|
||||||
pgt := &oauthToken{
|
pgt := &Token{
|
||||||
Access: access,
|
Access: access,
|
||||||
}
|
}
|
||||||
if err := pts.db.GetWhere("access", access, pgt); err != nil {
|
if err := pts.db.GetWhere("access", access, pgt); err != nil {
|
||||||
|
|
@ -143,7 +143,7 @@ func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.T
|
||||||
|
|
||||||
// GetByRefresh selects a token from the DB based on the Refresh field
|
// GetByRefresh selects a token from the DB based on the Refresh field
|
||||||
func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
|
func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
|
||||||
pgt := &oauthToken{
|
pgt := &Token{
|
||||||
Refresh: refresh,
|
Refresh: refresh,
|
||||||
}
|
}
|
||||||
if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil {
|
if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil {
|
||||||
|
|
@ -156,7 +156,7 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
|
||||||
The following models are basically helpers for the postgres token store implementation, they should only be used internally.
|
The following models are basically helpers for the postgres token store implementation, they should only be used internally.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// oauthToken is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
|
// Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
|
||||||
//
|
//
|
||||||
// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined,
|
// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined,
|
||||||
// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and
|
// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and
|
||||||
|
|
@ -164,9 +164,9 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
|
||||||
//
|
//
|
||||||
// Note that this struct does *not* satisfy the token interface shown here: https://github.com/gotosocial/oauth2/blob/master/model.go#L22
|
// Note that this struct does *not* satisfy the token interface shown here: https://github.com/gotosocial/oauth2/blob/master/model.go#L22
|
||||||
// and implemented here: https://github.com/gotosocial/oauth2/blob/master/models/token.go.
|
// and implemented here: https://github.com/gotosocial/oauth2/blob/master/models/token.go.
|
||||||
// As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
|
// As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
|
||||||
// and pgTokenToOauthToken can be used for that.
|
// and pgTokenToOauthToken can be used for that.
|
||||||
type oauthToken struct {
|
type Token struct {
|
||||||
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
|
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
|
||||||
ClientID string
|
ClientID string
|
||||||
UserID string
|
UserID string
|
||||||
|
|
@ -186,7 +186,7 @@ type oauthToken struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres
|
// oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres
|
||||||
func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
|
func oauthTokenToPGToken(tkn *models.Token) *Token {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
// For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
|
// For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
|
||||||
|
|
@ -208,7 +208,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
|
||||||
rea = now.Add(tkn.RefreshExpiresIn)
|
rea = now.Add(tkn.RefreshExpiresIn)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &oauthToken{
|
return &Token{
|
||||||
ClientID: tkn.ClientID,
|
ClientID: tkn.ClientID,
|
||||||
UserID: tkn.UserID,
|
UserID: tkn.UserID,
|
||||||
RedirectURI: tkn.RedirectURI,
|
RedirectURI: tkn.RedirectURI,
|
||||||
|
|
@ -228,7 +228,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
|
||||||
}
|
}
|
||||||
|
|
||||||
// pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token
|
// pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token
|
||||||
func pgTokenToOauthToken(pgt *oauthToken) *models.Token {
|
func pgTokenToOauthToken(pgt *Token) *models.Token {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
return &models.Token{
|
return &models.Token{
|
||||||
21
internal/oauth/tokenstore_test.go
Normal file
21
internal/oauth/tokenstore_test.go
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package oauth
|
||||||
|
|
||||||
|
// TODO: write tests
|
||||||
|
|
@ -36,7 +36,7 @@ import (
|
||||||
// Router provides the REST interface for gotosocial, using gin.
|
// Router provides the REST interface for gotosocial, using gin.
|
||||||
type Router interface {
|
type Router interface {
|
||||||
// Attach a gin handler to the router with the given method and path
|
// Attach a gin handler to the router with the given method and path
|
||||||
AttachHandler(method string, path string, handler gin.HandlerFunc)
|
AttachHandler(method string, path string, f gin.HandlerFunc)
|
||||||
// Attach a gin middleware to the router that will be used globally
|
// Attach a gin middleware to the router that will be used globally
|
||||||
AttachMiddleware(handler gin.HandlerFunc)
|
AttachMiddleware(handler gin.HandlerFunc)
|
||||||
// Start the router
|
// Start the router
|
||||||
|
|
@ -59,6 +59,8 @@ func (r *router) Start() {
|
||||||
r.logger.Fatalf("listen: %s", err)
|
r.logger.Fatalf("listen: %s", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
// c := &gin.Context{}
|
||||||
|
// c.Get()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop shuts down the router nicely
|
// Stop shuts down the router nicely
|
||||||
|
|
|
||||||
31
pkg/mastotypes/token.go
Normal file
31
pkg/mastotypes/token.go
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package mastotypes
|
||||||
|
|
||||||
|
// Token represents an OAuth token used for authenticating with the API and performing actions.. See https://docs.joinmastodon.org/entities/token/
|
||||||
|
type Token struct {
|
||||||
|
// An OAuth token to be used for authorization.
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
// The OAuth token type. Mastodon uses Bearer tokens.
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
// The OAuth scopes granted by this token, space-separated.
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
// When the token was generated. (UNIX timestamp seconds)
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue