diff --git a/singleflight/singleflight.go b/singleflight/singleflight.go index 9a4f8d5..97a1aa4 100644 --- a/singleflight/singleflight.go +++ b/singleflight/singleflight.go @@ -17,6 +17,10 @@ type call struct { val interface{} err error + // forgotten indicates whether Forget was called with this call's key + // while the call was still in flight. + forgotten bool + // These fields are read and written with the singleflight // mutex held before the WaitGroup is done, and are read but // not written after the WaitGroup is done. @@ -94,7 +98,9 @@ func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) { c.wg.Done() g.mu.Lock() - delete(g.m, key) + if !c.forgotten { + delete(g.m, key) + } for _, ch := range c.chans { ch <- Result{c.val, c.err, c.dups > 0} } @@ -106,6 +112,9 @@ func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) { // an earlier call to complete. func (g *Group) Forget(key string) { g.mu.Lock() + if c, ok := g.m[key]; ok { + c.forgotten = true + } delete(g.m, key) g.mu.Unlock() } diff --git a/singleflight/singleflight_test.go b/singleflight/singleflight_test.go index 5e6f1b3..ad04037 100644 --- a/singleflight/singleflight_test.go +++ b/singleflight/singleflight_test.go @@ -85,3 +85,75 @@ func TestDoDupSuppress(t *testing.T) { t.Errorf("number of calls = %d; want over 0 and less than %d", got, n) } } + +// Test that singleflight behaves correctly after Forget called. +// See https://github.com/golang/go/issues/31420 +func TestForget(t *testing.T) { + var g Group + + var firstStarted, firstFinished sync.WaitGroup + + firstStarted.Add(1) + firstFinished.Add(1) + + firstCh := make(chan struct{}) + go func() { + g.Do("key", func() (i interface{}, e error) { + firstStarted.Done() + <-firstCh + firstFinished.Done() + return + }) + }() + + firstStarted.Wait() + g.Forget("key") // from this point no two function using same key should be executed concurrently + + var secondStarted int32 + var secondFinished int32 + var thirdStarted int32 + + secondCh := make(chan struct{}) + secondRunning := make(chan struct{}) + go func() { + g.Do("key", func() (i interface{}, e error) { + defer func() { + }() + atomic.AddInt32(&secondStarted, 1) + // Notify that we started + secondCh <- struct{}{} + // Wait other get above signal + <-secondRunning + <-secondCh + atomic.AddInt32(&secondFinished, 1) + return 2, nil + }) + }() + + close(firstCh) + firstFinished.Wait() // wait for first execution (which should not affect execution after Forget) + + <-secondCh + // Notify second that we got the signal that it started + secondRunning <- struct{}{} + if atomic.LoadInt32(&secondStarted) != 1 { + t.Fatal("Second execution should be executed due to usage of forget") + } + + if atomic.LoadInt32(&secondFinished) == 1 { + t.Fatal("Second execution should be still active") + } + + close(secondCh) + result, _, _ := g.Do("key", func() (i interface{}, e error) { + atomic.AddInt32(&thirdStarted, 1) + return 3, nil + }) + + if atomic.LoadInt32(&thirdStarted) != 0 { + t.Error("Third call should not be started because was started during second execution") + } + if result != 2 { + t.Errorf("We should receive result produced by second call, expected: 2, got %d", result) + } +}