errgroup: propagate panic and Goexit through Wait
Recovered panic values are wrapped and saved in Group. Goexits are detected by a sentinel value set after the given function returns normally. Wait propagates the first instance of a panic or Goexit. According to the runtime.Goexit after the code will not be executed, with a bool, if f not call runtime.Goexit, is true, determine whether to propagate runtime.Goexit. Fixes golang/go#53757 Change-Id: Ic6426fc014fd1c4368ebaceef5b0d6163770a099 Reviewed-on: https://go-review.googlesource.com/c/sync/+/644575 Reviewed-by: Sean Liao <sean@liao.dev> Auto-Submit: Alan Donovan <adonovan@google.com> Commit-Queue: Alan Donovan <adonovan@google.com> Reviewed-by: Alan Donovan <adonovan@google.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
parent
396f3a06ea
commit
506c70f973
2 changed files with 153 additions and 18 deletions
|
|
@ -12,6 +12,8 @@ package errgroup
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
|
@ -31,6 +33,10 @@ type Group struct {
|
|||
|
||||
errOnce sync.Once
|
||||
err error
|
||||
|
||||
mu sync.Mutex
|
||||
panicValue any // = PanicError | PanicValue; non-nil if some Group.Go goroutine panicked.
|
||||
abnormal bool // some Group.Go goroutine terminated abnormally (panic or goexit).
|
||||
}
|
||||
|
||||
func (g *Group) done() {
|
||||
|
|
@ -50,13 +56,22 @@ func WithContext(ctx context.Context) (*Group, context.Context) {
|
|||
return &Group{cancel: cancel}, ctx
|
||||
}
|
||||
|
||||
// Wait blocks until all function calls from the Go method have returned, then
|
||||
// returns the first non-nil error (if any) from them.
|
||||
// Wait blocks until all function calls from the Go method have returned
|
||||
// normally, then returns the first non-nil error (if any) from them.
|
||||
//
|
||||
// If any of the calls panics, Wait panics with a [PanicValue];
|
||||
// and if any of them calls [runtime.Goexit], Wait calls runtime.Goexit.
|
||||
func (g *Group) Wait() error {
|
||||
g.wg.Wait()
|
||||
if g.cancel != nil {
|
||||
g.cancel(g.err)
|
||||
}
|
||||
if g.panicValue != nil {
|
||||
panic(g.panicValue)
|
||||
}
|
||||
if g.abnormal {
|
||||
runtime.Goexit()
|
||||
}
|
||||
return g.err
|
||||
}
|
||||
|
||||
|
|
@ -65,18 +80,56 @@ func (g *Group) Wait() error {
|
|||
// It blocks until the new goroutine can be added without the number of
|
||||
// active goroutines in the group exceeding the configured limit.
|
||||
//
|
||||
// The first call to return a non-nil error cancels the group's context, if the
|
||||
// group was created by calling WithContext. The error will be returned by Wait.
|
||||
// It blocks until the new goroutine can be added without the number of
|
||||
// goroutines in the group exceeding the configured limit.
|
||||
//
|
||||
// The first goroutine in the group that returns a non-nil error, panics, or
|
||||
// invokes [runtime.Goexit] will cancel the associated Context, if any.
|
||||
func (g *Group) Go(f func() error) {
|
||||
if g.sem != nil {
|
||||
g.sem <- token{}
|
||||
}
|
||||
|
||||
g.add(f)
|
||||
}
|
||||
|
||||
func (g *Group) add(f func() error) {
|
||||
g.wg.Add(1)
|
||||
go func() {
|
||||
defer g.done()
|
||||
normalReturn := false
|
||||
defer func() {
|
||||
if normalReturn {
|
||||
return
|
||||
}
|
||||
v := recover()
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
if !g.abnormal {
|
||||
if g.cancel != nil {
|
||||
g.cancel(g.err)
|
||||
}
|
||||
g.abnormal = true
|
||||
}
|
||||
if v != nil && g.panicValue == nil {
|
||||
switch v := v.(type) {
|
||||
case error:
|
||||
g.panicValue = PanicError{
|
||||
Recovered: v,
|
||||
Stack: debug.Stack(),
|
||||
}
|
||||
default:
|
||||
g.panicValue = PanicValue{
|
||||
Recovered: v,
|
||||
Stack: debug.Stack(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if err := f(); err != nil {
|
||||
err := f()
|
||||
normalReturn = true
|
||||
if err != nil {
|
||||
g.errOnce.Do(func() {
|
||||
g.err = err
|
||||
if g.cancel != nil {
|
||||
|
|
@ -101,19 +154,7 @@ func (g *Group) TryGo(f func() error) bool {
|
|||
}
|
||||
}
|
||||
|
||||
g.wg.Add(1)
|
||||
go func() {
|
||||
defer g.done()
|
||||
|
||||
if err := f(); err != nil {
|
||||
g.errOnce.Do(func() {
|
||||
g.err = err
|
||||
if g.cancel != nil {
|
||||
g.cancel(g.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
g.add(f)
|
||||
return true
|
||||
}
|
||||
|
||||
|
|
@ -135,3 +176,33 @@ func (g *Group) SetLimit(n int) {
|
|||
}
|
||||
g.sem = make(chan token, n)
|
||||
}
|
||||
|
||||
// PanicError wraps an error recovered from an unhandled panic
|
||||
// when calling a function passed to Go or TryGo.
|
||||
type PanicError struct {
|
||||
Recovered error
|
||||
Stack []byte // result of call to [debug.Stack]
|
||||
}
|
||||
|
||||
func (p PanicError) Error() string {
|
||||
// A Go Error method conventionally does not include a stack dump, so omit it
|
||||
// here. (Callers who care can extract it from the Stack field.)
|
||||
return fmt.Sprintf("recovered from errgroup.Group: %v", p.Recovered)
|
||||
}
|
||||
|
||||
func (p PanicError) Unwrap() error { return p.Recovered }
|
||||
|
||||
// PanicValue wraps a value that does not implement the error interface,
|
||||
// recovered from an unhandled panic when calling a function passed to Go or
|
||||
// TryGo.
|
||||
type PanicValue struct {
|
||||
Recovered any
|
||||
Stack []byte // result of call to [debug.Stack]
|
||||
}
|
||||
|
||||
func (p PanicValue) String() string {
|
||||
if len(p.Stack) > 0 {
|
||||
return fmt.Sprintf("recovered from errgroup.Group: %v\n%s", p.Recovered, p.Stack)
|
||||
}
|
||||
return fmt.Sprintf("recovered from errgroup.Group: %v", p.Recovered)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -289,6 +290,69 @@ func TestCancelCause(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPanic(t *testing.T) {
|
||||
t.Run("error", func(t *testing.T) {
|
||||
g := &errgroup.Group{}
|
||||
p := errors.New("")
|
||||
g.Go(func() error {
|
||||
panic(p)
|
||||
})
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err == nil {
|
||||
t.Fatalf("should propagate panic through Wait")
|
||||
}
|
||||
pe, ok := err.(errgroup.PanicError)
|
||||
if !ok {
|
||||
t.Fatalf("type should is errgroup.PanicError, but is %T", err)
|
||||
}
|
||||
if pe.Recovered != p {
|
||||
t.Fatalf("got %v, want %v", pe.Recovered, p)
|
||||
}
|
||||
if !strings.Contains(string(pe.Stack), "TestPanic.func") {
|
||||
t.Log(string(pe.Stack))
|
||||
t.Fatalf("stack trace incomplete")
|
||||
}
|
||||
}()
|
||||
g.Wait()
|
||||
})
|
||||
t.Run("any", func(t *testing.T) {
|
||||
g := &errgroup.Group{}
|
||||
g.Go(func() error {
|
||||
panic(1)
|
||||
})
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err == nil {
|
||||
t.Fatalf("should propagate panic through Wait")
|
||||
}
|
||||
pe, ok := err.(errgroup.PanicValue)
|
||||
if !ok {
|
||||
t.Fatalf("type should is errgroup.PanicValue, but is %T", err)
|
||||
}
|
||||
if pe.Recovered != 1 {
|
||||
t.Fatalf("got %v, want %v", pe.Recovered, 1)
|
||||
}
|
||||
if !strings.Contains(string(pe.Stack), "TestPanic.func") {
|
||||
t.Log(string(pe.Stack))
|
||||
t.Fatalf("stack trace incomplete")
|
||||
}
|
||||
}()
|
||||
g.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func TestGoexit(t *testing.T) {
|
||||
g := &errgroup.Group{}
|
||||
g.Go(func() error {
|
||||
t.Skip()
|
||||
t.Fatalf("Goexit fail")
|
||||
return nil
|
||||
})
|
||||
g.Wait()
|
||||
t.Fatalf("should call runtime.Goexit from Wait")
|
||||
}
|
||||
|
||||
func BenchmarkGo(b *testing.B) {
|
||||
fn := func() {}
|
||||
g := &errgroup.Group{}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue