mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-30 18:42:26 -05:00 
			
		
		
		
	* start moving to bun * changing more stuff * more * and yet more * tests passing * seems stable now * more big changes * small fix * little fixes
		
			
				
	
	
		
			266 lines
		
	
	
	
		
			7.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			266 lines
		
	
	
	
		
			7.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // SCRAM-SHA-256 authentication
 | |
| //
 | |
| // Resources:
 | |
| //   https://tools.ietf.org/html/rfc5802
 | |
| //   https://tools.ietf.org/html/rfc8265
 | |
| //   https://www.postgresql.org/docs/current/sasl-authentication.html
 | |
| //
 | |
| // Inspiration drawn from other implementations:
 | |
| //   https://github.com/lib/pq/pull/608
 | |
| //   https://github.com/lib/pq/pull/788
 | |
| //   https://github.com/lib/pq/pull/833
 | |
| 
 | |
| package pgconn
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"crypto/hmac"
 | |
| 	"crypto/rand"
 | |
| 	"crypto/sha256"
 | |
| 	"encoding/base64"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"strconv"
 | |
| 
 | |
| 	"github.com/jackc/pgproto3/v2"
 | |
| 	"golang.org/x/crypto/pbkdf2"
 | |
| 	"golang.org/x/text/secure/precis"
 | |
| )
 | |
| 
 | |
| const clientNonceLen = 18
 | |
| 
 | |
| // Perform SCRAM authentication.
 | |
| func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
 | |
| 	sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// Send client-first-message in a SASLInitialResponse
 | |
| 	saslInitialResponse := &pgproto3.SASLInitialResponse{
 | |
| 		AuthMechanism: "SCRAM-SHA-256",
 | |
| 		Data:          sc.clientFirstMessage(),
 | |
| 	}
 | |
| 	_, err = c.conn.Write(saslInitialResponse.Encode(nil))
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// Receive server-first-message payload in a AuthenticationSASLContinue.
 | |
| 	saslContinue, err := c.rxSASLContinue()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	err = sc.recvServerFirstMessage(saslContinue.Data)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// Send client-final-message in a SASLResponse
 | |
| 	saslResponse := &pgproto3.SASLResponse{
 | |
| 		Data: []byte(sc.clientFinalMessage()),
 | |
| 	}
 | |
| 	_, err = c.conn.Write(saslResponse.Encode(nil))
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// Receive server-final-message payload in a AuthenticationSASLFinal.
 | |
| 	saslFinal, err := c.rxSASLFinal()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return sc.recvServerFinalMessage(saslFinal.Data)
 | |
| }
 | |
| 
 | |
| func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
 | |
| 	msg, err := c.receiveMessage()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue)
 | |
| 	if ok {
 | |
| 		return saslContinue, nil
 | |
| 	}
 | |
| 
 | |
| 	return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message")
 | |
| }
 | |
| 
 | |
| func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
 | |
| 	msg, err := c.receiveMessage()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal)
 | |
| 	if ok {
 | |
| 		return saslFinal, nil
 | |
| 	}
 | |
| 
 | |
| 	return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message")
 | |
| }
 | |
| 
 | |
| type scramClient struct {
 | |
| 	serverAuthMechanisms []string
 | |
| 	password             []byte
 | |
| 	clientNonce          []byte
 | |
| 
 | |
| 	clientFirstMessageBare []byte
 | |
| 
 | |
| 	serverFirstMessage   []byte
 | |
| 	clientAndServerNonce []byte
 | |
| 	salt                 []byte
 | |
| 	iterations           int
 | |
| 
 | |
| 	saltedPassword []byte
 | |
| 	authMessage    []byte
 | |
| }
 | |
| 
 | |
| func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
 | |
| 	sc := &scramClient{
 | |
| 		serverAuthMechanisms: serverAuthMechanisms,
 | |
| 	}
 | |
| 
 | |
| 	// Ensure server supports SCRAM-SHA-256
 | |
| 	hasScramSHA256 := false
 | |
| 	for _, mech := range sc.serverAuthMechanisms {
 | |
| 		if mech == "SCRAM-SHA-256" {
 | |
| 			hasScramSHA256 = true
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 	if !hasScramSHA256 {
 | |
| 		return nil, errors.New("server does not support SCRAM-SHA-256")
 | |
| 	}
 | |
| 
 | |
| 	// precis.OpaqueString is equivalent to SASLprep for password.
 | |
| 	var err error
 | |
| 	sc.password, err = precis.OpaqueString.Bytes([]byte(password))
 | |
| 	if err != nil {
 | |
| 		// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
 | |
| 		sc.password = []byte(password)
 | |
| 	}
 | |
| 
 | |
| 	buf := make([]byte, clientNonceLen)
 | |
| 	_, err = rand.Read(buf)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
 | |
| 	base64.RawStdEncoding.Encode(sc.clientNonce, buf)
 | |
| 
 | |
| 	return sc, nil
 | |
| }
 | |
| 
 | |
| func (sc *scramClient) clientFirstMessage() []byte {
 | |
| 	sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
 | |
| 	return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
 | |
| }
 | |
| 
 | |
| func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
 | |
| 	sc.serverFirstMessage = serverFirstMessage
 | |
| 	buf := serverFirstMessage
 | |
| 	if !bytes.HasPrefix(buf, []byte("r=")) {
 | |
| 		return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
 | |
| 	}
 | |
| 	buf = buf[2:]
 | |
| 
 | |
| 	idx := bytes.IndexByte(buf, ',')
 | |
| 	if idx == -1 {
 | |
| 		return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
 | |
| 	}
 | |
| 	sc.clientAndServerNonce = buf[:idx]
 | |
| 	buf = buf[idx+1:]
 | |
| 
 | |
| 	if !bytes.HasPrefix(buf, []byte("s=")) {
 | |
| 		return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
 | |
| 	}
 | |
| 	buf = buf[2:]
 | |
| 
 | |
| 	idx = bytes.IndexByte(buf, ',')
 | |
| 	if idx == -1 {
 | |
| 		return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
 | |
| 	}
 | |
| 	saltStr := buf[:idx]
 | |
| 	buf = buf[idx+1:]
 | |
| 
 | |
| 	if !bytes.HasPrefix(buf, []byte("i=")) {
 | |
| 		return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
 | |
| 	}
 | |
| 	buf = buf[2:]
 | |
| 	iterationsStr := buf
 | |
| 
 | |
| 	var err error
 | |
| 	sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	sc.iterations, err = strconv.Atoi(string(iterationsStr))
 | |
| 	if err != nil || sc.iterations <= 0 {
 | |
| 		return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
 | |
| 		return errors.New("invalid SCRAM nonce: did not start with client nonce")
 | |
| 	}
 | |
| 
 | |
| 	if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
 | |
| 		return errors.New("invalid SCRAM nonce: did not include server nonce")
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (sc *scramClient) clientFinalMessage() string {
 | |
| 	clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
 | |
| 
 | |
| 	sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
 | |
| 	sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
 | |
| 
 | |
| 	clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
 | |
| 
 | |
| 	return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
 | |
| }
 | |
| 
 | |
| func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
 | |
| 	if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
 | |
| 		return errors.New("invalid SCRAM server-final-message received from server")
 | |
| 	}
 | |
| 
 | |
| 	serverSignature := serverFinalMessage[2:]
 | |
| 
 | |
| 	if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
 | |
| 		return errors.New("invalid SCRAM ServerSignature received from server")
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func computeHMAC(key, msg []byte) []byte {
 | |
| 	mac := hmac.New(sha256.New, key)
 | |
| 	mac.Write(msg)
 | |
| 	return mac.Sum(nil)
 | |
| }
 | |
| 
 | |
| func computeClientProof(saltedPassword, authMessage []byte) []byte {
 | |
| 	clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
 | |
| 	storedKey := sha256.Sum256(clientKey)
 | |
| 	clientSignature := computeHMAC(storedKey[:], authMessage)
 | |
| 
 | |
| 	clientProof := make([]byte, len(clientSignature))
 | |
| 	for i := 0; i < len(clientSignature); i++ {
 | |
| 		clientProof[i] = clientKey[i] ^ clientSignature[i]
 | |
| 	}
 | |
| 
 | |
| 	buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
 | |
| 	base64.StdEncoding.Encode(buf, clientProof)
 | |
| 	return buf
 | |
| }
 | |
| 
 | |
| func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
 | |
| 	serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
 | |
| 	serverSignature := computeHMAC(serverKey, authMessage)
 | |
| 	buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
 | |
| 	base64.StdEncoding.Encode(buf, serverSignature)
 | |
| 	return buf
 | |
| }
 |