bumps uptrace/bun dependencies to v1.2.6 (#3569)

This commit is contained in:
kim 2024-11-25 15:42:37 +00:00 committed by GitHub
commit 3fceb5fc1a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
68 changed files with 6517 additions and 194 deletions

View file

@ -25,24 +25,24 @@ func AppendBool(b []byte, v bool) []byte {
return append(b, "FALSE"...)
}
func AppendFloat32(b []byte, v float32) []byte {
return appendFloat(b, float64(v), 32)
func AppendFloat32(b []byte, num float32) []byte {
return appendFloat(b, float64(num), 32)
}
func AppendFloat64(b []byte, v float64) []byte {
return appendFloat(b, v, 64)
func AppendFloat64(b []byte, num float64) []byte {
return appendFloat(b, num, 64)
}
func appendFloat(b []byte, v float64, bitSize int) []byte {
func appendFloat(b []byte, num float64, bitSize int) []byte {
switch {
case math.IsNaN(v):
case math.IsNaN(num):
return append(b, "'NaN'"...)
case math.IsInf(v, 1):
case math.IsInf(num, 1):
return append(b, "'Infinity'"...)
case math.IsInf(v, -1):
case math.IsInf(num, -1):
return append(b, "'-Infinity'"...)
default:
return strconv.AppendFloat(b, v, 'f', -1, bitSize)
return strconv.AppendFloat(b, num, 'f', -1, bitSize)
}
}

View file

@ -31,5 +31,8 @@ const (
UpdateFromTable
MSSavepoint
GeneratedIdentity
CompositeIn // ... WHERE (A,B) IN ((N, NN), (N, NN)...)
CompositeIn // ... WHERE (A,B) IN ((N, NN), (N, NN)...)
UpdateOrderLimit // UPDATE ... ORDER BY ... LIMIT ...
DeleteOrderLimit // DELETE ... ORDER BY ... LIMIT ...
DeleteReturning
)

View file

@ -0,0 +1,245 @@
package pgdialect
import (
"fmt"
"strings"
"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate"
"github.com/uptrace/bun/migrate/sqlschema"
"github.com/uptrace/bun/schema"
)
func (d *Dialect) NewMigrator(db *bun.DB, schemaName string) sqlschema.Migrator {
return &migrator{db: db, schemaName: schemaName, BaseMigrator: sqlschema.NewBaseMigrator(db)}
}
type migrator struct {
*sqlschema.BaseMigrator
db *bun.DB
schemaName string
}
var _ sqlschema.Migrator = (*migrator)(nil)
func (m *migrator) AppendSQL(b []byte, operation interface{}) (_ []byte, err error) {
fmter := m.db.Formatter()
// Append ALTER TABLE statement to the enclosed query bytes []byte.
appendAlterTable := func(query []byte, tableName string) []byte {
query = append(query, "ALTER TABLE "...)
query = m.appendFQN(fmter, query, tableName)
return append(query, " "...)
}
switch change := operation.(type) {
case *migrate.CreateTableOp:
return m.AppendCreateTable(b, change.Model)
case *migrate.DropTableOp:
return m.AppendDropTable(b, m.schemaName, change.TableName)
case *migrate.RenameTableOp:
b, err = m.renameTable(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.RenameColumnOp:
b, err = m.renameColumn(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.AddColumnOp:
b, err = m.addColumn(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.DropColumnOp:
b, err = m.dropColumn(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.AddPrimaryKeyOp:
b, err = m.addPrimaryKey(fmter, appendAlterTable(b, change.TableName), change.PrimaryKey)
case *migrate.ChangePrimaryKeyOp:
b, err = m.changePrimaryKey(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.DropPrimaryKeyOp:
b, err = m.dropConstraint(fmter, appendAlterTable(b, change.TableName), change.PrimaryKey.Name)
case *migrate.AddUniqueConstraintOp:
b, err = m.addUnique(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.DropUniqueConstraintOp:
b, err = m.dropConstraint(fmter, appendAlterTable(b, change.TableName), change.Unique.Name)
case *migrate.ChangeColumnTypeOp:
b, err = m.changeColumnType(fmter, appendAlterTable(b, change.TableName), change)
case *migrate.AddForeignKeyOp:
b, err = m.addForeignKey(fmter, appendAlterTable(b, change.TableName()), change)
case *migrate.DropForeignKeyOp:
b, err = m.dropConstraint(fmter, appendAlterTable(b, change.TableName()), change.ConstraintName)
default:
return nil, fmt.Errorf("append sql: unknown operation %T", change)
}
if err != nil {
return nil, fmt.Errorf("append sql: %w", err)
}
return b, nil
}
func (m *migrator) appendFQN(fmter schema.Formatter, b []byte, tableName string) []byte {
return fmter.AppendQuery(b, "?.?", bun.Ident(m.schemaName), bun.Ident(tableName))
}
func (m *migrator) renameTable(fmter schema.Formatter, b []byte, rename *migrate.RenameTableOp) (_ []byte, err error) {
b = append(b, "RENAME TO "...)
b = fmter.AppendName(b, rename.NewName)
return b, nil
}
func (m *migrator) renameColumn(fmter schema.Formatter, b []byte, rename *migrate.RenameColumnOp) (_ []byte, err error) {
b = append(b, "RENAME COLUMN "...)
b = fmter.AppendName(b, rename.OldName)
b = append(b, " TO "...)
b = fmter.AppendName(b, rename.NewName)
return b, nil
}
func (m *migrator) addColumn(fmter schema.Formatter, b []byte, add *migrate.AddColumnOp) (_ []byte, err error) {
b = append(b, "ADD COLUMN "...)
b = fmter.AppendName(b, add.ColumnName)
b = append(b, " "...)
b, err = add.Column.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
if add.Column.GetDefaultValue() != "" {
b = append(b, " DEFAULT "...)
b = append(b, add.Column.GetDefaultValue()...)
b = append(b, " "...)
}
if add.Column.GetIsIdentity() {
b = appendGeneratedAsIdentity(b)
}
return b, nil
}
func (m *migrator) dropColumn(fmter schema.Formatter, b []byte, drop *migrate.DropColumnOp) (_ []byte, err error) {
b = append(b, "DROP COLUMN "...)
b = fmter.AppendName(b, drop.ColumnName)
return b, nil
}
func (m *migrator) addPrimaryKey(fmter schema.Formatter, b []byte, pk sqlschema.PrimaryKey) (_ []byte, err error) {
b = append(b, "ADD PRIMARY KEY ("...)
b, _ = pk.Columns.AppendQuery(fmter, b)
b = append(b, ")"...)
return b, nil
}
func (m *migrator) changePrimaryKey(fmter schema.Formatter, b []byte, change *migrate.ChangePrimaryKeyOp) (_ []byte, err error) {
b, _ = m.dropConstraint(fmter, b, change.Old.Name)
b = append(b, ", "...)
b, _ = m.addPrimaryKey(fmter, b, change.New)
return b, nil
}
func (m *migrator) addUnique(fmter schema.Formatter, b []byte, change *migrate.AddUniqueConstraintOp) (_ []byte, err error) {
b = append(b, "ADD CONSTRAINT "...)
if change.Unique.Name != "" {
b = fmter.AppendName(b, change.Unique.Name)
} else {
// Default naming scheme for unique constraints in Postgres is <table>_<column>_key
b = fmter.AppendName(b, fmt.Sprintf("%s_%s_key", change.TableName, change.Unique.Columns))
}
b = append(b, " UNIQUE ("...)
b, _ = change.Unique.Columns.AppendQuery(fmter, b)
b = append(b, ")"...)
return b, nil
}
func (m *migrator) dropConstraint(fmter schema.Formatter, b []byte, name string) (_ []byte, err error) {
b = append(b, "DROP CONSTRAINT "...)
b = fmter.AppendName(b, name)
return b, nil
}
func (m *migrator) addForeignKey(fmter schema.Formatter, b []byte, add *migrate.AddForeignKeyOp) (_ []byte, err error) {
b = append(b, "ADD CONSTRAINT "...)
name := add.ConstraintName
if name == "" {
colRef := add.ForeignKey.From
columns := strings.Join(colRef.Column.Split(), "_")
name = fmt.Sprintf("%s_%s_fkey", colRef.TableName, columns)
}
b = fmter.AppendName(b, name)
b = append(b, " FOREIGN KEY ("...)
if b, err = add.ForeignKey.From.Column.AppendQuery(fmter, b); err != nil {
return b, err
}
b = append(b, ")"...)
b = append(b, " REFERENCES "...)
b = m.appendFQN(fmter, b, add.ForeignKey.To.TableName)
b = append(b, " ("...)
if b, err = add.ForeignKey.To.Column.AppendQuery(fmter, b); err != nil {
return b, err
}
b = append(b, ")"...)
return b, nil
}
func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *migrate.ChangeColumnTypeOp) (_ []byte, err error) {
// alterColumn never re-assigns err, so there is no need to check for err != nil after calling it
var i int
appendAlterColumn := func() {
if i > 0 {
b = append(b, ", "...)
}
b = append(b, "ALTER COLUMN "...)
b = fmter.AppendName(b, colDef.Column)
i++
}
got, want := colDef.From, colDef.To
inspector := m.db.Dialect().(sqlschema.InspectorDialect)
if !inspector.CompareType(want, got) {
appendAlterColumn()
b = append(b, " SET DATA TYPE "...)
if b, err = want.AppendQuery(fmter, b); err != nil {
return b, err
}
}
// Column must be declared NOT NULL before identity can be added.
// Although PG can resolve the order of operations itself, we make this explicit in the query.
if want.GetIsNullable() != got.GetIsNullable() {
appendAlterColumn()
if !want.GetIsNullable() {
b = append(b, " SET NOT NULL"...)
} else {
b = append(b, " DROP NOT NULL"...)
}
}
if want.GetIsIdentity() != got.GetIsIdentity() {
appendAlterColumn()
if !want.GetIsIdentity() {
b = append(b, " DROP IDENTITY"...)
} else {
b = append(b, " ADD"...)
b = appendGeneratedAsIdentity(b)
}
}
if want.GetDefaultValue() != got.GetDefaultValue() {
appendAlterColumn()
if want.GetDefaultValue() == "" {
b = append(b, " DROP DEFAULT"...)
} else {
b = append(b, " SET DEFAULT "...)
b = append(b, want.GetDefaultValue()...)
}
}
return b, nil
}

View file

@ -5,6 +5,7 @@ import (
"database/sql/driver"
"encoding/hex"
"fmt"
"math"
"reflect"
"strconv"
"time"
@ -159,7 +160,7 @@ func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte {
case int64:
return strconv.AppendInt(b, v, 10)
case float64:
return dialect.AppendFloat64(b, v)
return arrayAppendFloat64(b, v)
case bool:
return dialect.AppendBool(b, v)
case []byte:
@ -167,7 +168,10 @@ func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte {
case string:
return arrayAppendString(b, v)
case time.Time:
return fmter.Dialect().AppendTime(b, v)
b = append(b, '"')
b = appendTime(b, v)
b = append(b, '"')
return b
default:
err := fmt.Errorf("pgdialect: can't append %T", v)
return dialect.AppendError(b, err)
@ -288,7 +292,7 @@ func appendFloat64Slice(b []byte, floats []float64) []byte {
b = append(b, '{')
for _, n := range floats {
b = dialect.AppendFloat64(b, n)
b = arrayAppendFloat64(b, n)
b = append(b, ',')
}
if len(floats) > 0 {
@ -302,6 +306,19 @@ func appendFloat64Slice(b []byte, floats []float64) []byte {
return b
}
func arrayAppendFloat64(b []byte, num float64) []byte {
switch {
case math.IsNaN(num):
return append(b, "NaN"...)
case math.IsInf(num, 1):
return append(b, "Infinity"...)
case math.IsInf(num, -1):
return append(b, "-Infinity"...)
default:
return strconv.AppendFloat(b, num, 'f', -1, 64)
}
}
func appendTimeSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ts := v.Convert(sliceTimeType).Interface().([]time.Time)
return appendTimeSlice(fmter, b, ts)
@ -383,6 +400,10 @@ func arrayScanner(typ reflect.Type) schema.ScannerFunc {
}
}
if src == nil {
return nil
}
b, err := toBytes(src)
if err != nil {
return err
@ -553,7 +574,7 @@ func scanFloat64SliceValue(dest reflect.Value, src interface{}) error {
}
func scanFloat64Slice(src interface{}) ([]float64, error) {
if src == -1 {
if src == nil {
return nil, nil
}
@ -593,7 +614,7 @@ func toBytes(src interface{}) ([]byte, error) {
case []byte:
return src, nil
default:
return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src)
return nil, fmt.Errorf("pgdialect: got %T, wanted []byte or string", src)
}
}

View file

@ -10,6 +10,7 @@ import (
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/migrate/sqlschema"
"github.com/uptrace/bun/schema"
)
@ -29,6 +30,10 @@ type Dialect struct {
features feature.Feature
}
var _ schema.Dialect = (*Dialect)(nil)
var _ sqlschema.InspectorDialect = (*Dialect)(nil)
var _ sqlschema.MigratorDialect = (*Dialect)(nil)
func New() *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
@ -48,7 +53,8 @@ func New() *Dialect {
feature.InsertOnConflict |
feature.SelectExists |
feature.GeneratedIdentity |
feature.CompositeIn
feature.CompositeIn |
feature.DeleteReturning
return d
}
@ -118,5 +124,10 @@ func (d *Dialect) AppendUint64(b []byte, n uint64) []byte {
}
func (d *Dialect) AppendSequence(b []byte, _ *schema.Table, _ *schema.Field) []byte {
return appendGeneratedAsIdentity(b)
}
// appendGeneratedAsIdentity appends GENERATED BY DEFAULT AS IDENTITY to the column definition.
func appendGeneratedAsIdentity(b []byte) []byte {
return append(b, " GENERATED BY DEFAULT AS IDENTITY"...)
}

View file

@ -0,0 +1,297 @@
package pgdialect
import (
"context"
"strings"
"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate/sqlschema"
orderedmap "github.com/wk8/go-ordered-map/v2"
)
type (
Schema = sqlschema.BaseDatabase
Table = sqlschema.BaseTable
Column = sqlschema.BaseColumn
)
func (d *Dialect) NewInspector(db *bun.DB, options ...sqlschema.InspectorOption) sqlschema.Inspector {
return newInspector(db, options...)
}
type Inspector struct {
sqlschema.InspectorConfig
db *bun.DB
}
var _ sqlschema.Inspector = (*Inspector)(nil)
func newInspector(db *bun.DB, options ...sqlschema.InspectorOption) *Inspector {
i := &Inspector{db: db}
sqlschema.ApplyInspectorOptions(&i.InspectorConfig, options...)
return i
}
func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) {
dbSchema := Schema{
Tables: orderedmap.New[string, sqlschema.Table](),
ForeignKeys: make(map[sqlschema.ForeignKey]string),
}
exclude := in.ExcludeTables
if len(exclude) == 0 {
// Avoid getting NOT IN (NULL) if bun.In() is called with an empty slice.
exclude = []string{""}
}
var tables []*InformationSchemaTable
if err := in.db.NewRaw(sqlInspectTables, in.SchemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil {
return dbSchema, err
}
var fks []*ForeignKey
if err := in.db.NewRaw(sqlInspectForeignKeys, in.SchemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil {
return dbSchema, err
}
dbSchema.ForeignKeys = make(map[sqlschema.ForeignKey]string, len(fks))
for _, table := range tables {
var columns []*InformationSchemaColumn
if err := in.db.NewRaw(sqlInspectColumnsQuery, table.Schema, table.Name).Scan(ctx, &columns); err != nil {
return dbSchema, err
}
colDefs := orderedmap.New[string, sqlschema.Column]()
uniqueGroups := make(map[string][]string)
for _, c := range columns {
def := c.Default
if c.IsSerial || c.IsIdentity {
def = ""
} else if !c.IsDefaultLiteral {
def = strings.ToLower(def)
}
colDefs.Set(c.Name, &Column{
Name: c.Name,
SQLType: c.DataType,
VarcharLen: c.VarcharLen,
DefaultValue: def,
IsNullable: c.IsNullable,
IsAutoIncrement: c.IsSerial,
IsIdentity: c.IsIdentity,
})
for _, group := range c.UniqueGroups {
uniqueGroups[group] = append(uniqueGroups[group], c.Name)
}
}
var unique []sqlschema.Unique
for name, columns := range uniqueGroups {
unique = append(unique, sqlschema.Unique{
Name: name,
Columns: sqlschema.NewColumns(columns...),
})
}
var pk *sqlschema.PrimaryKey
if len(table.PrimaryKey.Columns) > 0 {
pk = &sqlschema.PrimaryKey{
Name: table.PrimaryKey.ConstraintName,
Columns: sqlschema.NewColumns(table.PrimaryKey.Columns...),
}
}
dbSchema.Tables.Set(table.Name, &Table{
Schema: table.Schema,
Name: table.Name,
Columns: colDefs,
PrimaryKey: pk,
UniqueConstraints: unique,
})
}
for _, fk := range fks {
dbSchema.ForeignKeys[sqlschema.ForeignKey{
From: sqlschema.NewColumnReference(fk.SourceTable, fk.SourceColumns...),
To: sqlschema.NewColumnReference(fk.TargetTable, fk.TargetColumns...),
}] = fk.ConstraintName
}
return dbSchema, nil
}
type InformationSchemaTable struct {
Schema string `bun:"table_schema,pk"`
Name string `bun:"table_name,pk"`
PrimaryKey PrimaryKey `bun:"embed:primary_key_"`
Columns []*InformationSchemaColumn `bun:"rel:has-many,join:table_schema=table_schema,join:table_name=table_name"`
}
type InformationSchemaColumn struct {
Schema string `bun:"table_schema"`
Table string `bun:"table_name"`
Name string `bun:"column_name"`
DataType string `bun:"data_type"`
VarcharLen int `bun:"varchar_len"`
IsArray bool `bun:"is_array"`
ArrayDims int `bun:"array_dims"`
Default string `bun:"default"`
IsDefaultLiteral bool `bun:"default_is_literal_expr"`
IsIdentity bool `bun:"is_identity"`
IndentityType string `bun:"identity_type"`
IsSerial bool `bun:"is_serial"`
IsNullable bool `bun:"is_nullable"`
UniqueGroups []string `bun:"unique_groups,array"`
}
type ForeignKey struct {
ConstraintName string `bun:"constraint_name"`
SourceSchema string `bun:"schema_name"`
SourceTable string `bun:"table_name"`
SourceColumns []string `bun:"columns,array"`
TargetSchema string `bun:"target_schema"`
TargetTable string `bun:"target_table"`
TargetColumns []string `bun:"target_columns,array"`
}
type PrimaryKey struct {
ConstraintName string `bun:"name"`
Columns []string `bun:"columns,array"`
}
const (
// sqlInspectTables retrieves all user-defined tables in the selected schema.
// Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results.
sqlInspectTables = `
SELECT
"t".table_schema,
"t".table_name,
pk.name AS primary_key_name,
pk.columns AS primary_key_columns
FROM information_schema.tables "t"
LEFT JOIN (
SELECT i.indrelid, "idx".relname AS "name", ARRAY_AGG("a".attname) AS "columns"
FROM pg_index i
JOIN pg_attribute "a"
ON "a".attrelid = i.indrelid
AND "a".attnum = ANY("i".indkey)
AND i.indisprimary
JOIN pg_class "idx" ON i.indexrelid = "idx".oid
GROUP BY 1, 2
) pk
ON ("t".table_schema || '.' || "t".table_name)::regclass = pk.indrelid
WHERE table_type = 'BASE TABLE'
AND "t".table_schema = ?
AND "t".table_schema NOT LIKE 'pg_%'
AND "table_name" NOT IN (?)
ORDER BY "t".table_schema, "t".table_name
`
// sqlInspectColumnsQuery retrieves column definitions for the specified table.
// Unlike sqlInspectTables and sqlInspectSchema, it should be passed to bun.NewRaw
// with additional args for table_schema and table_name.
sqlInspectColumnsQuery = `
SELECT
"c".table_schema,
"c".table_name,
"c".column_name,
"c".data_type,
"c".character_maximum_length::integer AS varchar_len,
"c".data_type = 'ARRAY' AS is_array,
COALESCE("c".array_dims, 0) AS array_dims,
CASE
WHEN "c".column_default ~ '^''.*''::.*$' THEN substring("c".column_default FROM '^''(.*)''::.*$')
ELSE "c".column_default
END AS "default",
"c".column_default ~ '^''.*''::.*$' OR "c".column_default ~ '^[0-9\.]+$' AS default_is_literal_expr,
"c".is_identity = 'YES' AS is_identity,
"c".column_default = format('nextval(''%s_%s_seq''::regclass)', "c".table_name, "c".column_name) AS is_serial,
COALESCE("c".identity_type, '') AS identity_type,
"c".is_nullable = 'YES' AS is_nullable,
"c"."unique_groups" AS unique_groups
FROM (
SELECT
"table_schema",
"table_name",
"column_name",
"c".data_type,
"c".character_maximum_length,
"c".column_default,
"c".is_identity,
"c".is_nullable,
att.array_dims,
att.identity_type,
att."unique_groups",
att."constraint_type"
FROM information_schema.columns "c"
LEFT JOIN (
SELECT
s.nspname AS "table_schema",
"t".relname AS "table_name",
"c".attname AS "column_name",
"c".attndims AS array_dims,
"c".attidentity AS identity_type,
ARRAY_AGG(con.conname) FILTER (WHERE con.contype = 'u') AS "unique_groups",
ARRAY_AGG(con.contype) AS "constraint_type"
FROM (
SELECT
conname,
contype,
connamespace,
conrelid,
conrelid AS attrelid,
UNNEST(conkey) AS attnum
FROM pg_constraint
) con
LEFT JOIN pg_attribute "c" USING (attrelid, attnum)
LEFT JOIN pg_namespace s ON s.oid = con.connamespace
LEFT JOIN pg_class "t" ON "t".oid = con.conrelid
GROUP BY 1, 2, 3, 4, 5
) att USING ("table_schema", "table_name", "column_name")
) "c"
WHERE "table_schema" = ? AND "table_name" = ?
ORDER BY "table_schema", "table_name", "column_name"
`
// sqlInspectForeignKeys get FK definitions for user-defined tables.
// Pass bun.In([]string{...}) to exclude tables from this inspection or bun.In([]string{''}) to include all results.
sqlInspectForeignKeys = `
WITH
"schemas" AS (
SELECT oid, nspname
FROM pg_namespace
),
"tables" AS (
SELECT oid, relnamespace, relname, relkind
FROM pg_class
),
"columns" AS (
SELECT attrelid, attname, attnum
FROM pg_attribute
WHERE attisdropped = false
)
SELECT DISTINCT
co.conname AS "constraint_name",
ss.nspname AS schema_name,
s.relname AS "table_name",
ARRAY_AGG(sc.attname) AS "columns",
ts.nspname AS target_schema,
"t".relname AS target_table,
ARRAY_AGG(tc.attname) AS target_columns
FROM pg_constraint co
LEFT JOIN "tables" s ON s.oid = co.conrelid
LEFT JOIN "schemas" ss ON ss.oid = s.relnamespace
LEFT JOIN "columns" sc ON sc.attrelid = s.oid AND sc.attnum = ANY(co.conkey)
LEFT JOIN "tables" t ON t.oid = co.confrelid
LEFT JOIN "schemas" ts ON ts.oid = "t".relnamespace
LEFT JOIN "columns" tc ON tc.attrelid = "t".oid AND tc.attnum = ANY(co.confkey)
WHERE co.contype = 'f'
AND co.conrelid IN (SELECT oid FROM pg_class WHERE relkind = 'r')
AND ARRAY_POSITION(co.conkey, sc.attnum) = ARRAY_POSITION(co.confkey, tc.attnum)
AND ss.nspname = ?
AND s.relname NOT IN (?) AND "t".relname NOT IN (?)
GROUP BY "constraint_name", "schema_name", "table_name", target_schema, target_table
`
)

View file

@ -5,18 +5,22 @@ import (
"encoding/json"
"net"
"reflect"
"strings"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/migrate/sqlschema"
"github.com/uptrace/bun/schema"
)
const (
// Date / Time
pgTypeTimestampTz = "TIMESTAMPTZ" // Timestamp with a time zone
pgTypeDate = "DATE" // Date
pgTypeTime = "TIME" // Time without a time zone
pgTypeTimeTz = "TIME WITH TIME ZONE" // Time with a time zone
pgTypeInterval = "INTERVAL" // Time Interval
pgTypeTimestamp = "TIMESTAMP" // Timestamp
pgTypeTimestampWithTz = "TIMESTAMP WITH TIME ZONE" // Timestamp with a time zone
pgTypeTimestampTz = "TIMESTAMPTZ" // Timestamp with a time zone (alias)
pgTypeDate = "DATE" // Date
pgTypeTime = "TIME" // Time without a time zone
pgTypeTimeTz = "TIME WITH TIME ZONE" // Time with a time zone
pgTypeInterval = "INTERVAL" // Time interval
// Network Addresses
pgTypeInet = "INET" // IPv4 or IPv6 hosts and networks
@ -28,6 +32,13 @@ const (
pgTypeSerial = "SERIAL" // 4 byte autoincrementing integer
pgTypeBigSerial = "BIGSERIAL" // 8 byte autoincrementing integer
// Character Types
pgTypeChar = "CHAR" // fixed length string (blank padded)
pgTypeCharacter = "CHARACTER" // alias for CHAR
pgTypeText = "TEXT" // variable length string without limit
pgTypeVarchar = "VARCHAR" // variable length string with optional limit
pgTypeCharacterVarying = "CHARACTER VARYING" // alias for VARCHAR
// Binary Data Types
pgTypeBytea = "BYTEA" // binary string
)
@ -43,6 +54,10 @@ func (d *Dialect) DefaultVarcharLen() int {
return 0
}
func (d *Dialect) DefaultSchema() string {
return "public"
}
func fieldSQLType(field *schema.Field) string {
if field.UserSQLType != "" {
return field.UserSQLType
@ -103,3 +118,62 @@ func sqlType(typ reflect.Type) string {
return sqlType
}
var (
char = newAliases(pgTypeChar, pgTypeCharacter)
varchar = newAliases(pgTypeVarchar, pgTypeCharacterVarying)
timestampTz = newAliases(sqltype.Timestamp, pgTypeTimestampTz, pgTypeTimestampWithTz)
)
func (d *Dialect) CompareType(col1, col2 sqlschema.Column) bool {
typ1, typ2 := strings.ToUpper(col1.GetSQLType()), strings.ToUpper(col2.GetSQLType())
if typ1 == typ2 {
return checkVarcharLen(col1, col2, d.DefaultVarcharLen())
}
switch {
case char.IsAlias(typ1) && char.IsAlias(typ2):
return checkVarcharLen(col1, col2, d.DefaultVarcharLen())
case varchar.IsAlias(typ1) && varchar.IsAlias(typ2):
return checkVarcharLen(col1, col2, d.DefaultVarcharLen())
case timestampTz.IsAlias(typ1) && timestampTz.IsAlias(typ2):
return true
}
return false
}
// checkVarcharLen returns true if columns have the same VarcharLen, or,
// if one specifies no VarcharLen and the other one has the default lenght for pgdialect.
// We assume that the types are otherwise equivalent and that any non-character column
// would have VarcharLen == 0;
func checkVarcharLen(col1, col2 sqlschema.Column, defaultLen int) bool {
vl1, vl2 := col1.GetVarcharLen(), col2.GetVarcharLen()
if vl1 == vl2 {
return true
}
if (vl1 == 0 && vl2 == defaultLen) || (vl1 == defaultLen && vl2 == 0) {
return true
}
return false
}
// typeAlias defines aliases for common data types. It is a lightweight string set implementation.
type typeAlias map[string]struct{}
// IsAlias checks if typ1 and typ2 are aliases of the same data type.
func (t typeAlias) IsAlias(typ string) bool {
_, ok := t[typ]
return ok
}
// newAliases creates a set of aliases.
func newAliases(aliases ...string) typeAlias {
types := make(typeAlias)
for _, a := range aliases {
types[a] = struct{}{}
}
return types
}

View file

@ -2,5 +2,5 @@ package pgdialect
// Version is the current release version.
func Version() string {
return "1.2.5"
return "1.2.6"
}

View file

@ -40,7 +40,8 @@ func New() *Dialect {
feature.TableNotExists |
feature.SelectExists |
feature.AutoIncrement |
feature.CompositeIn
feature.CompositeIn |
feature.DeleteReturning
return d
}
@ -96,9 +97,13 @@ func (d *Dialect) DefaultVarcharLen() int {
// AUTOINCREMENT is only valid for INTEGER PRIMARY KEY, and this method will be a noop for other columns.
//
// Because this is a valid construct:
//
// CREATE TABLE ("id" INTEGER PRIMARY KEY AUTOINCREMENT);
//
// and this is not:
//
// CREATE TABLE ("id" INTEGER AUTOINCREMENT, PRIMARY KEY ("id"));
//
// AppendSequence adds a primary key constraint as a *side-effect*. Callers should expect it to avoid building invalid SQL.
// SQLite also [does not support] AUTOINCREMENT column in composite primary keys.
//
@ -111,6 +116,13 @@ func (d *Dialect) AppendSequence(b []byte, table *schema.Table, field *schema.Fi
return b
}
// DefaultSchemaName is the "schema-name" of the main database.
// The details might differ from other dialects, but for all means and purposes
// "main" is the default schema in an SQLite database.
func (d *Dialect) DefaultSchema() string {
return "main"
}
func fieldSQLType(field *schema.Field) string {
switch field.DiscoveredSQLType {
case sqltype.SmallInt, sqltype.BigInt:

View file

@ -2,5 +2,5 @@ package sqlitedialect
// Version is the current release version.
func Version() string {
return "1.2.5"
return "1.2.6"
}