mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-11-04 06:12:25 -06:00 
			
		
		
		
	* update dependencies, bump Go version to 1.19 * bump test image Go version * update golangci-lint * update gotosocial-drone-build * sign * linting, go fmt * update swagger docs * update swagger docs * whitespace * update contributing.md * fuckin whoopsie doopsie * linterino, linteroni * fix followrequest test not starting processor * fix other api/client tests not starting processor * fix remaining tests where processor not started * bump go-runners version * don't check last-webfingered-at, processor may have updated this * update swagger command * update bun to latest version * fix embed to work the same as before with new bun Signed-off-by: kim <grufwub@gmail.com> Co-authored-by: tsmethurst <tobi.smethurst@protonmail.com>
		
			
				
	
	
		
			206 lines
		
	
	
	
		
			5.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			206 lines
		
	
	
	
		
			5.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package pgproto3
 | 
						|
 | 
						|
import (
 | 
						|
	"encoding/binary"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
)
 | 
						|
 | 
						|
// Frontend acts as a client for the PostgreSQL wire protocol version 3.
 | 
						|
type Frontend struct {
 | 
						|
	cr ChunkReader
 | 
						|
	w  io.Writer
 | 
						|
 | 
						|
	// Backend message flyweights
 | 
						|
	authenticationOk                AuthenticationOk
 | 
						|
	authenticationCleartextPassword AuthenticationCleartextPassword
 | 
						|
	authenticationMD5Password       AuthenticationMD5Password
 | 
						|
	authenticationGSS               AuthenticationGSS
 | 
						|
	authenticationGSSContinue       AuthenticationGSSContinue
 | 
						|
	authenticationSASL              AuthenticationSASL
 | 
						|
	authenticationSASLContinue      AuthenticationSASLContinue
 | 
						|
	authenticationSASLFinal         AuthenticationSASLFinal
 | 
						|
	backendKeyData                  BackendKeyData
 | 
						|
	bindComplete                    BindComplete
 | 
						|
	closeComplete                   CloseComplete
 | 
						|
	commandComplete                 CommandComplete
 | 
						|
	copyBothResponse                CopyBothResponse
 | 
						|
	copyData                        CopyData
 | 
						|
	copyInResponse                  CopyInResponse
 | 
						|
	copyOutResponse                 CopyOutResponse
 | 
						|
	copyDone                        CopyDone
 | 
						|
	dataRow                         DataRow
 | 
						|
	emptyQueryResponse              EmptyQueryResponse
 | 
						|
	errorResponse                   ErrorResponse
 | 
						|
	functionCallResponse            FunctionCallResponse
 | 
						|
	noData                          NoData
 | 
						|
	noticeResponse                  NoticeResponse
 | 
						|
	notificationResponse            NotificationResponse
 | 
						|
	parameterDescription            ParameterDescription
 | 
						|
	parameterStatus                 ParameterStatus
 | 
						|
	parseComplete                   ParseComplete
 | 
						|
	readyForQuery                   ReadyForQuery
 | 
						|
	rowDescription                  RowDescription
 | 
						|
	portalSuspended                 PortalSuspended
 | 
						|
 | 
						|
	bodyLen    int
 | 
						|
	msgType    byte
 | 
						|
	partialMsg bool
 | 
						|
	authType   uint32
 | 
						|
}
 | 
						|
 | 
						|
// NewFrontend creates a new Frontend.
 | 
						|
func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
 | 
						|
	return &Frontend{cr: cr, w: w}
 | 
						|
}
 | 
						|
 | 
						|
// Send sends a message to the backend.
 | 
						|
func (f *Frontend) Send(msg FrontendMessage) error {
 | 
						|
	_, err := f.w.Write(msg.Encode(nil))
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func translateEOFtoErrUnexpectedEOF(err error) error {
 | 
						|
	if err == io.EOF {
 | 
						|
		return io.ErrUnexpectedEOF
 | 
						|
	}
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
// Receive receives a message from the backend. The returned message is only valid until the next call to Receive.
 | 
						|
func (f *Frontend) Receive() (BackendMessage, error) {
 | 
						|
	if !f.partialMsg {
 | 
						|
		header, err := f.cr.Next(5)
 | 
						|
		if err != nil {
 | 
						|
			return nil, translateEOFtoErrUnexpectedEOF(err)
 | 
						|
		}
 | 
						|
 | 
						|
		f.msgType = header[0]
 | 
						|
		f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
 | 
						|
		f.partialMsg = true
 | 
						|
		if f.bodyLen < 0 {
 | 
						|
			return nil, errors.New("invalid message with negative body length received")
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	msgBody, err := f.cr.Next(f.bodyLen)
 | 
						|
	if err != nil {
 | 
						|
		return nil, translateEOFtoErrUnexpectedEOF(err)
 | 
						|
	}
 | 
						|
 | 
						|
	f.partialMsg = false
 | 
						|
 | 
						|
	var msg BackendMessage
 | 
						|
	switch f.msgType {
 | 
						|
	case '1':
 | 
						|
		msg = &f.parseComplete
 | 
						|
	case '2':
 | 
						|
		msg = &f.bindComplete
 | 
						|
	case '3':
 | 
						|
		msg = &f.closeComplete
 | 
						|
	case 'A':
 | 
						|
		msg = &f.notificationResponse
 | 
						|
	case 'c':
 | 
						|
		msg = &f.copyDone
 | 
						|
	case 'C':
 | 
						|
		msg = &f.commandComplete
 | 
						|
	case 'd':
 | 
						|
		msg = &f.copyData
 | 
						|
	case 'D':
 | 
						|
		msg = &f.dataRow
 | 
						|
	case 'E':
 | 
						|
		msg = &f.errorResponse
 | 
						|
	case 'G':
 | 
						|
		msg = &f.copyInResponse
 | 
						|
	case 'H':
 | 
						|
		msg = &f.copyOutResponse
 | 
						|
	case 'I':
 | 
						|
		msg = &f.emptyQueryResponse
 | 
						|
	case 'K':
 | 
						|
		msg = &f.backendKeyData
 | 
						|
	case 'n':
 | 
						|
		msg = &f.noData
 | 
						|
	case 'N':
 | 
						|
		msg = &f.noticeResponse
 | 
						|
	case 'R':
 | 
						|
		var err error
 | 
						|
		msg, err = f.findAuthenticationMessageType(msgBody)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	case 's':
 | 
						|
		msg = &f.portalSuspended
 | 
						|
	case 'S':
 | 
						|
		msg = &f.parameterStatus
 | 
						|
	case 't':
 | 
						|
		msg = &f.parameterDescription
 | 
						|
	case 'T':
 | 
						|
		msg = &f.rowDescription
 | 
						|
	case 'V':
 | 
						|
		msg = &f.functionCallResponse
 | 
						|
	case 'W':
 | 
						|
		msg = &f.copyBothResponse
 | 
						|
	case 'Z':
 | 
						|
		msg = &f.readyForQuery
 | 
						|
	default:
 | 
						|
		return nil, fmt.Errorf("unknown message type: %c", f.msgType)
 | 
						|
	}
 | 
						|
 | 
						|
	err = msg.Decode(msgBody)
 | 
						|
	return msg, err
 | 
						|
}
 | 
						|
 | 
						|
// Authentication message type constants.
 | 
						|
// See src/include/libpq/pqcomm.h for all
 | 
						|
// constants.
 | 
						|
const (
 | 
						|
	AuthTypeOk                = 0
 | 
						|
	AuthTypeCleartextPassword = 3
 | 
						|
	AuthTypeMD5Password       = 5
 | 
						|
	AuthTypeSCMCreds          = 6
 | 
						|
	AuthTypeGSS               = 7
 | 
						|
	AuthTypeGSSCont           = 8
 | 
						|
	AuthTypeSSPI              = 9
 | 
						|
	AuthTypeSASL              = 10
 | 
						|
	AuthTypeSASLContinue      = 11
 | 
						|
	AuthTypeSASLFinal         = 12
 | 
						|
)
 | 
						|
 | 
						|
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
 | 
						|
	if len(src) < 4 {
 | 
						|
		return nil, errors.New("authentication message too short")
 | 
						|
	}
 | 
						|
	f.authType = binary.BigEndian.Uint32(src[:4])
 | 
						|
 | 
						|
	switch f.authType {
 | 
						|
	case AuthTypeOk:
 | 
						|
		return &f.authenticationOk, nil
 | 
						|
	case AuthTypeCleartextPassword:
 | 
						|
		return &f.authenticationCleartextPassword, nil
 | 
						|
	case AuthTypeMD5Password:
 | 
						|
		return &f.authenticationMD5Password, nil
 | 
						|
	case AuthTypeSCMCreds:
 | 
						|
		return nil, errors.New("AuthTypeSCMCreds is unimplemented")
 | 
						|
	case AuthTypeGSS:
 | 
						|
		return &f.authenticationGSS, nil
 | 
						|
	case AuthTypeGSSCont:
 | 
						|
		return &f.authenticationGSSContinue, nil
 | 
						|
	case AuthTypeSSPI:
 | 
						|
		return nil, errors.New("AuthTypeSSPI is unimplemented")
 | 
						|
	case AuthTypeSASL:
 | 
						|
		return &f.authenticationSASL, nil
 | 
						|
	case AuthTypeSASLContinue:
 | 
						|
		return &f.authenticationSASLContinue, nil
 | 
						|
	case AuthTypeSASLFinal:
 | 
						|
		return &f.authenticationSASLFinal, nil
 | 
						|
	default:
 | 
						|
		return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// GetAuthType returns the authType used in the current state of the frontend.
 | 
						|
// See SetAuthType for more information.
 | 
						|
func (f *Frontend) GetAuthType() uint32 {
 | 
						|
	return f.authType
 | 
						|
}
 |