Compare commits

...

5 commits

Author SHA1 Message Date
tobi
d212736165 boobs 2025-09-29 18:56:44 +02:00
tobi
6c04ae231c i'm adjusting the PR, pray i don't adjust it further 2025-09-29 16:48:46 +02:00
tobi
9e2fd4734b should be done poking now 2025-09-29 12:11:37 +02:00
tobi
408ddc367d whoops 2025-09-29 12:02:19 +02:00
tobi
228b41cb53 few more little tweaks 2025-09-29 11:59:35 +02:00

View file

@ -26,7 +26,6 @@ import (
"strings"
"time"
"code.superseriousbusiness.org/gotosocial/internal/db"
newmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/new"
oldmodel "code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/old"
"code.superseriousbusiness.org/gotosocial/internal/db/bundb/migrations/20250415111056_thread_all_statuses/util"
@ -62,44 +61,80 @@ func init() {
return gtserror.Newf("error adding statuses column thread_id_new: %w", err)
}
var sr statusRethreader
var updatedRowsTotal int64
var maxID string
var statuses []*oldmodel.Status
if err := doWALCheckpoint(ctx, db); err != nil {
return err
}
// Get a total count of all statuses before migration.
total, err := db.NewSelect().Table("statuses").Count(ctx)
// Get a total count of all
// statuses before migration.
totalStatuses, err := db.
NewSelect().
Table("statuses").
Count(ctx)
if err != nil {
return gtserror.Newf("error getting status table count: %w", err)
}
log.Warnf(ctx, "migrating %d statuses total, this may take a *long* time", totalStatuses)
// Start at largest
var sr statusRethreader
var updatedRowsTotal int64
var statuses []*oldmodel.Status
// Page starting at largest
// possible ULID value.
maxID = id.Highest
var maxID = id.Highest
log.Warnf(ctx, "migrating %d statuses, this may take a *long* time", total)
for {
start := time.Now()
// Open initial transaction.
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
for i := 1; ; i++ {
// Reset slice.
clear(statuses)
statuses = statuses[:0]
start := time.Now()
// Select IDs of next
// batch, paging down.
if err := db.NewSelect().
if err := tx.NewSelect().
Model(&statuses).
Column("id").
Where("? IS NULL", bun.Ident("in_reply_to_id")).
Where("? < ?", bun.Ident("id"), maxID).
OrderExpr("? DESC", bun.Ident("id")).
Limit(250).
Limit(100).
Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) {
return gtserror.Newf("error selecting unthreaded statuses: %w", err)
return gtserror.Newf("error selecting top-level statuses: %w", err)
}
// Every 50 loops, flush wal and begin new
// transaction, to avoid silly wal sizes.
if i%50 == 0 {
if err := tx.Commit(); err != nil {
return err
}
if err := doWALCheckpoint(ctx, db); err != nil {
return err
}
tx, err = db.BeginTx(ctx, nil)
if err != nil {
return err
}
}
// No more statuses!
l := len(statuses)
if l == 0 {
if err := tx.Commit(); err != nil {
return err
}
log.Info(ctx, "done migrating statuses!")
break
}
@ -107,52 +142,137 @@ func init() {
// Set next maxID value from statuses.
maxID = statuses[l-1].ID
// Rethread each selected status in a transaction.
// Rethread inside the transaction.
var updatedRowsThisBatch int64
if err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
for _, status := range statuses {
n, err := sr.rethreadStatus(ctx, tx, status)
if err != nil {
return gtserror.Newf("error rethreading status %s: %w", status.URI, err)
}
updatedRowsThisBatch += n
updatedRowsTotal += n
for _, status := range statuses {
n, err := sr.rethreadStatus(ctx, tx, status)
if err != nil {
return gtserror.Newf("error rethreading status %s: %w", status.URI, err)
}
return nil
}); err != nil {
return err
updatedRowsThisBatch += n
updatedRowsTotal += n
}
// Show current speed + percent migrated.
//
// Percent may end up wonky due to approximations
// and batching, so show a generic message at 100%.
// Show speed for this batch.
timeTaken := time.Since(start).Milliseconds()
msPerRow := float64(timeTaken) / float64(updatedRowsThisBatch)
rowsPerMs := float64(1) / float64(msPerRow)
rowsPerSecond := 1000 * rowsPerMs
percentDone := (float64(updatedRowsTotal) / float64(total)) * 100
if percentDone <= 100 {
log.Infof(
ctx,
"[updated %d total rows, now @ ~%.0f rows/s] done ~%.2f%% of statuses",
updatedRowsTotal, rowsPerSecond, percentDone,
)
} else {
log.Infof(
ctx,
"[updated %d total rows, now @ ~%.0f rows/s] almost done... ",
updatedRowsTotal, rowsPerSecond,
)
}
// Show percent migrated overall.
totalDone := (float64(updatedRowsTotal) / float64(totalStatuses)) * 100
log.Infof(
ctx,
"[~%.2f%% done; ~%.0f rows/s] paging top-level statuses",
totalDone, rowsPerSecond,
)
}
// Attempt to merge any sqlite write-ahead-log.
if err := doWALCheckpoint(ctx, db); err != nil {
return err
}
// Reset max ID.
maxID = id.Highest
// Create a temporary index on thread_id_new for stragglers.
log.Info(ctx, "creating temporary statuses thread_id_new index")
if _, err := db.NewCreateIndex().
Table("statuses").
Index("statuses_thread_id_new_idx").
Column("thread_id_new").
Exec(ctx); err != nil {
return gtserror.Newf("error creating new thread_id index: %w", err)
}
// Open a new transaction lads.
tx, err = db.BeginTx(ctx, nil)
if err != nil {
return err
}
for i := 1; ; i++ {
// Reset slice.
clear(statuses)
statuses = statuses[:0]
start := time.Now()
// Select IDs of stragglers for
// which we haven't set thread_id yet.
if err := tx.NewSelect().
Model(&statuses).
Column("id").
Where("? = ?", bun.Ident("thread_id_new"), id.Lowest).
Limit(500).
Scan(ctx); err != nil && !errors.Is(err, sql.ErrNoRows) {
return gtserror.Newf("error selecting unthreaded statuses: %w", err)
}
// Every 50 loops, flush wal and begin new
// transaction, to avoid silly wal sizes.
if i%50 == 0 {
if err := tx.Commit(); err != nil {
return err
}
if err := doWALCheckpoint(ctx, db); err != nil {
return err
}
tx, err = db.BeginTx(ctx, nil)
if err != nil {
return err
}
}
// No more statuses!
l := len(statuses)
if l == 0 {
if err := tx.Commit(); err != nil {
return err
}
log.Info(ctx, "done migrating statuses!")
break
}
// Rethread inside the transaction.
var updatedRowsThisBatch int64
for _, status := range statuses {
n, err := sr.rethreadStatus(ctx, tx, status)
if err != nil {
return gtserror.Newf("error rethreading status %s: %w", status.URI, err)
}
updatedRowsThisBatch += n
updatedRowsTotal += n
}
// Show speed for this batch.
timeTaken := time.Since(start).Milliseconds()
msPerRow := float64(timeTaken) / float64(updatedRowsThisBatch)
rowsPerMs := float64(1) / float64(msPerRow)
rowsPerSecond := 1000 * rowsPerMs
// Show percent migrated overall.
totalDone := (float64(updatedRowsTotal) / float64(totalStatuses)) * 100
log.Infof(
ctx,
"[~%.2f%% done; ~%.0f rows/s] cleaning up stragglers",
totalDone, rowsPerSecond,
)
}
log.Info(ctx, "dropping temporary thread_id_new index")
if _, err := db.NewDropIndex().
Index("statuses_thread_id_new_idx").
Exec(ctx); err != nil {
return gtserror.Newf("error dropping temporary thread_id_new index: %w", err)
}
log.Info(ctx, "dropping old thread_to_statuses table")
if _, err := db.NewDropTable().
Table("thread_to_statuses").
@ -265,17 +385,33 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu
// Ensure the passed status
// has up-to-date information.
// This may have changed from
// the initial batch selection
// to the rethreadStatus() call.
upToDateValues := make(map[string]any, 3)
if err := tx.NewSelect().
Model(status).
Column("in_reply_to_id", "thread_id").
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("in_reply_to_id", "thread_id", "thread_id_new").
Where("? = ?", bun.Ident("id"), status.ID).
Scan(ctx); err != nil {
Scan(ctx, &upToDateValues); err != nil {
return 0, gtserror.Newf("error selecting status: %w", err)
}
// If we've just threaded this status by setting
// thread_id_new, then by definition anything we
// could find from the entire thread must now be
// threaded, so we can save some database calls
// by skipping iterating up + down from here.
if v, ok := upToDateValues["thread_id_new"]; ok && v.(string) != id.Lowest {
log.Debug(ctx, "skipping just rethreaded status")
return 0, nil
}
// Set up-to-date values on the status.
if inReplyToID, ok := upToDateValues["in_reply_to_id"]; ok && inReplyToID != nil {
status.InReplyToID = inReplyToID.(string)
}
if threadID, ok := upToDateValues["thread_id"]; ok && threadID != nil {
status.ThreadID = threadID.(string)
}
// status and thread ID cursor
// index values. these are used
// to keep track of newly loaded
@ -328,6 +464,7 @@ func (sr *statusRethreader) rethreadStatus(ctx context.Context, tx bun.Tx, statu
// batch of statuses is already correctly
// threaded. Then we have nothing to do!
if sr.allThreaded && len(sr.threadIDs) == 1 {
log.Debug(ctx, "skipping just rethreaded thread")
return 0, nil
}
@ -479,7 +616,7 @@ func (sr *statusRethreader) getParents(ctx context.Context, tx bun.Tx) error {
Model(&parent).
Column("id", "in_reply_to_id", "thread_id").
Where("? = ?", bun.Ident("id"), id).
Scan(ctx); err != nil && err != db.ErrNoEntries {
Scan(ctx); err != nil && err != sql.ErrNoRows {
return err
}
@ -518,7 +655,7 @@ func (sr *statusRethreader) getChildren(ctx context.Context, tx bun.Tx, idx int)
Model(&sr.statuses).
Column("id", "thread_id").
Where("? = ?", bun.Ident("in_reply_to_id"), id).
Scan(ctx); err != nil && err != db.ErrNoEntries {
Scan(ctx); err != nil && err != sql.ErrNoRows {
return err
}
@ -557,7 +694,7 @@ func (sr *statusRethreader) getStragglers(ctx context.Context, tx bun.Tx, idx in
bun.Ident("id"),
bun.In(sr.statusIDs),
).
Scan(ctx); err != nil && err != db.ErrNoEntries {
Scan(ctx); err != nil && err != sql.ErrNoRows {
return err
}