mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-11-04 00:12:26 -06:00 
			
		
		
		
	
		
			
				
	
	
		
			1048 lines
		
	
	
	
		
			25 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			1048 lines
		
	
	
	
		
			25 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package schema
 | 
						|
 | 
						|
import (
 | 
						|
	"database/sql"
 | 
						|
	"fmt"
 | 
						|
	"reflect"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/jinzhu/inflection"
 | 
						|
 | 
						|
	"github.com/uptrace/bun/internal"
 | 
						|
	"github.com/uptrace/bun/internal/tagparser"
 | 
						|
)
 | 
						|
 | 
						|
const (
 | 
						|
	beforeAppendModelHookFlag internal.Flag = 1 << iota
 | 
						|
	beforeScanHookFlag
 | 
						|
	afterScanHookFlag
 | 
						|
	beforeScanRowHookFlag
 | 
						|
	afterScanRowHookFlag
 | 
						|
)
 | 
						|
 | 
						|
var (
 | 
						|
	baseModelType      = reflect.TypeOf((*BaseModel)(nil)).Elem()
 | 
						|
	tableNameInflector = inflection.Plural
 | 
						|
)
 | 
						|
 | 
						|
type BaseModel struct{}
 | 
						|
 | 
						|
// SetTableNameInflector overrides the default func that pluralizes
 | 
						|
// model name to get table name, e.g. my_article becomes my_articles.
 | 
						|
func SetTableNameInflector(fn func(string) string) {
 | 
						|
	tableNameInflector = fn
 | 
						|
}
 | 
						|
 | 
						|
// Table represents a SQL table created from Go struct.
 | 
						|
type Table struct {
 | 
						|
	dialect Dialect
 | 
						|
 | 
						|
	Type      reflect.Type
 | 
						|
	ZeroValue reflect.Value // reflect.Struct
 | 
						|
	ZeroIface interface{}   // struct pointer
 | 
						|
 | 
						|
	TypeName  string
 | 
						|
	ModelName string
 | 
						|
 | 
						|
	Name              string
 | 
						|
	SQLName           Safe
 | 
						|
	SQLNameForSelects Safe
 | 
						|
	Alias             string
 | 
						|
	SQLAlias          Safe
 | 
						|
 | 
						|
	allFields  []*Field // all fields including scanonly
 | 
						|
	Fields     []*Field // PKs + DataFields
 | 
						|
	PKs        []*Field
 | 
						|
	DataFields []*Field
 | 
						|
	relFields  []*Field
 | 
						|
 | 
						|
	FieldMap  map[string]*Field
 | 
						|
	StructMap map[string]*structField
 | 
						|
 | 
						|
	Relations map[string]*Relation
 | 
						|
	Unique    map[string][]*Field
 | 
						|
 | 
						|
	SoftDeleteField       *Field
 | 
						|
	UpdateSoftDeleteField func(fv reflect.Value, tm time.Time) error
 | 
						|
 | 
						|
	flags internal.Flag
 | 
						|
}
 | 
						|
 | 
						|
type structField struct {
 | 
						|
	Index []int
 | 
						|
	Table *Table
 | 
						|
}
 | 
						|
 | 
						|
func newTable(
 | 
						|
	dialect Dialect, typ reflect.Type, seen map[reflect.Type]*Table, canAddr bool,
 | 
						|
) *Table {
 | 
						|
	if table, ok := seen[typ]; ok {
 | 
						|
		return table
 | 
						|
	}
 | 
						|
 | 
						|
	table := new(Table)
 | 
						|
	seen[typ] = table
 | 
						|
 | 
						|
	table.dialect = dialect
 | 
						|
	table.Type = typ
 | 
						|
	table.ZeroValue = reflect.New(table.Type).Elem()
 | 
						|
	table.ZeroIface = reflect.New(table.Type).Interface()
 | 
						|
	table.TypeName = internal.ToExported(table.Type.Name())
 | 
						|
	table.ModelName = internal.Underscore(table.Type.Name())
 | 
						|
	tableName := tableNameInflector(table.ModelName)
 | 
						|
	table.setName(tableName)
 | 
						|
	table.Alias = table.ModelName
 | 
						|
	table.SQLAlias = table.quoteIdent(table.ModelName)
 | 
						|
 | 
						|
	table.Fields = make([]*Field, 0, typ.NumField())
 | 
						|
	table.FieldMap = make(map[string]*Field, typ.NumField())
 | 
						|
	table.processFields(typ, seen, canAddr)
 | 
						|
 | 
						|
	hooks := []struct {
 | 
						|
		typ  reflect.Type
 | 
						|
		flag internal.Flag
 | 
						|
	}{
 | 
						|
		{beforeAppendModelHookType, beforeAppendModelHookFlag},
 | 
						|
 | 
						|
		{beforeScanRowHookType, beforeScanRowHookFlag},
 | 
						|
		{afterScanRowHookType, afterScanRowHookFlag},
 | 
						|
	}
 | 
						|
 | 
						|
	typ = reflect.PtrTo(table.Type)
 | 
						|
	for _, hook := range hooks {
 | 
						|
		if typ.Implements(hook.typ) {
 | 
						|
			table.flags = table.flags.Set(hook.flag)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return table
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) init() {
 | 
						|
	for _, field := range t.relFields {
 | 
						|
		t.processRelation(field)
 | 
						|
	}
 | 
						|
	t.relFields = nil
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) processFields(
 | 
						|
	typ reflect.Type,
 | 
						|
	seen map[reflect.Type]*Table,
 | 
						|
	canAddr bool,
 | 
						|
) {
 | 
						|
	type embeddedField struct {
 | 
						|
		prefix     string
 | 
						|
		index      []int
 | 
						|
		unexported bool
 | 
						|
		subtable   *Table
 | 
						|
		subfield   *Field
 | 
						|
	}
 | 
						|
 | 
						|
	names := make(map[string]struct{})
 | 
						|
	embedded := make([]embeddedField, 0, 10)
 | 
						|
 | 
						|
	for i, n := 0, typ.NumField(); i < n; i++ {
 | 
						|
		sf := typ.Field(i)
 | 
						|
		unexported := sf.PkgPath != ""
 | 
						|
 | 
						|
		tagstr := sf.Tag.Get("bun")
 | 
						|
		if tagstr == "-" {
 | 
						|
			names[sf.Name] = struct{}{}
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		tag := tagparser.Parse(tagstr)
 | 
						|
 | 
						|
		if unexported && !sf.Anonymous { // unexported
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		if sf.Anonymous {
 | 
						|
			if sf.Name == "BaseModel" && sf.Type == baseModelType {
 | 
						|
				t.processBaseModelField(sf)
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			sfType := sf.Type
 | 
						|
			if sfType.Kind() == reflect.Ptr {
 | 
						|
				sfType = sfType.Elem()
 | 
						|
			}
 | 
						|
 | 
						|
			if sfType.Kind() != reflect.Struct { // ignore unexported non-struct types
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			subtable := newTable(t.dialect, sfType, seen, canAddr)
 | 
						|
 | 
						|
			for _, subfield := range subtable.allFields {
 | 
						|
				embedded = append(embedded, embeddedField{
 | 
						|
					index:      sf.Index,
 | 
						|
					unexported: unexported,
 | 
						|
					subtable:   subtable,
 | 
						|
					subfield:   subfield,
 | 
						|
				})
 | 
						|
			}
 | 
						|
 | 
						|
			if tagstr != "" {
 | 
						|
				tag := tagparser.Parse(tagstr)
 | 
						|
				if tag.HasOption("inherit") || tag.HasOption("extend") {
 | 
						|
					t.Name = subtable.Name
 | 
						|
					t.TypeName = subtable.TypeName
 | 
						|
					t.SQLName = subtable.SQLName
 | 
						|
					t.SQLNameForSelects = subtable.SQLNameForSelects
 | 
						|
					t.Alias = subtable.Alias
 | 
						|
					t.SQLAlias = subtable.SQLAlias
 | 
						|
					t.ModelName = subtable.ModelName
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		if prefix, ok := tag.Option("embed"); ok {
 | 
						|
			fieldType := indirectType(sf.Type)
 | 
						|
			if fieldType.Kind() != reflect.Struct {
 | 
						|
				panic(fmt.Errorf("bun: embed %s.%s: got %s, wanted reflect.Struct",
 | 
						|
					t.TypeName, sf.Name, fieldType.Kind()))
 | 
						|
			}
 | 
						|
 | 
						|
			subtable := newTable(t.dialect, fieldType, seen, canAddr)
 | 
						|
			for _, subfield := range subtable.allFields {
 | 
						|
				embedded = append(embedded, embeddedField{
 | 
						|
					prefix:     prefix,
 | 
						|
					index:      sf.Index,
 | 
						|
					unexported: unexported,
 | 
						|
					subtable:   subtable,
 | 
						|
					subfield:   subfield,
 | 
						|
				})
 | 
						|
			}
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		field := t.newField(sf, tag)
 | 
						|
		t.addField(field)
 | 
						|
		names[field.Name] = struct{}{}
 | 
						|
 | 
						|
		if field.IndirectType.Kind() == reflect.Struct {
 | 
						|
			if t.StructMap == nil {
 | 
						|
				t.StructMap = make(map[string]*structField)
 | 
						|
			}
 | 
						|
			t.StructMap[field.Name] = &structField{
 | 
						|
				Index: field.Index,
 | 
						|
				Table: newTable(t.dialect, field.IndirectType, seen, canAddr),
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// Only unambiguous embedded fields must be serialized.
 | 
						|
	ambiguousNames := make(map[string]int)
 | 
						|
	ambiguousTags := make(map[string]int)
 | 
						|
 | 
						|
	// Embedded types can never override a field that was already present at
 | 
						|
	// the top-level.
 | 
						|
	for name := range names {
 | 
						|
		ambiguousNames[name]++
 | 
						|
		ambiguousTags[name]++
 | 
						|
	}
 | 
						|
 | 
						|
	for _, f := range embedded {
 | 
						|
		ambiguousNames[f.prefix+f.subfield.Name]++
 | 
						|
		if !f.subfield.Tag.IsZero() {
 | 
						|
			ambiguousTags[f.prefix+f.subfield.Name]++
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	for _, embfield := range embedded {
 | 
						|
		subfield := embfield.subfield.Clone()
 | 
						|
 | 
						|
		if ambiguousNames[subfield.Name] > 1 &&
 | 
						|
			!(!subfield.Tag.IsZero() && ambiguousTags[subfield.Name] == 1) {
 | 
						|
			continue // ambiguous embedded field
 | 
						|
		}
 | 
						|
 | 
						|
		subfield.Index = makeIndex(embfield.index, subfield.Index)
 | 
						|
		if embfield.prefix != "" {
 | 
						|
			subfield.Name = embfield.prefix + subfield.Name
 | 
						|
			subfield.SQLName = t.quoteIdent(subfield.Name)
 | 
						|
		}
 | 
						|
		t.addField(subfield)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) setName(name string) {
 | 
						|
	t.Name = name
 | 
						|
	t.SQLName = t.quoteIdent(name)
 | 
						|
	t.SQLNameForSelects = t.quoteIdent(name)
 | 
						|
	if t.SQLAlias == "" {
 | 
						|
		t.Alias = name
 | 
						|
		t.SQLAlias = t.quoteIdent(name)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) String() string {
 | 
						|
	return "model=" + t.TypeName
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) CheckPKs() error {
 | 
						|
	if len(t.PKs) == 0 {
 | 
						|
		return fmt.Errorf("bun: %s does not have primary keys", t)
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) addField(field *Field) {
 | 
						|
	t.allFields = append(t.allFields, field)
 | 
						|
 | 
						|
	if field.Tag.HasOption("rel") || field.Tag.HasOption("m2m") {
 | 
						|
		t.relFields = append(t.relFields, field)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	if field.Tag.HasOption("join") {
 | 
						|
		internal.Warn.Printf(
 | 
						|
			`%s.%s "join" option must come together with "rel" option`,
 | 
						|
			t.TypeName, field.GoName,
 | 
						|
		)
 | 
						|
	}
 | 
						|
 | 
						|
	t.FieldMap[field.Name] = field
 | 
						|
	if altName, ok := field.Tag.Option("alt"); ok {
 | 
						|
		t.FieldMap[altName] = field
 | 
						|
	}
 | 
						|
 | 
						|
	if field.Tag.HasOption("scanonly") {
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	if _, ok := field.Tag.Options["soft_delete"]; ok {
 | 
						|
		t.SoftDeleteField = field
 | 
						|
		t.UpdateSoftDeleteField = softDeleteFieldUpdater(field)
 | 
						|
	}
 | 
						|
 | 
						|
	t.Fields = append(t.Fields, field)
 | 
						|
	if field.IsPK {
 | 
						|
		t.PKs = append(t.PKs, field)
 | 
						|
	} else {
 | 
						|
		t.DataFields = append(t.DataFields, field)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) LookupField(name string) *Field {
 | 
						|
	if field, ok := t.FieldMap[name]; ok {
 | 
						|
		return field
 | 
						|
	}
 | 
						|
 | 
						|
	table := t
 | 
						|
	var index []int
 | 
						|
	for {
 | 
						|
		structName, columnName, ok := strings.Cut(name, "__")
 | 
						|
		if !ok {
 | 
						|
			field, ok := table.FieldMap[name]
 | 
						|
			if !ok {
 | 
						|
				return nil
 | 
						|
			}
 | 
						|
			return field.WithIndex(index)
 | 
						|
		}
 | 
						|
		name = columnName
 | 
						|
 | 
						|
		strct := table.StructMap[structName]
 | 
						|
		if strct == nil {
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
		table = strct.Table
 | 
						|
		index = append(index, strct.Index...)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) HasField(name string) bool {
 | 
						|
	_, ok := t.FieldMap[name]
 | 
						|
	return ok
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) Field(name string) (*Field, error) {
 | 
						|
	field, ok := t.FieldMap[name]
 | 
						|
	if !ok {
 | 
						|
		return nil, fmt.Errorf("bun: %s does not have column=%s", t, name)
 | 
						|
	}
 | 
						|
	return field, nil
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) fieldByGoName(name string) *Field {
 | 
						|
	for _, f := range t.allFields {
 | 
						|
		if f.GoName == name {
 | 
						|
			return f
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) processBaseModelField(f reflect.StructField) {
 | 
						|
	tag := tagparser.Parse(f.Tag.Get("bun"))
 | 
						|
 | 
						|
	if isKnownTableOption(tag.Name) {
 | 
						|
		internal.Warn.Printf(
 | 
						|
			"%s.%s tag name %q is also an option name, is it a mistake? Try table:%s.",
 | 
						|
			t.TypeName, f.Name, tag.Name, tag.Name,
 | 
						|
		)
 | 
						|
	}
 | 
						|
 | 
						|
	for name := range tag.Options {
 | 
						|
		if !isKnownTableOption(name) {
 | 
						|
			internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if tag.Name != "" {
 | 
						|
		t.setName(tag.Name)
 | 
						|
	}
 | 
						|
 | 
						|
	if s, ok := tag.Option("table"); ok {
 | 
						|
		t.setName(s)
 | 
						|
	}
 | 
						|
 | 
						|
	if s, ok := tag.Option("select"); ok {
 | 
						|
		t.SQLNameForSelects = t.quoteTableName(s)
 | 
						|
	}
 | 
						|
 | 
						|
	if s, ok := tag.Option("alias"); ok {
 | 
						|
		t.Alias = s
 | 
						|
		t.SQLAlias = t.quoteIdent(s)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// nolint
 | 
						|
func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field {
 | 
						|
	sqlName := internal.Underscore(sf.Name)
 | 
						|
	if tag.Name != "" && tag.Name != sqlName {
 | 
						|
		if isKnownFieldOption(tag.Name) {
 | 
						|
			internal.Warn.Printf(
 | 
						|
				"%s.%s tag name %q is also an option name, is it a mistake? Try column:%s.",
 | 
						|
				t.TypeName, sf.Name, tag.Name, tag.Name,
 | 
						|
			)
 | 
						|
		}
 | 
						|
		sqlName = tag.Name
 | 
						|
	}
 | 
						|
 | 
						|
	for name := range tag.Options {
 | 
						|
		if !isKnownFieldOption(name) {
 | 
						|
			internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, sf.Name, name)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	field := &Field{
 | 
						|
		StructField: sf,
 | 
						|
		IsPtr:       sf.Type.Kind() == reflect.Ptr,
 | 
						|
 | 
						|
		Tag:          tag,
 | 
						|
		IndirectType: indirectType(sf.Type),
 | 
						|
		Index:        sf.Index,
 | 
						|
 | 
						|
		Name:    sqlName,
 | 
						|
		GoName:  sf.Name,
 | 
						|
		SQLName: t.quoteIdent(sqlName),
 | 
						|
	}
 | 
						|
 | 
						|
	field.NotNull = tag.HasOption("notnull")
 | 
						|
	field.NullZero = tag.HasOption("nullzero")
 | 
						|
	if tag.HasOption("pk") {
 | 
						|
		field.IsPK = true
 | 
						|
		field.NotNull = true
 | 
						|
	}
 | 
						|
	if tag.HasOption("autoincrement") {
 | 
						|
		field.AutoIncrement = true
 | 
						|
		field.NullZero = true
 | 
						|
	}
 | 
						|
	if tag.HasOption("identity") {
 | 
						|
		field.Identity = true
 | 
						|
	}
 | 
						|
 | 
						|
	if v, ok := tag.Options["unique"]; ok {
 | 
						|
		var names []string
 | 
						|
		if len(v) == 1 {
 | 
						|
			// Split the value by comma, this will allow multiple names to be specified.
 | 
						|
			// We can use this to create multiple named unique constraints where a single column
 | 
						|
			// might be included in multiple constraints.
 | 
						|
			names = strings.Split(v[0], ",")
 | 
						|
		} else {
 | 
						|
			names = v
 | 
						|
		}
 | 
						|
 | 
						|
		for _, uniqueName := range names {
 | 
						|
			if t.Unique == nil {
 | 
						|
				t.Unique = make(map[string][]*Field)
 | 
						|
			}
 | 
						|
			t.Unique[uniqueName] = append(t.Unique[uniqueName], field)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if s, ok := tag.Option("default"); ok {
 | 
						|
		field.SQLDefault = s
 | 
						|
	}
 | 
						|
	if s, ok := field.Tag.Option("type"); ok {
 | 
						|
		field.UserSQLType = s
 | 
						|
	}
 | 
						|
	field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType)
 | 
						|
	field.Append = FieldAppender(t.dialect, field)
 | 
						|
	field.Scan = FieldScanner(t.dialect, field)
 | 
						|
	field.IsZero = zeroChecker(field.StructField.Type)
 | 
						|
 | 
						|
	return field
 | 
						|
}
 | 
						|
 | 
						|
//---------------------------------------------------------------------------------------
 | 
						|
 | 
						|
func (t *Table) processRelation(field *Field) {
 | 
						|
	if rel, ok := field.Tag.Option("rel"); ok {
 | 
						|
		t.initRelation(field, rel)
 | 
						|
		return
 | 
						|
	}
 | 
						|
	if field.Tag.HasOption("m2m") {
 | 
						|
		t.addRelation(t.m2mRelation(field))
 | 
						|
		return
 | 
						|
	}
 | 
						|
	panic("not reached")
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) initRelation(field *Field, rel string) {
 | 
						|
	switch rel {
 | 
						|
	case "belongs-to":
 | 
						|
		t.addRelation(t.belongsToRelation(field))
 | 
						|
	case "has-one":
 | 
						|
		t.addRelation(t.hasOneRelation(field))
 | 
						|
	case "has-many":
 | 
						|
		t.addRelation(t.hasManyRelation(field))
 | 
						|
	default:
 | 
						|
		panic(fmt.Errorf("bun: unknown relation=%s on field=%s", rel, field.GoName))
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) addRelation(rel *Relation) {
 | 
						|
	if t.Relations == nil {
 | 
						|
		t.Relations = make(map[string]*Relation)
 | 
						|
	}
 | 
						|
	_, ok := t.Relations[rel.Field.GoName]
 | 
						|
	if ok {
 | 
						|
		panic(fmt.Errorf("%s already has %s", t, rel))
 | 
						|
	}
 | 
						|
	t.Relations[rel.Field.GoName] = rel
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) belongsToRelation(field *Field) *Relation {
 | 
						|
	joinTable := t.dialect.Tables().InProgress(field.IndirectType)
 | 
						|
	if err := joinTable.CheckPKs(); err != nil {
 | 
						|
		panic(err)
 | 
						|
	}
 | 
						|
 | 
						|
	rel := &Relation{
 | 
						|
		Type:      BelongsToRelation,
 | 
						|
		Field:     field,
 | 
						|
		JoinTable: joinTable,
 | 
						|
	}
 | 
						|
 | 
						|
	if field.Tag.HasOption("join_on") {
 | 
						|
		rel.Condition = field.Tag.Options["join_on"]
 | 
						|
	}
 | 
						|
 | 
						|
	rel.OnUpdate = "ON UPDATE NO ACTION"
 | 
						|
	if onUpdate, ok := field.Tag.Options["on_update"]; ok {
 | 
						|
		if len(onUpdate) > 1 {
 | 
						|
			panic(fmt.Errorf("bun: %s belongs-to %s: on_update option must be a single field", t.TypeName, field.GoName))
 | 
						|
		}
 | 
						|
 | 
						|
		rule := strings.ToUpper(onUpdate[0])
 | 
						|
		if !isKnownFKRule(rule) {
 | 
						|
			internal.Warn.Printf("bun: %s belongs-to %s: unknown on_update rule %s", t.TypeName, field.GoName, rule)
 | 
						|
		}
 | 
						|
 | 
						|
		s := fmt.Sprintf("ON UPDATE %s", rule)
 | 
						|
		rel.OnUpdate = s
 | 
						|
	}
 | 
						|
 | 
						|
	rel.OnDelete = "ON DELETE NO ACTION"
 | 
						|
	if onDelete, ok := field.Tag.Options["on_delete"]; ok {
 | 
						|
		if len(onDelete) > 1 {
 | 
						|
			panic(fmt.Errorf("bun: %s belongs-to %s: on_delete option must be a single field", t.TypeName, field.GoName))
 | 
						|
		}
 | 
						|
 | 
						|
		rule := strings.ToUpper(onDelete[0])
 | 
						|
		if !isKnownFKRule(rule) {
 | 
						|
			internal.Warn.Printf("bun: %s belongs-to %s: unknown on_delete rule %s", t.TypeName, field.GoName, rule)
 | 
						|
		}
 | 
						|
		s := fmt.Sprintf("ON DELETE %s", rule)
 | 
						|
		rel.OnDelete = s
 | 
						|
	}
 | 
						|
 | 
						|
	if join, ok := field.Tag.Options["join"]; ok {
 | 
						|
		baseColumns, joinColumns := parseRelationJoin(join)
 | 
						|
		for i, baseColumn := range baseColumns {
 | 
						|
			joinColumn := joinColumns[i]
 | 
						|
 | 
						|
			if f := t.FieldMap[baseColumn]; f != nil {
 | 
						|
				rel.BaseFields = append(rel.BaseFields, f)
 | 
						|
			} else {
 | 
						|
				panic(fmt.Errorf(
 | 
						|
					"bun: %s belongs-to %s: %s must have column %s",
 | 
						|
					t.TypeName, field.GoName, t.TypeName, baseColumn,
 | 
						|
				))
 | 
						|
			}
 | 
						|
 | 
						|
			if f := joinTable.FieldMap[joinColumn]; f != nil {
 | 
						|
				rel.JoinFields = append(rel.JoinFields, f)
 | 
						|
			} else {
 | 
						|
				panic(fmt.Errorf(
 | 
						|
					"bun: %s belongs-to %s: %s must have column %s",
 | 
						|
					t.TypeName, field.GoName, joinTable.TypeName, joinColumn,
 | 
						|
				))
 | 
						|
			}
 | 
						|
		}
 | 
						|
		return rel
 | 
						|
	}
 | 
						|
 | 
						|
	rel.JoinFields = joinTable.PKs
 | 
						|
	fkPrefix := internal.Underscore(field.GoName) + "_"
 | 
						|
	for _, joinPK := range joinTable.PKs {
 | 
						|
		fkName := fkPrefix + joinPK.Name
 | 
						|
		if fk := t.FieldMap[fkName]; fk != nil {
 | 
						|
			rel.BaseFields = append(rel.BaseFields, fk)
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		if fk := t.FieldMap[joinPK.Name]; fk != nil {
 | 
						|
			rel.BaseFields = append(rel.BaseFields, fk)
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		panic(fmt.Errorf(
 | 
						|
			"bun: %s belongs-to %s: %s must have column %s "+
 | 
						|
				"(to override, use join:base_column=join_column tag on %s field)",
 | 
						|
			t.TypeName, field.GoName, t.TypeName, fkName, field.GoName,
 | 
						|
		))
 | 
						|
	}
 | 
						|
	return rel
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) hasOneRelation(field *Field) *Relation {
 | 
						|
	if err := t.CheckPKs(); err != nil {
 | 
						|
		panic(err)
 | 
						|
	}
 | 
						|
 | 
						|
	joinTable := t.dialect.Tables().InProgress(field.IndirectType)
 | 
						|
	rel := &Relation{
 | 
						|
		Type:      HasOneRelation,
 | 
						|
		Field:     field,
 | 
						|
		JoinTable: joinTable,
 | 
						|
	}
 | 
						|
 | 
						|
	if field.Tag.HasOption("join_on") {
 | 
						|
		rel.Condition = field.Tag.Options["join_on"]
 | 
						|
	}
 | 
						|
 | 
						|
	if join, ok := field.Tag.Options["join"]; ok {
 | 
						|
		baseColumns, joinColumns := parseRelationJoin(join)
 | 
						|
		for i, baseColumn := range baseColumns {
 | 
						|
			if f := t.FieldMap[baseColumn]; f != nil {
 | 
						|
				rel.BaseFields = append(rel.BaseFields, f)
 | 
						|
			} else {
 | 
						|
				panic(fmt.Errorf(
 | 
						|
					"bun: %s has-one %s: %s must have column %s",
 | 
						|
					field.GoName, t.TypeName, t.TypeName, baseColumn,
 | 
						|
				))
 | 
						|
			}
 | 
						|
 | 
						|
			joinColumn := joinColumns[i]
 | 
						|
			if f := joinTable.FieldMap[joinColumn]; f != nil {
 | 
						|
				rel.JoinFields = append(rel.JoinFields, f)
 | 
						|
			} else {
 | 
						|
				panic(fmt.Errorf(
 | 
						|
					"bun: %s has-one %s: %s must have column %s",
 | 
						|
					field.GoName, t.TypeName, joinTable.TypeName, joinColumn,
 | 
						|
				))
 | 
						|
			}
 | 
						|
		}
 | 
						|
		return rel
 | 
						|
	}
 | 
						|
 | 
						|
	rel.BaseFields = t.PKs
 | 
						|
	fkPrefix := internal.Underscore(t.ModelName) + "_"
 | 
						|
	for _, pk := range t.PKs {
 | 
						|
		fkName := fkPrefix + pk.Name
 | 
						|
		if f := joinTable.FieldMap[fkName]; f != nil {
 | 
						|
			rel.JoinFields = append(rel.JoinFields, f)
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		if f := joinTable.FieldMap[pk.Name]; f != nil {
 | 
						|
			rel.JoinFields = append(rel.JoinFields, f)
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		panic(fmt.Errorf(
 | 
						|
			"bun: %s has-one %s: %s must have column %s "+
 | 
						|
				"(to override, use join:base_column=join_column tag on %s field)",
 | 
						|
			field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName,
 | 
						|
		))
 | 
						|
	}
 | 
						|
	return rel
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) hasManyRelation(field *Field) *Relation {
 | 
						|
	if err := t.CheckPKs(); err != nil {
 | 
						|
		panic(err)
 | 
						|
	}
 | 
						|
	if field.IndirectType.Kind() != reflect.Slice {
 | 
						|
		panic(fmt.Errorf(
 | 
						|
			"bun: %s.%s has-many relation requires slice, got %q",
 | 
						|
			t.TypeName, field.GoName, field.IndirectType.Kind(),
 | 
						|
		))
 | 
						|
	}
 | 
						|
 | 
						|
	joinTable := t.dialect.Tables().InProgress(indirectType(field.IndirectType.Elem()))
 | 
						|
	polymorphicValue, isPolymorphic := field.Tag.Option("polymorphic")
 | 
						|
	rel := &Relation{
 | 
						|
		Type:      HasManyRelation,
 | 
						|
		Field:     field,
 | 
						|
		JoinTable: joinTable,
 | 
						|
	}
 | 
						|
 | 
						|
	if field.Tag.HasOption("join_on") {
 | 
						|
		rel.Condition = field.Tag.Options["join_on"]
 | 
						|
	}
 | 
						|
 | 
						|
	var polymorphicColumn string
 | 
						|
 | 
						|
	if join, ok := field.Tag.Options["join"]; ok {
 | 
						|
		baseColumns, joinColumns := parseRelationJoin(join)
 | 
						|
		for i, baseColumn := range baseColumns {
 | 
						|
			joinColumn := joinColumns[i]
 | 
						|
 | 
						|
			if isPolymorphic && baseColumn == "type" {
 | 
						|
				polymorphicColumn = joinColumn
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			if f := t.FieldMap[baseColumn]; f != nil {
 | 
						|
				rel.BaseFields = append(rel.BaseFields, f)
 | 
						|
			} else {
 | 
						|
				panic(fmt.Errorf(
 | 
						|
					"bun: %s has-many %s: %s must have column %s",
 | 
						|
					t.TypeName, field.GoName, t.TypeName, baseColumn,
 | 
						|
				))
 | 
						|
			}
 | 
						|
 | 
						|
			if f := joinTable.FieldMap[joinColumn]; f != nil {
 | 
						|
				rel.JoinFields = append(rel.JoinFields, f)
 | 
						|
			} else {
 | 
						|
				panic(fmt.Errorf(
 | 
						|
					"bun: %s has-many %s: %s must have column %s",
 | 
						|
					t.TypeName, field.GoName, joinTable.TypeName, joinColumn,
 | 
						|
				))
 | 
						|
			}
 | 
						|
		}
 | 
						|
	} else {
 | 
						|
		rel.BaseFields = t.PKs
 | 
						|
		fkPrefix := internal.Underscore(t.ModelName) + "_"
 | 
						|
		if isPolymorphic {
 | 
						|
			polymorphicColumn = fkPrefix + "type"
 | 
						|
		}
 | 
						|
 | 
						|
		for _, pk := range t.PKs {
 | 
						|
			joinColumn := fkPrefix + pk.Name
 | 
						|
			if fk := joinTable.FieldMap[joinColumn]; fk != nil {
 | 
						|
				rel.JoinFields = append(rel.JoinFields, fk)
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			if fk := joinTable.FieldMap[pk.Name]; fk != nil {
 | 
						|
				rel.JoinFields = append(rel.JoinFields, fk)
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			panic(fmt.Errorf(
 | 
						|
				"bun: %s has-many %s: %s must have column %s "+
 | 
						|
					"(to override, use join:base_column=join_column tag on the field %s)",
 | 
						|
				t.TypeName, field.GoName, joinTable.TypeName, joinColumn, field.GoName,
 | 
						|
			))
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if isPolymorphic {
 | 
						|
		rel.PolymorphicField = joinTable.FieldMap[polymorphicColumn]
 | 
						|
		if rel.PolymorphicField == nil {
 | 
						|
			panic(fmt.Errorf(
 | 
						|
				"bun: %s has-many %s: %s must have polymorphic column %s",
 | 
						|
				t.TypeName, field.GoName, joinTable.TypeName, polymorphicColumn,
 | 
						|
			))
 | 
						|
		}
 | 
						|
 | 
						|
		if polymorphicValue == "" {
 | 
						|
			polymorphicValue = t.ModelName
 | 
						|
		}
 | 
						|
		rel.PolymorphicValue = polymorphicValue
 | 
						|
	}
 | 
						|
 | 
						|
	return rel
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) m2mRelation(field *Field) *Relation {
 | 
						|
	if field.IndirectType.Kind() != reflect.Slice {
 | 
						|
		panic(fmt.Errorf(
 | 
						|
			"bun: %s.%s m2m relation requires slice, got %q",
 | 
						|
			t.TypeName, field.GoName, field.IndirectType.Kind(),
 | 
						|
		))
 | 
						|
	}
 | 
						|
	joinTable := t.dialect.Tables().InProgress(indirectType(field.IndirectType.Elem()))
 | 
						|
 | 
						|
	if err := t.CheckPKs(); err != nil {
 | 
						|
		panic(err)
 | 
						|
	}
 | 
						|
	if err := joinTable.CheckPKs(); err != nil {
 | 
						|
		panic(err)
 | 
						|
	}
 | 
						|
 | 
						|
	m2mTableName, ok := field.Tag.Option("m2m")
 | 
						|
	if !ok {
 | 
						|
		panic(fmt.Errorf("bun: %s must have m2m tag option", field.GoName))
 | 
						|
	}
 | 
						|
 | 
						|
	m2mTable := t.dialect.Tables().ByName(m2mTableName)
 | 
						|
	if m2mTable == nil {
 | 
						|
		panic(fmt.Errorf(
 | 
						|
			"bun: can't find m2m %s table (use db.RegisterModel)",
 | 
						|
			m2mTableName,
 | 
						|
		))
 | 
						|
	}
 | 
						|
 | 
						|
	rel := &Relation{
 | 
						|
		Type:      ManyToManyRelation,
 | 
						|
		Field:     field,
 | 
						|
		JoinTable: joinTable,
 | 
						|
		M2MTable:  m2mTable,
 | 
						|
	}
 | 
						|
 | 
						|
	if field.Tag.HasOption("join_on") {
 | 
						|
		rel.Condition = field.Tag.Options["join_on"]
 | 
						|
	}
 | 
						|
 | 
						|
	var leftColumn, rightColumn string
 | 
						|
 | 
						|
	if join, ok := field.Tag.Options["join"]; ok {
 | 
						|
		left, right := parseRelationJoin(join)
 | 
						|
		leftColumn = left[0]
 | 
						|
		rightColumn = right[0]
 | 
						|
	} else {
 | 
						|
		leftColumn = t.TypeName
 | 
						|
		rightColumn = joinTable.TypeName
 | 
						|
	}
 | 
						|
 | 
						|
	leftField := m2mTable.fieldByGoName(leftColumn)
 | 
						|
	if leftField == nil {
 | 
						|
		panic(fmt.Errorf(
 | 
						|
			"bun: %s many-to-many %s: %s must have field %s "+
 | 
						|
				"(to override, use tag join:LeftField=RightField on field %s.%s",
 | 
						|
			t.TypeName, field.GoName, m2mTable.TypeName, leftColumn, t.TypeName, field.GoName,
 | 
						|
		))
 | 
						|
	}
 | 
						|
 | 
						|
	rightField := m2mTable.fieldByGoName(rightColumn)
 | 
						|
	if rightField == nil {
 | 
						|
		panic(fmt.Errorf(
 | 
						|
			"bun: %s many-to-many %s: %s must have field %s "+
 | 
						|
				"(to override, use tag join:LeftField=RightField on field %s.%s",
 | 
						|
			t.TypeName, field.GoName, m2mTable.TypeName, rightColumn, t.TypeName, field.GoName,
 | 
						|
		))
 | 
						|
	}
 | 
						|
 | 
						|
	leftRel := m2mTable.belongsToRelation(leftField)
 | 
						|
	rel.BaseFields = leftRel.JoinFields
 | 
						|
	rel.M2MBaseFields = leftRel.BaseFields
 | 
						|
 | 
						|
	rightRel := m2mTable.belongsToRelation(rightField)
 | 
						|
	rel.JoinFields = rightRel.JoinFields
 | 
						|
	rel.M2MJoinFields = rightRel.BaseFields
 | 
						|
 | 
						|
	return rel
 | 
						|
}
 | 
						|
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
 | 
						|
func (t *Table) Dialect() Dialect { return t.dialect }
 | 
						|
 | 
						|
func (t *Table) HasBeforeAppendModelHook() bool { return t.flags.Has(beforeAppendModelHookFlag) }
 | 
						|
 | 
						|
// DEPRECATED. Use HasBeforeScanRowHook.
 | 
						|
func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) }
 | 
						|
 | 
						|
// DEPRECATED. Use HasAfterScanRowHook.
 | 
						|
func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) }
 | 
						|
 | 
						|
func (t *Table) HasBeforeScanRowHook() bool { return t.flags.Has(beforeScanRowHookFlag) }
 | 
						|
func (t *Table) HasAfterScanRowHook() bool  { return t.flags.Has(afterScanRowHookFlag) }
 | 
						|
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
 | 
						|
func (t *Table) AppendNamedArg(
 | 
						|
	fmter Formatter, b []byte, name string, strct reflect.Value,
 | 
						|
) ([]byte, bool) {
 | 
						|
	if field, ok := t.FieldMap[name]; ok {
 | 
						|
		return field.AppendValue(fmter, b, strct), true
 | 
						|
	}
 | 
						|
	return b, false
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) quoteTableName(s string) Safe {
 | 
						|
	// Don't quote if table name contains placeholder (?) or parentheses.
 | 
						|
	if strings.IndexByte(s, '?') >= 0 ||
 | 
						|
		strings.IndexByte(s, '(') >= 0 ||
 | 
						|
		strings.IndexByte(s, ')') >= 0 {
 | 
						|
		return Safe(s)
 | 
						|
	}
 | 
						|
	return t.quoteIdent(s)
 | 
						|
}
 | 
						|
 | 
						|
func (t *Table) quoteIdent(s string) Safe {
 | 
						|
	return Safe(NewFormatter(t.dialect).AppendIdent(nil, s))
 | 
						|
}
 | 
						|
 | 
						|
func isKnownTableOption(name string) bool {
 | 
						|
	switch name {
 | 
						|
	case "table", "alias", "select":
 | 
						|
		return true
 | 
						|
	}
 | 
						|
	return false
 | 
						|
}
 | 
						|
 | 
						|
func isKnownFieldOption(name string) bool {
 | 
						|
	switch name {
 | 
						|
	case "column",
 | 
						|
		"alt",
 | 
						|
		"type",
 | 
						|
		"array",
 | 
						|
		"hstore",
 | 
						|
		"composite",
 | 
						|
		"json_use_number",
 | 
						|
		"msgpack",
 | 
						|
		"notnull",
 | 
						|
		"nullzero",
 | 
						|
		"default",
 | 
						|
		"unique",
 | 
						|
		"soft_delete",
 | 
						|
		"scanonly",
 | 
						|
		"skipupdate",
 | 
						|
 | 
						|
		"pk",
 | 
						|
		"autoincrement",
 | 
						|
		"rel",
 | 
						|
		"join",
 | 
						|
		"join_on",
 | 
						|
		"on_update",
 | 
						|
		"on_delete",
 | 
						|
		"m2m",
 | 
						|
		"polymorphic",
 | 
						|
		"identity":
 | 
						|
		return true
 | 
						|
	}
 | 
						|
	return false
 | 
						|
}
 | 
						|
 | 
						|
func isKnownFKRule(name string) bool {
 | 
						|
	switch name {
 | 
						|
	case "CASCADE",
 | 
						|
		"RESTRICT",
 | 
						|
		"SET NULL",
 | 
						|
		"SET DEFAULT":
 | 
						|
		return true
 | 
						|
	}
 | 
						|
	return false
 | 
						|
}
 | 
						|
 | 
						|
func parseRelationJoin(join []string) ([]string, []string) {
 | 
						|
	var ss []string
 | 
						|
	if len(join) == 1 {
 | 
						|
		ss = strings.Split(join[0], ",")
 | 
						|
	} else {
 | 
						|
		ss = join
 | 
						|
	}
 | 
						|
 | 
						|
	baseColumns := make([]string, len(ss))
 | 
						|
	joinColumns := make([]string, len(ss))
 | 
						|
	for i, s := range ss {
 | 
						|
		ss := strings.Split(strings.TrimSpace(s), "=")
 | 
						|
		if len(ss) != 2 {
 | 
						|
			panic(fmt.Errorf("can't parse relation join: %q", join))
 | 
						|
		}
 | 
						|
		baseColumns[i] = ss[0]
 | 
						|
		joinColumns[i] = ss[1]
 | 
						|
	}
 | 
						|
	return baseColumns, joinColumns
 | 
						|
}
 | 
						|
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
 | 
						|
func softDeleteFieldUpdater(field *Field) func(fv reflect.Value, tm time.Time) error {
 | 
						|
	typ := field.StructField.Type
 | 
						|
 | 
						|
	switch typ {
 | 
						|
	case timeType:
 | 
						|
		return func(fv reflect.Value, tm time.Time) error {
 | 
						|
			ptr := fv.Addr().Interface().(*time.Time)
 | 
						|
			*ptr = tm
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
	case nullTimeType:
 | 
						|
		return func(fv reflect.Value, tm time.Time) error {
 | 
						|
			ptr := fv.Addr().Interface().(*sql.NullTime)
 | 
						|
			*ptr = sql.NullTime{Time: tm}
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
	case nullIntType:
 | 
						|
		return func(fv reflect.Value, tm time.Time) error {
 | 
						|
			ptr := fv.Addr().Interface().(*sql.NullInt64)
 | 
						|
			*ptr = sql.NullInt64{Int64: tm.UnixNano()}
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	switch field.IndirectType.Kind() {
 | 
						|
	case reflect.Int64:
 | 
						|
		return func(fv reflect.Value, tm time.Time) error {
 | 
						|
			ptr := fv.Addr().Interface().(*int64)
 | 
						|
			*ptr = tm.UnixNano()
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
	case reflect.Ptr:
 | 
						|
		typ = typ.Elem()
 | 
						|
	default:
 | 
						|
		return softDeleteFieldUpdaterFallback(field)
 | 
						|
	}
 | 
						|
 | 
						|
	switch typ { //nolint:gocritic
 | 
						|
	case timeType:
 | 
						|
		return func(fv reflect.Value, tm time.Time) error {
 | 
						|
			fv.Set(reflect.ValueOf(&tm))
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	switch typ.Kind() { //nolint:gocritic
 | 
						|
	case reflect.Int64:
 | 
						|
		return func(fv reflect.Value, tm time.Time) error {
 | 
						|
			utime := tm.UnixNano()
 | 
						|
			fv.Set(reflect.ValueOf(&utime))
 | 
						|
			return nil
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return softDeleteFieldUpdaterFallback(field)
 | 
						|
}
 | 
						|
 | 
						|
func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value, tm time.Time) error {
 | 
						|
	return func(fv reflect.Value, tm time.Time) error {
 | 
						|
		return field.ScanWithCheck(fv, tm)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func makeIndex(a, b []int) []int {
 | 
						|
	dest := make([]int, 0, len(a)+len(b))
 | 
						|
	dest = append(dest, a...)
 | 
						|
	dest = append(dest, b...)
 | 
						|
	return dest
 | 
						|
}
 |