mirror of
				https://github.com/superseriousbusiness/gotosocial.git
				synced 2025-11-03 22:22:25 -06:00 
			
		
		
		
	
		
			
				
	
	
		
			402 lines
		
	
	
	
		
			8.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			402 lines
		
	
	
	
		
			8.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package migrate
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"io/ioutil"
 | 
						|
	"log"
 | 
						|
	"path/filepath"
 | 
						|
	"regexp"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/uptrace/bun"
 | 
						|
)
 | 
						|
 | 
						|
type MigratorOption func(m *Migrator)
 | 
						|
 | 
						|
func WithTableName(table string) MigratorOption {
 | 
						|
	return func(m *Migrator) {
 | 
						|
		m.table = table
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func WithLocksTableName(table string) MigratorOption {
 | 
						|
	return func(m *Migrator) {
 | 
						|
		m.locksTable = table
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type Migrator struct {
 | 
						|
	db         *bun.DB
 | 
						|
	migrations *Migrations
 | 
						|
 | 
						|
	ms MigrationSlice
 | 
						|
 | 
						|
	table      string
 | 
						|
	locksTable string
 | 
						|
}
 | 
						|
 | 
						|
func NewMigrator(db *bun.DB, migrations *Migrations, opts ...MigratorOption) *Migrator {
 | 
						|
	m := &Migrator{
 | 
						|
		db:         db,
 | 
						|
		migrations: migrations,
 | 
						|
 | 
						|
		ms: migrations.ms,
 | 
						|
 | 
						|
		table:      "bun_migrations",
 | 
						|
		locksTable: "bun_migration_locks",
 | 
						|
	}
 | 
						|
	for _, opt := range opts {
 | 
						|
		opt(m)
 | 
						|
	}
 | 
						|
	return m
 | 
						|
}
 | 
						|
 | 
						|
func (m *Migrator) DB() *bun.DB {
 | 
						|
	return m.db
 | 
						|
}
 | 
						|
 | 
						|
// MigrationsWithStatus returns migrations with status in ascending order.
 | 
						|
func (m *Migrator) MigrationsWithStatus(ctx context.Context) (MigrationSlice, error) {
 | 
						|
	sorted := m.migrations.Sorted()
 | 
						|
 | 
						|
	applied, err := m.selectAppliedMigrations(ctx)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	appliedMap := migrationMap(applied)
 | 
						|
	for i := range sorted {
 | 
						|
		m1 := &sorted[i]
 | 
						|
		if m2, ok := appliedMap[m1.Name]; ok {
 | 
						|
			m1.ID = m2.ID
 | 
						|
			m1.GroupID = m2.GroupID
 | 
						|
			m1.MigratedAt = m2.MigratedAt
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return sorted, nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *Migrator) Init(ctx context.Context) error {
 | 
						|
	if _, err := m.db.NewCreateTable().
 | 
						|
		Model((*Migration)(nil)).
 | 
						|
		ModelTableExpr(m.table).
 | 
						|
		IfNotExists().
 | 
						|
		Exec(ctx); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	if _, err := m.db.NewCreateTable().
 | 
						|
		Model((*migrationLock)(nil)).
 | 
						|
		ModelTableExpr(m.locksTable).
 | 
						|
		IfNotExists().
 | 
						|
		Exec(ctx); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *Migrator) Reset(ctx context.Context) error {
 | 
						|
	if _, err := m.db.NewDropTable().
 | 
						|
		Model((*Migration)(nil)).
 | 
						|
		ModelTableExpr(m.table).
 | 
						|
		IfExists().
 | 
						|
		Exec(ctx); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	if _, err := m.db.NewDropTable().
 | 
						|
		Model((*migrationLock)(nil)).
 | 
						|
		ModelTableExpr(m.locksTable).
 | 
						|
		IfExists().
 | 
						|
		Exec(ctx); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	return m.Init(ctx)
 | 
						|
}
 | 
						|
 | 
						|
// Migrate runs unapplied migrations. If a migration fails, migrate immediately exits.
 | 
						|
func (m *Migrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
 | 
						|
	cfg := newMigrationConfig(opts)
 | 
						|
 | 
						|
	if err := m.validate(); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if err := m.Lock(ctx); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	defer m.Unlock(ctx) //nolint:errcheck
 | 
						|
 | 
						|
	migrations, err := m.MigrationsWithStatus(ctx)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	group := &MigrationGroup{
 | 
						|
		Migrations: migrations.Unapplied(),
 | 
						|
	}
 | 
						|
	if len(group.Migrations) == 0 {
 | 
						|
		return group, nil
 | 
						|
	}
 | 
						|
	group.ID = migrations.LastGroupID() + 1
 | 
						|
 | 
						|
	for i := range group.Migrations {
 | 
						|
		migration := &group.Migrations[i]
 | 
						|
		migration.GroupID = group.ID
 | 
						|
 | 
						|
		if !cfg.nop && migration.Up != nil {
 | 
						|
			if err := migration.Up(ctx, m.db); err != nil {
 | 
						|
				return group, err
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		if err := m.MarkApplied(ctx, migration); err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return group, nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *Migrator) Rollback(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
 | 
						|
	cfg := newMigrationConfig(opts)
 | 
						|
 | 
						|
	if err := m.validate(); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if err := m.Lock(ctx); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	defer m.Unlock(ctx) //nolint:errcheck
 | 
						|
 | 
						|
	migrations, err := m.MigrationsWithStatus(ctx)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	lastGroup := migrations.LastGroup()
 | 
						|
 | 
						|
	for i := len(lastGroup.Migrations) - 1; i >= 0; i-- {
 | 
						|
		migration := &lastGroup.Migrations[i]
 | 
						|
 | 
						|
		if !cfg.nop && migration.Down != nil {
 | 
						|
			if err := migration.Down(ctx, m.db); err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		if err := m.MarkUnapplied(ctx, migration); err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return lastGroup, nil
 | 
						|
}
 | 
						|
 | 
						|
type MigrationStatus struct {
 | 
						|
	Migrations    MigrationSlice
 | 
						|
	NewMigrations MigrationSlice
 | 
						|
	LastGroup     *MigrationGroup
 | 
						|
}
 | 
						|
 | 
						|
func (m *Migrator) Status(ctx context.Context) (*MigrationStatus, error) {
 | 
						|
	log.Printf(
 | 
						|
		"DEPRECATED: bun: replace Status(ctx) with " +
 | 
						|
			"MigrationsWithStatus(ctx)")
 | 
						|
 | 
						|
	migrations, err := m.MigrationsWithStatus(ctx)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	return &MigrationStatus{
 | 
						|
		Migrations:    migrations,
 | 
						|
		NewMigrations: migrations.Unapplied(),
 | 
						|
		LastGroup:     migrations.LastGroup(),
 | 
						|
	}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *Migrator) MarkCompleted(ctx context.Context) (*MigrationGroup, error) {
 | 
						|
	log.Printf(
 | 
						|
		"DEPRECATED: bun: replace MarkCompleted(ctx) with " +
 | 
						|
			"Migrate(ctx, migrate.WithNopMigration())")
 | 
						|
 | 
						|
	return m.Migrate(ctx, WithNopMigration())
 | 
						|
}
 | 
						|
 | 
						|
type goMigrationConfig struct {
 | 
						|
	packageName string
 | 
						|
}
 | 
						|
 | 
						|
type GoMigrationOption func(cfg *goMigrationConfig)
 | 
						|
 | 
						|
func WithPackageName(name string) GoMigrationOption {
 | 
						|
	return func(cfg *goMigrationConfig) {
 | 
						|
		cfg.packageName = name
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// CreateGoMigration creates a Go migration file.
 | 
						|
func (m *Migrator) CreateGoMigration(
 | 
						|
	ctx context.Context, name string, opts ...GoMigrationOption,
 | 
						|
) (*MigrationFile, error) {
 | 
						|
	cfg := &goMigrationConfig{
 | 
						|
		packageName: "migrations",
 | 
						|
	}
 | 
						|
	for _, opt := range opts {
 | 
						|
		opt(cfg)
 | 
						|
	}
 | 
						|
 | 
						|
	name, err := m.genMigrationName(name)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	fname := name + ".go"
 | 
						|
	fpath := filepath.Join(m.migrations.getDirectory(), fname)
 | 
						|
	content := fmt.Sprintf(goTemplate, cfg.packageName)
 | 
						|
 | 
						|
	if err := ioutil.WriteFile(fpath, []byte(content), 0o644); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	mf := &MigrationFile{
 | 
						|
		Name:    fname,
 | 
						|
		Path:    fpath,
 | 
						|
		Content: content,
 | 
						|
	}
 | 
						|
	return mf, nil
 | 
						|
}
 | 
						|
 | 
						|
// CreateSQLMigrations creates an up and down SQL migration files.
 | 
						|
func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) {
 | 
						|
	name, err := m.genMigrationName(name)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	up, err := m.createSQL(ctx, name+".up.sql")
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	down, err := m.createSQL(ctx, name+".down.sql")
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return []*MigrationFile{up, down}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *Migrator) createSQL(ctx context.Context, fname string) (*MigrationFile, error) {
 | 
						|
	fpath := filepath.Join(m.migrations.getDirectory(), fname)
 | 
						|
 | 
						|
	if err := ioutil.WriteFile(fpath, []byte(sqlTemplate), 0o644); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	mf := &MigrationFile{
 | 
						|
		Name:    fname,
 | 
						|
		Path:    fpath,
 | 
						|
		Content: goTemplate,
 | 
						|
	}
 | 
						|
	return mf, nil
 | 
						|
}
 | 
						|
 | 
						|
var nameRE = regexp.MustCompile(`^[0-9a-z_\-]+$`)
 | 
						|
 | 
						|
func (m *Migrator) genMigrationName(name string) (string, error) {
 | 
						|
	const timeFormat = "20060102150405"
 | 
						|
 | 
						|
	if name == "" {
 | 
						|
		return "", errors.New("migrate: migration name can't be empty")
 | 
						|
	}
 | 
						|
	if !nameRE.MatchString(name) {
 | 
						|
		return "", fmt.Errorf("migrate: invalid migration name: %q", name)
 | 
						|
	}
 | 
						|
 | 
						|
	version := time.Now().UTC().Format(timeFormat)
 | 
						|
	return fmt.Sprintf("%s_%s", version, name), nil
 | 
						|
}
 | 
						|
 | 
						|
// MarkApplied marks the migration as applied (applied).
 | 
						|
func (m *Migrator) MarkApplied(ctx context.Context, migration *Migration) error {
 | 
						|
	_, err := m.db.NewInsert().Model(migration).
 | 
						|
		ModelTableExpr(m.table).
 | 
						|
		Exec(ctx)
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
// MarkUnapplied marks the migration as unapplied (new).
 | 
						|
func (m *Migrator) MarkUnapplied(ctx context.Context, migration *Migration) error {
 | 
						|
	_, err := m.db.NewDelete().
 | 
						|
		Model(migration).
 | 
						|
		ModelTableExpr(m.table).
 | 
						|
		Where("id = ?", migration.ID).
 | 
						|
		Exec(ctx)
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
// selectAppliedMigrations selects applied (applied) migrations in descending order.
 | 
						|
func (m *Migrator) selectAppliedMigrations(ctx context.Context) (MigrationSlice, error) {
 | 
						|
	var ms MigrationSlice
 | 
						|
	if err := m.db.NewSelect().
 | 
						|
		ColumnExpr("*").
 | 
						|
		Model(&ms).
 | 
						|
		ModelTableExpr(m.table).
 | 
						|
		Scan(ctx); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	return ms, nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *Migrator) formattedTableName(db *bun.DB) string {
 | 
						|
	return db.Formatter().FormatQuery(m.table)
 | 
						|
}
 | 
						|
 | 
						|
func (m *Migrator) validate() error {
 | 
						|
	if len(m.ms) == 0 {
 | 
						|
		return errors.New("migrate: there are no any migrations")
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
//------------------------------------------------------------------------------
 | 
						|
 | 
						|
type migrationLock struct {
 | 
						|
	ID        int64
 | 
						|
	TableName string `bun:",unique"`
 | 
						|
}
 | 
						|
 | 
						|
func (m *Migrator) Lock(ctx context.Context) error {
 | 
						|
	lock := &migrationLock{
 | 
						|
		TableName: m.formattedTableName(m.db),
 | 
						|
	}
 | 
						|
	if _, err := m.db.NewInsert().
 | 
						|
		Model(lock).
 | 
						|
		ModelTableExpr(m.locksTable).
 | 
						|
		Exec(ctx); err != nil {
 | 
						|
		return fmt.Errorf("migrate: migrations table is already locked (%w)", err)
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m *Migrator) Unlock(ctx context.Context) error {
 | 
						|
	tableName := m.formattedTableName(m.db)
 | 
						|
	_, err := m.db.NewDelete().
 | 
						|
		Model((*migrationLock)(nil)).
 | 
						|
		ModelTableExpr(m.locksTable).
 | 
						|
		Where("? = ?", bun.Ident("table_name"), tableName).
 | 
						|
		Exec(ctx)
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func migrationMap(ms MigrationSlice) map[string]*Migration {
 | 
						|
	mp := make(map[string]*Migration)
 | 
						|
	for i := range ms {
 | 
						|
		m := &ms[i]
 | 
						|
		mp[m.Name] = m
 | 
						|
	}
 | 
						|
	return mp
 | 
						|
}
 |