diff --git a/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go b/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go index bfa4dd84f..daf392ee6 100644 --- a/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go +++ b/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go @@ -26,7 +26,6 @@ import ( "strings" "time" - "code.superseriousbusiness.org/gotosocial/internal/db" newmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/new" oldmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/old" "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/util" @@ -62,44 +61,80 @@ func init() { return gtserror.Newf("error adding statuses column thread_id_new: %w", err) } - var sr statusRethreader - var updatedRowsTotal int64 - var maxID string - var statuses []*oldmodel.Status + if err := doWALCheckpoint(ctx, db); err != nil { + return err + } - // Get a total count of all statuses before migration. - total, err := db.NewSelect().Table("statuses").Count(ctx) + // Get a total count of all + // statuses before migration. + totalStatuses, err := db. + NewSelect(). + Table("statuses"). + Count(ctx) if err != nil { return gtserror.Newf("error getting status table count: %w", err) } + log.Warnf(ctx, "migrating %d statuses total, this may take a *long* time", totalStatuses) - // Start at largest + var sr statusRethreader + var updatedRowsTotal int64 + var statuses []*oldmodel.Status + + // Page starting at largest // possible ULID value. - maxID = id.Highest + var maxID = id.Highest - log.Warnf(ctx, "migrating %d statuses, this may take a *long* time", total) - for { - start := time.Now() + // Open initial transaction. + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + + for i := 1; ; i++ { // Reset slice. clear(statuses) statuses = statuses[:0] + start := time.Now() + // Select IDs of next // batch, paging down. - if err := db.NewSelect(). + if err := tx.NewSelect(). Model(&statuses). Column("id"). + Where("? IS NULL", bun.Ident("in_reply_to_id")). Where("? < ?", bun.Ident("id"), maxID). OrderExpr("? DESC", bun.Ident("id")). - Limit(250). + Limit(100). Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) { - return gtserror.Newf("error selecting unthreaded statuses: %w", err) + return gtserror.Newf("error selecting top-level statuses: %w", err) + } + + // Every 50 loops, flush wal and begin new + // transaction, to avoid silly wal sizes. + if i%50 == 0 { + if err := tx.Commit(); err != nil { + return err + } + + if err := doWALCheckpoint(ctx, db); err != nil { + return err + } + + tx, err = db.BeginTx(ctx, nil) + if err != nil { + return err + } } // No more statuses! l := len(statuses) if l == 0 { + if err := tx.Commit(); err != nil { + return err + } + log.Info(ctx, "done migrating statuses!") break } @@ -107,52 +142,137 @@ func init() { // Set next maxID value from statuses. maxID = statuses[l-1].ID - // Rethread each selected status in a transaction. + // Rethread inside the transaction. var updatedRowsThisBatch int64 - if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - for _, status := range statuses { - n, err := sr.rethreadStatus(ctx, tx, status) - if err != nil { - return gtserror.Newf("error rethreading status %s: %w", status.URI, err) - } - updatedRowsThisBatch += n - updatedRowsTotal += n + for _, status := range statuses { + n, err := sr.rethreadStatus(ctx, tx, status) + if err != nil { + return gtserror.Newf("error rethreading status %s: %w", status.URI, err) } - - return nil - }); err != nil { - return err + updatedRowsThisBatch += n + updatedRowsTotal += n } - // Show current speed + percent migrated. - // - // Percent may end up wonky due to approximations - // and batching, so show a generic message at 100%. + // Show speed for this batch. timeTaken := time.Since(start).Milliseconds() msPerRow := float64(timeTaken) / float64(updatedRowsThisBatch) rowsPerMs := float64(1) / float64(msPerRow) rowsPerSecond := 1000 * rowsPerMs - percentDone := (float64(updatedRowsTotal) / float64(total)) * 100 - if percentDone <= 100 { - log.Infof( - ctx, - "[updated %d total rows, now @ ~%.0f rows/s] done ~%.2f%% of statuses", - updatedRowsTotal, rowsPerSecond, percentDone, - ) - } else { - log.Infof( - ctx, - "[updated %d total rows, now @ ~%.0f rows/s] almost done... ", - updatedRowsTotal, rowsPerSecond, - ) - } + + // Show percent migrated overall. + totalDone := (float64(updatedRowsTotal) / float64(totalStatuses)) * 100 + + log.Infof( + ctx, + "[~%.2f%% done; ~%.0f rows/s] paging top-level statuses", + totalDone, rowsPerSecond, + ) } - // Attempt to merge any sqlite write-ahead-log. if err := doWALCheckpoint(ctx, db); err != nil { return err } + // Reset max ID. + maxID = id.Highest + + // Create a temporary index on thread_id_new for stragglers. + log.Info(ctx, "creating temporary statuses thread_id_new index") + if _, err := db.NewCreateIndex(). + Table("statuses"). + Index("statuses_thread_id_new_idx"). + Column("thread_id_new"). + Exec(ctx); err != nil { + return gtserror.Newf("error creating new thread_id index: %w", err) + } + + // Open a new transaction lads. + tx, err = db.BeginTx(ctx, nil) + if err != nil { + return err + } + + for i := 1; ; i++ { + + // Reset slice. + clear(statuses) + statuses = statuses[:0] + + start := time.Now() + + // Select IDs of stragglers for + // which we haven't set thread_id yet. + if err := tx.NewSelect(). + Model(&statuses). + Column("id"). + Where("? = ?", bun.Ident("thread_id_new"), id.Lowest). + Limit(500). + Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) { + return gtserror.Newf("error selecting unthreaded statuses: %w", err) + } + + // Every 50 loops, flush wal and begin new + // transaction, to avoid silly wal sizes. + if i%50 == 0 { + if err := tx.Commit(); err != nil { + return err + } + + if err := doWALCheckpoint(ctx, db); err != nil { + return err + } + + tx, err = db.BeginTx(ctx, nil) + if err != nil { + return err + } + } + + // No more statuses! + l := len(statuses) + if l == 0 { + if err := tx.Commit(); err != nil { + return err + } + + log.Info(ctx, "done migrating statuses!") + break + } + + // Rethread inside the transaction. + var updatedRowsThisBatch int64 + for _, status := range statuses { + n, err := sr.rethreadStatus(ctx, tx, status) + if err != nil { + return gtserror.Newf("error rethreading status %s: %w", status.URI, err) + } + updatedRowsThisBatch += n + updatedRowsTotal += n + } + + // Show speed for this batch. + timeTaken := time.Since(start).Milliseconds() + msPerRow := float64(timeTaken) / float64(updatedRowsThisBatch) + rowsPerMs := float64(1) / float64(msPerRow) + rowsPerSecond := 1000 * rowsPerMs + + // Show percent migrated overall. + totalDone := (float64(updatedRowsTotal) / float64(totalStatuses)) * 100 + + log.Infof( + ctx, + "[~%.2f%% done; ~%.0f rows/s] cleaning up stragglers", + totalDone, rowsPerSecond, + ) + } + + log.Info(ctx, "dropping temporary thread_id_new index") + if _, err := db.NewDropIndex(). + Index("statuses_thread_id_new_idx"). + Exec(ctx); err != nil { + return gtserror.Newf("error dropping temporary thread_id_new index: %w", err) + } + log.Info(ctx, "dropping old thread_to_statuses table") if _, err := db.NewDropTable(). Table("thread_to_statuses"). @@ -265,17 +385,33 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu // Ensure the passed status // has up-to-date information. - // This may have changed from - // the initial batch selection - // to the rethreadStatus() call. + upToDateValues := make(map[string]any, 3) if err := tx.NewSelect(). - Model(status). - Column("in_reply_to_id", "thread_id"). + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Column("in_reply_to_id", "thread_id", "thread_id_new"). Where("? = ?", bun.Ident("id"), status.ID). - Scan(ctx); err != nil { + Scan(ctx, &upToDateValues); err != nil { return 0, gtserror.Newf("error selecting status: %w", err) } + // If we've just threaded this status by setting + // thread_id_new, then by definition anything we + // could find from the entire thread must now be + // threaded, so we can save some database calls + // by skipping iterating up + down from here. + if v, ok := upToDateValues["thread_id_new"]; ok && v.(string) != id.Lowest { + log.Debug(ctx, "skipping just rethreaded status") + return 0, nil + } + + // Set up-to-date values on the status. + if inReplyToID, ok := upToDateValues["in_reply_to_id"]; ok && inReplyToID != nil { + status.InReplyToID = inReplyToID.(string) + } + if threadID, ok := upToDateValues["thread_id"]; ok && threadID != nil { + status.ThreadID = threadID.(string) + } + // status and thread ID cursor // index values. these are used // to keep track of newly loaded @@ -328,6 +464,7 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu // batch of statuses is already correctly // threaded. Then we have nothing to do! if sr.allThreaded && len(sr.threadIDs) == 1 { + log.Debug(ctx, "skipping just rethreaded thread") return 0, nil } @@ -479,7 +616,7 @@ func (sr *statusRethreader) getParents(ctx context.Context, tx bun.Tx) error { Model(&parent). Column("id", "in_reply_to_id", "thread_id"). Where("? = ?", bun.Ident("id"), id). - Scan(ctx); err != nil && err != db.ErrNoEntries { + Scan(ctx); err != nil && err != sql.ErrNoRows { return err } @@ -518,7 +655,7 @@ func (sr *statusRethreader) getChildren(ctx context.Context, tx bun.Tx, idx int) Model(&sr.statuses). Column("id", "thread_id"). Where("? = ?", bun.Ident("in_reply_to_id"), id). - Scan(ctx); err != nil && err != db.ErrNoEntries { + Scan(ctx); err != nil && err != sql.ErrNoRows { return err } @@ -557,7 +694,7 @@ func (sr *statusRethreader) getStragglers(ctx context.Context, tx bun.Tx, idx in bun.Ident("id"), bun.In(sr.statusIDs), ). - Scan(ctx); err != nil && err != db.ErrNoEntries { + Scan(ctx); err != nil && err != sql.ErrNoRows { return err }