mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 04:22:25 -05:00 
			
		
		
		
	
		
			
				
	
	
		
			138 lines
		
	
	
	
		
			2.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			138 lines
		
	
	
	
		
			2.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package pgdialect
 | |
| 
 | |
| import (
 | |
| 	"database/sql"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/uptrace/bun/internal"
 | |
| 	"github.com/uptrace/bun/schema"
 | |
| )
 | |
| 
 | |
| type MultiRange[T any] []Range[T]
 | |
| 
 | |
| type Range[T any] struct {
 | |
| 	Lower, Upper           T
 | |
| 	LowerBound, UpperBound RangeBound
 | |
| }
 | |
| 
 | |
| type RangeBound byte
 | |
| 
 | |
| const (
 | |
| 	RangeBoundInclusiveLeft  RangeBound = '['
 | |
| 	RangeBoundInclusiveRight RangeBound = ']'
 | |
| 	RangeBoundExclusiveLeft  RangeBound = '('
 | |
| 	RangeBoundExclusiveRight RangeBound = ')'
 | |
| )
 | |
| 
 | |
| func NewRange[T any](lower, upper T) Range[T] {
 | |
| 	return Range[T]{
 | |
| 		Lower:      lower,
 | |
| 		Upper:      upper,
 | |
| 		LowerBound: RangeBoundInclusiveLeft,
 | |
| 		UpperBound: RangeBoundExclusiveRight,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| var _ sql.Scanner = (*Range[any])(nil)
 | |
| 
 | |
| func (r *Range[T]) Scan(anySrc any) (err error) {
 | |
| 	src, ok := anySrc.([]byte)
 | |
| 	if !ok {
 | |
| 		return fmt.Errorf("pgdialect: Range can't scan %T", anySrc)
 | |
| 	}
 | |
| 
 | |
| 	if len(src) == 0 {
 | |
| 		return io.ErrUnexpectedEOF
 | |
| 	}
 | |
| 	r.LowerBound = RangeBound(src[0])
 | |
| 	src = src[1:]
 | |
| 
 | |
| 	src, err = scanElem(&r.Lower, src)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if len(src) == 0 {
 | |
| 		return io.ErrUnexpectedEOF
 | |
| 	}
 | |
| 	if ch := src[0]; ch != ',' {
 | |
| 		return fmt.Errorf("got %q, wanted %q", ch, ',')
 | |
| 	}
 | |
| 	src = src[1:]
 | |
| 
 | |
| 	src, err = scanElem(&r.Upper, src)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if len(src) == 0 {
 | |
| 		return io.ErrUnexpectedEOF
 | |
| 	}
 | |
| 	r.UpperBound = RangeBound(src[0])
 | |
| 	src = src[1:]
 | |
| 
 | |
| 	if len(src) > 0 {
 | |
| 		return fmt.Errorf("unread data: %q", src)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| var _ schema.QueryAppender = (*Range[any])(nil)
 | |
| 
 | |
| func (r *Range[T]) AppendQuery(fmt schema.Formatter, buf []byte) ([]byte, error) {
 | |
| 	buf = append(buf, byte(r.LowerBound))
 | |
| 	buf = appendElem(buf, r.Lower)
 | |
| 	buf = append(buf, ',')
 | |
| 	buf = appendElem(buf, r.Upper)
 | |
| 	buf = append(buf, byte(r.UpperBound))
 | |
| 	return buf, nil
 | |
| }
 | |
| 
 | |
| func scanElem(ptr any, src []byte) ([]byte, error) {
 | |
| 	switch ptr := ptr.(type) {
 | |
| 	case *time.Time:
 | |
| 		src, str, err := readStringLiteral(src)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		tm, err := internal.ParseTime(internal.String(str))
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		*ptr = tm
 | |
| 
 | |
| 		return src, nil
 | |
| 
 | |
| 	case sql.Scanner:
 | |
| 		src, str, err := readStringLiteral(src)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		if err := ptr.Scan(str); err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		return src, nil
 | |
| 
 | |
| 	default:
 | |
| 		panic(fmt.Errorf("unsupported range type: %T", ptr))
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func readStringLiteral(src []byte) ([]byte, []byte, error) {
 | |
| 	p := newParser(src)
 | |
| 
 | |
| 	if err := p.Skip('"'); err != nil {
 | |
| 		return nil, nil, err
 | |
| 	}
 | |
| 
 | |
| 	str, err := p.ReadSubstring('"')
 | |
| 	if err != nil {
 | |
| 		return nil, nil, err
 | |
| 	}
 | |
| 
 | |
| 	src = p.Remaining()
 | |
| 	return src, str, nil
 | |
| }
 |