From b24cf3b9a0aefe6493c39f8a653872246beefccc Mon Sep 17 00:00:00 2001 From: Dan Jones Date: Mon, 8 Sep 2025 16:50:50 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20race=20conditions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Taskfile.yml | 1 + errgroup.go | 37 ++++++++---- errgroup_test.go | 144 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 171 insertions(+), 11 deletions(-) diff --git a/Taskfile.yml b/Taskfile.yml index 055cccd..8f64c1f 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -52,6 +52,7 @@ tasks: serve-report: desc: Serve the coverage report + deps: [coverage-report] sources: - build/cover.html cmds: diff --git a/errgroup.go b/errgroup.go index 1142170..313b360 100644 --- a/errgroup.go +++ b/errgroup.go @@ -28,15 +28,15 @@ type Group struct { wg sync.WaitGroup - sem chan token + sem atomic.Value errOnce sync.Once err atomic.Value } func (g *Group) done() { - if g.sem != nil { - <-g.sem + if sem := g.sema(); sem != nil { + <-sem } g.wg.Done() } @@ -51,6 +51,20 @@ func WithContext(ctx context.Context) (*Group, context.Context) { return &Group{cancel: cancel}, ctx } +type semContainer struct{ sem chan token } + +func (g *Group) sema() chan token { + v := g.sem.Load() + if v == nil { + return nil + } + return v.(semContainer).sem +} + +func (g *Group) setSema(ch chan token) { + g.sem.Store(semContainer{ch}) +} + func (g *Group) error() error { v := g.err.Load() if v == nil { @@ -94,8 +108,8 @@ func (g *Group) Go(f func() error) { if g.error() != nil { return } - if g.sem != nil { - g.sem <- token{} + if sem := g.sema(); sem != nil { + sem <- token{} } g.wg.Add(1) @@ -126,9 +140,9 @@ func (g *Group) TryGo(f func() error) bool { if g.error() != nil { return false } - if g.sem != nil { + if sem := g.sema(); sem != nil { select { - case g.sem <- token{}: + case sem <- token{}: // Note: this allows barging iff channels in general allow barging. default: return false @@ -159,12 +173,13 @@ func (egerr *ErrgroupLimitError) Error() string { // The limit must not be modified while any goroutines in the group are active. func (g *Group) SetLimit(n int) { if n < 0 { - g.sem = nil + g.setSema(nil) return } - if len(g.sem) != 0 { - var err error = &ErrgroupLimitError{len(g.sem)} + + if sem := g.sema(); sem != nil && len(sem) != 0 { + var err error = &ErrgroupLimitError{len(sem)} panic(err) } - g.sem = make(chan token, n) + g.setSema(make(chan token, n)) } diff --git a/errgroup_test.go b/errgroup_test.go index 32db177..7e07ef7 100644 --- a/errgroup_test.go +++ b/errgroup_test.go @@ -255,6 +255,45 @@ func TestTryGo(t *testing.T) { g.Wait() } +func TestTryGoError(t *testing.T) { + g := &errgroup.Group{} + ch := make(chan struct{}) + var count int32 = 0 + err1 := errors.New("group_test: err1") + err2 := errors.New("group_test: err2") + + ran := g.TryGo(func() error { + atomic.AddInt32(&count, 1) + ch <- struct{}{} + return err1 + }) + + if !ran { + t.Error("TryGo should succeed but failed first run") + } + + <-ch + + ran = g.TryGo(func() error { + atomic.AddInt32(&count, 1) + return err2 + }) + + if ran { + t.Error("TryGo should failed but succeeded second run") + } + + if count != 1 { + t.Errorf("TryGo should have run 1 time, but ran %d times", count) + } + + err := g.Wait() + + if err != err1 { + t.Errorf("g.Wait() = %v, wanted %v", err, err1) + } +} + func TestGoLimit(t *testing.T) { const limit = 10 @@ -277,6 +316,111 @@ func TestGoLimit(t *testing.T) { } } +func TestSetLimit(t *testing.T) { + g := &errgroup.Group{} + ch := make(chan struct{}) + var count int32 = 0 + + g.SetLimit(0) + ran := g.TryGo(func() error { + atomic.AddInt32(&count, 1) + return nil + }) + if ran { + t.Fatal("TryGo should fail but succeeded first run") + } + + g.SetLimit(-1) + ran = g.TryGo(func() error { + atomic.AddInt32(&count, 1) + ch <- struct{}{} + return nil + }) + if !ran { + t.Fatal("TryGo should succeed but failed second run") + } + + <-ch + + g.SetLimit(1) + ran = g.TryGo(func() error { + atomic.AddInt32(&count, 1) + ch <- struct{}{} + return nil + }) + if !ran { + t.Fatal("TryGo should succeed but failed third run") + } + ran = g.TryGo(func() error { + atomic.AddInt32(&count, 1) + return nil + }) + if ran { + t.Fatal("TryGo should fail but succeeded fourth run") + } + + <-ch + + if count != 2 { + t.Errorf("TryGo should have run 2 times, but ran %d times", count) + } + + err := g.Wait() + + if err != nil { + t.Errorf("g.Wait() = %v, wanted %v", err, nil) + } + +} + +func TestLimitPanic(t *testing.T) { + g := &errgroup.Group{} + ch := make(chan struct{}) + + g.SetLimit(2) + + ran := g.TryGo(func() error { + ch <- struct{}{} + return nil + }) + if !ran { + t.Fatal("TryGo should succeed but failed first run") + } + + var err error + func() { + defer func() { + rec := recover() + if rec == nil { + t.Fatal("SetLimit should have panicked, but didn't") + } + + er, ok := rec.(error) + if !ok { + t.Fatalf("SetLimit should have panicked with an error, %T received", rec) + } + + err = er + }() + g.SetLimit(5) + }() + + <-ch + + glimErr := &errgroup.ErrgroupLimitError{} + if !errors.As(err, &glimErr) { + t.Fatalf("panicked error should have been ErrgroupLimitError. %T received", err) + } + + if glimErr.Size != 1 { + t.Fatalf("ErrgroupLimitError should have had a size of 1. Got %d", glimErr.Size) + } + expErrMsg := "errgroup: modify limit while 1 goroutines in the group are still active" + if errStr := glimErr.Error(); errStr != expErrMsg { + t.Fatalf("error message should have been %s. Got %s", expErrMsg, errStr) + } +} + func TestCancelCause(t *testing.T) { errDoom := errors.New("group_test: doomed")