diff --git a/semaphore/semaphore.go b/semaphore/semaphore.go index 7f096fe..30f632c 100644 --- a/semaphore/semaphore.go +++ b/semaphore/semaphore.go @@ -67,7 +67,12 @@ func (s *Weighted) Acquire(ctx context.Context, n int64) error { // fix up the queue, just pretend we didn't notice the cancelation. err = nil default: + isFront := s.waiters.Front() == elem s.waiters.Remove(elem) + // If we're at the front and there're extra tokens left, notify other waiters. + if isFront && s.size > s.cur { + s.notifyWaiters() + } } s.mu.Unlock() return err @@ -97,6 +102,11 @@ func (s *Weighted) Release(n int64) { s.mu.Unlock() panic("semaphore: released more than held") } + s.notifyWaiters() + s.mu.Unlock() +} + +func (s *Weighted) notifyWaiters() { for { next := s.waiters.Front() if next == nil { @@ -123,5 +133,4 @@ func (s *Weighted) Release(n int64) { s.waiters.Remove(next) close(w.ready) } - s.mu.Unlock() } diff --git a/semaphore/semaphore_test.go b/semaphore/semaphore_test.go index b5f8f13..6e8eca2 100644 --- a/semaphore/semaphore_test.go +++ b/semaphore/semaphore_test.go @@ -169,3 +169,34 @@ func TestLargeAcquireDoesntStarve(t *testing.T) { sem.Release(n) wg.Wait() } + +// translated from https://github.com/zhiqiangxu/util/blob/master/mutex/crwmutex_test.go#L43 +func TestAllocCancelDoesntStarve(t *testing.T) { + sem := semaphore.NewWeighted(10) + + // Block off a portion of the semaphore so that Acquire(_, 10) can eventually succeed. + sem.Acquire(context.Background(), 1) + + // In the background, Acquire(_, 10). + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + sem.Acquire(ctx, 10) + }() + + // Wait until the Acquire(_, 10) call blocks. + for sem.TryAcquire(1) { + sem.Release(1) + runtime.Gosched() + } + + // Now try to grab a read lock, and simultaneously unblock the Acquire(_, 10) call. + // Both Acquire calls should unblock and return, in either order. + go cancel() + + err := sem.Acquire(context.Background(), 1) + if err != nil { + t.Fatalf("Acquire(_, 1) failed unexpectedly: %v", err) + } + sem.Release(1) +}