mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-11-01 13:52:26 -05:00
Pg to bun (#148)
* start moving to bun * changing more stuff * more * and yet more * tests passing * seems stable now * more big changes * small fix * little fixes
This commit is contained in:
parent
071eca20ce
commit
2dc9fc1626
713 changed files with 98694 additions and 22704 deletions
93
vendor/github.com/uptrace/bun/schema/append.go
generated
vendored
Normal file
93
vendor/github.com/uptrace/bun/schema/append.go
generated
vendored
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
|
||||
"github.com/uptrace/bun/dialect"
|
||||
"github.com/uptrace/bun/dialect/sqltype"
|
||||
"github.com/uptrace/bun/internal"
|
||||
)
|
||||
|
||||
func FieldAppender(dialect Dialect, field *Field) AppenderFunc {
|
||||
if field.Tag.HasOption("msgpack") {
|
||||
return appendMsgpack
|
||||
}
|
||||
|
||||
switch strings.ToUpper(field.UserSQLType) {
|
||||
case sqltype.JSON, sqltype.JSONB:
|
||||
return AppendJSONValue
|
||||
}
|
||||
|
||||
return dialect.Appender(field.StructField.Type)
|
||||
}
|
||||
|
||||
func Append(fmter Formatter, b []byte, v interface{}, custom CustomAppender) []byte {
|
||||
switch v := v.(type) {
|
||||
case nil:
|
||||
return dialect.AppendNull(b)
|
||||
case bool:
|
||||
return dialect.AppendBool(b, v)
|
||||
case int:
|
||||
return strconv.AppendInt(b, int64(v), 10)
|
||||
case int32:
|
||||
return strconv.AppendInt(b, int64(v), 10)
|
||||
case int64:
|
||||
return strconv.AppendInt(b, v, 10)
|
||||
case uint:
|
||||
return strconv.AppendUint(b, uint64(v), 10)
|
||||
case uint32:
|
||||
return strconv.AppendUint(b, uint64(v), 10)
|
||||
case uint64:
|
||||
return strconv.AppendUint(b, v, 10)
|
||||
case float32:
|
||||
return dialect.AppendFloat32(b, v)
|
||||
case float64:
|
||||
return dialect.AppendFloat64(b, v)
|
||||
case string:
|
||||
return dialect.AppendString(b, v)
|
||||
case time.Time:
|
||||
return dialect.AppendTime(b, v)
|
||||
case []byte:
|
||||
return dialect.AppendBytes(b, v)
|
||||
case QueryAppender:
|
||||
return AppendQueryAppender(fmter, b, v)
|
||||
default:
|
||||
vv := reflect.ValueOf(v)
|
||||
if vv.Kind() == reflect.Ptr && vv.IsNil() {
|
||||
return dialect.AppendNull(b)
|
||||
}
|
||||
appender := Appender(vv.Type(), custom)
|
||||
return appender(fmter, b, vv)
|
||||
}
|
||||
}
|
||||
|
||||
func appendMsgpack(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
hexEnc := internal.NewHexEncoder(b)
|
||||
|
||||
enc := msgpack.GetEncoder()
|
||||
defer msgpack.PutEncoder(enc)
|
||||
|
||||
enc.Reset(hexEnc)
|
||||
if err := enc.EncodeValue(v); err != nil {
|
||||
return dialect.AppendError(b, err)
|
||||
}
|
||||
|
||||
if err := hexEnc.Close(); err != nil {
|
||||
return dialect.AppendError(b, err)
|
||||
}
|
||||
|
||||
return hexEnc.Bytes()
|
||||
}
|
||||
|
||||
func AppendQueryAppender(fmter Formatter, b []byte, app QueryAppender) []byte {
|
||||
bb, err := app.AppendQuery(fmter, b)
|
||||
if err != nil {
|
||||
return dialect.AppendError(b, err)
|
||||
}
|
||||
return bb
|
||||
}
|
||||
237
vendor/github.com/uptrace/bun/schema/append_value.go
generated
vendored
Normal file
237
vendor/github.com/uptrace/bun/schema/append_value.go
generated
vendored
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bun/dialect"
|
||||
"github.com/uptrace/bun/extra/bunjson"
|
||||
"github.com/uptrace/bun/internal"
|
||||
)
|
||||
|
||||
var (
|
||||
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
|
||||
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
|
||||
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
|
||||
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
|
||||
|
||||
driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
|
||||
queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem()
|
||||
)
|
||||
|
||||
type (
|
||||
AppenderFunc func(fmter Formatter, b []byte, v reflect.Value) []byte
|
||||
CustomAppender func(typ reflect.Type) AppenderFunc
|
||||
)
|
||||
|
||||
var appenders = []AppenderFunc{
|
||||
reflect.Bool: AppendBoolValue,
|
||||
reflect.Int: AppendIntValue,
|
||||
reflect.Int8: AppendIntValue,
|
||||
reflect.Int16: AppendIntValue,
|
||||
reflect.Int32: AppendIntValue,
|
||||
reflect.Int64: AppendIntValue,
|
||||
reflect.Uint: AppendUintValue,
|
||||
reflect.Uint8: AppendUintValue,
|
||||
reflect.Uint16: AppendUintValue,
|
||||
reflect.Uint32: AppendUintValue,
|
||||
reflect.Uint64: AppendUintValue,
|
||||
reflect.Uintptr: nil,
|
||||
reflect.Float32: AppendFloat32Value,
|
||||
reflect.Float64: AppendFloat64Value,
|
||||
reflect.Complex64: nil,
|
||||
reflect.Complex128: nil,
|
||||
reflect.Array: AppendJSONValue,
|
||||
reflect.Chan: nil,
|
||||
reflect.Func: nil,
|
||||
reflect.Interface: nil,
|
||||
reflect.Map: AppendJSONValue,
|
||||
reflect.Ptr: nil,
|
||||
reflect.Slice: AppendJSONValue,
|
||||
reflect.String: AppendStringValue,
|
||||
reflect.Struct: AppendJSONValue,
|
||||
reflect.UnsafePointer: nil,
|
||||
}
|
||||
|
||||
func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc {
|
||||
switch typ {
|
||||
case timeType:
|
||||
return appendTimeValue
|
||||
case ipType:
|
||||
return appendIPValue
|
||||
case ipNetType:
|
||||
return appendIPNetValue
|
||||
case jsonRawMessageType:
|
||||
return appendJSONRawMessageValue
|
||||
}
|
||||
|
||||
if typ.Implements(queryAppenderType) {
|
||||
return appendQueryAppenderValue
|
||||
}
|
||||
if typ.Implements(driverValuerType) {
|
||||
return driverValueAppender(custom)
|
||||
}
|
||||
|
||||
kind := typ.Kind()
|
||||
|
||||
if kind != reflect.Ptr {
|
||||
ptr := reflect.PtrTo(typ)
|
||||
if ptr.Implements(queryAppenderType) {
|
||||
return addrAppender(appendQueryAppenderValue, custom)
|
||||
}
|
||||
if ptr.Implements(driverValuerType) {
|
||||
return addrAppender(driverValueAppender(custom), custom)
|
||||
}
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case reflect.Interface:
|
||||
return ifaceAppenderFunc(typ, custom)
|
||||
case reflect.Ptr:
|
||||
return ptrAppenderFunc(typ, custom)
|
||||
case reflect.Slice:
|
||||
if typ.Elem().Kind() == reflect.Uint8 {
|
||||
return appendBytesValue
|
||||
}
|
||||
case reflect.Array:
|
||||
if typ.Elem().Kind() == reflect.Uint8 {
|
||||
return appendArrayBytesValue
|
||||
}
|
||||
}
|
||||
|
||||
if custom != nil {
|
||||
if fn := custom(typ); fn != nil {
|
||||
return fn
|
||||
}
|
||||
}
|
||||
return appenders[typ.Kind()]
|
||||
}
|
||||
|
||||
func ifaceAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc {
|
||||
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
if v.IsNil() {
|
||||
return dialect.AppendNull(b)
|
||||
}
|
||||
elem := v.Elem()
|
||||
appender := Appender(elem.Type(), custom)
|
||||
return appender(fmter, b, elem)
|
||||
}
|
||||
}
|
||||
|
||||
func ptrAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc {
|
||||
appender := Appender(typ.Elem(), custom)
|
||||
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
if v.IsNil() {
|
||||
return dialect.AppendNull(b)
|
||||
}
|
||||
return appender(fmter, b, v.Elem())
|
||||
}
|
||||
}
|
||||
|
||||
func AppendBoolValue(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
return dialect.AppendBool(b, v.Bool())
|
||||
}
|
||||
|
||||
func AppendIntValue(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
return strconv.AppendInt(b, v.Int(), 10)
|
||||
}
|
||||
|
||||
func AppendUintValue(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
return strconv.AppendUint(b, v.Uint(), 10)
|
||||
}
|
||||
|
||||
func AppendFloat32Value(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
return dialect.AppendFloat32(b, float32(v.Float()))
|
||||
}
|
||||
|
||||
func AppendFloat64Value(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
return dialect.AppendFloat64(b, float64(v.Float()))
|
||||
}
|
||||
|
||||
func appendBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
return dialect.AppendBytes(b, v.Bytes())
|
||||
}
|
||||
|
||||
func appendArrayBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
if v.CanAddr() {
|
||||
return dialect.AppendBytes(b, v.Slice(0, v.Len()).Bytes())
|
||||
}
|
||||
|
||||
tmp := make([]byte, v.Len())
|
||||
reflect.Copy(reflect.ValueOf(tmp), v)
|
||||
b = dialect.AppendBytes(b, tmp)
|
||||
return b
|
||||
}
|
||||
|
||||
func AppendStringValue(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
return dialect.AppendString(b, v.String())
|
||||
}
|
||||
|
||||
func AppendJSONValue(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
bb, err := bunjson.Marshal(v.Interface())
|
||||
if err != nil {
|
||||
return dialect.AppendError(b, err)
|
||||
}
|
||||
|
||||
if len(bb) > 0 && bb[len(bb)-1] == '\n' {
|
||||
bb = bb[:len(bb)-1]
|
||||
}
|
||||
|
||||
return dialect.AppendJSON(b, bb)
|
||||
}
|
||||
|
||||
func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
tm := v.Interface().(time.Time)
|
||||
return dialect.AppendTime(b, tm)
|
||||
}
|
||||
|
||||
func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
ip := v.Interface().(net.IP)
|
||||
return dialect.AppendString(b, ip.String())
|
||||
}
|
||||
|
||||
func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
ipnet := v.Interface().(net.IPNet)
|
||||
return dialect.AppendString(b, ipnet.String())
|
||||
}
|
||||
|
||||
func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
bytes := v.Bytes()
|
||||
if bytes == nil {
|
||||
return dialect.AppendNull(b)
|
||||
}
|
||||
return dialect.AppendString(b, internal.String(bytes))
|
||||
}
|
||||
|
||||
func appendQueryAppenderValue(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
return AppendQueryAppender(fmter, b, v.Interface().(QueryAppender))
|
||||
}
|
||||
|
||||
func driverValueAppender(custom CustomAppender) AppenderFunc {
|
||||
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
return appendDriverValue(fmter, b, v.Interface().(driver.Valuer), custom)
|
||||
}
|
||||
}
|
||||
|
||||
func appendDriverValue(fmter Formatter, b []byte, v driver.Valuer, custom CustomAppender) []byte {
|
||||
value, err := v.Value()
|
||||
if err != nil {
|
||||
return dialect.AppendError(b, err)
|
||||
}
|
||||
return Append(fmter, b, value, custom)
|
||||
}
|
||||
|
||||
func addrAppender(fn AppenderFunc, custom CustomAppender) AppenderFunc {
|
||||
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
|
||||
if !v.CanAddr() {
|
||||
err := fmt.Errorf("bun: Append(nonaddressable %T)", v.Interface())
|
||||
return dialect.AppendError(b, err)
|
||||
}
|
||||
return fn(fmter, b, v.Addr())
|
||||
}
|
||||
}
|
||||
99
vendor/github.com/uptrace/bun/schema/dialect.go
generated
vendored
Normal file
99
vendor/github.com/uptrace/bun/schema/dialect.go
generated
vendored
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/uptrace/bun/dialect"
|
||||
"github.com/uptrace/bun/dialect/feature"
|
||||
)
|
||||
|
||||
type Dialect interface {
|
||||
Init(db *sql.DB)
|
||||
|
||||
Name() dialect.Name
|
||||
Features() feature.Feature
|
||||
|
||||
Tables() *Tables
|
||||
OnTable(table *Table)
|
||||
|
||||
IdentQuote() byte
|
||||
Append(fmter Formatter, b []byte, v interface{}) []byte
|
||||
Appender(typ reflect.Type) AppenderFunc
|
||||
FieldAppender(field *Field) AppenderFunc
|
||||
Scanner(typ reflect.Type) ScannerFunc
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type nopDialect struct {
|
||||
tables *Tables
|
||||
features feature.Feature
|
||||
|
||||
appenderMap sync.Map
|
||||
scannerMap sync.Map
|
||||
}
|
||||
|
||||
func newNopDialect() *nopDialect {
|
||||
d := new(nopDialect)
|
||||
d.tables = NewTables(d)
|
||||
d.features = feature.Returning
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *nopDialect) Init(*sql.DB) {}
|
||||
|
||||
func (d *nopDialect) Name() dialect.Name {
|
||||
return dialect.Invalid
|
||||
}
|
||||
|
||||
func (d *nopDialect) Features() feature.Feature {
|
||||
return d.features
|
||||
}
|
||||
|
||||
func (d *nopDialect) Tables() *Tables {
|
||||
return d.tables
|
||||
}
|
||||
|
||||
func (d *nopDialect) OnField(field *Field) {}
|
||||
|
||||
func (d *nopDialect) OnTable(table *Table) {}
|
||||
|
||||
func (d *nopDialect) IdentQuote() byte {
|
||||
return '"'
|
||||
}
|
||||
|
||||
func (d *nopDialect) Append(fmter Formatter, b []byte, v interface{}) []byte {
|
||||
return Append(fmter, b, v, nil)
|
||||
}
|
||||
|
||||
func (d *nopDialect) Appender(typ reflect.Type) AppenderFunc {
|
||||
if v, ok := d.appenderMap.Load(typ); ok {
|
||||
return v.(AppenderFunc)
|
||||
}
|
||||
|
||||
fn := Appender(typ, nil)
|
||||
|
||||
if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok {
|
||||
return v.(AppenderFunc)
|
||||
}
|
||||
return fn
|
||||
}
|
||||
|
||||
func (d *nopDialect) FieldAppender(field *Field) AppenderFunc {
|
||||
return FieldAppender(d, field)
|
||||
}
|
||||
|
||||
func (d *nopDialect) Scanner(typ reflect.Type) ScannerFunc {
|
||||
if v, ok := d.scannerMap.Load(typ); ok {
|
||||
return v.(ScannerFunc)
|
||||
}
|
||||
|
||||
fn := Scanner(typ)
|
||||
|
||||
if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok {
|
||||
return v.(ScannerFunc)
|
||||
}
|
||||
return fn
|
||||
}
|
||||
117
vendor/github.com/uptrace/bun/schema/field.go
generated
vendored
Normal file
117
vendor/github.com/uptrace/bun/schema/field.go
generated
vendored
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/uptrace/bun/dialect"
|
||||
"github.com/uptrace/bun/internal/tagparser"
|
||||
)
|
||||
|
||||
type Field struct {
|
||||
StructField reflect.StructField
|
||||
|
||||
Tag tagparser.Tag
|
||||
IndirectType reflect.Type
|
||||
Index []int
|
||||
|
||||
Name string // SQL name, .e.g. id
|
||||
SQLName Safe // escaped SQL name, e.g. "id"
|
||||
GoName string // struct field name, e.g. Id
|
||||
|
||||
DiscoveredSQLType string
|
||||
UserSQLType string
|
||||
CreateTableSQLType string
|
||||
SQLDefault string
|
||||
|
||||
OnDelete string
|
||||
OnUpdate string
|
||||
|
||||
IsPK bool
|
||||
NotNull bool
|
||||
NullZero bool
|
||||
AutoIncrement bool
|
||||
|
||||
Append AppenderFunc
|
||||
Scan ScannerFunc
|
||||
IsZero IsZeroerFunc
|
||||
}
|
||||
|
||||
func (f *Field) String() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
func (f *Field) Clone() *Field {
|
||||
cp := *f
|
||||
cp.Index = cp.Index[:len(f.Index):len(f.Index)]
|
||||
return &cp
|
||||
}
|
||||
|
||||
func (f *Field) Value(strct reflect.Value) reflect.Value {
|
||||
return fieldByIndexAlloc(strct, f.Index)
|
||||
}
|
||||
|
||||
func (f *Field) HasZeroValue(v reflect.Value) bool {
|
||||
for _, idx := range f.Index {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
return true
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
v = v.Field(idx)
|
||||
}
|
||||
return f.IsZero(v)
|
||||
}
|
||||
|
||||
func (f *Field) AppendValue(fmter Formatter, b []byte, strct reflect.Value) []byte {
|
||||
fv, ok := fieldByIndex(strct, f.Index)
|
||||
if !ok {
|
||||
return dialect.AppendNull(b)
|
||||
}
|
||||
|
||||
if f.NullZero && f.IsZero(fv) {
|
||||
return dialect.AppendNull(b)
|
||||
}
|
||||
if f.Append == nil {
|
||||
panic(fmt.Errorf("bun: AppendValue(unsupported %s)", fv.Type()))
|
||||
}
|
||||
return f.Append(fmter, b, fv)
|
||||
}
|
||||
|
||||
func (f *Field) ScanWithCheck(fv reflect.Value, src interface{}) error {
|
||||
if f.Scan == nil {
|
||||
return fmt.Errorf("bun: Scan(unsupported %s)", f.IndirectType)
|
||||
}
|
||||
return f.Scan(fv, src)
|
||||
}
|
||||
|
||||
func (f *Field) ScanValue(strct reflect.Value, src interface{}) error {
|
||||
if src == nil {
|
||||
if fv, ok := fieldByIndex(strct, f.Index); ok {
|
||||
return f.ScanWithCheck(fv, src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
fv := fieldByIndexAlloc(strct, f.Index)
|
||||
return f.ScanWithCheck(fv, src)
|
||||
}
|
||||
|
||||
func (f *Field) markAsPK() {
|
||||
f.IsPK = true
|
||||
f.NotNull = true
|
||||
f.NullZero = true
|
||||
}
|
||||
|
||||
func indexEqual(ind1, ind2 []int) bool {
|
||||
if len(ind1) != len(ind2) {
|
||||
return false
|
||||
}
|
||||
for i, ind := range ind1 {
|
||||
if ind != ind2[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
248
vendor/github.com/uptrace/bun/schema/formatter.go
generated
vendored
Normal file
248
vendor/github.com/uptrace/bun/schema/formatter.go
generated
vendored
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/uptrace/bun/dialect"
|
||||
"github.com/uptrace/bun/dialect/feature"
|
||||
"github.com/uptrace/bun/internal"
|
||||
"github.com/uptrace/bun/internal/parser"
|
||||
)
|
||||
|
||||
var nopFormatter = Formatter{
|
||||
dialect: newNopDialect(),
|
||||
}
|
||||
|
||||
type Formatter struct {
|
||||
dialect Dialect
|
||||
args *namedArgList
|
||||
}
|
||||
|
||||
func NewFormatter(dialect Dialect) Formatter {
|
||||
return Formatter{
|
||||
dialect: dialect,
|
||||
}
|
||||
}
|
||||
|
||||
func NewNopFormatter() Formatter {
|
||||
return nopFormatter
|
||||
}
|
||||
|
||||
func (f Formatter) IsNop() bool {
|
||||
return f.dialect.Name() == dialect.Invalid
|
||||
}
|
||||
|
||||
func (f Formatter) Dialect() Dialect {
|
||||
return f.dialect
|
||||
}
|
||||
|
||||
func (f Formatter) IdentQuote() byte {
|
||||
return f.dialect.IdentQuote()
|
||||
}
|
||||
|
||||
func (f Formatter) AppendIdent(b []byte, ident string) []byte {
|
||||
return dialect.AppendIdent(b, ident, f.IdentQuote())
|
||||
}
|
||||
|
||||
func (f Formatter) AppendValue(b []byte, v reflect.Value) []byte {
|
||||
if v.Kind() == reflect.Ptr && v.IsNil() {
|
||||
return dialect.AppendNull(b)
|
||||
}
|
||||
appender := f.dialect.Appender(v.Type())
|
||||
return appender(f, b, v)
|
||||
}
|
||||
|
||||
func (f Formatter) HasFeature(feature feature.Feature) bool {
|
||||
return f.dialect.Features().Has(feature)
|
||||
}
|
||||
|
||||
func (f Formatter) WithArg(arg NamedArgAppender) Formatter {
|
||||
return Formatter{
|
||||
dialect: f.dialect,
|
||||
args: f.args.WithArg(arg),
|
||||
}
|
||||
}
|
||||
|
||||
func (f Formatter) WithNamedArg(name string, value interface{}) Formatter {
|
||||
return Formatter{
|
||||
dialect: f.dialect,
|
||||
args: f.args.WithArg(&namedArg{name: name, value: value}),
|
||||
}
|
||||
}
|
||||
|
||||
func (f Formatter) FormatQuery(query string, args ...interface{}) string {
|
||||
if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 {
|
||||
return query
|
||||
}
|
||||
return internal.String(f.AppendQuery(nil, query, args...))
|
||||
}
|
||||
|
||||
func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) []byte {
|
||||
if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 {
|
||||
return append(dst, query...)
|
||||
}
|
||||
return f.append(dst, parser.NewString(query), args)
|
||||
}
|
||||
|
||||
func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte {
|
||||
var namedArgs NamedArgAppender
|
||||
if len(args) == 1 {
|
||||
var ok bool
|
||||
namedArgs, ok = args[0].(NamedArgAppender)
|
||||
if !ok {
|
||||
namedArgs, _ = newStructArgs(f, args[0])
|
||||
}
|
||||
}
|
||||
|
||||
var argIndex int
|
||||
for p.Valid() {
|
||||
b, ok := p.ReadSep('?')
|
||||
if !ok {
|
||||
dst = append(dst, b...)
|
||||
continue
|
||||
}
|
||||
if len(b) > 0 && b[len(b)-1] == '\\' {
|
||||
dst = append(dst, b[:len(b)-1]...)
|
||||
dst = append(dst, '?')
|
||||
continue
|
||||
}
|
||||
dst = append(dst, b...)
|
||||
|
||||
name, numeric := p.ReadIdentifier()
|
||||
if name != "" {
|
||||
if numeric {
|
||||
idx, err := strconv.Atoi(name)
|
||||
if err != nil {
|
||||
goto restore_arg
|
||||
}
|
||||
|
||||
if idx >= len(args) {
|
||||
goto restore_arg
|
||||
}
|
||||
|
||||
dst = f.appendArg(dst, args[idx])
|
||||
continue
|
||||
}
|
||||
|
||||
if namedArgs != nil {
|
||||
dst, ok = namedArgs.AppendNamedArg(f, dst, name)
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
dst, ok = f.args.AppendNamedArg(f, dst, name)
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
|
||||
restore_arg:
|
||||
dst = append(dst, '?')
|
||||
dst = append(dst, name...)
|
||||
continue
|
||||
}
|
||||
|
||||
if argIndex >= len(args) {
|
||||
dst = append(dst, '?')
|
||||
continue
|
||||
}
|
||||
|
||||
arg := args[argIndex]
|
||||
argIndex++
|
||||
|
||||
dst = f.appendArg(dst, arg)
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (f Formatter) appendArg(b []byte, arg interface{}) []byte {
|
||||
switch arg := arg.(type) {
|
||||
case QueryAppender:
|
||||
bb, err := arg.AppendQuery(f, b)
|
||||
if err != nil {
|
||||
return dialect.AppendError(b, err)
|
||||
}
|
||||
return bb
|
||||
default:
|
||||
return f.dialect.Append(f, b, arg)
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type NamedArgAppender interface {
|
||||
AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool)
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type namedArgList struct {
|
||||
arg NamedArgAppender
|
||||
next *namedArgList
|
||||
}
|
||||
|
||||
func (l *namedArgList) WithArg(arg NamedArgAppender) *namedArgList {
|
||||
return &namedArgList{
|
||||
arg: arg,
|
||||
next: l,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *namedArgList) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
|
||||
for l != nil && l.arg != nil {
|
||||
if b, ok := l.arg.AppendNamedArg(fmter, b, name); ok {
|
||||
return b, true
|
||||
}
|
||||
l = l.next
|
||||
}
|
||||
return b, false
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type namedArg struct {
|
||||
name string
|
||||
value interface{}
|
||||
}
|
||||
|
||||
var _ NamedArgAppender = (*namedArg)(nil)
|
||||
|
||||
func (a *namedArg) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
|
||||
if a.name == name {
|
||||
return fmter.appendArg(b, a.value), true
|
||||
}
|
||||
return b, false
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
var _ NamedArgAppender = (*structArgs)(nil)
|
||||
|
||||
type structArgs struct {
|
||||
table *Table
|
||||
strct reflect.Value
|
||||
}
|
||||
|
||||
func newStructArgs(fmter Formatter, strct interface{}) (*structArgs, bool) {
|
||||
v := reflect.ValueOf(strct)
|
||||
if !v.IsValid() {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
v = reflect.Indirect(v)
|
||||
if v.Kind() != reflect.Struct {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return &structArgs{
|
||||
table: fmter.Dialect().Tables().Get(v.Type()),
|
||||
strct: v,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (m *structArgs) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
|
||||
return m.table.AppendNamedArg(fmter, b, name, m.strct)
|
||||
}
|
||||
20
vendor/github.com/uptrace/bun/schema/hook.go
generated
vendored
Normal file
20
vendor/github.com/uptrace/bun/schema/hook.go
generated
vendored
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type BeforeScanHook interface {
|
||||
BeforeScan(context.Context) error
|
||||
}
|
||||
|
||||
var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem()
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type AfterScanHook interface {
|
||||
AfterScan(context.Context) error
|
||||
}
|
||||
|
||||
var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem()
|
||||
32
vendor/github.com/uptrace/bun/schema/relation.go
generated
vendored
Normal file
32
vendor/github.com/uptrace/bun/schema/relation.go
generated
vendored
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
InvalidRelation = iota
|
||||
HasOneRelation
|
||||
BelongsToRelation
|
||||
HasManyRelation
|
||||
ManyToManyRelation
|
||||
)
|
||||
|
||||
type Relation struct {
|
||||
Type int
|
||||
Field *Field
|
||||
JoinTable *Table
|
||||
BaseFields []*Field
|
||||
JoinFields []*Field
|
||||
|
||||
PolymorphicField *Field
|
||||
PolymorphicValue string
|
||||
|
||||
M2MTable *Table
|
||||
M2MBaseFields []*Field
|
||||
M2MJoinFields []*Field
|
||||
}
|
||||
|
||||
func (r *Relation) String() string {
|
||||
return fmt.Sprintf("relation=%s", r.Field.GoName)
|
||||
}
|
||||
392
vendor/github.com/uptrace/bun/schema/scan.go
generated
vendored
Normal file
392
vendor/github.com/uptrace/bun/schema/scan.go
generated
vendored
Normal file
|
|
@ -0,0 +1,392 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
|
||||
"github.com/uptrace/bun/extra/bunjson"
|
||||
"github.com/uptrace/bun/internal"
|
||||
)
|
||||
|
||||
var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
|
||||
|
||||
type ScannerFunc func(dest reflect.Value, src interface{}) error
|
||||
|
||||
var scanners = []ScannerFunc{
|
||||
reflect.Bool: scanBool,
|
||||
reflect.Int: scanInt64,
|
||||
reflect.Int8: scanInt64,
|
||||
reflect.Int16: scanInt64,
|
||||
reflect.Int32: scanInt64,
|
||||
reflect.Int64: scanInt64,
|
||||
reflect.Uint: scanUint64,
|
||||
reflect.Uint8: scanUint64,
|
||||
reflect.Uint16: scanUint64,
|
||||
reflect.Uint32: scanUint64,
|
||||
reflect.Uint64: scanUint64,
|
||||
reflect.Uintptr: scanUint64,
|
||||
reflect.Float32: scanFloat64,
|
||||
reflect.Float64: scanFloat64,
|
||||
reflect.Complex64: nil,
|
||||
reflect.Complex128: nil,
|
||||
reflect.Array: nil,
|
||||
reflect.Chan: nil,
|
||||
reflect.Func: nil,
|
||||
reflect.Map: scanJSON,
|
||||
reflect.Ptr: nil,
|
||||
reflect.Slice: scanJSON,
|
||||
reflect.String: scanString,
|
||||
reflect.Struct: scanJSON,
|
||||
reflect.UnsafePointer: nil,
|
||||
}
|
||||
|
||||
func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
|
||||
if field.Tag.HasOption("msgpack") {
|
||||
return scanMsgpack
|
||||
}
|
||||
if field.Tag.HasOption("json_use_number") {
|
||||
return scanJSONUseNumber
|
||||
}
|
||||
return dialect.Scanner(field.StructField.Type)
|
||||
}
|
||||
|
||||
func Scanner(typ reflect.Type) ScannerFunc {
|
||||
kind := typ.Kind()
|
||||
|
||||
if kind == reflect.Ptr {
|
||||
if fn := Scanner(typ.Elem()); fn != nil {
|
||||
return ptrScanner(fn)
|
||||
}
|
||||
}
|
||||
|
||||
if typ.Implements(scannerType) {
|
||||
return scanScanner
|
||||
}
|
||||
|
||||
if kind != reflect.Ptr {
|
||||
ptr := reflect.PtrTo(typ)
|
||||
if ptr.Implements(scannerType) {
|
||||
return addrScanner(scanScanner)
|
||||
}
|
||||
}
|
||||
|
||||
switch typ {
|
||||
case timeType:
|
||||
return scanTime
|
||||
case ipType:
|
||||
return scanIP
|
||||
case ipNetType:
|
||||
return scanIPNet
|
||||
case jsonRawMessageType:
|
||||
return scanJSONRawMessage
|
||||
}
|
||||
|
||||
return scanners[kind]
|
||||
}
|
||||
|
||||
func scanBool(dest reflect.Value, src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case nil:
|
||||
dest.SetBool(false)
|
||||
return nil
|
||||
case bool:
|
||||
dest.SetBool(src)
|
||||
return nil
|
||||
case int64:
|
||||
dest.SetBool(src != 0)
|
||||
return nil
|
||||
case []byte:
|
||||
if len(src) == 1 {
|
||||
dest.SetBool(src[0] != '0')
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
|
||||
}
|
||||
|
||||
func scanInt64(dest reflect.Value, src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case nil:
|
||||
dest.SetInt(0)
|
||||
return nil
|
||||
case int64:
|
||||
dest.SetInt(src)
|
||||
return nil
|
||||
case uint64:
|
||||
dest.SetInt(int64(src))
|
||||
return nil
|
||||
case []byte:
|
||||
n, err := strconv.ParseInt(internal.String(src), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dest.SetInt(n)
|
||||
return nil
|
||||
case string:
|
||||
n, err := strconv.ParseInt(src, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dest.SetInt(n)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
|
||||
}
|
||||
|
||||
func scanUint64(dest reflect.Value, src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case nil:
|
||||
dest.SetUint(0)
|
||||
return nil
|
||||
case uint64:
|
||||
dest.SetUint(src)
|
||||
return nil
|
||||
case int64:
|
||||
dest.SetUint(uint64(src))
|
||||
return nil
|
||||
case []byte:
|
||||
n, err := strconv.ParseUint(internal.String(src), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dest.SetUint(n)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
|
||||
}
|
||||
|
||||
func scanFloat64(dest reflect.Value, src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case nil:
|
||||
dest.SetFloat(0)
|
||||
return nil
|
||||
case float64:
|
||||
dest.SetFloat(src)
|
||||
return nil
|
||||
case []byte:
|
||||
f, err := strconv.ParseFloat(internal.String(src), 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dest.SetFloat(f)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
|
||||
}
|
||||
|
||||
func scanString(dest reflect.Value, src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case nil:
|
||||
dest.SetString("")
|
||||
return nil
|
||||
case string:
|
||||
dest.SetString(src)
|
||||
return nil
|
||||
case []byte:
|
||||
dest.SetString(string(src))
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
|
||||
}
|
||||
|
||||
func scanTime(dest reflect.Value, src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case nil:
|
||||
destTime := dest.Addr().Interface().(*time.Time)
|
||||
*destTime = time.Time{}
|
||||
return nil
|
||||
case time.Time:
|
||||
destTime := dest.Addr().Interface().(*time.Time)
|
||||
*destTime = src
|
||||
return nil
|
||||
case string:
|
||||
srcTime, err := internal.ParseTime(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destTime := dest.Addr().Interface().(*time.Time)
|
||||
*destTime = srcTime
|
||||
return nil
|
||||
case []byte:
|
||||
srcTime, err := internal.ParseTime(internal.String(src))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destTime := dest.Addr().Interface().(*time.Time)
|
||||
*destTime = srcTime
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
|
||||
}
|
||||
|
||||
func scanScanner(dest reflect.Value, src interface{}) error {
|
||||
return dest.Interface().(sql.Scanner).Scan(src)
|
||||
}
|
||||
|
||||
func scanMsgpack(dest reflect.Value, src interface{}) error {
|
||||
if src == nil {
|
||||
return scanNull(dest)
|
||||
}
|
||||
|
||||
b, err := toBytes(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dec := msgpack.GetDecoder()
|
||||
defer msgpack.PutDecoder(dec)
|
||||
|
||||
dec.Reset(bytes.NewReader(b))
|
||||
return dec.DecodeValue(dest)
|
||||
}
|
||||
|
||||
func scanJSON(dest reflect.Value, src interface{}) error {
|
||||
if src == nil {
|
||||
return scanNull(dest)
|
||||
}
|
||||
|
||||
b, err := toBytes(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bunjson.Unmarshal(b, dest.Addr().Interface())
|
||||
}
|
||||
|
||||
func scanJSONUseNumber(dest reflect.Value, src interface{}) error {
|
||||
if src == nil {
|
||||
return scanNull(dest)
|
||||
}
|
||||
|
||||
b, err := toBytes(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dec := bunjson.NewDecoder(bytes.NewReader(b))
|
||||
dec.UseNumber()
|
||||
return dec.Decode(dest.Addr().Interface())
|
||||
}
|
||||
|
||||
func scanIP(dest reflect.Value, src interface{}) error {
|
||||
if src == nil {
|
||||
return scanNull(dest)
|
||||
}
|
||||
|
||||
b, err := toBytes(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ip := net.ParseIP(internal.String(b))
|
||||
if ip == nil {
|
||||
return fmt.Errorf("bun: invalid ip: %q", b)
|
||||
}
|
||||
|
||||
ptr := dest.Addr().Interface().(*net.IP)
|
||||
*ptr = ip
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func scanIPNet(dest reflect.Value, src interface{}) error {
|
||||
if src == nil {
|
||||
return scanNull(dest)
|
||||
}
|
||||
|
||||
b, err := toBytes(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, ipnet, err := net.ParseCIDR(internal.String(b))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ptr := dest.Addr().Interface().(*net.IPNet)
|
||||
*ptr = *ipnet
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func scanJSONRawMessage(dest reflect.Value, src interface{}) error {
|
||||
if src == nil {
|
||||
dest.SetBytes(nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
b, err := toBytes(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dest.SetBytes(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
func addrScanner(fn ScannerFunc) ScannerFunc {
|
||||
return func(dest reflect.Value, src interface{}) error {
|
||||
if !dest.CanAddr() {
|
||||
return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface())
|
||||
}
|
||||
return fn(dest.Addr(), src)
|
||||
}
|
||||
}
|
||||
|
||||
func toBytes(src interface{}) ([]byte, error) {
|
||||
switch src := src.(type) {
|
||||
case string:
|
||||
return internal.Bytes(src), nil
|
||||
case []byte:
|
||||
return src, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src)
|
||||
}
|
||||
}
|
||||
|
||||
func ptrScanner(fn ScannerFunc) ScannerFunc {
|
||||
return func(dest reflect.Value, src interface{}) error {
|
||||
if src == nil {
|
||||
if !dest.CanAddr() {
|
||||
if dest.IsNil() {
|
||||
return nil
|
||||
}
|
||||
return fn(dest.Elem(), src)
|
||||
}
|
||||
|
||||
if !dest.IsNil() {
|
||||
dest.Set(reflect.New(dest.Type().Elem()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if dest.IsNil() {
|
||||
dest.Set(reflect.New(dest.Type().Elem()))
|
||||
}
|
||||
return fn(dest.Elem(), src)
|
||||
}
|
||||
}
|
||||
|
||||
func scanNull(dest reflect.Value) error {
|
||||
if nilable(dest.Kind()) && dest.IsNil() {
|
||||
return nil
|
||||
}
|
||||
dest.Set(reflect.New(dest.Type()).Elem())
|
||||
return nil
|
||||
}
|
||||
|
||||
func nilable(kind reflect.Kind) bool {
|
||||
switch kind {
|
||||
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
76
vendor/github.com/uptrace/bun/schema/sqlfmt.go
generated
vendored
Normal file
76
vendor/github.com/uptrace/bun/schema/sqlfmt.go
generated
vendored
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
package schema
|
||||
|
||||
type QueryAppender interface {
|
||||
AppendQuery(fmter Formatter, b []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
type ColumnsAppender interface {
|
||||
AppendColumns(fmter Formatter, b []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Safe represents a safe SQL query.
|
||||
type Safe string
|
||||
|
||||
var _ QueryAppender = (*Safe)(nil)
|
||||
|
||||
func (s Safe) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
|
||||
return append(b, s...), nil
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Ident represents a SQL identifier, for example, table or column name.
|
||||
type Ident string
|
||||
|
||||
var _ QueryAppender = (*Ident)(nil)
|
||||
|
||||
func (s Ident) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
|
||||
return fmter.AppendIdent(b, string(s)), nil
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type QueryWithArgs struct {
|
||||
Query string
|
||||
Args []interface{}
|
||||
}
|
||||
|
||||
var _ QueryAppender = QueryWithArgs{}
|
||||
|
||||
func SafeQuery(query string, args []interface{}) QueryWithArgs {
|
||||
if query != "" && args == nil {
|
||||
args = make([]interface{}, 0)
|
||||
}
|
||||
return QueryWithArgs{Query: query, Args: args}
|
||||
}
|
||||
|
||||
func UnsafeIdent(ident string) QueryWithArgs {
|
||||
return QueryWithArgs{Query: ident}
|
||||
}
|
||||
|
||||
func (q QueryWithArgs) IsZero() bool {
|
||||
return q.Query == "" && q.Args == nil
|
||||
}
|
||||
|
||||
func (q QueryWithArgs) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
|
||||
if q.Args == nil {
|
||||
return fmter.AppendIdent(b, q.Query), nil
|
||||
}
|
||||
return fmter.AppendQuery(b, q.Query, q.Args...), nil
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type QueryWithSep struct {
|
||||
QueryWithArgs
|
||||
Sep string
|
||||
}
|
||||
|
||||
func SafeQueryWithSep(query string, args []interface{}, sep string) QueryWithSep {
|
||||
return QueryWithSep{
|
||||
QueryWithArgs: SafeQuery(query, args),
|
||||
Sep: sep,
|
||||
}
|
||||
}
|
||||
129
vendor/github.com/uptrace/bun/schema/sqltype.go
generated
vendored
Normal file
129
vendor/github.com/uptrace/bun/schema/sqltype.go
generated
vendored
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bun/dialect"
|
||||
"github.com/uptrace/bun/dialect/sqltype"
|
||||
"github.com/uptrace/bun/internal"
|
||||
)
|
||||
|
||||
var (
|
||||
bunNullTimeType = reflect.TypeOf((*NullTime)(nil)).Elem()
|
||||
nullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem()
|
||||
nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem()
|
||||
nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem()
|
||||
nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem()
|
||||
nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem()
|
||||
)
|
||||
|
||||
var sqlTypes = []string{
|
||||
reflect.Bool: sqltype.Boolean,
|
||||
reflect.Int: sqltype.BigInt,
|
||||
reflect.Int8: sqltype.SmallInt,
|
||||
reflect.Int16: sqltype.SmallInt,
|
||||
reflect.Int32: sqltype.Integer,
|
||||
reflect.Int64: sqltype.BigInt,
|
||||
reflect.Uint: sqltype.BigInt,
|
||||
reflect.Uint8: sqltype.SmallInt,
|
||||
reflect.Uint16: sqltype.SmallInt,
|
||||
reflect.Uint32: sqltype.Integer,
|
||||
reflect.Uint64: sqltype.BigInt,
|
||||
reflect.Uintptr: sqltype.BigInt,
|
||||
reflect.Float32: sqltype.Real,
|
||||
reflect.Float64: sqltype.DoublePrecision,
|
||||
reflect.Complex64: "",
|
||||
reflect.Complex128: "",
|
||||
reflect.Array: "",
|
||||
reflect.Chan: "",
|
||||
reflect.Func: "",
|
||||
reflect.Interface: "",
|
||||
reflect.Map: sqltype.VarChar,
|
||||
reflect.Ptr: "",
|
||||
reflect.Slice: sqltype.VarChar,
|
||||
reflect.String: sqltype.VarChar,
|
||||
reflect.Struct: sqltype.VarChar,
|
||||
reflect.UnsafePointer: "",
|
||||
}
|
||||
|
||||
func DiscoverSQLType(typ reflect.Type) string {
|
||||
switch typ {
|
||||
case timeType, nullTimeType, bunNullTimeType:
|
||||
return sqltype.Timestamp
|
||||
case nullBoolType:
|
||||
return sqltype.Boolean
|
||||
case nullFloatType:
|
||||
return sqltype.DoublePrecision
|
||||
case nullIntType:
|
||||
return sqltype.BigInt
|
||||
case nullStringType:
|
||||
return sqltype.VarChar
|
||||
}
|
||||
return sqlTypes[typ.Kind()]
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
var jsonNull = []byte("null")
|
||||
|
||||
// NullTime is a time.Time wrapper that marshals zero time as JSON null and SQL NULL.
|
||||
type NullTime struct {
|
||||
time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
_ json.Marshaler = (*NullTime)(nil)
|
||||
_ json.Unmarshaler = (*NullTime)(nil)
|
||||
_ sql.Scanner = (*NullTime)(nil)
|
||||
_ QueryAppender = (*NullTime)(nil)
|
||||
)
|
||||
|
||||
func (tm NullTime) MarshalJSON() ([]byte, error) {
|
||||
if tm.IsZero() {
|
||||
return jsonNull, nil
|
||||
}
|
||||
return tm.Time.MarshalJSON()
|
||||
}
|
||||
|
||||
func (tm *NullTime) UnmarshalJSON(b []byte) error {
|
||||
if bytes.Equal(b, jsonNull) {
|
||||
tm.Time = time.Time{}
|
||||
return nil
|
||||
}
|
||||
return tm.Time.UnmarshalJSON(b)
|
||||
}
|
||||
|
||||
func (tm NullTime) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
|
||||
if tm.IsZero() {
|
||||
return dialect.AppendNull(b), nil
|
||||
}
|
||||
return dialect.AppendTime(b, tm.Time), nil
|
||||
}
|
||||
|
||||
func (tm *NullTime) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
tm.Time = time.Time{}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
newtm, err := internal.ParseTime(internal.String(src))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tm.Time = newtm
|
||||
return nil
|
||||
case time.Time:
|
||||
tm.Time = src
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("bun: can't scan %#v into NullTime", src)
|
||||
}
|
||||
}
|
||||
948
vendor/github.com/uptrace/bun/schema/table.go
generated
vendored
Normal file
948
vendor/github.com/uptrace/bun/schema/table.go
generated
vendored
Normal file
|
|
@ -0,0 +1,948 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
|
||||
"github.com/uptrace/bun/internal"
|
||||
"github.com/uptrace/bun/internal/tagparser"
|
||||
)
|
||||
|
||||
const (
|
||||
beforeScanHookFlag internal.Flag = 1 << iota
|
||||
afterScanHookFlag
|
||||
)
|
||||
|
||||
var (
|
||||
baseModelType = reflect.TypeOf((*BaseModel)(nil)).Elem()
|
||||
tableNameInflector = inflection.Plural
|
||||
)
|
||||
|
||||
type BaseModel struct{}
|
||||
|
||||
// SetTableNameInflector overrides the default func that pluralizes
|
||||
// model name to get table name, e.g. my_article becomes my_articles.
|
||||
func SetTableNameInflector(fn func(string) string) {
|
||||
tableNameInflector = fn
|
||||
}
|
||||
|
||||
// Table represents a SQL table created from Go struct.
|
||||
type Table struct {
|
||||
dialect Dialect
|
||||
|
||||
Type reflect.Type
|
||||
ZeroValue reflect.Value // reflect.Struct
|
||||
ZeroIface interface{} // struct pointer
|
||||
|
||||
TypeName string
|
||||
ModelName string
|
||||
|
||||
Name string
|
||||
SQLName Safe
|
||||
SQLNameForSelects Safe
|
||||
Alias string
|
||||
SQLAlias Safe
|
||||
|
||||
Fields []*Field // PKs + DataFields
|
||||
PKs []*Field
|
||||
DataFields []*Field
|
||||
|
||||
fieldsMapMu sync.RWMutex
|
||||
FieldMap map[string]*Field
|
||||
|
||||
Relations map[string]*Relation
|
||||
Unique map[string][]*Field
|
||||
|
||||
SoftDeleteField *Field
|
||||
UpdateSoftDeleteField func(fv reflect.Value) error
|
||||
|
||||
allFields []*Field // read only
|
||||
skippedFields []*Field
|
||||
|
||||
flags internal.Flag
|
||||
}
|
||||
|
||||
func newTable(dialect Dialect, typ reflect.Type) *Table {
|
||||
t := new(Table)
|
||||
t.dialect = dialect
|
||||
t.Type = typ
|
||||
t.ZeroValue = reflect.New(t.Type).Elem()
|
||||
t.ZeroIface = reflect.New(t.Type).Interface()
|
||||
t.TypeName = internal.ToExported(t.Type.Name())
|
||||
t.ModelName = internal.Underscore(t.Type.Name())
|
||||
tableName := tableNameInflector(t.ModelName)
|
||||
t.setName(tableName)
|
||||
t.Alias = t.ModelName
|
||||
t.SQLAlias = t.quoteIdent(t.ModelName)
|
||||
|
||||
hooks := []struct {
|
||||
typ reflect.Type
|
||||
flag internal.Flag
|
||||
}{
|
||||
{beforeScanHookType, beforeScanHookFlag},
|
||||
{afterScanHookType, afterScanHookFlag},
|
||||
}
|
||||
|
||||
typ = reflect.PtrTo(t.Type)
|
||||
for _, hook := range hooks {
|
||||
if typ.Implements(hook.typ) {
|
||||
t.flags = t.flags.Set(hook.flag)
|
||||
}
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Table) init1() {
|
||||
t.initFields()
|
||||
}
|
||||
|
||||
func (t *Table) init2() {
|
||||
t.initInlines()
|
||||
t.initRelations()
|
||||
t.skippedFields = nil
|
||||
}
|
||||
|
||||
func (t *Table) setName(name string) {
|
||||
t.Name = name
|
||||
t.SQLName = t.quoteIdent(name)
|
||||
t.SQLNameForSelects = t.quoteIdent(name)
|
||||
if t.SQLAlias == "" {
|
||||
t.Alias = name
|
||||
t.SQLAlias = t.quoteIdent(name)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Table) String() string {
|
||||
return "model=" + t.TypeName
|
||||
}
|
||||
|
||||
func (t *Table) CheckPKs() error {
|
||||
if len(t.PKs) == 0 {
|
||||
return fmt.Errorf("bun: %s does not have primary keys", t)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Table) addField(field *Field) {
|
||||
t.Fields = append(t.Fields, field)
|
||||
if field.IsPK {
|
||||
t.PKs = append(t.PKs, field)
|
||||
} else {
|
||||
t.DataFields = append(t.DataFields, field)
|
||||
}
|
||||
t.FieldMap[field.Name] = field
|
||||
}
|
||||
|
||||
func (t *Table) removeField(field *Field) {
|
||||
t.Fields = removeField(t.Fields, field)
|
||||
if field.IsPK {
|
||||
t.PKs = removeField(t.PKs, field)
|
||||
} else {
|
||||
t.DataFields = removeField(t.DataFields, field)
|
||||
}
|
||||
delete(t.FieldMap, field.Name)
|
||||
}
|
||||
|
||||
func (t *Table) fieldWithLock(name string) *Field {
|
||||
t.fieldsMapMu.RLock()
|
||||
field := t.FieldMap[name]
|
||||
t.fieldsMapMu.RUnlock()
|
||||
return field
|
||||
}
|
||||
|
||||
func (t *Table) HasField(name string) bool {
|
||||
_, ok := t.FieldMap[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (t *Table) Field(name string) (*Field, error) {
|
||||
field, ok := t.FieldMap[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("bun: %s does not have column=%s", t, name)
|
||||
}
|
||||
return field, nil
|
||||
}
|
||||
|
||||
func (t *Table) fieldByGoName(name string) *Field {
|
||||
for _, f := range t.allFields {
|
||||
if f.GoName == name {
|
||||
return f
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Table) initFields() {
|
||||
t.Fields = make([]*Field, 0, t.Type.NumField())
|
||||
t.FieldMap = make(map[string]*Field, t.Type.NumField())
|
||||
t.addFields(t.Type, nil)
|
||||
|
||||
if len(t.PKs) > 0 {
|
||||
return
|
||||
}
|
||||
for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} {
|
||||
if field, ok := t.FieldMap[name]; ok {
|
||||
field.markAsPK()
|
||||
t.PKs = []*Field{field}
|
||||
t.DataFields = removeField(t.DataFields, field)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(t.PKs) == 1 {
|
||||
switch t.PKs[0].IndirectType.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
t.PKs[0].AutoIncrement = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
f := typ.Field(i)
|
||||
|
||||
// Make a copy so slice is not shared between fields.
|
||||
index := make([]int, len(baseIndex))
|
||||
copy(index, baseIndex)
|
||||
|
||||
if f.Anonymous {
|
||||
if f.Tag.Get("bun") == "-" {
|
||||
continue
|
||||
}
|
||||
if f.Name == "BaseModel" && f.Type == baseModelType {
|
||||
if len(index) == 0 {
|
||||
t.processBaseModelField(f)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
fieldType := indirectType(f.Type)
|
||||
if fieldType.Kind() != reflect.Struct {
|
||||
continue
|
||||
}
|
||||
t.addFields(fieldType, append(index, f.Index...))
|
||||
|
||||
tag := tagparser.Parse(f.Tag.Get("bun"))
|
||||
if _, inherit := tag.Options["inherit"]; inherit {
|
||||
embeddedTable := t.dialect.Tables().Ref(fieldType)
|
||||
t.TypeName = embeddedTable.TypeName
|
||||
t.SQLName = embeddedTable.SQLName
|
||||
t.SQLNameForSelects = embeddedTable.SQLNameForSelects
|
||||
t.Alias = embeddedTable.Alias
|
||||
t.SQLAlias = embeddedTable.SQLAlias
|
||||
t.ModelName = embeddedTable.ModelName
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
field := t.newField(f, index)
|
||||
if field != nil {
|
||||
t.addField(field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Table) processBaseModelField(f reflect.StructField) {
|
||||
tag := tagparser.Parse(f.Tag.Get("bun"))
|
||||
|
||||
if isKnownTableOption(tag.Name) {
|
||||
internal.Warn.Printf(
|
||||
"%s.%s tag name %q is also an option name; is it a mistake?",
|
||||
t.TypeName, f.Name, tag.Name,
|
||||
)
|
||||
}
|
||||
|
||||
for name := range tag.Options {
|
||||
if !isKnownTableOption(name) {
|
||||
internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
|
||||
}
|
||||
}
|
||||
|
||||
if tag.Name != "" {
|
||||
t.setName(tag.Name)
|
||||
}
|
||||
|
||||
if s, ok := tag.Options["select"]; ok {
|
||||
t.SQLNameForSelects = t.quoteTableName(s)
|
||||
}
|
||||
|
||||
if s, ok := tag.Options["alias"]; ok {
|
||||
t.Alias = s
|
||||
t.SQLAlias = t.quoteIdent(s)
|
||||
}
|
||||
}
|
||||
|
||||
//nolint
|
||||
func (t *Table) newField(f reflect.StructField, index []int) *Field {
|
||||
tag := tagparser.Parse(f.Tag.Get("bun"))
|
||||
|
||||
if f.PkgPath != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
sqlName := internal.Underscore(f.Name)
|
||||
|
||||
if tag.Name != sqlName && isKnownFieldOption(tag.Name) {
|
||||
internal.Warn.Printf(
|
||||
"%s.%s tag name %q is also an option name; is it a mistake?",
|
||||
t.TypeName, f.Name, tag.Name,
|
||||
)
|
||||
}
|
||||
|
||||
for name := range tag.Options {
|
||||
if !isKnownFieldOption(name) {
|
||||
internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
|
||||
}
|
||||
}
|
||||
|
||||
skip := tag.Name == "-"
|
||||
if !skip && tag.Name != "" {
|
||||
sqlName = tag.Name
|
||||
}
|
||||
|
||||
index = append(index, f.Index...)
|
||||
if field := t.fieldWithLock(sqlName); field != nil {
|
||||
if indexEqual(field.Index, index) {
|
||||
return field
|
||||
}
|
||||
t.removeField(field)
|
||||
}
|
||||
|
||||
field := &Field{
|
||||
StructField: f,
|
||||
|
||||
Tag: tag,
|
||||
IndirectType: indirectType(f.Type),
|
||||
Index: index,
|
||||
|
||||
Name: sqlName,
|
||||
GoName: f.Name,
|
||||
SQLName: t.quoteIdent(sqlName),
|
||||
}
|
||||
|
||||
field.NotNull = tag.HasOption("notnull")
|
||||
field.NullZero = tag.HasOption("nullzero")
|
||||
field.AutoIncrement = tag.HasOption("autoincrement")
|
||||
if tag.HasOption("pk") {
|
||||
field.markAsPK()
|
||||
}
|
||||
if tag.HasOption("allowzero") {
|
||||
if tag.HasOption("nullzero") {
|
||||
internal.Warn.Printf(
|
||||
"%s.%s: nullzero and allowzero options are mutually exclusive",
|
||||
t.TypeName, f.Name,
|
||||
)
|
||||
}
|
||||
field.NullZero = false
|
||||
}
|
||||
|
||||
if v, ok := tag.Options["unique"]; ok {
|
||||
// Split the value by comma, this will allow multiple names to be specified.
|
||||
// We can use this to create multiple named unique constraints where a single column
|
||||
// might be included in multiple constraints.
|
||||
for _, uniqueName := range strings.Split(v, ",") {
|
||||
if t.Unique == nil {
|
||||
t.Unique = make(map[string][]*Field)
|
||||
}
|
||||
t.Unique[uniqueName] = append(t.Unique[uniqueName], field)
|
||||
}
|
||||
}
|
||||
if s, ok := tag.Options["default"]; ok {
|
||||
field.SQLDefault = s
|
||||
}
|
||||
if s, ok := field.Tag.Options["type"]; ok {
|
||||
field.UserSQLType = s
|
||||
}
|
||||
field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType)
|
||||
field.Append = t.dialect.FieldAppender(field)
|
||||
field.Scan = FieldScanner(t.dialect, field)
|
||||
field.IsZero = FieldZeroChecker(field)
|
||||
|
||||
if v, ok := tag.Options["alt"]; ok {
|
||||
t.FieldMap[v] = field
|
||||
}
|
||||
|
||||
t.allFields = append(t.allFields, field)
|
||||
if skip {
|
||||
t.skippedFields = append(t.skippedFields, field)
|
||||
t.FieldMap[field.Name] = field
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, ok := tag.Options["soft_delete"]; ok {
|
||||
field.NullZero = true
|
||||
t.SoftDeleteField = field
|
||||
t.UpdateSoftDeleteField = softDeleteFieldUpdater(field)
|
||||
}
|
||||
|
||||
return field
|
||||
}
|
||||
|
||||
func (t *Table) initInlines() {
|
||||
for _, f := range t.skippedFields {
|
||||
if f.IndirectType.Kind() == reflect.Struct {
|
||||
t.inlineFields(f, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//---------------------------------------------------------------------------------------
|
||||
|
||||
func (t *Table) initRelations() {
|
||||
for i := 0; i < len(t.Fields); {
|
||||
f := t.Fields[i]
|
||||
if t.tryRelation(f) {
|
||||
t.Fields = removeField(t.Fields, f)
|
||||
t.DataFields = removeField(t.DataFields, f)
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
|
||||
if f.IndirectType.Kind() == reflect.Struct {
|
||||
t.inlineFields(f, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Table) tryRelation(field *Field) bool {
|
||||
if rel, ok := field.Tag.Options["rel"]; ok {
|
||||
t.initRelation(field, rel)
|
||||
return true
|
||||
}
|
||||
if field.Tag.HasOption("m2m") {
|
||||
t.addRelation(t.m2mRelation(field))
|
||||
return true
|
||||
}
|
||||
|
||||
if field.Tag.HasOption("join") {
|
||||
internal.Warn.Printf(
|
||||
`%s.%s option "join" requires a relation type`,
|
||||
t.TypeName, field.GoName,
|
||||
)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *Table) initRelation(field *Field, rel string) {
|
||||
switch rel {
|
||||
case "belongs-to":
|
||||
t.addRelation(t.belongsToRelation(field))
|
||||
case "has-one":
|
||||
t.addRelation(t.hasOneRelation(field))
|
||||
case "has-many":
|
||||
t.addRelation(t.hasManyRelation(field))
|
||||
default:
|
||||
panic(fmt.Errorf("bun: unknown relation=%s on field=%s", rel, field.GoName))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Table) addRelation(rel *Relation) {
|
||||
if t.Relations == nil {
|
||||
t.Relations = make(map[string]*Relation)
|
||||
}
|
||||
_, ok := t.Relations[rel.Field.GoName]
|
||||
if ok {
|
||||
panic(fmt.Errorf("%s already has %s", t, rel))
|
||||
}
|
||||
t.Relations[rel.Field.GoName] = rel
|
||||
}
|
||||
|
||||
func (t *Table) belongsToRelation(field *Field) *Relation {
|
||||
joinTable := t.dialect.Tables().Ref(field.IndirectType)
|
||||
if err := joinTable.CheckPKs(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
rel := &Relation{
|
||||
Type: HasOneRelation,
|
||||
Field: field,
|
||||
JoinTable: joinTable,
|
||||
}
|
||||
|
||||
if join, ok := field.Tag.Options["join"]; ok {
|
||||
baseColumns, joinColumns := parseRelationJoin(join)
|
||||
for i, baseColumn := range baseColumns {
|
||||
joinColumn := joinColumns[i]
|
||||
|
||||
if f := t.fieldWithLock(baseColumn); f != nil {
|
||||
rel.BaseFields = append(rel.BaseFields, f)
|
||||
} else {
|
||||
panic(fmt.Errorf(
|
||||
"bun: %s belongs-to %s: %s must have column %s",
|
||||
t.TypeName, field.GoName, t.TypeName, baseColumn,
|
||||
))
|
||||
}
|
||||
|
||||
if f := joinTable.fieldWithLock(joinColumn); f != nil {
|
||||
rel.JoinFields = append(rel.JoinFields, f)
|
||||
} else {
|
||||
panic(fmt.Errorf(
|
||||
"bun: %s belongs-to %s: %s must have column %s",
|
||||
t.TypeName, field.GoName, t.TypeName, baseColumn,
|
||||
))
|
||||
}
|
||||
}
|
||||
return rel
|
||||
}
|
||||
|
||||
rel.JoinFields = joinTable.PKs
|
||||
fkPrefix := internal.Underscore(field.GoName) + "_"
|
||||
for _, joinPK := range joinTable.PKs {
|
||||
fkName := fkPrefix + joinPK.Name
|
||||
if fk := t.fieldWithLock(fkName); fk != nil {
|
||||
rel.BaseFields = append(rel.BaseFields, fk)
|
||||
continue
|
||||
}
|
||||
|
||||
if fk := t.fieldWithLock(joinPK.Name); fk != nil {
|
||||
rel.BaseFields = append(rel.BaseFields, fk)
|
||||
continue
|
||||
}
|
||||
|
||||
panic(fmt.Errorf(
|
||||
"bun: %s belongs-to %s: %s must have column %s "+
|
||||
"(to override, use join:base_column=join_column tag on %s field)",
|
||||
t.TypeName, field.GoName, t.TypeName, fkName, field.GoName,
|
||||
))
|
||||
}
|
||||
return rel
|
||||
}
|
||||
|
||||
func (t *Table) hasOneRelation(field *Field) *Relation {
|
||||
if err := t.CheckPKs(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
joinTable := t.dialect.Tables().Ref(field.IndirectType)
|
||||
rel := &Relation{
|
||||
Type: BelongsToRelation,
|
||||
Field: field,
|
||||
JoinTable: joinTable,
|
||||
}
|
||||
|
||||
if join, ok := field.Tag.Options["join"]; ok {
|
||||
baseColumns, joinColumns := parseRelationJoin(join)
|
||||
for i, baseColumn := range baseColumns {
|
||||
if f := t.fieldWithLock(baseColumn); f != nil {
|
||||
rel.BaseFields = append(rel.BaseFields, f)
|
||||
} else {
|
||||
panic(fmt.Errorf(
|
||||
"bun: %s has-one %s: %s must have column %s",
|
||||
field.GoName, t.TypeName, joinTable.TypeName, baseColumn,
|
||||
))
|
||||
}
|
||||
|
||||
joinColumn := joinColumns[i]
|
||||
if f := joinTable.fieldWithLock(joinColumn); f != nil {
|
||||
rel.JoinFields = append(rel.JoinFields, f)
|
||||
} else {
|
||||
panic(fmt.Errorf(
|
||||
"bun: %s has-one %s: %s must have column %s",
|
||||
field.GoName, t.TypeName, joinTable.TypeName, baseColumn,
|
||||
))
|
||||
}
|
||||
}
|
||||
return rel
|
||||
}
|
||||
|
||||
rel.BaseFields = t.PKs
|
||||
fkPrefix := internal.Underscore(t.ModelName) + "_"
|
||||
for _, pk := range t.PKs {
|
||||
fkName := fkPrefix + pk.Name
|
||||
if f := joinTable.fieldWithLock(fkName); f != nil {
|
||||
rel.JoinFields = append(rel.JoinFields, f)
|
||||
continue
|
||||
}
|
||||
|
||||
if f := joinTable.fieldWithLock(pk.Name); f != nil {
|
||||
rel.JoinFields = append(rel.JoinFields, f)
|
||||
continue
|
||||
}
|
||||
|
||||
panic(fmt.Errorf(
|
||||
"bun: %s has-one %s: %s must have column %s "+
|
||||
"(to override, use join:base_column=join_column tag on %s field)",
|
||||
field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName,
|
||||
))
|
||||
}
|
||||
return rel
|
||||
}
|
||||
|
||||
func (t *Table) hasManyRelation(field *Field) *Relation {
|
||||
if err := t.CheckPKs(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if field.IndirectType.Kind() != reflect.Slice {
|
||||
panic(fmt.Errorf(
|
||||
"bun: %s.%s has-many relation requires slice, got %q",
|
||||
t.TypeName, field.GoName, field.IndirectType.Kind(),
|
||||
))
|
||||
}
|
||||
|
||||
joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
|
||||
polymorphicValue, isPolymorphic := field.Tag.Options["polymorphic"]
|
||||
rel := &Relation{
|
||||
Type: HasManyRelation,
|
||||
Field: field,
|
||||
JoinTable: joinTable,
|
||||
}
|
||||
var polymorphicColumn string
|
||||
|
||||
if join, ok := field.Tag.Options["join"]; ok {
|
||||
baseColumns, joinColumns := parseRelationJoin(join)
|
||||
for i, baseColumn := range baseColumns {
|
||||
joinColumn := joinColumns[i]
|
||||
|
||||
if isPolymorphic && baseColumn == "type" {
|
||||
polymorphicColumn = joinColumn
|
||||
continue
|
||||
}
|
||||
|
||||
if f := t.fieldWithLock(baseColumn); f != nil {
|
||||
rel.BaseFields = append(rel.BaseFields, f)
|
||||
} else {
|
||||
panic(fmt.Errorf(
|
||||
"bun: %s has-one %s: %s must have column %s",
|
||||
t.TypeName, field.GoName, t.TypeName, baseColumn,
|
||||
))
|
||||
}
|
||||
|
||||
if f := joinTable.fieldWithLock(joinColumn); f != nil {
|
||||
rel.JoinFields = append(rel.JoinFields, f)
|
||||
} else {
|
||||
panic(fmt.Errorf(
|
||||
"bun: %s has-one %s: %s must have column %s",
|
||||
t.TypeName, field.GoName, t.TypeName, baseColumn,
|
||||
))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rel.BaseFields = t.PKs
|
||||
fkPrefix := internal.Underscore(t.ModelName) + "_"
|
||||
if isPolymorphic {
|
||||
polymorphicColumn = fkPrefix + "type"
|
||||
}
|
||||
|
||||
for _, pk := range t.PKs {
|
||||
joinColumn := fkPrefix + pk.Name
|
||||
if fk := joinTable.fieldWithLock(joinColumn); fk != nil {
|
||||
rel.JoinFields = append(rel.JoinFields, fk)
|
||||
continue
|
||||
}
|
||||
|
||||
if fk := joinTable.fieldWithLock(pk.Name); fk != nil {
|
||||
rel.JoinFields = append(rel.JoinFields, fk)
|
||||
continue
|
||||
}
|
||||
|
||||
panic(fmt.Errorf(
|
||||
"bun: %s has-many %s: %s must have column %s "+
|
||||
"(to override, use join:base_column=join_column tag on the field %s)",
|
||||
t.TypeName, field.GoName, joinTable.TypeName, joinColumn, field.GoName,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
if isPolymorphic {
|
||||
rel.PolymorphicField = joinTable.fieldWithLock(polymorphicColumn)
|
||||
if rel.PolymorphicField == nil {
|
||||
panic(fmt.Errorf(
|
||||
"bun: %s has-many %s: %s must have polymorphic column %s",
|
||||
t.TypeName, field.GoName, joinTable.TypeName, polymorphicColumn,
|
||||
))
|
||||
}
|
||||
|
||||
if polymorphicValue == "" {
|
||||
polymorphicValue = t.ModelName
|
||||
}
|
||||
rel.PolymorphicValue = polymorphicValue
|
||||
}
|
||||
|
||||
return rel
|
||||
}
|
||||
|
||||
func (t *Table) m2mRelation(field *Field) *Relation {
|
||||
if field.IndirectType.Kind() != reflect.Slice {
|
||||
panic(fmt.Errorf(
|
||||
"bun: %s.%s m2m relation requires slice, got %q",
|
||||
t.TypeName, field.GoName, field.IndirectType.Kind(),
|
||||
))
|
||||
}
|
||||
joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
|
||||
|
||||
if err := t.CheckPKs(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := joinTable.CheckPKs(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
m2mTableName, ok := field.Tag.Options["m2m"]
|
||||
if !ok {
|
||||
panic(fmt.Errorf("bun: %s must have m2m tag option", field.GoName))
|
||||
}
|
||||
|
||||
m2mTable := t.dialect.Tables().ByName(m2mTableName)
|
||||
if m2mTable == nil {
|
||||
panic(fmt.Errorf(
|
||||
"bun: can't find m2m %s table (use db.RegisterModel)",
|
||||
m2mTableName,
|
||||
))
|
||||
}
|
||||
|
||||
rel := &Relation{
|
||||
Type: ManyToManyRelation,
|
||||
Field: field,
|
||||
JoinTable: joinTable,
|
||||
M2MTable: m2mTable,
|
||||
}
|
||||
var leftColumn, rightColumn string
|
||||
|
||||
if join, ok := field.Tag.Options["join"]; ok {
|
||||
left, right := parseRelationJoin(join)
|
||||
leftColumn = left[0]
|
||||
rightColumn = right[0]
|
||||
} else {
|
||||
leftColumn = t.TypeName
|
||||
rightColumn = joinTable.TypeName
|
||||
}
|
||||
|
||||
leftField := m2mTable.fieldByGoName(leftColumn)
|
||||
if leftField == nil {
|
||||
panic(fmt.Errorf(
|
||||
"bun: %s many-to-many %s: %s must have field %s "+
|
||||
"(to override, use tag join:LeftField=RightField on field %s.%s",
|
||||
t.TypeName, field.GoName, m2mTable.TypeName, leftColumn, t.TypeName, field.GoName,
|
||||
))
|
||||
}
|
||||
|
||||
rightField := m2mTable.fieldByGoName(rightColumn)
|
||||
if rightField == nil {
|
||||
panic(fmt.Errorf(
|
||||
"bun: %s many-to-many %s: %s must have field %s "+
|
||||
"(to override, use tag join:LeftField=RightField on field %s.%s",
|
||||
t.TypeName, field.GoName, m2mTable.TypeName, rightColumn, t.TypeName, field.GoName,
|
||||
))
|
||||
}
|
||||
|
||||
leftRel := m2mTable.belongsToRelation(leftField)
|
||||
rel.BaseFields = leftRel.JoinFields
|
||||
rel.M2MBaseFields = leftRel.BaseFields
|
||||
|
||||
rightRel := m2mTable.belongsToRelation(rightField)
|
||||
rel.JoinFields = rightRel.JoinFields
|
||||
rel.M2MJoinFields = rightRel.BaseFields
|
||||
|
||||
return rel
|
||||
}
|
||||
|
||||
func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) {
|
||||
if path == nil {
|
||||
path = map[reflect.Type]struct{}{
|
||||
t.Type: {},
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := path[field.IndirectType]; ok {
|
||||
return
|
||||
}
|
||||
path[field.IndirectType] = struct{}{}
|
||||
|
||||
joinTable := t.dialect.Tables().Ref(field.IndirectType)
|
||||
for _, f := range joinTable.allFields {
|
||||
f = f.Clone()
|
||||
f.GoName = field.GoName + "_" + f.GoName
|
||||
f.Name = field.Name + "__" + f.Name
|
||||
f.SQLName = t.quoteIdent(f.Name)
|
||||
f.Index = appendNew(field.Index, f.Index...)
|
||||
|
||||
t.fieldsMapMu.Lock()
|
||||
if _, ok := t.FieldMap[f.Name]; !ok {
|
||||
t.FieldMap[f.Name] = f
|
||||
}
|
||||
t.fieldsMapMu.Unlock()
|
||||
|
||||
if f.IndirectType.Kind() != reflect.Struct {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := path[f.IndirectType]; !ok {
|
||||
t.inlineFields(f, path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
func (t *Table) Dialect() Dialect { return t.dialect }
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) }
|
||||
func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) }
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
func (t *Table) AppendNamedArg(
|
||||
fmter Formatter, b []byte, name string, strct reflect.Value,
|
||||
) ([]byte, bool) {
|
||||
if field, ok := t.FieldMap[name]; ok {
|
||||
return fmter.appendArg(b, field.Value(strct).Interface()), true
|
||||
}
|
||||
return b, false
|
||||
}
|
||||
|
||||
func (t *Table) quoteTableName(s string) Safe {
|
||||
// Don't quote if table name contains placeholder (?) or parentheses.
|
||||
if strings.IndexByte(s, '?') >= 0 ||
|
||||
strings.IndexByte(s, '(') >= 0 ||
|
||||
strings.IndexByte(s, ')') >= 0 {
|
||||
return Safe(s)
|
||||
}
|
||||
return t.quoteIdent(s)
|
||||
}
|
||||
|
||||
func (t *Table) quoteIdent(s string) Safe {
|
||||
return Safe(NewFormatter(t.dialect).AppendIdent(nil, s))
|
||||
}
|
||||
|
||||
func appendNew(dst []int, src ...int) []int {
|
||||
cp := make([]int, len(dst)+len(src))
|
||||
copy(cp, dst)
|
||||
copy(cp[len(dst):], src)
|
||||
return cp
|
||||
}
|
||||
|
||||
func isKnownTableOption(name string) bool {
|
||||
switch name {
|
||||
case "alias", "select":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isKnownFieldOption(name string) bool {
|
||||
switch name {
|
||||
case "alias",
|
||||
"type",
|
||||
"array",
|
||||
"hstore",
|
||||
"composite",
|
||||
"json_use_number",
|
||||
"msgpack",
|
||||
"notnull",
|
||||
"nullzero",
|
||||
"allowzero",
|
||||
"default",
|
||||
"unique",
|
||||
"soft_delete",
|
||||
|
||||
"pk",
|
||||
"autoincrement",
|
||||
"rel",
|
||||
"join",
|
||||
"m2m",
|
||||
"polymorphic":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func removeField(fields []*Field, field *Field) []*Field {
|
||||
for i, f := range fields {
|
||||
if f == field {
|
||||
return append(fields[:i], fields[i+1:]...)
|
||||
}
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func parseRelationJoin(join string) ([]string, []string) {
|
||||
ss := strings.Split(join, ",")
|
||||
baseColumns := make([]string, len(ss))
|
||||
joinColumns := make([]string, len(ss))
|
||||
for i, s := range ss {
|
||||
ss := strings.Split(strings.TrimSpace(s), "=")
|
||||
if len(ss) != 2 {
|
||||
panic(fmt.Errorf("can't parse relation join: %q", join))
|
||||
}
|
||||
baseColumns[i] = ss[0]
|
||||
joinColumns[i] = ss[1]
|
||||
}
|
||||
return baseColumns, joinColumns
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error {
|
||||
typ := field.StructField.Type
|
||||
|
||||
switch typ {
|
||||
case timeType:
|
||||
return func(fv reflect.Value) error {
|
||||
ptr := fv.Addr().Interface().(*time.Time)
|
||||
*ptr = time.Now()
|
||||
return nil
|
||||
}
|
||||
case nullTimeType:
|
||||
return func(fv reflect.Value) error {
|
||||
ptr := fv.Addr().Interface().(*sql.NullTime)
|
||||
*ptr = sql.NullTime{Time: time.Now()}
|
||||
return nil
|
||||
}
|
||||
case nullIntType:
|
||||
return func(fv reflect.Value) error {
|
||||
ptr := fv.Addr().Interface().(*sql.NullInt64)
|
||||
*ptr = sql.NullInt64{Int64: time.Now().UnixNano()}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
switch field.IndirectType.Kind() {
|
||||
case reflect.Int64:
|
||||
return func(fv reflect.Value) error {
|
||||
ptr := fv.Addr().Interface().(*int64)
|
||||
*ptr = time.Now().UnixNano()
|
||||
return nil
|
||||
}
|
||||
case reflect.Ptr:
|
||||
typ = typ.Elem()
|
||||
default:
|
||||
return softDeleteFieldUpdaterFallback(field)
|
||||
}
|
||||
|
||||
switch typ { //nolint:gocritic
|
||||
case timeType:
|
||||
return func(fv reflect.Value) error {
|
||||
now := time.Now()
|
||||
fv.Set(reflect.ValueOf(&now))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
switch typ.Kind() { //nolint:gocritic
|
||||
case reflect.Int64:
|
||||
return func(fv reflect.Value) error {
|
||||
utime := time.Now().UnixNano()
|
||||
fv.Set(reflect.ValueOf(&utime))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return softDeleteFieldUpdaterFallback(field)
|
||||
}
|
||||
|
||||
func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error {
|
||||
return func(fv reflect.Value) error {
|
||||
return field.ScanWithCheck(fv, time.Now())
|
||||
}
|
||||
}
|
||||
148
vendor/github.com/uptrace/bun/schema/tables.go
generated
vendored
Normal file
148
vendor/github.com/uptrace/bun/schema/tables.go
generated
vendored
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type tableInProgress struct {
|
||||
table *Table
|
||||
|
||||
init1Once sync.Once
|
||||
init2Once sync.Once
|
||||
}
|
||||
|
||||
func newTableInProgress(table *Table) *tableInProgress {
|
||||
return &tableInProgress{
|
||||
table: table,
|
||||
}
|
||||
}
|
||||
|
||||
func (inp *tableInProgress) init1() bool {
|
||||
var inited bool
|
||||
inp.init1Once.Do(func() {
|
||||
inp.table.init1()
|
||||
inited = true
|
||||
})
|
||||
return inited
|
||||
}
|
||||
|
||||
func (inp *tableInProgress) init2() bool {
|
||||
var inited bool
|
||||
inp.init2Once.Do(func() {
|
||||
inp.table.init2()
|
||||
inited = true
|
||||
})
|
||||
return inited
|
||||
}
|
||||
|
||||
type Tables struct {
|
||||
dialect Dialect
|
||||
tables sync.Map
|
||||
|
||||
mu sync.RWMutex
|
||||
inProgress map[reflect.Type]*tableInProgress
|
||||
}
|
||||
|
||||
func NewTables(dialect Dialect) *Tables {
|
||||
return &Tables{
|
||||
dialect: dialect,
|
||||
inProgress: make(map[reflect.Type]*tableInProgress),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tables) Register(models ...interface{}) {
|
||||
for _, model := range models {
|
||||
_ = t.Get(reflect.TypeOf(model).Elem())
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tables) Get(typ reflect.Type) *Table {
|
||||
return t.table(typ, false)
|
||||
}
|
||||
|
||||
func (t *Tables) Ref(typ reflect.Type) *Table {
|
||||
return t.table(typ, true)
|
||||
}
|
||||
|
||||
func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table {
|
||||
if typ.Kind() != reflect.Struct {
|
||||
panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
|
||||
}
|
||||
|
||||
if v, ok := t.tables.Load(typ); ok {
|
||||
return v.(*Table)
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
|
||||
if v, ok := t.tables.Load(typ); ok {
|
||||
t.mu.Unlock()
|
||||
return v.(*Table)
|
||||
}
|
||||
|
||||
var table *Table
|
||||
|
||||
inProgress := t.inProgress[typ]
|
||||
if inProgress == nil {
|
||||
table = newTable(t.dialect, typ)
|
||||
inProgress = newTableInProgress(table)
|
||||
t.inProgress[typ] = inProgress
|
||||
} else {
|
||||
table = inProgress.table
|
||||
}
|
||||
|
||||
t.mu.Unlock()
|
||||
|
||||
inProgress.init1()
|
||||
if allowInProgress {
|
||||
return table
|
||||
}
|
||||
|
||||
if inProgress.init2() {
|
||||
t.mu.Lock()
|
||||
delete(t.inProgress, typ)
|
||||
t.tables.Store(typ, table)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
t.dialect.OnTable(table)
|
||||
|
||||
for _, field := range table.FieldMap {
|
||||
if field.UserSQLType == "" {
|
||||
field.UserSQLType = field.DiscoveredSQLType
|
||||
}
|
||||
if field.CreateTableSQLType == "" {
|
||||
field.CreateTableSQLType = field.UserSQLType
|
||||
}
|
||||
}
|
||||
|
||||
return table
|
||||
}
|
||||
|
||||
func (t *Tables) ByModel(name string) *Table {
|
||||
var found *Table
|
||||
t.tables.Range(func(key, value interface{}) bool {
|
||||
t := value.(*Table)
|
||||
if t.TypeName == name {
|
||||
found = t
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return found
|
||||
}
|
||||
|
||||
func (t *Tables) ByName(name string) *Table {
|
||||
var found *Table
|
||||
t.tables.Range(func(key, value interface{}) bool {
|
||||
t := value.(*Table)
|
||||
if t.Name == name {
|
||||
found = t
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return found
|
||||
}
|
||||
53
vendor/github.com/uptrace/bun/schema/util.go
generated
vendored
Normal file
53
vendor/github.com/uptrace/bun/schema/util.go
generated
vendored
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
package schema
|
||||
|
||||
import "reflect"
|
||||
|
||||
func indirectType(t reflect.Type) reflect.Type {
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) {
|
||||
if len(index) == 1 {
|
||||
return v.Field(index[0]), true
|
||||
}
|
||||
|
||||
for i, idx := range index {
|
||||
if i > 0 {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
return v, false
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
}
|
||||
v = v.Field(idx)
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value {
|
||||
if len(index) == 1 {
|
||||
return v.Field(index[0])
|
||||
}
|
||||
|
||||
for i, idx := range index {
|
||||
if i > 0 {
|
||||
v = indirectNil(v)
|
||||
}
|
||||
v = v.Field(idx)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func indirectNil(v reflect.Value) reflect.Value {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
return v
|
||||
}
|
||||
126
vendor/github.com/uptrace/bun/schema/zerochecker.go
generated
vendored
Normal file
126
vendor/github.com/uptrace/bun/schema/zerochecker.go
generated
vendored
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
var isZeroerType = reflect.TypeOf((*isZeroer)(nil)).Elem()
|
||||
|
||||
type isZeroer interface {
|
||||
IsZero() bool
|
||||
}
|
||||
|
||||
type IsZeroerFunc func(reflect.Value) bool
|
||||
|
||||
func FieldZeroChecker(field *Field) IsZeroerFunc {
|
||||
return zeroChecker(field.IndirectType)
|
||||
}
|
||||
|
||||
func zeroChecker(typ reflect.Type) IsZeroerFunc {
|
||||
if typ.Implements(isZeroerType) {
|
||||
return isZeroInterface
|
||||
}
|
||||
|
||||
kind := typ.Kind()
|
||||
|
||||
if kind != reflect.Ptr {
|
||||
ptr := reflect.PtrTo(typ)
|
||||
if ptr.Implements(isZeroerType) {
|
||||
return addrChecker(isZeroInterface)
|
||||
}
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case reflect.Array:
|
||||
if typ.Elem().Kind() == reflect.Uint8 {
|
||||
return isZeroBytes
|
||||
}
|
||||
return isZeroLen
|
||||
case reflect.String:
|
||||
return isZeroLen
|
||||
case reflect.Bool:
|
||||
return isZeroBool
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return isZeroInt
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||
return isZeroUint
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return isZeroFloat
|
||||
case reflect.Interface, reflect.Ptr, reflect.Slice, reflect.Map:
|
||||
return isNil
|
||||
}
|
||||
|
||||
if typ.Implements(driverValuerType) {
|
||||
return isZeroDriverValue
|
||||
}
|
||||
|
||||
return notZero
|
||||
}
|
||||
|
||||
func addrChecker(fn IsZeroerFunc) IsZeroerFunc {
|
||||
return func(v reflect.Value) bool {
|
||||
if !v.CanAddr() {
|
||||
return false
|
||||
}
|
||||
return fn(v.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
func isZeroInterface(v reflect.Value) bool {
|
||||
if v.Kind() == reflect.Ptr && v.IsNil() {
|
||||
return true
|
||||
}
|
||||
return v.Interface().(isZeroer).IsZero()
|
||||
}
|
||||
|
||||
func isZeroDriverValue(v reflect.Value) bool {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
return v.IsNil()
|
||||
}
|
||||
|
||||
valuer := v.Interface().(driver.Valuer)
|
||||
value, err := valuer.Value()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return value == nil
|
||||
}
|
||||
|
||||
func isZeroLen(v reflect.Value) bool {
|
||||
return v.Len() == 0
|
||||
}
|
||||
|
||||
func isNil(v reflect.Value) bool {
|
||||
return v.IsNil()
|
||||
}
|
||||
|
||||
func isZeroBool(v reflect.Value) bool {
|
||||
return !v.Bool()
|
||||
}
|
||||
|
||||
func isZeroInt(v reflect.Value) bool {
|
||||
return v.Int() == 0
|
||||
}
|
||||
|
||||
func isZeroUint(v reflect.Value) bool {
|
||||
return v.Uint() == 0
|
||||
}
|
||||
|
||||
func isZeroFloat(v reflect.Value) bool {
|
||||
return v.Float() == 0
|
||||
}
|
||||
|
||||
func isZeroBytes(v reflect.Value) bool {
|
||||
b := v.Slice(0, v.Len()).Bytes()
|
||||
for _, c := range b {
|
||||
if c != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func notZero(v reflect.Value) bool {
|
||||
return false
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue