better concurrency safety in Clear() and Done()

This commit is contained in:
kim 2025-04-08 23:20:36 +01:00
commit db91fe5a94
2 changed files with 24 additions and 25 deletions

View file

@ -60,7 +60,7 @@ func (p *preloader) Check() bool {
// CheckPreload will safely check the preload state, // CheckPreload will safely check the preload state,
// and if needed call the provided function. if a // and if needed call the provided function. if a
// preload is in progress, it will wait until complete. // preload is in progress, it will wait until complete.
func (p *preloader) CheckPreload(preload func()) { func (p *preloader) CheckPreload(preload func(*any)) {
for { for {
// Get state ptr. // Get state ptr.
ptr := p.p.Load() ptr := p.p.Load()
@ -93,7 +93,7 @@ func (p *preloader) CheckPreload(preload func()) {
// start attempts to start the given preload function, by // start attempts to start the given preload function, by
// performing a CAS operation with 'old'. return is success. // performing a CAS operation with 'old'. return is success.
func (p *preloader) start(old *any, preload func()) bool { func (p *preloader) start(old *any, preload func(*any)) bool {
// Optimistically setup a // Optimistically setup a
// new waitgroup to set as // new waitgroup to set as
@ -105,29 +105,24 @@ func (p *preloader) start(old *any, preload func()) bool {
// Wrap waitgroup in // Wrap waitgroup in
// 'any' for pointer. // 'any' for pointer.
new := any(&wg) new := any(&wg)
ptr := &new
// Attempt CAS operation to claim start. // Attempt CAS operation to claim start.
started := p.p.CompareAndSwap(old, &new) started := p.p.CompareAndSwap(old, ptr)
if !started { if !started {
return false return false
} }
// Start. // Start.
preload() preload(ptr)
return true return true
} }
// done marks state as preloaded, // done marks state as preloaded,
// i.e. no more preload required. // i.e. no more preload required.
func (p *preloader) Done() { func (p *preloader) Done(ptr *any) {
old := p.p.Swap(new(any)) if !p.p.CompareAndSwap(ptr, new(any)) {
if old == nil { // was brand-new log.Errorf(nil, "BUG: invalid preloader state: %#v", (*p.p.Load()))
return
}
switch t := (*old).(type) {
case *sync.WaitGroup: // was preloading
default:
log.Errorf(nil, "BUG: invalid preloader state: %#v", t)
} }
} }
@ -137,17 +132,21 @@ func (p *preloader) Clear() {
b := false b := false
a := any(b) a := any(b)
for { for {
old := p.p.Swap(&a) // Load current ptr.
if old == nil { // was brand-new ptr := p.p.Load()
return if ptr == nil {
return // was brand-new
} }
switch t := (*old).(type) {
case nil: // was preloaded // Check for a preload currently in progress.
if wg, _ := (*ptr).(*sync.WaitGroup); wg != nil {
wg.Wait()
continue
}
// Try mark as needing preload.
if p.p.CompareAndSwap(ptr, &a) {
return return
case bool: // was cleared
return
case *sync.WaitGroup: // was preloading
t.Wait()
} }
} }
} }

View file

@ -196,14 +196,14 @@ func (t *StatusTimeline) Preload(
n int, n int,
err error, err error,
) { ) {
t.preloader.CheckPreload(func() { t.preloader.CheckPreload(func(ptr *any) {
n, err = t.preload(loadPage, filter) n, err = t.preload(loadPage, filter)
if err != nil { if err != nil {
return return
} }
// Mark preloaded. // Mark as preloaded.
t.preloader.Done() t.preloader.Done(ptr)
}) })
return return
} }