[chore] Use bulk updates + fewer loops in status rethreading migration (#4459)

This pull request tries to optimize our status rethreading migration by using bulk updates + avoiding unnecessary writes, and doing the migration in one top-level loop and one stragglers loop, without the extra loop to copy thread_id over.

On my machine it runs at about 2400 rows per second on Postgres, now, and about 9000 rows per second on SQLite.

Tried *many* different ways of doing this, with and without temporary indexes, with different batch and transaction sizes, etc., and this seems to be just about the most performant way of getting stuff done.

With the changes, a few minutes have been shaved off migration time testing on my development machine. *Hopefully* this will translate to more time shaved off when running on a vps with slower read/write speed and less processor power.

SQLite before:

```
real	20m58,446s
user	16m26,635s
sys	5m53,648s
```

SQLite after:

```
real	14m25,435s
user	12m47,449s
sys	2m27,898s
```

Postgres before:

```
real	28m25,307s
user	3m40,005s
sys	4m45,018s
```

Postgres after:

```
real	22m31,999s
user	3m46,674s
sys	4m39,592s
```

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4459
Co-authored-by: tobi <tobi.smethurst@protonmail.com>
Co-committed-by: tobi <tobi.smethurst@protonmail.com>
This commit is contained in:
tobi 2025-10-03 12:28:55 +02:00 committed by kim
commit e7cd8bb43e
7 changed files with 429 additions and 271 deletions

View file

@ -24,13 +24,16 @@ import (
"reflect"
"slices"
"strings"
"time"
"code.superseriousbusiness.org/gotosocial/internal/db"
newmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/new"
oldmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/old"
"code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/util"
"code.superseriousbusiness.org/gotosocial/internal/gtserror"
"code.superseriousbusiness.org/gotosocial/internal/id"
"code.superseriousbusiness.org/gotosocial/internal/log"
"code.superseriousbusiness.org/gotosocial/internal/util/xslices"
"github.com/uptrace/bun"
)
@ -49,10 +52,26 @@ func init() {
"thread_id", "thread_id_new", 1)
var sr statusRethreader
var count int
var updatedTotal int64
var maxID string
var statuses []*oldmodel.Status
// Create thread_id_new already
// so we can populate it as we go.
log.Info(ctx, "creating statuses column thread_id_new")
if _, err := db.NewAddColumn().
Table("statuses").
ColumnExpr(newColDef).
Exec(ctx); err != nil {
return gtserror.Newf("error adding statuses column thread_id_new: %w", err)
}
// Try to merge the wal so we're
// not working on the wal file.
if err := doWALCheckpoint(ctx, db); err != nil {
return err
}
// Get a total count of all statuses before migration.
total, err := db.NewSelect().Table("statuses").Count(ctx)
if err != nil {
@ -63,74 +82,129 @@ func init() {
// possible ULID value.
maxID = id.Highest
log.Warn(ctx, "rethreading top-level statuses, this will take a *long* time")
for /* TOP LEVEL STATUS LOOP */ {
log.Warnf(ctx, "rethreading %d statuses, this will take a *long* time", total)
// Open initial transaction.
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
for i := 1; ; i++ {
// Reset slice.
clear(statuses)
statuses = statuses[:0]
// Select top-level statuses.
if err := db.NewSelect().
Model(&statuses).
Column("id", "thread_id").
batchStart := time.Now()
// Select top-level statuses.
if err := tx.NewSelect().
Model(&statuses).
Column("id").
// We specifically use in_reply_to_account_id instead of in_reply_to_id as
// they should both be set / unset in unison, but we specifically have an
// index on in_reply_to_account_id with ID ordering, unlike in_reply_to_id.
Where("? IS NULL", bun.Ident("in_reply_to_account_id")).
Where("? < ?", bun.Ident("id"), maxID).
OrderExpr("? DESC", bun.Ident("id")).
Limit(5000).
Limit(500).
Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) {
return gtserror.Newf("error selecting top level statuses: %w", err)
}
// Reached end of block.
if len(statuses) == 0 {
l := len(statuses)
if l == 0 {
// No more statuses!
//
// Transaction will be closed
// after leaving the loop.
break
} else if i%200 == 0 {
// Begin a new transaction every
// 200 batches (~100,000 statuses),
// to avoid massive commits.
// Close existing transaction.
if err := tx.Commit(); err != nil {
return err
}
// Try to flush the wal
// to avoid silly wal sizes.
if err := doWALCheckpoint(ctx, db); err != nil {
return err
}
// Open new transaction.
tx, err = db.BeginTx(ctx, nil)
if err != nil {
return err
}
}
// Set next maxID value from statuses.
maxID = statuses[len(statuses)-1].ID
// Rethread each selected batch of top-level statuses in a transaction.
if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Rethread each top-level status.
for _, status := range statuses {
n, err := sr.rethreadStatus(ctx, tx, status)
if err != nil {
return gtserror.Newf("error rethreading status %s: %w", status.URI, err)
}
count += n
// Rethread using the
// open transaction.
var updatedInBatch int64
for _, status := range statuses {
n, err := sr.rethreadStatus(ctx, tx, status, false)
if err != nil {
return gtserror.Newf("error rethreading status %s: %w", status.URI, err)
}
return nil
}); err != nil {
return err
updatedInBatch += n
updatedTotal += n
}
log.Infof(ctx, "[approx %d of %d] rethreading statuses (top-level)", count, total)
// Show speed for this batch.
timeTaken := time.Since(batchStart).Milliseconds()
msPerRow := float64(timeTaken) / float64(updatedInBatch)
rowsPerMs := float64(1) / float64(msPerRow)
rowsPerSecond := 1000 * rowsPerMs
// Show percent migrated overall.
totalDone := (float64(updatedTotal) / float64(total)) * 100
log.Infof(
ctx,
"[~%.2f%% done; ~%.0f rows/s] migrating threads",
totalDone, rowsPerSecond,
)
}
// Attempt to merge any sqlite write-ahead-log.
if err := doWALCheckpoint(ctx, db); err != nil {
// Close transaction.
if err := tx.Commit(); err != nil {
return err
}
log.Warn(ctx, "rethreading straggler statuses, this will take a *long* time")
for /* STRAGGLER STATUS LOOP */ {
// Create a partial index on thread_id_new to find stragglers.
// This index will be removed at the end of the migration.
log.Info(ctx, "creating temporary statuses thread_id_new index")
if _, err := db.NewCreateIndex().
Table("statuses").
Index("statuses_thread_id_new_idx").
Column("thread_id_new").
Where("? = ?", bun.Ident("thread_id_new"), id.Lowest).
Exec(ctx); err != nil {
return gtserror.Newf("error creating new thread_id index: %w", err)
}
for i := 1; ; i++ {
// Reset slice.
clear(statuses)
statuses = statuses[:0]
batchStart := time.Now()
// Select straggler statuses.
if err := db.NewSelect().
Model(&statuses).
Column("id", "in_reply_to_id", "thread_id").
Where("? IS NULL", bun.Ident("thread_id")).
Column("id").
Where("? = ?", bun.Ident("thread_id_new"), id.Lowest).
// We select in smaller batches for this part
// of the migration as there is a chance that
@ -138,7 +212,7 @@ func init() {
// part of the same thread, i.e. one call to
// rethreadStatus() may effect other statuses
// later in the slice.
Limit(1000).
Limit(250).
Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) {
return gtserror.Newf("error selecting straggler statuses: %w", err)
}
@ -149,23 +223,35 @@ func init() {
}
// Rethread each selected batch of straggler statuses in a transaction.
var updatedInBatch int64
if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Rethread each top-level status.
for _, status := range statuses {
n, err := sr.rethreadStatus(ctx, tx, status)
n, err := sr.rethreadStatus(ctx, tx, status, true)
if err != nil {
return gtserror.Newf("error rethreading status %s: %w", status.URI, err)
}
count += n
updatedInBatch += n
updatedTotal += n
}
return nil
}); err != nil {
return err
}
log.Infof(ctx, "[approx %d of %d] rethreading statuses (stragglers)", count, total)
// Show speed for this batch.
timeTaken := time.Since(batchStart).Milliseconds()
msPerRow := float64(timeTaken) / float64(updatedInBatch)
rowsPerMs := float64(1) / float64(msPerRow)
rowsPerSecond := 1000 * rowsPerMs
// Show percent migrated overall.
totalDone := (float64(updatedTotal) / float64(total)) * 100
log.Infof(
ctx,
"[~%.2f%% done; ~%.0f rows/s] migrating stragglers",
totalDone, rowsPerSecond,
)
}
// Attempt to merge any sqlite write-ahead-log.
@ -173,6 +259,13 @@ func init() {
return err
}
log.Info(ctx, "dropping temporary thread_id_new index")
if _, err := db.NewDropIndex().
Index("statuses_thread_id_new_idx").
Exec(ctx); err != nil {
return gtserror.Newf("error dropping temporary thread_id_new index: %w", err)
}
log.Info(ctx, "dropping old thread_to_statuses table")
if _, err := db.NewDropTable().
Table("thread_to_statuses").
@ -180,33 +273,6 @@ func init() {
return gtserror.Newf("error dropping old thread_to_statuses table: %w", err)
}
log.Info(ctx, "creating new statuses thread_id column")
if _, err := db.NewAddColumn().
Table("statuses").
ColumnExpr(newColDef).
Exec(ctx); err != nil {
return gtserror.Newf("error adding new thread_id column: %w", err)
}
log.Info(ctx, "setting thread_id_new = thread_id (this may take a while...)")
if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return batchUpdateByID(ctx, tx,
"statuses", // table
"id", // batchByCol
"UPDATE ? SET ? = ?", // updateQuery
[]any{bun.Ident("statuses"),
bun.Ident("thread_id_new"),
bun.Ident("thread_id")},
)
}); err != nil {
return err
}
// Attempt to merge any sqlite write-ahead-log.
if err := doWALCheckpoint(ctx, db); err != nil {
return err
}
log.Info(ctx, "dropping old statuses thread_id index")
if _, err := db.NewDropIndex().
Index("statuses_thread_id_idx").
@ -274,6 +340,11 @@ type statusRethreader struct {
// its contents are ephemeral.
statuses []*oldmodel.Status
// newThreadIDSet is used to track whether
// statuses in statusIDs have already have
// thread_id_new set on them.
newThreadIDSet map[string]struct{}
// seenIDs tracks the unique status and
// thread IDs we have seen, ensuring we
// don't append duplicates to statusIDs
@ -289,14 +360,15 @@ type statusRethreader struct {
}
// rethreadStatus is the main logic handler for statusRethreader{}. this is what gets called from the migration
// in order to trigger a status rethreading operation for the given status, returning total number rethreaded.
func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, status *oldmodel.Status) (int, error) {
// in order to trigger a status rethreading operation for the given status, returning total number of rows changed.
func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, status *oldmodel.Status, straggler bool) (int64, error) {
// Zero slice and
// map ptr values.
clear(sr.statusIDs)
clear(sr.threadIDs)
clear(sr.statuses)
clear(sr.newThreadIDSet)
clear(sr.seenIDs)
// Reset slices and values for use.
@ -305,6 +377,11 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu
sr.statuses = sr.statuses[:0]
sr.allThreaded = true
if sr.newThreadIDSet == nil {
// Allocate new hash set for newThreadIDSet.
sr.newThreadIDSet = make(map[string]struct{})
}
if sr.seenIDs == nil {
// Allocate new hash set for status IDs.
sr.seenIDs = make(map[string]struct{})
@ -317,12 +394,22 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu
// to the rethreadStatus() call.
if err := tx.NewSelect().
Model(status).
Column("in_reply_to_id", "thread_id").
Column("in_reply_to_id", "thread_id", "thread_id_new").
Where("? = ?", bun.Ident("id"), status.ID).
Scan(ctx); err != nil {
return 0, gtserror.Newf("error selecting status: %w", err)
}
// If we've just threaded this status by setting
// thread_id_new, then by definition anything we
// could find from the entire thread must now be
// threaded, so we can save some database calls
// by skipping iterating up + down from here.
if status.ThreadIDNew != id.Lowest {
log.Debugf(ctx, "skipping just rethreaded status: %s", status.ID)
return 0, nil
}
// status and thread ID cursor
// index values. these are used
// to keep track of newly loaded
@ -371,14 +458,14 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu
threadIdx = len(sr.threadIDs)
}
// Total number of
// statuses threaded.
total := len(sr.statusIDs)
// Check for the case where the entire
// batch of statuses is already correctly
// threaded. Then we have nothing to do!
if sr.allThreaded && len(sr.threadIDs) == 1 {
//
// Skip this check for straggler statuses
// that are part of broken threads.
if !straggler && sr.allThreaded && len(sr.threadIDs) == 1 {
log.Debug(ctx, "skipping just rethreaded thread")
return 0, nil
}
@ -417,36 +504,120 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu
}
}
// Update all the statuses to
// use determined thread_id.
if _, err := tx.NewUpdate().
Table("statuses").
Where("? IN (?)", bun.Ident("id"), bun.In(sr.statusIDs)).
Set("? = ?", bun.Ident("thread_id"), threadID).
Exec(ctx); err != nil {
var (
res sql.Result
err error
)
if len(sr.statusIDs) == 1 {
// If we're only updating one status
// we can use a simple update query.
res, err = tx.NewUpdate().
// Update the status model.
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
// Set the new thread ID, which we can use as
// an indication that we've migrated this batch.
Set("? = ?", bun.Ident("thread_id_new"), threadID).
// While we're here, also set old thread_id, as
// we'll use it for further rethreading purposes.
Set("? = ?", bun.Ident("thread_id"), threadID).
Where("? = ?", bun.Ident("status.id"), sr.statusIDs[0]).
Exec(ctx)
} else {
// If we're updating multiple statuses at once,
// build up a common table expression to update
// all statuses in this thread to use threadID.
//
// This ought to be a little more readable than
// using an "IN(*)" query, and PG or SQLite *may*
// be able to optimize it better.
//
// See:
//
// - https://sqlite.org/lang_with.html
// - https://www.postgresql.org/docs/current/queries-with.html
// - https://bun.uptrace.dev/guide/query-update.html#bulk-update
values := make([]*util.Status, 0, len(sr.statusIDs))
for _, statusID := range sr.statusIDs {
// Filter out statusIDs that have already had
// thread_id_new set, to avoid spurious writes.
if _, set := sr.newThreadIDSet[statusID]; !set {
values = append(values, &util.Status{
ID: statusID,
})
}
}
// Resulting query will look something like this:
//
// WITH "_data" ("id") AS (
// VALUES
// ('01JR6PZED0DDR2VZHQ8H87ZW98'),
// ('01JR6PZED0J91MJCAFDTCCCG8Q')
// )
// UPDATE "statuses" AS "status"
// SET
// "thread_id_new" = '01K6MGKX54BBJ3Y1FBPQY45E5P',
// "thread_id" = '01K6MGKX54BBJ3Y1FBPQY45E5P'
// FROM _data
// WHERE "status"."id" = "_data"."id"
res, err = tx.NewUpdate().
// Update the status model.
Model((*oldmodel.Status)(nil)).
// Provide the CTE values as "_data".
With("_data", tx.NewValues(&values)).
// Include `FROM _data` statement so we can use
// `_data` table in SET and WHERE components.
TableExpr("_data").
// Set the new thread ID, which we can use as
// an indication that we've migrated this batch.
Set("? = ?", bun.Ident("thread_id_new"), threadID).
// While we're here, also set old thread_id, as
// we'll use it for further rethreading purposes.
Set("? = ?", bun.Ident("thread_id"), threadID).
// "Join" to the CTE on status ID.
Where("? = ?", bun.Ident("status.id"), bun.Ident("_data.id")).
Exec(ctx)
}
if err != nil {
return 0, gtserror.Newf("error updating status thread ids: %w", err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return 0, gtserror.Newf("error counting rows affected: %w", err)
}
if len(sr.threadIDs) > 0 {
// Update any existing thread
// mutes to use latest thread_id.
// Dedupe thread IDs before query
// to avoid ludicrous "IN" clause.
threadIDs := sr.threadIDs
threadIDs = xslices.Deduplicate(threadIDs)
if _, err := tx.NewUpdate().
Table("thread_mutes").
Where("? IN (?)", bun.Ident("thread_id"), bun.In(sr.threadIDs)).
Where("? IN (?)", bun.Ident("thread_id"), bun.In(threadIDs)).
Set("? = ?", bun.Ident("thread_id"), threadID).
Exec(ctx); err != nil {
return 0, gtserror.Newf("error updating mute thread ids: %w", err)
}
}
return total, nil
return rowsAffected, nil
}
// append will append the given status to the internal tracking of statusRethreader{} for
// potential future operations, checking for uniqueness. it tracks the inReplyToID value
// for the next call to getParents(), it tracks the status ID for list of statuses that
// need updating, the thread ID for the list of thread links and mutes that need updating,
// and whether all the statuses all have a provided thread ID (i.e. allThreaded).
// may need updating, whether a new thread ID has been set for each status, the thread ID
// for the list of thread links and mutes that need updating, and whether all the statuses
// all have a provided thread ID (i.e. allThreaded).
func (sr *statusRethreader) append(status *oldmodel.Status) {
// Check if status already seen before.
@ -479,7 +650,14 @@ func (sr *statusRethreader) append(status *oldmodel.Status) {
}
// Add status ID to map of seen IDs.
sr.seenIDs[status.ID] = struct{}{}
mark := struct{}{}
sr.seenIDs[status.ID] = mark
// If new thread ID has already been
// set, add status ID to map of set IDs.
if status.ThreadIDNew != id.Lowest {
sr.newThreadIDSet[status.ID] = mark
}
}
func (sr *statusRethreader) getParents(ctx context.Context, tx bun.Tx) error {
@ -496,7 +674,7 @@ func (sr *statusRethreader) getParents(ctx context.Context, tx bun.Tx) error {
// Select next parent status.
if err := tx.NewSelect().
Model(&parent).
Column("id", "in_reply_to_id", "thread_id").
Column("id", "in_reply_to_id", "thread_id", "thread_id_new").
Where("? = ?", bun.Ident("id"), id).
Scan(ctx); err != nil && err != db.ErrNoEntries {
return err
@ -535,7 +713,7 @@ func (sr *statusRethreader) getChildren(ctx context.Context, tx bun.Tx, idx int)
// Select children of ID.
if err := tx.NewSelect().
Model(&sr.statuses).
Column("id", "thread_id").
Column("id", "thread_id", "thread_id_new").
Where("? = ?", bun.Ident("in_reply_to_id"), id).
Scan(ctx); err != nil && err != db.ErrNoEntries {
return err
@ -560,14 +738,19 @@ func (sr *statusRethreader) getStragglers(ctx context.Context, tx bun.Tx, idx in
clear(sr.statuses)
sr.statuses = sr.statuses[:0]
// Dedupe thread IDs before query
// to avoid ludicrous "IN" clause.
threadIDs := sr.threadIDs[idx:]
threadIDs = xslices.Deduplicate(threadIDs)
// Select stragglers that
// also have thread IDs.
if err := tx.NewSelect().
Model(&sr.statuses).
Column("id", "thread_id", "in_reply_to_id").
Column("id", "thread_id", "in_reply_to_id", "thread_id_new").
Where("? IN (?) AND ? NOT IN (?)",
bun.Ident("thread_id"),
bun.In(sr.threadIDs[idx:]),
bun.In(threadIDs),
bun.Ident("id"),
bun.In(sr.statusIDs),
).