mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-11-03 21:22:25 -06:00 
			
		
		
		
	
		
			
				
	
	
		
			140 lines
		
	
	
	
		
			2.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			140 lines
		
	
	
	
		
			2.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package sqlite3
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"math"
 | 
						|
	"reflect"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
	"unsafe"
 | 
						|
 | 
						|
	"github.com/ncruces/go-sqlite3/internal/util"
 | 
						|
)
 | 
						|
 | 
						|
// Quote escapes and quotes a value
 | 
						|
// making it safe to embed in SQL text.
 | 
						|
// Strings with embedded NUL characters are truncated.
 | 
						|
//
 | 
						|
// https://sqlite.org/lang_corefunc.html#quote
 | 
						|
func Quote(value any) string {
 | 
						|
	switch v := value.(type) {
 | 
						|
	case nil:
 | 
						|
		return "NULL"
 | 
						|
	case bool:
 | 
						|
		if v {
 | 
						|
			return "1"
 | 
						|
		} else {
 | 
						|
			return "0"
 | 
						|
		}
 | 
						|
 | 
						|
	case int:
 | 
						|
		return strconv.Itoa(v)
 | 
						|
	case int64:
 | 
						|
		return strconv.FormatInt(v, 10)
 | 
						|
	case float64:
 | 
						|
		switch {
 | 
						|
		case math.IsNaN(v):
 | 
						|
			return "NULL"
 | 
						|
		case math.IsInf(v, 1):
 | 
						|
			return "9.0e999"
 | 
						|
		case math.IsInf(v, -1):
 | 
						|
			return "-9.0e999"
 | 
						|
		}
 | 
						|
		return strconv.FormatFloat(v, 'g', -1, 64)
 | 
						|
	case time.Time:
 | 
						|
		return "'" + v.Format(time.RFC3339Nano) + "'"
 | 
						|
 | 
						|
	case string:
 | 
						|
		if i := strings.IndexByte(v, 0); i >= 0 {
 | 
						|
			v = v[:i]
 | 
						|
		}
 | 
						|
 | 
						|
		buf := make([]byte, 2+len(v)+strings.Count(v, "'"))
 | 
						|
		buf[0] = '\''
 | 
						|
		i := 1
 | 
						|
		for _, b := range []byte(v) {
 | 
						|
			if b == '\'' {
 | 
						|
				buf[i] = b
 | 
						|
				i += 1
 | 
						|
			}
 | 
						|
			buf[i] = b
 | 
						|
			i += 1
 | 
						|
		}
 | 
						|
		buf[len(buf)-1] = '\''
 | 
						|
		return unsafe.String(&buf[0], len(buf))
 | 
						|
 | 
						|
	case []byte:
 | 
						|
		buf := make([]byte, 3+2*len(v))
 | 
						|
		buf[1] = '\''
 | 
						|
		buf[0] = 'x'
 | 
						|
		i := 2
 | 
						|
		for _, b := range v {
 | 
						|
			const hex = "0123456789ABCDEF"
 | 
						|
			buf[i+0] = hex[b/16]
 | 
						|
			buf[i+1] = hex[b%16]
 | 
						|
			i += 2
 | 
						|
		}
 | 
						|
		buf[len(buf)-1] = '\''
 | 
						|
		return unsafe.String(&buf[0], len(buf))
 | 
						|
 | 
						|
	case ZeroBlob:
 | 
						|
		buf := bytes.Repeat([]byte("0"), int(3+2*int64(v)))
 | 
						|
		buf[1] = '\''
 | 
						|
		buf[0] = 'x'
 | 
						|
		buf[len(buf)-1] = '\''
 | 
						|
		return unsafe.String(&buf[0], len(buf))
 | 
						|
	}
 | 
						|
 | 
						|
	v := reflect.ValueOf(value)
 | 
						|
	k := v.Kind()
 | 
						|
 | 
						|
	if k == reflect.Interface || k == reflect.Pointer {
 | 
						|
		if v.IsNil() {
 | 
						|
			return "NULL"
 | 
						|
		}
 | 
						|
		v = v.Elem()
 | 
						|
		k = v.Kind()
 | 
						|
	}
 | 
						|
 | 
						|
	switch {
 | 
						|
	case v.CanInt():
 | 
						|
		return strconv.FormatInt(v.Int(), 10)
 | 
						|
	case v.CanUint():
 | 
						|
		return strconv.FormatUint(v.Uint(), 10)
 | 
						|
	case v.CanFloat():
 | 
						|
		return Quote(v.Float())
 | 
						|
	case k == reflect.Bool:
 | 
						|
		return Quote(v.Bool())
 | 
						|
	case k == reflect.String:
 | 
						|
		return Quote(v.String())
 | 
						|
	case (k == reflect.Slice || k == reflect.Array && v.CanAddr()) &&
 | 
						|
		v.Type().Elem().Kind() == reflect.Uint8:
 | 
						|
		return Quote(v.Bytes())
 | 
						|
	}
 | 
						|
 | 
						|
	panic(util.ValueErr)
 | 
						|
}
 | 
						|
 | 
						|
// QuoteIdentifier escapes and quotes an identifier
 | 
						|
// making it safe to embed in SQL text.
 | 
						|
// Strings with embedded NUL characters panic.
 | 
						|
func QuoteIdentifier(id string) string {
 | 
						|
	if strings.IndexByte(id, 0) >= 0 {
 | 
						|
		panic(util.ValueErr)
 | 
						|
	}
 | 
						|
 | 
						|
	buf := make([]byte, 2+len(id)+strings.Count(id, `"`))
 | 
						|
	buf[0] = '"'
 | 
						|
	i := 1
 | 
						|
	for _, b := range []byte(id) {
 | 
						|
		if b == '"' {
 | 
						|
			buf[i] = b
 | 
						|
			i += 1
 | 
						|
		}
 | 
						|
		buf[i] = b
 | 
						|
		i += 1
 | 
						|
	}
 | 
						|
	buf[len(buf)-1] = '"'
 | 
						|
	return unsafe.String(&buf[0], len(buf))
 | 
						|
}
 |