From db91fe5a946f80fe16a1ff165c1dc2e3e765626c Mon Sep 17 00:00:00 2001 From: kim Date: Tue, 8 Apr 2025 23:20:36 +0100 Subject: [PATCH] better concurrency safety in Clear() and Done() --- internal/cache/timeline/preload.go | 43 +++++++++++++++--------------- internal/cache/timeline/status.go | 6 ++--- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/internal/cache/timeline/preload.go b/internal/cache/timeline/preload.go index 7661e708f..35b03ffe5 100644 --- a/internal/cache/timeline/preload.go +++ b/internal/cache/timeline/preload.go @@ -60,7 +60,7 @@ func (p *preloader) Check() bool { // CheckPreload will safely check the preload state, // and if needed call the provided function. if a // preload is in progress, it will wait until complete. -func (p *preloader) CheckPreload(preload func()) { +func (p *preloader) CheckPreload(preload func(*any)) { for { // Get state ptr. ptr := p.p.Load() @@ -93,7 +93,7 @@ func (p *preloader) CheckPreload(preload func()) { // start attempts to start the given preload function, by // performing a CAS operation with 'old'. return is success. -func (p *preloader) start(old *any, preload func()) bool { +func (p *preloader) start(old *any, preload func(*any)) bool { // Optimistically setup a // new waitgroup to set as @@ -105,29 +105,24 @@ func (p *preloader) start(old *any, preload func()) bool { // Wrap waitgroup in // 'any' for pointer. new := any(&wg) + ptr := &new // Attempt CAS operation to claim start. - started := p.p.CompareAndSwap(old, &new) + started := p.p.CompareAndSwap(old, ptr) if !started { return false } // Start. - preload() + preload(ptr) return true } // done marks state as preloaded, // i.e. no more preload required. -func (p *preloader) Done() { - old := p.p.Swap(new(any)) - if old == nil { // was brand-new - return - } - switch t := (*old).(type) { - case *sync.WaitGroup: // was preloading - default: - log.Errorf(nil, "BUG: invalid preloader state: %#v", t) +func (p *preloader) Done(ptr *any) { + if !p.p.CompareAndSwap(ptr, new(any)) { + log.Errorf(nil, "BUG: invalid preloader state: %#v", (*p.p.Load())) } } @@ -137,17 +132,21 @@ func (p *preloader) Clear() { b := false a := any(b) for { - old := p.p.Swap(&a) - if old == nil { // was brand-new - return + // Load current ptr. + ptr := p.p.Load() + if ptr == nil { + return // was brand-new } - switch t := (*old).(type) { - case nil: // was preloaded + + // Check for a preload currently in progress. + if wg, _ := (*ptr).(*sync.WaitGroup); wg != nil { + wg.Wait() + continue + } + + // Try mark as needing preload. + if p.p.CompareAndSwap(ptr, &a) { return - case bool: // was cleared - return - case *sync.WaitGroup: // was preloading - t.Wait() } } } diff --git a/internal/cache/timeline/status.go b/internal/cache/timeline/status.go index 750d2b2f1..dc81dd391 100644 --- a/internal/cache/timeline/status.go +++ b/internal/cache/timeline/status.go @@ -196,14 +196,14 @@ func (t *StatusTimeline) Preload( n int, err error, ) { - t.preloader.CheckPreload(func() { + t.preloader.CheckPreload(func(ptr *any) { n, err = t.preload(loadPage, filter) if err != nil { return } - // Mark preloaded. - t.preloader.Done() + // Mark as preloaded. + t.preloader.Done(ptr) }) return }