bumps uptrace/bun dependencies to v1.2.6 (#3569)

This commit is contained in:
kim 2024-11-25 15:42:37 +00:00 committed by GitHub
commit 3fceb5fc1a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
68 changed files with 6517 additions and 194 deletions

View file

@ -1,3 +1,52 @@
## [1.2.6](https://github.com/uptrace/bun/compare/v1.2.5...v1.2.6) (2024-11-20)
### Bug Fixes
* append IDENTITY to ADD COLUMN statement if needed ([694f873](https://github.com/uptrace/bun/commit/694f873d61ed8d2f09032ae0c0dbec4b71c3719e))
* **ci:** prune stale should be executed at 3 AM every day ([0cedcb0](https://github.com/uptrace/bun/commit/0cedcb068229b63041a4f48de12bb767c8454048))
* cleanup after testUniqueRenamedTable ([b1ae32e](https://github.com/uptrace/bun/commit/b1ae32e9e9f45ff2a66e50bfd13bedcf6653d874))
* fix go.mod of oracledialect ([89e21ea](https://github.com/uptrace/bun/commit/89e21eab362c60511cca00890ae29551a2ba7c46))
* has many relationship with multiple columns ([1664b2c](https://github.com/uptrace/bun/commit/1664b2c07a5f6cfd3b6730e5005373686e9830a6))
* ignore case for type equivalence ([c3253a5](https://github.com/uptrace/bun/commit/c3253a5c59b078607db9e216ddc11afdef546e05))
* implement DefaultSchema for Oracle dialect ([d08fa40](https://github.com/uptrace/bun/commit/d08fa40cc87d67296a83a77448ea511531fc8cdd))
* **oracledialect:** add go.mod file so the dialect is released properly ([#1043](https://github.com/uptrace/bun/issues/1043)) ([1bb5597](https://github.com/uptrace/bun/commit/1bb5597f1a32f5d693101ef4a62e25d99f5b9db5))
* **oracledialect:** update go.mod by go mod tidy to fix tests ([7f90a15](https://github.com/uptrace/bun/commit/7f90a15c51a2482dda94226dd13b913d6b470a29))
* **pgdialect:** array value quoting ([892c416](https://github.com/uptrace/bun/commit/892c416272a8428c592896d65d3ad51a6f2356d8))
* remove schema name from t.Name during bun-schema inspection ([31ed582](https://github.com/uptrace/bun/commit/31ed58254ad08143d88684672acd33ce044ea5a9))
* rename column only if the name does not exist in 'target' ([fed6012](https://github.com/uptrace/bun/commit/fed6012d177e55b8320b31ef37fc02a0cbf0b9f5))
* support embed with tag Unique ([3acd6dd](https://github.com/uptrace/bun/commit/3acd6dd8546118d7b867ca796a5e56311edad070))
* update oracledialect/version.go in release.sh ([bcd070f](https://github.com/uptrace/bun/commit/bcd070f48a75d0092a5620261658c9c5994f0bf6))
* update schema.Field names ([9b810de](https://github.com/uptrace/bun/commit/9b810dee4b1a721efb82c913099f39f52c44eb57))
### Features
* add and drop columns ([3fdd5b8](https://github.com/uptrace/bun/commit/3fdd5b8f635f849a74e78c665274609f75245b19))
* add and drop IDENTITY ([dd83779](https://github.com/uptrace/bun/commit/dd837795c31490fd8816eec0e9833e79fafdda32))
* add support type for net/netip.addr and net/netip.prefix ([#1028](https://github.com/uptrace/bun/issues/1028)) ([95c4a8e](https://github.com/uptrace/bun/commit/95c4a8ebd634e1e99114727a7b157eeeb9297ee9))
* **automigrate:** detect renamed tables ([c03938f](https://github.com/uptrace/bun/commit/c03938ff5e9fa2f653e4c60668b1368357d2de10))
* change column type ([3cfd8c6](https://github.com/uptrace/bun/commit/3cfd8c62125786aaf6f493acc5b39f4d3db3d628))
* **ci:** support release on osx ([435510b](https://github.com/uptrace/bun/commit/435510b0a73b0d9e6d06e3e3c3f0fa4379e9ed8c))
* create sql migrations and apply them ([1bf7cfd](https://github.com/uptrace/bun/commit/1bf7cfd067e0e26ae212b0f7421e5abc6f67fb4f))
* create transactional migration files ([c3320f6](https://github.com/uptrace/bun/commit/c3320f624830dc2fe99af2c7cbe492b2a83f9e4a))
* detect Create/Drop table ([408859f](https://github.com/uptrace/bun/commit/408859f07be38236b39a00909cdce55d49f6f824))
* detect modified relations ([a918dc4](https://github.com/uptrace/bun/commit/a918dc472a33dd24c5fffd4d048bcf49f2e07a42))
* detect renamed columns ([886d0a5](https://github.com/uptrace/bun/commit/886d0a5b18aba272f1c86af2a2cf68ce4c8879f2))
* detect renamed tables ([8857bab](https://github.com/uptrace/bun/commit/8857bab54b94170d218633f3b210d379e4e51a21))
* enhance Apply method to accept multiple functions ([7823f2f](https://github.com/uptrace/bun/commit/7823f2f24c814e104dc59475156255c7b3b26144))
* implement fmt.Stringer queries ([5060e47](https://github.com/uptrace/bun/commit/5060e47db13451a982e48d0f14055a58ba60b472))
* improve FK handling ([a822fc5](https://github.com/uptrace/bun/commit/a822fc5f8ae547b7cd41e1ca35609d519d78598b))
* include target schema name in migration name ([ac8d221](https://github.com/uptrace/bun/commit/ac8d221e6443b469e794314c5fc189250fa542d5))
* **mariadb:** support RETURNING clause in DELETE statement ([b8dec9d](https://github.com/uptrace/bun/commit/b8dec9d9a06124696bd5ee2abbf33f19087174b6))
* migrate FKs ([4c1dfdb](https://github.com/uptrace/bun/commit/4c1dfdbe99c73d0c0f2d7b1f8b11adf30c6a41f7))
* **mysql:** support ORDER BY and LIMIT clauses in UPDATE and DELETE statements ([de71bed](https://github.com/uptrace/bun/commit/de71bed9252980648269af85b7a51cbc464ce710))
* support modifying primary keys ([a734629](https://github.com/uptrace/bun/commit/a734629fa285406038cbe4a50798626b5ac08539))
* support UNIQUE constraints ([3c4d5d2](https://github.com/uptrace/bun/commit/3c4d5d2c47be4652fb9b5cf1c6bd7b6c0a437287))
* use *bun.DB in MigratorDialect ([a8788bf](https://github.com/uptrace/bun/commit/a8788bf62cbcc954a08532c299c774262de7a81d))
## [1.2.5](https://github.com/uptrace/bun/compare/v1.2.3...v1.2.5) (2024-10-26)

View file

@ -6,7 +6,7 @@ test:
echo "go test in $${dir}"; \
(cd "$${dir}" && \
go test && \
env GOOS=linux GOARCH=386 go test && \
env GOOS=linux GOARCH=386 TZ= go test && \
go vet); \
done

View file

@ -4,6 +4,7 @@
[![PkgGoDev](https://pkg.go.dev/badge/github.com/uptrace/bun)](https://pkg.go.dev/github.com/uptrace/bun)
[![Documentation](https://img.shields.io/badge/bun-documentation-informational)](https://bun.uptrace.dev/)
[![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj)
[![Gurubase](https://img.shields.io/badge/Gurubase-Ask%20Bun%20Guru-006BFF)](https://gurubase.io/g/bun)
> Bun is brought to you by :star: [**uptrace/uptrace**](https://github.com/uptrace/uptrace). Uptrace
> is an open-source APM tool that supports distributed tracing, metrics, and logs. You can use it to

View file

@ -703,6 +703,5 @@ func (tx Tx) NewDropColumn() *DropColumnQuery {
//------------------------------------------------------------------------------
func (db *DB) makeQueryBytes() []byte {
// TODO: make this configurable?
return make([]byte, 0, 4096)
return internal.MakeQueryBytes()
}

View file

@ -25,24 +25,24 @@ func AppendBool(b []byte, v bool) []byte {
return append(b, "FALSE"...)
}
func AppendFloat32(b []byte, v float32) []byte {
return appendFloat(b, float64(v), 32)
func AppendFloat32(b []byte, num float32) []byte {
return appendFloat(b, float64(num), 32)
}
func AppendFloat64(b []byte, v float64) []byte {
return appendFloat(b, v, 64)
func AppendFloat64(b []byte, num float64) []byte {
return appendFloat(b, num, 64)
}
func appendFloat(b []byte, v float64, bitSize int) []byte {
func appendFloat(b []byte, num float64, bitSize int) []byte {
switch {
case math.IsNaN(v):
case math.IsNaN(num):
return append(b, "'NaN'"...)
case math.IsInf(v, 1):
case math.IsInf(num, 1):
return append(b, "'Infinity'"...)
case math.IsInf(v, -1):
case math.IsInf(num, -1):
return append(b, "'-Infinity'"...)
default:
return strconv.AppendFloat(b, v, 'f', -1, bitSize)
return strconv.AppendFloat(b, num, 'f', -1, bitSize)
}
}

View file

@ -31,5 +31,8 @@ const (
UpdateFromTable
MSSavepoint
GeneratedIdentity
CompositeIn // ... WHERE (A,B) IN ((N, NN), (N, NN)...)
CompositeIn // ... WHERE (A,B) IN ((N, NN), (N, NN)...)
UpdateOrderLimit // UPDATE ... ORDER BY ... LIMIT ...
DeleteOrderLimit // DELETE ... ORDER BY ... LIMIT ...
DeleteReturning
)

View file

@ -0,0 +1,245 @@
package pgdialect
import (
"fmt"
"strings"
"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate"
"github.com/uptrace/bun/migrate/sqlschema"
"github.com/uptrace/bun/schema"
)
func (d *Dialect) NewMigrator(db *bun.DB, schemaName string) sqlschema.Migrator {
return &migrator{db: db, schemaName: schemaName, BaseMigrator: sqlschema.NewBaseMigrator(db)}
}
type migrator struct {
*sqlschema.BaseMigrator
db *bun.DB
schemaName string
}
var _ sqlschema.Migrator = (*migrator)(nil)
func (m *migrator) AppendSQL(b []byte, operation interface{}) (_ []byte, err error) {
fmter := m.db.Formatter()
// Append ALTER TABLE statement to the enclosed query bytes []byte.
appendAlterTable := func(query []byte, tableName string) []byte {
query = append(query, "ALTER TABLE "...)
query = m.appendFQN(fmter, query, tableName)
return append(query, " "...)
}
switch change := operation.(type) {
case *migrate.CreateTableOp:
return m.AppendCreateTable(b, change.Model)
case *migrate.DropTableOp:
return m.AppendDropTable(b, m.schemaName, change.TableName)
case *migrate.RenameTableOp:
b, err = m.renameTable(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.RenameColumnOp:
b, err = m.renameColumn(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.AddColumnOp:
b, err = m.addColumn(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.DropColumnOp:
b, err = m.dropColumn(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.AddPrimaryKeyOp:
b, err = m.addPrimaryKey(fmter, appendAlterTable(b, change.TableName), change.PrimaryKey)
case *migrate.ChangePrimaryKeyOp:
b, err = m.changePrimaryKey(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.DropPrimaryKeyOp:
b, err = m.dropConstraint(fmter, appendAlterTable(b, change.TableName), change.PrimaryKey.Name)
case *migrate.AddUniqueConstraintOp:
b, err = m.addUnique(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.DropUniqueConstraintOp:
b, err = m.dropConstraint(fmter, appendAlterTable(b, change.TableName), change.Unique.Name)
case *migrate.ChangeColumnTypeOp:
b, err = m.changeColumnType(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.AddForeignKeyOp:
b, err = m.addForeignKey(fmter, appendAlterTable(b, change.TableName()), change)
case *migrate.DropForeignKeyOp:
b, err = m.dropConstraint(fmter, appendAlterTable(b, change.TableName()), change.ConstraintName)
default:
return nil, fmt.Errorf("append sql: unknown operation %T", change)
}
if err != nil {
return nil, fmt.Errorf("append sql: %w", err)
}
return b, nil
}
func (m *migrator) appendFQN(fmter schema.Formatter, b []byte, tableName string) []byte {
return fmter.AppendQuery(b, "?.?", bun.Ident(m.schemaName), bun.Ident(tableName))
}
func (m *migrator) renameTable(fmter schema.Formatter, b []byte, rename *migrate.RenameTableOp) (_ []byte, err error) {
b = append(b, "RENAME TO "...)
b = fmter.AppendName(b, rename.NewName)
return b, nil
}
func (m *migrator) renameColumn(fmter schema.Formatter, b []byte, rename *migrate.RenameColumnOp) (_ []byte, err error) {
b = append(b, "RENAME COLUMN "...)
b = fmter.AppendName(b, rename.OldName)
b = append(b, " TO "...)
b = fmter.AppendName(b, rename.NewName)
return b, nil
}
func (m *migrator) addColumn(fmter schema.Formatter, b []byte, add *migrate.AddColumnOp) (_ []byte, err error) {
b = append(b, "ADD COLUMN "...)
b = fmter.AppendName(b, add.ColumnName)
b = append(b, " "...)
b, err = add.Column.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
if add.Column.GetDefaultValue() != "" {
b = append(b, " DEFAULT "...)
b = append(b, add.Column.GetDefaultValue()...)
b = append(b, " "...)
}
if add.Column.GetIsIdentity() {
b = appendGeneratedAsIdentity(b)
}
return b, nil
}
func (m *migrator) dropColumn(fmter schema.Formatter, b []byte, drop *migrate.DropColumnOp) (_ []byte, err error) {
b = append(b, "DROP COLUMN "...)
b = fmter.AppendName(b, drop.ColumnName)
return b, nil
}
func (m *migrator) addPrimaryKey(fmter schema.Formatter, b []byte, pk sqlschema.PrimaryKey) (_ []byte, err error) {
b = append(b, "ADD PRIMARY KEY ("...)
b, _ = pk.Columns.AppendQuery(fmter, b)
b = append(b, ")"...)
return b, nil
}
func (m *migrator) changePrimaryKey(fmter schema.Formatter, b []byte, change *migrate.ChangePrimaryKeyOp) (_ []byte, err error) {
b, _ = m.dropConstraint(fmter, b, change.Old.Name)
b = append(b, ", "...)
b, _ = m.addPrimaryKey(fmter, b, change.New)
return b, nil
}
func (m *migrator) addUnique(fmter schema.Formatter, b []byte, change *migrate.AddUniqueConstraintOp) (_ []byte, err error) {
b = append(b, "ADD CONSTRAINT "...)
if change.Unique.Name != "" {
b = fmter.AppendName(b, change.Unique.Name)
} else {
// Default naming scheme for unique constraints in Postgres is <table>_<column>_key
b = fmter.AppendName(b, fmt.Sprintf("%s_%s_key", change.TableName, change.Unique.Columns))
}
b = append(b, " UNIQUE ("...)
b, _ = change.Unique.Columns.AppendQuery(fmter, b)
b = append(b, ")"...)
return b, nil
}
func (m *migrator) dropConstraint(fmter schema.Formatter, b []byte, name string) (_ []byte, err error) {
b = append(b, "DROP CONSTRAINT "...)
b = fmter.AppendName(b, name)
return b, nil
}
func (m *migrator) addForeignKey(fmter schema.Formatter, b []byte, add *migrate.AddForeignKeyOp) (_ []byte, err error) {
b = append(b, "ADD CONSTRAINT "...)
name := add.ConstraintName
if name == "" {
colRef := add.ForeignKey.From
columns := strings.Join(colRef.Column.Split(), "_")
name = fmt.Sprintf("%s_%s_fkey", colRef.TableName, columns)
}
b = fmter.AppendName(b, name)
b = append(b, " FOREIGN KEY ("...)
if b, err = add.ForeignKey.From.Column.AppendQuery(fmter, b); err != nil {
return b, err
}
b = append(b, ")"...)
b = append(b, " REFERENCES "...)
b = m.appendFQN(fmter, b, add.ForeignKey.To.TableName)
b = append(b, " ("...)
if b, err = add.ForeignKey.To.Column.AppendQuery(fmter, b); err != nil {
return b, err
}
b = append(b, ")"...)
return b, nil
}
func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *migrate.ChangeColumnTypeOp) (_ []byte, err error) {
// alterColumn never re-assigns err, so there is no need to check for err != nil after calling it
var i int
appendAlterColumn := func() {
if i > 0 {
b = append(b, ", "...)
}
b = append(b, "ALTER COLUMN "...)
b = fmter.AppendName(b, colDef.Column)
i++
}
got, want := colDef.From, colDef.To
inspector := m.db.Dialect().(sqlschema.InspectorDialect)
if !inspector.CompareType(want, got) {
appendAlterColumn()
b = append(b, " SET DATA TYPE "...)
if b, err = want.AppendQuery(fmter, b); err != nil {
return b, err
}
}
// Column must be declared NOT NULL before identity can be added.
// Although PG can resolve the order of operations itself, we make this explicit in the query.
if want.GetIsNullable() != got.GetIsNullable() {
appendAlterColumn()
if !want.GetIsNullable() {
b = append(b, " SET NOT NULL"...)
} else {
b = append(b, " DROP NOT NULL"...)
}
}
if want.GetIsIdentity() != got.GetIsIdentity() {
appendAlterColumn()
if !want.GetIsIdentity() {
b = append(b, " DROP IDENTITY"...)
} else {
b = append(b, " ADD"...)
b = appendGeneratedAsIdentity(b)
}
}
if want.GetDefaultValue() != got.GetDefaultValue() {
appendAlterColumn()
if want.GetDefaultValue() == "" {
b = append(b, " DROP DEFAULT"...)
} else {
b = append(b, " SET DEFAULT "...)
b = append(b, want.GetDefaultValue()...)
}
}
return b, nil
}

View file

@ -5,6 +5,7 @@ import (
"database/sql/driver"
"encoding/hex"
"fmt"
"math"
"reflect"
"strconv"
"time"
@ -159,7 +160,7 @@ func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte {
case int64:
return strconv.AppendInt(b, v, 10)
case float64:
return dialect.AppendFloat64(b, v)
return arrayAppendFloat64(b, v)
case bool:
return dialect.AppendBool(b, v)
case []byte:
@ -167,7 +168,10 @@ func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte {
case string:
return arrayAppendString(b, v)
case time.Time:
return fmter.Dialect().AppendTime(b, v)
b = append(b, '"')
b = appendTime(b, v)
b = append(b, '"')
return b
default:
err := fmt.Errorf("pgdialect: can't append %T", v)
return dialect.AppendError(b, err)
@ -288,7 +292,7 @@ func appendFloat64Slice(b []byte, floats []float64) []byte {
b = append(b, '{')
for _, n := range floats {
b = dialect.AppendFloat64(b, n)
b = arrayAppendFloat64(b, n)
b = append(b, ',')
}
if len(floats) > 0 {
@ -302,6 +306,19 @@ func appendFloat64Slice(b []byte, floats []float64) []byte {
return b
}
func arrayAppendFloat64(b []byte, num float64) []byte {
switch {
case math.IsNaN(num):
return append(b, "NaN"...)
case math.IsInf(num, 1):
return append(b, "Infinity"...)
case math.IsInf(num, -1):
return append(b, "-Infinity"...)
default:
return strconv.AppendFloat(b, num, 'f', -1, 64)
}
}
func appendTimeSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ts := v.Convert(sliceTimeType).Interface().([]time.Time)
return appendTimeSlice(fmter, b, ts)
@ -383,6 +400,10 @@ func arrayScanner(typ reflect.Type) schema.ScannerFunc {
}
}
if src == nil {
return nil
}
b, err := toBytes(src)
if err != nil {
return err
@ -553,7 +574,7 @@ func scanFloat64SliceValue(dest reflect.Value, src interface{}) error {
}
func scanFloat64Slice(src interface{}) ([]float64, error) {
if src == -1 {
if src == nil {
return nil, nil
}
@ -593,7 +614,7 @@ func toBytes(src interface{}) ([]byte, error) {
case []byte:
return src, nil
default:
return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src)
return nil, fmt.Errorf("pgdialect: got %T, wanted []byte or string", src)
}
}

View file

@ -10,6 +10,7 @@ import (
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/migrate/sqlschema"
"github.com/uptrace/bun/schema"
)
@ -29,6 +30,10 @@ type Dialect struct {
features feature.Feature
}
var _ schema.Dialect = (*Dialect)(nil)
var _ sqlschema.InspectorDialect = (*Dialect)(nil)
var _ sqlschema.MigratorDialect = (*Dialect)(nil)
func New() *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
@ -48,7 +53,8 @@ func New() *Dialect {
feature.InsertOnConflict |
feature.SelectExists |
feature.GeneratedIdentity |
feature.CompositeIn
feature.CompositeIn |
feature.DeleteReturning
return d
}
@ -118,5 +124,10 @@ func (d *Dialect) AppendUint64(b []byte, n uint64) []byte {
}
func (d *Dialect) AppendSequence(b []byte, _ *schema.Table, _ *schema.Field) []byte {
return appendGeneratedAsIdentity(b)
}
// appendGeneratedAsIdentity appends GENERATED BY DEFAULT AS IDENTITY to the column definition.
func appendGeneratedAsIdentity(b []byte) []byte {
return append(b, " GENERATED BY DEFAULT AS IDENTITY"...)
}

View file

@ -0,0 +1,297 @@
package pgdialect
import (
"context"
"strings"
"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate/sqlschema"
orderedmap "github.com/wk8/go-ordered-map/v2"
)
type (
Schema = sqlschema.BaseDatabase
Table = sqlschema.BaseTable
Column = sqlschema.BaseColumn
)
func (d *Dialect) NewInspector(db *bun.DB, options ...sqlschema.InspectorOption) sqlschema.Inspector {
return newInspector(db, options...)
}
type Inspector struct {
sqlschema.InspectorConfig
db *bun.DB
}
var _ sqlschema.Inspector = (*Inspector)(nil)
func newInspector(db *bun.DB, options ...sqlschema.InspectorOption) *Inspector {
i := &Inspector{db: db}
sqlschema.ApplyInspectorOptions(&i.InspectorConfig, options...)
return i
}
func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) {
dbSchema := Schema{
Tables: orderedmap.New[string, sqlschema.Table](),
ForeignKeys: make(map[sqlschema.ForeignKey]string),
}
exclude := in.ExcludeTables
if len(exclude) == 0 {
// Avoid getting NOT IN (NULL) if bun.In() is called with an empty slice.
exclude = []string{""}
}
var tables []*InformationSchemaTable
if err := in.db.NewRaw(sqlInspectTables, in.SchemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil {
return dbSchema, err
}
var fks []*ForeignKey
if err := in.db.NewRaw(sqlInspectForeignKeys, in.SchemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil {
return dbSchema, err
}
dbSchema.ForeignKeys = make(map[sqlschema.ForeignKey]string, len(fks))
for _, table := range tables {
var columns []*InformationSchemaColumn
if err := in.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil {
return dbSchema, err
}
colDefs := orderedmap.New[string, sqlschema.Column]()
uniqueGroups := make(map[string][]string)
for _, c := range columns {
def := c.Default
if c.IsSerial || c.IsIdentity {
def = ""
} else if !c.IsDefaultLiteral {
def = strings.ToLower(def)
}
colDefs.Set(c.Name, &Column{
Name: c.Name,
SQLType: c.DataType,
VarcharLen: c.VarcharLen,
DefaultValue: def,
IsNullable: c.IsNullable,
IsAutoIncrement: c.IsSerial,
IsIdentity: c.IsIdentity,
})
for _, group := range c.UniqueGroups {
uniqueGroups[group] = append(uniqueGroups[group], c.Name)
}
}
var unique []sqlschema.Unique
for name, columns := range uniqueGroups {
unique = append(unique, sqlschema.Unique{
Name: name,
Columns: sqlschema.NewColumns(columns...),
})
}
var pk *sqlschema.PrimaryKey
if len(table.PrimaryKey.Columns) > 0 {
pk = &sqlschema.PrimaryKey{
Name: table.PrimaryKey.ConstraintName,
Columns: sqlschema.NewColumns(table.PrimaryKey.Columns...),
}
}
dbSchema.Tables.Set(table.Name, &Table{
Schema: table.Schema,
Name: table.Name,
Columns: colDefs,
PrimaryKey: pk,
UniqueConstraints: unique,
})
}
for _, fk := range fks {
dbSchema.ForeignKeys[sqlschema.ForeignKey{
From: sqlschema.NewColumnReference(fk.SourceTable, fk.SourceColumns...),
To: sqlschema.NewColumnReference(fk.TargetTable, fk.TargetColumns...),
}] = fk.ConstraintName
}
return dbSchema, nil
}
type InformationSchemaTable struct {
Schema string `bun:"table_schema,pk"`
Name string `bun:"table_name,pk"`
PrimaryKey PrimaryKey `bun:"embed:primary_key_"`
Columns []*InformationSchemaColumn `bun:"rel:has-many,join:table_schema=table_schema,join:table_name=table_name"`
}
type InformationSchemaColumn struct {
Schema string `bun:"table_schema"`
Table string `bun:"table_name"`
Name string `bun:"column_name"`
DataType string `bun:"data_type"`
VarcharLen int `bun:"varchar_len"`
IsArray bool `bun:"is_array"`
ArrayDims int `bun:"array_dims"`
Default string `bun:"default"`
IsDefaultLiteral bool `bun:"default_is_literal_expr"`
IsIdentity bool `bun:"is_identity"`
IndentityType string `bun:"identity_type"`
IsSerial bool `bun:"is_serial"`
IsNullable bool `bun:"is_nullable"`
UniqueGroups []string `bun:"unique_groups,array"`
}
type ForeignKey struct {
ConstraintName string `bun:"constraint_name"`
SourceSchema string `bun:"schema_name"`
SourceTable string `bun:"table_name"`
SourceColumns []string `bun:"columns,array"`
TargetSchema string `bun:"target_schema"`
TargetTable string `bun:"target_table"`
TargetColumns []string `bun:"target_columns,array"`
}
type PrimaryKey struct {
ConstraintName string `bun:"name"`
Columns []string `bun:"columns,array"`
}
const (
// sqlInspectTables retrieves all user-defined tables in the selected schema.
// Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results.
sqlInspectTables = `
SELECT
"t".table_schema,
"t".table_name,
pk.name AS primary_key_name,
pk.columns AS primary_key_columns
FROM information_schema.tables "t"
LEFT JOIN (
SELECT i.indrelid, "idx".relname AS "name", ARRAY_AGG("a".attname) AS "columns"
FROM pg_index i
JOIN pg_attribute "a"
ON "a".attrelid = i.indrelid
AND "a".attnum = ANY("i".indkey)
AND i.indisprimary
JOIN pg_class "idx" ON i.indexrelid = "idx".oid
GROUP BY 1, 2
) pk
ON ("t".table_schema || '.' || "t".table_name)::regclass = pk.indrelid
WHERE table_type = 'BASE TABLE'
AND "t".table_schema = ?
AND "t".table_schema NOT LIKE 'pg_%'
AND "table_name" NOT IN (?)
ORDER BY "t".table_schema, "t".table_name
`
// sqlInspectColumnsQuery retrieves column definitions for the specified table.
// Unlike sqlInspectTables and sqlInspectSchema, it should be passed to bun.NewRaw
// with additional args for table_schema and table_name.
sqlInspectColumnsQuery = `
SELECT
"c".table_schema,
"c".table_name,
"c".column_name,
"c".data_type,
"c".character_maximum_length::integer AS varchar_len,
"c".data_type = 'ARRAY' AS is_array,
COALESCE("c".array_dims, 0) AS array_dims,
CASE
WHEN "c".column_default ~ '^''.*''::.*$' THEN substring("c".column_default FROM '^''(.*)''::.*$')
ELSE "c".column_default
END AS "default",
"c".column_default ~ '^''.*''::.*$' OR "c".column_default ~ '^[0-9\.]+$' AS default_is_literal_expr,
"c".is_identity = 'YES' AS is_identity,
"c".column_default = format('nextval(''%s_%s_seq''::regclass)', "c".table_name, "c".column_name) AS is_serial,
COALESCE("c".identity_type, '') AS identity_type,
"c".is_nullable = 'YES' AS is_nullable,
"c"."unique_groups" AS unique_groups
FROM (
SELECT
"table_schema",
"table_name",
"column_name",
"c".data_type,
"c".character_maximum_length,
"c".column_default,
"c".is_identity,
"c".is_nullable,
att.array_dims,
att.identity_type,
att."unique_groups",
att."constraint_type"
FROM information_schema.columns "c"
LEFT JOIN (
SELECT
s.nspname AS "table_schema",
"t".relname AS "table_name",
"c".attname AS "column_name",
"c".attndims AS array_dims,
"c".attidentity AS identity_type,
ARRAY_AGG(con.conname) FILTER (WHERE con.contype = 'u') AS "unique_groups",
ARRAY_AGG(con.contype) AS "constraint_type"
FROM (
SELECT
conname,
contype,
connamespace,
conrelid,
conrelid AS attrelid,
UNNEST(conkey) AS attnum
FROM pg_constraint
) con
LEFT JOIN pg_attribute "c" USING (attrelid, attnum)
LEFT JOIN pg_namespace s ON s.oid = con.connamespace
LEFT JOIN pg_class "t" ON "t".oid = con.conrelid
GROUP BY 1, 2, 3, 4, 5
) att USING ("table_schema", "table_name", "column_name")
) "c"
WHERE "table_schema" = ? AND "table_name" = ?
ORDER BY "table_schema", "table_name", "column_name"
`
// sqlInspectForeignKeys get FK definitions for user-defined tables.
// Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results.
sqlInspectForeignKeys = `
WITH
"schemas" AS (
SELECT oid, nspname
FROM pg_namespace
),
"tables" AS (
SELECT oid, relnamespace, relname, relkind
FROM pg_class
),
"columns" AS (
SELECT attrelid, attname, attnum
FROM pg_attribute
WHERE attisdropped = false
)
SELECT DISTINCT
co.conname AS "constraint_name",
ss.nspname AS schema_name,
s.relname AS "table_name",
ARRAY_AGG(sc.attname) AS "columns",
ts.nspname AS target_schema,
"t".relname AS target_table,
ARRAY_AGG(tc.attname) AS target_columns
FROM pg_constraint co
LEFT JOIN "tables" s ON s.oid = co.conrelid
LEFT JOIN "schemas" ss ON ss.oid = s.relnamespace
LEFT JOIN "columns" sc ON sc.attrelid = s.oid AND sc.attnum = ANY(co.conkey)
LEFT JOIN "tables" t ON t.oid = co.confrelid
LEFT JOIN "schemas" ts ON ts.oid = "t".relnamespace
LEFT JOIN "columns" tc ON tc.attrelid = "t".oid AND tc.attnum = ANY(co.confkey)
WHERE co.contype = 'f'
AND co.conrelid IN (SELECT oid FROM pg_class WHERE relkind = 'r')
AND ARRAY_POSITION(co.conkey, sc.attnum) = ARRAY_POSITION(co.confkey, tc.attnum)
AND ss.nspname = ?
AND s.relname NOT IN (?) AND "t".relname NOT IN (?)
GROUP BY "constraint_name", "schema_name", "table_name", target_schema, target_table
`
)

View file

@ -5,18 +5,22 @@ import (
"encoding/json"
"net"
"reflect"
"strings"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/migrate/sqlschema"
"github.com/uptrace/bun/schema"
)
const (
// Date / Time
pgTypeTimestampTz = "TIMESTAMPTZ" // Timestamp with a time zone
pgTypeDate = "DATE" // Date
pgTypeTime = "TIME" // Time without a time zone
pgTypeTimeTz = "TIME WITH TIME ZONE" // Time with a time zone
pgTypeInterval = "INTERVAL" // Time Interval
pgTypeTimestamp = "TIMESTAMP" // Timestamp
pgTypeTimestampWithTz = "TIMESTAMP WITH TIME ZONE" // Timestamp with a time zone
pgTypeTimestampTz = "TIMESTAMPTZ" // Timestamp with a time zone (alias)
pgTypeDate = "DATE" // Date
pgTypeTime = "TIME" // Time without a time zone
pgTypeTimeTz = "TIME WITH TIME ZONE" // Time with a time zone
pgTypeInterval = "INTERVAL" // Time interval
// Network Addresses
pgTypeInet = "INET" // IPv4 or IPv6 hosts and networks
@ -28,6 +32,13 @@ const (
pgTypeSerial = "SERIAL" // 4 byte autoincrementing integer
pgTypeBigSerial = "BIGSERIAL" // 8 byte autoincrementing integer
// Character Types
pgTypeChar = "CHAR" // fixed length string (blank padded)
pgTypeCharacter = "CHARACTER" // alias for CHAR
pgTypeText = "TEXT" // variable length string without limit
pgTypeVarchar = "VARCHAR" // variable length string with optional limit
pgTypeCharacterVarying = "CHARACTER VARYING" // alias for VARCHAR
// Binary Data Types
pgTypeBytea = "BYTEA" // binary string
)
@ -43,6 +54,10 @@ func (d *Dialect) DefaultVarcharLen() int {
return 0
}
func (d *Dialect) DefaultSchema() string {
return "public"
}
func fieldSQLType(field *schema.Field) string {
if field.UserSQLType != "" {
return field.UserSQLType
@ -103,3 +118,62 @@ func sqlType(typ reflect.Type) string {
return sqlType
}
var (
char = newAliases(pgTypeChar, pgTypeCharacter)
varchar = newAliases(pgTypeVarchar, pgTypeCharacterVarying)
timestampTz = newAliases(sqltype.Timestamp, pgTypeTimestampTz, pgTypeTimestampWithTz)
)
func (d *Dialect) CompareType(col1, col2 sqlschema.Column) bool {
typ1, typ2 := strings.ToUpper(col1.GetSQLType()), strings.ToUpper(col2.GetSQLType())
if typ1 == typ2 {
return checkVarcharLen(col1, col2, d.DefaultVarcharLen())
}
switch {
case char.IsAlias(typ1) && char.IsAlias(typ2):
return checkVarcharLen(col1, col2, d.DefaultVarcharLen())
case varchar.IsAlias(typ1) && varchar.IsAlias(typ2):
return checkVarcharLen(col1, col2, d.DefaultVarcharLen())
case timestampTz.IsAlias(typ1) && timestampTz.IsAlias(typ2):
return true
}
return false
}
// checkVarcharLen returns true if columns have the same VarcharLen, or,
// if one specifies no VarcharLen and the other one has the default lenght for pgdialect.
// We assume that the types are otherwise equivalent and that any non-character column
// would have VarcharLen == 0;
func checkVarcharLen(col1, col2 sqlschema.Column, defaultLen int) bool {
vl1, vl2 := col1.GetVarcharLen(), col2.GetVarcharLen()
if vl1 == vl2 {
return true
}
if (vl1 == 0 && vl2 == defaultLen) || (vl1 == defaultLen && vl2 == 0) {
return true
}
return false
}
// typeAlias defines aliases for common data types. It is a lightweight string set implementation.
type typeAlias map[string]struct{}
// IsAlias checks if typ1 and typ2 are aliases of the same data type.
func (t typeAlias) IsAlias(typ string) bool {
_, ok := t[typ]
return ok
}
// newAliases creates a set of aliases.
func newAliases(aliases ...string) typeAlias {
types := make(typeAlias)
for _, a := range aliases {
types[a] = struct{}{}
}
return types
}

View file

@ -2,5 +2,5 @@ package pgdialect
// Version is the current release version.
func Version() string {
return "1.2.5"
return "1.2.6"
}

View file

@ -40,7 +40,8 @@ func New() *Dialect {
feature.TableNotExists |
feature.SelectExists |
feature.AutoIncrement |
feature.CompositeIn
feature.CompositeIn |
feature.DeleteReturning
return d
}
@ -96,9 +97,13 @@ func (d *Dialect) DefaultVarcharLen() int {
// AUTOINCREMENT is only valid for INTEGER PRIMARY KEY, and this method will be a noop for other columns.
//
// Because this is a valid construct:
//
// CREATE TABLE ("id" INTEGER PRIMARY KEY AUTOINCREMENT);
//
// and this is not:
//
// CREATE TABLE ("id" INTEGER AUTOINCREMENT, PRIMARY KEY ("id"));
//
// AppendSequence adds a primary key constraint as a *side-effect*. Callers should expect it to avoid building invalid SQL.
// SQLite also [does not support] AUTOINCREMENT column in composite primary keys.
//
@ -111,6 +116,13 @@ func (d *Dialect) AppendSequence(b []byte, table *schema.Table, field *schema.Fi
return b
}
// DefaultSchemaName is the "schema-name" of the main database.
// The details might differ from other dialects, but for all means and purposes
// "main" is the default schema in an SQLite database.
func (d *Dialect) DefaultSchema() string {
return "main"
}
func fieldSQLType(field *schema.Field) string {
switch field.DiscoveredSQLType {
case sqltype.SmallInt, sqltype.BigInt:

View file

@ -2,5 +2,5 @@ package sqlitedialect
// Version is the current release version.
func Version() string {
return "1.2.5"
return "1.2.6"
}

View file

@ -79,3 +79,9 @@ func indirectNil(v reflect.Value) reflect.Value {
}
return v
}
// MakeQueryBytes returns zero-length byte slice with capacity of 4096.
func MakeQueryBytes() []byte {
// TODO: make this configurable?
return make([]byte, 0, 4096)
}

429
vendor/github.com/uptrace/bun/migrate/auto.go generated vendored Normal file
View file

@ -0,0 +1,429 @@
package migrate
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"github.com/uptrace/bun"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/migrate/sqlschema"
"github.com/uptrace/bun/schema"
)
type AutoMigratorOption func(m *AutoMigrator)
// WithModel adds a bun.Model to the scope of migrations.
func WithModel(models ...interface{}) AutoMigratorOption {
return func(m *AutoMigrator) {
m.includeModels = append(m.includeModels, models...)
}
}
// WithExcludeTable tells the AutoMigrator to ignore a table in the database.
// This prevents AutoMigrator from dropping tables which may exist in the schema
// but which are not used by the application.
//
// Do not exclude tables included via WithModel, as BunModelInspector ignores this setting.
func WithExcludeTable(tables ...string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.excludeTables = append(m.excludeTables, tables...)
}
}
// WithSchemaName changes the default database schema to migrate objects in.
func WithSchemaName(schemaName string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.schemaName = schemaName
}
}
// WithTableNameAuto overrides default migrations table name.
func WithTableNameAuto(table string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.table = table
m.migratorOpts = append(m.migratorOpts, WithTableName(table))
}
}
// WithLocksTableNameAuto overrides default migration locks table name.
func WithLocksTableNameAuto(table string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.locksTable = table
m.migratorOpts = append(m.migratorOpts, WithLocksTableName(table))
}
}
// WithMarkAppliedOnSuccessAuto sets the migrator to only mark migrations as applied/unapplied
// when their up/down is successful.
func WithMarkAppliedOnSuccessAuto(enabled bool) AutoMigratorOption {
return func(m *AutoMigrator) {
m.migratorOpts = append(m.migratorOpts, WithMarkAppliedOnSuccess(enabled))
}
}
// WithMigrationsDirectoryAuto overrides the default directory for migration files.
func WithMigrationsDirectoryAuto(directory string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.migrationsOpts = append(m.migrationsOpts, WithMigrationsDirectory(directory))
}
}
// AutoMigrator performs automated schema migrations.
//
// It is designed to be a drop-in replacement for some Migrator functionality and supports all existing
// configuration options.
// Similarly to Migrator, it has methods to create SQL migrations, write them to a file, and apply them.
// Unlike Migrator, it detects the differences between the state defined by bun models and the current
// database schema automatically.
//
// Usage:
// 1. Generate migrations and apply them au once with AutoMigrator.Migrate().
// 2. Create up- and down-SQL migration files and apply migrations using Migrator.Migrate().
//
// While both methods produce complete, reversible migrations (with entries in the database
// and SQL migration files), prefer creating migrations and applying them separately for
// any non-trivial cases to ensure AutoMigrator detects expected changes correctly.
//
// Limitations:
// - AutoMigrator only supports a subset of the possible ALTER TABLE modifications.
// - Some changes are not automatically reversible. For example, you would need to manually
// add a CREATE TABLE query to the .down migration file to revert a DROP TABLE migration.
// - Does not validate most dialect-specific constraints. For example, when changing column
// data type, make sure the data con be auto-casted to the new type.
// - Due to how the schema-state diff is calculated, it is not possible to rename a table and
// modify any of its columns' _data type_ in a single run. This will cause the AutoMigrator
// to drop and re-create the table under a different name; it is better to apply this change in 2 steps.
// Renaming a table and renaming its columns at the same time is possible.
// - Renaming table/column to an existing name, i.e. like this [A->B] [B->C], is not possible due to how
// AutoMigrator distinguishes "rename" and "unchanged" columns.
//
// Dialect must implement both sqlschema.Inspector and sqlschema.Migrator to be used with AutoMigrator.
type AutoMigrator struct {
db *bun.DB
// dbInspector creates the current state for the target database.
dbInspector sqlschema.Inspector
// modelInspector creates the desired state based on the model definitions.
modelInspector sqlschema.Inspector
// dbMigrator executes ALTER TABLE queries.
dbMigrator sqlschema.Migrator
table string // Migrations table (excluded from database inspection)
locksTable string // Migration locks table (excluded from database inspection)
// schemaName is the database schema considered for migration.
schemaName string
// includeModels define the migration scope.
includeModels []interface{}
// excludeTables are excluded from database inspection.
excludeTables []string
// diffOpts are passed to detector constructor.
diffOpts []diffOption
// migratorOpts are passed to Migrator constructor.
migratorOpts []MigratorOption
// migrationsOpts are passed to Migrations constructor.
migrationsOpts []MigrationsOption
}
func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, error) {
am := &AutoMigrator{
db: db,
table: defaultTable,
locksTable: defaultLocksTable,
schemaName: db.Dialect().DefaultSchema(),
}
for _, opt := range opts {
opt(am)
}
am.excludeTables = append(am.excludeTables, am.table, am.locksTable)
dbInspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(am.schemaName), sqlschema.WithExcludeTables(am.excludeTables...))
if err != nil {
return nil, err
}
am.dbInspector = dbInspector
am.diffOpts = append(am.diffOpts, withCompareTypeFunc(db.Dialect().(sqlschema.InspectorDialect).CompareType))
dbMigrator, err := sqlschema.NewMigrator(db, am.schemaName)
if err != nil {
return nil, err
}
am.dbMigrator = dbMigrator
tables := schema.NewTables(db.Dialect())
tables.Register(am.includeModels...)
am.modelInspector = sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(am.schemaName))
return am, nil
}
func (am *AutoMigrator) plan(ctx context.Context) (*changeset, error) {
var err error
got, err := am.dbInspector.Inspect(ctx)
if err != nil {
return nil, err
}
want, err := am.modelInspector.Inspect(ctx)
if err != nil {
return nil, err
}
changes := diff(got, want, am.diffOpts...)
if err := changes.ResolveDependencies(); err != nil {
return nil, fmt.Errorf("plan migrations: %w", err)
}
return changes, nil
}
// Migrate writes required changes to a new migration file and runs the migration.
// This will create and entry in the migrations table, making it possible to revert
// the changes with Migrator.Rollback(). MigrationOptions are passed on to Migrator.Migrate().
func (am *AutoMigrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
migrations, _, err := am.createSQLMigrations(ctx, false)
if err != nil {
return nil, fmt.Errorf("auto migrate: %w", err)
}
migrator := NewMigrator(am.db, migrations, am.migratorOpts...)
if err := migrator.Init(ctx); err != nil {
return nil, fmt.Errorf("auto migrate: %w", err)
}
group, err := migrator.Migrate(ctx, opts...)
if err != nil {
return nil, fmt.Errorf("auto migrate: %w", err)
}
return group, nil
}
// CreateSQLMigration writes required changes to a new migration file.
// Use migrate.Migrator to apply the generated migrations.
func (am *AutoMigrator) CreateSQLMigrations(ctx context.Context) ([]*MigrationFile, error) {
_, files, err := am.createSQLMigrations(ctx, true)
return files, err
}
// CreateTxSQLMigration writes required changes to a new migration file making sure they will be executed
// in a transaction when applied. Use migrate.Migrator to apply the generated migrations.
func (am *AutoMigrator) CreateTxSQLMigrations(ctx context.Context) ([]*MigrationFile, error) {
_, files, err := am.createSQLMigrations(ctx, false)
return files, err
}
func (am *AutoMigrator) createSQLMigrations(ctx context.Context, transactional bool) (*Migrations, []*MigrationFile, error) {
changes, err := am.plan(ctx)
if err != nil {
return nil, nil, fmt.Errorf("create sql migrations: %w", err)
}
name, _ := genMigrationName(am.schemaName + "_auto")
migrations := NewMigrations(am.migrationsOpts...)
migrations.Add(Migration{
Name: name,
Up: changes.Up(am.dbMigrator),
Down: changes.Down(am.dbMigrator),
Comment: "Changes detected by bun.AutoMigrator",
})
// Append .tx.up.sql or .up.sql to migration name, dependin if it should be transactional.
fname := func(direction string) string {
return name + map[bool]string{true: ".tx.", false: "."}[transactional] + direction + ".sql"
}
up, err := am.createSQL(ctx, migrations, fname("up"), changes, transactional)
if err != nil {
return nil, nil, fmt.Errorf("create sql migration up: %w", err)
}
down, err := am.createSQL(ctx, migrations, fname("down"), changes.GetReverse(), transactional)
if err != nil {
return nil, nil, fmt.Errorf("create sql migration down: %w", err)
}
return migrations, []*MigrationFile{up, down}, nil
}
func (am *AutoMigrator) createSQL(_ context.Context, migrations *Migrations, fname string, changes *changeset, transactional bool) (*MigrationFile, error) {
var buf bytes.Buffer
if transactional {
buf.WriteString("SET statement_timeout = 0;")
}
if err := changes.WriteTo(&buf, am.dbMigrator); err != nil {
return nil, err
}
content := buf.Bytes()
fpath := filepath.Join(migrations.getDirectory(), fname)
if err := os.WriteFile(fpath, content, 0o644); err != nil {
return nil, err
}
mf := &MigrationFile{
Name: fname,
Path: fpath,
Content: string(content),
}
return mf, nil
}
// Func creates a MigrationFunc that applies all operations all the changeset.
func (c *changeset) Func(m sqlschema.Migrator) MigrationFunc {
return func(ctx context.Context, db *bun.DB) error {
return c.apply(ctx, db, m)
}
}
// GetReverse returns a new changeset with each operation in it "reversed" and in reverse order.
func (c *changeset) GetReverse() *changeset {
var reverse changeset
for i := len(c.operations) - 1; i >= 0; i-- {
reverse.Add(c.operations[i].GetReverse())
}
return &reverse
}
// Up is syntactic sugar.
func (c *changeset) Up(m sqlschema.Migrator) MigrationFunc {
return c.Func(m)
}
// Down is syntactic sugar.
func (c *changeset) Down(m sqlschema.Migrator) MigrationFunc {
return c.GetReverse().Func(m)
}
// apply generates SQL for each operation and executes it.
func (c *changeset) apply(ctx context.Context, db *bun.DB, m sqlschema.Migrator) error {
if len(c.operations) == 0 {
return nil
}
for _, op := range c.operations {
if _, isComment := op.(*comment); isComment {
continue
}
b := internal.MakeQueryBytes()
b, err := m.AppendSQL(b, op)
if err != nil {
return fmt.Errorf("apply changes: %w", err)
}
query := internal.String(b)
if _, err = db.ExecContext(ctx, query); err != nil {
return fmt.Errorf("apply changes: %w", err)
}
}
return nil
}
func (c *changeset) WriteTo(w io.Writer, m sqlschema.Migrator) error {
var err error
b := internal.MakeQueryBytes()
for _, op := range c.operations {
if c, isComment := op.(*comment); isComment {
b = append(b, "/*\n"...)
b = append(b, *c...)
b = append(b, "\n*/"...)
continue
}
b, err = m.AppendSQL(b, op)
if err != nil {
return fmt.Errorf("write changeset: %w", err)
}
b = append(b, ";\n"...)
}
if _, err := w.Write(b); err != nil {
return fmt.Errorf("write changeset: %w", err)
}
return nil
}
func (c *changeset) ResolveDependencies() error {
if len(c.operations) <= 1 {
return nil
}
const (
unvisited = iota
current
visited
)
status := make(map[Operation]int, len(c.operations))
for _, op := range c.operations {
status[op] = unvisited
}
var resolved []Operation
var nextOp Operation
var visit func(op Operation) error
next := func() bool {
for op, s := range status {
if s == unvisited {
nextOp = op
return true
}
}
return false
}
// visit iterates over c.operations until it finds all operations that depend on the current one
// or runs into cirtular dependency, in which case it will return an error.
visit = func(op Operation) error {
switch status[op] {
case visited:
return nil
case current:
// TODO: add details (circle) to the error message
return errors.New("detected circular dependency")
}
status[op] = current
for _, another := range c.operations {
if dop, hasDeps := another.(interface {
DependsOn(Operation) bool
}); another == op || !hasDeps || !dop.DependsOn(op) {
continue
}
if err := visit(another); err != nil {
return err
}
}
status[op] = visited
// Any dependent nodes would've already been added to the list by now, so we prepend.
resolved = append([]Operation{op}, resolved...)
return nil
}
for next() {
if err := visit(nextOp); err != nil {
return err
}
}
c.operations = resolved
return nil
}

411
vendor/github.com/uptrace/bun/migrate/diff.go generated vendored Normal file
View file

@ -0,0 +1,411 @@
package migrate
import (
"github.com/uptrace/bun/migrate/sqlschema"
)
// changeset is a set of changes to the database schema definition.
type changeset struct {
operations []Operation
}
// Add new operations to the changeset.
func (c *changeset) Add(op ...Operation) {
c.operations = append(c.operations, op...)
}
// diff calculates the diff between the current database schema and the target state.
// The changeset is not sorted -- the caller should resolve dependencies before applying the changes.
func diff(got, want sqlschema.Database, opts ...diffOption) *changeset {
d := newDetector(got, want, opts...)
return d.detectChanges()
}
func (d *detector) detectChanges() *changeset {
currentTables := d.current.GetTables()
targetTables := d.target.GetTables()
RenameCreate:
for wantName, wantTable := range targetTables.FromOldest() {
// A table with this name exists in the database. We assume that schema objects won't
// be renamed to an already existing name, nor do we support such cases.
// Simply check if the table definition has changed.
if haveTable, ok := currentTables.Get(wantName); ok {
d.detectColumnChanges(haveTable, wantTable, true)
d.detectConstraintChanges(haveTable, wantTable)
continue
}
// Find all renamed tables. We assume that renamed tables have the same signature.
for haveName, haveTable := range currentTables.FromOldest() {
if _, exists := targetTables.Get(haveName); !exists && d.canRename(haveTable, wantTable) {
d.changes.Add(&RenameTableOp{
TableName: haveTable.GetName(),
NewName: wantName,
})
d.refMap.RenameTable(haveTable.GetName(), wantName)
// Find renamed columns, if any, and check if constraints (PK, UNIQUE) have been updated.
// We need not check wantTable any further.
d.detectColumnChanges(haveTable, wantTable, false)
d.detectConstraintChanges(haveTable, wantTable)
currentTables.Delete(haveName)
continue RenameCreate
}
}
// If wantTable does not exist in the database and was not renamed
// then we need to create this table in the database.
additional := wantTable.(*sqlschema.BunTable)
d.changes.Add(&CreateTableOp{
TableName: wantTable.GetName(),
Model: additional.Model,
})
}
// Drop any remaining "current" tables which do not have a model.
for name, table := range currentTables.FromOldest() {
if _, keep := targetTables.Get(name); !keep {
d.changes.Add(&DropTableOp{
TableName: table.GetName(),
})
}
}
targetFKs := d.target.GetForeignKeys()
currentFKs := d.refMap.Deref()
for fk := range targetFKs {
if _, ok := currentFKs[fk]; !ok {
d.changes.Add(&AddForeignKeyOp{
ForeignKey: fk,
ConstraintName: "", // leave empty to let each dialect apply their convention
})
}
}
for fk, name := range currentFKs {
if _, ok := targetFKs[fk]; !ok {
d.changes.Add(&DropForeignKeyOp{
ConstraintName: name,
ForeignKey: fk,
})
}
}
return &d.changes
}
// detechColumnChanges finds renamed columns and, if checkType == true, columns with changed type.
func (d *detector) detectColumnChanges(current, target sqlschema.Table, checkType bool) {
currentColumns := current.GetColumns()
targetColumns := target.GetColumns()
ChangeRename:
for tName, tCol := range targetColumns.FromOldest() {
// This column exists in the database, so it hasn't been renamed, dropped, or added.
// Still, we should not delete(columns, thisColumn), because later we will need to
// check that we do not try to rename a column to an already a name that already exists.
if cCol, ok := currentColumns.Get(tName); ok {
if checkType && !d.equalColumns(cCol, tCol) {
d.changes.Add(&ChangeColumnTypeOp{
TableName: target.GetName(),
Column: tName,
From: cCol,
To: d.makeTargetColDef(cCol, tCol),
})
}
continue
}
// Column tName does not exist in the database -- it's been either renamed or added.
// Find renamed columns first.
for cName, cCol := range currentColumns.FromOldest() {
// Cannot rename if a column with this name already exists or the types differ.
if _, exists := targetColumns.Get(cName); exists || !d.equalColumns(tCol, cCol) {
continue
}
d.changes.Add(&RenameColumnOp{
TableName: target.GetName(),
OldName: cName,
NewName: tName,
})
d.refMap.RenameColumn(target.GetName(), cName, tName)
currentColumns.Delete(cName) // no need to check this column again
// Update primary key definition to avoid superficially recreating the constraint.
current.GetPrimaryKey().Columns.Replace(cName, tName)
continue ChangeRename
}
d.changes.Add(&AddColumnOp{
TableName: target.GetName(),
ColumnName: tName,
Column: tCol,
})
}
// Drop columns which do not exist in the target schema and were not renamed.
for cName, cCol := range currentColumns.FromOldest() {
if _, keep := targetColumns.Get(cName); !keep {
d.changes.Add(&DropColumnOp{
TableName: target.GetName(),
ColumnName: cName,
Column: cCol,
})
}
}
}
func (d *detector) detectConstraintChanges(current, target sqlschema.Table) {
Add:
for _, want := range target.GetUniqueConstraints() {
for _, got := range current.GetUniqueConstraints() {
if got.Equals(want) {
continue Add
}
}
d.changes.Add(&AddUniqueConstraintOp{
TableName: target.GetName(),
Unique: want,
})
}
Drop:
for _, got := range current.GetUniqueConstraints() {
for _, want := range target.GetUniqueConstraints() {
if got.Equals(want) {
continue Drop
}
}
d.changes.Add(&DropUniqueConstraintOp{
TableName: target.GetName(),
Unique: got,
})
}
targetPK := target.GetPrimaryKey()
currentPK := current.GetPrimaryKey()
// Detect primary key changes
if targetPK == nil && currentPK == nil {
return
}
switch {
case targetPK == nil && currentPK != nil:
d.changes.Add(&DropPrimaryKeyOp{
TableName: target.GetName(),
PrimaryKey: *currentPK,
})
case currentPK == nil && targetPK != nil:
d.changes.Add(&AddPrimaryKeyOp{
TableName: target.GetName(),
PrimaryKey: *targetPK,
})
case targetPK.Columns != currentPK.Columns:
d.changes.Add(&ChangePrimaryKeyOp{
TableName: target.GetName(),
Old: *currentPK,
New: *targetPK,
})
}
}
func newDetector(got, want sqlschema.Database, opts ...diffOption) *detector {
cfg := &detectorConfig{
cmpType: func(c1, c2 sqlschema.Column) bool {
return c1.GetSQLType() == c2.GetSQLType() && c1.GetVarcharLen() == c2.GetVarcharLen()
},
}
for _, opt := range opts {
opt(cfg)
}
return &detector{
current: got,
target: want,
refMap: newRefMap(got.GetForeignKeys()),
cmpType: cfg.cmpType,
}
}
type diffOption func(*detectorConfig)
func withCompareTypeFunc(f CompareTypeFunc) diffOption {
return func(cfg *detectorConfig) {
cfg.cmpType = f
}
}
// detectorConfig controls how differences in the model states are resolved.
type detectorConfig struct {
cmpType CompareTypeFunc
}
// detector may modify the passed database schemas, so it isn't safe to re-use them.
type detector struct {
// current state represents the existing database schema.
current sqlschema.Database
// target state represents the database schema defined in bun models.
target sqlschema.Database
changes changeset
refMap refMap
// cmpType determines column type equivalence.
// Default is direct comparison with '==' operator, which is inaccurate
// due to the existence of dialect-specific type aliases. The caller
// should pass a concrete InspectorDialect.EquuivalentType for robust comparison.
cmpType CompareTypeFunc
}
// canRename checks if t1 can be renamed to t2.
func (d detector) canRename(t1, t2 sqlschema.Table) bool {
return t1.GetSchema() == t2.GetSchema() && equalSignatures(t1, t2, d.equalColumns)
}
func (d detector) equalColumns(col1, col2 sqlschema.Column) bool {
return d.cmpType(col1, col2) &&
col1.GetDefaultValue() == col2.GetDefaultValue() &&
col1.GetIsNullable() == col2.GetIsNullable() &&
col1.GetIsAutoIncrement() == col2.GetIsAutoIncrement() &&
col1.GetIsIdentity() == col2.GetIsIdentity()
}
func (d detector) makeTargetColDef(current, target sqlschema.Column) sqlschema.Column {
// Avoid unneccessary type-change migrations if the types are equivalent.
if d.cmpType(current, target) {
target = &sqlschema.BaseColumn{
Name: target.GetName(),
DefaultValue: target.GetDefaultValue(),
IsNullable: target.GetIsNullable(),
IsAutoIncrement: target.GetIsAutoIncrement(),
IsIdentity: target.GetIsIdentity(),
SQLType: current.GetSQLType(),
VarcharLen: current.GetVarcharLen(),
}
}
return target
}
type CompareTypeFunc func(sqlschema.Column, sqlschema.Column) bool
// equalSignatures determines if two tables have the same "signature".
func equalSignatures(t1, t2 sqlschema.Table, eq CompareTypeFunc) bool {
sig1 := newSignature(t1, eq)
sig2 := newSignature(t2, eq)
return sig1.Equals(sig2)
}
// signature is a set of column definitions, which allows "relation/name-agnostic" comparison between them;
// meaning that two columns are considered equal if their types are the same.
type signature struct {
// underlying stores the number of occurences for each unique column type.
// It helps to account for the fact that a table might have multiple columns that have the same type.
underlying map[sqlschema.BaseColumn]int
eq CompareTypeFunc
}
func newSignature(t sqlschema.Table, eq CompareTypeFunc) signature {
s := signature{
underlying: make(map[sqlschema.BaseColumn]int),
eq: eq,
}
s.scan(t)
return s
}
// scan iterates over table's field and counts occurrences of each unique column definition.
func (s *signature) scan(t sqlschema.Table) {
for _, icol := range t.GetColumns().FromOldest() {
scanCol := icol.(*sqlschema.BaseColumn)
// This is slightly more expensive than if the columns could be compared directly
// and we always did s.underlying[col]++, but we get type-equivalence in return.
col, count := s.getCount(*scanCol)
if count == 0 {
s.underlying[*scanCol] = 1
} else {
s.underlying[col]++
}
}
}
// getCount uses CompareTypeFunc to find a column with the same (equivalent) SQL type
// and returns its count. Count 0 means there are no columns with of this type.
func (s *signature) getCount(keyCol sqlschema.BaseColumn) (key sqlschema.BaseColumn, count int) {
for col, cnt := range s.underlying {
if s.eq(&col, &keyCol) {
return col, cnt
}
}
return keyCol, 0
}
// Equals returns true if 2 signatures share an identical set of columns.
func (s *signature) Equals(other signature) bool {
if len(s.underlying) != len(other.underlying) {
return false
}
for col, count := range s.underlying {
if _, countOther := other.getCount(col); countOther != count {
return false
}
}
return true
}
// refMap is a utility for tracking superficial changes in foreign keys,
// which do not require any modificiation in the database.
// Modern SQL dialects automatically updated foreign key constraints whenever
// a column or a table is renamed. Detector can use refMap to ignore any
// differences in foreign keys which were caused by renamed column/table.
type refMap map[*sqlschema.ForeignKey]string
func newRefMap(fks map[sqlschema.ForeignKey]string) refMap {
rm := make(map[*sqlschema.ForeignKey]string)
for fk, name := range fks {
rm[&fk] = name
}
return rm
}
// RenameT updates table name in all foreign key definions which depend on it.
func (rm refMap) RenameTable(tableName string, newName string) {
for fk := range rm {
switch tableName {
case fk.From.TableName:
fk.From.TableName = newName
case fk.To.TableName:
fk.To.TableName = newName
}
}
}
// RenameColumn updates column name in all foreign key definions which depend on it.
func (rm refMap) RenameColumn(tableName string, column, newName string) {
for fk := range rm {
if tableName == fk.From.TableName {
fk.From.Column.Replace(column, newName)
}
if tableName == fk.To.TableName {
fk.To.Column.Replace(column, newName)
}
}
}
// Deref returns copies of ForeignKey values to a map.
func (rm refMap) Deref() map[sqlschema.ForeignKey]string {
out := make(map[sqlschema.ForeignKey]string)
for fk, name := range rm {
out[*fk] = name
}
return out
}

View file

@ -12,14 +12,21 @@ import (
"github.com/uptrace/bun"
)
const (
defaultTable = "bun_migrations"
defaultLocksTable = "bun_migration_locks"
)
type MigratorOption func(m *Migrator)
// WithTableName overrides default migrations table name.
func WithTableName(table string) MigratorOption {
return func(m *Migrator) {
m.table = table
}
}
// WithLocksTableName overrides default migration locks table name.
func WithLocksTableName(table string) MigratorOption {
return func(m *Migrator) {
m.locksTable = table
@ -27,7 +34,7 @@ func WithLocksTableName(table string) MigratorOption {
}
// WithMarkAppliedOnSuccess sets the migrator to only mark migrations as applied/unapplied
// when their up/down is successful
// when their up/down is successful.
func WithMarkAppliedOnSuccess(enabled bool) MigratorOption {
return func(m *Migrator) {
m.markAppliedOnSuccess = enabled
@ -52,8 +59,8 @@ func NewMigrator(db *bun.DB, migrations *Migrations, opts ...MigratorOption) *Mi
ms: migrations.ms,
table: "bun_migrations",
locksTable: "bun_migration_locks",
table: defaultTable,
locksTable: defaultLocksTable,
}
for _, opt := range opts {
opt(m)
@ -246,7 +253,7 @@ func (m *Migrator) CreateGoMigration(
opt(cfg)
}
name, err := m.genMigrationName(name)
name, err := genMigrationName(name)
if err != nil {
return nil, err
}
@ -269,7 +276,7 @@ func (m *Migrator) CreateGoMigration(
// CreateTxSQLMigration creates transactional up and down SQL migration files.
func (m *Migrator) CreateTxSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) {
name, err := m.genMigrationName(name)
name, err := genMigrationName(name)
if err != nil {
return nil, err
}
@ -289,7 +296,7 @@ func (m *Migrator) CreateTxSQLMigrations(ctx context.Context, name string) ([]*M
// CreateSQLMigrations creates up and down SQL migration files.
func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) {
name, err := m.genMigrationName(name)
name, err := genMigrationName(name)
if err != nil {
return nil, err
}
@ -307,7 +314,7 @@ func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*Mig
return []*MigrationFile{up, down}, nil
}
func (m *Migrator) createSQL(ctx context.Context, fname string, transactional bool) (*MigrationFile, error) {
func (m *Migrator) createSQL(_ context.Context, fname string, transactional bool) (*MigrationFile, error) {
fpath := filepath.Join(m.migrations.getDirectory(), fname)
template := sqlTemplate
@ -329,7 +336,7 @@ func (m *Migrator) createSQL(ctx context.Context, fname string, transactional bo
var nameRE = regexp.MustCompile(`^[0-9a-z_\-]+$`)
func (m *Migrator) genMigrationName(name string) (string, error) {
func genMigrationName(name string) (string, error) {
const timeFormat = "20060102150405"
if name == "" {

340
vendor/github.com/uptrace/bun/migrate/operations.go generated vendored Normal file
View file

@ -0,0 +1,340 @@
package migrate
import (
"fmt"
"github.com/uptrace/bun/migrate/sqlschema"
)
// Operation encapsulates the request to change a database definition
// and knowns which operation can revert it.
//
// It is useful to define "monolith" Operations whenever possible,
// even though they a dialect may require several distinct steps to apply them.
// For example, changing a primary key involves first dropping the old constraint
// before generating the new one. Yet, this is only an implementation detail and
// passing a higher-level ChangePrimaryKeyOp will give the dialect more information
// about the applied change.
//
// Some operations might be irreversible due to technical limitations. Returning
// a *comment from GetReverse() will add an explanatory note to the generate migation file.
//
// To declare dependency on another Operation, operations should implement
// { DependsOn(Operation) bool } interface, which Changeset will use to resolve dependencies.
type Operation interface {
GetReverse() Operation
}
// CreateTableOp creates a new table in the schema.
//
// It does not report dependency on any other migration and may be executed first.
// Make sure the dialect does not include FOREIGN KEY constraints in the CREATE TABLE
// statement, as those may potentially reference not-yet-existing columns/tables.
type CreateTableOp struct {
TableName string
Model interface{}
}
var _ Operation = (*CreateTableOp)(nil)
func (op *CreateTableOp) GetReverse() Operation {
return &DropTableOp{TableName: op.TableName}
}
// DropTableOp drops a database table. This operation is not reversible.
type DropTableOp struct {
TableName string
}
var _ Operation = (*DropTableOp)(nil)
func (op *DropTableOp) DependsOn(another Operation) bool {
drop, ok := another.(*DropForeignKeyOp)
return ok && drop.ForeignKey.DependsOnTable(op.TableName)
}
// GetReverse for a DropTable returns a no-op migration. Logically, CreateTable is the reverse,
// but DropTable does not have the table's definition to create one.
func (op *DropTableOp) GetReverse() Operation {
c := comment(fmt.Sprintf("WARNING: \"DROP TABLE %s\" cannot be reversed automatically because table definition is not available", op.TableName))
return &c
}
// RenameTableOp renames the table. Changing the "schema" part of the table's FQN (moving tables between schemas) is not allowed.
type RenameTableOp struct {
TableName string
NewName string
}
var _ Operation = (*RenameTableOp)(nil)
func (op *RenameTableOp) GetReverse() Operation {
return &RenameTableOp{
TableName: op.NewName,
NewName: op.TableName,
}
}
// RenameColumnOp renames a column in the table. If the changeset includes a rename operation
// for the column's table, it should be executed first.
type RenameColumnOp struct {
TableName string
OldName string
NewName string
}
var _ Operation = (*RenameColumnOp)(nil)
func (op *RenameColumnOp) GetReverse() Operation {
return &RenameColumnOp{
TableName: op.TableName,
OldName: op.NewName,
NewName: op.OldName,
}
}
func (op *RenameColumnOp) DependsOn(another Operation) bool {
rename, ok := another.(*RenameTableOp)
return ok && op.TableName == rename.NewName
}
// AddColumnOp adds a new column to the table.
type AddColumnOp struct {
TableName string
ColumnName string
Column sqlschema.Column
}
var _ Operation = (*AddColumnOp)(nil)
func (op *AddColumnOp) GetReverse() Operation {
return &DropColumnOp{
TableName: op.TableName,
ColumnName: op.ColumnName,
Column: op.Column,
}
}
// DropColumnOp drop a column from the table.
//
// While some dialects allow DROP CASCADE to drop dependent constraints,
// explicit handling on constraints is preferred for transparency and debugging.
// DropColumnOp depends on DropForeignKeyOp, DropPrimaryKeyOp, and ChangePrimaryKeyOp
// if any of the constraints is defined on this table.
type DropColumnOp struct {
TableName string
ColumnName string
Column sqlschema.Column
}
var _ Operation = (*DropColumnOp)(nil)
func (op *DropColumnOp) GetReverse() Operation {
return &AddColumnOp{
TableName: op.TableName,
ColumnName: op.ColumnName,
Column: op.Column,
}
}
func (op *DropColumnOp) DependsOn(another Operation) bool {
switch drop := another.(type) {
case *DropForeignKeyOp:
return drop.ForeignKey.DependsOnColumn(op.TableName, op.ColumnName)
case *DropPrimaryKeyOp:
return op.TableName == drop.TableName && drop.PrimaryKey.Columns.Contains(op.ColumnName)
case *ChangePrimaryKeyOp:
return op.TableName == drop.TableName && drop.Old.Columns.Contains(op.ColumnName)
}
return false
}
// AddForeignKey adds a new FOREIGN KEY constraint.
type AddForeignKeyOp struct {
ForeignKey sqlschema.ForeignKey
ConstraintName string
}
var _ Operation = (*AddForeignKeyOp)(nil)
func (op *AddForeignKeyOp) TableName() string {
return op.ForeignKey.From.TableName
}
func (op *AddForeignKeyOp) DependsOn(another Operation) bool {
switch another := another.(type) {
case *RenameTableOp:
return op.ForeignKey.DependsOnTable(another.TableName) || op.ForeignKey.DependsOnTable(another.NewName)
case *CreateTableOp:
return op.ForeignKey.DependsOnTable(another.TableName)
}
return false
}
func (op *AddForeignKeyOp) GetReverse() Operation {
return &DropForeignKeyOp{
ForeignKey: op.ForeignKey,
ConstraintName: op.ConstraintName,
}
}
// DropForeignKeyOp drops a FOREIGN KEY constraint.
type DropForeignKeyOp struct {
ForeignKey sqlschema.ForeignKey
ConstraintName string
}
var _ Operation = (*DropForeignKeyOp)(nil)
func (op *DropForeignKeyOp) TableName() string {
return op.ForeignKey.From.TableName
}
func (op *DropForeignKeyOp) GetReverse() Operation {
return &AddForeignKeyOp{
ForeignKey: op.ForeignKey,
ConstraintName: op.ConstraintName,
}
}
// AddUniqueConstraintOp adds new UNIQUE constraint to the table.
type AddUniqueConstraintOp struct {
TableName string
Unique sqlschema.Unique
}
var _ Operation = (*AddUniqueConstraintOp)(nil)
func (op *AddUniqueConstraintOp) GetReverse() Operation {
return &DropUniqueConstraintOp{
TableName: op.TableName,
Unique: op.Unique,
}
}
func (op *AddUniqueConstraintOp) DependsOn(another Operation) bool {
switch another := another.(type) {
case *AddColumnOp:
return op.TableName == another.TableName && op.Unique.Columns.Contains(another.ColumnName)
case *RenameTableOp:
return op.TableName == another.NewName
case *DropUniqueConstraintOp:
// We want to drop the constraint with the same name before adding this one.
return op.TableName == another.TableName && op.Unique.Name == another.Unique.Name
default:
return false
}
}
// DropUniqueConstraintOp drops a UNIQUE constraint.
type DropUniqueConstraintOp struct {
TableName string
Unique sqlschema.Unique
}
var _ Operation = (*DropUniqueConstraintOp)(nil)
func (op *DropUniqueConstraintOp) DependsOn(another Operation) bool {
if rename, ok := another.(*RenameTableOp); ok {
return op.TableName == rename.NewName
}
return false
}
func (op *DropUniqueConstraintOp) GetReverse() Operation {
return &AddUniqueConstraintOp{
TableName: op.TableName,
Unique: op.Unique,
}
}
// ChangeColumnTypeOp set a new data type for the column.
// The two types should be such that the data can be auto-casted from one to another.
// E.g. reducing VARCHAR lenght is not possible in most dialects.
// AutoMigrator does not enforce or validate these rules.
type ChangeColumnTypeOp struct {
TableName string
Column string
From sqlschema.Column
To sqlschema.Column
}
var _ Operation = (*ChangeColumnTypeOp)(nil)
func (op *ChangeColumnTypeOp) GetReverse() Operation {
return &ChangeColumnTypeOp{
TableName: op.TableName,
Column: op.Column,
From: op.To,
To: op.From,
}
}
// DropPrimaryKeyOp drops the table's PRIMARY KEY.
type DropPrimaryKeyOp struct {
TableName string
PrimaryKey sqlschema.PrimaryKey
}
var _ Operation = (*DropPrimaryKeyOp)(nil)
func (op *DropPrimaryKeyOp) GetReverse() Operation {
return &AddPrimaryKeyOp{
TableName: op.TableName,
PrimaryKey: op.PrimaryKey,
}
}
// AddPrimaryKeyOp adds a new PRIMARY KEY to the table.
type AddPrimaryKeyOp struct {
TableName string
PrimaryKey sqlschema.PrimaryKey
}
var _ Operation = (*AddPrimaryKeyOp)(nil)
func (op *AddPrimaryKeyOp) GetReverse() Operation {
return &DropPrimaryKeyOp{
TableName: op.TableName,
PrimaryKey: op.PrimaryKey,
}
}
func (op *AddPrimaryKeyOp) DependsOn(another Operation) bool {
switch another := another.(type) {
case *AddColumnOp:
return op.TableName == another.TableName && op.PrimaryKey.Columns.Contains(another.ColumnName)
}
return false
}
// ChangePrimaryKeyOp changes the PRIMARY KEY of the table.
type ChangePrimaryKeyOp struct {
TableName string
Old sqlschema.PrimaryKey
New sqlschema.PrimaryKey
}
var _ Operation = (*AddPrimaryKeyOp)(nil)
func (op *ChangePrimaryKeyOp) GetReverse() Operation {
return &ChangePrimaryKeyOp{
TableName: op.TableName,
Old: op.New,
New: op.Old,
}
}
// comment denotes an Operation that cannot be executed.
//
// Operations, which cannot be reversed due to current technical limitations,
// may return &comment with a helpful message from their GetReverse() method.
//
// Chnagelog should skip it when applying operations or output as log message,
// and write it as an SQL comment when creating migration files.
type comment string
var _ Operation = (*comment)(nil)
func (c *comment) GetReverse() Operation { return c }

View file

@ -0,0 +1,75 @@
package sqlschema
import (
"fmt"
"github.com/uptrace/bun/schema"
)
type Column interface {
GetName() string
GetSQLType() string
GetVarcharLen() int
GetDefaultValue() string
GetIsNullable() bool
GetIsAutoIncrement() bool
GetIsIdentity() bool
AppendQuery(schema.Formatter, []byte) ([]byte, error)
}
var _ Column = (*BaseColumn)(nil)
// BaseColumn is a base column definition that stores various attributes of a column.
//
// Dialects and only dialects can use it to implement the Column interface.
// Other packages must use the Column interface.
type BaseColumn struct {
Name string
SQLType string
VarcharLen int
DefaultValue string
IsNullable bool
IsAutoIncrement bool
IsIdentity bool
// TODO: add Precision and Cardinality for timestamps/bit-strings/floats and arrays respectively.
}
func (cd BaseColumn) GetName() string {
return cd.Name
}
func (cd BaseColumn) GetSQLType() string {
return cd.SQLType
}
func (cd BaseColumn) GetVarcharLen() int {
return cd.VarcharLen
}
func (cd BaseColumn) GetDefaultValue() string {
return cd.DefaultValue
}
func (cd BaseColumn) GetIsNullable() bool {
return cd.IsNullable
}
func (cd BaseColumn) GetIsAutoIncrement() bool {
return cd.IsAutoIncrement
}
func (cd BaseColumn) GetIsIdentity() bool {
return cd.IsIdentity
}
// AppendQuery appends full SQL data type.
func (c *BaseColumn) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
b = append(b, c.SQLType...)
if c.VarcharLen == 0 {
return b, nil
}
b = append(b, "("...)
b = append(b, fmt.Sprint(c.VarcharLen)...)
b = append(b, ")"...)
return b, nil
}

View file

@ -0,0 +1,127 @@
package sqlschema
import (
"slices"
"strings"
"github.com/uptrace/bun/schema"
orderedmap "github.com/wk8/go-ordered-map/v2"
)
type Database interface {
GetTables() *orderedmap.OrderedMap[string, Table]
GetForeignKeys() map[ForeignKey]string
}
var _ Database = (*BaseDatabase)(nil)
// BaseDatabase is a base database definition.
//
// Dialects and only dialects can use it to implement the Database interface.
// Other packages must use the Database interface.
type BaseDatabase struct {
Tables *orderedmap.OrderedMap[string, Table]
ForeignKeys map[ForeignKey]string
}
func (ds BaseDatabase) GetTables() *orderedmap.OrderedMap[string, Table] {
return ds.Tables
}
func (ds BaseDatabase) GetForeignKeys() map[ForeignKey]string {
return ds.ForeignKeys
}
type ForeignKey struct {
From ColumnReference
To ColumnReference
}
func NewColumnReference(tableName string, columns ...string) ColumnReference {
return ColumnReference{
TableName: tableName,
Column: NewColumns(columns...),
}
}
func (fk ForeignKey) DependsOnTable(tableName string) bool {
return fk.From.TableName == tableName || fk.To.TableName == tableName
}
func (fk ForeignKey) DependsOnColumn(tableName string, column string) bool {
return fk.DependsOnTable(tableName) &&
(fk.From.Column.Contains(column) || fk.To.Column.Contains(column))
}
// Columns is a hashable representation of []string used to define schema constraints that depend on multiple columns.
// Although having duplicated column references in these constraints is illegal, Columns neither validates nor enforces this constraint on the caller.
type Columns string
// NewColumns creates a composite column from a slice of column names.
func NewColumns(columns ...string) Columns {
slices.Sort(columns)
return Columns(strings.Join(columns, ","))
}
func (c *Columns) String() string {
return string(*c)
}
func (c *Columns) AppendQuery(fmter schema.Formatter, b []byte) ([]byte, error) {
return schema.Safe(*c).AppendQuery(fmter, b)
}
// Split returns a slice of column names that make up the composite.
func (c *Columns) Split() []string {
return strings.Split(c.String(), ",")
}
// ContainsColumns checks that columns in "other" are a subset of current colums.
func (c *Columns) ContainsColumns(other Columns) bool {
columns := c.Split()
Outer:
for _, check := range other.Split() {
for _, column := range columns {
if check == column {
continue Outer
}
}
return false
}
return true
}
// Contains checks that a composite column contains the current column.
func (c *Columns) Contains(other string) bool {
return c.ContainsColumns(Columns(other))
}
// Replace renames a column if it is part of the composite.
// If a composite consists of multiple columns, only one column will be renamed.
func (c *Columns) Replace(oldColumn, newColumn string) bool {
columns := c.Split()
for i, column := range columns {
if column == oldColumn {
columns[i] = newColumn
*c = NewColumns(columns...)
return true
}
}
return false
}
// Unique represents a unique constraint defined on 1 or more columns.
type Unique struct {
Name string
Columns Columns
}
// Equals checks that two unique constraint are the same, assuming both are defined for the same table.
func (u Unique) Equals(other Unique) bool {
return u.Columns == other.Columns
}
type ColumnReference struct {
TableName string
Column Columns
}

View file

@ -0,0 +1,241 @@
package sqlschema
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/uptrace/bun"
"github.com/uptrace/bun/schema"
orderedmap "github.com/wk8/go-ordered-map/v2"
)
type InspectorDialect interface {
schema.Dialect
// Inspector returns a new instance of Inspector for the dialect.
// Dialects MAY set their default InspectorConfig values in constructor
// but MUST apply InspectorOptions to ensure they can be overriden.
//
// Use ApplyInspectorOptions to reduce boilerplate.
NewInspector(db *bun.DB, options ...InspectorOption) Inspector
// CompareType returns true if col1 and co2 SQL types are equivalent,
// i.e. they might use dialect-specifc type aliases (SERIAL ~ SMALLINT)
// or specify the same VARCHAR length differently (VARCHAR(255) ~ VARCHAR).
CompareType(Column, Column) bool
}
// InspectorConfig controls the scope of migration by limiting the objects Inspector should return.
// Inspectors SHOULD use the configuration directly instead of copying it, or MAY choose to embed it,
// to make sure options are always applied correctly.
type InspectorConfig struct {
// SchemaName limits inspection to tables in a particular schema.
SchemaName string
// ExcludeTables from inspection.
ExcludeTables []string
}
// Inspector reads schema state.
type Inspector interface {
Inspect(ctx context.Context) (Database, error)
}
func WithSchemaName(schemaName string) InspectorOption {
return func(cfg *InspectorConfig) {
cfg.SchemaName = schemaName
}
}
// WithExcludeTables works in append-only mode, i.e. tables cannot be re-included.
func WithExcludeTables(tables ...string) InspectorOption {
return func(cfg *InspectorConfig) {
cfg.ExcludeTables = append(cfg.ExcludeTables, tables...)
}
}
// NewInspector creates a new database inspector, if the dialect supports it.
func NewInspector(db *bun.DB, options ...InspectorOption) (Inspector, error) {
dialect, ok := (db.Dialect()).(InspectorDialect)
if !ok {
return nil, fmt.Errorf("%s does not implement sqlschema.Inspector", db.Dialect().Name())
}
return &inspector{
Inspector: dialect.NewInspector(db, options...),
}, nil
}
func NewBunModelInspector(tables *schema.Tables, options ...InspectorOption) *BunModelInspector {
bmi := &BunModelInspector{
tables: tables,
}
ApplyInspectorOptions(&bmi.InspectorConfig, options...)
return bmi
}
type InspectorOption func(*InspectorConfig)
func ApplyInspectorOptions(cfg *InspectorConfig, options ...InspectorOption) {
for _, opt := range options {
opt(cfg)
}
}
// inspector is opaque pointer to a database inspector.
type inspector struct {
Inspector
}
// BunModelInspector creates the current project state from the passed bun.Models.
// Do not recycle BunModelInspector for different sets of models, as older models will not be de-registerred before the next run.
type BunModelInspector struct {
InspectorConfig
tables *schema.Tables
}
var _ Inspector = (*BunModelInspector)(nil)
func (bmi *BunModelInspector) Inspect(ctx context.Context) (Database, error) {
state := BunModelSchema{
BaseDatabase: BaseDatabase{
ForeignKeys: make(map[ForeignKey]string),
},
Tables: orderedmap.New[string, Table](),
}
for _, t := range bmi.tables.All() {
if t.Schema != bmi.SchemaName {
continue
}
columns := orderedmap.New[string, Column]()
for _, f := range t.Fields {
sqlType, length, err := parseLen(f.CreateTableSQLType)
if err != nil {
return nil, fmt.Errorf("parse length in %q: %w", f.CreateTableSQLType, err)
}
columns.Set(f.Name, &BaseColumn{
Name: f.Name,
SQLType: strings.ToLower(sqlType), // TODO(dyma): maybe this is not necessary after Column.Eq()
VarcharLen: length,
DefaultValue: exprToLower(f.SQLDefault),
IsNullable: !f.NotNull,
IsAutoIncrement: f.AutoIncrement,
IsIdentity: f.Identity,
})
}
var unique []Unique
for name, group := range t.Unique {
// Create a separate unique index for single-column unique constraints
// let each dialect apply the default naming convention.
if name == "" {
for _, f := range group {
unique = append(unique, Unique{Columns: NewColumns(f.Name)})
}
continue
}
// Set the name if it is a "unique group", in which case the user has provided the name.
var columns []string
for _, f := range group {
columns = append(columns, f.Name)
}
unique = append(unique, Unique{Name: name, Columns: NewColumns(columns...)})
}
var pk *PrimaryKey
if len(t.PKs) > 0 {
var columns []string
for _, f := range t.PKs {
columns = append(columns, f.Name)
}
pk = &PrimaryKey{Columns: NewColumns(columns...)}
}
// In cases where a table is defined in a non-default schema in the `bun:table` tag,
// schema.Table only extracts the name of the schema, but passes the entire tag value to t.Name
// for backwads-compatibility. For example, a bun model like this:
// type Model struct { bun.BaseModel `bun:"table:favourite.books` }
// produces
// schema.Table{ Schema: "favourite", Name: "favourite.books" }
tableName := strings.TrimPrefix(t.Name, t.Schema+".")
state.Tables.Set(tableName, &BunTable{
BaseTable: BaseTable{
Schema: t.Schema,
Name: tableName,
Columns: columns,
UniqueConstraints: unique,
PrimaryKey: pk,
},
Model: t.ZeroIface,
})
for _, rel := range t.Relations {
// These relations are nominal and do not need a foreign key to be declared in the current table.
// They will be either expressed as N:1 relations in an m2m mapping table, or will be referenced by the other table if it's a 1:N.
if rel.Type == schema.ManyToManyRelation ||
rel.Type == schema.HasManyRelation {
continue
}
var fromCols, toCols []string
for _, f := range rel.BasePKs {
fromCols = append(fromCols, f.Name)
}
for _, f := range rel.JoinPKs {
toCols = append(toCols, f.Name)
}
target := rel.JoinTable
state.ForeignKeys[ForeignKey{
From: NewColumnReference(t.Name, fromCols...),
To: NewColumnReference(target.Name, toCols...),
}] = ""
}
}
return state, nil
}
func parseLen(typ string) (string, int, error) {
paren := strings.Index(typ, "(")
if paren == -1 {
return typ, 0, nil
}
length, err := strconv.Atoi(typ[paren+1 : len(typ)-1])
if err != nil {
return typ, 0, err
}
return typ[:paren], length, nil
}
// exprToLower converts string to lowercase, if it does not contain a string literal 'lit'.
// Use it to ensure that user-defined default values in the models are always comparable
// to those returned by the database inspector, regardless of the case convention in individual drivers.
func exprToLower(s string) string {
if strings.HasPrefix(s, "'") && strings.HasSuffix(s, "'") {
return s
}
return strings.ToLower(s)
}
// BunModelSchema is the schema state derived from bun table models.
type BunModelSchema struct {
BaseDatabase
Tables *orderedmap.OrderedMap[string, Table]
}
func (ms BunModelSchema) GetTables() *orderedmap.OrderedMap[string, Table] {
return ms.Tables
}
// BunTable provides additional table metadata that is only accessible from scanning bun models.
type BunTable struct {
BaseTable
// Model stores the zero interface to the underlying Go struct.
Model interface{}
}

View file

@ -0,0 +1,49 @@
package sqlschema
import (
"fmt"
"github.com/uptrace/bun"
"github.com/uptrace/bun/schema"
)
type MigratorDialect interface {
schema.Dialect
NewMigrator(db *bun.DB, schemaName string) Migrator
}
type Migrator interface {
AppendSQL(b []byte, operation interface{}) ([]byte, error)
}
// migrator is a dialect-agnostic wrapper for sqlschema.MigratorDialect.
type migrator struct {
Migrator
}
func NewMigrator(db *bun.DB, schemaName string) (Migrator, error) {
md, ok := db.Dialect().(MigratorDialect)
if !ok {
return nil, fmt.Errorf("%q dialect does not implement sqlschema.Migrator", db.Dialect().Name())
}
return &migrator{
Migrator: md.NewMigrator(db, schemaName),
}, nil
}
// BaseMigrator can be embeded by dialect's Migrator implementations to re-use some of the existing bun queries.
type BaseMigrator struct {
db *bun.DB
}
func NewBaseMigrator(db *bun.DB) *BaseMigrator {
return &BaseMigrator{db: db}
}
func (m *BaseMigrator) AppendCreateTable(b []byte, model interface{}) ([]byte, error) {
return m.db.NewCreateTable().Model(model).AppendQuery(m.db.Formatter(), b)
}
func (m *BaseMigrator) AppendDropTable(b []byte, schemaName, tableName string) ([]byte, error) {
return m.db.NewDropTable().TableExpr("?.?", bun.Ident(schemaName), bun.Ident(tableName)).AppendQuery(m.db.Formatter(), b)
}

View file

@ -0,0 +1,60 @@
package sqlschema
import (
orderedmap "github.com/wk8/go-ordered-map/v2"
)
type Table interface {
GetSchema() string
GetName() string
GetColumns() *orderedmap.OrderedMap[string, Column]
GetPrimaryKey() *PrimaryKey
GetUniqueConstraints() []Unique
}
var _ Table = (*BaseTable)(nil)
// BaseTable is a base table definition.
//
// Dialects and only dialects can use it to implement the Table interface.
// Other packages must use the Table interface.
type BaseTable struct {
Schema string
Name string
// ColumnDefinitions map each column name to the column definition.
Columns *orderedmap.OrderedMap[string, Column]
// PrimaryKey holds the primary key definition.
// A nil value means that no primary key is defined for the table.
PrimaryKey *PrimaryKey
// UniqueConstraints defined on the table.
UniqueConstraints []Unique
}
// PrimaryKey represents a primary key constraint defined on 1 or more columns.
type PrimaryKey struct {
Name string
Columns Columns
}
func (td *BaseTable) GetSchema() string {
return td.Schema
}
func (td *BaseTable) GetName() string {
return td.Name
}
func (td *BaseTable) GetColumns() *orderedmap.OrderedMap[string, Column] {
return td.Columns
}
func (td *BaseTable) GetPrimaryKey() *PrimaryKey {
return td.PrimaryKey
}
func (td *BaseTable) GetUniqueConstraints() []Unique {
return td.UniqueConstraints
}

View file

@ -51,7 +51,7 @@ func (m *hasManyModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error
dest := makeDest(m, len(columns))
var n int
m.structKey = make([]interface{}, len(m.rel.JoinPKs))
for rows.Next() {
if m.sliceOfPtr {
m.strct = reflect.New(m.table.Type).Elem()
@ -59,9 +59,8 @@ func (m *hasManyModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, error
m.strct.Set(m.table.ZeroValue)
}
m.structInited = false
m.scanIndex = 0
m.structKey = m.structKey[:0]
if err := rows.Scan(dest...); err != nil {
return 0, err
}
@ -92,9 +91,9 @@ func (m *hasManyModel) Scan(src interface{}) error {
return err
}
for _, f := range m.rel.JoinPKs {
if f.Name == field.Name {
m.structKey = append(m.structKey, indirectFieldValue(field.Value(m.strct)))
for i, f := range m.rel.JoinPKs {
if f.Name == column {
m.structKey[i] = indirectFieldValue(field.Value(m.strct))
break
}
}

View file

@ -1,6 +1,6 @@
{
"name": "gobun",
"version": "1.2.5",
"version": "1.2.6",
"main": "index.js",
"repository": "git@github.com:uptrace/bun.git",
"author": "Vladimir Mihailenco <vladimir.webdev@gmail.com>",

View file

@ -6,6 +6,8 @@ import (
"database/sql/driver"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/uptrace/bun/dialect"
@ -1352,3 +1354,113 @@ func (ih *idxHintsQuery) bufIndexHint(
b = append(b, ")"...)
return b, nil
}
//------------------------------------------------------------------------------
type orderLimitOffsetQuery struct {
order []schema.QueryWithArgs
limit int32
offset int32
}
func (q *orderLimitOffsetQuery) addOrder(orders ...string) {
for _, order := range orders {
if order == "" {
continue
}
index := strings.IndexByte(order, ' ')
if index == -1 {
q.order = append(q.order, schema.UnsafeIdent(order))
continue
}
field := order[:index]
sort := order[index+1:]
switch strings.ToUpper(sort) {
case "ASC", "DESC", "ASC NULLS FIRST", "DESC NULLS FIRST",
"ASC NULLS LAST", "DESC NULLS LAST":
q.order = append(q.order, schema.SafeQuery("? ?", []interface{}{
Ident(field),
Safe(sort),
}))
default:
q.order = append(q.order, schema.UnsafeIdent(order))
}
}
}
func (q *orderLimitOffsetQuery) addOrderExpr(query string, args ...interface{}) {
q.order = append(q.order, schema.SafeQuery(query, args))
}
func (q *orderLimitOffsetQuery) appendOrder(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if len(q.order) > 0 {
b = append(b, " ORDER BY "...)
for i, f := range q.order {
if i > 0 {
b = append(b, ", "...)
}
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
// MSSQL: allows Limit() without Order() as per https://stackoverflow.com/a/36156953
if q.limit > 0 && fmter.Dialect().Name() == dialect.MSSQL {
return append(b, " ORDER BY _temp_sort"...), nil
}
return b, nil
}
func (q *orderLimitOffsetQuery) setLimit(n int) {
q.limit = int32(n)
}
func (q *orderLimitOffsetQuery) setOffset(n int) {
q.offset = int32(n)
}
func (q *orderLimitOffsetQuery) appendLimitOffset(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if fmter.Dialect().Features().Has(feature.OffsetFetch) {
if q.limit > 0 && q.offset > 0 {
b = append(b, " OFFSET "...)
b = strconv.AppendInt(b, int64(q.offset), 10)
b = append(b, " ROWS"...)
b = append(b, " FETCH NEXT "...)
b = strconv.AppendInt(b, int64(q.limit), 10)
b = append(b, " ROWS ONLY"...)
} else if q.limit > 0 {
b = append(b, " OFFSET 0 ROWS"...)
b = append(b, " FETCH NEXT "...)
b = strconv.AppendInt(b, int64(q.limit), 10)
b = append(b, " ROWS ONLY"...)
} else if q.offset > 0 {
b = append(b, " OFFSET "...)
b = strconv.AppendInt(b, int64(q.offset), 10)
b = append(b, " ROWS"...)
}
} else {
if q.limit > 0 {
b = append(b, " LIMIT "...)
b = strconv.AppendInt(b, int64(q.limit), 10)
}
if q.offset > 0 {
b = append(b, " OFFSET "...)
b = strconv.AppendInt(b, int64(q.offset), 10)
}
}
return b, nil
}

View file

@ -42,9 +42,12 @@ func (q *AddColumnQuery) Err(err error) *AddColumnQuery {
return q
}
func (q *AddColumnQuery) Apply(fn func(*AddColumnQuery) *AddColumnQuery) *AddColumnQuery {
if fn != nil {
return fn(q)
// Apply calls each function in fns, passing the AddColumnQuery as an argument.
func (q *AddColumnQuery) Apply(fns ...func(*AddColumnQuery) *AddColumnQuery) *AddColumnQuery {
for _, fn := range fns {
if fn != nil {
q = fn(q)
}
}
return q
}

View file

@ -40,9 +40,12 @@ func (q *DropColumnQuery) Err(err error) *DropColumnQuery {
return q
}
func (q *DropColumnQuery) Apply(fn func(*DropColumnQuery) *DropColumnQuery) *DropColumnQuery {
if fn != nil {
return fn(q)
// Apply calls each function in fns, passing the DropColumnQuery as an argument.
func (q *DropColumnQuery) Apply(fns ...func(*DropColumnQuery) *DropColumnQuery) *DropColumnQuery {
for _, fn := range fns {
if fn != nil {
q = fn(q)
}
}
return q
}

View file

@ -3,6 +3,7 @@ package bun
import (
"context"
"database/sql"
"errors"
"time"
"github.com/uptrace/bun/dialect/feature"
@ -12,6 +13,7 @@ import (
type DeleteQuery struct {
whereBaseQuery
orderLimitOffsetQuery
returningQuery
}
@ -44,10 +46,12 @@ func (q *DeleteQuery) Err(err error) *DeleteQuery {
return q
}
// Apply calls the fn passing the DeleteQuery as an argument.
func (q *DeleteQuery) Apply(fn func(*DeleteQuery) *DeleteQuery) *DeleteQuery {
if fn != nil {
return fn(q)
// Apply calls each function in fns, passing the DeleteQuery as an argument.
func (q *DeleteQuery) Apply(fns ...func(*DeleteQuery) *DeleteQuery) *DeleteQuery {
for _, fn := range fns {
if fn != nil {
q = fn(q)
}
}
return q
}
@ -120,17 +124,50 @@ func (q *DeleteQuery) WhereAllWithDeleted() *DeleteQuery {
return q
}
func (q *DeleteQuery) Order(orders ...string) *DeleteQuery {
if !q.hasFeature(feature.DeleteOrderLimit) {
q.err = errors.New("bun: order is not supported for current dialect")
return q
}
q.addOrder(orders...)
return q
}
func (q *DeleteQuery) OrderExpr(query string, args ...interface{}) *DeleteQuery {
if !q.hasFeature(feature.DeleteOrderLimit) {
q.err = errors.New("bun: order is not supported for current dialect")
return q
}
q.addOrderExpr(query, args...)
return q
}
func (q *DeleteQuery) ForceDelete() *DeleteQuery {
q.flags = q.flags.Set(forceDeleteFlag)
return q
}
// ------------------------------------------------------------------------------
func (q *DeleteQuery) Limit(n int) *DeleteQuery {
if !q.hasFeature(feature.DeleteOrderLimit) {
q.err = errors.New("bun: limit is not supported for current dialect")
return q
}
q.setLimit(n)
return q
}
//------------------------------------------------------------------------------
// Returning adds a RETURNING clause to the query.
//
// To suppress the auto-generated RETURNING clause, use `Returning("NULL")`.
func (q *DeleteQuery) Returning(query string, args ...interface{}) *DeleteQuery {
if !q.hasFeature(feature.DeleteReturning) {
q.err = errors.New("bun: returning is not supported for current dialect")
return q
}
q.addReturning(schema.SafeQuery(query, args))
return q
}
@ -203,7 +240,21 @@ func (q *DeleteQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
return nil, err
}
if q.hasFeature(feature.Returning) && q.hasReturning() {
if q.hasMultiTables() && (len(q.order) > 0 || q.limit > 0) {
return nil, errors.New("bun: can't use ORDER or LIMIT with multiple tables")
}
b, err = q.appendOrder(fmter, b)
if err != nil {
return nil, err
}
b, err = q.appendLimitOffset(fmter, b)
if err != nil {
return nil, err
}
if q.hasFeature(feature.DeleteReturning) && q.hasReturning() {
b = append(b, " RETURNING "...)
b, err = q.appendReturning(fmter, b)
if err != nil {
@ -265,7 +316,7 @@ func (q *DeleteQuery) scanOrExec(
return nil, err
}
useScan := hasDest || (q.hasReturning() && q.hasFeature(feature.Returning|feature.Output))
useScan := hasDest || (q.hasReturning() && q.hasFeature(feature.DeleteReturning|feature.Output))
var model Model
if useScan {

View file

@ -53,10 +53,12 @@ func (q *InsertQuery) Err(err error) *InsertQuery {
return q
}
// Apply calls the fn passing the SelectQuery as an argument.
func (q *InsertQuery) Apply(fn func(*InsertQuery) *InsertQuery) *InsertQuery {
if fn != nil {
return fn(q)
// Apply calls each function in fns, passing the InsertQuery as an argument.
func (q *InsertQuery) Apply(fns ...func(*InsertQuery) *InsertQuery) *InsertQuery {
for _, fn := range fns {
if fn != nil {
q = fn(q)
}
}
return q
}

View file

@ -50,10 +50,12 @@ func (q *MergeQuery) Err(err error) *MergeQuery {
return q
}
// Apply calls the fn passing the MergeQuery as an argument.
func (q *MergeQuery) Apply(fn func(*MergeQuery) *MergeQuery) *MergeQuery {
if fn != nil {
return fn(q)
// Apply calls each function in fns, passing the MergeQuery as an argument.
func (q *MergeQuery) Apply(fns ...func(*MergeQuery) *MergeQuery) *MergeQuery {
for _, fn := range fns {
if fn != nil {
q = fn(q)
}
}
return q
}

View file

@ -96,3 +96,12 @@ func (q *RawQuery) AppendQuery(fmter schema.Formatter, b []byte) ([]byte, error)
func (q *RawQuery) Operation() string {
return "SELECT"
}
func (q *RawQuery) String() string {
buf, err := q.AppendQuery(q.db.Formatter(), nil)
if err != nil {
panic(err)
}
return string(buf)
}

View file

@ -6,8 +6,6 @@ import (
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"github.com/uptrace/bun/dialect"
@ -25,14 +23,12 @@ type union struct {
type SelectQuery struct {
whereBaseQuery
idxHintsQuery
orderLimitOffsetQuery
distinctOn []schema.QueryWithArgs
joins []joinQuery
group []schema.QueryWithArgs
having []schema.QueryWithArgs
order []schema.QueryWithArgs
limit int32
offset int32
selFor schema.QueryWithArgs
union []union
@ -66,10 +62,12 @@ func (q *SelectQuery) Err(err error) *SelectQuery {
return q
}
// Apply calls the fn passing the SelectQuery as an argument.
func (q *SelectQuery) Apply(fn func(*SelectQuery) *SelectQuery) *SelectQuery {
if fn != nil {
return fn(q)
// Apply calls each function in fns, passing the SelectQuery as an argument.
func (q *SelectQuery) Apply(fns ...func(*SelectQuery) *SelectQuery) *SelectQuery {
for _, fn := range fns {
if fn != nil {
q = fn(q)
}
}
return q
}
@ -279,46 +277,22 @@ func (q *SelectQuery) Having(having string, args ...interface{}) *SelectQuery {
}
func (q *SelectQuery) Order(orders ...string) *SelectQuery {
for _, order := range orders {
if order == "" {
continue
}
index := strings.IndexByte(order, ' ')
if index == -1 {
q.order = append(q.order, schema.UnsafeIdent(order))
continue
}
field := order[:index]
sort := order[index+1:]
switch strings.ToUpper(sort) {
case "ASC", "DESC", "ASC NULLS FIRST", "DESC NULLS FIRST",
"ASC NULLS LAST", "DESC NULLS LAST":
q.order = append(q.order, schema.SafeQuery("? ?", []interface{}{
Ident(field),
Safe(sort),
}))
default:
q.order = append(q.order, schema.UnsafeIdent(order))
}
}
q.addOrder(orders...)
return q
}
func (q *SelectQuery) OrderExpr(query string, args ...interface{}) *SelectQuery {
q.order = append(q.order, schema.SafeQuery(query, args))
q.addOrderExpr(query, args...)
return q
}
func (q *SelectQuery) Limit(n int) *SelectQuery {
q.limit = int32(n)
q.setLimit(n)
return q
}
func (q *SelectQuery) Offset(n int) *SelectQuery {
q.offset = int32(n)
q.setOffset(n)
return q
}
@ -615,35 +589,9 @@ func (q *SelectQuery) appendQuery(
return nil, err
}
if fmter.Dialect().Features().Has(feature.OffsetFetch) {
if q.limit > 0 && q.offset > 0 {
b = append(b, " OFFSET "...)
b = strconv.AppendInt(b, int64(q.offset), 10)
b = append(b, " ROWS"...)
b = append(b, " FETCH NEXT "...)
b = strconv.AppendInt(b, int64(q.limit), 10)
b = append(b, " ROWS ONLY"...)
} else if q.limit > 0 {
b = append(b, " OFFSET 0 ROWS"...)
b = append(b, " FETCH NEXT "...)
b = strconv.AppendInt(b, int64(q.limit), 10)
b = append(b, " ROWS ONLY"...)
} else if q.offset > 0 {
b = append(b, " OFFSET "...)
b = strconv.AppendInt(b, int64(q.offset), 10)
b = append(b, " ROWS"...)
}
} else {
if q.limit > 0 {
b = append(b, " LIMIT "...)
b = strconv.AppendInt(b, int64(q.limit), 10)
}
if q.offset > 0 {
b = append(b, " OFFSET "...)
b = strconv.AppendInt(b, int64(q.offset), 10)
}
b, err = q.appendLimitOffset(fmter, b)
if err != nil {
return nil, err
}
if !q.selFor.IsZero() {
@ -782,31 +730,6 @@ func (q *SelectQuery) appendTables(fmter schema.Formatter, b []byte) (_ []byte,
return q.appendTablesWithAlias(fmter, b)
}
func (q *SelectQuery) appendOrder(fmter schema.Formatter, b []byte) (_ []byte, err error) {
if len(q.order) > 0 {
b = append(b, " ORDER BY "...)
for i, f := range q.order {
if i > 0 {
b = append(b, ", "...)
}
b, err = f.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
return b, nil
}
// MSSQL: allows Limit() without Order() as per https://stackoverflow.com/a/36156953
if q.limit > 0 && fmter.Dialect().Name() == dialect.MSSQL {
return append(b, " ORDER BY _temp_sort"...), nil
}
return b, nil
}
//------------------------------------------------------------------------------
func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) {

View file

@ -151,3 +151,12 @@ func (q *DropTableQuery) afterDropTableHook(ctx context.Context) error {
}
return nil
}
func (q *DropTableQuery) String() string {
buf, err := q.AppendQuery(q.db.Formatter(), nil)
if err != nil {
panic(err)
}
return string(buf)
}

View file

@ -15,6 +15,7 @@ import (
type UpdateQuery struct {
whereBaseQuery
orderLimitOffsetQuery
returningQuery
customValueQuery
setQuery
@ -53,10 +54,12 @@ func (q *UpdateQuery) Err(err error) *UpdateQuery {
return q
}
// Apply calls the fn passing the SelectQuery as an argument.
func (q *UpdateQuery) Apply(fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery {
if fn != nil {
return fn(q)
// Apply calls each function in fns, passing the UpdateQuery as an argument.
func (q *UpdateQuery) Apply(fns ...func(*UpdateQuery) *UpdateQuery) *UpdateQuery {
for _, fn := range fns {
if fn != nil {
q = fn(q)
}
}
return q
}
@ -200,6 +203,34 @@ func (q *UpdateQuery) WhereAllWithDeleted() *UpdateQuery {
return q
}
// ------------------------------------------------------------------------------
func (q *UpdateQuery) Order(orders ...string) *UpdateQuery {
if !q.hasFeature(feature.UpdateOrderLimit) {
q.err = errors.New("bun: order is not supported for current dialect")
return q
}
q.addOrder(orders...)
return q
}
func (q *UpdateQuery) OrderExpr(query string, args ...interface{}) *UpdateQuery {
if !q.hasFeature(feature.UpdateOrderLimit) {
q.err = errors.New("bun: order is not supported for current dialect")
return q
}
q.addOrderExpr(query, args...)
return q
}
func (q *UpdateQuery) Limit(n int) *UpdateQuery {
if !q.hasFeature(feature.UpdateOrderLimit) {
q.err = errors.New("bun: limit is not supported for current dialect")
return q
}
q.setLimit(n)
return q
}
//------------------------------------------------------------------------------
// Returning adds a RETURNING clause to the query.
@ -278,6 +309,16 @@ func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
return nil, err
}
b, err = q.appendOrder(fmter, b)
if err != nil {
return nil, err
}
b, err = q.appendLimitOffset(fmter, b)
if err != nil {
return nil, err
}
if q.hasFeature(feature.Returning) && q.hasReturning() {
b = append(b, " RETURNING "...)
b, err = q.appendReturning(fmter, b)

View file

@ -39,6 +39,9 @@ type Dialect interface {
// is mandatory in queries that modify the schema (CREATE TABLE / ADD COLUMN, etc).
// Dialects that do not have such requirement may return 0, which should be interpreted so by the caller.
DefaultVarcharLen() int
// DefaultSchema should returns the name of the default database schema.
DefaultSchema() string
}
// ------------------------------------------------------------------------------
@ -185,3 +188,7 @@ func (d *nopDialect) DefaultVarcharLen() int {
func (d *nopDialect) AppendSequence(b []byte, _ *Table, _ *Field) []byte {
return b
}
func (d *nopDialect) DefaultSchema() string {
return "nop"
}

View file

@ -45,6 +45,7 @@ type Table struct {
TypeName string
ModelName string
Schema string
Name string
SQLName Safe
SQLNameForSelects Safe
@ -85,6 +86,7 @@ func (table *Table) init(dialect Dialect, typ reflect.Type, canAddr bool) {
table.setName(tableName)
table.Alias = table.ModelName
table.SQLAlias = table.quoteIdent(table.ModelName)
table.Schema = dialect.DefaultSchema()
table.Fields = make([]*Field, 0, typ.NumField())
table.FieldMap = make(map[string]*Field, typ.NumField())
@ -244,6 +246,31 @@ func (t *Table) processFields(typ reflect.Type, canAddr bool) {
subfield.SQLName = t.quoteIdent(subfield.Name)
}
t.addField(subfield)
if v, ok := subfield.Tag.Options["unique"]; ok {
t.addUnique(subfield, embfield.prefix, v)
}
}
}
func (t *Table) addUnique(field *Field, prefix string, tagOptions []string) {
var names []string
if len(tagOptions) == 1 {
// Split the value by comma, this will allow multiple names to be specified.
// We can use this to create multiple named unique constraints where a single column
// might be included in multiple constraints.
names = strings.Split(tagOptions[0], ",")
} else {
names = tagOptions
}
for _, uname := range names {
if t.Unique == nil {
t.Unique = make(map[string][]*Field)
}
if uname != "" && prefix != "" {
uname = prefix + uname
}
t.Unique[uname] = append(t.Unique[uname], field)
}
}
@ -371,10 +398,18 @@ func (t *Table) processBaseModelField(f reflect.StructField) {
}
if tag.Name != "" {
schema, _ := t.schemaFromTagName(tag.Name)
t.Schema = schema
// Eventually, we should only assign the "table" portion as the table name,
// which will also require a change in how the table name is appended to queries.
// Until that is done, set table name to tag.Name.
t.setName(tag.Name)
}
if s, ok := tag.Option("table"); ok {
schema, _ := t.schemaFromTagName(s)
t.Schema = schema
t.setName(s)
}
@ -388,6 +423,17 @@ func (t *Table) processBaseModelField(f reflect.StructField) {
}
}
// schemaFromTagName splits the bun.BaseModel tag name into schema and table name
// in case it is specified in the "schema"."table" format.
// Assume default schema if one isn't explicitly specified.
func (t *Table) schemaFromTagName(name string) (string, string) {
schema, table := t.dialect.DefaultSchema(), name
if schemaTable := strings.Split(name, "."); len(schemaTable) == 2 {
schema, table = schemaTable[0], schemaTable[1]
}
return schema, table
}
// nolint
func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field {
sqlName := internal.Underscore(sf.Name)
@ -439,22 +485,7 @@ func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field {
}
if v, ok := tag.Options["unique"]; ok {
var names []string
if len(v) == 1 {
// Split the value by comma, this will allow multiple names to be specified.
// We can use this to create multiple named unique constraints where a single column
// might be included in multiple constraints.
names = strings.Split(v[0], ",")
} else {
names = v
}
for _, uniqueName := range names {
if t.Unique == nil {
t.Unique = make(map[string][]*Field)
}
t.Unique[uniqueName] = append(t.Unique[uniqueName], field)
}
t.addUnique(field, "", v)
}
if s, ok := tag.Option("default"); ok {
field.SQLDefault = s

View file

@ -77,6 +77,7 @@ func (t *Tables) InProgress(typ reflect.Type) *Table {
return table
}
// ByModel gets the table by its Go name.
func (t *Tables) ByModel(name string) *Table {
var found *Table
t.tables.Range(func(typ reflect.Type, table *Table) bool {
@ -89,6 +90,7 @@ func (t *Tables) ByModel(name string) *Table {
return found
}
// ByName gets the table by its SQL name.
func (t *Tables) ByName(name string) *Table {
var found *Table
t.tables.Range(func(typ reflect.Type, table *Table) bool {
@ -100,3 +102,13 @@ func (t *Tables) ByName(name string) *Table {
})
return found
}
// All returns all registered tables.
func (t *Tables) All() []*Table {
var found []*Table
t.tables.Range(func(typ reflect.Type, table *Table) bool {
found = append(found, table)
return true
})
return found
}

View file

@ -2,5 +2,5 @@ package bun
// Version is the current release version.
func Version() string {
return "1.2.5"
return "1.2.6"
}