Short-circuit new goroutines if an error already occured

This commit is contained in:
Dan Jones 2025-09-07 23:21:44 -05:00
commit 8640832a27
2 changed files with 55 additions and 7 deletions

View file

@ -13,6 +13,7 @@ import (
"context" "context"
"fmt" "fmt"
"sync" "sync"
"sync/atomic"
) )
type token struct{} type token struct{}
@ -30,7 +31,7 @@ type Group struct {
sem chan token sem chan token
errOnce sync.Once errOnce sync.Once
err error err atomic.Value
} }
func (g *Group) done() { func (g *Group) done() {
@ -50,14 +51,22 @@ func WithContext(ctx context.Context) (*Group, context.Context) {
return &Group{cancel: cancel}, ctx return &Group{cancel: cancel}, ctx
} }
func (g *Group) error() error {
v := g.err.Load()
if v == nil {
return nil
}
return v.(error)
}
// Wait blocks until all function calls from the Go method have returned, then // Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them. // returns the first non-nil error (if any) from them.
func (g *Group) Wait() error { func (g *Group) Wait() error {
g.wg.Wait() g.wg.Wait()
if g.cancel != nil { if g.cancel != nil {
g.cancel(g.err) g.cancel(g.error())
} }
return g.err return g.error()
} }
// Go calls the given function in a new goroutine. // Go calls the given function in a new goroutine.
@ -70,6 +79,9 @@ func (g *Group) Wait() error {
// cancel the associated Context, if any. The error will be returned // cancel the associated Context, if any. The error will be returned
// by Wait. // by Wait.
func (g *Group) Go(f func() error) { func (g *Group) Go(f func() error) {
if g.error() != nil {
return
}
if g.sem != nil { if g.sem != nil {
g.sem <- token{} g.sem <- token{}
} }
@ -92,9 +104,9 @@ func (g *Group) Go(f func() error) {
if err := f(); err != nil { if err := f(); err != nil {
g.errOnce.Do(func() { g.errOnce.Do(func() {
g.err = err g.err.Store(err)
if g.cancel != nil { if g.cancel != nil {
g.cancel(g.err) g.cancel(err)
} }
}) })
} }
@ -106,6 +118,9 @@ func (g *Group) Go(f func() error) {
// //
// The return value reports whether the goroutine was started. // The return value reports whether the goroutine was started.
func (g *Group) TryGo(f func() error) bool { func (g *Group) TryGo(f func() error) bool {
if g.error() != nil {
return false
}
if g.sem != nil { if g.sem != nil {
select { select {
case g.sem <- token{}: case g.sem <- token{}:
@ -121,9 +136,9 @@ func (g *Group) TryGo(f func() error) bool {
if err := f(); err != nil { if err := f(); err != nil {
g.errOnce.Do(func() { g.errOnce.Do(func() {
g.err = err g.err.Store(err)
if g.cancel != nil { if g.cancel != nil {
g.cancel(g.err) g.cancel(err)
} }
}) })
} }

View file

@ -98,6 +98,38 @@ func ExampleGroup_parallel() {
// video result for "golang" // video result for "golang"
} }
// FirstError demonstrates that g.Go becomes a no-op if a previous g.Go
// has returned an error.
func ExampleGroup_firstError() {
err1 := errors.New("errgroup_test: 1")
err2 := errors.New("errgroup_test: 2")
g := new(errgroup.Group)
ch := make(chan struct{})
g.Go(func() error {
fmt.Printf("Returning %s\n", err1)
ch <- struct{}{}
return err1
})
<-ch
g.Go(func() error {
// This should never run
fmt.Printf("Returning %s\n", err2)
return err2
})
err := g.Wait()
fmt.Printf("Got %s\n", err)
// Output:
// Returning errgroup_test: 1
// Got errgroup_test: 1
}
func TestZeroGroup(t *testing.T) { func TestZeroGroup(t *testing.T) {
err1 := errors.New("errgroup_test: 1") err1 := errors.New("errgroup_test: 1")
err2 := errors.New("errgroup_test: 2") err2 := errors.New("errgroup_test: 2")
@ -256,6 +288,7 @@ func TestCancelCause(t *testing.T) {
{errs: []error{nil}, want: nil}, {errs: []error{nil}, want: nil},
{errs: []error{errDoom}, want: errDoom}, {errs: []error{errDoom}, want: errDoom},
{errs: []error{errDoom, nil}, want: errDoom}, {errs: []error{errDoom, nil}, want: errDoom},
{errs: []error{nil, errDoom}, want: errDoom},
} }
for _, tc := range cases { for _, tc := range cases {