errgroup: propagate panic and Goexit through Wait

Recovered panic values are wrapped and saved in Group.
Goexits are detected by a sentinel value set after the given function
returns normally. Wait propagates the first instance of a panic or
Goexit.

According to the runtime.Goexit after the code will not be executed,
with a bool, if f not call runtime.Goexit, is true,
determine whether to propagate runtime.Goexit.

Fixes golang/go#53757

Change-Id: Ic6426fc014fd1c4368ebaceef5b0d6163770a099
Reviewed-on: https://go-review.googlesource.com/c/sync/+/644575
Reviewed-by: Sean Liao <sean@liao.dev>
Auto-Submit: Alan Donovan <adonovan@google.com>
Commit-Queue: Alan Donovan <adonovan@google.com>
Reviewed-by: Alan Donovan <adonovan@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
qiulaidongfeng 2025-01-27 16:58:51 +08:00 committed by Gopher Robot
commit 506c70f973
2 changed files with 153 additions and 18 deletions

View file

@ -12,6 +12,8 @@ package errgroup
import (
"context"
"fmt"
"runtime"
"runtime/debug"
"sync"
)
@ -31,6 +33,10 @@ type Group struct {
errOnce sync.Once
err error
mu sync.Mutex
panicValue any // = PanicError | PanicValue; non-nil if some Group.Go goroutine panicked.
abnormal bool // some Group.Go goroutine terminated abnormally (panic or goexit).
}
func (g *Group) done() {
@ -50,13 +56,22 @@ func WithContext(ctx context.Context) (*Group, context.Context) {
return &Group{cancel: cancel}, ctx
}
// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
// Wait blocks until all function calls from the Go method have returned
// normally, then returns the first non-nil error (if any) from them.
//
// If any of the calls panics, Wait panics with a [PanicValue];
// and if any of them calls [runtime.Goexit], Wait calls runtime.Goexit.
func (g *Group) Wait() error {
g.wg.Wait()
if g.cancel != nil {
g.cancel(g.err)
}
if g.panicValue != nil {
panic(g.panicValue)
}
if g.abnormal {
runtime.Goexit()
}
return g.err
}
@ -65,18 +80,56 @@ func (g *Group) Wait() error {
// It blocks until the new goroutine can be added without the number of
// active goroutines in the group exceeding the configured limit.
//
// The first call to return a non-nil error cancels the group's context, if the
// group was created by calling WithContext. The error will be returned by Wait.
// It blocks until the new goroutine can be added without the number of
// goroutines in the group exceeding the configured limit.
//
// The first goroutine in the group that returns a non-nil error, panics, or
// invokes [runtime.Goexit] will cancel the associated Context, if any.
func (g *Group) Go(f func() error) {
if g.sem != nil {
g.sem <- token{}
}
g.add(f)
}
func (g *Group) add(f func() error) {
g.wg.Add(1)
go func() {
defer g.done()
normalReturn := false
defer func() {
if normalReturn {
return
}
v := recover()
g.mu.Lock()
defer g.mu.Unlock()
if !g.abnormal {
if g.cancel != nil {
g.cancel(g.err)
}
g.abnormal = true
}
if v != nil && g.panicValue == nil {
switch v := v.(type) {
case error:
g.panicValue = PanicError{
Recovered: v,
Stack: debug.Stack(),
}
default:
g.panicValue = PanicValue{
Recovered: v,
Stack: debug.Stack(),
}
}
}
}()
if err := f(); err != nil {
err := f()
normalReturn = true
if err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
@ -101,19 +154,7 @@ func (g *Group) TryGo(f func() error) bool {
}
}
g.wg.Add(1)
go func() {
defer g.done()
if err := f(); err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel(g.err)
}
})
}
}()
g.add(f)
return true
}
@ -135,3 +176,33 @@ func (g *Group) SetLimit(n int) {
}
g.sem = make(chan token, n)
}
// PanicError wraps an error recovered from an unhandled panic
// when calling a function passed to Go or TryGo.
type PanicError struct {
Recovered error
Stack []byte // result of call to [debug.Stack]
}
func (p PanicError) Error() string {
// A Go Error method conventionally does not include a stack dump, so omit it
// here. (Callers who care can extract it from the Stack field.)
return fmt.Sprintf("recovered from errgroup.Group: %v", p.Recovered)
}
func (p PanicError) Unwrap() error { return p.Recovered }
// PanicValue wraps a value that does not implement the error interface,
// recovered from an unhandled panic when calling a function passed to Go or
// TryGo.
type PanicValue struct {
Recovered any
Stack []byte // result of call to [debug.Stack]
}
func (p PanicValue) String() string {
if len(p.Stack) > 0 {
return fmt.Sprintf("recovered from errgroup.Group: %v\n%s", p.Recovered, p.Stack)
}
return fmt.Sprintf("recovered from errgroup.Group: %v", p.Recovered)
}

View file

@ -10,6 +10,7 @@ import (
"fmt"
"net/http"
"os"
"strings"
"sync/atomic"
"testing"
"time"
@ -289,6 +290,69 @@ func TestCancelCause(t *testing.T) {
}
}
func TestPanic(t *testing.T) {
t.Run("error", func(t *testing.T) {
g := &errgroup.Group{}
p := errors.New("")
g.Go(func() error {
panic(p)
})
defer func() {
err := recover()
if err == nil {
t.Fatalf("should propagate panic through Wait")
}
pe, ok := err.(errgroup.PanicError)
if !ok {
t.Fatalf("type should is errgroup.PanicError, but is %T", err)
}
if pe.Recovered != p {
t.Fatalf("got %v, want %v", pe.Recovered, p)
}
if !strings.Contains(string(pe.Stack), "TestPanic.func") {
t.Log(string(pe.Stack))
t.Fatalf("stack trace incomplete")
}
}()
g.Wait()
})
t.Run("any", func(t *testing.T) {
g := &errgroup.Group{}
g.Go(func() error {
panic(1)
})
defer func() {
err := recover()
if err == nil {
t.Fatalf("should propagate panic through Wait")
}
pe, ok := err.(errgroup.PanicValue)
if !ok {
t.Fatalf("type should is errgroup.PanicValue, but is %T", err)
}
if pe.Recovered != 1 {
t.Fatalf("got %v, want %v", pe.Recovered, 1)
}
if !strings.Contains(string(pe.Stack), "TestPanic.func") {
t.Log(string(pe.Stack))
t.Fatalf("stack trace incomplete")
}
}()
g.Wait()
})
}
func TestGoexit(t *testing.T) {
g := &errgroup.Group{}
g.Go(func() error {
t.Skip()
t.Fatalf("Goexit fail")
return nil
})
g.Wait()
t.Fatalf("should call runtime.Goexit from Wait")
}
func BenchmarkGo(b *testing.B) {
fn := func() {}
g := &errgroup.Group{}