updates the enum migration to perform a singular update for all values, using an SQL case statement

This commit is contained in:
kim 2025-02-11 14:30:05 +00:00
commit bb8afd231d

View file

@ -35,6 +35,11 @@ import (
"github.com/uptrace/bun/schema" "github.com/uptrace/bun/schema"
) )
// formatQuery formats given query + args according to given bun dialect, using for debug logging.
func formatQuery(d schema.Dialect, query string, args ...any) string {
return schema.NewFormatter(d).FormatQuery(query, args...)
}
// convertEnums performs a transaction that converts // convertEnums performs a transaction that converts
// a table's column of our old-style enums (strings) to // a table's column of our old-style enums (strings) to
// more performant and space-saving integer types. // more performant and space-saving integer types.
@ -82,26 +87,31 @@ func convertEnums[OldType ~string, NewType ~int16](
return gtserror.Newf("error selecting total count: %w", err) return gtserror.Newf("error selecting total count: %w", err)
} }
var updated int var args []any
for old, new := range mapping { var qbuf byteutil.Buffer
// Update old to new values. // Prepare a singular UPDATE statement using
res, err := tx.NewUpdate(). // SET $newColumn = (CASE $column WHEN $old THEN $new ... END)
Table(table). qbuf.WriteString("UPDATE ? SET ? = (CASE ? ")
Where("? = ?", bun.Ident(column), old). args = append(args, bun.Ident(table))
Set("? = ?", bun.Ident(newColumn), new). args = append(args, bun.Ident(newColumn))
Exec(ctx) args = append(args, bun.Ident(column))
for old, new := range mapping {
qbuf.WriteString("WHEN ? THEN ? ")
args = append(args, old, new)
}
qbuf.WriteString("ELSE ? END)")
args = append(args, *defaultValue)
// Execute the prepared raw query with arguments.
res, err := tx.NewRaw(qbuf.String(), args...).Exec(ctx)
if err != nil { if err != nil {
return gtserror.Newf("error updating old column values: %w", err) return gtserror.Newf("error updating old column values: %w", err)
} }
// Count number items updated. // Count number items updated.
n, _ := res.RowsAffected() updated, _ := res.RowsAffected()
updated += int(n) if total != int(updated) {
}
// Check total updated.
if total != updated {
log.Warnf(ctx, "total=%d does not match updated=%d", total, updated) log.Warnf(ctx, "total=%d does not match updated=%d", total, updated)
} }