update bun library -> v1.0.4

Signed-off-by: kim (grufwub) <grufwub@gmail.com>
This commit is contained in:
kim (grufwub) 2021-09-08 20:05:26 +01:00
commit bdcc090851
87 changed files with 538 additions and 503 deletions

View file

@ -2,7 +2,6 @@ package schema
import (
"database/sql/driver"
"encoding/json"
"fmt"
"net"
"reflect"
@ -14,16 +13,6 @@ import (
"github.com/uptrace/bun/internal"
)
var (
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem()
)
type (
AppenderFunc func(fmter Formatter, b []byte, v reflect.Value) []byte
CustomAppender func(typ reflect.Type) AppenderFunc
@ -60,6 +49,8 @@ var appenders = []AppenderFunc{
func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc {
switch typ {
case bytesType:
return appendBytesValue
case timeType:
return appendTimeValue
case ipType:
@ -93,7 +84,9 @@ func Appender(typ reflect.Type, custom CustomAppender) AppenderFunc {
case reflect.Interface:
return ifaceAppenderFunc(typ, custom)
case reflect.Ptr:
return ptrAppenderFunc(typ, custom)
if fn := Appender(typ.Elem(), custom); fn != nil {
return PtrAppender(fn)
}
case reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return appendBytesValue
@ -123,13 +116,12 @@ func ifaceAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc)
}
}
func ptrAppenderFunc(typ reflect.Type, custom func(reflect.Type) AppenderFunc) AppenderFunc {
appender := Appender(typ.Elem(), custom)
func PtrAppender(fn AppenderFunc) AppenderFunc {
return func(fmter Formatter, b []byte, v reflect.Value) []byte {
if v.IsNil() {
return dialect.AppendNull(b)
}
return appender(fmter, b, v.Elem())
return fn(fmter, b, v.Elem())
}
}

View file

@ -89,10 +89,10 @@ func (f Formatter) AppendQuery(dst []byte, query string, args ...interface{}) []
func (f Formatter) append(dst []byte, p *parser.Parser, args []interface{}) []byte {
var namedArgs NamedArgAppender
if len(args) == 1 {
var ok bool
namedArgs, ok = args[0].(NamedArgAppender)
if !ok {
namedArgs, _ = newStructArgs(f, args[0])
if v, ok := args[0].(NamedArgAppender); ok {
namedArgs = v
} else if v, ok := newStructArgs(f, args[0]); ok {
namedArgs = v
}
}

View file

@ -1,6 +1,23 @@
package schema
import "reflect"
import (
"database/sql/driver"
"encoding/json"
"net"
"reflect"
"time"
)
var (
bytesType = reflect.TypeOf((*[]byte)(nil)).Elem()
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
queryAppenderType = reflect.TypeOf((*QueryAppender)(nil)).Elem()
)
func indirectType(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {

View file

@ -7,10 +7,12 @@ import (
"net"
"reflect"
"strconv"
"strings"
"time"
"github.com/vmihailenco/msgpack/v5"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/extra/bunjson"
"github.com/uptrace/bun/internal"
)
@ -19,32 +21,35 @@ var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
type ScannerFunc func(dest reflect.Value, src interface{}) error
var scanners = []ScannerFunc{
reflect.Bool: scanBool,
reflect.Int: scanInt64,
reflect.Int8: scanInt64,
reflect.Int16: scanInt64,
reflect.Int32: scanInt64,
reflect.Int64: scanInt64,
reflect.Uint: scanUint64,
reflect.Uint8: scanUint64,
reflect.Uint16: scanUint64,
reflect.Uint32: scanUint64,
reflect.Uint64: scanUint64,
reflect.Uintptr: scanUint64,
reflect.Float32: scanFloat64,
reflect.Float64: scanFloat64,
reflect.Complex64: nil,
reflect.Complex128: nil,
reflect.Array: nil,
reflect.Chan: nil,
reflect.Func: nil,
reflect.Map: scanJSON,
reflect.Ptr: nil,
reflect.Slice: scanJSON,
reflect.String: scanString,
reflect.Struct: scanJSON,
reflect.UnsafePointer: nil,
var scanners []ScannerFunc
func init() {
scanners = []ScannerFunc{
reflect.Bool: scanBool,
reflect.Int: scanInt64,
reflect.Int8: scanInt64,
reflect.Int16: scanInt64,
reflect.Int32: scanInt64,
reflect.Int64: scanInt64,
reflect.Uint: scanUint64,
reflect.Uint8: scanUint64,
reflect.Uint16: scanUint64,
reflect.Uint32: scanUint64,
reflect.Uint64: scanUint64,
reflect.Uintptr: scanUint64,
reflect.Float32: scanFloat64,
reflect.Float64: scanFloat64,
reflect.Complex64: nil,
reflect.Complex128: nil,
reflect.Array: nil,
reflect.Interface: scanInterface,
reflect.Map: scanJSON,
reflect.Ptr: nil,
reflect.Slice: scanJSON,
reflect.String: scanString,
reflect.Struct: scanJSON,
reflect.UnsafePointer: nil,
}
}
func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
@ -54,6 +59,12 @@ func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
if field.Tag.HasOption("json_use_number") {
return scanJSONUseNumber
}
if field.StructField.Type.Kind() == reflect.Interface {
switch strings.ToUpper(field.UserSQLType) {
case sqltype.JSON, sqltype.JSONB:
return scanJSONIntoInterface
}
}
return dialect.Scanner(field.StructField.Type)
}
@ -62,7 +73,7 @@ func Scanner(typ reflect.Type) ScannerFunc {
if kind == reflect.Ptr {
if fn := Scanner(typ.Elem()); fn != nil {
return ptrScanner(fn)
return PtrScanner(fn)
}
}
@ -84,6 +95,8 @@ func Scanner(typ reflect.Type) ScannerFunc {
return scanIP
case ipNetType:
return scanIPNet
case bytesType:
return scanBytes
case jsonRawMessageType:
return scanJSONRawMessage
}
@ -196,6 +209,21 @@ func scanString(dest reflect.Value, src interface{}) error {
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanBytes(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetBytes(nil)
return nil
case string:
dest.SetBytes([]byte(src))
return nil
case []byte:
dest.SetBytes(src)
return nil
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanTime(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
@ -352,7 +380,7 @@ func toBytes(src interface{}) ([]byte, error) {
}
}
func ptrScanner(fn ScannerFunc) ScannerFunc {
func PtrScanner(fn ScannerFunc) ScannerFunc {
return func(dest reflect.Value, src interface{}) error {
if src == nil {
if !dest.CanAddr() {
@ -383,6 +411,43 @@ func scanNull(dest reflect.Value) error {
return nil
}
func scanJSONIntoInterface(dest reflect.Value, src interface{}) error {
if dest.IsNil() {
if src == nil {
return nil
}
b, err := toBytes(src)
if err != nil {
return err
}
return bunjson.Unmarshal(b, dest.Addr().Interface())
}
dest = dest.Elem()
if fn := Scanner(dest.Type()); fn != nil {
return fn(dest, src)
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func scanInterface(dest reflect.Value, src interface{}) error {
if dest.IsNil() {
if src == nil {
return nil
}
dest.Set(reflect.ValueOf(src))
return nil
}
dest = dest.Elem()
if fn := Scanner(dest.Type()); fn != nil {
return fn(dest, src)
}
return fmt.Errorf("bun: can't scan %#v into %s", src, dest.Type())
}
func nilable(kind reflect.Kind) bool {
switch kind {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:

View file

@ -40,7 +40,7 @@ type QueryWithArgs struct {
var _ QueryAppender = QueryWithArgs{}
func SafeQuery(query string, args []interface{}) QueryWithArgs {
if query != "" && args == nil {
if args == nil {
args = make([]interface{}, 0)
}
return QueryWithArgs{Query: query, Args: args}

View file

@ -23,32 +23,29 @@ var (
)
var sqlTypes = []string{
reflect.Bool: sqltype.Boolean,
reflect.Int: sqltype.BigInt,
reflect.Int8: sqltype.SmallInt,
reflect.Int16: sqltype.SmallInt,
reflect.Int32: sqltype.Integer,
reflect.Int64: sqltype.BigInt,
reflect.Uint: sqltype.BigInt,
reflect.Uint8: sqltype.SmallInt,
reflect.Uint16: sqltype.SmallInt,
reflect.Uint32: sqltype.Integer,
reflect.Uint64: sqltype.BigInt,
reflect.Uintptr: sqltype.BigInt,
reflect.Float32: sqltype.Real,
reflect.Float64: sqltype.DoublePrecision,
reflect.Complex64: "",
reflect.Complex128: "",
reflect.Array: "",
reflect.Chan: "",
reflect.Func: "",
reflect.Interface: "",
reflect.Map: sqltype.VarChar,
reflect.Ptr: "",
reflect.Slice: sqltype.VarChar,
reflect.String: sqltype.VarChar,
reflect.Struct: sqltype.VarChar,
reflect.UnsafePointer: "",
reflect.Bool: sqltype.Boolean,
reflect.Int: sqltype.BigInt,
reflect.Int8: sqltype.SmallInt,
reflect.Int16: sqltype.SmallInt,
reflect.Int32: sqltype.Integer,
reflect.Int64: sqltype.BigInt,
reflect.Uint: sqltype.BigInt,
reflect.Uint8: sqltype.SmallInt,
reflect.Uint16: sqltype.SmallInt,
reflect.Uint32: sqltype.Integer,
reflect.Uint64: sqltype.BigInt,
reflect.Uintptr: sqltype.BigInt,
reflect.Float32: sqltype.Real,
reflect.Float64: sqltype.DoublePrecision,
reflect.Complex64: "",
reflect.Complex128: "",
reflect.Array: "",
reflect.Interface: "",
reflect.Map: sqltype.VarChar,
reflect.Ptr: "",
reflect.Slice: sqltype.VarChar,
reflect.String: sqltype.VarChar,
reflect.Struct: sqltype.VarChar,
}
func DiscoverSQLType(typ reflect.Type) string {

View file

@ -60,10 +60,9 @@ type Table struct {
Unique map[string][]*Field
SoftDeleteField *Field
UpdateSoftDeleteField func(fv reflect.Value) error
UpdateSoftDeleteField func(fv reflect.Value, tm time.Time) error
allFields []*Field // read only
skippedFields []*Field
allFields []*Field // read only
flags internal.Flag
}
@ -104,9 +103,7 @@ func (t *Table) init1() {
}
func (t *Table) init2() {
t.initInlines()
t.initRelations()
t.skippedFields = nil
}
func (t *Table) setName(name string) {
@ -207,15 +204,20 @@ func (t *Table) initFields() {
func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
unexported := f.PkgPath != ""
// Make a copy so slice is not shared between fields.
if unexported && !f.Anonymous { // unexported
continue
}
if f.Tag.Get("bun") == "-" {
continue
}
// Make a copy so the slice is not shared between fields.
index := make([]int, len(baseIndex))
copy(index, baseIndex)
if f.Anonymous {
if f.Tag.Get("bun") == "-" {
continue
}
if f.Name == "BaseModel" && f.Type == baseModelType {
if len(index) == 0 {
t.processBaseModelField(f)
@ -243,8 +245,7 @@ func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
continue
}
field := t.newField(f, index)
if field != nil {
if field := t.newField(f, index); field != nil {
t.addField(field)
}
}
@ -284,11 +285,10 @@ func (t *Table) processBaseModelField(f reflect.StructField) {
func (t *Table) newField(f reflect.StructField, index []int) *Field {
tag := tagparser.Parse(f.Tag.Get("bun"))
if f.PkgPath != "" {
return nil
}
sqlName := internal.Underscore(f.Name)
if tag.Name != "" {
sqlName = tag.Name
}
if tag.Name != sqlName && isKnownFieldOption(tag.Name) {
internal.Warn.Printf(
@ -303,11 +303,6 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field {
}
}
skip := tag.Name == "-"
if !skip && tag.Name != "" {
sqlName = tag.Name
}
index = append(index, f.Index...)
if field := t.fieldWithLock(sqlName); field != nil {
if indexEqual(field.Index, index) {
@ -371,9 +366,11 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field {
}
t.allFields = append(t.allFields, field)
if skip {
t.skippedFields = append(t.skippedFields, field)
if tag.HasOption("scanonly") {
t.FieldMap[field.Name] = field
if field.IndirectType.Kind() == reflect.Struct {
t.inlineFields(field, nil)
}
return nil
}
@ -386,14 +383,6 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field {
return field
}
func (t *Table) initInlines() {
for _, f := range t.skippedFields {
if f.IndirectType.Kind() == reflect.Struct {
t.inlineFields(f, nil)
}
}
}
//---------------------------------------------------------------------------------------
func (t *Table) initRelations() {
@ -745,17 +734,15 @@ func (t *Table) m2mRelation(field *Field) *Relation {
return rel
}
func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) {
if path == nil {
path = map[reflect.Type]struct{}{
t.Type: {},
}
func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) {
if seen == nil {
seen = map[reflect.Type]struct{}{t.Type: {}}
}
if _, ok := path[field.IndirectType]; ok {
if _, ok := seen[field.IndirectType]; ok {
return
}
path[field.IndirectType] = struct{}{}
seen[field.IndirectType] = struct{}{}
joinTable := t.dialect.Tables().Ref(field.IndirectType)
for _, f := range joinTable.allFields {
@ -775,18 +762,15 @@ func (t *Table) inlineFields(field *Field, path map[reflect.Type]struct{}) {
continue
}
if _, ok := path[f.IndirectType]; !ok {
t.inlineFields(f, path)
if _, ok := seen[f.IndirectType]; !ok {
t.inlineFields(f, seen)
}
}
}
//------------------------------------------------------------------------------
func (t *Table) Dialect() Dialect { return t.dialect }
//------------------------------------------------------------------------------
func (t *Table) Dialect() Dialect { return t.dialect }
func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) }
func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) }
@ -845,6 +829,7 @@ func isKnownFieldOption(name string) bool {
"default",
"unique",
"soft_delete",
"scanonly",
"pk",
"autoincrement",
@ -883,35 +868,35 @@ func parseRelationJoin(join string) ([]string, []string) {
//------------------------------------------------------------------------------
func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error {
func softDeleteFieldUpdater(field *Field) func(fv reflect.Value, tm time.Time) error {
typ := field.StructField.Type
switch typ {
case timeType:
return func(fv reflect.Value) error {
return func(fv reflect.Value, tm time.Time) error {
ptr := fv.Addr().Interface().(*time.Time)
*ptr = time.Now()
*ptr = tm
return nil
}
case nullTimeType:
return func(fv reflect.Value) error {
return func(fv reflect.Value, tm time.Time) error {
ptr := fv.Addr().Interface().(*sql.NullTime)
*ptr = sql.NullTime{Time: time.Now()}
*ptr = sql.NullTime{Time: tm}
return nil
}
case nullIntType:
return func(fv reflect.Value) error {
return func(fv reflect.Value, tm time.Time) error {
ptr := fv.Addr().Interface().(*sql.NullInt64)
*ptr = sql.NullInt64{Int64: time.Now().UnixNano()}
*ptr = sql.NullInt64{Int64: tm.UnixNano()}
return nil
}
}
switch field.IndirectType.Kind() {
case reflect.Int64:
return func(fv reflect.Value) error {
return func(fv reflect.Value, tm time.Time) error {
ptr := fv.Addr().Interface().(*int64)
*ptr = time.Now().UnixNano()
*ptr = tm.UnixNano()
return nil
}
case reflect.Ptr:
@ -922,17 +907,16 @@ func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error {
switch typ { //nolint:gocritic
case timeType:
return func(fv reflect.Value) error {
now := time.Now()
fv.Set(reflect.ValueOf(&now))
return func(fv reflect.Value, tm time.Time) error {
fv.Set(reflect.ValueOf(&tm))
return nil
}
}
switch typ.Kind() { //nolint:gocritic
case reflect.Int64:
return func(fv reflect.Value) error {
utime := time.Now().UnixNano()
return func(fv reflect.Value, tm time.Time) error {
utime := tm.UnixNano()
fv.Set(reflect.ValueOf(&utime))
return nil
}
@ -941,8 +925,8 @@ func softDeleteFieldUpdater(field *Field) func(fv reflect.Value) error {
return softDeleteFieldUpdaterFallback(field)
}
func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value) error {
return func(fv reflect.Value) error {
return field.ScanWithCheck(fv, time.Now())
func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value, tm time.Time) error {
return func(fv reflect.Value, tm time.Time) error {
return field.ScanWithCheck(fv, tm)
}
}

View file

@ -67,6 +67,7 @@ func (t *Tables) Ref(typ reflect.Type) *Table {
}
func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table {
typ = indirectType(typ)
if typ.Kind() != reflect.Struct {
panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
}