diff --git a/errgroup.go b/errgroup.go index 33e6269..c52cbb9 100644 --- a/errgroup.go +++ b/errgroup.go @@ -13,6 +13,7 @@ import ( "context" "fmt" "sync" + "sync/atomic" ) type token struct{} @@ -30,7 +31,7 @@ type Group struct { sem chan token errOnce sync.Once - err error + err atomic.Value } func (g *Group) done() { @@ -50,14 +51,22 @@ func WithContext(ctx context.Context) (*Group, context.Context) { return &Group{cancel: cancel}, ctx } +func (g *Group) error() error { + v := g.err.Load() + if v == nil { + return nil + } + return v.(error) +} + // Wait blocks until all function calls from the Go method have returned, then // returns the first non-nil error (if any) from them. func (g *Group) Wait() error { g.wg.Wait() if g.cancel != nil { - g.cancel(g.err) + g.cancel(g.error()) } - return g.err + return g.error() } // Go calls the given function in a new goroutine. @@ -70,6 +79,9 @@ func (g *Group) Wait() error { // cancel the associated Context, if any. The error will be returned // by Wait. func (g *Group) Go(f func() error) { + if g.error() != nil { + return + } if g.sem != nil { g.sem <- token{} } @@ -92,9 +104,9 @@ func (g *Group) Go(f func() error) { if err := f(); err != nil { g.errOnce.Do(func() { - g.err = err + g.err.Store(err) if g.cancel != nil { - g.cancel(g.err) + g.cancel(err) } }) } @@ -106,6 +118,9 @@ func (g *Group) Go(f func() error) { // // The return value reports whether the goroutine was started. func (g *Group) TryGo(f func() error) bool { + if g.error() != nil { + return false + } if g.sem != nil { select { case g.sem <- token{}: @@ -121,9 +136,9 @@ func (g *Group) TryGo(f func() error) bool { if err := f(); err != nil { g.errOnce.Do(func() { - g.err = err + g.err.Store(err) if g.cancel != nil { - g.cancel(g.err) + g.cancel(err) } }) } diff --git a/errgroup_test.go b/errgroup_test.go index 58b4d75..32db177 100644 --- a/errgroup_test.go +++ b/errgroup_test.go @@ -98,6 +98,38 @@ func ExampleGroup_parallel() { // video result for "golang" } +// FirstError demonstrates that g.Go becomes a no-op if a previous g.Go +// has returned an error. +func ExampleGroup_firstError() { + err1 := errors.New("errgroup_test: 1") + err2 := errors.New("errgroup_test: 2") + + g := new(errgroup.Group) + + ch := make(chan struct{}) + + g.Go(func() error { + fmt.Printf("Returning %s\n", err1) + ch <- struct{}{} + return err1 + }) + + <-ch + + g.Go(func() error { + // This should never run + fmt.Printf("Returning %s\n", err2) + return err2 + }) + + err := g.Wait() + fmt.Printf("Got %s\n", err) + + // Output: + // Returning errgroup_test: 1 + // Got errgroup_test: 1 +} + func TestZeroGroup(t *testing.T) { err1 := errors.New("errgroup_test: 1") err2 := errors.New("errgroup_test: 2") @@ -256,6 +288,7 @@ func TestCancelCause(t *testing.T) { {errs: []error{nil}, want: nil}, {errs: []error{errDoom}, want: errDoom}, {errs: []error{errDoom, nil}, want: errDoom}, + {errs: []error{nil, errDoom}, want: errDoom}, } for _, tc := range cases {