mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 07:42:26 -05:00 
			
		
		
		
	* start moving to bun * changing more stuff * more * and yet more * tests passing * seems stable now * more big changes * small fix * little fixes
		
			
				
	
	
		
			201 lines
		
	
	
	
		
			5.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			201 lines
		
	
	
	
		
			5.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package pgproto3
 | |
| 
 | |
| import (
 | |
| 	"encoding/binary"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| )
 | |
| 
 | |
| // Frontend acts as a client for the PostgreSQL wire protocol version 3.
 | |
| type Frontend struct {
 | |
| 	cr ChunkReader
 | |
| 	w  io.Writer
 | |
| 
 | |
| 	// Backend message flyweights
 | |
| 	authenticationOk                AuthenticationOk
 | |
| 	authenticationCleartextPassword AuthenticationCleartextPassword
 | |
| 	authenticationMD5Password       AuthenticationMD5Password
 | |
| 	authenticationSASL              AuthenticationSASL
 | |
| 	authenticationSASLContinue      AuthenticationSASLContinue
 | |
| 	authenticationSASLFinal         AuthenticationSASLFinal
 | |
| 	backendKeyData                  BackendKeyData
 | |
| 	bindComplete                    BindComplete
 | |
| 	closeComplete                   CloseComplete
 | |
| 	commandComplete                 CommandComplete
 | |
| 	copyBothResponse                CopyBothResponse
 | |
| 	copyData                        CopyData
 | |
| 	copyInResponse                  CopyInResponse
 | |
| 	copyOutResponse                 CopyOutResponse
 | |
| 	copyDone                        CopyDone
 | |
| 	dataRow                         DataRow
 | |
| 	emptyQueryResponse              EmptyQueryResponse
 | |
| 	errorResponse                   ErrorResponse
 | |
| 	functionCallResponse            FunctionCallResponse
 | |
| 	noData                          NoData
 | |
| 	noticeResponse                  NoticeResponse
 | |
| 	notificationResponse            NotificationResponse
 | |
| 	parameterDescription            ParameterDescription
 | |
| 	parameterStatus                 ParameterStatus
 | |
| 	parseComplete                   ParseComplete
 | |
| 	readyForQuery                   ReadyForQuery
 | |
| 	rowDescription                  RowDescription
 | |
| 	portalSuspended                 PortalSuspended
 | |
| 
 | |
| 	bodyLen    int
 | |
| 	msgType    byte
 | |
| 	partialMsg bool
 | |
| 	authType   uint32
 | |
| }
 | |
| 
 | |
| // NewFrontend creates a new Frontend.
 | |
| func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
 | |
| 	return &Frontend{cr: cr, w: w}
 | |
| }
 | |
| 
 | |
| // Send sends a message to the backend.
 | |
| func (f *Frontend) Send(msg FrontendMessage) error {
 | |
| 	_, err := f.w.Write(msg.Encode(nil))
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func translateEOFtoErrUnexpectedEOF(err error) error {
 | |
| 	if err == io.EOF {
 | |
| 		return io.ErrUnexpectedEOF
 | |
| 	}
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // Receive receives a message from the backend. The returned message is only valid until the next call to Receive.
 | |
| func (f *Frontend) Receive() (BackendMessage, error) {
 | |
| 	if !f.partialMsg {
 | |
| 		header, err := f.cr.Next(5)
 | |
| 		if err != nil {
 | |
| 			return nil, translateEOFtoErrUnexpectedEOF(err)
 | |
| 		}
 | |
| 
 | |
| 		f.msgType = header[0]
 | |
| 		f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
 | |
| 		f.partialMsg = true
 | |
| 	}
 | |
| 
 | |
| 	msgBody, err := f.cr.Next(f.bodyLen)
 | |
| 	if err != nil {
 | |
| 		return nil, translateEOFtoErrUnexpectedEOF(err)
 | |
| 	}
 | |
| 
 | |
| 	f.partialMsg = false
 | |
| 
 | |
| 	var msg BackendMessage
 | |
| 	switch f.msgType {
 | |
| 	case '1':
 | |
| 		msg = &f.parseComplete
 | |
| 	case '2':
 | |
| 		msg = &f.bindComplete
 | |
| 	case '3':
 | |
| 		msg = &f.closeComplete
 | |
| 	case 'A':
 | |
| 		msg = &f.notificationResponse
 | |
| 	case 'c':
 | |
| 		msg = &f.copyDone
 | |
| 	case 'C':
 | |
| 		msg = &f.commandComplete
 | |
| 	case 'd':
 | |
| 		msg = &f.copyData
 | |
| 	case 'D':
 | |
| 		msg = &f.dataRow
 | |
| 	case 'E':
 | |
| 		msg = &f.errorResponse
 | |
| 	case 'G':
 | |
| 		msg = &f.copyInResponse
 | |
| 	case 'H':
 | |
| 		msg = &f.copyOutResponse
 | |
| 	case 'I':
 | |
| 		msg = &f.emptyQueryResponse
 | |
| 	case 'K':
 | |
| 		msg = &f.backendKeyData
 | |
| 	case 'n':
 | |
| 		msg = &f.noData
 | |
| 	case 'N':
 | |
| 		msg = &f.noticeResponse
 | |
| 	case 'R':
 | |
| 		var err error
 | |
| 		msg, err = f.findAuthenticationMessageType(msgBody)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 	case 's':
 | |
| 		msg = &f.portalSuspended
 | |
| 	case 'S':
 | |
| 		msg = &f.parameterStatus
 | |
| 	case 't':
 | |
| 		msg = &f.parameterDescription
 | |
| 	case 'T':
 | |
| 		msg = &f.rowDescription
 | |
| 	case 'V':
 | |
| 		msg = &f.functionCallResponse
 | |
| 	case 'W':
 | |
| 		msg = &f.copyBothResponse
 | |
| 	case 'Z':
 | |
| 		msg = &f.readyForQuery
 | |
| 	default:
 | |
| 		return nil, fmt.Errorf("unknown message type: %c", f.msgType)
 | |
| 	}
 | |
| 
 | |
| 	err = msg.Decode(msgBody)
 | |
| 	return msg, err
 | |
| }
 | |
| 
 | |
| // Authentication message type constants.
 | |
| // See src/include/libpq/pqcomm.h for all
 | |
| // constants.
 | |
| const (
 | |
| 	AuthTypeOk                = 0
 | |
| 	AuthTypeCleartextPassword = 3
 | |
| 	AuthTypeMD5Password       = 5
 | |
| 	AuthTypeSCMCreds          = 6
 | |
| 	AuthTypeGSS               = 7
 | |
| 	AuthTypeGSSCont           = 8
 | |
| 	AuthTypeSSPI              = 9
 | |
| 	AuthTypeSASL              = 10
 | |
| 	AuthTypeSASLContinue      = 11
 | |
| 	AuthTypeSASLFinal         = 12
 | |
| )
 | |
| 
 | |
| func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
 | |
| 	if len(src) < 4 {
 | |
| 		return nil, errors.New("authentication message too short")
 | |
| 	}
 | |
| 	f.authType = binary.BigEndian.Uint32(src[:4])
 | |
| 
 | |
| 	switch f.authType {
 | |
| 	case AuthTypeOk:
 | |
| 		return &f.authenticationOk, nil
 | |
| 	case AuthTypeCleartextPassword:
 | |
| 		return &f.authenticationCleartextPassword, nil
 | |
| 	case AuthTypeMD5Password:
 | |
| 		return &f.authenticationMD5Password, nil
 | |
| 	case AuthTypeSCMCreds:
 | |
| 		return nil, errors.New("AuthTypeSCMCreds is unimplemented")
 | |
| 	case AuthTypeGSS:
 | |
| 		return nil, errors.New("AuthTypeGSS is unimplemented")
 | |
| 	case AuthTypeGSSCont:
 | |
| 		return nil, errors.New("AuthTypeGSSCont is unimplemented")
 | |
| 	case AuthTypeSSPI:
 | |
| 		return nil, errors.New("AuthTypeSSPI is unimplemented")
 | |
| 	case AuthTypeSASL:
 | |
| 		return &f.authenticationSASL, nil
 | |
| 	case AuthTypeSASLContinue:
 | |
| 		return &f.authenticationSASLContinue, nil
 | |
| 	case AuthTypeSASLFinal:
 | |
| 		return &f.authenticationSASLFinal, nil
 | |
| 	default:
 | |
| 		return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // GetAuthType returns the authType used in the current state of the frontend.
 | |
| // See SetAuthType for more information.
 | |
| func (f *Frontend) GetAuthType() uint32 {
 | |
| 	return f.authType
 | |
| }
 |