diff --git a/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go b/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go index daf392ee6..bfa4dd84f 100644 --- a/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go +++ b/internal/db/bundb/migrations/20250415111056_thread_all_statuses.go @@ -26,6 +26,7 @@ import ( "strings" "time" + "code.superseriousbusiness.org/gotosocial/internal/db" newmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/new" oldmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/old" "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/util" @@ -61,80 +62,44 @@ func init() { return gtserror.Newf("error adding statuses column thread_id_new: %w", err) } - if err := doWALCheckpoint(ctx, db); err != nil { - return err - } + var sr statusRethreader + var updatedRowsTotal int64 + var maxID string + var statuses []*oldmodel.Status - // Get a total count of all - // statuses before migration. - totalStatuses, err := db. - NewSelect(). - Table("statuses"). - Count(ctx) + // Get a total count of all statuses before migration. + total, err := db.NewSelect().Table("statuses").Count(ctx) if err != nil { return gtserror.Newf("error getting status table count: %w", err) } - log.Warnf(ctx, "migrating %d statuses total, this may take a *long* time", totalStatuses) - var sr statusRethreader - var updatedRowsTotal int64 - var statuses []*oldmodel.Status - - // Page starting at largest + // Start at largest // possible ULID value. - var maxID = id.Highest + maxID = id.Highest - // Open initial transaction. - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return err - } - - for i := 1; ; i++ { + log.Warnf(ctx, "migrating %d statuses, this may take a *long* time", total) + for { + start := time.Now() // Reset slice. clear(statuses) statuses = statuses[:0] - start := time.Now() - // Select IDs of next // batch, paging down. - if err := tx.NewSelect(). + if err := db.NewSelect(). Model(&statuses). Column("id"). - Where("? IS NULL", bun.Ident("in_reply_to_id")). Where("? < ?", bun.Ident("id"), maxID). OrderExpr("? DESC", bun.Ident("id")). - Limit(100). + Limit(250). Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) { - return gtserror.Newf("error selecting top-level statuses: %w", err) - } - - // Every 50 loops, flush wal and begin new - // transaction, to avoid silly wal sizes. - if i%50 == 0 { - if err := tx.Commit(); err != nil { - return err - } - - if err := doWALCheckpoint(ctx, db); err != nil { - return err - } - - tx, err = db.BeginTx(ctx, nil) - if err != nil { - return err - } + return gtserror.Newf("error selecting unthreaded statuses: %w", err) } // No more statuses! l := len(statuses) if l == 0 { - if err := tx.Commit(); err != nil { - return err - } - log.Info(ctx, "done migrating statuses!") break } @@ -142,137 +107,52 @@ func init() { // Set next maxID value from statuses. maxID = statuses[l-1].ID - // Rethread inside the transaction. + // Rethread each selected status in a transaction. var updatedRowsThisBatch int64 - for _, status := range statuses { - n, err := sr.rethreadStatus(ctx, tx, status) - if err != nil { - return gtserror.Newf("error rethreading status %s: %w", status.URI, err) + if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + for _, status := range statuses { + n, err := sr.rethreadStatus(ctx, tx, status) + if err != nil { + return gtserror.Newf("error rethreading status %s: %w", status.URI, err) + } + updatedRowsThisBatch += n + updatedRowsTotal += n } - updatedRowsThisBatch += n - updatedRowsTotal += n + + return nil + }); err != nil { + return err } - // Show speed for this batch. + // Show current speed + percent migrated. + // + // Percent may end up wonky due to approximations + // and batching, so show a generic message at 100%. timeTaken := time.Since(start).Milliseconds() msPerRow := float64(timeTaken) / float64(updatedRowsThisBatch) rowsPerMs := float64(1) / float64(msPerRow) rowsPerSecond := 1000 * rowsPerMs - - // Show percent migrated overall. - totalDone := (float64(updatedRowsTotal) / float64(totalStatuses)) * 100 - - log.Infof( - ctx, - "[~%.2f%% done; ~%.0f rows/s] paging top-level statuses", - totalDone, rowsPerSecond, - ) + percentDone := (float64(updatedRowsTotal) / float64(total)) * 100 + if percentDone <= 100 { + log.Infof( + ctx, + "[updated %d total rows, now @ ~%.0f rows/s] done ~%.2f%% of statuses", + updatedRowsTotal, rowsPerSecond, percentDone, + ) + } else { + log.Infof( + ctx, + "[updated %d total rows, now @ ~%.0f rows/s] almost done... ", + updatedRowsTotal, rowsPerSecond, + ) + } } + // Attempt to merge any sqlite write-ahead-log. if err := doWALCheckpoint(ctx, db); err != nil { return err } - // Reset max ID. - maxID = id.Highest - - // Create a temporary index on thread_id_new for stragglers. - log.Info(ctx, "creating temporary statuses thread_id_new index") - if _, err := db.NewCreateIndex(). - Table("statuses"). - Index("statuses_thread_id_new_idx"). - Column("thread_id_new"). - Exec(ctx); err != nil { - return gtserror.Newf("error creating new thread_id index: %w", err) - } - - // Open a new transaction lads. - tx, err = db.BeginTx(ctx, nil) - if err != nil { - return err - } - - for i := 1; ; i++ { - - // Reset slice. - clear(statuses) - statuses = statuses[:0] - - start := time.Now() - - // Select IDs of stragglers for - // which we haven't set thread_id yet. - if err := tx.NewSelect(). - Model(&statuses). - Column("id"). - Where("? = ?", bun.Ident("thread_id_new"), id.Lowest). - Limit(500). - Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) { - return gtserror.Newf("error selecting unthreaded statuses: %w", err) - } - - // Every 50 loops, flush wal and begin new - // transaction, to avoid silly wal sizes. - if i%50 == 0 { - if err := tx.Commit(); err != nil { - return err - } - - if err := doWALCheckpoint(ctx, db); err != nil { - return err - } - - tx, err = db.BeginTx(ctx, nil) - if err != nil { - return err - } - } - - // No more statuses! - l := len(statuses) - if l == 0 { - if err := tx.Commit(); err != nil { - return err - } - - log.Info(ctx, "done migrating statuses!") - break - } - - // Rethread inside the transaction. - var updatedRowsThisBatch int64 - for _, status := range statuses { - n, err := sr.rethreadStatus(ctx, tx, status) - if err != nil { - return gtserror.Newf("error rethreading status %s: %w", status.URI, err) - } - updatedRowsThisBatch += n - updatedRowsTotal += n - } - - // Show speed for this batch. - timeTaken := time.Since(start).Milliseconds() - msPerRow := float64(timeTaken) / float64(updatedRowsThisBatch) - rowsPerMs := float64(1) / float64(msPerRow) - rowsPerSecond := 1000 * rowsPerMs - - // Show percent migrated overall. - totalDone := (float64(updatedRowsTotal) / float64(totalStatuses)) * 100 - - log.Infof( - ctx, - "[~%.2f%% done; ~%.0f rows/s] cleaning up stragglers", - totalDone, rowsPerSecond, - ) - } - - log.Info(ctx, "dropping temporary thread_id_new index") - if _, err := db.NewDropIndex(). - Index("statuses_thread_id_new_idx"). - Exec(ctx); err != nil { - return gtserror.Newf("error dropping temporary thread_id_new index: %w", err) - } - log.Info(ctx, "dropping old thread_to_statuses table") if _, err := db.NewDropTable(). Table("thread_to_statuses"). @@ -385,33 +265,17 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu // Ensure the passed status // has up-to-date information. - upToDateValues := make(map[string]any, 3) + // This may have changed from + // the initial batch selection + // to the rethreadStatus() call. if err := tx.NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Column("in_reply_to_id", "thread_id", "thread_id_new"). + Model(status). + Column("in_reply_to_id", "thread_id"). Where("? = ?", bun.Ident("id"), status.ID). - Scan(ctx, &upToDateValues); err != nil { + Scan(ctx); err != nil { return 0, gtserror.Newf("error selecting status: %w", err) } - // If we've just threaded this status by setting - // thread_id_new, then by definition anything we - // could find from the entire thread must now be - // threaded, so we can save some database calls - // by skipping iterating up + down from here. - if v, ok := upToDateValues["thread_id_new"]; ok && v.(string) != id.Lowest { - log.Debug(ctx, "skipping just rethreaded status") - return 0, nil - } - - // Set up-to-date values on the status. - if inReplyToID, ok := upToDateValues["in_reply_to_id"]; ok && inReplyToID != nil { - status.InReplyToID = inReplyToID.(string) - } - if threadID, ok := upToDateValues["thread_id"]; ok && threadID != nil { - status.ThreadID = threadID.(string) - } - // status and thread ID cursor // index values. these are used // to keep track of newly loaded @@ -464,7 +328,6 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu // batch of statuses is already correctly // threaded. Then we have nothing to do! if sr.allThreaded && len(sr.threadIDs) == 1 { - log.Debug(ctx, "skipping just rethreaded thread") return 0, nil } @@ -616,7 +479,7 @@ func (sr *statusRethreader) getParents(ctx context.Context, tx bun.Tx) error { Model(&parent). Column("id", "in_reply_to_id", "thread_id"). Where("? = ?", bun.Ident("id"), id). - Scan(ctx); err != nil && err != sql.ErrNoRows { + Scan(ctx); err != nil && err != db.ErrNoEntries { return err } @@ -655,7 +518,7 @@ func (sr *statusRethreader) getChildren(ctx context.Context, tx bun.Tx, idx int) Model(&sr.statuses). Column("id", "thread_id"). Where("? = ?", bun.Ident("in_reply_to_id"), id). - Scan(ctx); err != nil && err != sql.ErrNoRows { + Scan(ctx); err != nil && err != db.ErrNoEntries { return err } @@ -694,7 +557,7 @@ func (sr *statusRethreader) getStragglers(ctx context.Context, tx bun.Tx, idx in bun.Ident("id"), bun.In(sr.statusIDs), ). - Scan(ctx); err != nil && err != sql.ErrNoRows { + Scan(ctx); err != nil && err != db.ErrNoEntries { return err }