Pg to bun (#148)

* start moving to bun

* changing more stuff

* more

* and yet more

* tests passing

* seems stable now

* more big changes

* small fix

* little fixes
This commit is contained in:
tobi 2021-08-25 15:34:33 +02:00 committed by GitHub
commit 2dc9fc1626
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
713 changed files with 98694 additions and 22704 deletions

View file

@ -43,13 +43,13 @@ type tokenStore struct {
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
// the tokens in the DB once per minute and deletes any that have expired.
func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.TokenStore {
pts := &tokenStore{
ts := &tokenStore{
db: db,
log: log,
}
// set the token store to clean out expired tokens once per minute, or return if we're done
go func(ctx context.Context, pts *tokenStore, log *logrus.Logger) {
go func(ctx context.Context, ts *tokenStore, log *logrus.Logger) {
cleanloop:
for {
select {
@ -58,32 +58,32 @@ func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.
break cleanloop
case <-time.After(1 * time.Minute):
log.Trace("sweeping out old oauth entries broom broom")
if err := pts.sweep(); err != nil {
if err := ts.sweep(ctx); err != nil {
log.Errorf("error while sweeping oauth entries: %s", err)
}
}
}
}(ctx, pts, log)
return pts
}(ctx, ts, log)
return ts
}
// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so.
func (pts *tokenStore) sweep() error {
func (ts *tokenStore) sweep(ctx context.Context) error {
// select *all* tokens from the db
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
tokens := new([]*Token)
if err := pts.db.GetAll(tokens); err != nil {
if err := ts.db.GetAll(ctx, tokens); err != nil {
return err
}
// iterate through and remove expired tokens
now := time.Now()
for _, pgt := range *tokens {
for _, dbt := range *tokens {
// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
// we only want to check if a token expired before now if the expiry time is *not zero*;
// ie., if it's been explicity set.
if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) {
if err := pts.db.DeleteByID(pgt.ID, pgt); err != nil {
if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) {
if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil {
return err
}
}
@ -94,92 +94,92 @@ func (pts *tokenStore) sweep() error {
// Create creates and store the new token information.
// For the original implementation, see https://github.com/superseriousbusiness/oauth2/blob/master/store/token.go#L34
func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
t, ok := info.(*models.Token)
if !ok {
return errors.New("info param was not a models.Token")
}
pgt := TokenToPGToken(t)
if pgt.ID == "" {
pgtID, err := id.NewRandomULID()
dbt := TokenToDBToken(t)
if dbt.ID == "" {
dbtID, err := id.NewRandomULID()
if err != nil {
return err
}
pgt.ID = pgtID
dbt.ID = dbtID
}
if err := pts.db.Put(pgt); err != nil {
if err := ts.db.Put(ctx, dbt); err != nil {
return fmt.Errorf("error in tokenstore create: %s", err)
}
return nil
}
// RemoveByCode deletes a token from the DB based on the Code field
func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
return pts.db.DeleteWhere([]db.Where{{Key: "code", Value: code}}, &Token{})
func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
return ts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, &Token{})
}
// RemoveByAccess deletes a token from the DB based on the Access field
func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
return pts.db.DeleteWhere([]db.Where{{Key: "access", Value: access}}, &Token{})
func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
return ts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, &Token{})
}
// RemoveByRefresh deletes a token from the DB based on the Refresh field
func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
return pts.db.DeleteWhere([]db.Where{{Key: "refresh", Value: refresh}}, &Token{})
func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
return ts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, &Token{})
}
// GetByCode selects a token from the DB based on the Code field
func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
if code == "" {
return nil, nil
}
pgt := &Token{
dbt := &Token{
Code: code,
}
if err := pts.db.GetWhere([]db.Where{{Key: "code", Value: code}}, pgt); err != nil {
if err := ts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, dbt); err != nil {
return nil, err
}
return TokenToOauthToken(pgt), nil
return DBTokenToToken(dbt), nil
}
// GetByAccess selects a token from the DB based on the Access field
func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
if access == "" {
return nil, nil
}
pgt := &Token{
dbt := &Token{
Access: access,
}
if err := pts.db.GetWhere([]db.Where{{Key: "access", Value: access}}, pgt); err != nil {
if err := ts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, dbt); err != nil {
return nil, err
}
return TokenToOauthToken(pgt), nil
return DBTokenToToken(dbt), nil
}
// GetByRefresh selects a token from the DB based on the Refresh field
func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
if refresh == "" {
return nil, nil
}
pgt := &Token{
dbt := &Token{
Refresh: refresh,
}
if err := pts.db.GetWhere([]db.Where{{Key: "refresh", Value: refresh}}, pgt); err != nil {
if err := ts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, dbt); err != nil {
return nil, err
}
return TokenToOauthToken(pgt), nil
return DBTokenToToken(dbt), nil
}
/*
The following models are basically helpers for the postgres token store implementation, they should only be used internally.
The following models are basically helpers for the token store implementation, they should only be used internally.
*/
// Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
//
// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined,
// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and
// and tokens with expired TTLs are automatically removed. Since some databases don't have that feature, it's easier to set an expiry time and
// then periodically sweep out tokens when that time has passed.
//
// Note that this struct does *not* satisfy the token interface shown here: https://github.com/superseriousbusiness/oauth2/blob/master/model.go#L22
@ -187,26 +187,26 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
// As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
// and pgTokenToOauthToken can be used for that.
type Token struct {
ID string `pg:"type:CHAR(26),pk,notnull"`
ID string `bun:"type:CHAR(26),pk,notnull"`
ClientID string
UserID string
RedirectURI string
Scope string
Code string `pg:"default:'',pk"`
Code string `bun:"default:'',pk"`
CodeChallenge string
CodeChallengeMethod string
CodeCreateAt time.Time `pg:"type:timestamp"`
CodeExpiresAt time.Time `pg:"type:timestamp"`
Access string `pg:"default:'',pk"`
AccessCreateAt time.Time `pg:"type:timestamp"`
AccessExpiresAt time.Time `pg:"type:timestamp"`
Refresh string `pg:"default:'',pk"`
RefreshCreateAt time.Time `pg:"type:timestamp"`
RefreshExpiresAt time.Time `pg:"type:timestamp"`
CodeCreateAt time.Time `bun:",nullzero"`
CodeExpiresAt time.Time `bun:",nullzero"`
Access string `bun:"default:'',pk"`
AccessCreateAt time.Time `bun:",nullzero"`
AccessExpiresAt time.Time `bun:",nullzero"`
Refresh string `bun:"default:'',pk"`
RefreshCreateAt time.Time `bun:",nullzero"`
RefreshExpiresAt time.Time `bun:",nullzero"`
}
// TokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres
func TokenToPGToken(tkn *models.Token) *Token {
// TokenToDBToken is a lil util function that takes a gotosocial token and gives back a token for inserting into a database.
func TokenToDBToken(tkn *models.Token) *Token {
now := time.Now()
// For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
@ -247,40 +247,40 @@ func TokenToPGToken(tkn *models.Token) *Token {
}
}
// TokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token
func TokenToOauthToken(pgt *Token) *models.Token {
// DBTokenToToken is a lil util function that takes a database token and gives back a gotosocial token
func DBTokenToToken(dbt *Token) *models.Token {
now := time.Now()
var codeExpiresIn time.Duration
if !pgt.CodeExpiresAt.IsZero() {
codeExpiresIn = pgt.CodeExpiresAt.Sub(now)
if !dbt.CodeExpiresAt.IsZero() {
codeExpiresIn = dbt.CodeExpiresAt.Sub(now)
}
var accessExpiresIn time.Duration
if !pgt.AccessExpiresAt.IsZero() {
accessExpiresIn = pgt.AccessExpiresAt.Sub(now)
if !dbt.AccessExpiresAt.IsZero() {
accessExpiresIn = dbt.AccessExpiresAt.Sub(now)
}
var refreshExpiresIn time.Duration
if !pgt.RefreshExpiresAt.IsZero() {
refreshExpiresIn = pgt.RefreshExpiresAt.Sub(now)
if !dbt.RefreshExpiresAt.IsZero() {
refreshExpiresIn = dbt.RefreshExpiresAt.Sub(now)
}
return &models.Token{
ClientID: pgt.ClientID,
UserID: pgt.UserID,
RedirectURI: pgt.RedirectURI,
Scope: pgt.Scope,
Code: pgt.Code,
CodeChallenge: pgt.CodeChallenge,
CodeChallengeMethod: pgt.CodeChallengeMethod,
CodeCreateAt: pgt.CodeCreateAt,
ClientID: dbt.ClientID,
UserID: dbt.UserID,
RedirectURI: dbt.RedirectURI,
Scope: dbt.Scope,
Code: dbt.Code,
CodeChallenge: dbt.CodeChallenge,
CodeChallengeMethod: dbt.CodeChallengeMethod,
CodeCreateAt: dbt.CodeCreateAt,
CodeExpiresIn: codeExpiresIn,
Access: pgt.Access,
AccessCreateAt: pgt.AccessCreateAt,
Access: dbt.Access,
AccessCreateAt: dbt.AccessCreateAt,
AccessExpiresIn: accessExpiresIn,
Refresh: pgt.Refresh,
RefreshCreateAt: pgt.RefreshCreateAt,
Refresh: dbt.Refresh,
RefreshCreateAt: dbt.RefreshCreateAt,
RefreshExpiresIn: refreshExpiresIn,
}
}