mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-10-31 01:02:25 -05:00 
			
		
		
		
	
		
			
				
	
	
		
			458 lines
		
	
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			458 lines
		
	
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package migrate
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"os"
 | |
| 	"path/filepath"
 | |
| 	"regexp"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/uptrace/bun"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	defaultTable      = "bun_migrations"
 | |
| 	defaultLocksTable = "bun_migration_locks"
 | |
| )
 | |
| 
 | |
| type MigratorOption func(m *Migrator)
 | |
| 
 | |
| // WithTableName overrides default migrations table name.
 | |
| func WithTableName(table string) MigratorOption {
 | |
| 	return func(m *Migrator) {
 | |
| 		m.table = table
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // WithLocksTableName overrides default migration locks table name.
 | |
| func WithLocksTableName(table string) MigratorOption {
 | |
| 	return func(m *Migrator) {
 | |
| 		m.locksTable = table
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // WithMarkAppliedOnSuccess sets the migrator to only mark migrations as applied/unapplied
 | |
| // when their up/down is successful.
 | |
| func WithMarkAppliedOnSuccess(enabled bool) MigratorOption {
 | |
| 	return func(m *Migrator) {
 | |
| 		m.markAppliedOnSuccess = enabled
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type Migrator struct {
 | |
| 	db         *bun.DB
 | |
| 	migrations *Migrations
 | |
| 
 | |
| 	ms MigrationSlice
 | |
| 
 | |
| 	table                string
 | |
| 	locksTable           string
 | |
| 	markAppliedOnSuccess bool
 | |
| }
 | |
| 
 | |
| func NewMigrator(db *bun.DB, migrations *Migrations, opts ...MigratorOption) *Migrator {
 | |
| 	m := &Migrator{
 | |
| 		db:         db,
 | |
| 		migrations: migrations,
 | |
| 
 | |
| 		ms: migrations.ms,
 | |
| 
 | |
| 		table:      defaultTable,
 | |
| 		locksTable: defaultLocksTable,
 | |
| 	}
 | |
| 	for _, opt := range opts {
 | |
| 		opt(m)
 | |
| 	}
 | |
| 	return m
 | |
| }
 | |
| 
 | |
| func (m *Migrator) DB() *bun.DB {
 | |
| 	return m.db
 | |
| }
 | |
| 
 | |
| // MigrationsWithStatus returns migrations with status in ascending order.
 | |
| func (m *Migrator) MigrationsWithStatus(ctx context.Context) (MigrationSlice, error) {
 | |
| 	sorted, _, err := m.migrationsWithStatus(ctx)
 | |
| 	return sorted, err
 | |
| }
 | |
| 
 | |
| func (m *Migrator) migrationsWithStatus(ctx context.Context) (MigrationSlice, int64, error) {
 | |
| 	sorted := m.migrations.Sorted()
 | |
| 
 | |
| 	applied, err := m.AppliedMigrations(ctx)
 | |
| 	if err != nil {
 | |
| 		return nil, 0, err
 | |
| 	}
 | |
| 
 | |
| 	appliedMap := migrationMap(applied)
 | |
| 	for i := range sorted {
 | |
| 		m1 := &sorted[i]
 | |
| 		if m2, ok := appliedMap[m1.Name]; ok {
 | |
| 			m1.ID = m2.ID
 | |
| 			m1.GroupID = m2.GroupID
 | |
| 			m1.MigratedAt = m2.MigratedAt
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return sorted, applied.LastGroupID(), nil
 | |
| }
 | |
| 
 | |
| func (m *Migrator) Init(ctx context.Context) error {
 | |
| 	if _, err := m.db.NewCreateTable().
 | |
| 		Model((*Migration)(nil)).
 | |
| 		ModelTableExpr(m.table).
 | |
| 		IfNotExists().
 | |
| 		Exec(ctx); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	if _, err := m.db.NewCreateTable().
 | |
| 		Model((*migrationLock)(nil)).
 | |
| 		ModelTableExpr(m.locksTable).
 | |
| 		IfNotExists().
 | |
| 		Exec(ctx); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (m *Migrator) Reset(ctx context.Context) error {
 | |
| 	if _, err := m.db.NewDropTable().
 | |
| 		Model((*Migration)(nil)).
 | |
| 		ModelTableExpr(m.table).
 | |
| 		IfExists().
 | |
| 		Exec(ctx); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	if _, err := m.db.NewDropTable().
 | |
| 		Model((*migrationLock)(nil)).
 | |
| 		ModelTableExpr(m.locksTable).
 | |
| 		IfExists().
 | |
| 		Exec(ctx); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return m.Init(ctx)
 | |
| }
 | |
| 
 | |
| // Migrate runs unapplied migrations. If a migration fails, migrate immediately exits.
 | |
| func (m *Migrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
 | |
| 	cfg := newMigrationConfig(opts)
 | |
| 
 | |
| 	if err := m.validate(); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	migrations, lastGroupID, err := m.migrationsWithStatus(ctx)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	migrations = migrations.Unapplied()
 | |
| 
 | |
| 	group := new(MigrationGroup)
 | |
| 	if len(migrations) == 0 {
 | |
| 		return group, nil
 | |
| 	}
 | |
| 	group.ID = lastGroupID + 1
 | |
| 
 | |
| 	for i := range migrations {
 | |
| 		migration := &migrations[i]
 | |
| 		migration.GroupID = group.ID
 | |
| 
 | |
| 		if !m.markAppliedOnSuccess {
 | |
| 			if err := m.MarkApplied(ctx, migration); err != nil {
 | |
| 				return group, err
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		group.Migrations = migrations[:i+1]
 | |
| 
 | |
| 		if !cfg.nop && migration.Up != nil {
 | |
| 			if err := migration.Up(ctx, m.db); err != nil {
 | |
| 				return group, err
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if m.markAppliedOnSuccess {
 | |
| 			if err := m.MarkApplied(ctx, migration); err != nil {
 | |
| 				return group, err
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return group, nil
 | |
| }
 | |
| 
 | |
| func (m *Migrator) Rollback(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
 | |
| 	cfg := newMigrationConfig(opts)
 | |
| 
 | |
| 	if err := m.validate(); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	migrations, err := m.MigrationsWithStatus(ctx)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	lastGroup := migrations.LastGroup()
 | |
| 
 | |
| 	for i := len(lastGroup.Migrations) - 1; i >= 0; i-- {
 | |
| 		migration := &lastGroup.Migrations[i]
 | |
| 
 | |
| 		if !m.markAppliedOnSuccess {
 | |
| 			if err := m.MarkUnapplied(ctx, migration); err != nil {
 | |
| 				return lastGroup, err
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if !cfg.nop && migration.Down != nil {
 | |
| 			if err := migration.Down(ctx, m.db); err != nil {
 | |
| 				return lastGroup, err
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if m.markAppliedOnSuccess {
 | |
| 			if err := m.MarkUnapplied(ctx, migration); err != nil {
 | |
| 				return lastGroup, err
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return lastGroup, nil
 | |
| }
 | |
| 
 | |
| type goMigrationConfig struct {
 | |
| 	packageName string
 | |
| 	goTemplate  string
 | |
| }
 | |
| 
 | |
| type GoMigrationOption func(cfg *goMigrationConfig)
 | |
| 
 | |
| func WithPackageName(name string) GoMigrationOption {
 | |
| 	return func(cfg *goMigrationConfig) {
 | |
| 		cfg.packageName = name
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func WithGoTemplate(template string) GoMigrationOption {
 | |
| 	return func(cfg *goMigrationConfig) {
 | |
| 		cfg.goTemplate = template
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // CreateGoMigration creates a Go migration file.
 | |
| func (m *Migrator) CreateGoMigration(
 | |
| 	ctx context.Context, name string, opts ...GoMigrationOption,
 | |
| ) (*MigrationFile, error) {
 | |
| 	cfg := &goMigrationConfig{
 | |
| 		packageName: "migrations",
 | |
| 		goTemplate:  goTemplate,
 | |
| 	}
 | |
| 	for _, opt := range opts {
 | |
| 		opt(cfg)
 | |
| 	}
 | |
| 
 | |
| 	name, err := genMigrationName(name)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	fname := name + ".go"
 | |
| 	fpath := filepath.Join(m.migrations.getDirectory(), fname)
 | |
| 	content := fmt.Sprintf(cfg.goTemplate, cfg.packageName)
 | |
| 
 | |
| 	if err := os.WriteFile(fpath, []byte(content), 0o644); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	mf := &MigrationFile{
 | |
| 		Name:    fname,
 | |
| 		Path:    fpath,
 | |
| 		Content: content,
 | |
| 	}
 | |
| 	return mf, nil
 | |
| }
 | |
| 
 | |
| // CreateTxSQLMigration creates transactional up and down SQL migration files.
 | |
| func (m *Migrator) CreateTxSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) {
 | |
| 	name, err := genMigrationName(name)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	up, err := m.createSQL(ctx, name+".tx.up.sql", true)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	down, err := m.createSQL(ctx, name+".tx.down.sql", true)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return []*MigrationFile{up, down}, nil
 | |
| }
 | |
| 
 | |
| // CreateSQLMigrations creates up and down SQL migration files.
 | |
| func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) {
 | |
| 	name, err := genMigrationName(name)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	up, err := m.createSQL(ctx, name+".up.sql", false)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	down, err := m.createSQL(ctx, name+".down.sql", false)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return []*MigrationFile{up, down}, nil
 | |
| }
 | |
| 
 | |
| func (m *Migrator) createSQL(_ context.Context, fname string, transactional bool) (*MigrationFile, error) {
 | |
| 	fpath := filepath.Join(m.migrations.getDirectory(), fname)
 | |
| 
 | |
| 	template := sqlTemplate
 | |
| 	if transactional {
 | |
| 		template = transactionalSQLTemplate
 | |
| 	}
 | |
| 
 | |
| 	if err := os.WriteFile(fpath, []byte(template), 0o644); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	mf := &MigrationFile{
 | |
| 		Name:    fname,
 | |
| 		Path:    fpath,
 | |
| 		Content: goTemplate,
 | |
| 	}
 | |
| 	return mf, nil
 | |
| }
 | |
| 
 | |
| var nameRE = regexp.MustCompile(`^[0-9a-z_\-]+$`)
 | |
| 
 | |
| func genMigrationName(name string) (string, error) {
 | |
| 	const timeFormat = "20060102150405"
 | |
| 
 | |
| 	if name == "" {
 | |
| 		return "", errors.New("migrate: migration name can't be empty")
 | |
| 	}
 | |
| 	if !nameRE.MatchString(name) {
 | |
| 		return "", fmt.Errorf("migrate: invalid migration name: %q", name)
 | |
| 	}
 | |
| 
 | |
| 	version := time.Now().UTC().Format(timeFormat)
 | |
| 	return fmt.Sprintf("%s_%s", version, name), nil
 | |
| }
 | |
| 
 | |
| // MarkApplied marks the migration as applied (completed).
 | |
| func (m *Migrator) MarkApplied(ctx context.Context, migration *Migration) error {
 | |
| 	_, err := m.db.NewInsert().Model(migration).
 | |
| 		ModelTableExpr(m.table).
 | |
| 		Exec(ctx)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // MarkUnapplied marks the migration as unapplied (new).
 | |
| func (m *Migrator) MarkUnapplied(ctx context.Context, migration *Migration) error {
 | |
| 	_, err := m.db.NewDelete().
 | |
| 		Model(migration).
 | |
| 		ModelTableExpr(m.table).
 | |
| 		Where("id = ?", migration.ID).
 | |
| 		Exec(ctx)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func (m *Migrator) TruncateTable(ctx context.Context) error {
 | |
| 	_, err := m.db.NewTruncateTable().
 | |
| 		Model((*Migration)(nil)).
 | |
| 		ModelTableExpr(m.table).
 | |
| 		Exec(ctx)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // MissingMigrations returns applied migrations that can no longer be found.
 | |
| func (m *Migrator) MissingMigrations(ctx context.Context) (MigrationSlice, error) {
 | |
| 	applied, err := m.AppliedMigrations(ctx)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	existing := migrationMap(m.migrations.ms)
 | |
| 	for i := len(applied) - 1; i >= 0; i-- {
 | |
| 		m := &applied[i]
 | |
| 		if _, ok := existing[m.Name]; ok {
 | |
| 			applied = append(applied[:i], applied[i+1:]...)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return applied, nil
 | |
| }
 | |
| 
 | |
| // AppliedMigrations selects applied (applied) migrations in descending order.
 | |
| func (m *Migrator) AppliedMigrations(ctx context.Context) (MigrationSlice, error) {
 | |
| 	var ms MigrationSlice
 | |
| 	if err := m.db.NewSelect().
 | |
| 		ColumnExpr("*").
 | |
| 		Model(&ms).
 | |
| 		ModelTableExpr(m.table).
 | |
| 		Scan(ctx); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return ms, nil
 | |
| }
 | |
| 
 | |
| func (m *Migrator) formattedTableName(db *bun.DB) string {
 | |
| 	return db.Formatter().FormatQuery(m.table)
 | |
| }
 | |
| 
 | |
| func (m *Migrator) validate() error {
 | |
| 	if len(m.ms) == 0 {
 | |
| 		return errors.New("migrate: there are no migrations")
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| //------------------------------------------------------------------------------
 | |
| 
 | |
| type migrationLock struct {
 | |
| 	ID        int64  `bun:",pk,autoincrement"`
 | |
| 	TableName string `bun:",unique"`
 | |
| }
 | |
| 
 | |
| func (m *Migrator) Lock(ctx context.Context) error {
 | |
| 	lock := &migrationLock{
 | |
| 		TableName: m.formattedTableName(m.db),
 | |
| 	}
 | |
| 	if _, err := m.db.NewInsert().
 | |
| 		Model(lock).
 | |
| 		ModelTableExpr(m.locksTable).
 | |
| 		Exec(ctx); err != nil {
 | |
| 		return fmt.Errorf("migrate: migrations table is already locked (%w)", err)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (m *Migrator) Unlock(ctx context.Context) error {
 | |
| 	tableName := m.formattedTableName(m.db)
 | |
| 	_, err := m.db.NewDelete().
 | |
| 		Model((*migrationLock)(nil)).
 | |
| 		ModelTableExpr(m.locksTable).
 | |
| 		Where("? = ?", bun.Ident("table_name"), tableName).
 | |
| 		Exec(ctx)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func migrationMap(ms MigrationSlice) map[string]*Migration {
 | |
| 	mp := make(map[string]*Migration)
 | |
| 	for i := range ms {
 | |
| 		m := &ms[i]
 | |
| 		mp[m.Name] = m
 | |
| 	}
 | |
| 	return mp
 | |
| }
 |