✨ Short-circuit new goroutines if an error already occured
This commit is contained in:
parent
dc577bcb9c
commit
8640832a27
2 changed files with 55 additions and 7 deletions
29
errgroup.go
29
errgroup.go
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
type token struct{}
|
type token struct{}
|
||||||
|
|
@ -30,7 +31,7 @@ type Group struct {
|
||||||
sem chan token
|
sem chan token
|
||||||
|
|
||||||
errOnce sync.Once
|
errOnce sync.Once
|
||||||
err error
|
err atomic.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) done() {
|
func (g *Group) done() {
|
||||||
|
|
@ -50,14 +51,22 @@ func WithContext(ctx context.Context) (*Group, context.Context) {
|
||||||
return &Group{cancel: cancel}, ctx
|
return &Group{cancel: cancel}, ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *Group) error() error {
|
||||||
|
v := g.err.Load()
|
||||||
|
if v == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return v.(error)
|
||||||
|
}
|
||||||
|
|
||||||
// Wait blocks until all function calls from the Go method have returned, then
|
// Wait blocks until all function calls from the Go method have returned, then
|
||||||
// returns the first non-nil error (if any) from them.
|
// returns the first non-nil error (if any) from them.
|
||||||
func (g *Group) Wait() error {
|
func (g *Group) Wait() error {
|
||||||
g.wg.Wait()
|
g.wg.Wait()
|
||||||
if g.cancel != nil {
|
if g.cancel != nil {
|
||||||
g.cancel(g.err)
|
g.cancel(g.error())
|
||||||
}
|
}
|
||||||
return g.err
|
return g.error()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Go calls the given function in a new goroutine.
|
// Go calls the given function in a new goroutine.
|
||||||
|
|
@ -70,6 +79,9 @@ func (g *Group) Wait() error {
|
||||||
// cancel the associated Context, if any. The error will be returned
|
// cancel the associated Context, if any. The error will be returned
|
||||||
// by Wait.
|
// by Wait.
|
||||||
func (g *Group) Go(f func() error) {
|
func (g *Group) Go(f func() error) {
|
||||||
|
if g.error() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if g.sem != nil {
|
if g.sem != nil {
|
||||||
g.sem <- token{}
|
g.sem <- token{}
|
||||||
}
|
}
|
||||||
|
|
@ -92,9 +104,9 @@ func (g *Group) Go(f func() error) {
|
||||||
|
|
||||||
if err := f(); err != nil {
|
if err := f(); err != nil {
|
||||||
g.errOnce.Do(func() {
|
g.errOnce.Do(func() {
|
||||||
g.err = err
|
g.err.Store(err)
|
||||||
if g.cancel != nil {
|
if g.cancel != nil {
|
||||||
g.cancel(g.err)
|
g.cancel(err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -106,6 +118,9 @@ func (g *Group) Go(f func() error) {
|
||||||
//
|
//
|
||||||
// The return value reports whether the goroutine was started.
|
// The return value reports whether the goroutine was started.
|
||||||
func (g *Group) TryGo(f func() error) bool {
|
func (g *Group) TryGo(f func() error) bool {
|
||||||
|
if g.error() != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
if g.sem != nil {
|
if g.sem != nil {
|
||||||
select {
|
select {
|
||||||
case g.sem <- token{}:
|
case g.sem <- token{}:
|
||||||
|
|
@ -121,9 +136,9 @@ func (g *Group) TryGo(f func() error) bool {
|
||||||
|
|
||||||
if err := f(); err != nil {
|
if err := f(); err != nil {
|
||||||
g.errOnce.Do(func() {
|
g.errOnce.Do(func() {
|
||||||
g.err = err
|
g.err.Store(err)
|
||||||
if g.cancel != nil {
|
if g.cancel != nil {
|
||||||
g.cancel(g.err)
|
g.cancel(err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,38 @@ func ExampleGroup_parallel() {
|
||||||
// video result for "golang"
|
// video result for "golang"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FirstError demonstrates that g.Go becomes a no-op if a previous g.Go
|
||||||
|
// has returned an error.
|
||||||
|
func ExampleGroup_firstError() {
|
||||||
|
err1 := errors.New("errgroup_test: 1")
|
||||||
|
err2 := errors.New("errgroup_test: 2")
|
||||||
|
|
||||||
|
g := new(errgroup.Group)
|
||||||
|
|
||||||
|
ch := make(chan struct{})
|
||||||
|
|
||||||
|
g.Go(func() error {
|
||||||
|
fmt.Printf("Returning %s\n", err1)
|
||||||
|
ch <- struct{}{}
|
||||||
|
return err1
|
||||||
|
})
|
||||||
|
|
||||||
|
<-ch
|
||||||
|
|
||||||
|
g.Go(func() error {
|
||||||
|
// This should never run
|
||||||
|
fmt.Printf("Returning %s\n", err2)
|
||||||
|
return err2
|
||||||
|
})
|
||||||
|
|
||||||
|
err := g.Wait()
|
||||||
|
fmt.Printf("Got %s\n", err)
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// Returning errgroup_test: 1
|
||||||
|
// Got errgroup_test: 1
|
||||||
|
}
|
||||||
|
|
||||||
func TestZeroGroup(t *testing.T) {
|
func TestZeroGroup(t *testing.T) {
|
||||||
err1 := errors.New("errgroup_test: 1")
|
err1 := errors.New("errgroup_test: 1")
|
||||||
err2 := errors.New("errgroup_test: 2")
|
err2 := errors.New("errgroup_test: 2")
|
||||||
|
|
@ -256,6 +288,7 @@ func TestCancelCause(t *testing.T) {
|
||||||
{errs: []error{nil}, want: nil},
|
{errs: []error{nil}, want: nil},
|
||||||
{errs: []error{errDoom}, want: errDoom},
|
{errs: []error{errDoom}, want: errDoom},
|
||||||
{errs: []error{errDoom, nil}, want: errDoom},
|
{errs: []error{errDoom, nil}, want: errDoom},
|
||||||
|
{errs: []error{nil, errDoom}, want: errDoom},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue