🐛 Fix race conditions
This commit is contained in:
parent
80f455ae77
commit
b24cf3b9a0
3 changed files with 171 additions and 11 deletions
|
|
@ -52,6 +52,7 @@ tasks:
|
||||||
|
|
||||||
serve-report:
|
serve-report:
|
||||||
desc: Serve the coverage report
|
desc: Serve the coverage report
|
||||||
|
deps: [coverage-report]
|
||||||
sources:
|
sources:
|
||||||
- build/cover.html
|
- build/cover.html
|
||||||
cmds:
|
cmds:
|
||||||
|
|
|
||||||
37
errgroup.go
37
errgroup.go
|
|
@ -28,15 +28,15 @@ type Group struct {
|
||||||
|
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
|
|
||||||
sem chan token
|
sem atomic.Value
|
||||||
|
|
||||||
errOnce sync.Once
|
errOnce sync.Once
|
||||||
err atomic.Value
|
err atomic.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) done() {
|
func (g *Group) done() {
|
||||||
if g.sem != nil {
|
if sem := g.sema(); sem != nil {
|
||||||
<-g.sem
|
<-sem
|
||||||
}
|
}
|
||||||
g.wg.Done()
|
g.wg.Done()
|
||||||
}
|
}
|
||||||
|
|
@ -51,6 +51,20 @@ func WithContext(ctx context.Context) (*Group, context.Context) {
|
||||||
return &Group{cancel: cancel}, ctx
|
return &Group{cancel: cancel}, ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type semContainer struct{ sem chan token }
|
||||||
|
|
||||||
|
func (g *Group) sema() chan token {
|
||||||
|
v := g.sem.Load()
|
||||||
|
if v == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return v.(semContainer).sem
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Group) setSema(ch chan token) {
|
||||||
|
g.sem.Store(semContainer{ch})
|
||||||
|
}
|
||||||
|
|
||||||
func (g *Group) error() error {
|
func (g *Group) error() error {
|
||||||
v := g.err.Load()
|
v := g.err.Load()
|
||||||
if v == nil {
|
if v == nil {
|
||||||
|
|
@ -94,8 +108,8 @@ func (g *Group) Go(f func() error) {
|
||||||
if g.error() != nil {
|
if g.error() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if g.sem != nil {
|
if sem := g.sema(); sem != nil {
|
||||||
g.sem <- token{}
|
sem <- token{}
|
||||||
}
|
}
|
||||||
|
|
||||||
g.wg.Add(1)
|
g.wg.Add(1)
|
||||||
|
|
@ -126,9 +140,9 @@ func (g *Group) TryGo(f func() error) bool {
|
||||||
if g.error() != nil {
|
if g.error() != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if g.sem != nil {
|
if sem := g.sema(); sem != nil {
|
||||||
select {
|
select {
|
||||||
case g.sem <- token{}:
|
case sem <- token{}:
|
||||||
// Note: this allows barging iff channels in general allow barging.
|
// Note: this allows barging iff channels in general allow barging.
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
|
|
@ -159,12 +173,13 @@ func (egerr *ErrgroupLimitError) Error() string {
|
||||||
// The limit must not be modified while any goroutines in the group are active.
|
// The limit must not be modified while any goroutines in the group are active.
|
||||||
func (g *Group) SetLimit(n int) {
|
func (g *Group) SetLimit(n int) {
|
||||||
if n < 0 {
|
if n < 0 {
|
||||||
g.sem = nil
|
g.setSema(nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(g.sem) != 0 {
|
|
||||||
var err error = &ErrgroupLimitError{len(g.sem)}
|
if sem := g.sema(); sem != nil && len(sem) != 0 {
|
||||||
|
var err error = &ErrgroupLimitError{len(sem)}
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
g.sem = make(chan token, n)
|
g.setSema(make(chan token, n))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
144
errgroup_test.go
144
errgroup_test.go
|
|
@ -255,6 +255,45 @@ func TestTryGo(t *testing.T) {
|
||||||
g.Wait()
|
g.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTryGoError(t *testing.T) {
|
||||||
|
g := &errgroup.Group{}
|
||||||
|
ch := make(chan struct{})
|
||||||
|
var count int32 = 0
|
||||||
|
err1 := errors.New("group_test: err1")
|
||||||
|
err2 := errors.New("group_test: err2")
|
||||||
|
|
||||||
|
ran := g.TryGo(func() error {
|
||||||
|
atomic.AddInt32(&count, 1)
|
||||||
|
ch <- struct{}{}
|
||||||
|
return err1
|
||||||
|
})
|
||||||
|
|
||||||
|
if !ran {
|
||||||
|
t.Error("TryGo should succeed but failed first run")
|
||||||
|
}
|
||||||
|
|
||||||
|
<-ch
|
||||||
|
|
||||||
|
ran = g.TryGo(func() error {
|
||||||
|
atomic.AddInt32(&count, 1)
|
||||||
|
return err2
|
||||||
|
})
|
||||||
|
|
||||||
|
if ran {
|
||||||
|
t.Error("TryGo should failed but succeeded second run")
|
||||||
|
}
|
||||||
|
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("TryGo should have run 1 time, but ran %d times", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := g.Wait()
|
||||||
|
|
||||||
|
if err != err1 {
|
||||||
|
t.Errorf("g.Wait() = %v, wanted %v", err, err1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGoLimit(t *testing.T) {
|
func TestGoLimit(t *testing.T) {
|
||||||
const limit = 10
|
const limit = 10
|
||||||
|
|
||||||
|
|
@ -277,6 +316,111 @@ func TestGoLimit(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetLimit(t *testing.T) {
|
||||||
|
g := &errgroup.Group{}
|
||||||
|
ch := make(chan struct{})
|
||||||
|
var count int32 = 0
|
||||||
|
|
||||||
|
g.SetLimit(0)
|
||||||
|
ran := g.TryGo(func() error {
|
||||||
|
atomic.AddInt32(&count, 1)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if ran {
|
||||||
|
t.Fatal("TryGo should fail but succeeded first run")
|
||||||
|
}
|
||||||
|
|
||||||
|
g.SetLimit(-1)
|
||||||
|
ran = g.TryGo(func() error {
|
||||||
|
atomic.AddInt32(&count, 1)
|
||||||
|
ch <- struct{}{}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if !ran {
|
||||||
|
t.Fatal("TryGo should succeed but failed second run")
|
||||||
|
}
|
||||||
|
|
||||||
|
<-ch
|
||||||
|
|
||||||
|
g.SetLimit(1)
|
||||||
|
ran = g.TryGo(func() error {
|
||||||
|
atomic.AddInt32(&count, 1)
|
||||||
|
ch <- struct{}{}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if !ran {
|
||||||
|
t.Fatal("TryGo should succeed but failed third run")
|
||||||
|
}
|
||||||
|
ran = g.TryGo(func() error {
|
||||||
|
atomic.AddInt32(&count, 1)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if ran {
|
||||||
|
t.Fatal("TryGo should fail but succeeded fourth run")
|
||||||
|
}
|
||||||
|
|
||||||
|
<-ch
|
||||||
|
|
||||||
|
if count != 2 {
|
||||||
|
t.Errorf("TryGo should have run 2 times, but ran %d times", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := g.Wait()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("g.Wait() = %v, wanted %v", err, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLimitPanic(t *testing.T) {
|
||||||
|
g := &errgroup.Group{}
|
||||||
|
ch := make(chan struct{})
|
||||||
|
|
||||||
|
g.SetLimit(2)
|
||||||
|
|
||||||
|
ran := g.TryGo(func() error {
|
||||||
|
ch <- struct{}{}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if !ran {
|
||||||
|
t.Fatal("TryGo should succeed but failed first run")
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
rec := recover()
|
||||||
|
if rec == nil {
|
||||||
|
t.Fatal("SetLimit should have panicked, but didn't")
|
||||||
|
}
|
||||||
|
|
||||||
|
er, ok := rec.(error)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("SetLimit should have panicked with an error, %T received", rec)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = er
|
||||||
|
}()
|
||||||
|
g.SetLimit(5)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-ch
|
||||||
|
|
||||||
|
glimErr := &errgroup.ErrgroupLimitError{}
|
||||||
|
if !errors.As(err, &glimErr) {
|
||||||
|
t.Fatalf("panicked error should have been ErrgroupLimitError. %T received", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if glimErr.Size != 1 {
|
||||||
|
t.Fatalf("ErrgroupLimitError should have had a size of 1. Got %d", glimErr.Size)
|
||||||
|
}
|
||||||
|
expErrMsg := "errgroup: modify limit while 1 goroutines in the group are still active"
|
||||||
|
if errStr := glimErr.Error(); errStr != expErrMsg {
|
||||||
|
t.Fatalf("error message should have been %s. Got %s", expErrMsg, errStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCancelCause(t *testing.T) {
|
func TestCancelCause(t *testing.T) {
|
||||||
errDoom := errors.New("group_test: doomed")
|
errDoom := errors.New("group_test: doomed")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue