mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 04:32:25 -05:00 
			
		
		
		
	
		
			
				
	
	
		
			329 lines
		
	
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			329 lines
		
	
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package runtime
 | |
| 
 | |
| import (
 | |
| 	"encoding/base64"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"net/url"
 | |
| 	"regexp"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
 | |
| 	"google.golang.org/genproto/protobuf/field_mask"
 | |
| 	"google.golang.org/grpc/grpclog"
 | |
| 	"google.golang.org/protobuf/proto"
 | |
| 	"google.golang.org/protobuf/reflect/protoreflect"
 | |
| 	"google.golang.org/protobuf/reflect/protoregistry"
 | |
| 	"google.golang.org/protobuf/types/known/durationpb"
 | |
| 	"google.golang.org/protobuf/types/known/timestamppb"
 | |
| 	"google.golang.org/protobuf/types/known/wrapperspb"
 | |
| )
 | |
| 
 | |
| var valuesKeyRegexp = regexp.MustCompile(`^(.*)\[(.*)\]$`)
 | |
| 
 | |
| var currentQueryParser QueryParameterParser = &defaultQueryParser{}
 | |
| 
 | |
| // QueryParameterParser defines interface for all query parameter parsers
 | |
| type QueryParameterParser interface {
 | |
| 	Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error
 | |
| }
 | |
| 
 | |
| // PopulateQueryParameters parses query parameters
 | |
| // into "msg" using current query parser
 | |
| func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
 | |
| 	return currentQueryParser.Parse(msg, values, filter)
 | |
| }
 | |
| 
 | |
| type defaultQueryParser struct{}
 | |
| 
 | |
| // Parse populates "values" into "msg".
 | |
| // A value is ignored if its key starts with one of the elements in "filter".
 | |
| func (*defaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
 | |
| 	for key, values := range values {
 | |
| 		match := valuesKeyRegexp.FindStringSubmatch(key)
 | |
| 		if len(match) == 3 {
 | |
| 			key = match[1]
 | |
| 			values = append([]string{match[2]}, values...)
 | |
| 		}
 | |
| 		fieldPath := strings.Split(key, ".")
 | |
| 		if filter.HasCommonPrefix(fieldPath) {
 | |
| 			continue
 | |
| 		}
 | |
| 		if err := populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, values); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // PopulateFieldFromPath sets a value in a nested Protobuf structure.
 | |
| func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
 | |
| 	fieldPath := strings.Split(fieldPathString, ".")
 | |
| 	return populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, []string{value})
 | |
| }
 | |
| 
 | |
| func populateFieldValueFromPath(msgValue protoreflect.Message, fieldPath []string, values []string) error {
 | |
| 	if len(fieldPath) < 1 {
 | |
| 		return errors.New("no field path")
 | |
| 	}
 | |
| 	if len(values) < 1 {
 | |
| 		return errors.New("no value provided")
 | |
| 	}
 | |
| 
 | |
| 	var fieldDescriptor protoreflect.FieldDescriptor
 | |
| 	for i, fieldName := range fieldPath {
 | |
| 		fields := msgValue.Descriptor().Fields()
 | |
| 
 | |
| 		// Get field by name
 | |
| 		fieldDescriptor = fields.ByName(protoreflect.Name(fieldName))
 | |
| 		if fieldDescriptor == nil {
 | |
| 			fieldDescriptor = fields.ByJSONName(fieldName)
 | |
| 			if fieldDescriptor == nil {
 | |
| 				// We're not returning an error here because this could just be
 | |
| 				// an extra query parameter that isn't part of the request.
 | |
| 				grpclog.Infof("field not found in %q: %q", msgValue.Descriptor().FullName(), strings.Join(fieldPath, "."))
 | |
| 				return nil
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// If this is the last element, we're done
 | |
| 		if i == len(fieldPath)-1 {
 | |
| 			break
 | |
| 		}
 | |
| 
 | |
| 		// Only singular message fields are allowed
 | |
| 		if fieldDescriptor.Message() == nil || fieldDescriptor.Cardinality() == protoreflect.Repeated {
 | |
| 			return fmt.Errorf("invalid path: %q is not a message", fieldName)
 | |
| 		}
 | |
| 
 | |
| 		// Get the nested message
 | |
| 		msgValue = msgValue.Mutable(fieldDescriptor).Message()
 | |
| 	}
 | |
| 
 | |
| 	// Check if oneof already set
 | |
| 	if of := fieldDescriptor.ContainingOneof(); of != nil {
 | |
| 		if f := msgValue.WhichOneof(of); f != nil {
 | |
| 			return fmt.Errorf("field already set for oneof %q", of.FullName().Name())
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	switch {
 | |
| 	case fieldDescriptor.IsList():
 | |
| 		return populateRepeatedField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).List(), values)
 | |
| 	case fieldDescriptor.IsMap():
 | |
| 		return populateMapField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).Map(), values)
 | |
| 	}
 | |
| 
 | |
| 	if len(values) > 1 {
 | |
| 		return fmt.Errorf("too many values for field %q: %s", fieldDescriptor.FullName().Name(), strings.Join(values, ", "))
 | |
| 	}
 | |
| 
 | |
| 	return populateField(fieldDescriptor, msgValue, values[0])
 | |
| }
 | |
| 
 | |
| func populateField(fieldDescriptor protoreflect.FieldDescriptor, msgValue protoreflect.Message, value string) error {
 | |
| 	v, err := parseField(fieldDescriptor, value)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("parsing field %q: %w", fieldDescriptor.FullName().Name(), err)
 | |
| 	}
 | |
| 
 | |
| 	msgValue.Set(fieldDescriptor, v)
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func populateRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, values []string) error {
 | |
| 	for _, value := range values {
 | |
| 		v, err := parseField(fieldDescriptor, value)
 | |
| 		if err != nil {
 | |
| 			return fmt.Errorf("parsing list %q: %w", fieldDescriptor.FullName().Name(), err)
 | |
| 		}
 | |
| 		list.Append(v)
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func populateMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error {
 | |
| 	if len(values) != 2 {
 | |
| 		return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fieldDescriptor.FullName())
 | |
| 	}
 | |
| 
 | |
| 	key, err := parseField(fieldDescriptor.MapKey(), values[0])
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("parsing map key %q: %w", fieldDescriptor.FullName().Name(), err)
 | |
| 	}
 | |
| 
 | |
| 	value, err := parseField(fieldDescriptor.MapValue(), values[1])
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("parsing map value %q: %w", fieldDescriptor.FullName().Name(), err)
 | |
| 	}
 | |
| 
 | |
| 	mp.Set(key.MapKey(), value)
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func parseField(fieldDescriptor protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {
 | |
| 	switch fieldDescriptor.Kind() {
 | |
| 	case protoreflect.BoolKind:
 | |
| 		v, err := strconv.ParseBool(value)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		return protoreflect.ValueOfBool(v), nil
 | |
| 	case protoreflect.EnumKind:
 | |
| 		enum, err := protoregistry.GlobalTypes.FindEnumByName(fieldDescriptor.Enum().FullName())
 | |
| 		switch {
 | |
| 		case errors.Is(err, protoregistry.NotFound):
 | |
| 			return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fieldDescriptor.Enum().FullName())
 | |
| 		case err != nil:
 | |
| 			return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err)
 | |
| 		}
 | |
| 		// Look for enum by name
 | |
| 		v := enum.Descriptor().Values().ByName(protoreflect.Name(value))
 | |
| 		if v == nil {
 | |
| 			i, err := strconv.Atoi(value)
 | |
| 			if err != nil {
 | |
| 				return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
 | |
| 			}
 | |
| 			// Look for enum by number
 | |
| 			v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i))
 | |
| 			if v == nil {
 | |
| 				return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
 | |
| 			}
 | |
| 		}
 | |
| 		return protoreflect.ValueOfEnum(v.Number()), nil
 | |
| 	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
 | |
| 		v, err := strconv.ParseInt(value, 10, 32)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		return protoreflect.ValueOfInt32(int32(v)), nil
 | |
| 	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
 | |
| 		v, err := strconv.ParseInt(value, 10, 64)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		return protoreflect.ValueOfInt64(v), nil
 | |
| 	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
 | |
| 		v, err := strconv.ParseUint(value, 10, 32)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		return protoreflect.ValueOfUint32(uint32(v)), nil
 | |
| 	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
 | |
| 		v, err := strconv.ParseUint(value, 10, 64)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		return protoreflect.ValueOfUint64(v), nil
 | |
| 	case protoreflect.FloatKind:
 | |
| 		v, err := strconv.ParseFloat(value, 32)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		return protoreflect.ValueOfFloat32(float32(v)), nil
 | |
| 	case protoreflect.DoubleKind:
 | |
| 		v, err := strconv.ParseFloat(value, 64)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		return protoreflect.ValueOfFloat64(v), nil
 | |
| 	case protoreflect.StringKind:
 | |
| 		return protoreflect.ValueOfString(value), nil
 | |
| 	case protoreflect.BytesKind:
 | |
| 		v, err := base64.URLEncoding.DecodeString(value)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		return protoreflect.ValueOfBytes(v), nil
 | |
| 	case protoreflect.MessageKind, protoreflect.GroupKind:
 | |
| 		return parseMessage(fieldDescriptor.Message(), value)
 | |
| 	default:
 | |
| 		panic(fmt.Sprintf("unknown field kind: %v", fieldDescriptor.Kind()))
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) {
 | |
| 	var msg proto.Message
 | |
| 	switch msgDescriptor.FullName() {
 | |
| 	case "google.protobuf.Timestamp":
 | |
| 		if value == "null" {
 | |
| 			break
 | |
| 		}
 | |
| 		t, err := time.Parse(time.RFC3339Nano, value)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		msg = timestamppb.New(t)
 | |
| 	case "google.protobuf.Duration":
 | |
| 		if value == "null" {
 | |
| 			break
 | |
| 		}
 | |
| 		d, err := time.ParseDuration(value)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		msg = durationpb.New(d)
 | |
| 	case "google.protobuf.DoubleValue":
 | |
| 		v, err := strconv.ParseFloat(value, 64)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		msg = &wrapperspb.DoubleValue{Value: v}
 | |
| 	case "google.protobuf.FloatValue":
 | |
| 		v, err := strconv.ParseFloat(value, 32)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		msg = &wrapperspb.FloatValue{Value: float32(v)}
 | |
| 	case "google.protobuf.Int64Value":
 | |
| 		v, err := strconv.ParseInt(value, 10, 64)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		msg = &wrapperspb.Int64Value{Value: v}
 | |
| 	case "google.protobuf.Int32Value":
 | |
| 		v, err := strconv.ParseInt(value, 10, 32)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		msg = &wrapperspb.Int32Value{Value: int32(v)}
 | |
| 	case "google.protobuf.UInt64Value":
 | |
| 		v, err := strconv.ParseUint(value, 10, 64)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		msg = &wrapperspb.UInt64Value{Value: v}
 | |
| 	case "google.protobuf.UInt32Value":
 | |
| 		v, err := strconv.ParseUint(value, 10, 32)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		msg = &wrapperspb.UInt32Value{Value: uint32(v)}
 | |
| 	case "google.protobuf.BoolValue":
 | |
| 		v, err := strconv.ParseBool(value)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		msg = &wrapperspb.BoolValue{Value: v}
 | |
| 	case "google.protobuf.StringValue":
 | |
| 		msg = &wrapperspb.StringValue{Value: value}
 | |
| 	case "google.protobuf.BytesValue":
 | |
| 		v, err := base64.URLEncoding.DecodeString(value)
 | |
| 		if err != nil {
 | |
| 			return protoreflect.Value{}, err
 | |
| 		}
 | |
| 		msg = &wrapperspb.BytesValue{Value: v}
 | |
| 	case "google.protobuf.FieldMask":
 | |
| 		fm := &field_mask.FieldMask{}
 | |
| 		fm.Paths = append(fm.Paths, strings.Split(value, ",")...)
 | |
| 		msg = fm
 | |
| 	default:
 | |
| 		return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
 | |
| 	}
 | |
| 
 | |
| 	return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
 | |
| }
 |