From 0976fa681c295de5355f7a4d968b56cb9da8a76b Mon Sep 17 00:00:00 2001 From: Changkun Ou Date: Sat, 7 May 2022 17:30:52 +0200 Subject: [PATCH] x/sync/errgroup: add TryGo and SetLimit to control concurrency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This benchmark shows the difference between two implementations. Using explicit waiter with mutex (old, before PS3) or channel (new, since PS4). There is no significant difference at a measure: name old time/op new time/op delta Go-8 247ns ±10% 245ns ±10% ~ (p=0.571 n=5+10) name old alloc/op new alloc/op delta Go-8 48.0B ± 0% 40.0B ± 0% -16.67% (p=0.000 n=5+10) name old allocs/op new allocs/op delta Go-8 2.00 ± 0% 2.00 ± 0% ~ (all equal) Fixes golang/go#27837 Change-Id: I60247f1a2a1cdce2b180f10b409e37de8b82341e Reviewed-on: https://go-review.googlesource.com/c/sync/+/405174 Reviewed-by: Bryan Mills Reviewed-by: Heschi Kreinick TryBot-Result: Gopher Robot Run-TryBot: Changkun Ou Auto-Submit: Bryan Mills --- errgroup/errgroup.go | 69 ++++++++++++++++++++++++++++++- errgroup/errgroup_test.go | 86 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 2 deletions(-) diff --git a/errgroup/errgroup.go b/errgroup/errgroup.go index 9857fe5..1eab2fd 100644 --- a/errgroup/errgroup.go +++ b/errgroup/errgroup.go @@ -8,9 +8,12 @@ package errgroup import ( "context" + "fmt" "sync" ) +type token struct{} + // A Group is a collection of goroutines working on subtasks that are part of // the same overall task. // @@ -20,10 +23,19 @@ type Group struct { wg sync.WaitGroup + sem chan token + errOnce sync.Once err error } +func (g *Group) done() { + if g.sem != nil { + <-g.sem + } + g.wg.Done() +} + // WithContext returns a new Group and an associated Context derived from ctx. // // The derived Context is canceled the first time a function passed to Go @@ -45,14 +57,19 @@ func (g *Group) Wait() error { } // Go calls the given function in a new goroutine. +// It blocks until the new goroutine can be added without the number of +// active goroutines in the group exceeding the configured limit. // // The first call to return a non-nil error cancels the group; its error will be // returned by Wait. func (g *Group) Go(f func() error) { - g.wg.Add(1) + if g.sem != nil { + g.sem <- token{} + } + g.wg.Add(1) go func() { - defer g.wg.Done() + defer g.done() if err := f(); err != nil { g.errOnce.Do(func() { @@ -64,3 +81,51 @@ func (g *Group) Go(f func() error) { } }() } + +// TryGo calls the given function in a new goroutine only if the number of +// active goroutines in the group is currently below the configured limit. +// +// The return value reports whether the goroutine was started. +func (g *Group) TryGo(f func() error) bool { + if g.sem != nil { + select { + case g.sem <- token{}: + // Note: this allows barging iff channels in general allow barging. + default: + return false + } + } + + g.wg.Add(1) + go func() { + defer g.done() + + if err := f(); err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel() + } + }) + } + }() + return true +} + +// SetLimit limits the number of active goroutines in this group to at most n. +// A negative value indicates no limit. +// +// Any subsequent call to the Go method will block until it can add an active +// goroutine without exceeding the configured limit. +// +// The limit must not be modified while any goroutines in the group are active. +func (g *Group) SetLimit(n int) { + if n < 0 { + g.sem = nil + return + } + if len(g.sem) != 0 { + panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem))) + } + g.sem = make(chan token, n) +} diff --git a/errgroup/errgroup_test.go b/errgroup/errgroup_test.go index 5a0b9cb..0358842 100644 --- a/errgroup/errgroup_test.go +++ b/errgroup/errgroup_test.go @@ -10,7 +10,9 @@ import ( "fmt" "net/http" "os" + "sync/atomic" "testing" + "time" "golang.org/x/sync/errgroup" ) @@ -174,3 +176,87 @@ func TestWithContext(t *testing.T) { } } } + +func TestTryGo(t *testing.T) { + g := &errgroup.Group{} + n := 42 + g.SetLimit(42) + ch := make(chan struct{}) + fn := func() error { + ch <- struct{}{} + return nil + } + for i := 0; i < n; i++ { + if !g.TryGo(fn) { + t.Fatalf("TryGo should succeed but got fail at %d-th call.", i) + } + } + if g.TryGo(fn) { + t.Fatalf("TryGo is expected to fail but succeeded.") + } + go func() { + for i := 0; i < n; i++ { + <-ch + } + }() + g.Wait() + + if !g.TryGo(fn) { + t.Fatalf("TryGo should success but got fail after all goroutines.") + } + go func() { <-ch }() + g.Wait() + + // Switch limit. + g.SetLimit(1) + if !g.TryGo(fn) { + t.Fatalf("TryGo should success but got failed.") + } + if g.TryGo(fn) { + t.Fatalf("TryGo should fail but succeeded.") + } + go func() { <-ch }() + g.Wait() + + // Block all calls. + g.SetLimit(0) + for i := 0; i < 1<<10; i++ { + if g.TryGo(fn) { + t.Fatalf("TryGo should fail but got succeded.") + } + } + g.Wait() +} + +func TestGoLimit(t *testing.T) { + const limit = 10 + + g := &errgroup.Group{} + g.SetLimit(limit) + var active int32 + for i := 0; i <= 1<<10; i++ { + g.Go(func() error { + n := atomic.AddInt32(&active, 1) + if n > limit { + return fmt.Errorf("saw %d active goroutines; want ≤ %d", n, limit) + } + time.Sleep(1 * time.Microsecond) // Give other goroutines a chance to increment active. + atomic.AddInt32(&active, -1) + return nil + }) + } + if err := g.Wait(); err != nil { + t.Fatal(err) + } +} + +func BenchmarkGo(b *testing.B) { + fn := func() {} + g := &errgroup.Group{} + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + g.Go(func() error { fn(); return nil }) + } + g.Wait() +}