mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-30 20:12:26 -05:00 
			
		
		
		
	* start moving to bun * changing more stuff * more * and yet more * tests passing * seems stable now * more big changes * small fix * little fixes
		
			
				
	
	
		
			280 lines
		
	
	
	
		
			6.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			280 lines
		
	
	
	
		
			6.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package pgx
 | |
| 
 | |
| import (
 | |
| 	"database/sql/driver"
 | |
| 	"fmt"
 | |
| 	"math"
 | |
| 	"reflect"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/jackc/pgio"
 | |
| 	"github.com/jackc/pgtype"
 | |
| )
 | |
| 
 | |
| // PostgreSQL format codes
 | |
| const (
 | |
| 	TextFormatCode   = 0
 | |
| 	BinaryFormatCode = 1
 | |
| )
 | |
| 
 | |
| // SerializationError occurs on failure to encode or decode a value
 | |
| type SerializationError string
 | |
| 
 | |
| func (e SerializationError) Error() string {
 | |
| 	return string(e)
 | |
| }
 | |
| 
 | |
| func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) {
 | |
| 	if arg == nil {
 | |
| 		return nil, nil
 | |
| 	}
 | |
| 
 | |
| 	refVal := reflect.ValueOf(arg)
 | |
| 	if refVal.Kind() == reflect.Ptr && refVal.IsNil() {
 | |
| 		return nil, nil
 | |
| 	}
 | |
| 
 | |
| 	switch arg := arg.(type) {
 | |
| 
 | |
| 	// https://github.com/jackc/pgx/issues/409 Changed JSON and JSONB to surface
 | |
| 	// []byte to database/sql instead of string. But that caused problems with the
 | |
| 	// simple protocol because the driver.Valuer case got taken before the
 | |
| 	// pgtype.TextEncoder case. And driver.Valuer needed to be first in the usual
 | |
| 	// case because of https://github.com/jackc/pgx/issues/339. So instead we
 | |
| 	// special case JSON and JSONB.
 | |
| 	case *pgtype.JSON:
 | |
| 		buf, err := arg.EncodeText(ci, nil)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		if buf == nil {
 | |
| 			return nil, nil
 | |
| 		}
 | |
| 		return string(buf), nil
 | |
| 	case *pgtype.JSONB:
 | |
| 		buf, err := arg.EncodeText(ci, nil)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		if buf == nil {
 | |
| 			return nil, nil
 | |
| 		}
 | |
| 		return string(buf), nil
 | |
| 
 | |
| 	case driver.Valuer:
 | |
| 		return callValuerValue(arg)
 | |
| 	case pgtype.TextEncoder:
 | |
| 		buf, err := arg.EncodeText(ci, nil)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		if buf == nil {
 | |
| 			return nil, nil
 | |
| 		}
 | |
| 		return string(buf), nil
 | |
| 	case float32:
 | |
| 		return float64(arg), nil
 | |
| 	case float64:
 | |
| 		return arg, nil
 | |
| 	case bool:
 | |
| 		return arg, nil
 | |
| 	case time.Duration:
 | |
| 		return fmt.Sprintf("%d microsecond", int64(arg)/1000), nil
 | |
| 	case time.Time:
 | |
| 		return arg, nil
 | |
| 	case string:
 | |
| 		return arg, nil
 | |
| 	case []byte:
 | |
| 		return arg, nil
 | |
| 	case int8:
 | |
| 		return int64(arg), nil
 | |
| 	case int16:
 | |
| 		return int64(arg), nil
 | |
| 	case int32:
 | |
| 		return int64(arg), nil
 | |
| 	case int64:
 | |
| 		return arg, nil
 | |
| 	case int:
 | |
| 		return int64(arg), nil
 | |
| 	case uint8:
 | |
| 		return int64(arg), nil
 | |
| 	case uint16:
 | |
| 		return int64(arg), nil
 | |
| 	case uint32:
 | |
| 		return int64(arg), nil
 | |
| 	case uint64:
 | |
| 		if arg > math.MaxInt64 {
 | |
| 			return nil, fmt.Errorf("arg too big for int64: %v", arg)
 | |
| 		}
 | |
| 		return int64(arg), nil
 | |
| 	case uint:
 | |
| 		if uint64(arg) > math.MaxInt64 {
 | |
| 			return nil, fmt.Errorf("arg too big for int64: %v", arg)
 | |
| 		}
 | |
| 		return int64(arg), nil
 | |
| 	}
 | |
| 
 | |
| 	if dt, found := ci.DataTypeForValue(arg); found {
 | |
| 		v := dt.Value
 | |
| 		err := v.Set(arg)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		if buf == nil {
 | |
| 			return nil, nil
 | |
| 		}
 | |
| 		return string(buf), nil
 | |
| 	}
 | |
| 
 | |
| 	if refVal.Kind() == reflect.Ptr {
 | |
| 		arg = refVal.Elem().Interface()
 | |
| 		return convertSimpleArgument(ci, arg)
 | |
| 	}
 | |
| 
 | |
| 	if strippedArg, ok := stripNamedType(&refVal); ok {
 | |
| 		return convertSimpleArgument(ci, strippedArg)
 | |
| 	}
 | |
| 	return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg))
 | |
| }
 | |
| 
 | |
| func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32, arg interface{}) ([]byte, error) {
 | |
| 	if arg == nil {
 | |
| 		return pgio.AppendInt32(buf, -1), nil
 | |
| 	}
 | |
| 
 | |
| 	switch arg := arg.(type) {
 | |
| 	case pgtype.BinaryEncoder:
 | |
| 		sp := len(buf)
 | |
| 		buf = pgio.AppendInt32(buf, -1)
 | |
| 		argBuf, err := arg.EncodeBinary(ci, buf)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		if argBuf != nil {
 | |
| 			buf = argBuf
 | |
| 			pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
 | |
| 		}
 | |
| 		return buf, nil
 | |
| 	case pgtype.TextEncoder:
 | |
| 		sp := len(buf)
 | |
| 		buf = pgio.AppendInt32(buf, -1)
 | |
| 		argBuf, err := arg.EncodeText(ci, buf)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		if argBuf != nil {
 | |
| 			buf = argBuf
 | |
| 			pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
 | |
| 		}
 | |
| 		return buf, nil
 | |
| 	case string:
 | |
| 		buf = pgio.AppendInt32(buf, int32(len(arg)))
 | |
| 		buf = append(buf, arg...)
 | |
| 		return buf, nil
 | |
| 	}
 | |
| 
 | |
| 	refVal := reflect.ValueOf(arg)
 | |
| 
 | |
| 	if refVal.Kind() == reflect.Ptr {
 | |
| 		if refVal.IsNil() {
 | |
| 			return pgio.AppendInt32(buf, -1), nil
 | |
| 		}
 | |
| 		arg = refVal.Elem().Interface()
 | |
| 		return encodePreparedStatementArgument(ci, buf, oid, arg)
 | |
| 	}
 | |
| 
 | |
| 	if dt, ok := ci.DataTypeForOID(oid); ok {
 | |
| 		value := dt.Value
 | |
| 		err := value.Set(arg)
 | |
| 		if err != nil {
 | |
| 			{
 | |
| 				if arg, ok := arg.(driver.Valuer); ok {
 | |
| 					v, err := callValuerValue(arg)
 | |
| 					if err != nil {
 | |
| 						return nil, err
 | |
| 					}
 | |
| 					return encodePreparedStatementArgument(ci, buf, oid, v)
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		sp := len(buf)
 | |
| 		buf = pgio.AppendInt32(buf, -1)
 | |
| 		argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		if argBuf != nil {
 | |
| 			buf = argBuf
 | |
| 			pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
 | |
| 		}
 | |
| 		return buf, nil
 | |
| 	}
 | |
| 
 | |
| 	if strippedArg, ok := stripNamedType(&refVal); ok {
 | |
| 		return encodePreparedStatementArgument(ci, buf, oid, strippedArg)
 | |
| 	}
 | |
| 	return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
 | |
| }
 | |
| 
 | |
| // chooseParameterFormatCode determines the correct format code for an
 | |
| // argument to a prepared statement. It defaults to TextFormatCode if no
 | |
| // determination can be made.
 | |
| func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid uint32, arg interface{}) int16 {
 | |
| 	switch arg := arg.(type) {
 | |
| 	case pgtype.ParamFormatPreferrer:
 | |
| 		return arg.PreferredParamFormat()
 | |
| 	case pgtype.BinaryEncoder:
 | |
| 		return BinaryFormatCode
 | |
| 	case string, *string, pgtype.TextEncoder:
 | |
| 		return TextFormatCode
 | |
| 	}
 | |
| 
 | |
| 	return ci.ParamFormatCodeForOID(oid)
 | |
| }
 | |
| 
 | |
| func stripNamedType(val *reflect.Value) (interface{}, bool) {
 | |
| 	switch val.Kind() {
 | |
| 	case reflect.Int:
 | |
| 		convVal := int(val.Int())
 | |
| 		return convVal, reflect.TypeOf(convVal) != val.Type()
 | |
| 	case reflect.Int8:
 | |
| 		convVal := int8(val.Int())
 | |
| 		return convVal, reflect.TypeOf(convVal) != val.Type()
 | |
| 	case reflect.Int16:
 | |
| 		convVal := int16(val.Int())
 | |
| 		return convVal, reflect.TypeOf(convVal) != val.Type()
 | |
| 	case reflect.Int32:
 | |
| 		convVal := int32(val.Int())
 | |
| 		return convVal, reflect.TypeOf(convVal) != val.Type()
 | |
| 	case reflect.Int64:
 | |
| 		convVal := int64(val.Int())
 | |
| 		return convVal, reflect.TypeOf(convVal) != val.Type()
 | |
| 	case reflect.Uint:
 | |
| 		convVal := uint(val.Uint())
 | |
| 		return convVal, reflect.TypeOf(convVal) != val.Type()
 | |
| 	case reflect.Uint8:
 | |
| 		convVal := uint8(val.Uint())
 | |
| 		return convVal, reflect.TypeOf(convVal) != val.Type()
 | |
| 	case reflect.Uint16:
 | |
| 		convVal := uint16(val.Uint())
 | |
| 		return convVal, reflect.TypeOf(convVal) != val.Type()
 | |
| 	case reflect.Uint32:
 | |
| 		convVal := uint32(val.Uint())
 | |
| 		return convVal, reflect.TypeOf(convVal) != val.Type()
 | |
| 	case reflect.Uint64:
 | |
| 		convVal := uint64(val.Uint())
 | |
| 		return convVal, reflect.TypeOf(convVal) != val.Type()
 | |
| 	case reflect.String:
 | |
| 		convVal := val.String()
 | |
| 		return convVal, reflect.TypeOf(convVal) != val.Type()
 | |
| 	}
 | |
| 
 | |
| 	return nil, false
 | |
| }
 |