chunking away at it

This commit is contained in:
tsmethurst 2021-03-26 19:02:20 +01:00
commit f58f77bf1f
23 changed files with 860 additions and 394 deletions

View file

@ -21,6 +21,7 @@ package db
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"strings" "strings"
"github.com/go-fed/activity/pub" "github.com/go-fed/activity/pub"
@ -145,6 +146,10 @@ type DB interface {
// C) something went wrong in the db // C) something went wrong in the db
IsEmailAvailable(email string) error IsEmailAvailable(email string) error
// NewSignup creates a new user in the database with the given parameters, with an *unconfirmed* email address.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string) (*model.User, error)
/* /*
USEFUL CONVERSION FUNCTIONS USEFUL CONVERSION FUNCTIONS
*/ */

View file

@ -23,6 +23,7 @@
package model package model
import ( import (
"crypto/rsa"
"net/url" "net/url"
"time" "time"
) )
@ -82,6 +83,8 @@ type Account struct {
SubscriptionExpiresAt time.Time `pg:"type:timestamp"` SubscriptionExpiresAt time.Time `pg:"type:timestamp"`
// Does this account identify itself as a bot? // Does this account identify itself as a bot?
Bot bool Bot bool
// What reason was given for signing up when this account was created?
Reason string
/* /*
PRIVACY SETTINGS PRIVACY SETTINGS
@ -123,9 +126,9 @@ type Account struct {
Secret string Secret string
// Privatekey for validating activitypub requests, will obviously only be defined for local accounts // Privatekey for validating activitypub requests, will obviously only be defined for local accounts
PrivateKey string PrivateKey *rsa.PrivateKey
// Publickey for encoding activitypub requests, will be defined for both local and remote accounts // Publickey for encoding activitypub requests, will be defined for both local and remote accounts
PublicKey string PublicKey *rsa.PublicKey
/* /*
ADMIN FIELDS ADMIN FIELDS

View file

@ -20,8 +20,11 @@ package db
import ( import (
"context" "context"
"crypto/rand"
"crypto/rsa"
"errors" "errors"
"fmt" "fmt"
"net"
"net/mail" "net/mail"
"regexp" "regexp"
"strings" "strings"
@ -35,6 +38,7 @@ import (
"github.com/gotosocial/gotosocial/internal/db/model" "github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/pkg/mastotypes" "github.com/gotosocial/gotosocial/pkg/mastotypes"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
) )
// postgresService satisfies the DB interface // postgresService satisfies the DB interface
@ -305,7 +309,6 @@ func (ps *postgresService) GetAccountByUserID(userID string, account *model.Acco
return err return err
} }
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil { if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
fmt.Println(account)
if err == pg.ErrNoRows { if err == pg.ErrNoRows {
return ErrNoEntries{} return ErrNoEntries{}
} }
@ -412,6 +415,43 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
return nil return nil
} }
func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string) (*model.User, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
ps.log.Errorf("error creating new rsa key: %s", err)
return nil, err
}
a := &model.Account{
Username: username,
DisplayName: username,
Reason: reason,
PrivateKey: key,
PublicKey: &key.PublicKey,
ActorType: "Person",
}
if _, err = ps.conn.Model(a).Insert(); err != nil {
return nil, err
}
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("error hashing password: %s", err)
}
u := &model.User{
AccountID: a.ID,
EncryptedPassword: string(pw),
SignUpIP: signUpIP,
Locale: locale,
UnconfirmedEmail: email,
}
if _, err = ps.conn.Model(u).Insert(); err != nil {
return nil, err
}
return u, nil
}
/* /*
CONVERSION FUNCTIONS CONVERSION FUNCTIONS
*/ */
@ -433,7 +473,6 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype
} }
fields = append(fields, mField) fields = append(fields, mField)
} }
fmt.Printf("fields: %+v", fields)
// count followers // count followers
followers := []model.Follow{} followers := []model.Follow{}

21
internal/db/pg_test.go Normal file
View file

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
// TODO: write tests for postgres

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db package db
import ( import (

21
internal/db/pgfed_test.go Normal file
View file

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
// TODO: write tests for pgfed

View file

@ -19,6 +19,8 @@
package account package account
import ( import (
"fmt"
"net"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -26,9 +28,10 @@ import (
"github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/db/model" "github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/internal/module" "github.com/gotosocial/gotosocial/internal/module"
"github.com/gotosocial/gotosocial/internal/module/oauth" "github.com/gotosocial/gotosocial/internal/oauth"
"github.com/gotosocial/gotosocial/internal/router" "github.com/gotosocial/gotosocial/internal/router"
"github.com/gotosocial/gotosocial/pkg/mastotypes" "github.com/gotosocial/gotosocial/pkg/mastotypes"
"github.com/gotosocial/oauth2/v4"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -41,6 +44,7 @@ const (
type accountModule struct { type accountModule struct {
config *config.Config config *config.Config
db db.DB db db.DB
oauthServer oauth.Server
log *logrus.Logger log *logrus.Logger
} }
@ -60,15 +64,15 @@ func (m *accountModule) Route(r router.Router) error {
return nil return nil
} }
// accountCreatePOSTHandler handles create account requests, validates them,
// and puts them in the database if they're valid.
// It should be served as a POST at /api/v1/accounts
func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) { func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "AccountCreatePOSTHandler") l := m.log.WithField("func", "accountCreatePOSTHandler")
// TODO: check whether a valid app token has been presented!! authed, err := oauth.GetAuthed(c)
// See: https://docs.joinmastodon.org/methods/accounts/ if err != nil {
l.Debugf("couldn't auth: %s", err)
l.Trace("checking if registration is open") c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
if !m.config.AccountsConfig.OpenRegistration {
l.Debug("account registration is closed, returning error to client")
c.JSON(http.StatusUnauthorized, gin.H{"error": "account registration is closed"})
return return
} }
@ -81,15 +85,34 @@ func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) {
} }
l.Tracef("validating form %+v", form) l.Tracef("validating form %+v", form)
if err := validateCreateAccount(form, m.config.AccountsConfig.ReasonRequired, m.db); err != nil { if err := validateCreateAccount(form, m.config.AccountsConfig, m.db); err != nil {
l.Debugf("error validating form: %s", err) l.Debugf("error validating form: %s", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
clientIP := c.ClientIP()
l.Tracef("attempting to parse client ip address %s", clientIP)
signUpIP := net.ParseIP(clientIP)
if signUpIP == nil {
l.Debugf("error validating sign up ip address %s", clientIP)
c.JSON(http.StatusBadRequest, gin.H{"error": "ip address could not be parsed from request"})
return
}
ti, err := m.accountCreate(form, signUpIP, authed.Token, authed.Application)
if err != nil {
l.Errorf("internal server error while creating new account: %s", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, ti)
} }
// accountVerifyGETHandler serves a user's account details to them IF they reached this // accountVerifyGETHandler serves a user's account details to them IF they reached this
// handler while in possession of a valid token, according to the oauth middleware. // handler while in possession of a valid token, according to the oauth middleware.
// It should be served as a GET at /api/v1/accounts/verify_credentials
func (m *accountModule) accountVerifyGETHandler(c *gin.Context) { func (m *accountModule) accountVerifyGETHandler(c *gin.Context) {
l := m.log.WithField("func", "AccountVerifyGETHandler") l := m.log.WithField("func", "AccountVerifyGETHandler")
@ -120,3 +143,39 @@ func (m *accountModule) accountVerifyGETHandler(c *gin.Context) {
l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive) l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive)
c.JSON(http.StatusOK, acctSensitive) c.JSON(http.StatusOK, acctSensitive)
} }
/*
HELPER FUNCTIONS
*/
// accountCreate does the dirty work of making an account and user in the database.
// It then returns a token to the caller, for use with the new account, as per the
// spec here: https://docs.joinmastodon.org/methods/accounts/
func (m *accountModule) accountCreate(form *mastotypes.AccountCreateRequest, signUpIP net.IP, token oauth2.TokenInfo, app *model.Application) (*mastotypes.Token, error) {
l := m.log.WithField("func", "accountCreate")
// don't store a reason if we don't require one
reason := form.Reason
if !m.config.AccountsConfig.ReasonRequired {
reason = ""
}
l.Trace("creating new username and account")
user, err := m.db.NewSignup(form.Username, reason, m.config.AccountsConfig.RequireApproval, form.Email, form.Password, signUpIP, form.Locale)
if err != nil {
return nil, fmt.Errorf("error creating new signup in the database: %s", err)
}
l.Tracef("generating a token for user %s with account %s and application %s", user.ID, user.AccountID, app.ID)
ti, err := m.oauthServer.GenerateUserAccessToken(token, app.ClientSecret, user.ID)
if err != nil {
return nil, fmt.Errorf("error creating new access token for user %s: %s", user.ID, err)
}
return &mastotypes.Token{
AccessToken: ti.GetCode(),
TokenType: "Bearer",
Scope: ti.GetScope(),
CreatedAt: ti.GetCodeCreateAt().Unix(),
}, nil
}

View file

@ -20,34 +20,33 @@ package account
import ( import (
"context" "context"
"fmt" "net/http/httptest"
"net/url"
"testing" "testing"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gotosocial/gotosocial/internal/config" "github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/db/model" "github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/internal/module/oauth"
"github.com/gotosocial/gotosocial/internal/router"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"golang.org/x/crypto/bcrypt"
) )
type AccountTestSuite struct { type AccountTestSuite struct {
suite.Suite suite.Suite
db db.DB log *logrus.Logger
testAccountLocal *model.Account testAccountLocal *model.Account
testAccountRemote *model.Account testAccountRemote *model.Account
testUser *model.User testUser *model.User
config *config.Config db db.DB
accountModule *accountModule
} }
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout // SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
func (suite *AccountTestSuite) SetupSuite() { func (suite *AccountTestSuite) SetupSuite() {
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
suite.log = log
c := config.Empty() c := config.Empty()
c.DBConfig = &config.DBConfig{ c.DBConfig = &config.DBConfig{
Type: "postgres", Type: "postgres",
@ -58,118 +57,126 @@ func (suite *AccountTestSuite) SetupSuite() {
Database: "postgres", Database: "postgres",
ApplicationName: "gotosocial", ApplicationName: "gotosocial",
} }
suite.config = c
encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) database, err := db.New(context.Background(), c, log)
if err != nil { if err != nil {
logrus.Panicf("error encrypting user pass: %s", err) suite.FailNow(err.Error())
}
suite.db = database
suite.accountModule = &accountModule{
config: c,
db: database,
log: log,
} }
localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png") // encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
if err != nil { // if err != nil {
logrus.Panicf("error parsing localavatar url: %s", err) // logrus.Panicf("error encrypting user pass: %s", err)
} // }
localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png")
if err != nil {
logrus.Panicf("error parsing localheader url: %s", err)
}
acctID := uuid.NewString() // localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png")
suite.testAccountLocal = &model.Account{ // if err != nil {
ID: acctID, // logrus.Panicf("error parsing localavatar url: %s", err)
Username: "local_account_of_some_kind", // }
AvatarRemoteURL: localAvatar, // localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png")
HeaderRemoteURL: localHeader, // if err != nil {
DisplayName: "michael caine", // logrus.Panicf("error parsing localheader url: %s", err)
Fields: []model.Field{ // }
{
Name: "come and ave a go",
Value: "if you think you're hard enough",
},
{
Name: "website",
Value: "https://imdb.com",
VerifiedAt: time.Now(),
},
},
Note: "My name is Michael Caine and i'm a local user.",
Discoverable: true,
}
avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png") // acctID := uuid.NewString()
if err != nil { // suite.testAccountLocal = &model.Account{
logrus.Panicf("error parsing avatarURL: %s", err) // ID: acctID,
} // Username: "local_account_of_some_kind",
// AvatarRemoteURL: localAvatar,
// HeaderRemoteURL: localHeader,
// DisplayName: "michael caine",
// Fields: []model.Field{
// {
// Name: "come and ave a go",
// Value: "if you think you're hard enough",
// },
// {
// Name: "website",
// Value: "https://imdb.com",
// VerifiedAt: time.Now(),
// },
// },
// Note: "My name is Michael Caine and i'm a local user.",
// Discoverable: true,
// }
headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png") // avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png")
if err != nil { // if err != nil {
logrus.Panicf("error parsing avatarURL: %s", err) // logrus.Panicf("error parsing avatarURL: %s", err)
} // }
suite.testAccountRemote = &model.Account{
ID: uuid.NewString(),
Username: "neato_bombeato",
Domain: "example.org",
AvatarFileName: "avatar.png", // headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png")
AvatarContentType: "image/png", // if err != nil {
AvatarFileSize: 1024, // logrus.Panicf("error parsing avatarURL: %s", err)
AvatarUpdatedAt: time.Now(), // }
AvatarRemoteURL: avatarURL, // suite.testAccountRemote = &model.Account{
// ID: uuid.NewString(),
// Username: "neato_bombeato",
// Domain: "example.org",
HeaderFileName: "avatar.png", // AvatarFileName: "avatar.png",
HeaderContentType: "image/png", // AvatarContentType: "image/png",
HeaderFileSize: 1024, // AvatarFileSize: 1024,
HeaderUpdatedAt: time.Now(), // AvatarUpdatedAt: time.Now(),
HeaderRemoteURL: headerURL, // AvatarRemoteURL: avatarURL,
DisplayName: "one cool dude 420", // HeaderFileName: "avatar.png",
Fields: []model.Field{ // HeaderContentType: "image/png",
{ // HeaderFileSize: 1024,
Name: "pronouns", // HeaderUpdatedAt: time.Now(),
Value: "he/they", // HeaderRemoteURL: headerURL,
},
{ // DisplayName: "one cool dude 420",
Name: "website", // Fields: []model.Field{
Value: "https://imcool.edu", // {
VerifiedAt: time.Now(), // Name: "pronouns",
}, // Value: "he/they",
}, // },
Note: "<p>I'm cool as heck!</p>", // {
Discoverable: true, // Name: "website",
URI: "https://example.org/users/neato_bombeato", // Value: "https://imcool.edu",
URL: "https://example.org/@neato_bombeato", // VerifiedAt: time.Now(),
LastWebfingeredAt: time.Now(), // },
InboxURL: "https://example.org/users/neato_bombeato/inbox", // },
OutboxURL: "https://example.org/users/neato_bombeato/outbox", // Note: "<p>I'm cool as heck!</p>",
SharedInboxURL: "https://example.org/inbox", // Discoverable: true,
FollowersURL: "https://example.org/users/neato_bombeato/followers", // URI: "https://example.org/users/neato_bombeato",
FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured", // URL: "https://example.org/@neato_bombeato",
} // LastWebfingeredAt: time.Now(),
suite.testUser = &model.User{ // InboxURL: "https://example.org/users/neato_bombeato/inbox",
ID: uuid.NewString(), // OutboxURL: "https://example.org/users/neato_bombeato/outbox",
EncryptedPassword: string(encryptedPassword), // SharedInboxURL: "https://example.org/inbox",
Email: "user@example.org", // FollowersURL: "https://example.org/users/neato_bombeato/followers",
AccountID: acctID, // FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured",
// }
// suite.testUser = &model.User{
// ID: uuid.NewString(),
// EncryptedPassword: string(encryptedPassword),
// Email: "user@example.org",
// AccountID: acctID,
// }
}
func (suite *AccountTestSuite) TearDownSuite() {
if err := suite.db.Stop(context.Background()); err != nil {
logrus.Panicf("error closing db connection: %s", err)
} }
} }
// SetupTest creates a postgres connection and creates the oauth_clients table before each test // SetupTest creates a db connection and creates necessary tables before each test
func (suite *AccountTestSuite) SetupTest() { func (suite *AccountTestSuite) SetupTest() {
log := logrus.New()
log.SetLevel(logrus.TraceLevel)
db, err := db.New(context.Background(), suite.config, log)
if err != nil {
logrus.Panicf("error creating database connection: %s", err)
}
suite.db = db
models := []interface{}{ models := []interface{}{
&model.User{}, &model.User{},
&model.Account{}, &model.Account{},
&model.Follow{}, &model.Follow{},
&model.Status{}, &model.Status{},
&model.Application{},
} }
for _, m := range models { for _, m := range models {
@ -177,70 +184,31 @@ func (suite *AccountTestSuite) SetupTest() {
logrus.Panicf("db connection error: %s", err) logrus.Panicf("db connection error: %s", err)
} }
} }
if err := suite.db.Put(suite.testAccountLocal); err != nil {
logrus.Panicf("could not insert test account into db: %s", err)
}
if err := suite.db.Put(suite.testUser); err != nil {
logrus.Panicf("could not insert test user into db: %s", err)
}
} }
// TearDownTest drops the oauth_clients table and closes the pg connection after each test // TearDownTest drops tables to make sure there's no data in the db
func (suite *AccountTestSuite) TearDownTest() { func (suite *AccountTestSuite) TearDownTest() {
models := []interface{}{ models := []interface{}{
&model.User{}, &model.User{},
&model.Account{}, &model.Account{},
&model.Follow{}, &model.Follow{},
&model.Status{}, &model.Status{},
&model.Application{},
} }
for _, m := range models { for _, m := range models {
if err := suite.db.DropTable(m); err != nil { if err := suite.db.DropTable(m); err != nil {
logrus.Panicf("error dropping table: %s", err) logrus.Panicf("error dropping table: %s", err)
} }
} }
if err := suite.db.Stop(context.Background()); err != nil {
logrus.Panicf("error closing db connection: %s", err)
}
suite.db = nil
} }
func (suite *AccountTestSuite) TestAPIInitialize() { func (suite *AccountTestSuite) TestAccountCreatePOSTHandler() {
log := logrus.New() // TODO: figure out how to test this properly
log.SetLevel(logrus.TraceLevel) recorder := httptest.NewRecorder()
recorder.Header().Set("X-Forwarded-For", "127.0.0.1")
r, err := router.New(suite.config, log) ctx, _ := gin.CreateTestContext(recorder)
if err != nil { // ctx.Set()
suite.FailNow(fmt.Sprintf("error creating router: %s", err)) suite.accountModule.accountCreatePOSTHandler(ctx)
}
r.AttachMiddleware(func(c *gin.Context) {
account := &model.Account{}
if err := suite.db.GetAccountByUserID(suite.testUser.ID, account); err != nil || account == nil {
suite.T().Log(err)
suite.FailNowf("no account found for user %s, continuing with unauthenticated request: %+v", "", suite.testUser.ID, account)
fmt.Println(account)
return
}
c.Set(oauth.SessionAuthorizedAccount, account)
c.Set(oauth.SessionAuthorizedUser, suite.testUser.ID)
})
acct := New(suite.config, suite.db, log)
if err := acct.Route(r); err != nil {
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
}
r.Start()
defer func() {
if err := r.Stop(context.Background()); err != nil {
panic(fmt.Errorf("error stopping router: %s", err))
}
}()
time.Sleep(10 * time.Second)
} }
func TestAccountTestSuite(t *testing.T) { func TestAccountTestSuite(t *testing.T) {

View file

@ -21,12 +21,17 @@ package account
import ( import (
"errors" "errors"
"github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/util" "github.com/gotosocial/gotosocial/internal/util"
"github.com/gotosocial/gotosocial/pkg/mastotypes" "github.com/gotosocial/gotosocial/pkg/mastotypes"
) )
func validateCreateAccount(form *mastotypes.AccountCreateRequest, reasonRequired bool, database db.DB) error { func validateCreateAccount(form *mastotypes.AccountCreateRequest, c *config.AccountsConfig, database db.DB) error {
if !c.OpenRegistration {
return errors.New("registration is not open for this server")
}
if err := util.ValidateSignUpUsername(form.Username); err != nil { if err := util.ValidateSignUpUsername(form.Username); err != nil {
return err return err
} }
@ -47,7 +52,7 @@ func validateCreateAccount(form *mastotypes.AccountCreateRequest, reasonRequired
return err return err
} }
if err := util.ValidateSignUpReason(form.Reason, reasonRequired); err != nil { if err := util.ValidateSignUpReason(form.Reason, c.ReasonRequired); err != nil {
return err return err
} }

140
internal/module/app/app.go Normal file
View file

@ -0,0 +1,140 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package app
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/internal/module"
"github.com/gotosocial/gotosocial/internal/oauth"
"github.com/gotosocial/gotosocial/internal/router"
"github.com/gotosocial/gotosocial/pkg/mastotypes"
"github.com/sirupsen/logrus"
)
const appsPath = "/api/v1/apps"
type appModule struct {
server oauth.Server
db db.DB
log *logrus.Logger
}
// New returns a new auth module
func New(srv oauth.Server, db db.DB, log *logrus.Logger) module.ClientAPIModule {
return &appModule{
server: srv,
db: db,
log: log,
}
}
// Route satisfies the RESTAPIModule interface
func (m *appModule) Route(s router.Router) error {
s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
return nil
}
// appsPOSTHandler should be served at https://example.org/api/v1/apps
// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
func (m *appModule) appsPOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "AppsPOSTHandler")
l.Trace("entering AppsPOSTHandler")
form := &mastotypes.ApplicationPOSTRequest{}
if err := c.ShouldBind(form); err != nil {
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
return
}
// permitted length for most fields
permittedLength := 64
// redirect can be a bit bigger because we probably need to encode data in the redirect uri
permittedRedirect := 256
// check lengths of fields before proceeding so the user can't spam huge entries into the database
if len(form.ClientName) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
return
}
if len(form.Website) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
return
}
if len(form.RedirectURIs) > permittedRedirect {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
return
}
if len(form.Scopes) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
return
}
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
var scopes string
if form.Scopes == "" {
scopes = "read"
} else {
scopes = form.Scopes
}
// generate new IDs for this application and its associated client
clientID := uuid.NewString()
clientSecret := uuid.NewString()
vapidKey := uuid.NewString()
// generate the application to put in the database
app := &model.Application{
Name: form.ClientName,
Website: form.Website,
RedirectURI: form.RedirectURIs,
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: scopes,
VapidKey: vapidKey,
}
// chuck it in the db
if err := m.db.Put(app); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// now we need to model an oauth client from the application that the oauth library can use
oc := &oauth.Client{
ID: clientID,
Secret: clientSecret,
Domain: form.RedirectURIs,
UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
}
// chuck it in the db
if err := m.db.Put(oc); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
c.JSON(http.StatusOK, app.ToMasto())
}

View file

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package app
// TODO: write tests

View file

@ -1,4 +1,4 @@
# oauth # auth
This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API. This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API.

View file

@ -16,55 +16,40 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
// Package oauth is a module that provides oauth functionality to a router. // Package auth is a module that provides oauth functionality to a router.
// It adds the following paths: // It adds the following paths:
// /api/v1/apps
// /auth/sign_in // /auth/sign_in
// /oauth/token // /oauth/token
// /oauth/authorize // /oauth/authorize
// It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token. // It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token.
package oauth package auth
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/db/model" "github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/internal/module" "github.com/gotosocial/gotosocial/internal/module"
"github.com/gotosocial/gotosocial/internal/oauth"
"github.com/gotosocial/gotosocial/internal/router" "github.com/gotosocial/gotosocial/internal/router"
"github.com/gotosocial/gotosocial/pkg/mastotypes" "github.com/gotosocial/gotosocial/pkg/mastotypes"
"github.com/gotosocial/oauth2/v4"
"github.com/gotosocial/oauth2/v4/errors"
"github.com/gotosocial/oauth2/v4/manage"
"github.com/gotosocial/oauth2/v4/server"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
const ( const (
appsPath = "/api/v1/apps"
authSignInPath = "/auth/sign_in" authSignInPath = "/auth/sign_in"
oauthTokenPath = "/oauth/token" oauthTokenPath = "/oauth/token"
oauthAuthorizePath = "/oauth/authorize" oauthAuthorizePath = "/oauth/authorize"
// SessionAuthorizedUser is the key set in the gin context for the id of
// a User who has successfully passed Bearer token authorization.
// The interface returned from grabbing this key should be parsed as a string.
SessionAuthorizedUser = "authorized_user"
// SessionAuthorizedAccount is the key set in the gin context for the Account
// of a User who has successfully passed Bearer token authorization.
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Account
SessionAuthorizedAccount = "authorized_account"
) )
// oauthModule is an oauth2 oauthModule that satisfies the ClientAPIModule interface type authModule struct {
type oauthModule struct { server oauth.Server
oauthManager *manage.Manager
oauthServer *server.Server
db db.DB db db.DB
log *logrus.Logger log *logrus.Logger
} }
@ -74,52 +59,17 @@ type login struct {
Password string `form:"password"` Password string `form:"password"`
} }
// New returns a new oauth module // New returns a new auth module
func New(ts oauth2.TokenStore, cs oauth2.ClientStore, db db.DB, log *logrus.Logger) module.ClientAPIModule { func New(srv oauth.Server, db db.DB, log *logrus.Logger) module.ClientAPIModule {
manager := manage.NewDefaultManager() return &authModule{
manager.MapTokenStorage(ts) server: srv,
manager.MapClientStorage(cs)
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
sc := &server.Config{
TokenType: "Bearer",
// Must follow the spec.
AllowGetAccessRequest: false,
// Support only the non-implicit flow.
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
// Allow:
// - Authorization Code (for first & third parties)
AllowedGrantTypes: []oauth2.GrantType{
oauth2.AuthorizationCode,
},
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
}
srv := server.NewServer(sc, manager)
srv.SetInternalErrorHandler(func(err error) *errors.Response {
log.Errorf("internal oauth error: %s", err)
return nil
})
srv.SetResponseErrorHandler(func(re *errors.Response) {
log.Errorf("internal response error: %s", re.Error)
})
m := &oauthModule{
oauthManager: manager,
oauthServer: srv,
db: db, db: db,
log: log, log: log,
} }
m.oauthServer.SetUserAuthorizationHandler(m.userAuthorizationHandler)
m.oauthServer.SetClientInfoHandler(server.ClientFormHandler)
return m
} }
// Route satisfies the RESTAPIModule interface // Route satisfies the RESTAPIModule interface
func (m *oauthModule) Route(s router.Router) error { func (m *authModule) Route(s router.Router) error {
s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler) s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler)
s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler) s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler)
@ -129,7 +79,6 @@ func (m *oauthModule) Route(s router.Router) error {
s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler) s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler)
s.AttachMiddleware(m.oauthTokenMiddleware) s.AttachMiddleware(m.oauthTokenMiddleware)
return nil return nil
} }
@ -137,93 +86,10 @@ func (m *oauthModule) Route(s router.Router) error {
MAIN HANDLERS -- serve these through a server/router MAIN HANDLERS -- serve these through a server/router
*/ */
// appsPOSTHandler should be served at https://example.org/api/v1/apps
// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
func (m *oauthModule) appsPOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "AppsPOSTHandler")
l.Trace("entering AppsPOSTHandler")
form := &mastotypes.ApplicationPOSTRequest{}
if err := c.ShouldBind(form); err != nil {
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
return
}
// permitted length for most fields
permittedLength := 64
// redirect can be a bit bigger because we probably need to encode data in the redirect uri
permittedRedirect := 256
// check lengths of fields before proceeding so the user can't spam huge entries into the database
if len(form.ClientName) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
return
}
if len(form.Website) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
return
}
if len(form.RedirectURIs) > permittedRedirect {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
return
}
if len(form.Scopes) > permittedLength {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
return
}
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
var scopes string
if form.Scopes == "" {
scopes = "read"
} else {
scopes = form.Scopes
}
// generate new IDs for this application and its associated client
clientID := uuid.NewString()
clientSecret := uuid.NewString()
vapidKey := uuid.NewString()
// generate the application to put in the database
app := &model.Application{
Name: form.ClientName,
Website: form.Website,
RedirectURI: form.RedirectURIs,
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: scopes,
VapidKey: vapidKey,
}
// chuck it in the db
if err := m.db.Put(app); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// now we need to model an oauth client from the application that the oauth library can use
oc := &oauthClient{
ID: clientID,
Secret: clientSecret,
Domain: form.RedirectURIs,
UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
}
// chuck it in the db
if err := m.db.Put(oc); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
c.JSON(http.StatusOK, app.ToMasto())
}
// signInGETHandler should be served at https://example.org/auth/sign_in. // signInGETHandler should be served at https://example.org/auth/sign_in.
// The idea is to present a sign in page to the user, where they can enter their username and password. // The idea is to present a sign in page to the user, where they can enter their username and password.
// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler // The form will then POST to the sign in page, which will be handled by SignInPOSTHandler
func (m *oauthModule) signInGETHandler(c *gin.Context) { func (m *authModule) signInGETHandler(c *gin.Context) {
m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html") m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html")
c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{}) c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{})
} }
@ -231,7 +97,7 @@ func (m *oauthModule) signInGETHandler(c *gin.Context) {
// signInPOSTHandler should be served at https://example.org/auth/sign_in. // signInPOSTHandler should be served at https://example.org/auth/sign_in.
// The idea is to present a sign in page to the user, where they can enter their username and password. // The idea is to present a sign in page to the user, where they can enter their username and password.
// The handler will then redirect to the auth handler served at /auth // The handler will then redirect to the auth handler served at /auth
func (m *oauthModule) signInPOSTHandler(c *gin.Context) { func (m *authModule) signInPOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "SignInPOSTHandler") l := m.log.WithField("func", "SignInPOSTHandler")
s := sessions.Default(c) s := sessions.Default(c)
form := &login{} form := &login{}
@ -260,10 +126,10 @@ func (m *oauthModule) signInPOSTHandler(c *gin.Context) {
// tokenPOSTHandler should be served as a POST at https://example.org/oauth/token // tokenPOSTHandler should be served as a POST at https://example.org/oauth/token
// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs. // The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token // See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token
func (m *oauthModule) tokenPOSTHandler(c *gin.Context) { func (m *authModule) tokenPOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "TokenPOSTHandler") l := m.log.WithField("func", "TokenPOSTHandler")
l.Trace("entered TokenPOSTHandler") l.Trace("entered TokenPOSTHandler")
if err := m.oauthServer.HandleTokenRequest(c.Writer, c.Request); err != nil { if err := m.server.HandleTokenRequest(c.Writer, c.Request); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} }
} }
@ -271,7 +137,7 @@ func (m *oauthModule) tokenPOSTHandler(c *gin.Context) {
// authorizeGETHandler should be served as GET at https://example.org/oauth/authorize // authorizeGETHandler should be served as GET at https://example.org/oauth/authorize
// The idea here is to present an oauth authorize page to the user, with a button // The idea here is to present an oauth authorize page to the user, with a button
// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user // that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
func (m *oauthModule) authorizeGETHandler(c *gin.Context) { func (m *authModule) authorizeGETHandler(c *gin.Context) {
l := m.log.WithField("func", "AuthorizeGETHandler") l := m.log.WithField("func", "AuthorizeGETHandler")
s := sessions.Default(c) s := sessions.Default(c)
@ -349,7 +215,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
// At this point we assume that the user has A) logged in and B) accepted that the app should act for them, // At this point we assume that the user has A) logged in and B) accepted that the app should act for them,
// so we should proceed with the authentication flow and generate an oauth token for them if we can. // so we should proceed with the authentication flow and generate an oauth token for them if we can.
// See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user // See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
func (m *oauthModule) authorizePOSTHandler(c *gin.Context) { func (m *authModule) authorizePOSTHandler(c *gin.Context) {
l := m.log.WithField("func", "AuthorizePOSTHandler") l := m.log.WithField("func", "AuthorizePOSTHandler")
s := sessions.Default(c) s := sessions.Default(c)
@ -404,7 +270,7 @@ func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
l.Tracef("values on request set to %+v", c.Request.Form) l.Tracef("values on request set to %+v", c.Request.Form)
// and proceed with authorization using the oauth2 library // and proceed with authorization using the oauth2 library
if err := m.oauthServer.HandleAuthorizeRequest(c.Writer, c.Request); err != nil { if err := m.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
} }
} }
@ -418,25 +284,50 @@ func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
// the request. Then, it will look up the account for that user, and set that in the request too. // the request. Then, it will look up the account for that user, and set that in the request too.
// If user or account can't be found, then the handler won't *fail*, in case the server wants to allow // If user or account can't be found, then the handler won't *fail*, in case the server wants to allow
// public requests that don't have a Bearer token set (eg., for public instance information and so on). // public requests that don't have a Bearer token set (eg., for public instance information and so on).
func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) { func (m *authModule) oauthTokenMiddleware(c *gin.Context) {
l := m.log.WithField("func", "ValidatePassword") l := m.log.WithField("func", "ValidatePassword")
l.Trace("entering OauthTokenMiddleware") l.Trace("entering OauthTokenMiddleware")
ti, err := m.oauthServer.ValidationBearerToken(c.Request) ti, err := m.server.ValidationBearerToken(c.Request)
if err != nil { if err != nil {
l.Trace("no valid token presented: continuing with unauthenticated request") l.Trace("no valid token presented: continuing with unauthenticated request")
return return
} }
l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope()) c.Set(oauth.SessionAuthorizedToken, ti)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedToken, ti)
acct := &model.Account{} // check for user-level token
if err := m.db.GetAccountByUserID(ti.GetUserID(), acct); err != nil || acct == nil { if uid := ti.GetUserID(); uid != "" {
l.Tracef("no account found for user %s, continuing with unauthenticated request", ti.GetUserID()) l.Tracef("authenticated user %s with bearer token, scope is %s", uid, ti.GetScope())
// fetch user's and account for this user id
user := &model.User{}
if err := m.db.GetByID(uid, user); err != nil || user == nil {
l.Warnf("no user found for validated uid %s", uid)
return return
} }
c.Set(oauth.SessionAuthorizedUser, user)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user)
c.Set(SessionAuthorizedAccount, acct) acct := &model.Account{}
c.Set(SessionAuthorizedUser, ti.GetUserID()) if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil {
l.Warnf("no account found for validated user %s", uid)
return
}
c.Set(oauth.SessionAuthorizedAccount, acct)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedAccount, acct)
}
// check for application token
if cid := ti.GetClientID(); cid != "" {
l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope())
app := &model.Application{}
if err := m.db.GetWhere("client_id", cid, app); err != nil {
l.Tracef("no app found for client %s", cid)
}
c.Set(oauth.SessionAuthorizedApplication, app)
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedApplication, app)
}
} }
/* /*
@ -447,7 +338,7 @@ func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) {
// The goal is to authenticate the password against the one for that email // The goal is to authenticate the password against the one for that email
// address stored in the database. If OK, we return the userid (a uuid) for that user, // address stored in the database. If OK, we return the userid (a uuid) for that user,
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db. // so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
func (m *oauthModule) validatePassword(email string, password string) (userid string, err error) { func (m *authModule) validatePassword(email string, password string) (userid string, err error) {
l := m.log.WithField("func", "ValidatePassword") l := m.log.WithField("func", "ValidatePassword")
// make sure an email/password was provided and bail if not // make sure an email/password was provided and bail if not
@ -487,18 +378,6 @@ func incorrectPassword() (string, error) {
return "", errors.New("password/email combination was incorrect") return "", errors.New("password/email combination was incorrect")
} }
// userAuthorizationHandler gets the user's ID from the 'userid' field of the request form,
// or redirects to the /auth/sign_in page, if this key is not present.
func (m *oauthModule) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) {
l := m.log.WithField("func", "UserAuthorizationHandler")
userID = r.FormValue("userid")
if userID == "" {
return "", errors.New("userid was empty, redirecting to sign in page")
}
l.Tracef("returning userID %s", userID)
return userID, err
}
// parseAuthForm parses the OAuthAuthorize form in the gin context, and stores // parseAuthForm parses the OAuthAuthorize form in the gin context, and stores
// the values in the form into the session. // the values in the form into the session.
func parseAuthForm(c *gin.Context, l *logrus.Entry) error { func parseAuthForm(c *gin.Context, l *logrus.Entry) error {

View file

@ -16,38 +16,38 @@
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package oauth package auth
import ( import (
"context" "context"
"fmt" "fmt"
"testing" "testing"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gotosocial/gotosocial/internal/config" "github.com/gotosocial/gotosocial/internal/config"
"github.com/gotosocial/gotosocial/internal/db" "github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/db/model" "github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/gotosocial/internal/oauth"
"github.com/gotosocial/gotosocial/internal/router" "github.com/gotosocial/gotosocial/internal/router"
"github.com/gotosocial/oauth2/v4"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
type OauthTestSuite struct { type AuthTestSuite struct {
suite.Suite suite.Suite
tokenStore oauth2.TokenStore oauthServer oauth.Server
clientStore oauth2.ClientStore
db db.DB db db.DB
testAccount *model.Account testAccount *model.Account
testApplication *model.Application testApplication *model.Application
testUser *model.User testUser *model.User
testClient *oauthClient testClient *oauth.Client
config *config.Config config *config.Config
} }
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout // SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
func (suite *OauthTestSuite) SetupSuite() { func (suite *AuthTestSuite) SetupSuite() {
c := config.Empty() c := config.Empty()
// we're running on localhost without https so set the protocol to http // we're running on localhost without https so set the protocol to http
c.Protocol = "http" c.Protocol = "http"
@ -84,7 +84,7 @@ func (suite *OauthTestSuite) SetupSuite() {
Email: "user@example.org", Email: "user@example.org",
AccountID: acctID, AccountID: acctID,
} }
suite.testClient = &oauthClient{ suite.testClient = &oauth.Client{
ID: "a-known-client-id", ID: "a-known-client-id",
Secret: "some-secret", Secret: "some-secret",
Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host), Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host),
@ -101,7 +101,7 @@ func (suite *OauthTestSuite) SetupSuite() {
} }
// SetupTest creates a postgres connection and creates the oauth_clients table before each test // SetupTest creates a postgres connection and creates the oauth_clients table before each test
func (suite *OauthTestSuite) SetupTest() { func (suite *AuthTestSuite) SetupTest() {
log := logrus.New() log := logrus.New()
log.SetLevel(logrus.TraceLevel) log.SetLevel(logrus.TraceLevel)
@ -113,8 +113,8 @@ func (suite *OauthTestSuite) SetupTest() {
suite.db = db suite.db = db
models := []interface{}{ models := []interface{}{
&oauthClient{}, &oauth.Client{},
&oauthToken{}, &oauth.Token{},
&model.User{}, &model.User{},
&model.Account{}, &model.Account{},
&model.Application{}, &model.Application{},
@ -126,8 +126,7 @@ func (suite *OauthTestSuite) SetupTest() {
} }
} }
suite.tokenStore = newTokenStore(context.Background(), suite.db, logrus.New()) suite.oauthServer = oauth.New(suite.db, log)
suite.clientStore = newClientStore(suite.db)
if err := suite.db.Put(suite.testAccount); err != nil { if err := suite.db.Put(suite.testAccount); err != nil {
logrus.Panicf("could not insert test account into db: %s", err) logrus.Panicf("could not insert test account into db: %s", err)
@ -145,10 +144,10 @@ func (suite *OauthTestSuite) SetupTest() {
} }
// TearDownTest drops the oauth_clients table and closes the pg connection after each test // TearDownTest drops the oauth_clients table and closes the pg connection after each test
func (suite *OauthTestSuite) TearDownTest() { func (suite *AuthTestSuite) TearDownTest() {
models := []interface{}{ models := []interface{}{
&oauthClient{}, &oauth.Client{},
&oauthToken{}, &oauth.Token{},
&model.User{}, &model.User{},
&model.Account{}, &model.Account{},
&model.Application{}, &model.Application{},
@ -164,7 +163,7 @@ func (suite *OauthTestSuite) TearDownTest() {
suite.db = nil suite.db = nil
} }
func (suite *OauthTestSuite) TestAPIInitialize() { func (suite *AuthTestSuite) TestAPIInitialize() {
log := logrus.New() log := logrus.New()
log.SetLevel(logrus.TraceLevel) log.SetLevel(logrus.TraceLevel)
@ -173,17 +172,18 @@ func (suite *OauthTestSuite) TestAPIInitialize() {
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
} }
api := New(suite.tokenStore, suite.clientStore, suite.db, log) api := New(suite.oauthServer, suite.db, log)
if err := api.Route(r); err != nil { if err := api.Route(r); err != nil {
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err)) suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
} }
r.Start() r.Start()
time.Sleep(60 * time.Second)
if err := r.Stop(context.Background()); err != nil { if err := r.Stop(context.Background()); err != nil {
suite.FailNow(fmt.Sprintf("error stopping router: %s", err)) suite.FailNow(fmt.Sprintf("error stopping router: %s", err))
} }
} }
func TestOauthTestSuite(t *testing.T) { func TestAuthTestSuite(t *testing.T) {
suite.Run(t, new(OauthTestSuite)) suite.Run(t, new(AuthTestSuite))
} }

View file

@ -38,7 +38,7 @@ func newClientStore(db db.DB) oauth2.ClientStore {
} }
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
poc := &oauthClient{ poc := &Client{
ID: clientID, ID: clientID,
} }
if err := cs.db.GetByID(clientID, poc); err != nil { if err := cs.db.GetByID(clientID, poc); err != nil {
@ -48,7 +48,7 @@ func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.Cli
} }
func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
poc := &oauthClient{ poc := &Client{
ID: cli.GetID(), ID: cli.GetID(),
Secret: cli.GetSecret(), Secret: cli.GetSecret(),
Domain: cli.GetDomain(), Domain: cli.GetDomain(),
@ -58,13 +58,13 @@ func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo
} }
func (cs *clientStore) Delete(ctx context.Context, id string) error { func (cs *clientStore) Delete(ctx context.Context, id string) error {
poc := &oauthClient{ poc := &Client{
ID: id, ID: id,
} }
return cs.db.DeleteByID(id, poc) return cs.db.DeleteByID(id, poc)
} }
type oauthClient struct { type Client struct {
ID string ID string
Secret string Secret string
Domain string Domain string

View file

@ -69,7 +69,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
suite.db = db suite.db = db
models := []interface{}{ models := []interface{}{
&oauthClient{}, &Client{},
} }
for _, m := range models { for _, m := range models {
@ -82,7 +82,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
// TearDownTest drops the oauth_clients table and closes the pg connection after each test // TearDownTest drops the oauth_clients table and closes the pg connection after each test
func (suite *PgClientStoreTestSuite) TearDownTest() { func (suite *PgClientStoreTestSuite) TearDownTest() {
models := []interface{}{ models := []interface{}{
&oauthClient{}, &Client{},
} }
for _, m := range models { for _, m := range models {
if err := suite.db.DropTable(m); err != nil { if err := suite.db.DropTable(m); err != nil {

212
internal/oauth/oauth.go Normal file
View file

@ -0,0 +1,212 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package oauth
import (
"context"
"net/http"
"github.com/gin-gonic/gin"
"github.com/gotosocial/gotosocial/internal/db"
"github.com/gotosocial/gotosocial/internal/db/model"
"github.com/gotosocial/oauth2/v4"
"github.com/gotosocial/oauth2/v4/errors"
"github.com/gotosocial/oauth2/v4/manage"
"github.com/gotosocial/oauth2/v4/server"
"github.com/sirupsen/logrus"
)
const (
SessionAuthorizedToken = "authorized_token"
// SessionAuthorizedUser is the key set in the gin context for the id of
// a User who has successfully passed Bearer token authorization.
// The interface returned from grabbing this key should be parsed as a *gtsmodel.User
SessionAuthorizedUser = "authorized_user"
// SessionAuthorizedAccount is the key set in the gin context for the Account
// of a User who has successfully passed Bearer token authorization.
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Account
SessionAuthorizedAccount = "authorized_account"
// SessionAuthorizedAccount is the key set in the gin context for the Application
// of a Client who has successfully passed Bearer token authorization.
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Application
SessionAuthorizedApplication = "authorized_app"
)
// Server wraps some oauth2 server functions in an interface, exposing only what is needed
type Server interface {
HandleTokenRequest(w http.ResponseWriter, r *http.Request) error
HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error
ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error)
GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error)
}
// s fulfils the Server interface using the underlying oauth2 server
type s struct {
server *server.Server
log *logrus.Logger
}
type Authed struct {
Token oauth2.TokenInfo
Application *model.Application
User *model.User
Account *model.Account
}
// GetAuthed is a convenience function for returning an Authed struct from a gin context.
// In essence, it tries to extract a token, application, user, and account from the context,
// and then sets them on a struct for convenience.
//
// If any are not present in the context, they will be set to nil on the returned Authed struct.
//
// If *ALL* are not present, then nil and an error will be returned.
//
// If something goes wrong during parsing, then nil and an error will be returned (consider this not authed).
func GetAuthed(c *gin.Context) (*Authed, error) {
ctx := c.Copy()
a := &Authed{}
var i interface{}
var ok bool
i, ok = ctx.Get(SessionAuthorizedToken)
if ok {
parsed, ok := i.(oauth2.TokenInfo)
if !ok {
return nil, errors.New("could not parse token from session context")
}
a.Token = parsed
}
i, ok = ctx.Get(SessionAuthorizedApplication)
if ok {
parsed, ok := i.(*model.Application)
if !ok {
return nil, errors.New("could not parse application from session context")
}
a.Application = parsed
}
i, ok = ctx.Get(SessionAuthorizedUser)
if ok {
parsed, ok := i.(*model.User)
if !ok {
return nil, errors.New("could not parse user from session context")
}
a.User = parsed
}
i, ok = ctx.Get(SessionAuthorizedAccount)
if ok {
parsed, ok := i.(*model.Account)
if !ok {
return nil, errors.New("could not parse account from session context")
}
a.Account = parsed
}
if a.Token == nil && a.Application == nil && a.User == nil && a.Account == nil {
return nil, errors.New("not authorized")
}
return a, nil
}
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
func (s *s) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
return s.server.HandleTokenRequest(w, r)
}
// HandleAuthorizeRequest wraps the oauth2 library's HandleAuthorizeRequest function
func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
return s.server.HandleAuthorizeRequest(w, r)
}
// ValidationBearerToken wraps the oauth2 library's ValidationBearerToken function
func (s *s) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
return s.server.ValidationBearerToken(r)
}
// GenerateUserAccessToken shortcuts the normal oauth flow to create an user-level
// bearer token *without* requiring that user to log in. This is useful when we
// need to create a token for new users who haven't validated their email or logged in yet.
//
// The ti parameter refers to an existing Application token that was used to make the upstream
// request. This token needs to be validated and exist in database in order to create a new token.
func (s *s) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error) {
tgr := &oauth2.TokenGenerateRequest{
ClientID: ti.GetClientID(),
ClientSecret: clientSecret,
UserID: userID,
RedirectURI: ti.GetRedirectURI(),
Scope: ti.GetScope(),
Code: ti.GetCode(),
CodeChallenge: ti.GetCodeChallenge(),
CodeChallengeMethod: ti.GetCodeChallengeMethod(),
}
return s.server.Manager.GenerateAccessToken(context.Background(), oauth2.AuthorizationCode, tgr)
}
func New(database db.DB, log *logrus.Logger) Server {
ts := newTokenStore(context.Background(), database, log)
cs := newClientStore(database)
manager := manage.NewDefaultManager()
manager.MapTokenStorage(ts)
manager.MapClientStorage(cs)
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
sc := &server.Config{
TokenType: "Bearer",
// Must follow the spec.
AllowGetAccessRequest: false,
// Support only the non-implicit flow.
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
// Allow:
// - Authorization Code (for first & third parties)
// - Client Credentials (for applications)
AllowedGrantTypes: []oauth2.GrantType{
oauth2.AuthorizationCode,
oauth2.ClientCredentials,
},
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
}
srv := server.NewServer(sc, manager)
srv.SetInternalErrorHandler(func(err error) *errors.Response {
log.Errorf("internal oauth error: %s", err)
return nil
})
srv.SetResponseErrorHandler(func(re *errors.Response) {
log.Errorf("internal response error: %s", re.Error)
})
srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
userID := r.FormValue("userid")
if userID == "" {
return "", errors.New("userid was empty")
}
return userID, nil
})
srv.SetClientInfoHandler(server.ClientFormHandler)
return &s{
server: srv,
}
}

View file

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package oauth
// TODO: write tests

View file

@ -70,7 +70,7 @@ func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.Tok
func (pts *tokenStore) sweep() error { func (pts *tokenStore) sweep() error {
// select *all* tokens from the db // select *all* tokens from the db
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. // todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
tokens := new([]*oauthToken) tokens := new([]*Token)
if err := pts.db.GetAll(tokens); err != nil { if err := pts.db.GetAll(tokens); err != nil {
return err return err
} }
@ -106,22 +106,22 @@ func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error
// RemoveByCode deletes a token from the DB based on the Code field // RemoveByCode deletes a token from the DB based on the Code field
func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error { func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
return pts.db.DeleteWhere("code", code, &oauthToken{}) return pts.db.DeleteWhere("code", code, &Token{})
} }
// RemoveByAccess deletes a token from the DB based on the Access field // RemoveByAccess deletes a token from the DB based on the Access field
func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
return pts.db.DeleteWhere("access", access, &oauthToken{}) return pts.db.DeleteWhere("access", access, &Token{})
} }
// RemoveByRefresh deletes a token from the DB based on the Refresh field // RemoveByRefresh deletes a token from the DB based on the Refresh field
func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
return pts.db.DeleteWhere("refresh", refresh, &oauthToken{}) return pts.db.DeleteWhere("refresh", refresh, &Token{})
} }
// GetByCode selects a token from the DB based on the Code field // GetByCode selects a token from the DB based on the Code field
func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
pgt := &oauthToken{ pgt := &Token{
Code: code, Code: code,
} }
if err := pts.db.GetWhere("code", code, pgt); err != nil { if err := pts.db.GetWhere("code", code, pgt); err != nil {
@ -132,7 +132,7 @@ func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.Token
// GetByAccess selects a token from the DB based on the Access field // GetByAccess selects a token from the DB based on the Access field
func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
pgt := &oauthToken{ pgt := &Token{
Access: access, Access: access,
} }
if err := pts.db.GetWhere("access", access, pgt); err != nil { if err := pts.db.GetWhere("access", access, pgt); err != nil {
@ -143,7 +143,7 @@ func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.T
// GetByRefresh selects a token from the DB based on the Refresh field // GetByRefresh selects a token from the DB based on the Refresh field
func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
pgt := &oauthToken{ pgt := &Token{
Refresh: refresh, Refresh: refresh,
} }
if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil { if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil {
@ -156,7 +156,7 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
The following models are basically helpers for the postgres token store implementation, they should only be used internally. The following models are basically helpers for the postgres token store implementation, they should only be used internally.
*/ */
// oauthToken is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt. // Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
// //
// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined, // Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined,
// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and // and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and
@ -164,9 +164,9 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
// //
// Note that this struct does *not* satisfy the token interface shown here: https://github.com/gotosocial/oauth2/blob/master/model.go#L22 // Note that this struct does *not* satisfy the token interface shown here: https://github.com/gotosocial/oauth2/blob/master/model.go#L22
// and implemented here: https://github.com/gotosocial/oauth2/blob/master/models/token.go. // and implemented here: https://github.com/gotosocial/oauth2/blob/master/models/token.go.
// As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken // As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
// and pgTokenToOauthToken can be used for that. // and pgTokenToOauthToken can be used for that.
type oauthToken struct { type Token struct {
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"` ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
ClientID string ClientID string
UserID string UserID string
@ -186,7 +186,7 @@ type oauthToken struct {
} }
// oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres // oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres
func oauthTokenToPGToken(tkn *models.Token) *oauthToken { func oauthTokenToPGToken(tkn *models.Token) *Token {
now := time.Now() now := time.Now()
// For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's // For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
@ -208,7 +208,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
rea = now.Add(tkn.RefreshExpiresIn) rea = now.Add(tkn.RefreshExpiresIn)
} }
return &oauthToken{ return &Token{
ClientID: tkn.ClientID, ClientID: tkn.ClientID,
UserID: tkn.UserID, UserID: tkn.UserID,
RedirectURI: tkn.RedirectURI, RedirectURI: tkn.RedirectURI,
@ -228,7 +228,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
} }
// pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token // pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token
func pgTokenToOauthToken(pgt *oauthToken) *models.Token { func pgTokenToOauthToken(pgt *Token) *models.Token {
now := time.Now() now := time.Now()
return &models.Token{ return &models.Token{

View file

@ -0,0 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package oauth
// TODO: write tests

View file

@ -36,7 +36,7 @@ import (
// Router provides the REST interface for gotosocial, using gin. // Router provides the REST interface for gotosocial, using gin.
type Router interface { type Router interface {
// Attach a gin handler to the router with the given method and path // Attach a gin handler to the router with the given method and path
AttachHandler(method string, path string, handler gin.HandlerFunc) AttachHandler(method string, path string, f gin.HandlerFunc)
// Attach a gin middleware to the router that will be used globally // Attach a gin middleware to the router that will be used globally
AttachMiddleware(handler gin.HandlerFunc) AttachMiddleware(handler gin.HandlerFunc)
// Start the router // Start the router
@ -59,6 +59,8 @@ func (r *router) Start() {
r.logger.Fatalf("listen: %s", err) r.logger.Fatalf("listen: %s", err)
} }
}() }()
// c := &gin.Context{}
// c.Get()
} }
// Stop shuts down the router nicely // Stop shuts down the router nicely

31
pkg/mastotypes/token.go Normal file
View file

@ -0,0 +1,31 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package mastotypes
// Token represents an OAuth token used for authenticating with the API and performing actions.. See https://docs.joinmastodon.org/entities/token/
type Token struct {
// An OAuth token to be used for authorization.
AccessToken string `json:"access_token"`
// The OAuth token type. Mastodon uses Bearer tokens.
TokenType string `json:"token_type"`
// The OAuth scopes granted by this token, space-separated.
Scope string `json:"scope"`
// When the token was generated. (UNIX timestamp seconds)
CreatedAt int64 `json:"created_at"`
}