Pg to bun (#148)

* start moving to bun

* changing more stuff

* more

* and yet more

* tests passing

* seems stable now

* more big changes

* small fix

* little fixes
This commit is contained in:
tobi 2021-08-25 15:34:33 +02:00 committed by GitHub
commit 2dc9fc1626
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
713 changed files with 98694 additions and 22704 deletions

93
vendor/github.com/uptrace/bun/schema/append.go generated vendored Normal file
View file

@ -0,0 +1,93 @@
package schema
import (
"reflect"
"strconv"
"strings"
"time"
"github.com/vmihailenco/msgpack/v5"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/internal"
)
func FieldAppender(dialect Dialect, field *Field) AppenderFunc {
if field.Tag.HasOption("msgpack") {
return appendMsgpack
}
switch strings.ToUpper(field.UserSQLType) {
case sqltype.JSON, sqltype.JSONB:
return AppendJSONValue
}
return dialect.Appender(field.StructField.Type)
}
func Append(fmter Formatter, b []byte, v interface{}, custom CustomAppender) []byte {
switch v := v.(type) {
case nil:
return dialect.AppendNull(b)
case bool:
return dialect.AppendBool(b, v)
case int:
return strconv.AppendInt(b, int64(v), 10)
case int32:
return strconv.AppendInt(b, int64(v), 10)
case int64:
return strconv.AppendInt(b, v, 10)
case uint:
return strconv.AppendUint(b, uint64(v), 10)
case uint32:
return strconv.AppendUint(b, uint64(v), 10)
case uint64:
return strconv.AppendUint(b, v, 10)
case float32:
return dialect.AppendFloat32(b, v)
case float64:
return dialect.AppendFloat64(b, v)
case string:
return dialect.AppendString(b, v)
case time.Time:
return dialect.AppendTime(b, v)
case []byte:
return dialect.AppendBytes(b, v)
case QueryAppender:
return AppendQueryAppender(fmter, b, v)
default:
vv := reflect.ValueOf(v)
if vv.Kind() == reflect.Ptr && vv.IsNil() {
return dialect.AppendNull(b)
}
appender := Appender(vv.Type(), custom)
return appender(fmter, b, vv)
}
}
func appendMsgpack(fmter Formatter, b []byte, v reflect.Value) []byte {
hexEnc := internal.NewHexEncoder(b)
enc := msgpack.GetEncoder()
defer msgpack.PutEncoder(enc)
enc.Reset(hexEnc)
if err := enc.EncodeValue(v); err != nil {
return dialect.AppendError(b, err)
}
if err := hexEnc.Close(); err != nil {
return dialect.AppendError(b, err)
}
return hexEnc.Bytes()
}
func AppendQueryAppender(fmter Formatter, b []byte, app QueryAppender) []byte {
bb, err := app.AppendQuery(fmter, b)
if err != nil {
return dialect.AppendError(b, err)
}
return bb
}

237
vendor/github.com/uptrace/bun/schema/append_value.go generated vendored Normal file
View file

@ -0,0 +1,237 @@
package schema
import (
"database/sql/driver"
"encoding/json"
"fmt"
"net"
"reflect"
"strconv"
"time"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/extra/bunjson"
"github.com/uptrace/bun/internal"
)
var (
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem()
)
type (
AppenderFunc func(fmter Formatter, b []byte, v reflect.Value) []byte
CustomAppender func(typ reflect.Type) AppenderFunc
)
var appenders = []AppenderFunc{
reflect.Bool: AppendBoolValue,
reflect.Int: AppendIntValue,
reflect.Int8: AppendIntValue,
reflect.Int16: AppendIntValue,
reflect.Int32: AppendIntValue,
reflect.Int64: AppendIntValue,
reflect.Uint: AppendUintValue,
reflect.Uint8: AppendUintValue,
reflect.Uint16: AppendUintValue,
reflect.Uint32: AppendUintValue,
reflect.Uint64: AppendUintValue,
reflect.Uintptr: nil,
reflect.Float32: AppendFloat32Value,
reflect.Float64: AppendFloat64Value,
reflect.Complex64: nil,
reflect.Complex128: nil,
reflect.Array: AppendJSONValue,
reflect.Chan: nil,
reflect.Func: nil,
reflect.Interface: nil,
reflect.Map: AppendJSONValue,
reflect.Ptr: nil,
reflect.Slice: AppendJSONValue,
reflect.String: AppendStringValue,
reflect.Struct: AppendJSONValue,
reflect.UnsafePointer: nil,
}
func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc {
switch typ {
case timeType:
return appendTimeValue
case ipType:
return appendIPValue
case ipNetType:
return appendIPNetValue
case jsonRawMessageType:
return appendJSONRawMessageValue
}
if typ.Implements(queryAppenderType) {
return appendQueryAppenderValue
}
if typ.Implements(driverValuerType) {
return driverValueAppender(custom)
}
kind := typ.Kind()
if kind != reflect.Ptr {
ptr := reflect.PtrTo(typ)
if ptr.Implements(queryAppenderType) {
return addrAppender(appendQueryAppenderValue, custom)
}
if ptr.Implements(driverValuerType) {
return addrAppender(driverValueAppender(custom), custom)
}
}
switch kind {
case reflect.Interface:
return ifaceAppenderFunc(typ, custom)
case reflect.Ptr:
return ptrAppenderFunc(typ, custom)
case reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return appendBytesValue
}
case reflect.Array:
if typ.Elem().Kind() == reflect.Uint8 {
return appendArrayBytesValue
}
}
if custom != nil {
if fn := custom(typ); fn != nil {
return fn
}
}
return appenders[typ.Kind()]
}
func ifaceAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc {
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
if v.IsNil() {
return dialect.AppendNull(b)
}
elem := v.Elem()
appender := Appender(elem.Type(), custom)
return appender(fmter, b, elem)
}
}
func ptrAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc {
appender := Appender(typ.Elem(), custom)
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
if v.IsNil() {
return dialect.AppendNull(b)
}
return appender(fmter, b, v.Elem())
}
}
func AppendBoolValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return dialect.AppendBool(b, v.Bool())
}
func AppendIntValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return strconv.AppendInt(b, v.Int(), 10)
}
func AppendUintValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return strconv.AppendUint(b, v.Uint(), 10)
}
func AppendFloat32Value(fmter Formatter, b []byte, v reflect.Value) []byte {
return dialect.AppendFloat32(b, float32(v.Float()))
}
func AppendFloat64Value(fmter Formatter, b []byte, v reflect.Value) []byte {
return dialect.AppendFloat64(b, float64(v.Float()))
}
func appendBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return dialect.AppendBytes(b, v.Bytes())
}
func appendArrayBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
if v.CanAddr() {
return dialect.AppendBytes(b, v.Slice(0, v.Len()).Bytes())
}
tmp := make([]byte, v.Len())
reflect.Copy(reflect.ValueOf(tmp), v)
b = dialect.AppendBytes(b, tmp)
return b
}
func AppendStringValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return dialect.AppendString(b, v.String())
}
func AppendJSONValue(fmter Formatter, b []byte, v reflect.Value) []byte {
bb, err := bunjson.Marshal(v.Interface())
if err != nil {
return dialect.AppendError(b, err)
}
if len(bb) > 0 && bb[len(bb)-1] == '\n' {
bb = bb[:len(bb)-1]
}
return dialect.AppendJSON(b, bb)
}
func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte {
tm := v.Interface().(time.Time)
return dialect.AppendTime(b, tm)
}
func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte {
ip := v.Interface().(net.IP)
return dialect.AppendString(b, ip.String())
}
func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte {
ipnet := v.Interface().(net.IPNet)
return dialect.AppendString(b, ipnet.String())
}
func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte {
bytes := v.Bytes()
if bytes == nil {
return dialect.AppendNull(b)
}
return dialect.AppendString(b, internal.String(bytes))
}
func appendQueryAppenderValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return AppendQueryAppender(fmter, b, v.Interface().(QueryAppender))
}
func driverValueAppender(custom CustomAppender) AppenderFunc {
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
return appendDriverValue(fmter, b, v.Interface().(driver.Valuer), custom)
}
}
func appendDriverValue(fmter Formatter, b []byte, v driver.Valuer, custom CustomAppender) []byte {
value, err := v.Value()
if err != nil {
return dialect.AppendError(b, err)
}
return Append(fmter, b, value, custom)
}
func addrAppender(fn AppenderFunc, custom CustomAppender) AppenderFunc {
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
if !v.CanAddr() {
err := fmt.Errorf("bun: Append(nonaddressable %T)", v.Interface())
return dialect.AppendError(b, err)
}
return fn(fmter, b, v.Addr())
}
}

99
vendor/github.com/uptrace/bun/schema/dialect.go generated vendored Normal file
View file

@ -0,0 +1,99 @@
package schema
import (
"database/sql"
"reflect"
"sync"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
)
type Dialect interface {
Init(db *sql.DB)
Name() dialect.Name
Features() feature.Feature
Tables() *Tables
OnTable(table *Table)
IdentQuote() byte
Append(fmter Formatter, b []byte, v interface{}) []byte
Appender(typ reflect.Type) AppenderFunc
FieldAppender(field *Field) AppenderFunc
Scanner(typ reflect.Type) ScannerFunc
}
//------------------------------------------------------------------------------
type nopDialect struct {
tables *Tables
features feature.Feature
appenderMap sync.Map
scannerMap sync.Map
}
func newNopDialect() *nopDialect {
d := new(nopDialect)
d.tables = NewTables(d)
d.features = feature.Returning
return d
}
func (d *nopDialect) Init(*sql.DB) {}
func (d *nopDialect) Name() dialect.Name {
return dialect.Invalid
}
func (d *nopDialect) Features() feature.Feature {
return d.features
}
func (d *nopDialect) Tables() *Tables {
return d.tables
}
func (d *nopDialect) OnField(field *Field) {}
func (d *nopDialect) OnTable(table *Table) {}
func (d *nopDialect) IdentQuote() byte {
return '"'
}
func (d *nopDialect) Append(fmter Formatter, b []byte, v interface{}) []byte {
return Append(fmter, b, v, nil)
}
func (d *nopDialect) Appender(typ reflect.Type) AppenderFunc {
if v, ok := d.appenderMap.Load(typ); ok {
return v.(AppenderFunc)
}
fn := Appender(typ, nil)
if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok {
return v.(AppenderFunc)
}
return fn
}
func (d *nopDialect) FieldAppender(field *Field) AppenderFunc {
return FieldAppender(d, field)
}
func (d *nopDialect) Scanner(typ reflect.Type) ScannerFunc {
if v, ok := d.scannerMap.Load(typ); ok {
return v.(ScannerFunc)
}
fn := Scanner(typ)
if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok {
return v.(ScannerFunc)
}
return fn
}

117
vendor/github.com/uptrace/bun/schema/field.go generated vendored Normal file
View file

@ -0,0 +1,117 @@
package schema
import (
"fmt"
"reflect"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/internal/tagparser"
)
type Field struct {
StructField reflect.StructField
Tag tagparser.Tag
IndirectType reflect.Type
Index []int
Name string // SQL name, .e.g. id
SQLName Safe // escaped SQL name, e.g. "id"
GoName string // struct field name, e.g. Id
DiscoveredSQLType string
UserSQLType string
CreateTableSQLType string
SQLDefault string
OnDelete string
OnUpdate string
IsPK bool
NotNull bool
NullZero bool
AutoIncrement bool
Append AppenderFunc
Scan ScannerFunc
IsZero IsZeroerFunc
}
func (f *Field) String() string {
return f.Name
}
func (f *Field) Clone() *Field {
cp := *f
cp.Index = cp.Index[:len(f.Index):len(f.Index)]
return &cp
}
func (f *Field) Value(strct reflect.Value) reflect.Value {
return fieldByIndexAlloc(strct, f.Index)
}
func (f *Field) HasZeroValue(v reflect.Value) bool {
for _, idx := range f.Index {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return true
}
v = v.Elem()
}
v = v.Field(idx)
}
return f.IsZero(v)
}
func (f *Field) AppendValue(fmter Formatter, b []byte, strct reflect.Value) []byte {
fv, ok := fieldByIndex(strct, f.Index)
if !ok {
return dialect.AppendNull(b)
}
if f.NullZero && f.IsZero(fv) {
return dialect.AppendNull(b)
}
if f.Append == nil {
panic(fmt.Errorf("bun: AppendValue(unsupported %s)", fv.Type()))
}
return f.Append(fmter, b, fv)
}
func (f *Field) ScanWithCheck(fv reflect.Value, src interface{}) error {
if f.Scan == nil {
return fmt.Errorf("bun: Scan(unsupported %s)", f.IndirectType)
}
return f.Scan(fv, src)
}
func (f *Field) ScanValue(strct reflect.Value, src interface{}) error {
if src == nil {
if fv, ok := fieldByIndex(strct, f.Index); ok {
return f.ScanWithCheck(fv, src)
}
return nil
}
fv := fieldByIndexAlloc(strct, f.Index)
return f.ScanWithCheck(fv, src)
}
func (f *Field) markAsPK() {
f.IsPK = true
f.NotNull = true
f.NullZero = true
}
func indexEqual(ind1, ind2 []int) bool {
if len(ind1) != len(ind2) {
return false
}
for i, ind := range ind1 {
if ind != ind2[i] {
return false
}
}
return true
}

248
vendor/github.com/uptrace/bun/schema/formatter.go generated vendored Normal file
View file

@ -0,0 +1,248 @@
package schema
import (
"reflect"
"strconv"
"strings"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/internal/parser"
)
var nopFormatter = Formatter{
dialect: newNopDialect(),
}
type Formatter struct {
dialect Dialect
args *namedArgList
}
func NewFormatter(dialect Dialect) Formatter {
return Formatter{
dialect: dialect,
}
}
func NewNopFormatter() Formatter {
return nopFormatter
}
func (f Formatter) IsNop() bool {
return f.dialect.Name() == dialect.Invalid
}
func (f Formatter) Dialect() Dialect {
return f.dialect
}
func (f Formatter) IdentQuote() byte {
return f.dialect.IdentQuote()
}
func (f Formatter) AppendIdent(b []byte, ident string) []byte {
return dialect.AppendIdent(b, ident, f.IdentQuote())
}
func (f Formatter) AppendValue(b []byte, v reflect.Value) []byte {
if v.Kind() == reflect.Ptr && v.IsNil() {
return dialect.AppendNull(b)
}
appender := f.dialect.Appender(v.Type())
return appender(f, b, v)
}
func (f Formatter) HasFeature(feature feature.Feature) bool {
return f.dialect.Features().Has(feature)
}
func (f Formatter) WithArg(arg NamedArgAppender) Formatter {
return Formatter{
dialect: f.dialect,
args: f.args.WithArg(arg),
}
}
func (f Formatter) WithNamedArg(name string, value interface{}) Formatter {
return Formatter{
dialect: f.dialect,
args: f.args.WithArg(&namedArg{name: name, value: value}),
}
}
func (f Formatter) FormatQuery(query string, args ...interface{}) string {
if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 {
return query
}
return internal.String(f.AppendQuery(nil, query, args...))
}
func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) []byte {
if f.IsNop() || (args == nil && f.args == nil) || strings.IndexByte(query, '?') == -1 {
return append(dst, query...)
}
return f.append(dst, parser.NewString(query), args)
}
func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte {
var namedArgs NamedArgAppender
if len(args) == 1 {
var ok bool
namedArgs, ok = args[0].(NamedArgAppender)
if !ok {
namedArgs, _ = newStructArgs(f, args[0])
}
}
var argIndex int
for p.Valid() {
b, ok := p.ReadSep('?')
if !ok {
dst = append(dst, b...)
continue
}
if len(b) > 0 && b[len(b)-1] == '\\' {
dst = append(dst, b[:len(b)-1]...)
dst = append(dst, '?')
continue
}
dst = append(dst, b...)
name, numeric := p.ReadIdentifier()
if name != "" {
if numeric {
idx, err := strconv.Atoi(name)
if err != nil {
goto restore_arg
}
if idx >= len(args) {
goto restore_arg
}
dst = f.appendArg(dst, args[idx])
continue
}
if namedArgs != nil {
dst, ok = namedArgs.AppendNamedArg(f, dst, name)
if ok {
continue
}
}
dst, ok = f.args.AppendNamedArg(f, dst, name)
if ok {
continue
}
restore_arg:
dst = append(dst, '?')
dst = append(dst, name...)
continue
}
if argIndex >= len(args) {
dst = append(dst, '?')
continue
}
arg := args[argIndex]
argIndex++
dst = f.appendArg(dst, arg)
}
return dst
}
func (f Formatter) appendArg(b []byte, arg interface{}) []byte {
switch arg := arg.(type) {
case QueryAppender:
bb, err := arg.AppendQuery(f, b)
if err != nil {
return dialect.AppendError(b, err)
}
return bb
default:
return f.dialect.Append(f, b, arg)
}
}
//------------------------------------------------------------------------------
type NamedArgAppender interface {
AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool)
}
//------------------------------------------------------------------------------
type namedArgList struct {
arg NamedArgAppender
next *namedArgList
}
func (l *namedArgList) WithArg(arg NamedArgAppender) *namedArgList {
return &namedArgList{
arg: arg,
next: l,
}
}
func (l *namedArgList) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
for l != nil && l.arg != nil {
if b, ok := l.arg.AppendNamedArg(fmter, b, name); ok {
return b, true
}
l = l.next
}
return b, false
}
//------------------------------------------------------------------------------
type namedArg struct {
name string
value interface{}
}
var _ NamedArgAppender = (*namedArg)(nil)
func (a *namedArg) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
if a.name == name {
return fmter.appendArg(b, a.value), true
}
return b, false
}
//------------------------------------------------------------------------------
var _ NamedArgAppender = (*structArgs)(nil)
type structArgs struct {
table *Table
strct reflect.Value
}
func newStructArgs(fmter Formatter, strct interface{}) (*structArgs, bool) {
v := reflect.ValueOf(strct)
if !v.IsValid() {
return nil, false
}
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return nil, false
}
return &structArgs{
table: fmter.Dialect().Tables().Get(v.Type()),
strct: v,
}, true
}
func (m *structArgs) AppendNamedArg(fmter Formatter, b []byte, name string) ([]byte, bool) {
return m.table.AppendNamedArg(fmter, b, name, m.strct)
}

20
vendor/github.com/uptrace/bun/schema/hook.go generated vendored Normal file
View file

@ -0,0 +1,20 @@
package schema
import (
"context"
"reflect"
)
type BeforeScanHook interface {
BeforeScan(context.Context) error
}
var beforeScanHookType = reflect.TypeOf((*BeforeScanHook)(nil)).Elem()
//------------------------------------------------------------------------------
type AfterScanHook interface {
AfterScan(context.Context) error
}
var afterScanHookType = reflect.TypeOf((*AfterScanHook)(nil)).Elem()

32
vendor/github.com/uptrace/bun/schema/relation.go generated vendored Normal file
View file

@ -0,0 +1,32 @@
package schema
import (
"fmt"
)
const (
InvalidRelation = iota
HasOneRelation
BelongsToRelation
HasManyRelation
ManyToManyRelation
)
type Relation struct {
Type int
Field *Field
JoinTable *Table
BaseFields []*Field
JoinFields []*Field
PolymorphicField *Field
PolymorphicValue string
M2MTable *Table
M2MBaseFields []*Field
M2MJoinFields []*Field
}
func (r *Relation) String() string {
return fmt.Sprintf("relation=%s", r.Field.GoName)
}

392
vendor/github.com/uptrace/bun/schema/scan.go generated vendored Normal file
View file

@ -0,0 +1,392 @@
package schema
import (
"bytes"
"database/sql"
"fmt"
"net"
"reflect"
"strconv"
"time"
"github.com/vmihailenco/msgpack/v5"
"github.com/uptrace/bun/extra/bunjson"
"github.com/uptrace/bun/internal"
)
var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
type ScannerFunc func(dest reflect.Value, src interface{}) error
var scanners = []ScannerFunc{
reflect.Bool: scanBool,
reflect.Int: scanInt64,
reflect.Int8: scanInt64,
reflect.Int16: scanInt64,
reflect.Int32: scanInt64,
reflect.Int64: scanInt64,
reflect.Uint: scanUint64,
reflect.Uint8: scanUint64,
reflect.Uint16: scanUint64,
reflect.Uint32: scanUint64,
reflect.Uint64: scanUint64,
reflect.Uintptr: scanUint64,
reflect.Float32: scanFloat64,
reflect.Float64: scanFloat64,
reflect.Complex64: nil,
reflect.Complex128: nil,
reflect.Array: nil,
reflect.Chan: nil,
reflect.Func: nil,
reflect.Map: scanJSON,
reflect.Ptr: nil,
reflect.Slice: scanJSON,
reflect.String: scanString,
reflect.Struct: scanJSON,
reflect.UnsafePointer: nil,
}
func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
if field.Tag.HasOption("msgpack") {
return scanMsgpack
}
if field.Tag.HasOption("json_use_number") {
return scanJSONUseNumber
}
return dialect.Scanner(field.StructField.Type)
}
func Scanner(typ reflect.Type) ScannerFunc {
kind := typ.Kind()
if kind == reflect.Ptr {
if fn := Scanner(typ.Elem()); fn != nil {
return ptrScanner(fn)
}
}
if typ.Implements(scannerType) {
return scanScanner
}
if kind != reflect.Ptr {
ptr := reflect.PtrTo(typ)
if ptr.Implements(scannerType) {
return addrScanner(scanScanner)
}
}
switch typ {
case timeType:
return scanTime
case ipType:
return scanIP
case ipNetType:
return scanIPNet
case jsonRawMessageType:
return scanJSONRawMessage
}
return scanners[kind]
}
func scanBool(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetBool(false)
return nil
case bool:
dest.SetBool(src)
return nil
case int64:
dest.SetBool(src != 0)
return nil
case []byte:
if len(src) == 1 {
dest.SetBool(src[0] != '0')
return nil
}
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanInt64(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetInt(0)
return nil
case int64:
dest.SetInt(src)
return nil
case uint64:
dest.SetInt(int64(src))
return nil
case []byte:
n, err := strconv.ParseInt(internal.String(src), 10, 64)
if err != nil {
return err
}
dest.SetInt(n)
return nil
case string:
n, err := strconv.ParseInt(src, 10, 64)
if err != nil {
return err
}
dest.SetInt(n)
return nil
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanUint64(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetUint(0)
return nil
case uint64:
dest.SetUint(src)
return nil
case int64:
dest.SetUint(uint64(src))
return nil
case []byte:
n, err := strconv.ParseUint(internal.String(src), 10, 64)
if err != nil {
return err
}
dest.SetUint(n)
return nil
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanFloat64(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetFloat(0)
return nil
case float64:
dest.SetFloat(src)
return nil
case []byte:
f, err := strconv.ParseFloat(internal.String(src), 64)
if err != nil {
return err
}
dest.SetFloat(f)
return nil
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanString(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetString("")
return nil
case string:
dest.SetString(src)
return nil
case []byte:
dest.SetString(string(src))
return nil
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanTime(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
destTime := dest.Addr().Interface().(*time.Time)
*destTime = time.Time{}
return nil
case time.Time:
destTime := dest.Addr().Interface().(*time.Time)
*destTime = src
return nil
case string:
srcTime, err := internal.ParseTime(src)
if err != nil {
return err
}
destTime := dest.Addr().Interface().(*time.Time)
*destTime = srcTime
return nil
case []byte:
srcTime, err := internal.ParseTime(internal.String(src))
if err != nil {
return err
}
destTime := dest.Addr().Interface().(*time.Time)
*destTime = srcTime
return nil
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanScanner(dest reflect.Value, src interface{}) error {
return dest.Interface().(sql.Scanner).Scan(src)
}
func scanMsgpack(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
dec := msgpack.GetDecoder()
defer msgpack.PutDecoder(dec)
dec.Reset(bytes.NewReader(b))
return dec.DecodeValue(dest)
}
func scanJSON(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
return bunjson.Unmarshal(b, dest.Addr().Interface())
}
func scanJSONUseNumber(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
dec := bunjson.NewDecoder(bytes.NewReader(b))
dec.UseNumber()
return dec.Decode(dest.Addr().Interface())
}
func scanIP(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
ip := net.ParseIP(internal.String(b))
if ip == nil {
return fmt.Errorf("bun: invalid ip: %q", b)
}
ptr := dest.Addr().Interface().(*net.IP)
*ptr = ip
return nil
}
func scanIPNet(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
_, ipnet, err := net.ParseCIDR(internal.String(b))
if err != nil {
return err
}
ptr := dest.Addr().Interface().(*net.IPNet)
*ptr = *ipnet
return nil
}
func scanJSONRawMessage(dest reflect.Value, src interface{}) error {
if src == nil {
dest.SetBytes(nil)
return nil
}
b, err := toBytes(src)
if err != nil {
return err
}
dest.SetBytes(b)
return nil
}
func addrScanner(fn ScannerFunc) ScannerFunc {
return func(dest reflect.Value, src interface{}) error {
if !dest.CanAddr() {
return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface())
}
return fn(dest.Addr(), src)
}
}
func toBytes(src interface{}) ([]byte, error) {
switch src := src.(type) {
case string:
return internal.Bytes(src), nil
case []byte:
return src, nil
default:
return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src)
}
}
func ptrScanner(fn ScannerFunc) ScannerFunc {
return func(dest reflect.Value, src interface{}) error {
if src == nil {
if !dest.CanAddr() {
if dest.IsNil() {
return nil
}
return fn(dest.Elem(), src)
}
if !dest.IsNil() {
dest.Set(reflect.New(dest.Type().Elem()))
}
return nil
}
if dest.IsNil() {
dest.Set(reflect.New(dest.Type().Elem()))
}
return fn(dest.Elem(), src)
}
}
func scanNull(dest reflect.Value) error {
if nilable(dest.Kind()) && dest.IsNil() {
return nil
}
dest.Set(reflect.New(dest.Type()).Elem())
return nil
}
func nilable(kind reflect.Kind) bool {
switch kind {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return true
}
return false
}

76
vendor/github.com/uptrace/bun/schema/sqlfmt.go generated vendored Normal file
View file

@ -0,0 +1,76 @@
package schema
type QueryAppender interface {
AppendQuery(fmter Formatter, b []byte) ([]byte, error)
}
type ColumnsAppender interface {
AppendColumns(fmter Formatter, b []byte) ([]byte, error)
}
//------------------------------------------------------------------------------
// Safe represents a safe SQL query.
type Safe string
var _ QueryAppender = (*Safe)(nil)
func (s Safe) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
return append(b, s...), nil
}
//------------------------------------------------------------------------------
// Ident represents a SQL identifier, for example, table or column name.
type Ident string
var _ QueryAppender = (*Ident)(nil)
func (s Ident) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
return fmter.AppendIdent(b, string(s)), nil
}
//------------------------------------------------------------------------------
type QueryWithArgs struct {
Query string
Args []interface{}
}
var _ QueryAppender = QueryWithArgs{}
func SafeQuery(query string, args []interface{}) QueryWithArgs {
if query != "" && args == nil {
args = make([]interface{}, 0)
}
return QueryWithArgs{Query: query, Args: args}
}
func UnsafeIdent(ident string) QueryWithArgs {
return QueryWithArgs{Query: ident}
}
func (q QueryWithArgs) IsZero() bool {
return q.Query == "" && q.Args == nil
}
func (q QueryWithArgs) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
if q.Args == nil {
return fmter.AppendIdent(b, q.Query), nil
}
return fmter.AppendQuery(b, q.Query, q.Args...), nil
}
//------------------------------------------------------------------------------
type QueryWithSep struct {
QueryWithArgs
Sep string
}
func SafeQueryWithSep(query string, args []interface{}, sep string) QueryWithSep {
return QueryWithSep{
QueryWithArgs: SafeQuery(query, args),
Sep: sep,
}
}

129
vendor/github.com/uptrace/bun/schema/sqltype.go generated vendored Normal file
View file

@ -0,0 +1,129 @@
package schema
import (
"bytes"
"database/sql"
"encoding/json"
"fmt"
"reflect"
"time"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/internal"
)
var (
bunNullTimeType = reflect.TypeOf((*NullTime)(nil)).Elem()
nullTimeType = reflect.TypeOf((*sql.NullTime)(nil)).Elem()
nullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem()
nullFloatType = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem()
nullIntType = reflect.TypeOf((*sql.NullInt64)(nil)).Elem()
nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem()
)
var sqlTypes = []string{
reflect.Bool: sqltype.Boolean,
reflect.Int: sqltype.BigInt,
reflect.Int8: sqltype.SmallInt,
reflect.Int16: sqltype.SmallInt,
reflect.Int32: sqltype.Integer,
reflect.Int64: sqltype.BigInt,
reflect.Uint: sqltype.BigInt,
reflect.Uint8: sqltype.SmallInt,
reflect.Uint16: sqltype.SmallInt,
reflect.Uint32: sqltype.Integer,
reflect.Uint64: sqltype.BigInt,
reflect.Uintptr: sqltype.BigInt,
reflect.Float32: sqltype.Real,
reflect.Float64: sqltype.DoublePrecision,
reflect.Complex64: "",
reflect.Complex128: "",
reflect.Array: "",
reflect.Chan: "",
reflect.Func: "",
reflect.Interface: "",
reflect.Map: sqltype.VarChar,
reflect.Ptr: "",
reflect.Slice: sqltype.VarChar,
reflect.String: sqltype.VarChar,
reflect.Struct: sqltype.VarChar,
reflect.UnsafePointer: "",
}
func DiscoverSQLType(typ reflect.Type) string {
switch typ {
case timeType, nullTimeType, bunNullTimeType:
return sqltype.Timestamp
case nullBoolType:
return sqltype.Boolean
case nullFloatType:
return sqltype.DoublePrecision
case nullIntType:
return sqltype.BigInt
case nullStringType:
return sqltype.VarChar
}
return sqlTypes[typ.Kind()]
}
//------------------------------------------------------------------------------
var jsonNull = []byte("null")
// NullTime is a time.Time wrapper that marshals zero time as JSON null and SQL NULL.
type NullTime struct {
time.Time
}
var (
_ json.Marshaler = (*NullTime)(nil)
_ json.Unmarshaler = (*NullTime)(nil)
_ sql.Scanner = (*NullTime)(nil)
_ QueryAppender = (*NullTime)(nil)
)
func (tm NullTime) MarshalJSON() ([]byte, error) {
if tm.IsZero() {
return jsonNull, nil
}
return tm.Time.MarshalJSON()
}
func (tm *NullTime) UnmarshalJSON(b []byte) error {
if bytes.Equal(b, jsonNull) {
tm.Time = time.Time{}
return nil
}
return tm.Time.UnmarshalJSON(b)
}
func (tm NullTime) AppendQuery(fmter Formatter, b []byte) ([]byte, error) {
if tm.IsZero() {
return dialect.AppendNull(b), nil
}
return dialect.AppendTime(b, tm.Time), nil
}
func (tm *NullTime) Scan(src interface{}) error {
if src == nil {
tm.Time = time.Time{}
return nil
}
switch src := src.(type) {
case []byte:
newtm, err := internal.ParseTime(internal.String(src))
if err != nil {
return err
}
tm.Time = newtm
return nil
case time.Time:
tm.Time = src
return nil
default:
return fmt.Errorf("bun: can't scan %#v into NullTime", src)
}
}

948
vendor/github.com/uptrace/bun/schema/table.go generated vendored Normal file
View file

@ -0,0 +1,948 @@
package schema
import (
"database/sql"
"fmt"
"reflect"
"strings"
"sync"
"time"
"github.com/jinzhu/inflection"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/internal/tagparser"
)
const (
beforeScanHookFlag internal.Flag = 1 << iota
afterScanHookFlag
)
var (
baseModelType = reflect.TypeOf((*BaseModel)(nil)).Elem()
tableNameInflector = inflection.Plural
)
type BaseModel struct{}
// SetTableNameInflector overrides the default func that pluralizes
// model name to get table name, e.g. my_article becomes my_articles.
func SetTableNameInflector(fn func(string) string) {
tableNameInflector = fn
}
// Table represents a SQL table created from Go struct.
type Table struct {
dialect Dialect
Type reflect.Type
ZeroValue reflect.Value // reflect.Struct
ZeroIface interface{} // struct pointer
TypeName string
ModelName string
Name string
SQLName Safe
SQLNameForSelects Safe
Alias string
SQLAlias Safe
Fields []*Field // PKs + DataFields
PKs []*Field
DataFields []*Field
fieldsMapMu sync.RWMutex
FieldMap map[string]*Field
Relations map[string]*Relation
Unique map[string][]*Field
SoftDeleteField *Field
UpdateSoftDeleteField func(fv reflect.Value) error
allFields []*Field // read only
skippedFields []*Field
flags internal.Flag
}
func newTable(dialect Dialect, typ reflect.Type) *Table {
t := new(Table)
t.dialect = dialect
t.Type = typ
t.ZeroValue = reflect.New(t.Type).Elem()
t.ZeroIface = reflect.New(t.Type).Interface()
t.TypeName = internal.ToExported(t.Type.Name())
t.ModelName = internal.Underscore(t.Type.Name())
tableName := tableNameInflector(t.ModelName)
t.setName(tableName)
t.Alias = t.ModelName
t.SQLAlias = t.quoteIdent(t.ModelName)
hooks := []struct {
typ reflect.Type
flag internal.Flag
}{
{beforeScanHookType, beforeScanHookFlag},
{afterScanHookType, afterScanHookFlag},
}
typ = reflect.PtrTo(t.Type)
for _, hook := range hooks {
if typ.Implements(hook.typ) {
t.flags = t.flags.Set(hook.flag)
}
}
return t
}
func (t *Table) init1() {
t.initFields()
}
func (t *Table) init2() {
t.initInlines()
t.initRelations()
t.skippedFields = nil
}
func (t *Table) setName(name string) {
t.Name = name
t.SQLName = t.quoteIdent(name)
t.SQLNameForSelects = t.quoteIdent(name)
if t.SQLAlias == "" {
t.Alias = name
t.SQLAlias = t.quoteIdent(name)
}
}
func (t *Table) String() string {
return "model=" + t.TypeName
}
func (t *Table) CheckPKs() error {
if len(t.PKs) == 0 {
return fmt.Errorf("bun: %s does not have primary keys", t)
}
return nil
}
func (t *Table) addField(field *Field) {
t.Fields = append(t.Fields, field)
if field.IsPK {
t.PKs = append(t.PKs, field)
} else {
t.DataFields = append(t.DataFields, field)
}
t.FieldMap[field.Name] = field
}
func (t *Table) removeField(field *Field) {
t.Fields = removeField(t.Fields, field)
if field.IsPK {
t.PKs = removeField(t.PKs, field)
} else {
t.DataFields = removeField(t.DataFields, field)
}
delete(t.FieldMap, field.Name)
}
func (t *Table) fieldWithLock(name string) *Field {
t.fieldsMapMu.RLock()
field := t.FieldMap[name]
t.fieldsMapMu.RUnlock()
return field
}
func (t *Table) HasField(name string) bool {
_, ok := t.FieldMap[name]
return ok
}
func (t *Table) Field(name string) (*Field, error) {
field, ok := t.FieldMap[name]
if !ok {
return nil, fmt.Errorf("bun: %s does not have column=%s", t, name)
}
return field, nil
}
func (t *Table) fieldByGoName(name string) *Field {
for _, f := range t.allFields {
if f.GoName == name {
return f
}
}
return nil
}
func (t *Table) initFields() {
t.Fields = make([]*Field, 0, t.Type.NumField())
t.FieldMap = make(map[string]*Field, t.Type.NumField())
t.addFields(t.Type, nil)
if len(t.PKs) > 0 {
return
}
for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} {
if field, ok := t.FieldMap[name]; ok {
field.markAsPK()
t.PKs = []*Field{field}
t.DataFields = removeField(t.DataFields, field)
break
}
}
if len(t.PKs) == 1 {
switch t.PKs[0].IndirectType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
t.PKs[0].AutoIncrement = true
}
}
}
func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
// Make a copy so slice is not shared between fields.
index := make([]int, len(baseIndex))
copy(index, baseIndex)
if f.Anonymous {
if f.Tag.Get("bun") == "-" {
continue
}
if f.Name == "BaseModel" && f.Type == baseModelType {
if len(index) == 0 {
t.processBaseModelField(f)
}
continue
}
fieldType := indirectType(f.Type)
if fieldType.Kind() != reflect.Struct {
continue
}
t.addFields(fieldType, append(index, f.Index...))
tag := tagparser.Parse(f.Tag.Get("bun"))
if _, inherit := tag.Options["inherit"]; inherit {
embeddedTable := t.dialect.Tables().Ref(fieldType)
t.TypeName = embeddedTable.TypeName
t.SQLName = embeddedTable.SQLName
t.SQLNameForSelects = embeddedTable.SQLNameForSelects
t.Alias = embeddedTable.Alias
t.SQLAlias = embeddedTable.SQLAlias
t.ModelName = embeddedTable.ModelName
}
continue
}
field := t.newField(f, index)
if field != nil {
t.addField(field)
}
}
}
func (t *Table) processBaseModelField(f reflect.StructField) {
tag := tagparser.Parse(f.Tag.Get("bun"))
if isKnownTableOption(tag.Name) {
internal.Warn.Printf(
"%s.%s tag name %q is also an option name; is it a mistake?",
t.TypeName, f.Name, tag.Name,
)
}
for name := range tag.Options {
if !isKnownTableOption(name) {
internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
}
}
if tag.Name != "" {
t.setName(tag.Name)
}
if s, ok := tag.Options["select"]; ok {
t.SQLNameForSelects = t.quoteTableName(s)
}
if s, ok := tag.Options["alias"]; ok {
t.Alias = s
t.SQLAlias = t.quoteIdent(s)
}
}
//nolint
func (t *Table) newField(f reflect.StructField, index []int) *Field {
tag := tagparser.Parse(f.Tag.Get("bun"))
if f.PkgPath != "" {
return nil
}
sqlName := internal.Underscore(f.Name)
if tag.Name != sqlName && isKnownFieldOption(tag.Name) {
internal.Warn.Printf(
"%s.%s tag name %q is also an option name; is it a mistake?",
t.TypeName, f.Name, tag.Name,
)
}
for name := range tag.Options {
if !isKnownFieldOption(name) {
internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
}
}
skip := tag.Name == "-"
if !skip && tag.Name != "" {
sqlName = tag.Name
}
index = append(index, f.Index...)
if field := t.fieldWithLock(sqlName); field != nil {
if indexEqual(field.Index, index) {
return field
}
t.removeField(field)
}
field := &Field{
StructField: f,
Tag: tag,
IndirectType: indirectType(f.Type),
Index: index,
Name: sqlName,
GoName: f.Name,
SQLName: t.quoteIdent(sqlName),
}
field.NotNull = tag.HasOption("notnull")
field.NullZero = tag.HasOption("nullzero")
field.AutoIncrement = tag.HasOption("autoincrement")
if tag.HasOption("pk") {
field.markAsPK()
}
if tag.HasOption("allowzero") {
if tag.HasOption("nullzero") {
internal.Warn.Printf(
"%s.%s: nullzero and allowzero options are mutually exclusive",
t.TypeName, f.Name,
)
}
field.NullZero = false
}
if v, ok := tag.Options["unique"]; ok {
// Split the value by comma, this will allow multiple names to be specified.
// We can use this to create multiple named unique constraints where a single column
// might be included in multiple constraints.
for _, uniqueName := range strings.Split(v, ",") {
if t.Unique == nil {
t.Unique = make(map[string][]*Field)
}
t.Unique[uniqueName] = append(t.Unique[uniqueName], field)
}
}
if s, ok := tag.Options["default"]; ok {
field.SQLDefault = s
}
if s, ok := field.Tag.Options["type"]; ok {
field.UserSQLType = s
}
field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType)
field.Append = t.dialect.FieldAppender(field)
field.Scan = FieldScanner(t.dialect, field)
field.IsZero = FieldZeroChecker(field)
if v, ok := tag.Options["alt"]; ok {
t.FieldMap[v] = field
}
t.allFields = append(t.allFields, field)
if skip {
t.skippedFields = append(t.skippedFields, field)
t.FieldMap[field.Name] = field
return nil
}
if _, ok := tag.Options["soft_delete"]; ok {
field.NullZero = true
t.SoftDeleteField = field
t.UpdateSoftDeleteField = softDeleteFieldUpdater(field)
}
return field
}
func (t *Table) initInlines() {
for _, f := range t.skippedFields {
if f.IndirectType.Kind() == reflect.Struct {
t.inlineFields(f, nil)
}
}
}
//---------------------------------------------------------------------------------------
func (t *Table) initRelations() {
for i := 0; i < len(t.Fields); {
f := t.Fields[i]
if t.tryRelation(f) {
t.Fields = removeField(t.Fields, f)
t.DataFields = removeField(t.DataFields, f)
} else {
i++
}
if f.IndirectType.Kind() == reflect.Struct {
t.inlineFields(f, nil)
}
}
}
func (t *Table) tryRelation(field *Field) bool {
if rel, ok := field.Tag.Options["rel"]; ok {
t.initRelation(field, rel)
return true
}
if field.Tag.HasOption("m2m") {
t.addRelation(t.m2mRelation(field))
return true
}
if field.Tag.HasOption("join") {
internal.Warn.Printf(
`%s.%s option "join" requires a relation type`,
t.TypeName, field.GoName,
)
}
return false
}
func (t *Table) initRelation(field *Field, rel string) {
switch rel {
case "belongs-to":
t.addRelation(t.belongsToRelation(field))
case "has-one":
t.addRelation(t.hasOneRelation(field))
case "has-many":
t.addRelation(t.hasManyRelation(field))
default:
panic(fmt.Errorf("bun: unknown relation=%s on field=%s", rel, field.GoName))
}
}
func (t *Table) addRelation(rel *Relation) {
if t.Relations == nil {
t.Relations = make(map[string]*Relation)
}
_, ok := t.Relations[rel.Field.GoName]
if ok {
panic(fmt.Errorf("%s already has %s", t, rel))
}
t.Relations[rel.Field.GoName] = rel
}
func (t *Table) belongsToRelation(field *Field) *Relation {
joinTable := t.dialect.Tables().Ref(field.IndirectType)
if err := joinTable.CheckPKs(); err != nil {
panic(err)
}
rel := &Relation{
Type: HasOneRelation,
Field: field,
JoinTable: joinTable,
}
if join, ok := field.Tag.Options["join"]; ok {
baseColumns, joinColumns := parseRelationJoin(join)
for i, baseColumn := range baseColumns {
joinColumn := joinColumns[i]
if f := t.fieldWithLock(baseColumn); f != nil {
rel.BaseFields = append(rel.BaseFields, f)
} else {
panic(fmt.Errorf(
"bun: %s belongs-to %s: %s must have column %s",
t.TypeName, field.GoName, t.TypeName, baseColumn,
))
}
if f := joinTable.fieldWithLock(joinColumn); f != nil {
rel.JoinFields = append(rel.JoinFields, f)
} else {
panic(fmt.Errorf(
"bun: %s belongs-to %s: %s must have column %s",
t.TypeName, field.GoName, t.TypeName, baseColumn,
))
}
}
return rel
}
rel.JoinFields = joinTable.PKs
fkPrefix := internal.Underscore(field.GoName) + "_"
for _, joinPK := range joinTable.PKs {
fkName := fkPrefix + joinPK.Name
if fk := t.fieldWithLock(fkName); fk != nil {
rel.BaseFields = append(rel.BaseFields, fk)
continue
}
if fk := t.fieldWithLock(joinPK.Name); fk != nil {
rel.BaseFields = append(rel.BaseFields, fk)
continue
}
panic(fmt.Errorf(
"bun: %s belongs-to %s: %s must have column %s "+
"(to override, use join:base_column=join_column tag on %s field)",
t.TypeName, field.GoName, t.TypeName, fkName, field.GoName,
))
}
return rel
}
func (t *Table) hasOneRelation(field *Field) *Relation {
if err := t.CheckPKs(); err != nil {
panic(err)
}
joinTable := t.dialect.Tables().Ref(field.IndirectType)
rel := &Relation{
Type: BelongsToRelation,
Field: field,
JoinTable: joinTable,
}
if join, ok := field.Tag.Options["join"]; ok {
baseColumns, joinColumns := parseRelationJoin(join)
for i, baseColumn := range baseColumns {
if f := t.fieldWithLock(baseColumn); f != nil {
rel.BaseFields = append(rel.BaseFields, f)
} else {
panic(fmt.Errorf(
"bun: %s has-one %s: %s must have column %s",
field.GoName, t.TypeName, joinTable.TypeName, baseColumn,
))
}
joinColumn := joinColumns[i]
if f := joinTable.fieldWithLock(joinColumn); f != nil {
rel.JoinFields = append(rel.JoinFields, f)
} else {
panic(fmt.Errorf(
"bun: %s has-one %s: %s must have column %s",
field.GoName, t.TypeName, joinTable.TypeName, baseColumn,
))
}
}
return rel
}
rel.BaseFields = t.PKs
fkPrefix := internal.Underscore(t.ModelName) + "_"
for _, pk := range t.PKs {
fkName := fkPrefix + pk.Name
if f := joinTable.fieldWithLock(fkName); f != nil {
rel.JoinFields = append(rel.JoinFields, f)
continue
}
if f := joinTable.fieldWithLock(pk.Name); f != nil {
rel.JoinFields = append(rel.JoinFields, f)
continue
}
panic(fmt.Errorf(
"bun: %s has-one %s: %s must have column %s "+
"(to override, use join:base_column=join_column tag on %s field)",
field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName,
))
}
return rel
}
func (t *Table) hasManyRelation(field *Field) *Relation {
if err := t.CheckPKs(); err != nil {
panic(err)
}
if field.IndirectType.Kind() != reflect.Slice {
panic(fmt.Errorf(
"bun: %s.%s has-many relation requires slice, got %q",
t.TypeName, field.GoName, field.IndirectType.Kind(),
))
}
joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
polymorphicValue, isPolymorphic := field.Tag.Options["polymorphic"]
rel := &Relation{
Type: HasManyRelation,
Field: field,
JoinTable: joinTable,
}
var polymorphicColumn string
if join, ok := field.Tag.Options["join"]; ok {
baseColumns, joinColumns := parseRelationJoin(join)
for i, baseColumn := range baseColumns {
joinColumn := joinColumns[i]
if isPolymorphic && baseColumn == "type" {
polymorphicColumn = joinColumn
continue
}
if f := t.fieldWithLock(baseColumn); f != nil {
rel.BaseFields = append(rel.BaseFields, f)
} else {
panic(fmt.Errorf(
"bun: %s has-one %s: %s must have column %s",
t.TypeName, field.GoName, t.TypeName, baseColumn,
))
}
if f := joinTable.fieldWithLock(joinColumn); f != nil {
rel.JoinFields = append(rel.JoinFields, f)
} else {
panic(fmt.Errorf(
"bun: %s has-one %s: %s must have column %s",
t.TypeName, field.GoName, t.TypeName, baseColumn,
))
}
}
} else {
rel.BaseFields = t.PKs
fkPrefix := internal.Underscore(t.ModelName) + "_"
if isPolymorphic {
polymorphicColumn = fkPrefix + "type"
}
for _, pk := range t.PKs {
joinColumn := fkPrefix + pk.Name
if fk := joinTable.fieldWithLock(joinColumn); fk != nil {
rel.JoinFields = append(rel.JoinFields, fk)
continue
}
if fk := joinTable.fieldWithLock(pk.Name); fk != nil {
rel.JoinFields = append(rel.JoinFields, fk)
continue
}
panic(fmt.Errorf(
"bun: %s has-many %s: %s must have column %s "+
"(to override, use join:base_column=join_column tag on the field %s)",
t.TypeName, field.GoName, joinTable.TypeName, joinColumn, field.GoName,
))
}
}
if isPolymorphic {
rel.PolymorphicField = joinTable.fieldWithLock(polymorphicColumn)
if rel.PolymorphicField == nil {
panic(fmt.Errorf(
"bun: %s has-many %s: %s must have polymorphic column %s",
t.TypeName, field.GoName, joinTable.TypeName, polymorphicColumn,
))
}
if polymorphicValue == "" {
polymorphicValue = t.ModelName
}
rel.PolymorphicValue = polymorphicValue
}
return rel
}
func (t *Table) m2mRelation(field *Field) *Relation {
if field.IndirectType.Kind() != reflect.Slice {
panic(fmt.Errorf(
"bun: %s.%s m2m relation requires slice, got %q",
t.TypeName, field.GoName, field.IndirectType.Kind(),
))
}
joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
if err := t.CheckPKs(); err != nil {
panic(err)
}
if err := joinTable.CheckPKs(); err != nil {
panic(err)
}
m2mTableName, ok := field.Tag.Options["m2m"]
if !ok {
panic(fmt.Errorf("bun: %s must have m2m tag option", field.GoName))
}
m2mTable := t.dialect.Tables().ByName(m2mTableName)
if m2mTable == nil {
panic(fmt.Errorf(
"bun: can't find m2m %s table (use db.RegisterModel)",
m2mTableName,
))
}
rel := &Relation{
Type: ManyToManyRelation,
Field: field,
JoinTable: joinTable,
M2MTable: m2mTable,
}
var leftColumn, rightColumn string
if join, ok := field.Tag.Options["join"]; ok {
left, right := parseRelationJoin(join)
leftColumn = left[0]
rightColumn = right[0]
} else {
leftColumn = t.TypeName
rightColumn = joinTable.TypeName
}
leftField := m2mTable.fieldByGoName(leftColumn)
if leftField == nil {
panic(fmt.Errorf(
"bun: %s many-to-many %s: %s must have field %s "+
"(to override, use tag join:LeftField=RightField on field %s.%s",
t.TypeName, field.GoName, m2mTable.TypeName, leftColumn, t.TypeName, field.GoName,
))
}
rightField := m2mTable.fieldByGoName(rightColumn)
if rightField == nil {
panic(fmt.Errorf(
"bun: %s many-to-many %s: %s must have field %s "+
"(to override, use tag join:LeftField=RightField on field %s.%s",
t.TypeName, field.GoName, m2mTable.TypeName, rightColumn, t.TypeName, field.GoName,
))
}
leftRel := m2mTable.belongsToRelation(leftField)
rel.BaseFields = leftRel.JoinFields
rel.M2MBaseFields = leftRel.BaseFields
rightRel := m2mTable.belongsToRelation(rightField)
rel.JoinFields = rightRel.JoinFields
rel.M2MJoinFields = rightRel.BaseFields
return rel
}
func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) {
if path == nil {
path = map[reflect.Type]struct{}{
t.Type: {},
}
}
if _, ok := path[field.IndirectType]; ok {
return
}
path[field.IndirectType] = struct{}{}
joinTable := t.dialect.Tables().Ref(field.IndirectType)
for _, f := range joinTable.allFields {
f = f.Clone()
f.GoName = field.GoName + "_" + f.GoName
f.Name = field.Name + "__" + f.Name
f.SQLName = t.quoteIdent(f.Name)
f.Index = appendNew(field.Index, f.Index...)
t.fieldsMapMu.Lock()
if _, ok := t.FieldMap[f.Name]; !ok {
t.FieldMap[f.Name] = f
}
t.fieldsMapMu.Unlock()
if f.IndirectType.Kind() != reflect.Struct {
continue
}
if _, ok := path[f.IndirectType]; !ok {
t.inlineFields(f, path)
}
}
}
//------------------------------------------------------------------------------
func (t *Table) Dialect() Dialect { return t.dialect }
//------------------------------------------------------------------------------
func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) }
func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) }
//------------------------------------------------------------------------------
func (t *Table) AppendNamedArg(
fmter Formatter, b []byte, name string, strct reflect.Value,
) ([]byte, bool) {
if field, ok := t.FieldMap[name]; ok {
return fmter.appendArg(b, field.Value(strct).Interface()), true
}
return b, false
}
func (t *Table) quoteTableName(s string) Safe {
// Don't quote if table name contains placeholder (?) or parentheses.
if strings.IndexByte(s, '?') >= 0 ||
strings.IndexByte(s, '(') >= 0 ||
strings.IndexByte(s, ')') >= 0 {
return Safe(s)
}
return t.quoteIdent(s)
}
func (t *Table) quoteIdent(s string) Safe {
return Safe(NewFormatter(t.dialect).AppendIdent(nil, s))
}
func appendNew(dst []int, src ...int) []int {
cp := make([]int, len(dst)+len(src))
copy(cp, dst)
copy(cp[len(dst):], src)
return cp
}
func isKnownTableOption(name string) bool {
switch name {
case "alias", "select":
return true
}
return false
}
func isKnownFieldOption(name string) bool {
switch name {
case "alias",
"type",
"array",
"hstore",
"composite",
"json_use_number",
"msgpack",
"notnull",
"nullzero",
"allowzero",
"default",
"unique",
"soft_delete",
"pk",
"autoincrement",
"rel",
"join",
"m2m",
"polymorphic":
return true
}
return false
}
func removeField(fields []*Field, field *Field) []*Field {
for i, f := range fields {
if f == field {
return append(fields[:i], fields[i+1:]...)
}
}
return fields
}
func parseRelationJoin(join string) ([]string, []string) {
ss := strings.Split(join, ",")
baseColumns := make([]string, len(ss))
joinColumns := make([]string, len(ss))
for i, s := range ss {
ss := strings.Split(strings.TrimSpace(s), "=")
if len(ss) != 2 {
panic(fmt.Errorf("can't parse relation join: %q", join))
}
baseColumns[i] = ss[0]
joinColumns[i] = ss[1]
}
return baseColumns, joinColumns
}
//------------------------------------------------------------------------------
func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error {
typ := field.StructField.Type
switch typ {
case timeType:
return func(fv reflect.Value) error {
ptr := fv.Addr().Interface().(*time.Time)
*ptr = time.Now()
return nil
}
case nullTimeType:
return func(fv reflect.Value) error {
ptr := fv.Addr().Interface().(*sql.NullTime)
*ptr = sql.NullTime{Time: time.Now()}
return nil
}
case nullIntType:
return func(fv reflect.Value) error {
ptr := fv.Addr().Interface().(*sql.NullInt64)
*ptr = sql.NullInt64{Int64: time.Now().UnixNano()}
return nil
}
}
switch field.IndirectType.Kind() {
case reflect.Int64:
return func(fv reflect.Value) error {
ptr := fv.Addr().Interface().(*int64)
*ptr = time.Now().UnixNano()
return nil
}
case reflect.Ptr:
typ = typ.Elem()
default:
return softDeleteFieldUpdaterFallback(field)
}
switch typ { //nolint:gocritic
case timeType:
return func(fv reflect.Value) error {
now := time.Now()
fv.Set(reflect.ValueOf(&now))
return nil
}
}
switch typ.Kind() { //nolint:gocritic
case reflect.Int64:
return func(fv reflect.Value) error {
utime := time.Now().UnixNano()
fv.Set(reflect.ValueOf(&utime))
return nil
}
}
return softDeleteFieldUpdaterFallback(field)
}
func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error {
return func(fv reflect.Value) error {
return field.ScanWithCheck(fv, time.Now())
}
}

148
vendor/github.com/uptrace/bun/schema/tables.go generated vendored Normal file
View file

@ -0,0 +1,148 @@
package schema
import (
"fmt"
"reflect"
"sync"
)
type tableInProgress struct {
table *Table
init1Once sync.Once
init2Once sync.Once
}
func newTableInProgress(table *Table) *tableInProgress {
return &tableInProgress{
table: table,
}
}
func (inp *tableInProgress) init1() bool {
var inited bool
inp.init1Once.Do(func() {
inp.table.init1()
inited = true
})
return inited
}
func (inp *tableInProgress) init2() bool {
var inited bool
inp.init2Once.Do(func() {
inp.table.init2()
inited = true
})
return inited
}
type Tables struct {
dialect Dialect
tables sync.Map
mu sync.RWMutex
inProgress map[reflect.Type]*tableInProgress
}
func NewTables(dialect Dialect) *Tables {
return &Tables{
dialect: dialect,
inProgress: make(map[reflect.Type]*tableInProgress),
}
}
func (t *Tables) Register(models ...interface{}) {
for _, model := range models {
_ = t.Get(reflect.TypeOf(model).Elem())
}
}
func (t *Tables) Get(typ reflect.Type) *Table {
return t.table(typ, false)
}
func (t *Tables) Ref(typ reflect.Type) *Table {
return t.table(typ, true)
}
func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table {
if typ.Kind() != reflect.Struct {
panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
}
if v, ok := t.tables.Load(typ); ok {
return v.(*Table)
}
t.mu.Lock()
if v, ok := t.tables.Load(typ); ok {
t.mu.Unlock()
return v.(*Table)
}
var table *Table
inProgress := t.inProgress[typ]
if inProgress == nil {
table = newTable(t.dialect, typ)
inProgress = newTableInProgress(table)
t.inProgress[typ] = inProgress
} else {
table = inProgress.table
}
t.mu.Unlock()
inProgress.init1()
if allowInProgress {
return table
}
if inProgress.init2() {
t.mu.Lock()
delete(t.inProgress, typ)
t.tables.Store(typ, table)
t.mu.Unlock()
}
t.dialect.OnTable(table)
for _, field := range table.FieldMap {
if field.UserSQLType == "" {
field.UserSQLType = field.DiscoveredSQLType
}
if field.CreateTableSQLType == "" {
field.CreateTableSQLType = field.UserSQLType
}
}
return table
}
func (t *Tables) ByModel(name string) *Table {
var found *Table
t.tables.Range(func(key, value interface{}) bool {
t := value.(*Table)
if t.TypeName == name {
found = t
return false
}
return true
})
return found
}
func (t *Tables) ByName(name string) *Table {
var found *Table
t.tables.Range(func(key, value interface{}) bool {
t := value.(*Table)
if t.Name == name {
found = t
return false
}
return true
})
return found
}

53
vendor/github.com/uptrace/bun/schema/util.go generated vendored Normal file
View file

@ -0,0 +1,53 @@
package schema
import "reflect"
func indirectType(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}
func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) {
if len(index) == 1 {
return v.Field(index[0]), true
}
for i, idx := range index {
if i > 0 {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return v, false
}
v = v.Elem()
}
}
v = v.Field(idx)
}
return v, true
}
func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value {
if len(index) == 1 {
return v.Field(index[0])
}
for i, idx := range index {
if i > 0 {
v = indirectNil(v)
}
v = v.Field(idx)
}
return v
}
func indirectNil(v reflect.Value) reflect.Value {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
return v
}

126
vendor/github.com/uptrace/bun/schema/zerochecker.go generated vendored Normal file
View file

@ -0,0 +1,126 @@
package schema
import (
"database/sql/driver"
"reflect"
)
var isZeroerType = reflect.TypeOf((*isZeroer)(nil)).Elem()
type isZeroer interface {
IsZero() bool
}
type IsZeroerFunc func(reflect.Value) bool
func FieldZeroChecker(field *Field) IsZeroerFunc {
return zeroChecker(field.IndirectType)
}
func zeroChecker(typ reflect.Type) IsZeroerFunc {
if typ.Implements(isZeroerType) {
return isZeroInterface
}
kind := typ.Kind()
if kind != reflect.Ptr {
ptr := reflect.PtrTo(typ)
if ptr.Implements(isZeroerType) {
return addrChecker(isZeroInterface)
}
}
switch kind {
case reflect.Array:
if typ.Elem().Kind() == reflect.Uint8 {
return isZeroBytes
}
return isZeroLen
case reflect.String:
return isZeroLen
case reflect.Bool:
return isZeroBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return isZeroInt
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return isZeroUint
case reflect.Float32, reflect.Float64:
return isZeroFloat
case reflect.Interface, reflect.Ptr, reflect.Slice, reflect.Map:
return isNil
}
if typ.Implements(driverValuerType) {
return isZeroDriverValue
}
return notZero
}
func addrChecker(fn IsZeroerFunc) IsZeroerFunc {
return func(v reflect.Value) bool {
if !v.CanAddr() {
return false
}
return fn(v.Addr())
}
}
func isZeroInterface(v reflect.Value) bool {
if v.Kind() == reflect.Ptr && v.IsNil() {
return true
}
return v.Interface().(isZeroer).IsZero()
}
func isZeroDriverValue(v reflect.Value) bool {
if v.Kind() == reflect.Ptr {
return v.IsNil()
}
valuer := v.Interface().(driver.Valuer)
value, err := valuer.Value()
if err != nil {
return false
}
return value == nil
}
func isZeroLen(v reflect.Value) bool {
return v.Len() == 0
}
func isNil(v reflect.Value) bool {
return v.IsNil()
}
func isZeroBool(v reflect.Value) bool {
return !v.Bool()
}
func isZeroInt(v reflect.Value) bool {
return v.Int() == 0
}
func isZeroUint(v reflect.Value) bool {
return v.Uint() == 0
}
func isZeroFloat(v reflect.Value) bool {
return v.Float() == 0
}
func isZeroBytes(v reflect.Value) bool {
b := v.Slice(0, v.Len()).Bytes()
for _, c := range b {
if c != 0 {
return false
}
}
return true
}
func notZero(v reflect.Value) bool {
return false
}