mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 14:32:24 -05:00 
			
		
		
		
	
		
			
				
	
	
		
			937 lines
		
	
	
	
		
			22 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			937 lines
		
	
	
	
		
			22 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package schema
 | |
| 
 | |
| import (
 | |
| 	"database/sql"
 | |
| 	"fmt"
 | |
| 	"reflect"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/jinzhu/inflection"
 | |
| 
 | |
| 	"github.com/uptrace/bun/internal"
 | |
| 	"github.com/uptrace/bun/internal/tagparser"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	beforeScanHookFlag internal.Flag = 1 << iota
 | |
| 	afterScanHookFlag
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	baseModelType      = reflect.TypeOf((*BaseModel)(nil)).Elem()
 | |
| 	tableNameInflector = inflection.Plural
 | |
| )
 | |
| 
 | |
| type BaseModel struct{}
 | |
| 
 | |
| // SetTableNameInflector overrides the default func that pluralizes
 | |
| // model name to get table name, e.g. my_article becomes my_articles.
 | |
| func SetTableNameInflector(fn func(string) string) {
 | |
| 	tableNameInflector = fn
 | |
| }
 | |
| 
 | |
| // Table represents a SQL table created from Go struct.
 | |
| type Table struct {
 | |
| 	dialect Dialect
 | |
| 
 | |
| 	Type      reflect.Type
 | |
| 	ZeroValue reflect.Value // reflect.Struct
 | |
| 	ZeroIface interface{}   // struct pointer
 | |
| 
 | |
| 	TypeName  string
 | |
| 	ModelName string
 | |
| 
 | |
| 	Name              string
 | |
| 	SQLName           Safe
 | |
| 	SQLNameForSelects Safe
 | |
| 	Alias             string
 | |
| 	SQLAlias          Safe
 | |
| 
 | |
| 	Fields     []*Field // PKs + DataFields
 | |
| 	PKs        []*Field
 | |
| 	DataFields []*Field
 | |
| 
 | |
| 	fieldsMapMu sync.RWMutex
 | |
| 	FieldMap    map[string]*Field
 | |
| 
 | |
| 	Relations map[string]*Relation
 | |
| 	Unique    map[string][]*Field
 | |
| 
 | |
| 	SoftDeleteField       *Field
 | |
| 	UpdateSoftDeleteField func(fv reflect.Value, tm time.Time) error
 | |
| 
 | |
| 	allFields []*Field // read only
 | |
| 
 | |
| 	flags internal.Flag
 | |
| }
 | |
| 
 | |
| func newTable(dialect Dialect, typ reflect.Type) *Table {
 | |
| 	t := new(Table)
 | |
| 	t.dialect = dialect
 | |
| 	t.Type = typ
 | |
| 	t.ZeroValue = reflect.New(t.Type).Elem()
 | |
| 	t.ZeroIface = reflect.New(t.Type).Interface()
 | |
| 	t.TypeName = internal.ToExported(t.Type.Name())
 | |
| 	t.ModelName = internal.Underscore(t.Type.Name())
 | |
| 	tableName := tableNameInflector(t.ModelName)
 | |
| 	t.setName(tableName)
 | |
| 	t.Alias = t.ModelName
 | |
| 	t.SQLAlias = t.quoteIdent(t.ModelName)
 | |
| 
 | |
| 	hooks := []struct {
 | |
| 		typ  reflect.Type
 | |
| 		flag internal.Flag
 | |
| 	}{
 | |
| 		{beforeScanHookType, beforeScanHookFlag},
 | |
| 		{afterScanHookType, afterScanHookFlag},
 | |
| 	}
 | |
| 
 | |
| 	typ = reflect.PtrTo(t.Type)
 | |
| 	for _, hook := range hooks {
 | |
| 		if typ.Implements(hook.typ) {
 | |
| 			t.flags = t.flags.Set(hook.flag)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return t
 | |
| }
 | |
| 
 | |
| func (t *Table) init1() {
 | |
| 	t.initFields()
 | |
| }
 | |
| 
 | |
| func (t *Table) init2() {
 | |
| 	t.initRelations()
 | |
| }
 | |
| 
 | |
| func (t *Table) setName(name string) {
 | |
| 	t.Name = name
 | |
| 	t.SQLName = t.quoteIdent(name)
 | |
| 	t.SQLNameForSelects = t.quoteIdent(name)
 | |
| 	if t.SQLAlias == "" {
 | |
| 		t.Alias = name
 | |
| 		t.SQLAlias = t.quoteIdent(name)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (t *Table) String() string {
 | |
| 	return "model=" + t.TypeName
 | |
| }
 | |
| 
 | |
| func (t *Table) CheckPKs() error {
 | |
| 	if len(t.PKs) == 0 {
 | |
| 		return fmt.Errorf("bun: %s does not have primary keys", t)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (t *Table) addField(field *Field) {
 | |
| 	t.Fields = append(t.Fields, field)
 | |
| 	if field.IsPK {
 | |
| 		t.PKs = append(t.PKs, field)
 | |
| 	} else {
 | |
| 		t.DataFields = append(t.DataFields, field)
 | |
| 	}
 | |
| 	t.FieldMap[field.Name] = field
 | |
| }
 | |
| 
 | |
| func (t *Table) removeField(field *Field) {
 | |
| 	t.Fields = removeField(t.Fields, field)
 | |
| 	if field.IsPK {
 | |
| 		t.PKs = removeField(t.PKs, field)
 | |
| 	} else {
 | |
| 		t.DataFields = removeField(t.DataFields, field)
 | |
| 	}
 | |
| 	delete(t.FieldMap, field.Name)
 | |
| }
 | |
| 
 | |
| func (t *Table) fieldWithLock(name string) *Field {
 | |
| 	t.fieldsMapMu.RLock()
 | |
| 	field := t.FieldMap[name]
 | |
| 	t.fieldsMapMu.RUnlock()
 | |
| 	return field
 | |
| }
 | |
| 
 | |
| func (t *Table) HasField(name string) bool {
 | |
| 	_, ok := t.FieldMap[name]
 | |
| 	return ok
 | |
| }
 | |
| 
 | |
| func (t *Table) Field(name string) (*Field, error) {
 | |
| 	field, ok := t.FieldMap[name]
 | |
| 	if !ok {
 | |
| 		return nil, fmt.Errorf("bun: %s does not have column=%s", t, name)
 | |
| 	}
 | |
| 	return field, nil
 | |
| }
 | |
| 
 | |
| func (t *Table) fieldByGoName(name string) *Field {
 | |
| 	for _, f := range t.allFields {
 | |
| 		if f.GoName == name {
 | |
| 			return f
 | |
| 		}
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (t *Table) initFields() {
 | |
| 	t.Fields = make([]*Field, 0, t.Type.NumField())
 | |
| 	t.FieldMap = make(map[string]*Field, t.Type.NumField())
 | |
| 	t.addFields(t.Type, nil)
 | |
| 
 | |
| 	if len(t.PKs) == 0 {
 | |
| 		for _, name := range []string{"id", "uuid", "pk_" + t.ModelName} {
 | |
| 			if field, ok := t.FieldMap[name]; ok {
 | |
| 				field.markAsPK()
 | |
| 				t.PKs = []*Field{field}
 | |
| 				t.DataFields = removeField(t.DataFields, field)
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if len(t.PKs) == 1 {
 | |
| 		pk := t.PKs[0]
 | |
| 		if pk.SQLDefault != "" {
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		switch pk.IndirectType.Kind() {
 | |
| 		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
 | |
| 			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
 | |
| 			pk.AutoIncrement = true
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (t *Table) addFields(typ reflect.Type, baseIndex []int) {
 | |
| 	for i := 0; i < typ.NumField(); i++ {
 | |
| 		f := typ.Field(i)
 | |
| 		unexported := f.PkgPath != ""
 | |
| 
 | |
| 		if unexported && !f.Anonymous { // unexported
 | |
| 			continue
 | |
| 		}
 | |
| 		if f.Tag.Get("bun") == "-" {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		// Make a copy so the slice is not shared between fields.
 | |
| 		index := make([]int, len(baseIndex))
 | |
| 		copy(index, baseIndex)
 | |
| 
 | |
| 		if f.Anonymous {
 | |
| 			if f.Name == "BaseModel" && f.Type == baseModelType {
 | |
| 				if len(index) == 0 {
 | |
| 					t.processBaseModelField(f)
 | |
| 				}
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			fieldType := indirectType(f.Type)
 | |
| 			if fieldType.Kind() != reflect.Struct {
 | |
| 				continue
 | |
| 			}
 | |
| 			t.addFields(fieldType, append(index, f.Index...))
 | |
| 
 | |
| 			tag := tagparser.Parse(f.Tag.Get("bun"))
 | |
| 			if _, inherit := tag.Options["inherit"]; inherit {
 | |
| 				embeddedTable := t.dialect.Tables().Ref(fieldType)
 | |
| 				t.TypeName = embeddedTable.TypeName
 | |
| 				t.SQLName = embeddedTable.SQLName
 | |
| 				t.SQLNameForSelects = embeddedTable.SQLNameForSelects
 | |
| 				t.Alias = embeddedTable.Alias
 | |
| 				t.SQLAlias = embeddedTable.SQLAlias
 | |
| 				t.ModelName = embeddedTable.ModelName
 | |
| 			}
 | |
| 
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		if field := t.newField(f, index); field != nil {
 | |
| 			t.addField(field)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (t *Table) processBaseModelField(f reflect.StructField) {
 | |
| 	tag := tagparser.Parse(f.Tag.Get("bun"))
 | |
| 
 | |
| 	if isKnownTableOption(tag.Name) {
 | |
| 		internal.Warn.Printf(
 | |
| 			"%s.%s tag name %q is also an option name; is it a mistake?",
 | |
| 			t.TypeName, f.Name, tag.Name,
 | |
| 		)
 | |
| 	}
 | |
| 
 | |
| 	for name := range tag.Options {
 | |
| 		if !isKnownTableOption(name) {
 | |
| 			internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if tag.Name != "" {
 | |
| 		t.setName(tag.Name)
 | |
| 	}
 | |
| 
 | |
| 	if s, ok := tag.Options["select"]; ok {
 | |
| 		t.SQLNameForSelects = t.quoteTableName(s)
 | |
| 	}
 | |
| 
 | |
| 	if s, ok := tag.Options["alias"]; ok {
 | |
| 		t.Alias = s
 | |
| 		t.SQLAlias = t.quoteIdent(s)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| //nolint
 | |
| func (t *Table) newField(f reflect.StructField, index []int) *Field {
 | |
| 	tag := tagparser.Parse(f.Tag.Get("bun"))
 | |
| 
 | |
| 	sqlName := internal.Underscore(f.Name)
 | |
| 	if tag.Name != "" {
 | |
| 		sqlName = tag.Name
 | |
| 	}
 | |
| 
 | |
| 	if tag.Name != sqlName && isKnownFieldOption(tag.Name) {
 | |
| 		internal.Warn.Printf(
 | |
| 			"%s.%s tag name %q is also an option name; is it a mistake?",
 | |
| 			t.TypeName, f.Name, tag.Name,
 | |
| 		)
 | |
| 	}
 | |
| 
 | |
| 	for name := range tag.Options {
 | |
| 		if !isKnownFieldOption(name) {
 | |
| 			internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	index = append(index, f.Index...)
 | |
| 	if field := t.fieldWithLock(sqlName); field != nil {
 | |
| 		if indexEqual(field.Index, index) {
 | |
| 			return field
 | |
| 		}
 | |
| 		t.removeField(field)
 | |
| 	}
 | |
| 
 | |
| 	field := &Field{
 | |
| 		StructField: f,
 | |
| 
 | |
| 		Tag:          tag,
 | |
| 		IndirectType: indirectType(f.Type),
 | |
| 		Index:        index,
 | |
| 
 | |
| 		Name:    sqlName,
 | |
| 		GoName:  f.Name,
 | |
| 		SQLName: t.quoteIdent(sqlName),
 | |
| 	}
 | |
| 
 | |
| 	field.NotNull = tag.HasOption("notnull")
 | |
| 	field.NullZero = tag.HasOption("nullzero")
 | |
| 	field.AutoIncrement = tag.HasOption("autoincrement")
 | |
| 	if tag.HasOption("pk") {
 | |
| 		field.markAsPK()
 | |
| 	}
 | |
| 	if tag.HasOption("allowzero") {
 | |
| 		if tag.HasOption("nullzero") {
 | |
| 			internal.Warn.Printf(
 | |
| 				"%s.%s: nullzero and allowzero options are mutually exclusive",
 | |
| 				t.TypeName, f.Name,
 | |
| 			)
 | |
| 		}
 | |
| 		field.NullZero = false
 | |
| 	}
 | |
| 
 | |
| 	if v, ok := tag.Options["unique"]; ok {
 | |
| 		// Split the value by comma, this will allow multiple names to be specified.
 | |
| 		// We can use this to create multiple named unique constraints where a single column
 | |
| 		// might be included in multiple constraints.
 | |
| 		for _, uniqueName := range strings.Split(v, ",") {
 | |
| 			if t.Unique == nil {
 | |
| 				t.Unique = make(map[string][]*Field)
 | |
| 			}
 | |
| 			t.Unique[uniqueName] = append(t.Unique[uniqueName], field)
 | |
| 		}
 | |
| 	}
 | |
| 	if s, ok := tag.Options["default"]; ok {
 | |
| 		field.SQLDefault = s
 | |
| 	}
 | |
| 	if s, ok := field.Tag.Options["type"]; ok {
 | |
| 		field.UserSQLType = s
 | |
| 	}
 | |
| 	field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType)
 | |
| 	field.Append = t.dialect.FieldAppender(field)
 | |
| 	field.Scan = FieldScanner(t.dialect, field)
 | |
| 	field.IsZero = zeroChecker(field.StructField.Type)
 | |
| 
 | |
| 	if v, ok := tag.Options["alt"]; ok {
 | |
| 		t.FieldMap[v] = field
 | |
| 	}
 | |
| 
 | |
| 	t.allFields = append(t.allFields, field)
 | |
| 	if tag.HasOption("scanonly") {
 | |
| 		t.FieldMap[field.Name] = field
 | |
| 		if field.IndirectType.Kind() == reflect.Struct {
 | |
| 			t.inlineFields(field, nil)
 | |
| 		}
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	if _, ok := tag.Options["soft_delete"]; ok {
 | |
| 		field.NullZero = true
 | |
| 		t.SoftDeleteField = field
 | |
| 		t.UpdateSoftDeleteField = softDeleteFieldUpdater(field)
 | |
| 	}
 | |
| 
 | |
| 	return field
 | |
| }
 | |
| 
 | |
| //---------------------------------------------------------------------------------------
 | |
| 
 | |
| func (t *Table) initRelations() {
 | |
| 	for i := 0; i < len(t.Fields); {
 | |
| 		f := t.Fields[i]
 | |
| 		if t.tryRelation(f) {
 | |
| 			t.Fields = removeField(t.Fields, f)
 | |
| 			t.DataFields = removeField(t.DataFields, f)
 | |
| 		} else {
 | |
| 			i++
 | |
| 		}
 | |
| 
 | |
| 		if f.IndirectType.Kind() == reflect.Struct {
 | |
| 			t.inlineFields(f, nil)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (t *Table) tryRelation(field *Field) bool {
 | |
| 	if rel, ok := field.Tag.Options["rel"]; ok {
 | |
| 		t.initRelation(field, rel)
 | |
| 		return true
 | |
| 	}
 | |
| 	if field.Tag.HasOption("m2m") {
 | |
| 		t.addRelation(t.m2mRelation(field))
 | |
| 		return true
 | |
| 	}
 | |
| 
 | |
| 	if field.Tag.HasOption("join") {
 | |
| 		internal.Warn.Printf(
 | |
| 			`%s.%s option "join" requires a relation type`,
 | |
| 			t.TypeName, field.GoName,
 | |
| 		)
 | |
| 	}
 | |
| 
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func (t *Table) initRelation(field *Field, rel string) {
 | |
| 	switch rel {
 | |
| 	case "belongs-to":
 | |
| 		t.addRelation(t.belongsToRelation(field))
 | |
| 	case "has-one":
 | |
| 		t.addRelation(t.hasOneRelation(field))
 | |
| 	case "has-many":
 | |
| 		t.addRelation(t.hasManyRelation(field))
 | |
| 	default:
 | |
| 		panic(fmt.Errorf("bun: unknown relation=%s on field=%s", rel, field.GoName))
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (t *Table) addRelation(rel *Relation) {
 | |
| 	if t.Relations == nil {
 | |
| 		t.Relations = make(map[string]*Relation)
 | |
| 	}
 | |
| 	_, ok := t.Relations[rel.Field.GoName]
 | |
| 	if ok {
 | |
| 		panic(fmt.Errorf("%s already has %s", t, rel))
 | |
| 	}
 | |
| 	t.Relations[rel.Field.GoName] = rel
 | |
| }
 | |
| 
 | |
| func (t *Table) belongsToRelation(field *Field) *Relation {
 | |
| 	joinTable := t.dialect.Tables().Ref(field.IndirectType)
 | |
| 	if err := joinTable.CheckPKs(); err != nil {
 | |
| 		panic(err)
 | |
| 	}
 | |
| 
 | |
| 	rel := &Relation{
 | |
| 		Type:      HasOneRelation,
 | |
| 		Field:     field,
 | |
| 		JoinTable: joinTable,
 | |
| 	}
 | |
| 
 | |
| 	if join, ok := field.Tag.Options["join"]; ok {
 | |
| 		baseColumns, joinColumns := parseRelationJoin(join)
 | |
| 		for i, baseColumn := range baseColumns {
 | |
| 			joinColumn := joinColumns[i]
 | |
| 
 | |
| 			if f := t.fieldWithLock(baseColumn); f != nil {
 | |
| 				rel.BaseFields = append(rel.BaseFields, f)
 | |
| 			} else {
 | |
| 				panic(fmt.Errorf(
 | |
| 					"bun: %s belongs-to %s: %s must have column %s",
 | |
| 					t.TypeName, field.GoName, t.TypeName, baseColumn,
 | |
| 				))
 | |
| 			}
 | |
| 
 | |
| 			if f := joinTable.fieldWithLock(joinColumn); f != nil {
 | |
| 				rel.JoinFields = append(rel.JoinFields, f)
 | |
| 			} else {
 | |
| 				panic(fmt.Errorf(
 | |
| 					"bun: %s belongs-to %s: %s must have column %s",
 | |
| 					t.TypeName, field.GoName, t.TypeName, baseColumn,
 | |
| 				))
 | |
| 			}
 | |
| 		}
 | |
| 		return rel
 | |
| 	}
 | |
| 
 | |
| 	rel.JoinFields = joinTable.PKs
 | |
| 	fkPrefix := internal.Underscore(field.GoName) + "_"
 | |
| 	for _, joinPK := range joinTable.PKs {
 | |
| 		fkName := fkPrefix + joinPK.Name
 | |
| 		if fk := t.fieldWithLock(fkName); fk != nil {
 | |
| 			rel.BaseFields = append(rel.BaseFields, fk)
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		if fk := t.fieldWithLock(joinPK.Name); fk != nil {
 | |
| 			rel.BaseFields = append(rel.BaseFields, fk)
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		panic(fmt.Errorf(
 | |
| 			"bun: %s belongs-to %s: %s must have column %s "+
 | |
| 				"(to override, use join:base_column=join_column tag on %s field)",
 | |
| 			t.TypeName, field.GoName, t.TypeName, fkName, field.GoName,
 | |
| 		))
 | |
| 	}
 | |
| 	return rel
 | |
| }
 | |
| 
 | |
| func (t *Table) hasOneRelation(field *Field) *Relation {
 | |
| 	if err := t.CheckPKs(); err != nil {
 | |
| 		panic(err)
 | |
| 	}
 | |
| 
 | |
| 	joinTable := t.dialect.Tables().Ref(field.IndirectType)
 | |
| 	rel := &Relation{
 | |
| 		Type:      BelongsToRelation,
 | |
| 		Field:     field,
 | |
| 		JoinTable: joinTable,
 | |
| 	}
 | |
| 
 | |
| 	if join, ok := field.Tag.Options["join"]; ok {
 | |
| 		baseColumns, joinColumns := parseRelationJoin(join)
 | |
| 		for i, baseColumn := range baseColumns {
 | |
| 			if f := t.fieldWithLock(baseColumn); f != nil {
 | |
| 				rel.BaseFields = append(rel.BaseFields, f)
 | |
| 			} else {
 | |
| 				panic(fmt.Errorf(
 | |
| 					"bun: %s has-one %s: %s must have column %s",
 | |
| 					field.GoName, t.TypeName, joinTable.TypeName, baseColumn,
 | |
| 				))
 | |
| 			}
 | |
| 
 | |
| 			joinColumn := joinColumns[i]
 | |
| 			if f := joinTable.fieldWithLock(joinColumn); f != nil {
 | |
| 				rel.JoinFields = append(rel.JoinFields, f)
 | |
| 			} else {
 | |
| 				panic(fmt.Errorf(
 | |
| 					"bun: %s has-one %s: %s must have column %s",
 | |
| 					field.GoName, t.TypeName, joinTable.TypeName, baseColumn,
 | |
| 				))
 | |
| 			}
 | |
| 		}
 | |
| 		return rel
 | |
| 	}
 | |
| 
 | |
| 	rel.BaseFields = t.PKs
 | |
| 	fkPrefix := internal.Underscore(t.ModelName) + "_"
 | |
| 	for _, pk := range t.PKs {
 | |
| 		fkName := fkPrefix + pk.Name
 | |
| 		if f := joinTable.fieldWithLock(fkName); f != nil {
 | |
| 			rel.JoinFields = append(rel.JoinFields, f)
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		if f := joinTable.fieldWithLock(pk.Name); f != nil {
 | |
| 			rel.JoinFields = append(rel.JoinFields, f)
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		panic(fmt.Errorf(
 | |
| 			"bun: %s has-one %s: %s must have column %s "+
 | |
| 				"(to override, use join:base_column=join_column tag on %s field)",
 | |
| 			field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName,
 | |
| 		))
 | |
| 	}
 | |
| 	return rel
 | |
| }
 | |
| 
 | |
| func (t *Table) hasManyRelation(field *Field) *Relation {
 | |
| 	if err := t.CheckPKs(); err != nil {
 | |
| 		panic(err)
 | |
| 	}
 | |
| 	if field.IndirectType.Kind() != reflect.Slice {
 | |
| 		panic(fmt.Errorf(
 | |
| 			"bun: %s.%s has-many relation requires slice, got %q",
 | |
| 			t.TypeName, field.GoName, field.IndirectType.Kind(),
 | |
| 		))
 | |
| 	}
 | |
| 
 | |
| 	joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
 | |
| 	polymorphicValue, isPolymorphic := field.Tag.Options["polymorphic"]
 | |
| 	rel := &Relation{
 | |
| 		Type:      HasManyRelation,
 | |
| 		Field:     field,
 | |
| 		JoinTable: joinTable,
 | |
| 	}
 | |
| 	var polymorphicColumn string
 | |
| 
 | |
| 	if join, ok := field.Tag.Options["join"]; ok {
 | |
| 		baseColumns, joinColumns := parseRelationJoin(join)
 | |
| 		for i, baseColumn := range baseColumns {
 | |
| 			joinColumn := joinColumns[i]
 | |
| 
 | |
| 			if isPolymorphic && baseColumn == "type" {
 | |
| 				polymorphicColumn = joinColumn
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			if f := t.fieldWithLock(baseColumn); f != nil {
 | |
| 				rel.BaseFields = append(rel.BaseFields, f)
 | |
| 			} else {
 | |
| 				panic(fmt.Errorf(
 | |
| 					"bun: %s has-one %s: %s must have column %s",
 | |
| 					t.TypeName, field.GoName, t.TypeName, baseColumn,
 | |
| 				))
 | |
| 			}
 | |
| 
 | |
| 			if f := joinTable.fieldWithLock(joinColumn); f != nil {
 | |
| 				rel.JoinFields = append(rel.JoinFields, f)
 | |
| 			} else {
 | |
| 				panic(fmt.Errorf(
 | |
| 					"bun: %s has-one %s: %s must have column %s",
 | |
| 					t.TypeName, field.GoName, t.TypeName, baseColumn,
 | |
| 				))
 | |
| 			}
 | |
| 		}
 | |
| 	} else {
 | |
| 		rel.BaseFields = t.PKs
 | |
| 		fkPrefix := internal.Underscore(t.ModelName) + "_"
 | |
| 		if isPolymorphic {
 | |
| 			polymorphicColumn = fkPrefix + "type"
 | |
| 		}
 | |
| 
 | |
| 		for _, pk := range t.PKs {
 | |
| 			joinColumn := fkPrefix + pk.Name
 | |
| 			if fk := joinTable.fieldWithLock(joinColumn); fk != nil {
 | |
| 				rel.JoinFields = append(rel.JoinFields, fk)
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			if fk := joinTable.fieldWithLock(pk.Name); fk != nil {
 | |
| 				rel.JoinFields = append(rel.JoinFields, fk)
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			panic(fmt.Errorf(
 | |
| 				"bun: %s has-many %s: %s must have column %s "+
 | |
| 					"(to override, use join:base_column=join_column tag on the field %s)",
 | |
| 				t.TypeName, field.GoName, joinTable.TypeName, joinColumn, field.GoName,
 | |
| 			))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if isPolymorphic {
 | |
| 		rel.PolymorphicField = joinTable.fieldWithLock(polymorphicColumn)
 | |
| 		if rel.PolymorphicField == nil {
 | |
| 			panic(fmt.Errorf(
 | |
| 				"bun: %s has-many %s: %s must have polymorphic column %s",
 | |
| 				t.TypeName, field.GoName, joinTable.TypeName, polymorphicColumn,
 | |
| 			))
 | |
| 		}
 | |
| 
 | |
| 		if polymorphicValue == "" {
 | |
| 			polymorphicValue = t.ModelName
 | |
| 		}
 | |
| 		rel.PolymorphicValue = polymorphicValue
 | |
| 	}
 | |
| 
 | |
| 	return rel
 | |
| }
 | |
| 
 | |
| func (t *Table) m2mRelation(field *Field) *Relation {
 | |
| 	if field.IndirectType.Kind() != reflect.Slice {
 | |
| 		panic(fmt.Errorf(
 | |
| 			"bun: %s.%s m2m relation requires slice, got %q",
 | |
| 			t.TypeName, field.GoName, field.IndirectType.Kind(),
 | |
| 		))
 | |
| 	}
 | |
| 	joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
 | |
| 
 | |
| 	if err := t.CheckPKs(); err != nil {
 | |
| 		panic(err)
 | |
| 	}
 | |
| 	if err := joinTable.CheckPKs(); err != nil {
 | |
| 		panic(err)
 | |
| 	}
 | |
| 
 | |
| 	m2mTableName, ok := field.Tag.Options["m2m"]
 | |
| 	if !ok {
 | |
| 		panic(fmt.Errorf("bun: %s must have m2m tag option", field.GoName))
 | |
| 	}
 | |
| 
 | |
| 	m2mTable := t.dialect.Tables().ByName(m2mTableName)
 | |
| 	if m2mTable == nil {
 | |
| 		panic(fmt.Errorf(
 | |
| 			"bun: can't find m2m %s table (use db.RegisterModel)",
 | |
| 			m2mTableName,
 | |
| 		))
 | |
| 	}
 | |
| 
 | |
| 	rel := &Relation{
 | |
| 		Type:      ManyToManyRelation,
 | |
| 		Field:     field,
 | |
| 		JoinTable: joinTable,
 | |
| 		M2MTable:  m2mTable,
 | |
| 	}
 | |
| 	var leftColumn, rightColumn string
 | |
| 
 | |
| 	if join, ok := field.Tag.Options["join"]; ok {
 | |
| 		left, right := parseRelationJoin(join)
 | |
| 		leftColumn = left[0]
 | |
| 		rightColumn = right[0]
 | |
| 	} else {
 | |
| 		leftColumn = t.TypeName
 | |
| 		rightColumn = joinTable.TypeName
 | |
| 	}
 | |
| 
 | |
| 	leftField := m2mTable.fieldByGoName(leftColumn)
 | |
| 	if leftField == nil {
 | |
| 		panic(fmt.Errorf(
 | |
| 			"bun: %s many-to-many %s: %s must have field %s "+
 | |
| 				"(to override, use tag join:LeftField=RightField on field %s.%s",
 | |
| 			t.TypeName, field.GoName, m2mTable.TypeName, leftColumn, t.TypeName, field.GoName,
 | |
| 		))
 | |
| 	}
 | |
| 
 | |
| 	rightField := m2mTable.fieldByGoName(rightColumn)
 | |
| 	if rightField == nil {
 | |
| 		panic(fmt.Errorf(
 | |
| 			"bun: %s many-to-many %s: %s must have field %s "+
 | |
| 				"(to override, use tag join:LeftField=RightField on field %s.%s",
 | |
| 			t.TypeName, field.GoName, m2mTable.TypeName, rightColumn, t.TypeName, field.GoName,
 | |
| 		))
 | |
| 	}
 | |
| 
 | |
| 	leftRel := m2mTable.belongsToRelation(leftField)
 | |
| 	rel.BaseFields = leftRel.JoinFields
 | |
| 	rel.M2MBaseFields = leftRel.BaseFields
 | |
| 
 | |
| 	rightRel := m2mTable.belongsToRelation(rightField)
 | |
| 	rel.JoinFields = rightRel.JoinFields
 | |
| 	rel.M2MJoinFields = rightRel.BaseFields
 | |
| 
 | |
| 	return rel
 | |
| }
 | |
| 
 | |
| func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) {
 | |
| 	if seen == nil {
 | |
| 		seen = map[reflect.Type]struct{}{t.Type: {}}
 | |
| 	}
 | |
| 
 | |
| 	if _, ok := seen[field.IndirectType]; ok {
 | |
| 		return
 | |
| 	}
 | |
| 	seen[field.IndirectType] = struct{}{}
 | |
| 
 | |
| 	joinTable := t.dialect.Tables().Ref(field.IndirectType)
 | |
| 	for _, f := range joinTable.allFields {
 | |
| 		f = f.Clone()
 | |
| 		f.GoName = field.GoName + "_" + f.GoName
 | |
| 		f.Name = field.Name + "__" + f.Name
 | |
| 		f.SQLName = t.quoteIdent(f.Name)
 | |
| 		f.Index = appendNew(field.Index, f.Index...)
 | |
| 
 | |
| 		t.fieldsMapMu.Lock()
 | |
| 		if _, ok := t.FieldMap[f.Name]; !ok {
 | |
| 			t.FieldMap[f.Name] = f
 | |
| 		}
 | |
| 		t.fieldsMapMu.Unlock()
 | |
| 
 | |
| 		if f.IndirectType.Kind() != reflect.Struct {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		if _, ok := seen[f.IndirectType]; !ok {
 | |
| 			t.inlineFields(f, seen)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| //------------------------------------------------------------------------------
 | |
| 
 | |
| func (t *Table) Dialect() Dialect        { return t.dialect }
 | |
| func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) }
 | |
| func (t *Table) HasAfterScanHook() bool  { return t.flags.Has(afterScanHookFlag) }
 | |
| 
 | |
| //------------------------------------------------------------------------------
 | |
| 
 | |
| func (t *Table) AppendNamedArg(
 | |
| 	fmter Formatter, b []byte, name string, strct reflect.Value,
 | |
| ) ([]byte, bool) {
 | |
| 	if field, ok := t.FieldMap[name]; ok {
 | |
| 		return fmter.appendArg(b, field.Value(strct).Interface()), true
 | |
| 	}
 | |
| 	return b, false
 | |
| }
 | |
| 
 | |
| func (t *Table) quoteTableName(s string) Safe {
 | |
| 	// Don't quote if table name contains placeholder (?) or parentheses.
 | |
| 	if strings.IndexByte(s, '?') >= 0 ||
 | |
| 		strings.IndexByte(s, '(') >= 0 ||
 | |
| 		strings.IndexByte(s, ')') >= 0 {
 | |
| 		return Safe(s)
 | |
| 	}
 | |
| 	return t.quoteIdent(s)
 | |
| }
 | |
| 
 | |
| func (t *Table) quoteIdent(s string) Safe {
 | |
| 	return Safe(NewFormatter(t.dialect).AppendIdent(nil, s))
 | |
| }
 | |
| 
 | |
| func appendNew(dst []int, src ...int) []int {
 | |
| 	cp := make([]int, len(dst)+len(src))
 | |
| 	copy(cp, dst)
 | |
| 	copy(cp[len(dst):], src)
 | |
| 	return cp
 | |
| }
 | |
| 
 | |
| func isKnownTableOption(name string) bool {
 | |
| 	switch name {
 | |
| 	case "alias", "select":
 | |
| 		return true
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func isKnownFieldOption(name string) bool {
 | |
| 	switch name {
 | |
| 	case "alias",
 | |
| 		"type",
 | |
| 		"array",
 | |
| 		"hstore",
 | |
| 		"composite",
 | |
| 		"json_use_number",
 | |
| 		"msgpack",
 | |
| 		"notnull",
 | |
| 		"nullzero",
 | |
| 		"allowzero",
 | |
| 		"default",
 | |
| 		"unique",
 | |
| 		"soft_delete",
 | |
| 		"scanonly",
 | |
| 
 | |
| 		"pk",
 | |
| 		"autoincrement",
 | |
| 		"rel",
 | |
| 		"join",
 | |
| 		"m2m",
 | |
| 		"polymorphic":
 | |
| 		return true
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func removeField(fields []*Field, field *Field) []*Field {
 | |
| 	for i, f := range fields {
 | |
| 		if f == field {
 | |
| 			return append(fields[:i], fields[i+1:]...)
 | |
| 		}
 | |
| 	}
 | |
| 	return fields
 | |
| }
 | |
| 
 | |
| func parseRelationJoin(join string) ([]string, []string) {
 | |
| 	ss := strings.Split(join, ",")
 | |
| 	baseColumns := make([]string, len(ss))
 | |
| 	joinColumns := make([]string, len(ss))
 | |
| 	for i, s := range ss {
 | |
| 		ss := strings.Split(strings.TrimSpace(s), "=")
 | |
| 		if len(ss) != 2 {
 | |
| 			panic(fmt.Errorf("can't parse relation join: %q", join))
 | |
| 		}
 | |
| 		baseColumns[i] = ss[0]
 | |
| 		joinColumns[i] = ss[1]
 | |
| 	}
 | |
| 	return baseColumns, joinColumns
 | |
| }
 | |
| 
 | |
| //------------------------------------------------------------------------------
 | |
| 
 | |
| func softDeleteFieldUpdater(field *Field) func(fv reflect.Value, tm time.Time) error {
 | |
| 	typ := field.StructField.Type
 | |
| 
 | |
| 	switch typ {
 | |
| 	case timeType:
 | |
| 		return func(fv reflect.Value, tm time.Time) error {
 | |
| 			ptr := fv.Addr().Interface().(*time.Time)
 | |
| 			*ptr = tm
 | |
| 			return nil
 | |
| 		}
 | |
| 	case nullTimeType:
 | |
| 		return func(fv reflect.Value, tm time.Time) error {
 | |
| 			ptr := fv.Addr().Interface().(*sql.NullTime)
 | |
| 			*ptr = sql.NullTime{Time: tm}
 | |
| 			return nil
 | |
| 		}
 | |
| 	case nullIntType:
 | |
| 		return func(fv reflect.Value, tm time.Time) error {
 | |
| 			ptr := fv.Addr().Interface().(*sql.NullInt64)
 | |
| 			*ptr = sql.NullInt64{Int64: tm.UnixNano()}
 | |
| 			return nil
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	switch field.IndirectType.Kind() {
 | |
| 	case reflect.Int64:
 | |
| 		return func(fv reflect.Value, tm time.Time) error {
 | |
| 			ptr := fv.Addr().Interface().(*int64)
 | |
| 			*ptr = tm.UnixNano()
 | |
| 			return nil
 | |
| 		}
 | |
| 	case reflect.Ptr:
 | |
| 		typ = typ.Elem()
 | |
| 	default:
 | |
| 		return softDeleteFieldUpdaterFallback(field)
 | |
| 	}
 | |
| 
 | |
| 	switch typ { //nolint:gocritic
 | |
| 	case timeType:
 | |
| 		return func(fv reflect.Value, tm time.Time) error {
 | |
| 			fv.Set(reflect.ValueOf(&tm))
 | |
| 			return nil
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	switch typ.Kind() { //nolint:gocritic
 | |
| 	case reflect.Int64:
 | |
| 		return func(fv reflect.Value, tm time.Time) error {
 | |
| 			utime := tm.UnixNano()
 | |
| 			fv.Set(reflect.ValueOf(&utime))
 | |
| 			return nil
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return softDeleteFieldUpdaterFallback(field)
 | |
| }
 | |
| 
 | |
| func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value, tm time.Time) error {
 | |
| 	return func(fv reflect.Value, tm time.Time) error {
 | |
| 		return field.ScanWithCheck(fv, tm)
 | |
| 	}
 | |
| }
 |