diff --git a/semaphore/semaphore.go b/semaphore/semaphore.go new file mode 100644 index 0000000..e9d2d79 --- /dev/null +++ b/semaphore/semaphore.go @@ -0,0 +1,131 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package semaphore provides a weighted semaphore implementation. +package semaphore // import "golang.org/x/sync/semaphore" + +import ( + "container/list" + "sync" + + // Use the old context because packages that depend on this one + // (e.g. cloud.google.com/go/...) must run on Go 1.6. + // TODO(jba): update to "context" when possible. + "golang.org/x/net/context" +) + +type waiter struct { + n int64 + ready chan<- struct{} // Closed when semaphore acquired. +} + +// NewWeighted creates a new weighted semaphore with the given +// maximum combined weight for concurrent access. +func NewWeighted(n int64) *Weighted { + w := &Weighted{size: n} + return w +} + +// Weighted provides a way to bound concurrent access to a resource. +// The callers can request access with a given weight. +type Weighted struct { + size int64 + cur int64 + mu sync.Mutex + waiters list.List +} + +// Acquire acquires the semaphore with a weight of n, blocking only until ctx +// is done. On success, returns nil. On failure, returns ctx.Err() and leaves +// the semaphore unchanged. +// +// If ctx is already done, Acquire may still succeed without blocking. +func (s *Weighted) Acquire(ctx context.Context, n int64) error { + s.mu.Lock() + if s.size-s.cur >= n && s.waiters.Len() == 0 { + s.cur += n + s.mu.Unlock() + return nil + } + + if n > s.size { + // Don't make other Acquire calls block on one that's doomed to fail. + s.mu.Unlock() + <-ctx.Done() + return ctx.Err() + } + + ready := make(chan struct{}) + w := waiter{n: n, ready: ready} + elem := s.waiters.PushBack(w) + s.mu.Unlock() + + select { + case <-ctx.Done(): + err := ctx.Err() + s.mu.Lock() + select { + case <-ready: + // Acquired the semaphore after we were canceled. Rather than trying to + // fix up the queue, just pretend we didn't notice the cancelation. + err = nil + default: + s.waiters.Remove(elem) + } + s.mu.Unlock() + return err + + case <-ready: + return nil + } +} + +// TryAcquire acquires the semaphore with a weight of n without blocking. +// On success, returns true. On failure, returns false and leaves the semaphore unchanged. +func (s *Weighted) TryAcquire(n int64) bool { + s.mu.Lock() + success := s.size-s.cur >= n && s.waiters.Len() == 0 + if success { + s.cur += n + } + s.mu.Unlock() + return success +} + +// Release releases the semaphore with a weight of n. +func (s *Weighted) Release(n int64) { + s.mu.Lock() + s.cur -= n + if s.cur < 0 { + s.mu.Unlock() + panic("semaphore: bad release") + } + for { + next := s.waiters.Front() + if next == nil { + break // No more waiters blocked. + } + + w := next.Value.(waiter) + if s.size-s.cur < w.n { + // Not enough tokens for the next waiter. We could keep going (to try to + // find a waiter with a smaller request), but under load that could cause + // starvation for large requests; instead, we leave all remaining waiters + // blocked. + // + // Consider a semaphore used as a read-write lock, with N tokens, N + // readers, and one writer. Each reader can Acquire(1) to obtain a read + // lock. The writer can Acquire(N) to obtain a write lock, excluding all + // of the readers. If we allow the readers to jump ahead in the queue, + // the writer will starve — there is always one token available for every + // reader. + break + } + + s.cur += w.n + s.waiters.Remove(next) + close(w.ready) + } + s.mu.Unlock() +} diff --git a/semaphore/semaphore_bench_test.go b/semaphore/semaphore_bench_test.go new file mode 100644 index 0000000..5bb2fb3 --- /dev/null +++ b/semaphore/semaphore_bench_test.go @@ -0,0 +1,130 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.7 + +package semaphore + +import ( + "fmt" + "testing" + + "golang.org/x/net/context" +) + +// weighted is an interface matching a subset of *Weighted. It allows +// alternate implementations for testing and benchmarking. +type weighted interface { + Acquire(context.Context, int64) error + TryAcquire(int64) bool + Release(int64) +} + +// semChan implements Weighted using a channel for +// comparing against the condition variable-based implementation. +type semChan chan struct{} + +func newSemChan(n int64) semChan { + return semChan(make(chan struct{}, n)) +} + +func (s semChan) Acquire(_ context.Context, n int64) error { + for i := int64(0); i < n; i++ { + s <- struct{}{} + } + return nil +} + +func (s semChan) TryAcquire(n int64) bool { + if int64(len(s))+n > int64(cap(s)) { + return false + } + + for i := int64(0); i < n; i++ { + s <- struct{}{} + } + return true +} + +func (s semChan) Release(n int64) { + for i := int64(0); i < n; i++ { + <-s + } +} + +// acquireN calls Acquire(size) on sem N times and then calls Release(size) N times. +func acquireN(b *testing.B, sem weighted, size int64, N int) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < N; j++ { + sem.Acquire(context.Background(), size) + } + for j := 0; j < N; j++ { + sem.Release(size) + } + } +} + +// tryAcquireN calls TryAcquire(size) on sem N times and then calls Release(size) N times. +func tryAcquireN(b *testing.B, sem weighted, size int64, N int) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < N; j++ { + if !sem.TryAcquire(size) { + b.Fatalf("TryAcquire(%v) = false, want true", size) + } + } + for j := 0; j < N; j++ { + sem.Release(size) + } + } +} + +func BenchmarkNewSeq(b *testing.B) { + for _, cap := range []int64{1, 128} { + b.Run(fmt.Sprintf("Weighted-%d", cap), func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = NewWeighted(cap) + } + }) + b.Run(fmt.Sprintf("semChan-%d", cap), func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = newSemChan(cap) + } + }) + } +} + +func BenchmarkAcquireSeq(b *testing.B) { + for _, c := range []struct { + cap, size int64 + N int + }{ + {1, 1, 1}, + {2, 1, 1}, + {16, 1, 1}, + {128, 1, 1}, + {2, 2, 1}, + {16, 2, 8}, + {128, 2, 64}, + {2, 1, 2}, + {16, 8, 2}, + {128, 64, 2}, + } { + for _, w := range []struct { + name string + w weighted + }{ + {"Weighted", NewWeighted(c.cap)}, + {"semChan", newSemChan(c.cap)}, + } { + b.Run(fmt.Sprintf("%s-acquire-%d-%d-%d", w.name, c.cap, c.size, c.N), func(b *testing.B) { + acquireN(b, w.w, c.size, c.N) + }) + b.Run(fmt.Sprintf("%s-tryAcquire-%d-%d-%d", w.name, c.cap, c.size, c.N), func(b *testing.B) { + tryAcquireN(b, w.w, c.size, c.N) + }) + } + } +} diff --git a/semaphore/semaphore_test.go b/semaphore/semaphore_test.go new file mode 100644 index 0000000..ad35dc2 --- /dev/null +++ b/semaphore/semaphore_test.go @@ -0,0 +1,168 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package semaphore + +import ( + "math/rand" + "runtime" + "sync" + "testing" + "time" + + "golang.org/x/net/context" + "golang.org/x/sync/errgroup" +) + +const maxSleep = 1 * time.Millisecond + +func HammerWeighted(sem *Weighted, n int64, loops int) { + for i := 0; i < loops; i++ { + sem.Acquire(context.Background(), n) + time.Sleep(time.Duration(rand.Int63n(int64(maxSleep/time.Nanosecond))) * time.Nanosecond) + sem.Release(n) + } +} + +func TestWeighted(t *testing.T) { + t.Parallel() + + n := runtime.GOMAXPROCS(0) + sem := NewWeighted(int64(n)) + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + HammerWeighted(sem, int64(i), 1000) + }() + } + wg.Wait() +} + +func TestWeightedPanic(t *testing.T) { + t.Parallel() + + defer func() { + if recover() == nil { + t.Fatal("release of an unacquired weighted semaphore did not panic") + } + }() + w := NewWeighted(1) + w.Release(1) +} + +func TestWeightedTryAcquire(t *testing.T) { + t.Parallel() + + ctx := context.Background() + sem := NewWeighted(2) + tries := []bool{} + sem.Acquire(ctx, 1) + tries = append(tries, sem.TryAcquire(1)) + tries = append(tries, sem.TryAcquire(1)) + + sem.Release(2) + + tries = append(tries, sem.TryAcquire(1)) + sem.Acquire(ctx, 1) + tries = append(tries, sem.TryAcquire(1)) + + want := []bool{true, false, true, false} + for i := range tries { + if tries[i] != want[i] { + t.Errorf("tries[%d]: got %t, want %t", i, tries[i], want[i]) + } + } +} + +func TestWeightedAcquire(t *testing.T) { + t.Parallel() + + ctx := context.Background() + sem := NewWeighted(2) + tryAcquire := func(n int64) bool { + ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + defer cancel() + return sem.Acquire(ctx, n) == nil + } + + tries := []bool{} + sem.Acquire(ctx, 1) + tries = append(tries, tryAcquire(1)) + tries = append(tries, tryAcquire(1)) + + sem.Release(2) + + tries = append(tries, tryAcquire(1)) + sem.Acquire(ctx, 1) + tries = append(tries, tryAcquire(1)) + + want := []bool{true, false, true, false} + for i := range tries { + if tries[i] != want[i] { + t.Errorf("tries[%d]: got %t, want %t", i, tries[i], want[i]) + } + } +} + +func TestWeightedDoesntBlockIfTooBig(t *testing.T) { + t.Parallel() + + const n = 2 + sem := NewWeighted(n) + { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go sem.Acquire(ctx, n+1) + } + + g, ctx := errgroup.WithContext(context.Background()) + for i := n * 3; i > 0; i-- { + g.Go(func() error { + err := sem.Acquire(ctx, 1) + if err == nil { + time.Sleep(1 * time.Millisecond) + sem.Release(1) + } + return err + }) + } + if err := g.Wait(); err != nil { + t.Errorf("NewWeighted(%v) failed to AcquireCtx(_, 1) with AcquireCtx(_, %v) pending", n, n+1) + } +} + +// TestLargeAcquireDoesntStarve times out if a large call to Acquire starves. +// Merely returning from the test function indicates success. +func TestLargeAcquireDoesntStarve(t *testing.T) { + t.Parallel() + + ctx := context.Background() + n := int64(runtime.GOMAXPROCS(0)) + sem := NewWeighted(n) + running := true + + var wg sync.WaitGroup + wg.Add(int(n)) + for i := n; i > 0; i-- { + sem.Acquire(ctx, 1) + go func() { + defer func() { + sem.Release(1) + wg.Done() + }() + for running { + time.Sleep(1 * time.Millisecond) + sem.Release(1) + sem.Acquire(ctx, 1) + } + }() + } + + sem.Acquire(ctx, n) + running = false + sem.Release(n) + wg.Wait() +}